1. integrated data generator using google tts

2. added training script

fix module packaging issue

implement call audio data recycler for asr

1. added streamlit based validation ui with mongodb datastore integration
2. fix asr wrong sample rate inference
3. update requirements

1. refactored streamlit code
2. fixed issues in data manifest handling

refresh to next entry on submit and comment out mongo clearing code for safety :P

add validation ui and post processing to correct using validation data

1. added a tool to extract asr data from gcp transcripts logs
2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file
3. clean-up leaderboard duration logic
4. added a wip dataloader service
5. made the asr_data_writer util more generic with verbose flags and unique filename
6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node
7. refactored the validation post processing to dump a ui config for validation
8. included utility functions to correct, fill update and clear annotations from mongodb data
9. refactored the ui logic to be more generic for any asr data
10. updated setup.py dependencies to support the above features

unlink temporary files after transcribing

1. clean-up unused data process code
2. fix invalid sample no from mongo
3. data loader service return remote netref

1. added training utils with custom data loaders with remote rpyc dataservice support
2. fix validation correction dump path
3. cache dataset for precaching before training to memory
4. update dependencies

1. implement dataset augmentation and validation in process
2. added option to skip 'incorrect' annotations in validation data
3. added confirmation on clearing mongo collection
4. added an option to navigate to a given text in the validation ui
5. added a dataset and remote option to trainer to load dataset from directory and remote rpyc service

1. added utility command to export call logs
2. mongo conn accepts port

refactored module structure

1. enabled silece stripping in chunks when recycling audio from asr logs
2. limit asr recycling to 1 min of start audio to get reliable alignments and ignoring agent channel
3. added rev recycler for generating asr dataset from rev transcripts and audio
4. update pydub dependency for silence stripping fn and removing threadpool hardcoded worker count

1. added support for mono/dual channel rev transcripts
2. handle errors when extracting datapoints from rev meta data
3. added suport for annotation only task when dumping ui data

cleanup rev recycle

added option to disable plots during validation

fix skipping null audio and add more verbose logs

respect verbose flag

don't load audio for annotation only ui and keep spoken as text for normal asr validation

1. refactored wav chunk processing method
2. renamed streamlit to validation_ui

show duration on validation of dataset

parallelize data loading from remote

skipping invalid data points

1. removed the transcriber_pretrained/speller from utils
2. introduced get_mongo_coll to get the collection object directly from mongo uri
3. removed processing of correction entries to remove space/upper casing

refactor validation process arguments and logging

1. added a data extraction type argument
2. cleanup/refactor

1. using dataname args for update/fill annotations
2. rename to dump_ui

added support for name/dates/cities call data extraction and more logs

handling non-pnr cases without parens in text data

1. added conv data generator
2. more utils

1. added start delay arg in call recycler
2. implement ui_dump/manifest  writer in call_recycler itself
3. refactored call data point plotter
4. added sample-ui task-ui  on the validation process
5. implemented call-quality stats using corrections from mongo
6. support deleting cursors on mongo
7. implement multiple task support on validation ui based on task_id mongo field

fix 11st to 11th in ordinal

stripping silence on call chunk

1. added option to strip silent chunks
2. computing caller quality based on task-id of corrections

1. fix update-correction to use ui_dump instead of manifest
2. update training params no of checkpoints on chpk frequency

1. split extract all data types in one shot with --extraction-type all flag
2. add notes about diffing split extracted and original data
3. add a nlu conv generator to generate conv data based on nlu utterances and entities
4. add task uid support for dumping corrections
5. abstracted generate date fn

1. added a test generator and slu evaluator
2. ui dump now include gcp results
3. showing default option for more args validation process commands

added evaluation command

clean-up
pull/1/head
Malar Kannan 2020-04-08 17:26:27 +05:30
parent f7ebd8e90a
commit e24a8cf9d0
19 changed files with 2228 additions and 5 deletions

41
.gitignore vendored
View File

@ -1,3 +1,11 @@
/data/
/model/
/train/
.env*
*.yaml
*.yml
*.json
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python
@ -108,3 +116,36 @@ dmypy.json
.pyre/
# End of https://www.gitignore.io/api/python
# Created by https://www.gitignore.io/api/macos
# Edit at https://www.gitignore.io/?templates=macos
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
# End of https://www.gitignore.io/api/macos

5
Notes.md Normal file
View File

@ -0,0 +1,5 @@
> Diff after splitting based on type
```
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
```

View File

@ -62,7 +62,7 @@ class JasperASR(object):
wf = wave.open(audio_file_path, "w")
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(16000)
wf.setframerate(24000)
wf.writeframesraw(audio_data)
wf.close()
manifest = {"audio_filepath": audio_file_path, "duration": 60, "text": "todo"}
@ -108,6 +108,8 @@ class JasperASR(object):
tensors = self.neural_factory.infer(tensors=eval_tensors)
prediction = post_process_predictions(tensors[0], self.labels)
prediction_text = ". ".join(prediction)
os.unlink(manifest_file.name)
os.unlink(audio_file.name)
return prediction_text
def transcribe_file(self, audio_file, *args, **kwargs):

21
jasper/client.py Normal file
View File

@ -0,0 +1,21 @@
import os
import logging
import rpyc
from functools import lru_cache
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost")
ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045"))
@lru_cache()
def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
asr = rpyc.connect(asr_host, asr_port).root
logger.info(f"connected to asr server successfully")
return asr.transcribe

1
jasper/data/__init__.py Normal file
View File

@ -0,0 +1 @@

77
jasper/data/process.py Normal file
View File

@ -0,0 +1,77 @@
import json
from pathlib import Path
from sklearn.model_selection import train_test_split
from .utils import asr_manifest_reader, asr_manifest_writer
from typing import List
from itertools import chain
import typer
app = typer.Typer()
@app.command()
def fixate_data(dataset_path: Path):
manifest_path = dataset_path / Path("manifest.json")
real_manifest_path = dataset_path / Path("abs_manifest.json")
def fix_path():
for i in asr_manifest_reader(manifest_path):
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
yield i
asr_manifest_writer(real_manifest_path, fix_path())
@app.command()
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
reader_list = []
abs_manifest_path = Path("abs_manifest.json")
for dataset_path in src_dataset_paths:
manifest_path = dataset_path / abs_manifest_path
reader_list.append(asr_manifest_reader(manifest_path))
dest_dataset_path.mkdir(parents=True, exist_ok=True)
dest_manifest_path = dest_dataset_path / abs_manifest_path
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
@app.command()
def split_data(dataset_path: Path, test_size: float = 0.1):
manifest_path = dataset_path / Path("abs_manifest.json")
asr_data = list(asr_manifest_reader(manifest_path))
train_data, test_data = train_test_split(asr_data, test_size=test_size)
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_data)
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_data)
@app.command()
def validate_data(dataset_path: Path):
from natural.date import compress
from datetime import timedelta
for mf_type in ["train_manifest.json", "test_manifest.json"]:
data_file = dataset_path / Path(mf_type)
print(f"validating {data_file}.")
with Path(data_file).open("r") as pf:
data_jsonl = pf.readlines()
duration = 0
for (i, s) in enumerate(data_jsonl):
try:
d = json.loads(s)
duration += d["duration"]
audio_file = data_file.parent / Path(d["audio_filepath"])
if not audio_file.exists():
raise OSError(f"File {audio_file} not found")
except BaseException as e:
print(f'failed on {i} with "{e}"')
duration_str = compress(timedelta(seconds=duration), pad=" ")
print(
f"no errors found. seems like a valid {mf_type}. contains {duration_str}sec of audio"
)
def main():
app()
if __name__ == "__main__":
main()

