1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-03-08 10:32:35 +00:00

Compare commits

...

40 Commits

Author SHA1 Message Date
ae5586be72 added evaluation command 2020-07-09 14:36:51 +05:30
069392d098 1. added a test generator and slu evaluator
2. ui dump now include gcp results
3. showing default option for more args validation process commands
2020-06-29 14:24:56 +05:30
515e9c1037 1. split extract all data types in one shot with --extraction-type all flag
2. add notes about diffing split extracted and original data
3. add a nlu conv generator to generate conv data based on nlu utterances and entities
4. add task uid support for dumping corrections
5. abstracted generate date fn
2020-06-25 11:03:09 +05:30
e76ccda5dd 1. fix update-correction to use ui_dump instead of manifest
2. update training params no of checkpoints on chpk frequency
2020-06-19 14:16:04 +05:30
000853b600 1. added option to strip silent chunks
2. computing caller quality based on task-id of corrections
2020-06-17 21:42:20 +05:30
ac0e04c226 stripping silence on call chunk 2020-06-17 19:43:25 +05:30
62eefb9294 fix 11st to 11th in ordinal 2020-06-17 19:30:12 +05:30
8e238c254e 1. added start delay arg in call recycler
2. implement ui_dump/manifest  writer in call_recycler itself
3. refactored call data point plotter
4. added sample-ui task-ui  on the validation process
5. implemented call-quality stats using corrections from mongo
6. support deleting cursors on mongo
7. implement multiple task support on validation ui based on task_id mongo field
2020-06-17 19:11:15 +05:30
7dbb04dcbf 1. added conv data generator
2. more utils
2020-06-16 15:38:07 +05:30
7472b6457d handling non-pnr cases without parens in text data 2020-06-16 11:02:53 +05:30
120302aad3 added support for name/dates/cities call data extraction and more logs 2020-06-15 10:24:38 +05:30
a7a25e9b07 1. using dataname args for update/fill annotations
2. rename to dump_ui
2020-06-10 14:55:59 +05:30
6d149d282d 1. added a data extraction type argument
2. cleanup/refactor
2020-06-09 19:16:24 +05:30
8db1be0083 refactor validation process arguments and logging 2020-06-05 16:32:08 +05:30
bca227a7d7 1. removed the transcriber_pretrained/speller from utils
2. introduced get_mongo_coll to get the collection object directly from mongo uri
3. removed processing of correction entries to remove space/upper casing
2020-06-04 17:49:16 +05:30
e3a01169c2 skipping invalid data points 2020-06-02 17:21:30 +05:30
3a5ce069ab parallelize data loading from remote 2020-05-29 12:14:14 +05:30
9f9cb62b60 show duration on validation of dataset 2020-05-28 11:35:31 +05:30
de21952349 1. refactored wav chunk processing method
2. renamed streamlit to validation_ui
2020-05-28 11:18:39 +05:30
d87369c8fe don't load audio for annotation only ui and keep spoken as text for normal asr validation 2020-05-27 15:57:42 +05:30
41af0a87de respect verbose flag 2020-05-27 15:54:16 +05:30
6f395af10d fix skipping null audio and add more verbose logs 2020-05-27 15:49:58 +05:30
a38789d0c3 added option to disable plots during validation 2020-05-27 15:43:03 +05:30
7ff2db3e2e cleanup rev recycle 2020-05-27 15:33:22 +05:30
1acf9e403c 1. added support for mono/dual channel rev transcripts
2. handle errors when extracting datapoints from rev meta data
3. added suport for annotation only task when dumping ui data
2020-05-27 15:19:25 +05:30
1f2bedc156 1. enabled silece stripping in chunks when recycling audio from asr logs
2. limit asr recycling to 1 min of start audio to get reliable alignments and ignoring agent channel
3. added rev recycler for generating asr dataset from rev transcripts and audio
4. update pydub dependency for silence stripping fn and removing threadpool hardcoded worker count
2020-05-27 14:22:44 +05:30
fca9c1aeb3 refactored module structure 2020-05-21 19:13:44 +05:30
2d5b720284 1. added utility command to export call logs
2. mongo conn accepts port
2020-05-21 10:43:26 +05:30
8e79bbb571 1. implement dataset augmentation and validation in process
2. added option to skip 'incorrect' annotations in validation data
3. added confirmation on clearing mongo collection
4. added an option to navigate to a given text in the validation ui
5. added a dataset and remote option to trainer to load dataset from directory and remote rpyc service
2020-05-20 11:16:22 +05:30
83db445a6f 1. added training utils with custom data loaders with remote rpyc dataservice support
2. fix validation correction dump path
3. cache dataset for precaching before training to memory
4. update dependencies
2020-05-14 15:39:44 +05:30
d4aef4088d 1. clean-up unused data process code
2. fix invalid sample no from mongo
3. data loader service return remote netref
2020-05-13 14:02:46 +05:30
fdccea6b23 unlink temporary files after transcribing 2020-05-12 23:38:31 +05:30
c06a0814b9 1. added a tool to extract asr data from gcp transcripts logs
2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file
3. clean-up leaderboard duration logic
4. added a wip dataloader service
5. made the asr_data_writer util more generic with verbose flags and unique filename
6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node
7. refactored the validation post processing to dump a ui config for validation
8. included utility functions to correct, fill update and clear annotations from mongodb data
9. refactored the ui logic to be more generic for any asr data
10. updated setup.py dependencies to support the above features
2020-05-12 23:38:06 +05:30
a7da729c0b add validation ui and post processing to correct using validation data 2020-05-06 12:18:34 +05:30
aae03a6ae4 refresh to next entry on submit and comment out mongo clearing code for safety :P 2020-04-29 22:52:46 +05:30
4fd05a56d0 1. refactored streamlit code
2. fixed issues in data manifest handling
2020-04-29 17:22:45 +05:30
41074a1bca 1. added streamlit based validation ui with mongodb datastore integration
2. fix asr wrong sample rate inference
3. update requirements
2020-04-29 14:26:11 +05:30
61048f855e implement call audio data recycler for asr 2020-04-27 10:53:14 +05:30
2c15b00da3 fix module packaging issue 2020-04-08 20:45:38 +05:30
d22a99a4f6 1. integrated data generator using google tts
2. added training script
2020-04-08 18:53:49 +05:30
30 changed files with 3816 additions and 5 deletions

41
.gitignore vendored
View File

@@ -1,3 +1,11 @@
/data/
/model/
/train/
.env*
*.yaml
*.yml
*.json
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python
@@ -108,3 +116,36 @@ dmypy.json
.pyre/
# End of https://www.gitignore.io/api/python
# Created by https://www.gitignore.io/api/macos
# Edit at https://www.gitignore.io/?templates=macos
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
# End of https://www.gitignore.io/api/macos

5
Notes.md Normal file
View File

@@ -0,0 +1,5 @@
> Diff after splitting based on type
```
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
```

View File

@@ -62,7 +62,7 @@ class JasperASR(object):
wf = wave.open(audio_file_path, "w")
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(16000)
wf.setframerate(24000)
wf.writeframesraw(audio_data)
wf.close()
manifest = {"audio_filepath": audio_file_path, "duration": 60, "text": "todo"}
@@ -108,6 +108,8 @@ class JasperASR(object):
tensors = self.neural_factory.infer(tensors=eval_tensors)
prediction = post_process_predictions(tensors[0], self.labels)
prediction_text = ". ".join(prediction)
os.unlink(manifest_file.name)
os.unlink(audio_file.name)
return prediction_text
def transcribe_file(self, audio_file, *args, **kwargs):

21
jasper/client.py Normal file
View File

@@ -0,0 +1,21 @@
import os
import logging
import rpyc
from functools import lru_cache
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost")
ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045"))
@lru_cache()
def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
asr = rpyc.connect(asr_host, asr_port).root
logger.info(f"connected to asr server successfully")
return asr.transcribe

1
jasper/data/__init__.py Normal file
View File

@@ -0,0 +1 @@

104
jasper/data/asr_recycler.py Normal file
View File