57
jasper/data/server.py Normal file
View File

@ -0,0 +1,57 @@
import os
from pathlib import Path
import typer
import rpyc
from rpyc.utils.server import ThreadedServer
import nemo
import pickle
# import nemo.collections.asr as nemo_asr
from nemo.collections.asr.parts.segment import AudioSegment
app = typer.Typer()
nemo.core.NeuralModuleFactory(
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
)
class ASRDataService(rpyc.Service):
def exposed_get_path_samples(
self, file_path, target_sr, int_values, offset, duration, trim
):
print(f"loading.. {file_path}")
audio = AudioSegment.from_file(
file_path,
target_sr=target_sr,
int_values=int_values,
offset=offset,
duration=duration,
trim=trim,
)
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
return pickle.dumps(audio.samples)
def exposed_read_path(self, file_path):
# print(f"reading path.. {file_path}")
return Path(file_path).read_bytes()
@app.command()
def run_server(port: int = 0):
listen_port = port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
service = ASRDataService()
t = ThreadedServer(
service, port=listen_port, protocol_config={"allow_all_attrs": True}
)
typer.echo(f"starting asr server on {listen_port}...")
t.start()
def main():
app()
if __name__ == "__main__":
main()

236
jasper/data/utils.py Normal file
View File

@ -0,0 +1,236 @@
import io
import os
import json
import wave
from pathlib import Path
from functools import partial
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor
import pymongo
from slugify import slugify
from jasper.client import transcribe_gen
from nemo.collections.asr.metrics import word_error_rate
import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm
def manifest_str(path, dur, text):
return (
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
+ "\n"
)
def wav_bytes(audio_bytes, frame_rate=24000):
wf_b = io.BytesIO()
with wave.open(wf_b, mode="w") as wf:
wf.setnchannels(1)
wf.setframerate(frame_rate)
wf.setsampwidth(2)
wf.writeframesraw(audio_bytes)
return wf_b.getvalue()
def tscript_uuid_fname(transcript):
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for transcript, audio_dur, wav_data in asr_data_source:
fname = tscript_uuid_fname(transcript)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
rel_data_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
if verbose:
print(f"writing '{transcript}' of duration {audio_dur}")
num_datapoints += 1
return num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
ui_dump_file = dataset_dir / Path("ui_dump.json")
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
ui_dump = {
"use_domain_asr": False,
"annotation_only": False,
"enable_plots": True,
"data": [],
}
data_funcs = []
transcriber_pretrained = transcribe_gen(asr_port=8044)
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
def data_fn(
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
return {
"audio_filepath": str(rel_data_path),
"duration": round(audio_dur, 1),
"text": transcript,
"real_idx": num_datapoints,
"audio_path": audio_path,
"spoken": transcript,
"caller": caller_name,
"utterance_id": fname,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
}
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
audio_path = str(audio_file)
rel_data_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
data_funcs.append(
partial(
data_fn,
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
)
)
num_datapoints += 1
dump_data = parallel_apply(lambda x: x(), data_funcs)
# dump_data = [x() for x in tqdm(data_funcs)]
ui_dump["data"] = dump_data
ExtendedPath(ui_dump_file).write_json(ui_dump)
return num_datapoints
def asr_manifest_reader(data_manifest_path: Path):
print(f"reading manifest from {data_manifest_path}")
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
data_data = [json.loads(v) for v in data_jsonl]
for p in data_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["text"] = p["text"].strip()
yield p
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
with asr_manifest_path.open("w") as mf:
print(f"opening {asr_manifest_path} for writing manifest")
for mani_dict in manifest_str_source:
manifest = manifest_str(
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
)
mf.write(manifest)
def asr_test_writer(out_file_path: Path, source):
def dd_str(dd, idx):
path = dd["audio_filepath"]
# dur = dd["duration"]
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
res_file = out_file_path.with_suffix(".result.json")
with out_file_path.open("w") as of:
print(f"opening {out_file_path} for writing test")
results = []
idx = 0
for ui_dd in source:
results.append(ui_dd)
out_str = dd_str(ui_dd, idx)
of.write(out_str)
idx += 1
of.write("DO_HANGUP\n")
ExtendedPath(res_file).write_json(results)
def batch(iterable, n=1):
ls = len(iterable)
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
class ExtendedPath(type(Path())):
"""docstring for ExtendedPath."""
def read_json(self):
print(f"reading json from {self}")
with self.open("r") as jf:
return json.load(jf)
def write_json(self, data):
print(f"writing json to {self}")
self.parent.mkdir(parents=True, exist_ok=True)
with self.open("w") as jf:
return json.dump(data, jf, indent=2)
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
mongo_uri = f"mongodb://{mongo_host}:{port}/"
return pymongo.MongoClient(mongo_uri)[db][col]
def strip_silence(sound):
from pydub.silence import detect_leading_silence
start_trim = detect_leading_silence(sound)
end_trim = detect_leading_silence(sound.reverse())
duration = len(sound)
return sound[start_trim : duration - end_trim]
def plot_seg(wav_plot_path, audio_path):
fig = plt.Figure()
ax = fig.add_subplot()
(y, sr) = librosa.load(audio_path)
librosa.display.waveplot(y=y, sr=sr, ax=ax)
with wav_plot_path.open("wb") as wav_plot_f:
fig.set_tight_layout(True)
fig.savefig(wav_plot_f, format="png", dpi=50)
def parallel_apply(fn, iterable, workers=8):
with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")
return [
res
for res in tqdm(
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
)
]

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,418 @@
import json
import shutil
from pathlib import Path
from enum import Enum
import typer
from tqdm import tqdm
from ..utils import (
alnum_to_asr_tokens,
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
tscript_uuid_fname,
get_mongo_conn,
plot_seg,
)
app = typer.Typer()
def preprocess_datapoint(
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
):
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.client import transcribe_gen
try:
res = dict(sample)
res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
if use_domain_asr:
res["spoken"] = alnum_to_asr_tokens(res["text"])
else:
res["spoken"] = res["text"]
res["utterance_id"] = audio_path.stem
if not annotation_only:
transcriber_pretrained = transcribe_gen(asr_port=8044)
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate(
[res["text"]], [res["pretrained_asr"]]
)
if use_domain_asr:
transcriber_speller = transcribe_gen(asr_port=8045)
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]]
)
if enable_plots:
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
return res
except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
@app.command()
def dump_ui(
data_name: str = typer.Option("dataname", show_default=True),
dataset_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/valiation_data"),
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
use_domain_asr: bool = False,
annotation_only: bool = False,
enable_plots: bool = True,
):
from concurrent.futures import ThreadPoolExecutor
from functools import partial
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
dump_path: Path = dump_dir / Path(data_name) / dump_fname
plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
data_funcs = [
partial(
preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
annotation_only,
enable_plots,
)
for i, v in enumerate(data_jsonl)
]
def exec_func(f):
return f()
with ThreadPoolExecutor() as exe:
print("starting all preprocess tasks")
data_final = filter(
None,
list(
tqdm(
exe.map(exec_func, data_funcs),
position=0,
leave=True,
total=len(data_funcs),
)
),
)
if annotation_only:
result = list(data_final)
else:
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(data_final, key=lambda x: x[wer_key], reverse=True)
ui_config = {
"use_domain_asr": use_domain_asr,
"annotation_only": annotation_only,
"enable_plots": enable_plots,
"data": result,
}
ExtendedPath(dump_path).write_json(ui_config)
@app.command()
def sample_ui(
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
sample_count: int = typer.Option(80, show_default=True),
sample_file: Path = Path("sample_dump.json"),
):
import pandas as pd
processed_data_path = dump_dir / Path(data_name) / dump_file
sample_path = dump_dir / Path(data_name) / sample_file
processed_data = ExtendedPath(processed_data_path).read_json()
df = pd.DataFrame(processed_data["data"])
samples_per_caller = sample_count // len(df["caller"].unique())
caller_samples = pd.concat(
[g.sample(samples_per_caller) for (c, g) in df.groupby("caller")]
)
caller_samples = caller_samples.reset_index(drop=True)
caller_samples["real_idx"] = caller_samples.index
sample_data = caller_samples.to_dict("records")
processed_data["data"] = sample_data
typer.echo(f"sampling {sample_count} datapoints")
ExtendedPath(sample_path).write_json(processed_data)
@app.command()
def task_ui(
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
task_count: int = typer.Option(4, show_default=True),
task_file: str = "task_dump",
):
import pandas as pd
import numpy as np
processed_data_path = dump_dir / Path(data_name) / dump_file
processed_data = ExtendedPath(processed_data_path).read_json()
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
task_f = task_f.reset_index(drop=True)
task_f["real_idx"] = task_f.index
task_data = task_f.to_dict("records")
processed_data["data"] = task_data
task_path = dump_dir / Path(data_name) / Path(task_file + f"-{t_idx}.json")
ExtendedPath(task_path).write_json(processed_data)
@app.command()
def dump_corrections(
task_uid: str,
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("corrections.json"),
):
dump_path = dump_dir / Path(data_name) / dump_fname
col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@app.command()
def caller_quality(
task_uid: str,
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("ui_dump.json"),
correction_fname: Path = Path("corrections.json"),
):
import copy
import pandas as pd
dump_path = dump_dir / Path(data_name) / dump_fname
correction_path = dump_dir / Path(data_name) / correction_fname
dump_data = ExtendedPath(dump_path).read_json()
dump_map = {d["utterance_id"]: d for d in dump_data["data"]}
correction_data = ExtendedPath(correction_path).read_json()
def correction_dp(c):
dp = copy.deepcopy(dump_map[c["code"]])
dp["valid"] = c["value"]["status"] == "Correct"
return dp
corrected_dump = [
correction_dp(c)
for c in correction_data
if c["task_id"].rsplit("-", 1)[1] == task_uid
]
df = pd.DataFrame(corrected_dump)
print(f"Total samples: {len(df)}")
for (c, g) in df.groupby("caller"):
total = len(g)
valid = len(g[g["valid"] == True])
valid_rate = valid * 100 / total
print(f"Caller: {c} Valid%:{valid_rate:.2f} of {total} samples")
@app.command()
def fill_unannotated(
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
dump_file: Path = Path("ui_dump.json"),
corrections_file: Path = Path("corrections.json"),
):
processed_data_path = dump_dir / Path(data_name) / dump_file
corrections_path = dump_dir / Path(data_name) / corrections_file
processed_data = json.load(processed_data_path.open())
corrections = json.load(corrections_path.open())
annotated_codes = {c["code"] for c in corrections}
all_codes = {c["gold_chars"] for c in processed_data}
unann_codes = all_codes - annotated_codes
mongo_conn = get_mongo_conn(col="asr_validation")
for c in unann_codes:
mongo_conn.find_one_and_update(
{"type": "correction", "code": c},
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
upsert=True,
)
@app.command()
def split_extract(
data_name: str = typer.Option("dataname", show_default=True),
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
# dump_dir: Path = Path("./data/valiation_data"),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
extraction_type: str = "all",
):
import shutil
data_manifest_path = dump_dir / Path(data_name) / manifest_file
conv_data = ExtendedPath(conv_data_path).read_json()
def extract_data_of_type(extraction_key):
extraction_vals = conv_data[extraction_key]
dest_data_name = data_name + "_" + extraction_key.lower()
manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = dump_dir / Path(dest_data_name)
dest_data_dir.mkdir(exist_ok=True, parents=True)
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file
dest_ui_path = dest_data_dir / dump_file
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
ui_data_path = dump_dir / Path(data_name) / dump_file
orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
final_data = []
for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i
final_data.append(d)
orig_ui_data['data'] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file:
dest_correction_path = dest_data_dir / corrections_file
corrections_path = dump_dir / Path(data_name) / corrections_file
corrections = json.load(corrections_path.open())
extracted_corrections = list(
filter(
lambda c: c["code"] in file_ui_map
and file_ui_map[c["code"]]["text"] in extraction_vals,
corrections,
)
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all':
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
extract_data_of_type(extraction_type.value)
@app.command()
def update_corrections(
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
ui_dump_file: Path = Path("ui_dump.json"),
skip_incorrect: bool = typer.Option(True, show_default=True),
):
data_manifest_path = dump_dir / Path(data_name) / manifest_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
correction_map = {
c["code"]: c["value"]["correction"]
for c in corrections
if c["value"]["status"] == "Incorrect"
}
# for d in manifest_data_gen:
# if d["chars"] in incorrect_set:
# d["audio_path"].unlink()
# renamed_set = set()
for d in ui_data:
if d["utterance_id"] in correct_set:
yield {
"audio_filepath": d["audio_filepath"],
"duration": d["duration"],
"text": d["text"],
}
elif d["utterance_id"] in correction_map:
correct_text = correction_map[d["utterance_id"]]
if skip_incorrect:
print(
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
)
else:
orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
yield {
"audio_filepath": new_filepath,
"duration": d["duration"],
"text": correct_text,
}
else:
orig_audio_path = Path(d["audio_path"])
# don't delete if another correction points to an old file
# if d["text"] not in renamed_set:
orig_audio_path.unlink()
# else:
# print(f'skipping deletion of correction:{d["text"]}')
typer.echo(f"Using data manifest:{data_manifest_path}")
dataset_dir = data_manifest_path.parent
dataset_name = dataset_dir.name
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
if not backup_dir.exists():
typer.echo(f"backing up to :{backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir))
# manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
new_data_manifest_path.replace(data_manifest_path)
@app.command()
def clear_mongo_corrections():
delete = typer.confirm("are you sure you want to clear mongo collection it?")
if delete:
col = get_mongo_conn(col="asr_validation")
col.delete_many({"type": "correction"})
col.delete_many({"type": "current_cursor"})
typer.echo("deleted mongo collection.")
return
typer.echo("Aborted")
def main():
app()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,38 @@
import streamlit.ReportThread as ReportThread
from streamlit.ScriptRequestQueue import RerunData
from streamlit.ScriptRunner import RerunException
from streamlit.server.Server import Server
def rerun():
"""Rerun a Streamlit app from the top!"""
widget_states = _get_widget_states()
raise RerunException(RerunData(widget_states))
def _get_widget_states():
# Hack to get the session object from Streamlit.
ctx = ReportThread.get_report_ctx()
session = None
current_server = Server.get_current()
if hasattr(current_server, '_session_infos'):
# Streamlit < 0.56
session_infos = Server.get_current()._session_infos.values()
else:
session_infos = Server.get_current()._session_info_by_id.values()
for session_info in session_infos:
if session_info.session.enqueue == ctx.enqueue:
session = session_info.session
if session is None:
raise RuntimeError(
"Oh noes. Couldn't get your Streamlit Session object"
"Are you doing something fancy with threads?"
)
# Got the session object!
return session._widget_states

View File

@ -0,0 +1,158 @@
from pathlib import Path
import streamlit as st
import typer
from uuid import uuid4
from ..utils import ExtendedPath, get_mongo_conn
from .st_rerun import rerun
app = typer.Typer()
if not hasattr(st, "mongo_connected"):
st.mongoclient = get_mongo_conn(col="asr_validation")
mongo_conn = st.mongoclient
st.task_id = str(uuid4())
def current_cursor_fn():
# mongo_conn = st.mongoclient
cursor_obj = mongo_conn.find_one(
{"type": "current_cursor", "task_id": st.task_id}
)
cursor_val = cursor_obj["cursor"]
return cursor_val
def update_cursor_fn(val=0):
mongo_conn.find_one_and_update(
{"type": "current_cursor", "task_id": st.task_id},
{"$set": {"type": "current_cursor", "task_id": st.task_id, "cursor": val}},
upsert=True,
)
rerun()
def get_correction_entry_fn(code):
return mongo_conn.find_one(
{"type": "correction", "code": code}, projection={"_id": False}
)
def update_entry_fn(code, value):
mongo_conn.find_one_and_update(
{"type": "correction", "code": code},
{"$set": {"value": value, "task_id": st.task_id}},
upsert=True,
)
def set_task_fn(mf_path):
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
if not task_path.exists():
print(f"creating task lock at {task_path}")
task_path.touch()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn
st.get_correction_entry = get_correction_entry_fn
st.update_entry = update_entry_fn
st.set_task = set_task_fn
st.mongo_connected = True
cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
if not cursor_obj:
update_cursor_fn(0)
@st.cache()
def load_ui_data(validation_ui_data_path: Path):
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
return ExtendedPath(validation_ui_data_path).read_json()
@app.command()
def main(manifest: Path):
st.set_task(manifest)
ui_config = load_ui_data(manifest)
asr_data = ui_config["data"]
use_domain_asr = ui_config.get("use_domain_asr", True)
annotation_only = ui_config.get("annotation_only", False)
enable_plots = ui_config.get("enable_plots", True)
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
title_type = "Speller " if use_domain_asr else ""
task_uid = st.task_id.rsplit("-", 1)[1]
if annotation_only:
st.title(f"ASR Annotation - # {task_uid}")
else:
st.title(f"ASR {title_type}Validation - # {task_uid}")
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
if not annotation_only:
if use_domain_asr:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:")
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if "caller" in sample:
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
if enable_plots:
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text
corrected = sample["text"]
correction_entry = st.get_correction_entry(sample["utterance_id"])
selected_idx = 0
options = ("Correct", "Incorrect", "Inaudible")
# if correction entry is present set the corresponding ui defaults
if correction_entry:
selected_idx = options.index(correction_entry["value"]["status"])
corrected = correction_entry["value"]["correction"]
selected = st.radio("The Audio is", options, index=selected_idx)
if selected == "Incorrect":
corrected = st.text_input("Actual:", value=corrected)
if selected == "Inaudible":
corrected = ""
if st.button("Submit"):
st.update_entry(
sample["utterance_id"], {"status": selected, "correction": corrected}
)
st.update_cursor(sample_no + 1)
if correction_entry:
st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
)
text_sample = st.text_input("Go to Text:", value="")
if text_sample != "":
candidates = [
i
for (i, p) in enumerate(asr_data)
if p["text"] == text_sample or p["spoken"] == text_sample
]
if len(candidates) > 0:
st.update_cursor(candidates[0])
real_idx = st.number_input(
"Go to real-index",
value=sample["real_idx"],
min_value=0,
max_value=len(asr_data) - 1,
)
if real_idx != int(sample["real_idx"]):
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
st.update_cursor(idx)
if __name__ == "__main__":
try:
app()
except SystemExit:
pass

359
jasper/evaluate.py Normal file
View File

@ -0,0 +1,359 @@
# Copyright (c) 2019 NVIDIA Corporation
import argparse
import copy
# import math
import os
from pathlib import Path
from functools import partial
from ruamel.yaml import YAML
import nemo
import nemo.collections.asr as nemo_asr
import nemo.utils.argparse as nm_argparse
from nemo.collections.asr.helpers import (
# monitor_asr_train_progress,
process_evaluation_batch,
process_evaluation_epoch,
)
# from nemo.utils.lr_policies import CosineAnnealing
from training.data_loaders import RpycAudioToTextDataLayer
logging = nemo.logging
def parse_args():
parser = argparse.ArgumentParser(
parents=[nm_argparse.NemoArgParser()],
description="Jasper",
conflict_handler="resolve",
)
parser.set_defaults(
checkpoint_dir=None,
optimizer="novograd",
batch_size=64,
eval_batch_size=64,
lr=0.002,
amp_opt_level="O1",
create_tb_writer=True,
model_config="./train/jasper10x5dr.yaml",
work_dir="./train/work",
num_epochs=300,
weight_decay=0.005,
checkpoint_save_freq=100,
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
exp_name="jasper-speller",
)
# Overwrite default args
parser.add_argument(
"--max_steps",
type=int,
default=None,
required=False,
help="max number of steps to train",
)
parser.add_argument(
"--num_epochs", type=int, required=False, help="number of epochs to train"
)
parser.add_argument(
"--model_config",
type=str,
required=False,
help="model configuration file: model.yaml",
)
parser.add_argument(
"--encoder_checkpoint",
type=str,
required=True,
help="encoder checkpoint file: JasperEncoder.pt",
)
parser.add_argument(
"--decoder_checkpoint",
type=str,
required=True,
help="decoder checkpoint file: JasperDecoderForCTC.pt",
)
parser.add_argument(
"--remote_data",
type=str,
required=False,
default="",
help="remote dataloader endpoint",
)
parser.add_argument(
"--dataset",
type=str,
required=False,
default="",
help="dataset directory containing train/test manifests",
)
# Create new args
parser.add_argument("--exp_name", default="Jasper", type=str)
parser.add_argument("--beta1", default=0.95, type=float)
parser.add_argument("--beta2", default=0.25, type=float)
parser.add_argument("--warmup_steps", default=0, type=int)
parser.add_argument(
"--load_dir",
default=None,
type=str,
help="directory with pre-trained checkpoint",
)
args = parser.parse_args()
if args.max_steps is None and args.num_epochs is None:
raise ValueError("Either max_steps or num_epochs should be provided.")
return args
def construct_name(
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
):
if max_steps is not None:
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
)
else:
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
)
def create_all_dags(args, neural_factory):
yaml = YAML(typ="safe")
with open(args.model_config) as f:
jasper_params = yaml.load(f)
vocab = jasper_params["labels"]
sample_rate = jasper_params["sample_rate"]
# Calculate num_workers for dataloader
total_cpus = os.cpu_count()
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
# perturb_config = jasper_params.get('perturb', None)
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
del train_dl_params["train"]
del train_dl_params["eval"]
# del train_dl_params["normalize_transcripts"]
if args.dataset:
d_path = Path(args.dataset)
if not args.train_dataset:
args.train_dataset = str(d_path / Path("train_manifest.json"))
if not args.eval_datasets:
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
data_loader_layer = nemo_asr.AudioToTextDataLayer
if args.remote_data:
train_dl_params["rpyc_host"] = args.remote_data
data_loader_layer = RpycAudioToTextDataLayer
# data_layer = data_loader_layer(
# manifest_filepath=args.train_dataset,
# sample_rate=sample_rate,
# labels=vocab,
# batch_size=args.batch_size,
# num_workers=cpu_per_traindl,
# **train_dl_params,
# # normalize_transcripts=False
# )
#
# N = len(data_layer)
# steps_per_epoch = math.ceil(
# N / (args.batch_size * args.iter_per_step * args.num_gpus)
# )
# logging.info("Have {0} examples to train on.".format(N))
#
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
)
# multiply_batch_config = jasper_params.get("MultiplyBatch", None)
# if multiply_batch_config:
# multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
#
# spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
# if spectr_augment_config:
# data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
# **spectr_augment_config
# )
#
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
if args.remote_data:
eval_dl_params["rpyc_host"] = args.remote_data
del eval_dl_params["train"]
del eval_dl_params["eval"]
data_layers_eval = []
# if args.eval_datasets:
for eval_datasets in args.eval_datasets:
data_layer_eval = data_loader_layer(
manifest_filepath=eval_datasets,
sample_rate=sample_rate,
labels=vocab,
batch_size=args.eval_batch_size,
num_workers=cpu_per_traindl,
**eval_dl_params,
)
data_layers_eval.append(data_layer_eval)
# else:
# logging.warning("There were no val datasets passed")
jasper_encoder = nemo_asr.JasperEncoder(
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
**jasper_params["JasperEncoder"],
)
jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0)
jasper_decoder = nemo_asr.JasperDecoderForCTC(
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
num_classes=len(vocab),
)
jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0)
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
greedy_decoder = nemo_asr.GreedyCTCDecoder()
# logging.info("================================")
# logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
# logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
# logging.info(
# f"Total number of parameters in model: "
# f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
# )
# logging.info("================================")
#
# # Train DAG
# (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
# processed_signal_t, p_length_t = data_preprocessor(
# input_signal=audio_signal_t, length=a_sig_length_t
# )
#
# if multiply_batch_config:
# (
# processed_signal_t,
# p_length_t,
# transcript_t,
# transcript_len_t,
# ) = multiply_batch(
# in_x=processed_signal_t,
# in_x_len=p_length_t,
# in_y=transcript_t,
# in_y_len=transcript_len_t,
# )
#
# if spectr_augment_config:
# processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
#
# encoded_t, encoded_len_t = jasper_encoder(
# audio_signal=processed_signal_t, length=p_length_t
# )
# log_probs_t = jasper_decoder(encoder_output=encoded_t)
# predictions_t = greedy_decoder(log_probs=log_probs_t)
# loss_t = ctc_loss(
# log_probs=log_probs_t,
# targets=transcript_t,
# input_length=encoded_len_t,
# target_length=transcript_len_t,
# )
#
# # Callbacks needed to print info to console and Tensorboard
# train_callback = nemo.core.SimpleLossLoggerCallback(
# tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
# print_func=partial(monitor_asr_train_progress, labels=vocab),
# get_tb_values=lambda x: [("loss", x[0])],
# tb_writer=neural_factory.tb_writer,
# )
#
# chpt_callback = nemo.core.CheckpointCallback(
# folder=neural_factory.checkpoint_dir,
# load_from_folder=args.load_dir,
# step_freq=args.checkpoint_save_freq,
# checkpoints_to_keep=30,
# )
#
# callbacks = [train_callback, chpt_callback]
callbacks = []
# assemble eval DAGs
for i, eval_dl in enumerate(data_layers_eval):
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
processed_signal_e, p_length_e = data_preprocessor(
input_signal=audio_signal_e, length=a_sig_length_e
)
encoded_e, encoded_len_e = jasper_encoder(
audio_signal=processed_signal_e, length=p_length_e
)
log_probs_e = jasper_decoder(encoder_output=encoded_e)
predictions_e = greedy_decoder(log_probs=log_probs_e)
loss_e = ctc_loss(
log_probs=log_probs_e,
targets=transcript_e,
input_length=encoded_len_e,
target_length=transcript_len_e,
)
# create corresponding eval callback
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
eval_callback = nemo.core.EvaluatorCallback(
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
eval_step=args.eval_freq,
tb_writer=neural_factory.tb_writer,
)
callbacks.append(eval_callback)
return callbacks
def main():
args = parse_args()
# name = construct_name(
# args.exp_name,
# args.lr,
# args.batch_size,
# args.max_steps,
# args.num_epochs,
# args.weight_decay,
# args.optimizer,
# args.iter_per_step,
# )
# log_dir = name
# if args.work_dir:
# log_dir = os.path.join(args.work_dir, name)
# instantiate Neural Factory with supported backend
neural_factory = nemo.core.NeuralModuleFactory(
placement=nemo.core.DeviceType.GPU,
backend=nemo.core.Backend.PyTorch,
# local_rank=args.local_rank,
# optimization_level=args.amp_opt_level,
# log_dir=log_dir,
# checkpoint_dir=args.checkpoint_dir,
# create_tb_writer=args.create_tb_writer,
# files_to_copy=[args.model_config, __file__],
# cudnn_benchmark=args.cudnn_benchmark,
# tensorboard_dir=args.tensorboard_dir,
)
args.num_gpus = neural_factory.world_size
# checkpoint_dir = neural_factory.checkpoint_dir
if args.local_rank is not None:
logging.info("Doing ALL GPU")
# build dags
callbacks = create_all_dags(args, neural_factory)
# evaluate model
neural_factory.eval(callbacks=callbacks)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@

366
jasper/training/cli.py Normal file
View File

@ -0,0 +1,366 @@
# Copyright (c) 2019 NVIDIA Corporation
import argparse
import copy
import math
import os
from pathlib import Path
from functools import partial
from ruamel.yaml import YAML
import nemo
import nemo.collections.asr as nemo_asr
import nemo.utils.argparse as nm_argparse
from nemo.collections.asr.helpers import (
monitor_asr_train_progress,
process_evaluation_batch,
process_evaluation_epoch,
)
from nemo.utils.lr_policies import CosineAnnealing
from .data_loaders import RpycAudioToTextDataLayer
logging = nemo.logging
def parse_args():
parser = argparse.ArgumentParser(
parents=[nm_argparse.NemoArgParser()],
description="Jasper",
conflict_handler="resolve",
)
parser.set_defaults(
checkpoint_dir=None,
optimizer="novograd",
batch_size=64,
eval_batch_size=64,
lr=0.002,
amp_opt_level="O1",
create_tb_writer=True,
model_config="./train/jasper10x5dr.yaml",
work_dir="./train/work",
num_epochs=300,
weight_decay=0.005,
checkpoint_save_freq=100,
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
exp_name="jasper-speller",
)
# Overwrite default args
parser.add_argument(
"--max_steps",
type=int,
default=None,
required=False,
help="max number of steps to train",
)
parser.add_argument(
"--num_epochs",
type=int,
required=False,
help="number of epochs to train",
)
parser.add_argument(
"--model_config",
type=str,
required=False,
help="model configuration file: model.yaml",
)
parser.add_argument(
"--remote_data",
type=str,
required=False,
default="",
help="remote dataloader endpoint",
)
parser.add_argument(
"--dataset",
type=str,
required=False,
default="",
help="dataset directory containing train/test manifests",
)
# Create new args
parser.add_argument("--exp_name", default="Jasper", type=str)
parser.add_argument("--beta1", default=0.95, type=float)
parser.add_argument("--beta2", default=0.25, type=float)
parser.add_argument("--warmup_steps", default=0, type=int)
parser.add_argument(
"--load_dir",
default=None,
type=str,
help="directory with pre-trained checkpoint",
)
args = parser.parse_args()
if args.max_steps is None and args.num_epochs is None:
raise ValueError("Either max_steps or num_epochs should be provided.")
return args
def construct_name(
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
):
if max_steps is not None:
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
)
else:
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
)
def create_all_dags(args, neural_factory):
yaml = YAML(typ="safe")
with open(args.model_config) as f:
jasper_params = yaml.load(f)
vocab = jasper_params["labels"]
sample_rate = jasper_params["sample_rate"]
# Calculate num_workers for dataloader
total_cpus = os.cpu_count()
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
# perturb_config = jasper_params.get('perturb', None)
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
del train_dl_params["train"]
del train_dl_params["eval"]
# del train_dl_params["normalize_transcripts"]
if args.dataset:
d_path = Path(args.dataset)
if not args.train_dataset:
args.train_dataset = str(d_path / Path("train_manifest.json"))
if not args.eval_datasets:
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
data_loader_layer = nemo_asr.AudioToTextDataLayer
if args.remote_data:
train_dl_params["rpyc_host"] = args.remote_data
data_loader_layer = RpycAudioToTextDataLayer
data_layer = data_loader_layer(
manifest_filepath=args.train_dataset,
sample_rate=sample_rate,
labels=vocab,
batch_size=args.batch_size,
num_workers=cpu_per_traindl,
**train_dl_params,
# normalize_transcripts=False
)
N = len(data_layer)
steps_per_epoch = math.ceil(
N / (args.batch_size * args.iter_per_step * args.num_gpus)
)
logging.info("Have {0} examples to train on.".format(N))
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
)
multiply_batch_config = jasper_params.get("MultiplyBatch", None)
if multiply_batch_config:
multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
if spectr_augment_config:
data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
**spectr_augment_config
)
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
if args.remote_data:
eval_dl_params["rpyc_host"] = args.remote_data
del eval_dl_params["train"]
del eval_dl_params["eval"]
data_layers_eval = []
if args.eval_datasets:
for eval_datasets in args.eval_datasets:
data_layer_eval = data_loader_layer(
manifest_filepath=eval_datasets,
sample_rate=sample_rate,
labels=vocab,
batch_size=args.eval_batch_size,
num_workers=cpu_per_traindl,
**eval_dl_params,
)
data_layers_eval.append(data_layer_eval)
else:
logging.warning("There were no val datasets passed")
jasper_encoder = nemo_asr.JasperEncoder(
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
**jasper_params["JasperEncoder"],
)
jasper_decoder = nemo_asr.JasperDecoderForCTC(
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
num_classes=len(vocab),
)
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
greedy_decoder = nemo_asr.GreedyCTCDecoder()
logging.info("================================")
logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
logging.info(
f"Total number of parameters in model: "
f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
)
logging.info("================================")
# Train DAG
(audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
processed_signal_t, p_length_t = data_preprocessor(
input_signal=audio_signal_t, length=a_sig_length_t
)
if multiply_batch_config:
(
processed_signal_t,
p_length_t,
transcript_t,
transcript_len_t,
) = multiply_batch(
in_x=processed_signal_t,
in_x_len=p_length_t,
in_y=transcript_t,
in_y_len=transcript_len_t,
)
if spectr_augment_config:
processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
encoded_t, encoded_len_t = jasper_encoder(
audio_signal=processed_signal_t, length=p_length_t
)
log_probs_t = jasper_decoder(encoder_output=encoded_t)
predictions_t = greedy_decoder(log_probs=log_probs_t)
loss_t = ctc_loss(
log_probs=log_probs_t,
targets=transcript_t,
input_length=encoded_len_t,
target_length=transcript_len_t,
)
# Callbacks needed to print info to console and Tensorboard
train_callback = nemo.core.SimpleLossLoggerCallback(
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
print_func=partial(monitor_asr_train_progress, labels=vocab),
get_tb_values=lambda x: [("loss", x[0])],
tb_writer=neural_factory.tb_writer,
)
chpt_callback = nemo.core.CheckpointCallback(
folder=neural_factory.checkpoint_dir,
load_from_folder=args.load_dir,
step_freq=args.checkpoint_save_freq,
checkpoints_to_keep=30,
)
callbacks = [train_callback, chpt_callback]
# assemble eval DAGs
for i, eval_dl in enumerate(data_layers_eval):
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
processed_signal_e, p_length_e = data_preprocessor(
input_signal=audio_signal_e, length=a_sig_length_e
)
encoded_e, encoded_len_e = jasper_encoder(
audio_signal=processed_signal_e, length=p_length_e
)
log_probs_e = jasper_decoder(encoder_output=encoded_e)
predictions_e = greedy_decoder(log_probs=log_probs_e)
loss_e = ctc_loss(
log_probs=log_probs_e,
targets=transcript_e,
input_length=encoded_len_e,
target_length=transcript_len_e,
)
# create corresponding eval callback
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
eval_callback = nemo.core.EvaluatorCallback(
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
eval_step=args.eval_freq,
tb_writer=neural_factory.tb_writer,
)
callbacks.append(eval_callback)
return loss_t, callbacks, steps_per_epoch
def main():
args = parse_args()
name = construct_name(
args.exp_name,
args.lr,
args.batch_size,
args.max_steps,
args.num_epochs,
args.weight_decay,
args.optimizer,
args.iter_per_step,
)
log_dir = name
if args.work_dir:
log_dir = os.path.join(args.work_dir, name)
# instantiate Neural Factory with supported backend
neural_factory = nemo.core.NeuralModuleFactory(
backend=nemo.core.Backend.PyTorch,
local_rank=args.local_rank,
optimization_level=args.amp_opt_level,
log_dir=log_dir,
checkpoint_dir=args.checkpoint_dir,
create_tb_writer=args.create_tb_writer,
files_to_copy=[args.model_config, __file__],
cudnn_benchmark=args.cudnn_benchmark,
tensorboard_dir=args.tensorboard_dir,
)
args.num_gpus = neural_factory.world_size
checkpoint_dir = neural_factory.checkpoint_dir
if args.local_rank is not None:
logging.info("Doing ALL GPU")
# build dags
train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory)
# train model
neural_factory.train(
tensors_to_optimize=[train_loss],
callbacks=callbacks,
lr_policy=CosineAnnealing(
args.max_steps
if args.max_steps is not None
else args.num_epochs * steps_per_epoch,
warmup_steps=args.warmup_steps,
),
optimizer=args.optimizer,
optimization_params={
"num_epochs": args.num_epochs,
"max_steps": args.max_steps,
"lr": args.lr,
"betas": (args.beta1, args.beta2),
"weight_decay": args.weight_decay,
"grad_norm_clip": None,
},
batches_per_step=args.iter_per_step,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,334 @@
from functools import partial
import tempfile
# from typing import Any, Dict, List, Optional
import torch
import nemo
# import nemo.collections.asr as nemo_asr
from nemo.backends.pytorch import DataLayerNM
from nemo.core import DeviceType
# from nemo.core.neural_types import *
from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType
from nemo.utils.decorators import add_port_docs
from nemo.collections.asr.parts.dataset import (
# AudioDataset,
# AudioLabelDataset,
# KaldiFeatureDataset,
# TranscriptDataset,
parsers,
collections,
seq_collate_fn,
)
# from functools import lru_cache
import rpyc
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from .featurizer import RpycWaveformFeaturizer
# from nemo.collections.asr.parts.features import WaveformFeaturizer
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
logging = nemo.logging
class CachedAudioDataset(torch.utils.data.Dataset):
"""
Dataset that loads tensors via a json file containing paths to audio
files, transcripts, and durations (in seconds). Each new line is a
different sample. Example below:
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
"/path/to/audio.txt", "duration": 23.147}
...
{"audio_filepath": "/path/to/audio.wav", "text": "the
transcription", offset": 301.75, "duration": 0.82, "utt":
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
Args:
manifest_filepath: Path to manifest json as described above. Can
be comma-separated paths.
labels: String containing all the possible characters to map to
featurizer: Initialized featurizer class that converts paths of
audio to feature tensors
max_duration: If audio exceeds this length, do not include in dataset
min_duration: If audio is less than this length, do not include
in dataset
max_utts: Limit number of utterances
blank_index: blank character index, default = -1
unk_index: unk_character index, default = -1
normalize: whether to normalize transcript text (default): True
bos_id: Id of beginning of sequence symbol to append if not None
eos_id: Id of end of sequence symbol to append if not None
load_audio: Boolean flag indicate whether do or not load audio
"""
def __init__(
self,
manifest_filepath,
labels,
featurizer,
max_duration=None,
min_duration=None,
max_utts=0,
blank_index=-1,
unk_index=-1,
normalize=True,
trim=False,
bos_id=None,
eos_id=None,
load_audio=True,
parser="en",
):
self.collection = collections.ASRAudioText(
manifests_files=manifest_filepath.split(","),
parser=parsers.make_parser(
labels=labels,
name=parser,
unk_id=unk_index,
blank_id=blank_index,
do_normalize=normalize,
),
min_duration=min_duration,
max_duration=max_duration,
max_number=max_utts,
)
self.index_feature_map = {}
self.featurizer = featurizer
self.trim = trim
self.eos_id = eos_id
self.bos_id = bos_id
self.load_audio = load_audio
print(f"initializing dataset {manifest_filepath}")
def exec_func(i):
return self[i]
task_count = len(self.collection)
with ThreadPoolExecutor() as exe:
print("starting all loading tasks")
list(
tqdm(
exe.map(exec_func, range(task_count)),
position=0,
leave=True,
total=task_count,
)
)
print(f"initializing complete")
def __getitem__(self, index):
sample = self.collection[index]
if self.load_audio:
cached_features = self.index_feature_map.get(index)
if cached_features is not None:
features = cached_features
else:
features = self.featurizer.process(
sample.audio_file,
offset=0,
duration=sample.duration,
trim=self.trim,
)
self.index_feature_map[index] = features
f, fl = features, torch.tensor(features.shape[0]).long()
else:
f, fl = None, None
t, tl = sample.text_tokens, len(sample.text_tokens)
if self.bos_id is not None:
t = [self.bos_id] + t
tl += 1
if self.eos_id is not None:
t = t + [self.eos_id]
tl += 1
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
def __len__(self):
return len(self.collection)
class RpycAudioToTextDataLayer(DataLayerNM):
"""Data Layer for general ASR tasks.
Module which reads ASR labeled data. It accepts comma-separated
JSON manifest files describing the correspondence between wav audio files
and their transcripts. JSON files should be of the following format::
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
transcript_0}
...
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
transcript_n}
Args:
manifest_filepath (str): Dataset parameter.
Path to JSON containing data.
labels (list): Dataset parameter.
List of characters that can be output by the ASR model.
For Jasper, this is the 28 character set {a-z '}. The CTC blank
symbol is automatically added later for models using ctc.
batch_size (int): batch size
sample_rate (int): Target sampling rate for data. Audio files will be
resampled to sample_rate if it is not already.
Defaults to 16000.
int_values (bool): Bool indicating whether the audio file is saved as
int data or float data.
Defaults to False.
eos_id (id): Dataset parameter.
End of string symbol id used for seq2seq models.
Defaults to None.
min_duration (float): Dataset parameter.
All training files which have a duration less than min_duration
are dropped. Note: Duration is read from the manifest JSON.
Defaults to 0.1.
max_duration (float): Dataset parameter.
All training files which have a duration more than max_duration
are dropped. Note: Duration is read from the manifest JSON.
Defaults to None.
normalize_transcripts (bool): Dataset parameter.
Whether to use automatic text cleaning.
It is highly recommended to manually clean text for best results.
Defaults to True.
trim_silence (bool): Whether to use trim silence from beginning and end
of audio signal using librosa.effects.trim().
Defaults to False.
load_audio (bool): Dataset parameter.
Controls whether the dataloader loads the audio signal and
transcript or just the transcript.
Defaults to True.
drop_last (bool): See PyTorch DataLoader.
Defaults to False.
shuffle (bool): See PyTorch DataLoader.
Defaults to True.
num_workers (int): See PyTorch DataLoader.
Defaults to 0.
perturb_config (dict): Currently disabled.
"""
@property
@add_port_docs()
def output_ports(self):
"""Returns definitions of module output ports.
"""
return {
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
"audio_signal": NeuralType(
("B", "T"),
AudioSignal(freq=self._sample_rate)
if self is not None and self._sample_rate is not None
else AudioSignal(),
),
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
"transcripts": NeuralType(("B", "T"), LabelsType()),
"transcript_length": NeuralType(tuple("B"), LengthsType()),
}
def __init__(
self,
manifest_filepath,
labels,
batch_size,
sample_rate=16000,
int_values=False,
bos_id=None,
eos_id=None,
pad_id=None,
min_duration=0.1,
max_duration=None,
normalize_transcripts=True,
trim_silence=False,
load_audio=True,
rpyc_host="",
drop_last=False,
shuffle=True,
num_workers=0,
):
super().__init__()
self._sample_rate = sample_rate
def rpyc_root_fn():
return rpyc.connect(
rpyc_host, 8064, config={"sync_request_timeout": 600}
).root
rpyc_conn = rpyc_root_fn()
self._featurizer = RpycWaveformFeaturizer(
sample_rate=self._sample_rate,
int_values=int_values,
augmentor=None,
rpyc_conn=rpyc_conn,
)
def read_remote_manifests():
local_mp = []
for mrp in manifest_filepath.split(","):
md = rpyc_conn.read_path(mrp)
mf = tempfile.NamedTemporaryFile(
dir="/tmp", prefix="jasper_manifest.", delete=False
)
mf.write(md)
mf.close()
local_mp.append(mf.name)
return ",".join(local_mp)
local_manifest_filepath = read_remote_manifests()
dataset_params = {
"manifest_filepath": local_manifest_filepath,
"labels": labels,
"featurizer": self._featurizer,
"max_duration": max_duration,
"min_duration": min_duration,
"normalize": normalize_transcripts,
"trim": trim_silence,
"bos_id": bos_id,
"eos_id": eos_id,
"load_audio": load_audio,
}
self._dataset = CachedAudioDataset(**dataset_params)
self._batch_size = batch_size
# Set up data loader
if self._placement == DeviceType.AllGpu:
logging.info("Parallelizing Datalayer.")
sampler = torch.utils.data.distributed.DistributedSampler(self._dataset)
else:
sampler = None
if batch_size == -1:
batch_size = len(self._dataset)
pad_id = 0 if pad_id is None else pad_id
self._dataloader = torch.utils.data.DataLoader(
dataset=self._dataset,
batch_size=batch_size,
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
drop_last=drop_last,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=1,
)
def __len__(self):
return len(self._dataset)
@property
def dataset(self):
return None
@property
def data_iterator(self):
return self._dataloader

View File

@ -0,0 +1,51 @@
# import math
# import librosa
import torch
import pickle
# import torch.nn as nn
# from torch_stft import STFT
# from nemo import logging
from nemo.collections.asr.parts.perturb import AudioAugmentor
# from nemo.collections.asr.parts.segment import AudioSegment
class RpycWaveformFeaturizer(object):
def __init__(
self, sample_rate=16000, int_values=False, augmentor=None, rpyc_conn=None
):
self.augmentor = augmentor if augmentor is not None else AudioAugmentor()
self.sample_rate = sample_rate
self.int_values = int_values
self.remote_path_samples = rpyc_conn.get_path_samples
def max_augmentation_length(self, length):
return self.augmentor.max_augmentation_length(length)
def process(self, file_path, offset=0, duration=0, trim=False):
audio = self.remote_path_samples(
file_path,
target_sr=self.sample_rate,
int_values=self.int_values,
offset=offset,
duration=duration,
trim=trim,
)
return torch.tensor(pickle.loads(audio), dtype=torch.float)
def process_segment(self, audio_segment):
self.augmentor.perturb(audio_segment)
return torch.tensor(audio_segment, dtype=torch.float)
@classmethod
def from_config(cls, input_config, perturbation_configs=None):
if perturbation_configs is not None:
aa = AudioAugmentor.from_config(perturbation_configs)
else:
aa = None
sample_rate = input_config.get("sample_rate", 16000)
int_values = input_config.get("int_values", False)
return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa)