@@ -0,0 +1,104 @@
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
app = typer.Typer()
@app.command()
def extract_data(
call_audio_dir: Path = Path("/dataset/png_prod/call_audio"),
call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"),
output_dir: Path = Path("./data"),
dataset_name: str = "png_gcp_2jan",
verbose: bool = False,
):
from pydub import AudioSegment
from .utils import ExtendedPath, asr_data_writer, strip_silence
from lenses import lens
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_event_generator(call_audio_dir):
for wav_path in call_audio_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
meta_path = call_meta_dir / rel_meta_path
events = ExtendedPath(meta_path).read_json()
yield call_wav, wav_path, events
def contains_asr(x):
return "AsrResult" in x
def channel(n):
def filter_func(ev):
return (
ev["AsrResult"]["Channel"] == n
if "Channel" in ev["AsrResult"]
else n == 0
)
return filter_func
def compute_endtime(call_wav, state):
for (i, st) in enumerate(state):
start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0)
transcript = st["AsrResult"]["Alternatives"][0]["Transcript"]
if i + 1 < len(state):
end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
else:
end_time = call_wav.duration_seconds
full_code_seg = call_wav[start_time * 1000 : end_time * 1000]
code_seg = strip_silence(full_code_seg)
code_fb = BytesIO()
code_seg.export(code_fb, format="wav")
code_wav = code_fb.getvalue()
# only starting 1 min audio has reliable alignment ignore rest
if start_time > 60:
if verbose:
print(f'start time over 60 seconds of audio skipping.')
break
# only if some reasonable audio data is present yield it
if code_seg.duration_seconds < 0.5:
if verbose:
print(f'transcript chunk "{transcript}" contains no audio skipping.')
continue
yield transcript, code_seg.duration_seconds, code_wav
def asr_data_generator(call_wav, call_wav_fname, events):
call_wav_0, call_wav_1 = call_wav.split_to_mono()
asr_events = lens["Events"].Each()["Event"].Filter(contains_asr)
call_evs_0 = asr_events.Filter(channel(0)).collect()(events)
# Ignoring agent channel events
# call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
if verbose:
typer.echo(f"processing data points on {call_wav_fname}")
call_data_0 = compute_endtime(call_wav_0, call_evs_0)
# Ignoring agent channel
# call_data_1 = compute_endtime(call_wav_1, call_evs_1)
return call_data_0 # chain(call_data_0, call_data_1)
def generate_call_asr_data():
full_asr_data = []
total_duration = 0
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
asr_data = asr_data_generator(wav, wav_path, ev)
total_duration += wav.duration_seconds
full_asr_data.append(asr_data)
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
typer.echo(f"written {n_dps} data points")
generate_call_asr_data()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,509 @@
import typer
from pathlib import Path
from enum import Enum
app = typer.Typer()
@app.command()
def export_all_logs(
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
):
from .utils import get_mongo_conn
from collections import defaultdict
from ruamel.yaml import YAML
yaml = YAML()
mongo_coll = get_mongo_conn()
caller_calls = defaultdict(lambda: [])
for call in mongo_coll.find():
sysid = call["SystemID"]
call_uri = f"http://{domain}/calls/{sysid}"
caller = call["Caller"]
caller_calls[caller].append(call_uri)
caller_list = []
for caller in caller_calls:
caller_list.append({"name": caller, "calls": caller_calls[caller]})
output_yaml = {"users": caller_list}
typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
with call_logs_file.open("w") as yf:
yaml.dump(output_yaml, yf)
@app.command()
def export_calls_between(
start_cid: str,
end_cid: str,
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
mongo_port: int = 27017,
):
from collections import defaultdict
from ruamel.yaml import YAML
from .utils import get_mongo_conn
yaml = YAML()
mongo_coll = get_mongo_conn(port=mongo_port)
start_meta = mongo_coll.find_one({"SystemID": start_cid})
end_meta = mongo_coll.find_one({"SystemID": end_cid})
caller_calls = defaultdict(lambda: [])
call_query = mongo_coll.find(
{
"StartTS": {"$gte": start_meta["StartTS"]},
"EndTS": {"$lte": end_meta["EndTS"]},
}
)
for call in call_query:
sysid = call["SystemID"]
call_uri = f"http://{domain}/calls/{sysid}"
caller = call["Caller"]
caller_calls[caller].append(call_uri)
caller_list = []
for caller in caller_calls:
caller_list.append({"name": caller, "calls": caller_calls[caller]})
output_yaml = {"users": caller_list}
typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
with call_logs_file.open("w") as yf:
yaml.dump(output_yaml, yf)
@app.command()
def copy_metas(
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
output_dir: Path = Path("./data"),
meta_dir: Path = Path("/tmp/call_metas"),
):
from lenses import lens
from ruamel.yaml import YAML
from urllib.parse import urlsplit
from shutil import copy2
yaml = YAML()
call_logs = yaml.load(call_logs_file.read_text())
call_meta_dir: Path = output_dir / Path("call_metas")
call_meta_dir.mkdir(exist_ok=True, parents=True)
meta_dir.mkdir(exist_ok=True, parents=True)
def get_cid(uri):
return Path(urlsplit(uri).path).stem
def copy_meta(uri):
cid = get_cid(uri)
saved_meta_path = call_meta_dir / Path(f"{cid}.json")
dest_meta_path = meta_dir / Path(f"{cid}.json")
if not saved_meta_path.exists():
print(f"{saved_meta_path} not found")
copy2(saved_meta_path, dest_meta_path)
def download_meta_audio():
call_lens = lens["users"].Each()["calls"].Each()
call_lens.modify(copy_meta)(call_logs)
download_meta_audio()
class ExtractionType(str, Enum):
flow = "flow"
data = "data"
@app.command()
def analyze(
leaderboard: bool = False,
plot_calls: bool = False,
extract_data: bool = False,
extraction_type: ExtractionType = typer.Option(
ExtractionType.data, show_default=True
),
start_delay: float = 1.5,
download_only: bool = False,
strip_silent_chunks: bool = True,
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
output_dir: Path = Path("./data"),
data_name: str = None,
mongo_uri: str = typer.Option(
"mongodb://localhost:27017/test.calls", show_default=True
),
):
from urllib.parse import urlsplit
from functools import reduce
import boto3
from io import BytesIO
import json
from ruamel.yaml import YAML
import re
from google.protobuf.timestamp_pb2 import Timestamp
from datetime import timedelta
import librosa
import librosa.display
from lenses import lens
from pprint import pprint
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from tqdm import tqdm
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll, get_call_logs
from pydub import AudioSegment
from natural.date import compress
matplotlib.rcParams["agg.path.chunksize"] = 10000
matplotlib.use("agg")
yaml = YAML()
s3 = boto3.client("s3")
mongo_collection = get_mongo_coll(mongo_uri)
call_media_dir: Path = output_dir / Path("call_wavs")
call_media_dir.mkdir(exist_ok=True, parents=True)
call_meta_dir: Path = output_dir / Path("call_metas")
call_meta_dir.mkdir(exist_ok=True, parents=True)
call_plot_dir: Path = output_dir / Path("plots")
call_plot_dir.mkdir(exist_ok=True, parents=True)
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
dataset_name = call_logs_file.stem if not data_name else data_name
call_logs = yaml.load(call_logs_file.read_text())
def gen_ev_fev_timedelta(fev):
fev_p = Timestamp()
fev_p.FromJsonString(fev["CreatedTS"])
fev_dt = fev_p.ToDatetime()
td_0 = timedelta()
def get_timedelta(ev):
ev_p = Timestamp()
ev_p.FromJsonString(value=ev["CreatedTS"])
ev_dt = ev_p.ToDatetime()
delta = ev_dt - fev_dt
return delta if delta > td_0 else td_0
return get_timedelta
def chunk_n(evs, n):
return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
if extraction_type == ExtractionType.data:
def is_utter_event(ev):
return (
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
and (ev["Type"] != "DEBUG")
and ev["Type"] != "ASR_RESULT"
)
def get_data_points(utter_events, td_fn):
data_points = []
for evs in chunk_n(utter_events, 3):
try:
assert evs[0]["Type"] == "CONV_RESULT"
assert evs[1]["Type"] == "STARTED_SPEAKING"
assert evs[2]["Type"] == "STOPPED_SPEAKING"
start_time = td_fn(evs[1]).total_seconds() - start_delay
end_time = td_fn(evs[2]).total_seconds()
spoken = evs[0]["Msg"]
data_points.append(
{"start_time": start_time, "end_time": end_time, "code": spoken}
)
except AssertionError:
# skipping invalid data_points
pass
return data_points
def text_extractor(spoken):
return (
re.search(r"'(.*)'", spoken).groups(0)[0]
if len(spoken) > 6 and re.search(r"'(.*)'", spoken)
else spoken
)
elif extraction_type == ExtractionType.flow:
def is_final_asr_event_or_spoken(ev):
pld = json.loads(ev["Payload"])
return (
pld["AsrResult"]["Results"][0]["IsFinal"]
if ev["Type"] == "ASR_RESULT"
else True
)
def is_utter_event(ev):
return (
ev["Author"] == "CONV"
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
) and (ev["Type"] != "DEBUG")
def get_data_points(utter_events, td_fn):
data_points = []
for evs in chunk_n(utter_events, 4):
try:
assert len(evs) == 4
assert evs[0]["Type"] == "CONV_RESULT"
assert evs[1]["Type"] == "STARTED_SPEAKING"
assert evs[2]["Type"] == "ASR_RESULT"
assert evs[3]["Type"] == "STOPPED_SPEAKING"
start_time = td_fn(evs[1]).total_seconds() - start_delay
end_time = td_fn(evs[2]).total_seconds()
conv_msg = evs[0]["Msg"]
if "full name" in conv_msg.lower():
pld = json.loads(evs[2]["Payload"])
spoken = pld["AsrResult"]["Results"][0]["Alternatives"][0][
"Transcript"
]
data_points.append(
{
"start_time": start_time,
"end_time": end_time,
"code": spoken,
}
)
except AssertionError:
# skipping invalid data_points
pass
return data_points
def text_extractor(spoken):
return spoken
def process_call(call_obj):
call_meta = get_call_logs(call_obj, s3, call_meta_dir)
call_events = call_meta["Events"]
def is_writer_uri_event(ev):
return ev["Author"] == "AUDIO_WRITER" and "s3://" in ev["Msg"]
writer_events = list(filter(is_writer_uri_event, call_events))
s3_wav_url = re.search(r"(s3://.*)", writer_events[0]["Msg"]).groups(0)[0]
s3_wav_url_p = urlsplit(s3_wav_url)
def is_first_audio_ev(state, ev):
if state[0]:
return state
else:
return (ev["Author"] == "GATEWAY" and ev["Type"] == "AUDIO", ev)
(_, first_audio_ev) = reduce(is_first_audio_ev, call_events, (False, {}))
get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
uevs = list(filter(is_utter_event, call_events))
ev_count = len(uevs)
utter_events = uevs[: ev_count - ev_count % 3]
saved_wav_path = call_media_dir / Path(Path(s3_wav_url_p.path).name)
if not saved_wav_path.exists():
print(f"downloading : {saved_wav_path} from {s3_wav_url}")
s3.download_file(
s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
)
return {
"wav_path": saved_wav_path,
"num_samples": len(utter_events) // 3,
"meta": call_obj,
"first_event_fn": get_ev_fev_timedelta,
"utter_events": utter_events,
}
def get_cid(uri):
return Path(urlsplit(uri).path).stem
def ensure_call(uri):
cid = get_cid(uri)
meta = mongo_collection.find_one({"SystemID": cid})
process_meta = process_call(meta)
return process_meta
def retrieve_processed_callmeta(uri):
cid = get_cid(uri)
meta = mongo_collection.find_one({"SystemID": cid})
duration = meta["EndTS"] - meta["StartTS"]
process_meta = process_call(meta)
data_points = get_data_points(
process_meta["utter_events"], process_meta["first_event_fn"]
)
process_meta["data_points"] = data_points
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
def retrieve_callmeta(call_uri):
uri = call_uri["call_uri"]
name = call_uri["name"]
cid = get_cid(uri)
meta = mongo_collection.find_one({"SystemID": cid})
duration = meta["EndTS"] - meta["StartTS"]
process_meta = process_call(meta)
data_points = get_data_points(
process_meta["utter_events"], process_meta["first_event_fn"]
)
process_meta["data_points"] = data_points
return {
"url": uri,
"name": name,
"meta": meta,
"duration": duration,
"process": process_meta,
}
def download_meta_audio():
call_lens = lens["users"].Each()["calls"].Each()
call_lens.modify(ensure_call)(call_logs)
def plot_calls_data():
def plot_data_points(y, sr, data_points, file_path):
plt.figure(figsize=(16, 12))
librosa.display.waveplot(y=y, sr=sr)
for dp in data_points:
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
plt.axvspan(start, end, color="green", alpha=0.2)
text_pos = (start + end) / 2
plt.text(
text_pos,
0.25,
f"{code}",
rotation=90,
horizontalalignment="center",
verticalalignment="center",
)
plt.title("Datapoints")
plt.savefig(file_path, format="png")
return file_path
def plot_call(call_obj):
saved_wav_path, data_points, sys_id = (
call_obj["process"]["wav_path"],
call_obj["process"]["data_points"],
call_obj["meta"]["SystemID"],
)
file_path = call_plot_dir / Path(sys_id).with_suffix(".png")
if not file_path.exists():
print(f"plotting: {file_path}")
(y, sr) = librosa.load(saved_wav_path)
plot_data_points(y, sr, data_points, str(file_path))
return file_path
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
# call_plot_data = call_lens.collect()(call_stats)
call_plots = call_lens.modify(plot_call)(call_stats)
# with ThreadPoolExecutor(max_workers=20) as exe:
# print('starting all plot tasks')
# responses = [exe.submit(plot_call, w) for w in call_plot_data]
# print('submitted all plot tasks')
# call_plots = [r.result() for r in responses]
pprint(call_plots)
def extract_data_points():
if strip_silent_chunks:
def audio_process(seg):
return strip_silence(seg)
else:
def audio_process(seg):
return seg
def gen_data_values(saved_wav_path, data_points, caller_name):
call_seg = (
AudioSegment.from_wav(saved_wav_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
for dp_id, dp in enumerate(data_points):
start, end, spoken = dp["start_time"], dp["end_time"], dp["code"]
spoken_seg = audio_process(call_seg[start * 1000 : end * 1000])
spoken_fb = BytesIO()
spoken_seg.export(spoken_fb, format="wav")
spoken_wav = spoken_fb.getvalue()
# search for actual pnr code and handle plain codes as well
extracted_code = text_extractor(spoken)
if strip_silent_chunks and spoken_seg.duration_seconds < 0.5:
print(f'transcript chunk "{spoken}" contains no audio skipping.')
continue
yield extracted_code, spoken_seg.duration_seconds, spoken_wav, caller_name, spoken_seg
call_lens = lens["users"].Each()["calls"].Each()
def assign_user_call(uc):
return (
lens["calls"]
.Each()
.modify(lambda c: {"call_uri": c, "name": uc["name"]})(uc)
)
user_call_logs = lens["users"].Each().modify(assign_user_call)(call_logs)
call_stats = call_lens.modify(retrieve_callmeta)(user_call_logs)
call_objs = call_lens.collect()(call_stats)
def data_source():
for call_obj in tqdm(call_objs):
saved_wav_path, data_points, name = (
call_obj["process"]["wav_path"],
call_obj["process"]["data_points"],
call_obj["name"],
)
for dp in gen_data_values(saved_wav_path, data_points, name):
yield dp
ui_dump_manifest_writer(call_asr_data, dataset_name, data_source())
def show_leaderboard():
def compute_user_stats(call_stat):
n_samples = (
lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
)
n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
return {
"num_samples": n_samples,
"duration": n_duration.total_seconds(),
"samples_rate": n_samples / n_duration.total_seconds(),
"duration_str": compress(n_duration, pad=" "),
"name": call_stat["name"],
}
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
user_stats = lens["users"].Each().modify(compute_user_stats)(call_stats)
leader_df = (
pd.DataFrame(user_stats["users"])
.sort_values(by=["duration"], ascending=False)
.reset_index(drop=True)
)
leader_df["rank"] = leader_df.index + 1
leader_board = leader_df.rename(
columns={
"rank": "Rank",
"num_samples": "Count",
"name": "Name",
"samples_rate": "SpeechRate",
"duration_str": "Duration",
}
)[["Rank", "Name", "Count", "Duration"]]
print(
"""ASR Dataset Leaderboard :
---------------------------------"""
)
print(leader_board.to_string(index=False))
if download_only:
download_meta_audio()
return
if leaderboard:
show_leaderboard()
if plot_calls:
plot_calls_data()
if extract_data:
extract_data_points()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,27 @@
import typer
from pathlib import Path
from .utils import generate_dates
app = typer.Typer()
@app.command()
def export_conv_json(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
):
from .utils import ExtendedPath
conv_data = ExtendedPath(conv_src).read_json()
conv_data["dates"] = generate_dates()
ExtendedPath(conv_dest).write_json(conv_data)
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,98 @@
from pathlib import Path
import typer
import pandas as pd
from ruamel.yaml import YAML
from itertools import product
from .utils import generate_dates
app = typer.Typer()
def unique_entity_list(entity_template_tags, entity_data):
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return list(unique_entity_set)
def nlu_entity_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
yaml = YAML()
nlu_data = yaml.load(nlu_data_file.read_text())
for cf in nlu_data["csv_files"]:
data = pd.read_csv(cf["fname"])
for et in cf["entities"]:
entity_name = et["name"]
entity_template_tags = et["tags"]
if "filter" in et:
entity_data = data[data[cf["filter_key"]] == et["filter"]]
else:
entity_data = data
yield entity_name, entity_template_tags, entity_data
def nlu_samples_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
yaml = YAML()
nlu_data = yaml.load(nlu_data_file.read_text())
sm = {s["name"]: s for s in nlu_data["samples_per_entity"]}
return sm
@app.command()
def compute_unique_nlu_stats(
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
):
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
nlu_data_file
):
entity_count = len(unique_entity_list(entity_template_tags, entity_data))
print(f"{entity_name}\t{entity_count}")
def replace_entity(tmpl, value, tags):
result = tmpl
for t in tags:
result = result.replace(t, value)
return result
@app.command()
def export_nlu_conv_json(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
):
from .utils import ExtendedPath
from random import sample
entity_samples = nlu_samples_reader(nlu_data_file)
conv_data = ExtendedPath(conv_src).read_json()
conv_data["Dates"] = generate_dates()
result_dict = {}
data_count = 0
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
nlu_data_file
):
entity_variants = sample(conv_data[entity_name], entity_samples[entity_name]["test_size"])
unique_entites = unique_entity_list(entity_template_tags, entity_data)
# sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
result_dict[entity_name] = []
for val in entity_variants:
sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
for tmpl in sample_entites:
result = replace_entity(tmpl, val, entity_template_tags)
result_dict[entity_name].append(result)
data_count += 1
print(f"Total of {data_count} variants generated")
ExtendedPath(conv_dest).write_json(result_dict)
def main():
app()
if __name__ == "__main__":
main()

77
jasper/data/process.py Normal file
View File

@@ -0,0 +1,77 @@
import json
from pathlib import Path
from sklearn.model_selection import train_test_split
from .utils import asr_manifest_reader, asr_manifest_writer
from typing import List
from itertools import chain
import typer
app = typer.Typer()
@app.command()
def fixate_data(dataset_path: Path):
manifest_path = dataset_path / Path("manifest.json")
real_manifest_path = dataset_path / Path("abs_manifest.json")
def fix_path():
for i in asr_manifest_reader(manifest_path):
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
yield i
asr_manifest_writer(real_manifest_path, fix_path())
@app.command()
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
reader_list = []
abs_manifest_path = Path("abs_manifest.json")
for dataset_path in src_dataset_paths:
manifest_path = dataset_path / abs_manifest_path
reader_list.append(asr_manifest_reader(manifest_path))
dest_dataset_path.mkdir(parents=True, exist_ok=True)
dest_manifest_path = dest_dataset_path / abs_manifest_path
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
@app.command()
def split_data(dataset_path: Path, test_size: float = 0.1):
manifest_path = dataset_path / Path("abs_manifest.json")
asr_data = list(asr_manifest_reader(manifest_path))
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
@app.command()
def validate_data(dataset_path: Path):
from natural.date import compress
from datetime import timedelta
for mf_type in ["train_manifest.json", "test_manifest.json"]:
data_file = dataset_path / Path(mf_type)
print(f"validating {data_file}.")
with Path(data_file).open("r") as pf:
pnr_jsonl = pf.readlines()
duration = 0
for (i, s) in enumerate(pnr_jsonl):
try:
d = json.loads(s)
duration += d["duration"]
audio_file = data_file.parent / Path(d["audio_filepath"])
if not audio_file.exists():
raise OSError(f"File {audio_file} not found")
except BaseException as e:
print(f'failed on {i} with "{e}"')
duration_str = compress(timedelta(seconds=duration), pad=" ")
print(
f"no errors found. seems like a valid {mf_type}. contains {duration_str}sec of audio"
)
def main():
app()
if __name__ == "__main__":
main()

175
jasper/data/rev_recycler.py Normal file
View File

@@ -0,0 +1,175 @@
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
import re
app = typer.Typer()
@app.command()
def extract_data(
call_audio_dir: Path = typer.Option(Path("/dataset/rev/wavs"), show_default=True),
call_meta_dir: Path = typer.Option(Path("/dataset/rev/jsons"), show_default=True),
output_dir: Path = typer.Option(Path("./data"), show_default=True),
dataset_name: str = typer.Option("rev_transribed", show_default=True),
verbose: bool = False,
):
from pydub import AudioSegment
from .utils import ExtendedPath, asr_data_writer, strip_silence
from lenses import lens
import datetime
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_event_generator(call_audio_dir):
for wav_path in call_audio_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
meta_path = call_meta_dir / rel_meta_path
if meta_path.exists():
events = ExtendedPath(meta_path).read_json()
yield call_wav, wav_path, events
else:
if verbose:
typer.echo(f"missing json corresponding to {wav_path}")
def contains_asr(x):
return "AsrResult" in x
def channel(n):
def filter_func(ev):
return (
ev["AsrResult"]["Channel"] == n
if "Channel" in ev["AsrResult"]
else n == 0
)
return filter_func
def time_to_msecs(time_str):
return (
datetime.datetime.strptime(time_str, "%H:%M:%S,%f")
- datetime.datetime(1900, 1, 1)
).total_seconds() * 1000
def process_utterance_chunk(wav_seg, start_time, end_time, monologue):
# offset by 1sec left side to include vad? discarded audio
full_tscript_wav_seg = wav_seg[
time_to_msecs(start_time) - 1000 : time_to_msecs(end_time) # + 1000
]
tscript_wav_seg = strip_silence(full_tscript_wav_seg)
tscript_wav_fb = BytesIO()
tscript_wav_seg.export(tscript_wav_fb, format="wav")
tscript_wav = tscript_wav_fb.getvalue()
text = "".join(lens["elements"].Each()["value"].collect()(monologue))
text_clean = re.sub(r"\[.*\]", "", text)
return tscript_wav, tscript_wav_seg.duration_seconds, text_clean
def dual_asr_data_generator(wav_seg, wav_path, meta):
left_audio, right_audio = wav_seg.split_to_mono()
channel_map = {"Agent": right_audio, "Client": left_audio}
monologues = lens["monologues"].Each().collect()(meta)
for monologue in monologues:
# print(monologue["speaker_name"])
speaker_channel = channel_map.get(monologue["speaker_name"])
if not speaker_channel:
if verbose:
print(
f'unknown speaker tag {monologue["speaker_name"]} in wav:{wav_path} skipping.'
)
continue
try:
start_time = (
lens["elements"]
.Each()
.Filter(lambda x: "timestamp" in x)["timestamp"]
.collect()(monologue)[0]
)
end_time = (
lens["elements"]
.Each()
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
.collect()(monologue)[-1]
)
except IndexError:
if verbose:
print(
f"error when loading timestamp events in wav:{wav_path} skipping."
)
continue
tscript_wav, seg_dur, text_clean = process_utterance_chunk(
speaker_channel, start_time, end_time, monologue
)
if seg_dur < 0.5:
if verbose:
print(
f'transcript chunk "{text_clean}" contains no audio in {wav_path} skipping.'
)
continue
yield text_clean, seg_dur, tscript_wav
def mono_asr_data_generator(wav_seg, wav_path, meta):
monologues = lens["monologues"].Each().collect()(meta)
for monologue in monologues:
try:
start_time = (
lens["elements"]
.Each()
.Filter(lambda x: "timestamp" in x)["timestamp"]
.collect()(monologue)[0]
)
end_time = (
lens["elements"]
.Each()
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
.collect()(monologue)[-1]
)
except IndexError:
if verbose:
print(
f"error when loading timestamp events in wav:{wav_path} skipping."
)
continue
tscript_wav, seg_dur, text_clean = process_utterance_chunk(
wav_seg, start_time, end_time, monologue
)
if seg_dur < 0.5:
if verbose:
print(
f'transcript chunk "{text_clean}" contains no audio in {wav_path} skipping.'
)
continue
yield text_clean, seg_dur, tscript_wav
def generate_rev_asr_data():
full_asr_data = []
total_duration = 0
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
if wav.channels > 2:
print(f"skipping many channel audio {wav_path}")
asr_data_generator = (
mono_asr_data_generator
if wav.channels == 1
else dual_asr_data_generator
)
asr_data = asr_data_generator(wav, wav_path, ev)
total_duration += wav.duration_seconds
full_asr_data.append(asr_data)
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
typer.echo(f"written {n_dps} data points")
generate_rev_asr_data()
def main():
app()
if __name__ == "__main__":
main()

57
jasper/data/server.py Normal file
View File