View File

@ -1,11 +1,52 @@
from setuptools import setup
from setuptools import setup, find_packages
requirements = [
"ruamel.yaml",
"torch==1.4.0",
"torchvision==0.5.0",
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
]
extra_requirements = {"server": ["rpyc==4.1.4"]}
extra_requirements = {
"server": ["rpyc~=4.1.4", "tqdm~=4.39.0"],
"data": [
"google-cloud-texttospeech~=1.0.1",
"tqdm~=4.39.0",
"pydub~=0.24.0",
"scikit_learn~=0.22.1",
"pandas~=1.0.3",
"boto3~=1.12.35",
"ruamel.yaml==0.16.10",
"pymongo==3.10.1",
"librosa==0.7.2",
"matplotlib==3.2.1",
"pandas==1.0.3",
"tabulate==0.8.7",
"natural==0.2.0",
"num2words==0.5.10",
"typer[all]==0.1.1",
"python-slugify==4.0.0",
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
],
"validation": [
"rpyc~=4.1.4",
"pymongo==3.10.1",
"typer[all]==0.1.1",
"tqdm~=4.39.0",
"librosa==0.7.2",
"matplotlib==3.2.1",
"pydub~=0.24.0",
"streamlit==0.58.0",
"natural==0.2.0",
"stringcase==1.2.0",
"google-cloud-speech~=1.3.1",
]
# "train": [
# "torchaudio==0.5.0",
# "torch-stft==0.1.4",
# ]
}
packages = find_packages()
setup(
name="jasper-asr",
@ -17,11 +58,24 @@ setup(
license="MIT",
install_requires=requirements,
extras_require=extra_requirements,
packages=["."],
packages=packages,
entry_points={
"console_scripts": [
"jasper_transcribe = jasper.transcribe:main",
"jasper_asr_rpyc_server = jasper.server:main",
"jasper_server = jasper.server:main",
"jasper_trainer = jasper.training.cli:main",
"jasper_evaluator = jasper.evaluate:main",
"jasper_data_tts_generate = jasper.data.tts_generator:main",
"jasper_data_conv_generate = jasper.data.conv_generator:main",
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
"jasper_data_test_generate = jasper.data.test_generator:main",
"jasper_data_call_recycle = jasper.data.call_recycler:main",
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
"jasper_data_server = jasper.data.server:main",
"jasper_data_validation = jasper.data.validation.process:main",
"jasper_data_preprocess = jasper.data.process:main",
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
]
},
zip_safe=False,

3
validation_ui.py Normal file
View File

@ -0,0 +1,3 @@
import runpy
runpy.run_module("jasper.data.validation.ui", run_name="__main__", alter_sys=True)