@@ -0,0 +1,57 @@
import os
from pathlib import Path
import typer
import rpyc
from rpyc.utils.server import ThreadedServer
import nemo
import pickle
# import nemo.collections.asr as nemo_asr
from nemo.collections.asr.parts.segment import AudioSegment
app = typer.Typer()
nemo.core.NeuralModuleFactory(
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
)
class ASRDataService(rpyc.Service):
def exposed_get_path_samples(
self, file_path, target_sr, int_values, offset, duration, trim
):
print(f"loading.. {file_path}")
audio = AudioSegment.from_file(
file_path,
target_sr=target_sr,
int_values=int_values,
offset=offset,
duration=duration,
trim=trim,
)
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
return pickle.dumps(audio.samples)
def exposed_read_path(self, file_path):
# print(f"reading path.. {file_path}")
return Path(file_path).read_bytes()
@app.command()
def run_server(port: int = 0):
listen_port = port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
service = ASRDataService()
t = ThreadedServer(
service, port=listen_port, protocol_config={"allow_all_attrs": True}
)
typer.echo(f"starting asr server on {listen_port}...")
t.start()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,180 @@
import typer
from pathlib import Path
import json
# from .utils import generate_dates, asr_test_writer
app = typer.Typer()
def run_test(reg_path, coll, s3, call_meta_dir, city_code, test_path):
from time import sleep
import subprocess
from .utils import ExtendedPath, get_call_logs
coll.delete_many({"CallID": test_path.name})
# test_path = dump_dir / data_name / test_file
# "../saas_reg/regression/run.sh -f data/asr_data/call_upwork_test_cnd_cities/asr_test.reg"
test_output = subprocess.run(
["/bin/bash", "-c", f"{str(reg_path)} --addr [::]:15400 -f {str(test_path)}"]
)
if test_output.returncode != 0:
print("Error running test {test_file}")
return
def get_meta():
call_meta = coll.find_one({"CallID": test_path.name})
if call_meta:
return call_meta
else:
sleep(2)
return get_meta()
call_meta = get_meta()
call_logs = get_call_logs(call_meta, s3, call_meta_dir)
call_events = call_logs["Events"]
test_data_path = test_path.with_suffix(".result.json")
test_data = ExtendedPath(test_data_path).read_json()
def is_final_asr_event_or_spoken(ev):
pld = json.loads(ev["Payload"])
return (
pld["AsrResult"]["Results"][0]["IsFinal"]
if ev["Type"] == "ASR_RESULT"
else False
)
def is_test_event(ev):
return (
ev["Author"] == "NLU"
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
) and (ev["Type"] != "DEBUG")
test_evs = list(filter(is_test_event, call_events))
if len(test_evs) == 2:
try:
asr_payload = test_evs[0]["Payload"]
asr_result = json.loads(asr_payload)["AsrResult"]["Results"][0]
alt_tscripts = [alt["Transcript"] for alt in asr_result["Alternatives"]]
gcp_result = "|".join(alt_tscripts)
entity_asr = asr_result["AsrDynamicResults"][0]["Candidate"]["Transcript"]
nlu_payload = test_evs[1]["Payload"]
nlu_result_payload = json.loads(nlu_payload)["NluResults"]
entity = test_data[0]["entity"]
text = test_data[0]["text"]
audio_filepath = test_data[0]["audio_filepath"]
pretrained_asr = test_data[0]["pretrained_asr"]
nlu_entity = list(json.loads(nlu_result_payload)["Entities"].values())[0]
asr_entity = city_code[entity] if entity in city_code else "UNKNOWN"
entities_match = asr_entity == nlu_entity
result = "Success" if entities_match else "Fail"
return {
"expected_entity": entity,
"text": text,
"audio_filepath": audio_filepath,
"pretrained_asr": pretrained_asr,
"entity_asr": entity_asr,
"google_asr": gcp_result,
"nlu_result": nlu_result_payload,
"asr_entity": asr_entity,
"nlu_entity": nlu_entity,
"result": result,
}
except Exception:
return {
"expected_entity": test_data[0]["entity"],
"text": test_data[0]["text"],
"audio_filepath": test_data[0]["audio_filepath"],
"pretrained_asr": test_data[0]["pretrained_asr"],
"entity_asr": "",
"google_asr": "",
"nlu_result": "",
"asr_entity": "",
"nlu_entity": "",
"result": "Error",
}
else:
return {
"expected_entity": test_data[0]["entity"],
"text": test_data[0]["text"],
"audio_filepath": test_data[0]["audio_filepath"],
"pretrained_asr": test_data[0]["pretrained_asr"],
"entity_asr": "",
"google_asr": "",
"nlu_result": "",
"asr_entity": "",
"nlu_entity": "",
"result": "Empty",
}
@app.command()
def evaluate_slu(
# conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
# extraction_key: str = "Cities",
dump_dir: Path = Path("./data/asr_data"),
call_meta_dir: Path = Path("./data/call_metas"),
test_file_pref: str = "asr_test",
mongo_uri: str = typer.Option(
"mongodb://localhost:27017/test.calls", show_default=True
),
test_results: Path = Path("./data/results.csv"),
airport_codes: Path = Path("./airports_code.csv"),
reg_path: Path = Path("../saas_reg/regression/run.sh"),
test_id: str = "5ef481f27031edf6910e94e0",
):
# import json
from .utils import get_mongo_coll
import pandas as pd
import boto3
from concurrent.futures import ThreadPoolExecutor
from functools import partial
# import subprocess
# from time import sleep
import csv
from tqdm import tqdm
s3 = boto3.client("s3")
df = pd.read_csv(airport_codes)[["iata", "city"]]
city_code = pd.Series(df["iata"].values, index=df["city"]).to_dict()
test_files = list((dump_dir / data_name).glob(test_file_pref + "*.reg"))
coll = get_mongo_coll(mongo_uri)
with test_results.open("w") as csvfile:
fieldnames = [
"expected_entity",
"text",
"audio_filepath",
"pretrained_asr",
"entity_asr",
"google_asr",
"nlu_result",
"asr_entity",
"nlu_entity",
"result",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
with ThreadPoolExecutor(max_workers=8) as exe:
print("starting all loading tasks")
for test_result in tqdm(
exe.map(
partial(run_test, reg_path, coll, s3, call_meta_dir, city_code),
test_files,
),
position=0,
leave=True,
total=len(test_files),
):
writer.writerow(test_result)
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,99 @@
import typer
from pathlib import Path
from .utils import generate_dates, asr_test_writer
app = typer.Typer()
@app.command()
def export_test_reg(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
extraction_key: str = "Cities",
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
test_file: Path = Path("asr_test.reg"),
):
from .utils import (
ExtendedPath,
asr_manifest_reader,
gcp_transcribe_gen,
parallel_apply,
)
from ..client import transcribe_gen
from pydub import AudioSegment
from queue import PriorityQueue
jasper_map = {
"PNRs": 8045,
"Cities": 8046,
"Names": 8047,
"Dates": 8048,
}
# jasper_map = {"PNRs": 8050, "Cities": 8050, "Names": 8050, "Dates": 8050}
transcriber_gcp = gcp_transcribe_gen()
transcriber_trained = transcribe_gen(asr_port=jasper_map[extraction_key])
transcriber_all_trained = transcribe_gen(asr_port=8050)
transcriber_libri_all_trained = transcribe_gen(asr_port=8051)
def find_ent(dd, conv_data):
ents = PriorityQueue()
for ent in conv_data:
if ent in dd["text"]:
ents.put((-len(ent), ent))
return ents.get_nowait()[1]
def process_data(d):
orig_seg = AudioSegment.from_wav(d["audio_path"])
jas_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(24000)
gcp_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
deepgram_file = Path("/home/shubham/voice_auto/pnrs/wav/") / Path(
d["audio_path"].stem + ".txt"
)
if deepgram_file.exists():
d["deepgram"] = "".join(
[s.replace("CHANNEL 0:", "") for s in deepgram_file.read_text().split("\n")]
)
else:
d["deepgram"] = 'Not Found'
d["audio_path"] = str(d["audio_path"])
d["gcp_transcript"] = transcriber_gcp(gcp_seg.raw_data)
d["jasper_trained"] = transcriber_trained(jas_seg.raw_data)
d["jasper_all"] = transcriber_all_trained(jas_seg.raw_data)
d["jasper_libri"] = transcriber_libri_all_trained(jas_seg.raw_data)
return d
conv_data = ExtendedPath(conv_src).read_json()
conv_data["Dates"] = generate_dates()
dump_data_path = dump_dir / Path(data_name) / dump_file
ui_dump_data = ExtendedPath(dump_data_path).read_json()["data"]
ui_dump_map = {i["utterance_id"]: i for i in ui_dump_data}
manifest_path = dump_dir / Path(data_name) / manifest_file
test_points = list(asr_manifest_reader(manifest_path))
test_data_objs = [{**(ui_dump_map[t["audio_path"].stem]), **t} for t in test_points]
test_data = parallel_apply(process_data, test_data_objs)
# test_data = [process_data(t) for t in test_data_objs]
test_path = dump_dir / Path(data_name) / test_file
def dd_gen(dump_data):
for dd in dump_data:
ent = find_ent(dd, conv_data[extraction_key])
dd["entity"] = ent
if ent:
yield dd
asr_test_writer(test_path, dd_gen(test_data))
# for i, b in enumerate(batch(test_data, 1)):
# test_fname = Path(f"{test_file.stem}_{i}.reg")
# test_path = dump_dir / Path(data_name) / test_fname
# asr_test_writer(test_path, dd_gen(test_data))
def main():
app()
if __name__ == "__main__":
main()

View File

View File

@@ -0,0 +1,52 @@
from logging import getLogger
from google.cloud import texttospeech
LOGGER = getLogger("googletts")
class GoogleTTS(object):
def __init__(self):
self.client = texttospeech.TextToSpeechClient()
def text_to_speech(self, text: str, params: dict) -> bytes:
tts_input = texttospeech.types.SynthesisInput(ssml=text)
voice = texttospeech.types.VoiceSelectionParams(
language_code=params["language"], name=params["name"]
)
audio_config = texttospeech.types.AudioConfig(
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
sample_rate_hertz=params["sample_rate"],
)
response = self.client.synthesize_speech(tts_input, voice, audio_config)
audio_content = response.audio_content
return audio_content
@classmethod
def voice_list(cls):
"""Lists the available voices."""
client = cls().client
# Performs the list voices request
voices = client.list_voices()
results = []
for voice in voices.voices:
supported_eng_langs = [
lang for lang in voice.language_codes if lang[:2] == "en"
]
if len(supported_eng_langs) > 0:
lang = ",".join(supported_eng_langs)
else:
continue
ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender)
results.append(
{
"name": voice.name,
"language": lang,
"gender": ssml_gender.name,
"engine": "wavenet" if "Wav" in voice.name else "standard",
"sample_rate": voice.natural_sample_rate_hertz,
}
)
return results

View File

@@ -0,0 +1,26 @@
"""
TTSClient Abstract Class
"""
from abc import ABC, abstractmethod
class TTSClient(ABC):
"""
Base class for TTS
"""
@abstractmethod
def text_to_speech(self, text: str, num_channels: int, sample_rate: int,
audio_encoding) -> bytes:
"""
convert text to bytes
Arguments:
text {[type]} -- text to convert
channel {[type]} -- output audio bytes channel setting
width {[type]} -- width of audio bytes
rate {[type]} -- rare for audio bytes
Returns:
[type] -- [description]
"""

View File

@@ -0,0 +1,62 @@
# import io
# import sys
# import json
import argparse
import logging
from pathlib import Path
from .utils import random_pnr_generator, asr_data_writer
from .tts.googletts import GoogleTTS
from tqdm import tqdm
import random
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def pnr_tts_streamer(count):
google_voices = GoogleTTS.voice_list()
gtts = GoogleTTS()
for pnr_code in tqdm(random_pnr_generator(count)):
tts_code = f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
param = random.choice(google_voices)
param["sample_rate"] = 24000
param["num_channels"] = 1
wav_data = gtts.text_to_speech(text=tts_code, params=param)
audio_dur = len(wav_data[44:]) / (2 * 24000)
yield pnr_code, audio_dur, wav_data
def generate_asr_data_fromtts(output_dir, dataset_name, count):
asr_data_writer(output_dir, dataset_name, pnr_tts_streamer(count))
def arg_parser():
prog = Path(__file__).stem
parser = argparse.ArgumentParser(
prog=prog, description=f"generates asr training data"
)
parser.add_argument(
"--output_dir",
type=Path,
default=Path("./train/asr_data"),
help="directory to output asr data",
)
parser.add_argument(
"--count", type=int, default=3, help="number of datapoints to generate"
)
parser.add_argument(
"--dataset_name", type=str, default="pnr_data", help="name of the dataset"
)
return parser
def main():
parser = arg_parser()
args = parser.parse_args()
generate_asr_data_fromtts(**vars(args))
if __name__ == "__main__":
main()

485
jasper/data/utils.py Normal file
View File

@@ -0,0 +1,485 @@
import io
import os
import json
import wave
from pathlib import Path
from itertools import product
from functools import partial
from math import floor
from uuid import uuid4
from urllib.parse import urlsplit
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import pymongo
from slugify import slugify
from num2words import num2words
from jasper.client import transcribe_gen
from nemo.collections.asr.metrics import word_error_rate
import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm
def manifest_str(path, dur, text):
return (
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
+ "\n"
)
def wav_bytes(audio_bytes, frame_rate=24000):
wf_b = io.BytesIO()
with wave.open(wf_b, mode="w") as wf:
wf.setnchannels(1)
wf.setframerate(frame_rate)
wf.setsampwidth(2)
wf.writeframesraw(audio_bytes)
return wf_b.getvalue()
def random_pnr_generator(count=10000):
LENGTH = 3
# alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
alphabet = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
numeric = list("0123456789")
np_alphabet = np.array(alphabet, dtype="|S1")
np_numeric = np.array(numeric, dtype="|S1")
np_alpha_codes = np.random.choice(np_alphabet, [count, LENGTH])
np_num_codes = np.random.choice(np_numeric, [count, LENGTH])
np_code_seed = np.concatenate((np_alpha_codes, np_num_codes), axis=1).T
np.random.shuffle(np_code_seed)
np_codes = np_code_seed.T
codes = [(b"".join(np_codes[i])).decode("utf-8") for i in range(len(np_codes))]
return codes
def alnum_to_asr_tokens(text):
letters = " ".join(list(text))
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
return ("".join(num_tokens)).lower()
def tscript_uuid_fname(transcript):
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for transcript, audio_dur, wav_data in asr_data_source:
fname = tscript_uuid_fname(transcript)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
rel_pnr_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
mf.write(manifest)
if verbose:
print(f"writing '{transcript}' of duration {audio_dur}")
num_datapoints += 1
return num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
ui_dump_file = dataset_dir / Path("ui_dump.json")
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
ui_dump = {
"use_domain_asr": False,
"annotation_only": False,
"enable_plots": True,
"data": [],
}
data_funcs = []
transcriber_gcp = gcp_transcribe_gen()
transcriber_pretrained = transcribe_gen(asr_port=8044)
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
def data_fn(
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_pnr_path,
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
gcp_seg = aud_seg.set_frame_rate(16000)
gcp_result = transcriber_gcp(gcp_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
return {
"audio_filepath": str(rel_pnr_path),
"duration": round(audio_dur, 1),
"text": transcript,
"real_idx": num_datapoints,
"audio_path": audio_path,
"spoken": transcript,
"caller": caller_name,
"utterance_id": fname,
"gcp_asr": gcp_result,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
}
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
audio_path = str(audio_file)
rel_pnr_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
mf.write(manifest)
data_funcs.append(
partial(
data_fn,
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_pnr_path,
)
)
num_datapoints += 1
dump_data = parallel_apply(lambda x: x(), data_funcs)
# dump_data = [x() for x in tqdm(data_funcs)]
ui_dump["data"] = dump_data
ExtendedPath(ui_dump_file).write_json(ui_dump)
return num_datapoints
def asr_manifest_reader(data_manifest_path: Path):
print(f"reading manifest from {data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [json.loads(v) for v in pnr_jsonl]
for p in pnr_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["text"] = p["text"].strip()
yield p
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
with asr_manifest_path.open("w") as mf:
print(f"opening {asr_manifest_path} for writing manifest")
for mani_dict in manifest_str_source:
manifest = manifest_str(
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
)
mf.write(manifest)
def asr_test_writer(out_file_path: Path, source):
def dd_str(dd, idx):
path = dd["audio_filepath"]
# dur = dd["duration"]
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
res_file = out_file_path.with_suffix(".result.json")
with out_file_path.open("w") as of:
print(f"opening {out_file_path} for writing test")
results = []
idx = 0
for ui_dd in source:
results.append(ui_dd)
out_str = dd_str(ui_dd, idx)
of.write(out_str)
idx += 1
of.write("DO_HANGUP\n")
ExtendedPath(res_file).write_json(results)
def batch(iterable, n=1):
ls = len(iterable)
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
class ExtendedPath(type(Path())):
"""docstring for ExtendedPath."""
def read_json(self):
print(f"reading json from {self}")
with self.open("r") as jf:
return json.load(jf)
def write_json(self, data):
print(f"writing json to {self}")
self.parent.mkdir(parents=True, exist_ok=True)
with self.open("w") as jf:
return json.dump(data, jf, indent=2)
def get_mongo_coll(uri="mongodb://localhost:27017/test.calls"):
ud = pymongo.uri_parser.parse_uri(uri)
conn = pymongo.MongoClient(uri)
return conn[ud["database"]][ud["collection"]]
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
mongo_uri = f"mongodb://{mongo_host}:{port}/"
return pymongo.MongoClient(mongo_uri)[db][col]
def strip_silence(sound):
from pydub.silence import detect_leading_silence
start_trim = detect_leading_silence(sound)
end_trim = detect_leading_silence(sound.reverse())
duration = len(sound)
return sound[start_trim : duration - end_trim]
def plot_seg(wav_plot_path, audio_path):
fig = plt.Figure()
ax = fig.add_subplot()
(y, sr) = librosa.load(audio_path)
librosa.display.waveplot(y=y, sr=sr, ax=ax)
with wav_plot_path.open("wb") as wav_plot_f:
fig.set_tight_layout(True)
fig.savefig(wav_plot_f, format="png", dpi=50)
def generate_dates():
days = [i for i in range(1, 32)]
months = [
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
]
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
def ordinal(n):
return "%d%s" % (
n,
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
)
def canon_vars(d, m):
return [
ordinal(d) + " " + m,
m + " " + ordinal(d),
ordinal(d) + " of " + m,
m + " the " + ordinal(d),
str(d) + " " + m,
m + " " + str(d),
]
return [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
def get_call_logs(call_obj, s3, call_meta_dir):
meta_s3_uri = call_obj["DataURI"]
s3_event_url_p = urlsplit(meta_s3_uri)
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
if not saved_meta_path.exists():
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
s3.download_file(
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
)
call_metas = json.load(saved_meta_path.open())
return call_metas
def gcp_transcribe_gen():
from google.cloud import speech_v1
from google.cloud.speech_v1 import enums
# import io
client = speech_v1.SpeechClient()
# local_file_path = 'resources/brooklyn_bridge.raw'
# The language of the supplied audio
language_code = "en-US"
model = "phone_call"
# Sample rate in Hertz of the audio data sent
sample_rate_hertz = 16000
# Encoding of audio data sent. This sample sets this explicitly.
# This field is optional for FLAC and WAV audio formats.
encoding = enums.RecognitionConfig.AudioEncoding.LINEAR16
config = {
"language_code": language_code,
"sample_rate_hertz": sample_rate_hertz,
"encoding": encoding,
"model": model,
"enable_automatic_punctuation": True,
"max_alternatives": 10,
"enable_word_time_offsets": True, # used to detect start and end time of utterances
"speech_contexts": [
{
"phrases": [
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"$OOV_CLASS_DIGIT_SEQUENCE",
"$TIME",
"$YEAR",
]
},
{
"phrases": [
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
]
},
{
"phrases": [
"PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"my PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"my PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"It's $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE is my PNR",
]
},
{"phrases": ["my name is"]},
{"phrases": ["Number $ORDINAL", "Numeral $ORDINAL"]},
{
"phrases": [
"John Smith",
"Carina Hu",
"Travis Lim",
"Marvin Tan",
"Samuel Tan",
"Dawn Mathew",
"Dawn",
"Mathew",
]
},
{
"phrases": [
"Beijing",
"Tokyo",
"London",
"19 August",
"7 October",
"11 December",
"17 September",
"19th August",
"7th October",
"11th December",
"17th September",
"ABC123",
"KWXUNP",
"XLU5K1",
"WL2JV6",
"KBS651",
]
},
{
"phrases": [
"first flight",
"second flight",
"third flight",
"first option",
"second option",
"third option",
"first one",
"second one",
"third one",
]
},
],
"metadata": {
"industry_naics_code_of_audio": 481111,
"interaction_type": enums.RecognitionMetadata.InteractionType.PHONE_CALL,
},
}
def sample_recognize(content):
"""
Transcribe a short audio file using synchronous speech recognition
Args:
local_file_path Path to local audio file, e.g. /path/audio.wav
"""
# with io.open(local_file_path, "rb") as f:
# content = f.read()
audio = {"content": content}
response = client.recognize(config, audio)
for result in response.results:
# First alternative is the most probable result
return "/".join([alt.transcript for alt in result.alternatives])
# print(u"Transcript: {}".format(alternative.transcript))
return ""
return sample_recognize
def parallel_apply(fn, iterable, workers=8):
with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")
return [
res
for res in tqdm(
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
)
]
def main():
for c in random_pnr_generator():
print(c)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,425 @@
import json
import shutil
from pathlib import Path
from enum import Enum
import typer
from tqdm import tqdm
from ..utils import (
alnum_to_asr_tokens,
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
tscript_uuid_fname,
get_mongo_conn,
plot_seg,
)
app = typer.Typer()
def preprocess_datapoint(
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
):
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.client import transcribe_gen
try:
res = dict(sample)
res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
if use_domain_asr:
res["spoken"] = alnum_to_asr_tokens(res["text"])
else:
res["spoken"] = res["text"]
res["utterance_id"] = audio_path.stem
if not annotation_only:
transcriber_pretrained = transcribe_gen(asr_port=8044)
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate(
[res["text"]], [res["pretrained_asr"]]
)
if use_domain_asr:
transcriber_speller = transcribe_gen(asr_port=8045)
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]]
)
if enable_plots:
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
return res
except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
@app.command()
def dump_ui(
data_name: str = typer.Option("call_alphanum", show_default=True),
dataset_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/valiation_data"),
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
use_domain_asr: bool = False,
annotation_only: bool = False,
enable_plots: bool = True,
):
from concurrent.futures import ThreadPoolExecutor
from functools import partial
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
dump_path: Path = dump_dir / Path(data_name) / dump_fname
plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_funcs = [
partial(
preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
annotation_only,
enable_plots,
)
for i, v in enumerate(pnr_jsonl)
]
def exec_func(f):
return f()
with ThreadPoolExecutor() as exe:
print("starting all preprocess tasks")
pnr_data = filter(
None,
list(
tqdm(
exe.map(exec_func, pnr_funcs),
position=0,
leave=True,
total=len(pnr_funcs),
)
),
)
if annotation_only:
result = list(pnr_data)
else:
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
ui_config = {
"use_domain_asr": use_domain_asr,
"annotation_only": annotation_only,
"enable_plots": enable_plots,
"data": result,
}
ExtendedPath(dump_path).write_json(ui_config)
@app.command()
def sample_ui(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
sample_count: int = typer.Option(80, show_default=True),
sample_file: Path = Path("sample_dump.json"),
):
import pandas as pd
processed_data_path = dump_dir / Path(data_name) / dump_file
sample_path = dump_dir / Path(data_name) / sample_file
processed_data = ExtendedPath(processed_data_path).read_json()
df = pd.DataFrame(processed_data["data"])
samples_per_caller = sample_count // len(df["caller"].unique())
caller_samples = pd.concat(
[g.sample(samples_per_caller) for (c, g) in df.groupby("caller")]
)
caller_samples = caller_samples.reset_index(drop=True)
caller_samples["real_idx"] = caller_samples.index
sample_data = caller_samples.to_dict("records")
processed_data["data"] = sample_data
typer.echo(f"sampling {sample_count} datapoints")
ExtendedPath(sample_path).write_json(processed_data)
@app.command()
def task_ui(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
task_count: int = typer.Option(4, show_default=True),
task_file: str = "task_dump",
):
import pandas as pd
import numpy as np
processed_data_path = dump_dir / Path(data_name) / dump_file
processed_data = ExtendedPath(processed_data_path).read_json()
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
task_f = task_f.reset_index(drop=True)
task_f["real_idx"] = task_f.index
task_data = task_f.to_dict("records")
processed_data["data"] = task_data
task_path = dump_dir / Path(data_name) / Path(task_file + f"-{t_idx}.json")
ExtendedPath(task_path).write_json(processed_data)
@app.command()
def dump_corrections(
task_uid: str,
data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("corrections.json"),
):
dump_path = dump_dir / Path(data_name) / dump_fname
col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@app.command()
def caller_quality(
task_uid: str,
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("ui_dump.json"),
correction_fname: Path = Path("corrections.json"),
):
import copy
import pandas as pd
dump_path = dump_dir / Path(data_name) / dump_fname
correction_path = dump_dir / Path(data_name) / correction_fname
dump_data = ExtendedPath(dump_path).read_json()
dump_map = {d["utterance_id"]: d for d in dump_data["data"]}
correction_data = ExtendedPath(correction_path).read_json()
def correction_dp(c):
dp = copy.deepcopy(dump_map[c["code"]])
dp["valid"] = c["value"]["status"] == "Correct"
return dp
corrected_dump = [
correction_dp(c)
for c in correction_data
if c["task_id"].rsplit("-", 1)[1] == task_uid
]
df = pd.DataFrame(corrected_dump)
print(f"Total samples: {len(df)}")
for (c, g) in df.groupby("caller"):
total = len(g)
valid = len(g[g["valid"] == True])
valid_rate = valid * 100 / total
print(f"Caller: {c} Valid%:{valid_rate:.2f} of {total} samples")
@app.command()
def fill_unannotated(
data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
dump_file: Path = Path("ui_dump.json"),
corrections_file: Path = Path("corrections.json"),
):
processed_data_path = dump_dir / Path(data_name) / dump_file
corrections_path = dump_dir / Path(data_name) / corrections_file
processed_data = json.load(processed_data_path.open())
corrections = json.load(corrections_path.open())
annotated_codes = {c["code"] for c in corrections}
all_codes = {c["gold_chars"] for c in processed_data}
unann_codes = all_codes - annotated_codes
mongo_conn = get_mongo_conn(col="asr_validation")
for c in unann_codes:
mongo_conn.find_one_and_update(
{"type": "correction", "code": c},
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
upsert=True,
)
class ExtractionType(str, Enum):
date = "dates"
city = "cities"
name = "names"
all = "all"
@app.command()
def split_extract(
data_name: str = typer.Option("call_alphanum", show_default=True),
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
# dump_dir: Path = Path("./data/valiation_data"),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
extraction_type: ExtractionType = ExtractionType.all,
):
import shutil
data_manifest_path = dump_dir / Path(data_name) / manifest_file
conv_data = ExtendedPath(conv_data_path).read_json()
def extract_data_of_type(extraction_key):
extraction_vals = conv_data[extraction_key]
dest_data_name = data_name + "_" + extraction_key.lower()
manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = dump_dir / Path(dest_data_name)
dest_data_dir.mkdir(exist_ok=True, parents=True)
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file
dest_ui_path = dest_data_dir / dump_file
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
ui_data_path = dump_dir / Path(data_name) / dump_file
orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
final_data = []
for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i
final_data.append(d)
orig_ui_data['data'] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file:
dest_correction_path = dest_data_dir / corrections_file
corrections_path = dump_dir / Path(data_name) / corrections_file
corrections = json.load(corrections_path.open())
extracted_corrections = list(
filter(
lambda c: c["code"] in file_ui_map
and file_ui_map[c["code"]]["text"] in extraction_vals,
corrections,
)
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all':
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
extract_data_of_type(extraction_type.value)
@app.command()
def update_corrections(
data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
ui_dump_file: Path = Path("ui_dump.json"),
skip_incorrect: bool = typer.Option(True, show_default=True),
):
data_manifest_path = dump_dir / Path(data_name) / manifest_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
correction_map = {
c["code"]: c["value"]["correction"]
for c in corrections
if c["value"]["status"] == "Incorrect"
}
# for d in manifest_data_gen:
# if d["chars"] in incorrect_set:
# d["audio_path"].unlink()
# renamed_set = set()
for d in ui_data:
if d["utterance_id"] in correct_set:
yield {
"audio_filepath": d["audio_filepath"],
"duration": d["duration"],
"text": d["text"],
}
elif d["utterance_id"] in correction_map:
correct_text = correction_map[d["utterance_id"]]
if skip_incorrect:
print(
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
)
else:
orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
yield {
"audio_filepath": new_filepath,
"duration": d["duration"],
"text": correct_text,
}
else:
orig_audio_path = Path(d["audio_path"])
# don't delete if another correction points to an old file
# if d["text"] not in renamed_set:
orig_audio_path.unlink()
# else:
# print(f'skipping deletion of correction:{d["text"]}')
typer.echo(f"Using data manifest:{data_manifest_path}")
dataset_dir = data_manifest_path.parent
dataset_name = dataset_dir.name
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
if not backup_dir.exists():
typer.echo(f"backing up to :{backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir))
# manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
new_data_manifest_path.replace(data_manifest_path)
@app.command()
def clear_mongo_corrections():
delete = typer.confirm("are you sure you want to clear mongo collection it?")
if delete:
col = get_mongo_conn(col="asr_validation")
col.delete_many({"type": "correction"})
col.delete_many({"type": "current_cursor"})
typer.echo("deleted mongo collection.")
return
typer.echo("Aborted")
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,38 @@
import streamlit.ReportThread as ReportThread
from streamlit.ScriptRequestQueue import RerunData
from streamlit.ScriptRunner import RerunException
from streamlit.server.Server import Server
def rerun():
"""Rerun a Streamlit app from the top!"""
widget_states = _get_widget_states()
raise RerunException(RerunData(widget_states))
def _get_widget_states():
# Hack to get the session object from Streamlit.
ctx = ReportThread.get_report_ctx()
session = None
current_server = Server.get_current()
if hasattr(current_server, '_session_infos'):
# Streamlit < 0.56
session_infos = Server.get_current()._session_infos.values()
else:
session_infos = Server.get_current()._session_info_by_id.values()
for session_info in session_infos:
if session_info.session.enqueue == ctx.enqueue:
session = session_info.session
if session is None:
raise RuntimeError(
"Oh noes. Couldn't get your Streamlit Session object"
"Are you doing something fancy with threads?"
)
# Got the session object!
return session._widget_states

View File

@@ -0,0 +1,158 @@
from pathlib import Path
import streamlit as st
import typer
from uuid import uuid4
from ..utils import ExtendedPath, get_mongo_conn
from .st_rerun import rerun
app = typer.Typer()
if not hasattr(st, "mongo_connected"):
st.mongoclient = get_mongo_conn(col="asr_validation")
mongo_conn = st.mongoclient
st.task_id = str(uuid4())
def current_cursor_fn():
# mongo_conn = st.mongoclient
cursor_obj = mongo_conn.find_one(
{"type": "current_cursor", "task_id": st.task_id}
)
cursor_val = cursor_obj["cursor"]
return cursor_val
def update_cursor_fn(val=0):
mongo_conn.find_one_and_update(
{"type": "current_cursor", "task_id": st.task_id},
{"$set": {"type": "current_cursor", "task_id": st.task_id, "cursor": val}},
upsert=True,
)
rerun()
def get_correction_entry_fn(code):
return mongo_conn.find_one(
{"type": "correction", "code": code}, projection={"_id": False}
)
def update_entry_fn(code, value):
mongo_conn.find_one_and_update(
{"type": "correction", "code": code},
{"$set": {"value": value, "task_id": st.task_id}},
upsert=True,
)
def set_task_fn(mf_path):
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
if not task_path.exists():
print(f"creating task lock at {task_path}")
task_path.touch()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn
st.get_correction_entry = get_correction_entry_fn
st.update_entry = update_entry_fn
st.set_task = set_task_fn
st.mongo_connected = True
cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
if not cursor_obj:
update_cursor_fn(0)
@st.cache()
def load_ui_data(validation_ui_data_path: Path):
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
return ExtendedPath(validation_ui_data_path).read_json()
@app.command()
def main(manifest: Path):
st.set_task(manifest)
ui_config = load_ui_data(manifest)
asr_data = ui_config["data"]
use_domain_asr = ui_config.get("use_domain_asr", True)
annotation_only = ui_config.get("annotation_only", False)
enable_plots = ui_config.get("enable_plots", True)
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
title_type = "Speller " if use_domain_asr else ""
task_uid = st.task_id.rsplit("-", 1)[1]
if annotation_only:
st.title(f"ASR Annotation - # {task_uid}")
else:
st.title(f"ASR {title_type}Validation - # {task_uid}")
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
if not annotation_only:
if use_domain_asr:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:")
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if "caller" in sample:
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
if enable_plots:
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text
corrected = sample["text"]
correction_entry = st.get_correction_entry(sample["utterance_id"])
selected_idx = 0
options = ("Correct", "Incorrect", "Inaudible")
# if correction entry is present set the corresponding ui defaults
if correction_entry:
selected_idx = options.index(correction_entry["value"]["status"])
corrected = correction_entry["value"]["correction"]
selected = st.radio("The Audio is", options, index=selected_idx)
if selected == "Incorrect":
corrected = st.text_input("Actual:", value=corrected)
if selected == "Inaudible":
corrected = ""
if st.button("Submit"):
st.update_entry(
sample["utterance_id"], {"status": selected, "correction": corrected}
)
st.update_cursor(sample_no + 1)
if correction_entry:
st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
)
text_sample = st.text_input("Go to Text:", value="")
if text_sample != "":
candidates = [
i
for (i, p) in enumerate(asr_data)
if p["text"] == text_sample or p["spoken"] == text_sample
]
if len(candidates) > 0:
st.update_cursor(candidates[0])
real_idx = st.number_input(
"Go to real-index",
value=sample["real_idx"],
min_value=0,
max_value=len(asr_data) - 1,
)
if real_idx != int(sample["real_idx"]):
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
st.update_cursor(idx)
if __name__ == "__main__":
try:
app()
except SystemExit:
pass

359
jasper/evaluate.py Normal file
View File

@@ -0,0 +1,359 @@
# Copyright (c) 2019 NVIDIA Corporation
import argparse
import copy
# import math
import os
from pathlib import Path
from functools import partial
from ruamel.yaml import YAML
import nemo
import nemo.collections.asr as nemo_asr
import nemo.utils.argparse as nm_argparse
from nemo.collections.asr.helpers import (
# monitor_asr_train_progress,
process_evaluation_batch,
process_evaluation_epoch,
)
# from nemo.utils.lr_policies import CosineAnnealing
from training.data_loaders import RpycAudioToTextDataLayer
logging = nemo.logging
def parse_args():
parser = argparse.ArgumentParser(
parents=[nm_argparse.NemoArgParser()],
description="Jasper",
conflict_handler="resolve",
)
parser.set_defaults(
checkpoint_dir=None,
optimizer="novograd",
batch_size=64,
eval_batch_size=64,
lr=0.002,
amp_opt_level="O1",
create_tb_writer=True,
model_config="./train/jasper10x5dr.yaml",
work_dir="./train/work",
num_epochs=300,
weight_decay=0.005,
checkpoint_save_freq=100,
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
exp_name="jasper-speller",
)
# Overwrite default args
parser.add_argument(
"--max_steps",
type=int,
default=None,
required=False,
help="max number of steps to train",
)
parser.add_argument(
"--num_epochs", type=int, required=False, help="number of epochs to train"
)
parser.add_argument(
"--model_config",
type=str,
required=False,
help="model configuration file: model.yaml",
)
parser.add_argument(
"--encoder_checkpoint",
type=str,
required=True,
help="encoder checkpoint file: JasperEncoder.pt",
)
parser.add_argument(
"--decoder_checkpoint",
type=str,
required=True,
help="decoder checkpoint file: JasperDecoderForCTC.pt",
)
parser.add_argument(
"--remote_data",
type=str,
required=False,
default="",
help="remote dataloader endpoint",
)
parser.add_argument(
"--dataset",
type=str,
required=False,
default="",
help="dataset directory containing train/test manifests",
)
# Create new args
parser.add_argument("--exp_name", default="Jasper", type=str)
parser.add_argument("--beta1", default=0.95, type=float)
parser.add_argument("--beta2", default=0.25, type=float)
parser.add_argument("--warmup_steps", default=0, type=int)
parser.add_argument(
"--load_dir",
default=None,
type=str,
help="directory with pre-trained checkpoint",
)
args = parser.parse_args()
if args.max_steps is None and args.num_epochs is None:
raise ValueError("Either max_steps or num_epochs should be provided.")
return args
def construct_name(
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
):
if max_steps is not None:
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
)
else:
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
)
def create_all_dags(args, neural_factory):
yaml = YAML(typ="safe")
with open(args.model_config) as f:
jasper_params = yaml.load(f)
vocab = jasper_params["labels"]
sample_rate = jasper_params["sample_rate"]
# Calculate num_workers for dataloader
total_cpus = os.cpu_count()
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
# perturb_config = jasper_params.get('perturb', None)
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
del train_dl_params["train"]
del train_dl_params["eval"]
# del train_dl_params["normalize_transcripts"]
if args.dataset:
d_path = Path(args.dataset)
if not args.train_dataset:
args.train_dataset = str(d_path / Path("train_manifest.json"))
if not args.eval_datasets:
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
data_loader_layer = nemo_asr.AudioToTextDataLayer
if args.remote_data:
train_dl_params["rpyc_host"] = args.remote_data
data_loader_layer = RpycAudioToTextDataLayer
# data_layer = data_loader_layer(
# manifest_filepath=args.train_dataset,
# sample_rate=sample_rate,
# labels=vocab,
# batch_size=args.batch_size,
# num_workers=cpu_per_traindl,
# **train_dl_params,
# # normalize_transcripts=False
# )
#
# N = len(data_layer)
# steps_per_epoch = math.ceil(
# N / (args.batch_size * args.iter_per_step * args.num_gpus)
# )
# logging.info("Have {0} examples to train on.".format(N))
#
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
)
# multiply_batch_config = jasper_params.get("MultiplyBatch", None)
# if multiply_batch_config:
# multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
#
# spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
# if spectr_augment_config:
# data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
# **spectr_augment_config
# )
#
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
if args.remote_data:
eval_dl_params["rpyc_host"] = args.remote_data
del eval_dl_params["train"]
del eval_dl_params["eval"]
data_layers_eval = []
# if args.eval_datasets:
for eval_datasets in args.eval_datasets:
data_layer_eval = data_loader_layer(
manifest_filepath=eval_datasets,
sample_rate=sample_rate,
labels=vocab,
batch_size=args.eval_batch_size,
num_workers=cpu_per_traindl,
**eval_dl_params,
)
data_layers_eval.append(data_layer_eval)
# else:
# logging.warning("There were no val datasets passed")
jasper_encoder = nemo_asr.JasperEncoder(
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
**jasper_params["JasperEncoder"],
)
jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0)
jasper_decoder = nemo_asr.JasperDecoderForCTC(
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
num_classes=len(vocab),
)
jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0)
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
greedy_decoder = nemo_asr.GreedyCTCDecoder()
# logging.info("================================")
# logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
# logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
# logging.info(
# f"Total number of parameters in model: "
# f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
# )
# logging.info("================================")
#
# # Train DAG
# (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
# processed_signal_t, p_length_t = data_preprocessor(
# input_signal=audio_signal_t, length=a_sig_length_t
# )
#
# if multiply_batch_config:
# (
# processed_signal_t,
# p_length_t,
# transcript_t,
# transcript_len_t,
# ) = multiply_batch(
# in_x=processed_signal_t,
# in_x_len=p_length_t,
# in_y=transcript_t,
# in_y_len=transcript_len_t,
# )
#
# if spectr_augment_config:
# processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
#
# encoded_t, encoded_len_t = jasper_encoder(
# audio_signal=processed_signal_t, length=p_length_t
# )
# log_probs_t = jasper_decoder(encoder_output=encoded_t)
# predictions_t = greedy_decoder(log_probs=log_probs_t)
# loss_t = ctc_loss(
# log_probs=log_probs_t,
# targets=transcript_t,
# input_length=encoded_len_t,
# target_length=transcript_len_t,
# )
#
# # Callbacks needed to print info to console and Tensorboard
# train_callback = nemo.core.SimpleLossLoggerCallback(
# tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
# print_func=partial(monitor_asr_train_progress, labels=vocab),
# get_tb_values=lambda x: [("loss", x[0])],
# tb_writer=neural_factory.tb_writer,
# )
#
# chpt_callback = nemo.core.CheckpointCallback(
# folder=neural_factory.checkpoint_dir,
# load_from_folder=args.load_dir,
# step_freq=args.checkpoint_save_freq,
# checkpoints_to_keep=30,
# )
#
# callbacks = [train_callback, chpt_callback]
callbacks = []
# assemble eval DAGs
for i, eval_dl in enumerate(data_layers_eval):
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
processed_signal_e, p_length_e = data_preprocessor(
input_signal=audio_signal_e, length=a_sig_length_e
)
encoded_e, encoded_len_e = jasper_encoder(
audio_signal=processed_signal_e, length=p_length_e
)
log_probs_e = jasper_decoder(encoder_output=encoded_e)
predictions_e = greedy_decoder(log_probs=log_probs_e)
loss_e = ctc_loss(
log_probs=log_probs_e,
targets=transcript_e,
input_length=encoded_len_e,
target_length=transcript_len_e,
)
# create corresponding eval callback
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
eval_callback = nemo.core.EvaluatorCallback(
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
eval_step=args.eval_freq,
tb_writer=neural_factory.tb_writer,
)
callbacks.append(eval_callback)
return callbacks
def main():
args = parse_args()
# name = construct_name(
# args.exp_name,
# args.lr,
# args.batch_size,
# args.max_steps,
# args.num_epochs,
# args.weight_decay,
# args.optimizer,
# args.iter_per_step,
# )
# log_dir = name
# if args.work_dir:
# log_dir = os.path.join(args.work_dir, name)
# instantiate Neural Factory with supported backend
neural_factory = nemo.core.NeuralModuleFactory(
placement=nemo.core.DeviceType.GPU,
backend=nemo.core.Backend.PyTorch,
# local_rank=args.local_rank,
# optimization_level=args.amp_opt_level,
# log_dir=log_dir,
# checkpoint_dir=args.checkpoint_dir,
# create_tb_writer=args.create_tb_writer,
# files_to_copy=[args.model_config, __file__],
# cudnn_benchmark=args.cudnn_benchmark,
# tensorboard_dir=args.tensorboard_dir,
)
args.num_gpus = neural_factory.world_size
# checkpoint_dir = neural_factory.checkpoint_dir
if args.local_rank is not None:
logging.info("Doing ALL GPU")
# build dags
callbacks = create_all_dags(args, neural_factory)
# evaluate model
neural_factory.eval(callbacks=callbacks)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1 @@

366
jasper/training/cli.py Normal file
View File

@@ -0,0 +1,366 @@
# Copyright (c) 2019 NVIDIA Corporation
import argparse
import copy
import math
import os
from pathlib import Path
from functools import partial
from ruamel.yaml import YAML
import nemo
import nemo.collections.asr as nemo_asr
import nemo.utils.argparse as nm_argparse
from nemo.collections.asr.helpers import (
monitor_asr_train_progress,
process_evaluation_batch,
process_evaluation_epoch,
)
from nemo.utils.lr_policies import CosineAnnealing
from .data_loaders import RpycAudioToTextDataLayer
logging = nemo.logging
def parse_args():
parser = argparse.ArgumentParser(
parents=[nm_argparse.NemoArgParser()],
description="Jasper",
conflict_handler="resolve",
)
parser.set_defaults(
checkpoint_dir=None,
optimizer="novograd",
batch_size=64,
eval_batch_size=64,
lr=0.002,
amp_opt_level="O1",
create_tb_writer=True,
model_config="./train/jasper10x5dr.yaml",
work_dir="./train/work",
num_epochs=300,
weight_decay=0.005,
checkpoint_save_freq=100,
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
exp_name="jasper-speller",
)
# Overwrite default args
parser.add_argument(
"--max_steps",
type=int,
default=None,
required=False,
help="max number of steps to train",
)
parser.add_argument(
"--num_epochs",
type=int,
required=False,
help="number of epochs to train",
)
parser.add_argument(
"--model_config",
type=str,
required=False,
help="model configuration file: model.yaml",
)
parser.add_argument(
"--remote_data",
type=str,
required=False,
default="",
help="remote dataloader endpoint",
)
parser.add_argument(
"--dataset",
type=str,
required=False,
default="",
help="dataset directory containing train/test manifests",
)
# Create new args
parser.add_argument("--exp_name", default="Jasper", type=str)
parser.add_argument("--beta1", default=0.95, type=float)
parser.add_argument("--beta2", default=0.25, type=float)
parser.add_argument("--warmup_steps", default=0, type=int)
parser.add_argument(
"--load_dir",
default=None,
type=str,
help="directory with pre-trained checkpoint",
)
args = parser.parse_args()
if args.max_steps is None and args.num_epochs is None:
raise ValueError("Either max_steps or num_epochs should be provided.")
return args
def construct_name(
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
):
if max_steps is not None:
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
)
else:
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
)
def create_all_dags(args, neural_factory):
yaml = YAML(typ="safe")
with open(args.model_config) as f:
jasper_params = yaml.load(f)
vocab = jasper_params["labels"]
sample_rate = jasper_params["sample_rate"]
# Calculate num_workers for dataloader
total_cpus = os.cpu_count()
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
# perturb_config = jasper_params.get('perturb', None)
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
del train_dl_params["train"]
del train_dl_params["eval"]
# del train_dl_params["normalize_transcripts"]
if args.dataset:
d_path = Path(args.dataset)
if not args.train_dataset:
args.train_dataset = str(d_path / Path("train_manifest.json"))
if not args.eval_datasets:
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
data_loader_layer = nemo_asr.AudioToTextDataLayer
if args.remote_data:
train_dl_params["rpyc_host"] = args.remote_data
data_loader_layer = RpycAudioToTextDataLayer
data_layer = data_loader_layer(
manifest_filepath=args.train_dataset,
sample_rate=sample_rate,
labels=vocab,
batch_size=args.batch_size,
num_workers=cpu_per_traindl,
**train_dl_params,
# normalize_transcripts=False
)
N = len(data_layer)
steps_per_epoch = math.ceil(
N / (args.batch_size * args.iter_per_step * args.num_gpus)
)
logging.info("Have {0} examples to train on.".format(N))
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
)
multiply_batch_config = jasper_params.get("MultiplyBatch", None)
if multiply_batch_config:
multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
if spectr_augment_config:
data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
**spectr_augment_config
)
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
if args.remote_data:
eval_dl_params["rpyc_host"] = args.remote_data
del eval_dl_params["train"]
del eval_dl_params["eval"]
data_layers_eval = []
if args.eval_datasets:
for eval_datasets in args.eval_datasets:
data_layer_eval = data_loader_layer(
manifest_filepath=eval_datasets,
sample_rate=sample_rate,
labels=vocab,
batch_size=args.eval_batch_size,
num_workers=cpu_per_traindl,
**eval_dl_params,
)
data_layers_eval.append(data_layer_eval)
else:
logging.warning("There were no val datasets passed")
jasper_encoder = nemo_asr.JasperEncoder(
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
**jasper_params["JasperEncoder"],
)
jasper_decoder = nemo_asr.JasperDecoderForCTC(
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
num_classes=len(vocab),
)
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
greedy_decoder = nemo_asr.GreedyCTCDecoder()
logging.info("================================")
logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
logging.info(
f"Total number of parameters in model: "
f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
)
logging.info("================================")
# Train DAG
(audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
processed_signal_t, p_length_t = data_preprocessor(
input_signal=audio_signal_t, length=a_sig_length_t
)
if multiply_batch_config:
(
processed_signal_t,
p_length_t,
transcript_t,
transcript_len_t,
) = multiply_batch(
in_x=processed_signal_t,
in_x_len=p_length_t,
in_y=transcript_t,
in_y_len=transcript_len_t,
)
if spectr_augment_config:
processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
encoded_t, encoded_len_t = jasper_encoder(
audio_signal=processed_signal_t, length=p_length_t
)
log_probs_t = jasper_decoder(encoder_output=encoded_t)
predictions_t = greedy_decoder(log_probs=log_probs_t)
loss_t = ctc_loss(
log_probs=log_probs_t,
targets=transcript_t,
input_length=encoded_len_t,
target_length=transcript_len_t,
)
# Callbacks needed to print info to console and Tensorboard
train_callback = nemo.core.SimpleLossLoggerCallback(
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
print_func=partial(monitor_asr_train_progress, labels=vocab),
get_tb_values=lambda x: [("loss", x[0])],
tb_writer=neural_factory.tb_writer,
)
chpt_callback = nemo.core.CheckpointCallback(
folder=neural_factory.checkpoint_dir,
load_from_folder=args.load_dir,
step_freq=args.checkpoint_save_freq,
checkpoints_to_keep=30,
)
callbacks = [train_callback, chpt_callback]
# assemble eval DAGs
for i, eval_dl in enumerate(data_layers_eval):
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
processed_signal_e, p_length_e = data_preprocessor(
input_signal=audio_signal_e, length=a_sig_length_e
)
encoded_e, encoded_len_e = jasper_encoder(
audio_signal=processed_signal_e, length=p_length_e
)
log_probs_e = jasper_decoder(encoder_output=encoded_e)
predictions_e = greedy_decoder(log_probs=log_probs_e)
loss_e = ctc_loss(
log_probs=log_probs_e,
targets=transcript_e,
input_length=encoded_len_e,
target_length=transcript_len_e,
)
# create corresponding eval callback
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
eval_callback = nemo.core.EvaluatorCallback(
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
eval_step=args.eval_freq,
tb_writer=neural_factory.tb_writer,
)
callbacks.append(eval_callback)
return loss_t, callbacks, steps_per_epoch
def main():
args = parse_args()
name = construct_name(
args.exp_name,
args.lr,
args.batch_size,
args.max_steps,
args.num_epochs,
args.weight_decay,
args.optimizer,
args.iter_per_step,
)
log_dir = name
if args.work_dir:
log_dir = os.path.join(args.work_dir, name)
# instantiate Neural Factory with supported backend
neural_factory = nemo.core.NeuralModuleFactory(
backend=nemo.core.Backend.PyTorch,
local_rank=args.local_rank,
optimization_level=args.amp_opt_level,
log_dir=log_dir,
checkpoint_dir=args.checkpoint_dir,
create_tb_writer=args.create_tb_writer,
files_to_copy=[args.model_config, __file__],
cudnn_benchmark=args.cudnn_benchmark,
tensorboard_dir=args.tensorboard_dir,
)
args.num_gpus = neural_factory.world_size
checkpoint_dir = neural_factory.checkpoint_dir
if args.local_rank is not None:
logging.info("Doing ALL GPU")
# build dags
train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory)
# train model
neural_factory.train(
tensors_to_optimize=[train_loss],
callbacks=callbacks,
lr_policy=CosineAnnealing(
args.max_steps
if args.max_steps is not None
else args.num_epochs * steps_per_epoch,
warmup_steps=args.warmup_steps,
),
optimizer=args.optimizer,
optimization_params={
"num_epochs": args.num_epochs,
"max_steps": args.max_steps,
"lr": args.lr,
"betas": (args.beta1, args.beta2),
"weight_decay": args.weight_decay,
"grad_norm_clip": None,
},
batches_per_step=args.iter_per_step,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,334 @@
from functools import partial
import tempfile
# from typing import Any, Dict, List, Optional
import torch
import nemo
# import nemo.collections.asr as nemo_asr
from nemo.backends.pytorch import DataLayerNM
from nemo.core import DeviceType
# from nemo.core.neural_types import *
from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType
from nemo.utils.decorators import add_port_docs
from nemo.collections.asr.parts.dataset import (
# AudioDataset,
# AudioLabelDataset,
# KaldiFeatureDataset,
# TranscriptDataset,
parsers,
collections,
seq_collate_fn,
)
# from functools import lru_cache
import rpyc
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from .featurizer import RpycWaveformFeaturizer
# from nemo.collections.asr.parts.features import WaveformFeaturizer
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
logging = nemo.logging
class CachedAudioDataset(torch.utils.data.Dataset):
"""
Dataset that loads tensors via a json file containing paths to audio
files, transcripts, and durations (in seconds). Each new line is a
different sample. Example below:
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
"/path/to/audio.txt", "duration": 23.147}
...
{"audio_filepath": "/path/to/audio.wav", "text": "the
transcription", offset": 301.75, "duration": 0.82, "utt":
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
Args:
manifest_filepath: Path to manifest json as described above. Can
be comma-separated paths.
labels: String containing all the possible characters to map to
featurizer: Initialized featurizer class that converts paths of
audio to feature tensors
max_duration: If audio exceeds this length, do not include in dataset
min_duration: If audio is less than this length, do not include
in dataset
max_utts: Limit number of utterances
blank_index: blank character index, default = -1
unk_index: unk_character index, default = -1
normalize: whether to normalize transcript text (default): True
bos_id: Id of beginning of sequence symbol to append if not None
eos_id: Id of end of sequence symbol to append if not None
load_audio: Boolean flag indicate whether do or not load audio
"""
def __init__(
self,
manifest_filepath,
labels,
featurizer,
max_duration=None,
min_duration=None,
max_utts=0,
blank_index=-1,
unk_index=-1,
normalize=True,
trim=False,
bos_id=None,
eos_id=None,
load_audio=True,
parser="en",
):
self.collection = collections.ASRAudioText(
manifests_files=manifest_filepath.split(","),
parser=parsers.make_parser(
labels=labels,
name=parser,
unk_id=unk_index,
blank_id=blank_index,
do_normalize=normalize,
),
min_duration=min_duration,
max_duration=max_duration,
max_number=max_utts,
)
self.index_feature_map = {}
self.featurizer = featurizer
self.trim = trim
self.eos_id = eos_id
self.bos_id = bos_id
self.load_audio = load_audio
print(f"initializing dataset {manifest_filepath}")
def exec_func(i):
return self[i]
task_count = len(self.collection)
with ThreadPoolExecutor() as exe:
print("starting all loading tasks")
list(
tqdm(
exe.map(exec_func, range(task_count)),
position=0,
leave=True,
total=task_count,
)
)
print(f"initializing complete")
def __getitem__(self, index):
sample = self.collection[index]
if self.load_audio:
cached_features = self.index_feature_map.get(index)
if cached_features is not None:
features = cached_features
else:
features = self.featurizer.process(
sample.audio_file,
offset=0,
duration=sample.duration,
trim=self.trim,
)
self.index_feature_map[index] = features
f, fl = features, torch.tensor(features.shape[0]).long()
else:
f, fl = None, None
t, tl = sample.text_tokens, len(sample.text_tokens)
if self.bos_id is not None:
t = [self.bos_id] + t
tl += 1
if self.eos_id is not None:
t = t + [self.eos_id]
tl += 1
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
def __len__(self):
return len(self.collection)
class RpycAudioToTextDataLayer(DataLayerNM):
"""Data Layer for general ASR tasks.
Module which reads ASR labeled data. It accepts comma-separated
JSON manifest files describing the correspondence between wav audio files
and their transcripts. JSON files should be of the following format::
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
transcript_0}
...
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
transcript_n}
Args:
manifest_filepath (str): Dataset parameter.
Path to JSON containing data.
labels (list): Dataset parameter.
List of characters that can be output by the ASR model.
For Jasper, this is the 28 character set {a-z '}. The CTC blank
symbol is automatically added later for models using ctc.
batch_size (int): batch size
sample_rate (int): Target sampling rate for data. Audio files will be
resampled to sample_rate if it is not already.
Defaults to 16000.
int_values (bool): Bool indicating whether the audio file is saved as
int data or float data.
Defaults to False.
eos_id (id): Dataset parameter.
End of string symbol id used for seq2seq models.
Defaults to None.
min_duration (float): Dataset parameter.
All training files which have a duration less than min_duration
are dropped. Note: Duration is read from the manifest JSON.
Defaults to 0.1.
max_duration (float): Dataset parameter.
All training files which have a duration more than max_duration
are dropped. Note: Duration is read from the manifest JSON.
Defaults to None.
normalize_transcripts (bool): Dataset parameter.
Whether to use automatic text cleaning.
It is highly recommended to manually clean text for best results.
Defaults to True.
trim_silence (bool): Whether to use trim silence from beginning and end
of audio signal using librosa.effects.trim().
Defaults to False.
load_audio (bool): Dataset parameter.
Controls whether the dataloader loads the audio signal and
transcript or just the transcript.
Defaults to True.
drop_last (bool): See PyTorch DataLoader.
Defaults to False.
shuffle (bool): See PyTorch DataLoader.
Defaults to True.
num_workers (int): See PyTorch DataLoader.
Defaults to 0.
perturb_config (dict): Currently disabled.
"""
@property
@add_port_docs()
def output_ports(self):
"""Returns definitions of module output ports.
"""
return {
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
"audio_signal": NeuralType(
("B", "T"),
AudioSignal(freq=self._sample_rate)
if self is not None and self._sample_rate is not None
else AudioSignal(),
),
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
"transcripts": NeuralType(("B", "T"), LabelsType()),
"transcript_length": NeuralType(tuple("B"), LengthsType()),
}
def __init__(
self,
manifest_filepath,
labels,
batch_size,
sample_rate=16000,
int_values=False,
bos_id=None,
eos_id=None,
pad_id=None,
min_duration=0.1,
max_duration=None,
normalize_transcripts=True,
trim_silence=False,
load_audio=True,
rpyc_host="",
drop_last=False,
shuffle=True,
num_workers=0,
):
super().__init__()
self._sample_rate = sample_rate
def rpyc_root_fn():
return rpyc.connect(
rpyc_host, 8064, config={"sync_request_timeout": 600}
).root
rpyc_conn = rpyc_root_fn()
self._featurizer = RpycWaveformFeaturizer(
sample_rate=self._sample_rate,
int_values=int_values,
augmentor=None,
rpyc_conn=rpyc_conn,
)
def read_remote_manifests():
local_mp = []
for mrp in manifest_filepath.split(","):
md = rpyc_conn.read_path(mrp)
mf = tempfile.NamedTemporaryFile(
dir="/tmp", prefix="jasper_manifest.", delete=False
)
mf.write(md)
mf.close()
local_mp.append(mf.name)
return ",".join(local_mp)
local_manifest_filepath = read_remote_manifests()
dataset_params = {
"manifest_filepath": local_manifest_filepath,
"labels": labels,
"featurizer": self._featurizer,
"max_duration": max_duration,
"min_duration": min_duration,
"normalize": normalize_transcripts,
"trim": trim_silence,
"bos_id": bos_id,
"eos_id": eos_id,
"load_audio": load_audio,
}
self._dataset = CachedAudioDataset(**dataset_params)
self._batch_size = batch_size
# Set up data loader
if self._placement == DeviceType.AllGpu:
logging.info("Parallelizing Datalayer.")
sampler = torch.utils.data.distributed.DistributedSampler(self._dataset)
else:
sampler = None
if batch_size == -1:
batch_size = len(self._dataset)
pad_id = 0 if pad_id is None else pad_id
self._dataloader = torch.utils.data.DataLoader(
dataset=self._dataset,
batch_size=batch_size,
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
drop_last=drop_last,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=1,
)
def __len__(self):
return len(self._dataset)
@property
def dataset(self):
return None
@property
def data_iterator(self):
return self._dataloader

View File

@@ -0,0 +1,51 @@
# import math
# import librosa
import torch
import pickle
# import torch.nn as nn
# from torch_stft import STFT
# from nemo import logging
from nemo.collections.asr.parts.perturb import AudioAugmentor
# from nemo.collections.asr.parts.segment import AudioSegment
class RpycWaveformFeaturizer(object):
def __init__(
self, sample_rate=16000, int_values=False, augmentor=None, rpyc_conn=None
):
self.augmentor = augmentor if augmentor is not None else AudioAugmentor()
self.sample_rate = sample_rate
self.int_values = int_values
self.remote_path_samples = rpyc_conn.get_path_samples
def max_augmentation_length(self, length):
return self.augmentor.max_augmentation_length(length)
def process(self, file_path, offset=0, duration=0, trim=False):
audio = self.remote_path_samples(
file_path,
target_sr=self.sample_rate,
int_values=self.int_values,
offset=offset,
duration=duration,
trim=trim,
)
return torch.tensor(pickle.loads(audio), dtype=torch.float)
def process_segment(self, audio_segment):
self.augmentor.perturb(audio_segment)
return torch.tensor(audio_segment, dtype=torch.float)
@classmethod
def from_config(cls, input_config, perturbation_configs=None):
if perturbation_configs is not None:
aa = AudioAugmentor.from_config(perturbation_configs)
else:
aa = None
sample_rate = input_config.get("sample_rate", 16000)
int_values = input_config.get("int_values", False)
return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa)

View File

@@ -1,11 +1,52 @@
from setuptools import setup
from setuptools import setup, find_packages
requirements = [
"ruamel.yaml",
"torch==1.4.0",
"torchvision==0.5.0",
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
]
extra_requirements = {"server": ["rpyc==4.1.4"]}
extra_requirements = {
"server": ["rpyc~=4.1.4", "tqdm~=4.39.0"],
"data": [
"google-cloud-texttospeech~=1.0.1",
"tqdm~=4.39.0",
"pydub~=0.24.0",
"scikit_learn~=0.22.1",
"pandas~=1.0.3",
"boto3~=1.12.35",
"ruamel.yaml==0.16.10",
"pymongo==3.10.1",
"librosa==0.7.2",
"matplotlib==3.2.1",
"pandas==1.0.3",
"tabulate==0.8.7",
"natural==0.2.0",
"num2words==0.5.10",
"typer[all]==0.1.1",
"python-slugify==4.0.0",
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
],
"validation": [
"rpyc~=4.1.4",
"pymongo==3.10.1",
"typer[all]==0.1.1",
"tqdm~=4.39.0",
"librosa==0.7.2",
"matplotlib==3.2.1",
"pydub~=0.24.0",
"streamlit==0.58.0",
"natural==0.2.0",
"stringcase==1.2.0",
"google-cloud-speech~=1.3.1",
]
# "train": [
# "torchaudio==0.5.0",
# "torch-stft==0.1.4",
# ]
}
packages = find_packages()
setup(
name="jasper-asr",
@@ -17,11 +58,24 @@ setup(
license="MIT",
install_requires=requirements,
extras_require=extra_requirements,
packages=["."],
packages=packages,
entry_points={
"console_scripts": [
"jasper_transcribe = jasper.transcribe:main",
"jasper_asr_rpyc_server = jasper.server:main",
"jasper_server = jasper.server:main",
"jasper_trainer = jasper.training.cli:main",
"jasper_evaluator = jasper.evaluate:main",
"jasper_data_tts_generate = jasper.data.tts_generator:main",
"jasper_data_conv_generate = jasper.data.conv_generator:main",
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
"jasper_data_test_generate = jasper.data.test_generator:main",
"jasper_data_call_recycle = jasper.data.call_recycler:main",
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
"jasper_data_server = jasper.data.server:main",
"jasper_data_validation = jasper.data.validation.process:main",
"jasper_data_preprocess = jasper.data.process:main",
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
]
},
zip_safe=False,

3
validation_ui.py Normal file
View File

@@ -0,0 +1,3 @@
import runpy
runpy.run_module("jasper.data.validation.ui", run_name="__main__", alter_sys=True)