diff --git a/.gitignore b/.gitignore index aab7ea0..3676125 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,11 @@ +/data/ +/model/ +/train/ +.env* +*.yaml +*.yml +*.json + # Created by https://www.gitignore.io/api/python # Edit at https://www.gitignore.io/?templates=python @@ -108,3 +116,36 @@ dmypy.json .pyre/ # End of https://www.gitignore.io/api/python + +# Created by https://www.gitignore.io/api/macos +# Edit at https://www.gitignore.io/?templates=macos + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# End of https://www.gitignore.io/api/macos diff --git a/Notes.md b/Notes.md new file mode 100644 index 0000000..4195305 --- /dev/null +++ b/Notes.md @@ -0,0 +1,5 @@ + +> Diff after splitting based on type +``` +diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort) +``` diff --git a/jasper/asr.py b/jasper/asr.py index de3d78f..e52695d 100644 --- a/jasper/asr.py +++ b/jasper/asr.py @@ -62,7 +62,7 @@ class JasperASR(object): wf = wave.open(audio_file_path, "w") wf.setnchannels(1) wf.setsampwidth(2) - wf.setframerate(16000) + wf.setframerate(24000) wf.writeframesraw(audio_data) wf.close() manifest = {"audio_filepath": audio_file_path, "duration": 60, "text": "todo"} @@ -108,6 +108,8 @@ class JasperASR(object): tensors = self.neural_factory.infer(tensors=eval_tensors) prediction = post_process_predictions(tensors[0], self.labels) prediction_text = ". ".join(prediction) + os.unlink(manifest_file.name) + os.unlink(audio_file.name) return prediction_text def transcribe_file(self, audio_file, *args, **kwargs): diff --git a/jasper/client.py b/jasper/client.py new file mode 100644 index 0000000..6c474a5 --- /dev/null +++ b/jasper/client.py @@ -0,0 +1,21 @@ +import os +import logging +import rpyc +from functools import lru_cache + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost") +ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045")) + + +@lru_cache() +def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT): + logger.info(f"connecting to asr server at {asr_host}:{asr_port}") + asr = rpyc.connect(asr_host, asr_port).root + logger.info(f"connected to asr server successfully") + return asr.transcribe diff --git a/jasper/data/__init__.py b/jasper/data/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/jasper/data/__init__.py @@ -0,0 +1 @@ + diff --git a/jasper/data/process.py b/jasper/data/process.py new file mode 100644 index 0000000..472c5bf --- /dev/null +++ b/jasper/data/process.py @@ -0,0 +1,77 @@ +import json +from pathlib import Path +from sklearn.model_selection import train_test_split +from .utils import asr_manifest_reader, asr_manifest_writer +from typing import List +from itertools import chain +import typer + +app = typer.Typer() + + +@app.command() +def fixate_data(dataset_path: Path): + manifest_path = dataset_path / Path("manifest.json") + real_manifest_path = dataset_path / Path("abs_manifest.json") + + def fix_path(): + for i in asr_manifest_reader(manifest_path): + i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"])) + yield i + + asr_manifest_writer(real_manifest_path, fix_path()) + + +@app.command() +def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path): + reader_list = [] + abs_manifest_path = Path("abs_manifest.json") + for dataset_path in src_dataset_paths: + manifest_path = dataset_path / abs_manifest_path + reader_list.append(asr_manifest_reader(manifest_path)) + dest_dataset_path.mkdir(parents=True, exist_ok=True) + dest_manifest_path = dest_dataset_path / abs_manifest_path + asr_manifest_writer(dest_manifest_path, chain(*reader_list)) + + +@app.command() +def split_data(dataset_path: Path, test_size: float = 0.1): + manifest_path = dataset_path / Path("abs_manifest.json") + asr_data = list(asr_manifest_reader(manifest_path)) + train_data, test_data = train_test_split(asr_data, test_size=test_size) + asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_data) + asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_data) + + +@app.command() +def validate_data(dataset_path: Path): + from natural.date import compress + from datetime import timedelta + + for mf_type in ["train_manifest.json", "test_manifest.json"]: + data_file = dataset_path / Path(mf_type) + print(f"validating {data_file}.") + with Path(data_file).open("r") as pf: + data_jsonl = pf.readlines() + duration = 0 + for (i, s) in enumerate(data_jsonl): + try: + d = json.loads(s) + duration += d["duration"] + audio_file = data_file.parent / Path(d["audio_filepath"]) + if not audio_file.exists(): + raise OSError(f"File {audio_file} not found") + except BaseException as e: + print(f'failed on {i} with "{e}"') + duration_str = compress(timedelta(seconds=duration), pad=" ") + print( + f"no errors found. seems like a valid {mf_type}. contains {duration_str}sec of audio" + ) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data/server.py b/jasper/data/server.py new file mode 100644 index 0000000..856c381 --- /dev/null +++ b/jasper/data/server.py @@ -0,0 +1,57 @@ +import os +from pathlib import Path + +import typer +import rpyc +from rpyc.utils.server import ThreadedServer +import nemo +import pickle + +# import nemo.collections.asr as nemo_asr +from nemo.collections.asr.parts.segment import AudioSegment + +app = typer.Typer() + +nemo.core.NeuralModuleFactory( + backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU +) + + +class ASRDataService(rpyc.Service): + def exposed_get_path_samples( + self, file_path, target_sr, int_values, offset, duration, trim + ): + print(f"loading.. {file_path}") + audio = AudioSegment.from_file( + file_path, + target_sr=target_sr, + int_values=int_values, + offset=offset, + duration=duration, + trim=trim, + ) + # print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}") + return pickle.dumps(audio.samples) + + def exposed_read_path(self, file_path): + # print(f"reading path.. {file_path}") + return Path(file_path).read_bytes() + + +@app.command() +def run_server(port: int = 0): + listen_port = port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064")) + service = ASRDataService() + t = ThreadedServer( + service, port=listen_port, protocol_config={"allow_all_attrs": True} + ) + typer.echo(f"starting asr server on {listen_port}...") + t.start() + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data/utils.py b/jasper/data/utils.py new file mode 100644 index 0000000..5409e1c --- /dev/null +++ b/jasper/data/utils.py @@ -0,0 +1,236 @@ +import io +import os +import json +import wave +from pathlib import Path +from functools import partial +from uuid import uuid4 +from concurrent.futures import ThreadPoolExecutor + +import pymongo +from slugify import slugify +from jasper.client import transcribe_gen +from nemo.collections.asr.metrics import word_error_rate +import matplotlib.pyplot as plt +import librosa +import librosa.display +from tqdm import tqdm + + +def manifest_str(path, dur, text): + return ( + json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text}) + + "\n" + ) + + +def wav_bytes(audio_bytes, frame_rate=24000): + wf_b = io.BytesIO() + with wave.open(wf_b, mode="w") as wf: + wf.setnchannels(1) + wf.setframerate(frame_rate) + wf.setsampwidth(2) + wf.writeframesraw(audio_bytes) + return wf_b.getvalue() + + +def tscript_uuid_fname(transcript): + return str(uuid4()) + "_" + slugify(transcript, max_length=8) + + +def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False): + dataset_dir = output_dir / Path(dataset_name) + (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) + asr_manifest = dataset_dir / Path("manifest.json") + num_datapoints = 0 + with asr_manifest.open("w") as mf: + print(f"writing manifest to {asr_manifest}") + for transcript, audio_dur, wav_data in asr_data_source: + fname = tscript_uuid_fname(transcript) + audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") + audio_file.write_bytes(wav_data) + rel_data_path = audio_file.relative_to(dataset_dir) + manifest = manifest_str(str(rel_data_path), audio_dur, transcript) + mf.write(manifest) + if verbose: + print(f"writing '{transcript}' of duration {audio_dur}") + num_datapoints += 1 + return num_datapoints + + +def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False): + dataset_dir = output_dir / Path(dataset_name) + (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) + ui_dump_file = dataset_dir / Path("ui_dump.json") + (dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True) + asr_manifest = dataset_dir / Path("manifest.json") + num_datapoints = 0 + ui_dump = { + "use_domain_asr": False, + "annotation_only": False, + "enable_plots": True, + "data": [], + } + data_funcs = [] + transcriber_pretrained = transcribe_gen(asr_port=8044) + with asr_manifest.open("w") as mf: + print(f"writing manifest to {asr_manifest}") + + def data_fn( + transcript, + audio_dur, + wav_data, + caller_name, + aud_seg, + fname, + audio_path, + num_datapoints, + rel_data_path, + ): + pretrained_result = transcriber_pretrained(aud_seg.raw_data) + pretrained_wer = word_error_rate([transcript], [pretrained_result]) + wav_plot_path = ( + dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png") + ) + if not wav_plot_path.exists(): + plot_seg(wav_plot_path, audio_path) + return { + "audio_filepath": str(rel_data_path), + "duration": round(audio_dur, 1), + "text": transcript, + "real_idx": num_datapoints, + "audio_path": audio_path, + "spoken": transcript, + "caller": caller_name, + "utterance_id": fname, + "pretrained_asr": pretrained_result, + "pretrained_wer": pretrained_wer, + "plot_path": str(wav_plot_path), + } + + for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source: + fname = str(uuid4()) + "_" + slugify(transcript, max_length=8) + audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") + audio_file.write_bytes(wav_data) + audio_path = str(audio_file) + rel_data_path = audio_file.relative_to(dataset_dir) + manifest = manifest_str(str(rel_data_path), audio_dur, transcript) + mf.write(manifest) + data_funcs.append( + partial( + data_fn, + transcript, + audio_dur, + wav_data, + caller_name, + aud_seg, + fname, + audio_path, + num_datapoints, + rel_data_path, + ) + ) + num_datapoints += 1 + dump_data = parallel_apply(lambda x: x(), data_funcs) + # dump_data = [x() for x in tqdm(data_funcs)] + ui_dump["data"] = dump_data + ExtendedPath(ui_dump_file).write_json(ui_dump) + return num_datapoints + + +def asr_manifest_reader(data_manifest_path: Path): + print(f"reading manifest from {data_manifest_path}") + with data_manifest_path.open("r") as pf: + data_jsonl = pf.readlines() + data_data = [json.loads(v) for v in data_jsonl] + for p in data_data: + p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"]) + p["text"] = p["text"].strip() + yield p + + +def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source): + with asr_manifest_path.open("w") as mf: + print(f"opening {asr_manifest_path} for writing manifest") + for mani_dict in manifest_str_source: + manifest = manifest_str( + mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"] + ) + mf.write(manifest) + + +def asr_test_writer(out_file_path: Path, source): + def dd_str(dd, idx): + path = dd["audio_filepath"] + # dur = dd["duration"] + # return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n" + return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n" + + res_file = out_file_path.with_suffix(".result.json") + with out_file_path.open("w") as of: + print(f"opening {out_file_path} for writing test") + results = [] + idx = 0 + for ui_dd in source: + results.append(ui_dd) + out_str = dd_str(ui_dd, idx) + of.write(out_str) + idx += 1 + of.write("DO_HANGUP\n") + ExtendedPath(res_file).write_json(results) + + +def batch(iterable, n=1): + ls = len(iterable) + return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)] + + +class ExtendedPath(type(Path())): + """docstring for ExtendedPath.""" + + def read_json(self): + print(f"reading json from {self}") + with self.open("r") as jf: + return json.load(jf) + + def write_json(self, data): + print(f"writing json to {self}") + self.parent.mkdir(parents=True, exist_ok=True) + with self.open("w") as jf: + return json.dump(data, jf, indent=2) + + +def get_mongo_conn(host="", port=27017, db="test", col="calls"): + mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost") + mongo_uri = f"mongodb://{mongo_host}:{port}/" + return pymongo.MongoClient(mongo_uri)[db][col] + + +def strip_silence(sound): + from pydub.silence import detect_leading_silence + + start_trim = detect_leading_silence(sound) + end_trim = detect_leading_silence(sound.reverse()) + duration = len(sound) + return sound[start_trim : duration - end_trim] + + +def plot_seg(wav_plot_path, audio_path): + fig = plt.Figure() + ax = fig.add_subplot() + (y, sr) = librosa.load(audio_path) + librosa.display.waveplot(y=y, sr=sr, ax=ax) + with wav_plot_path.open("wb") as wav_plot_f: + fig.set_tight_layout(True) + fig.savefig(wav_plot_f, format="png", dpi=50) + + +def parallel_apply(fn, iterable, workers=8): + with ThreadPoolExecutor(max_workers=workers) as exe: + print(f"parallelly applying {fn}") + return [ + res + for res in tqdm( + exe.map(fn, iterable), position=0, leave=True, total=len(iterable) + ) + ] diff --git a/jasper/data/validation/__init__.py b/jasper/data/validation/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/jasper/data/validation/__init__.py @@ -0,0 +1 @@ + diff --git a/jasper/data/validation/process.py b/jasper/data/validation/process.py new file mode 100644 index 0000000..f887c27 --- /dev/null +++ b/jasper/data/validation/process.py @@ -0,0 +1,418 @@ +import json +import shutil +from pathlib import Path +from enum import Enum + +import typer +from tqdm import tqdm + +from ..utils import ( + alnum_to_asr_tokens, + ExtendedPath, + asr_manifest_reader, + asr_manifest_writer, + tscript_uuid_fname, + get_mongo_conn, + plot_seg, +) + +app = typer.Typer() + + +def preprocess_datapoint( + idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots +): + from pydub import AudioSegment + from nemo.collections.asr.metrics import word_error_rate + from jasper.client import transcribe_gen + + try: + res = dict(sample) + res["real_idx"] = idx + audio_path = rel_root / Path(sample["audio_filepath"]) + res["audio_path"] = str(audio_path) + if use_domain_asr: + res["spoken"] = alnum_to_asr_tokens(res["text"]) + else: + res["spoken"] = res["text"] + res["utterance_id"] = audio_path.stem + if not annotation_only: + transcriber_pretrained = transcribe_gen(asr_port=8044) + + aud_seg = ( + AudioSegment.from_file_using_temporary_files(audio_path) + .set_channels(1) + .set_sample_width(2) + .set_frame_rate(24000) + ) + res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) + res["pretrained_wer"] = word_error_rate( + [res["text"]], [res["pretrained_asr"]] + ) + if use_domain_asr: + transcriber_speller = transcribe_gen(asr_port=8045) + res["domain_asr"] = transcriber_speller(aud_seg.raw_data) + res["domain_wer"] = word_error_rate( + [res["spoken"]], [res["pretrained_asr"]] + ) + if enable_plots: + wav_plot_path = ( + rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png") + ) + if not wav_plot_path.exists(): + plot_seg(wav_plot_path, audio_path) + res["plot_path"] = str(wav_plot_path) + return res + except BaseException as e: + print(f'failed on {idx}: {sample["audio_filepath"]} with {e}') + + +@app.command() +def dump_ui( + data_name: str = typer.Option("dataname", show_default=True), + dataset_dir: Path = Path("./data/asr_data"), + dump_dir: Path = Path("./data/valiation_data"), + dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True), + use_domain_asr: bool = False, + annotation_only: bool = False, + enable_plots: bool = True, +): + from concurrent.futures import ThreadPoolExecutor + from functools import partial + + data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json") + dump_path: Path = dump_dir / Path(data_name) / dump_fname + plot_dir = data_manifest_path.parent / Path("wav_plots") + plot_dir.mkdir(parents=True, exist_ok=True) + typer.echo(f"Using data manifest:{data_manifest_path}") + with data_manifest_path.open("r") as pf: + data_jsonl = pf.readlines() + data_funcs = [ + partial( + preprocess_datapoint, + i, + data_manifest_path.parent, + json.loads(v), + use_domain_asr, + annotation_only, + enable_plots, + ) + for i, v in enumerate(data_jsonl) + ] + + def exec_func(f): + return f() + + with ThreadPoolExecutor() as exe: + print("starting all preprocess tasks") + data_final = filter( + None, + list( + tqdm( + exe.map(exec_func, data_funcs), + position=0, + leave=True, + total=len(data_funcs), + ) + ), + ) + if annotation_only: + result = list(data_final) + else: + wer_key = "domain_wer" if use_domain_asr else "pretrained_wer" + result = sorted(data_final, key=lambda x: x[wer_key], reverse=True) + ui_config = { + "use_domain_asr": use_domain_asr, + "annotation_only": annotation_only, + "enable_plots": enable_plots, + "data": result, + } + ExtendedPath(dump_path).write_json(ui_config) + + +@app.command() +def sample_ui( + data_name: str = typer.Option("dataname", show_default=True), + dump_dir: Path = Path("./data/asr_data"), + dump_file: Path = Path("ui_dump.json"), + sample_count: int = typer.Option(80, show_default=True), + sample_file: Path = Path("sample_dump.json"), +): + import pandas as pd + + processed_data_path = dump_dir / Path(data_name) / dump_file + sample_path = dump_dir / Path(data_name) / sample_file + processed_data = ExtendedPath(processed_data_path).read_json() + df = pd.DataFrame(processed_data["data"]) + samples_per_caller = sample_count // len(df["caller"].unique()) + caller_samples = pd.concat( + [g.sample(samples_per_caller) for (c, g) in df.groupby("caller")] + ) + caller_samples = caller_samples.reset_index(drop=True) + caller_samples["real_idx"] = caller_samples.index + sample_data = caller_samples.to_dict("records") + processed_data["data"] = sample_data + typer.echo(f"sampling {sample_count} datapoints") + ExtendedPath(sample_path).write_json(processed_data) + + +@app.command() +def task_ui( + data_name: str = typer.Option("dataname", show_default=True), + dump_dir: Path = Path("./data/asr_data"), + dump_file: Path = Path("ui_dump.json"), + task_count: int = typer.Option(4, show_default=True), + task_file: str = "task_dump", +): + import pandas as pd + import numpy as np + + processed_data_path = dump_dir / Path(data_name) / dump_file + processed_data = ExtendedPath(processed_data_path).read_json() + df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True) + for t_idx, task_f in enumerate(np.array_split(df, task_count)): + task_f = task_f.reset_index(drop=True) + task_f["real_idx"] = task_f.index + task_data = task_f.to_dict("records") + processed_data["data"] = task_data + task_path = dump_dir / Path(data_name) / Path(task_file + f"-{t_idx}.json") + ExtendedPath(task_path).write_json(processed_data) + + +@app.command() +def dump_corrections( + task_uid: str, + data_name: str = typer.Option("dataname", show_default=True), + dump_dir: Path = Path("./data/asr_data"), + dump_fname: Path = Path("corrections.json"), +): + dump_path = dump_dir / Path(data_name) / dump_fname + col = get_mongo_conn(col="asr_validation") + task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0] + corrections = list(col.find({"type": "correction"}, projection={"_id": False})) + cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False}) + corrections = [c for c in cursor_obj] + ExtendedPath(dump_path).write_json(corrections) + + +@app.command() +def caller_quality( + task_uid: str, + data_name: str = typer.Option("dataname", show_default=True), + dump_dir: Path = Path("./data/asr_data"), + dump_fname: Path = Path("ui_dump.json"), + correction_fname: Path = Path("corrections.json"), +): + import copy + import pandas as pd + + dump_path = dump_dir / Path(data_name) / dump_fname + correction_path = dump_dir / Path(data_name) / correction_fname + dump_data = ExtendedPath(dump_path).read_json() + + dump_map = {d["utterance_id"]: d for d in dump_data["data"]} + correction_data = ExtendedPath(correction_path).read_json() + + def correction_dp(c): + dp = copy.deepcopy(dump_map[c["code"]]) + dp["valid"] = c["value"]["status"] == "Correct" + return dp + + corrected_dump = [ + correction_dp(c) + for c in correction_data + if c["task_id"].rsplit("-", 1)[1] == task_uid + ] + df = pd.DataFrame(corrected_dump) + print(f"Total samples: {len(df)}") + for (c, g) in df.groupby("caller"): + total = len(g) + valid = len(g[g["valid"] == True]) + valid_rate = valid * 100 / total + print(f"Caller: {c} Valid%:{valid_rate:.2f} of {total} samples") + + +@app.command() +def fill_unannotated( + data_name: str = typer.Option("dataname", show_default=True), + dump_dir: Path = Path("./data/valiation_data"), + dump_file: Path = Path("ui_dump.json"), + corrections_file: Path = Path("corrections.json"), +): + processed_data_path = dump_dir / Path(data_name) / dump_file + corrections_path = dump_dir / Path(data_name) / corrections_file + processed_data = json.load(processed_data_path.open()) + corrections = json.load(corrections_path.open()) + annotated_codes = {c["code"] for c in corrections} + all_codes = {c["gold_chars"] for c in processed_data} + unann_codes = all_codes - annotated_codes + mongo_conn = get_mongo_conn(col="asr_validation") + for c in unann_codes: + mongo_conn.find_one_and_update( + {"type": "correction", "code": c}, + {"$set": {"value": {"status": "Inaudible", "correction": ""}}}, + upsert=True, + ) + + +@app.command() +def split_extract( + data_name: str = typer.Option("dataname", show_default=True), + # dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True), + # dump_dir: Path = Path("./data/valiation_data"), + dump_dir: Path = Path("./data/asr_data"), + dump_file: Path = Path("ui_dump.json"), + manifest_file: Path = Path("manifest.json"), + corrections_file: str = typer.Option("corrections.json", show_default=True), + conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True), + extraction_type: str = "all", +): + import shutil + + data_manifest_path = dump_dir / Path(data_name) / manifest_file + conv_data = ExtendedPath(conv_data_path).read_json() + + def extract_data_of_type(extraction_key): + extraction_vals = conv_data[extraction_key] + dest_data_name = data_name + "_" + extraction_key.lower() + + manifest_gen = asr_manifest_reader(data_manifest_path) + dest_data_dir = dump_dir / Path(dest_data_name) + dest_data_dir.mkdir(exist_ok=True, parents=True) + (dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True) + dest_manifest_path = dest_data_dir / manifest_file + dest_ui_path = dest_data_dir / dump_file + + def extract_manifest(mg): + for m in mg: + if m["text"] in extraction_vals: + shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"])) + yield m + + asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen)) + + ui_data_path = dump_dir / Path(data_name) / dump_file + orig_ui_data = ExtendedPath(ui_data_path).read_json() + ui_data = orig_ui_data["data"] + file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data} + extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data)) + final_data = [] + for i, d in enumerate(extracted_ui_data): + d['real_idx'] = i + final_data.append(d) + orig_ui_data['data'] = final_data + ExtendedPath(dest_ui_path).write_json(orig_ui_data) + + if corrections_file: + dest_correction_path = dest_data_dir / corrections_file + corrections_path = dump_dir / Path(data_name) / corrections_file + corrections = json.load(corrections_path.open()) + extracted_corrections = list( + filter( + lambda c: c["code"] in file_ui_map + and file_ui_map[c["code"]]["text"] in extraction_vals, + corrections, + ) + ) + ExtendedPath(dest_correction_path).write_json(extracted_corrections) + + if extraction_type.value == 'all': + for ext_key in conv_data.keys(): + extract_data_of_type(ext_key) + else: + extract_data_of_type(extraction_type.value) + + +@app.command() +def update_corrections( + data_name: str = typer.Option("dataname", show_default=True), + dump_dir: Path = Path("./data/asr_data"), + manifest_file: Path = Path("manifest.json"), + corrections_file: Path = Path("corrections.json"), + ui_dump_file: Path = Path("ui_dump.json"), + skip_incorrect: bool = typer.Option(True, show_default=True), +): + data_manifest_path = dump_dir / Path(data_name) / manifest_file + corrections_path = dump_dir / Path(data_name) / corrections_file + ui_dump_path = dump_dir / Path(data_name) / ui_dump_file + + def correct_manifest(ui_dump_path, corrections_path): + corrections = ExtendedPath(corrections_path).read_json() + ui_data = ExtendedPath(ui_dump_path).read_json()['data'] + correct_set = { + c["code"] for c in corrections if c["value"]["status"] == "Correct" + } + # incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"} + correction_map = { + c["code"]: c["value"]["correction"] + for c in corrections + if c["value"]["status"] == "Incorrect" + } + # for d in manifest_data_gen: + # if d["chars"] in incorrect_set: + # d["audio_path"].unlink() + # renamed_set = set() + for d in ui_data: + if d["utterance_id"] in correct_set: + yield { + "audio_filepath": d["audio_filepath"], + "duration": d["duration"], + "text": d["text"], + } + elif d["utterance_id"] in correction_map: + correct_text = correction_map[d["utterance_id"]] + if skip_incorrect: + print( + f'skipping incorrect {d["audio_path"]} corrected to {correct_text}' + ) + else: + orig_audio_path = Path(d["audio_path"]) + new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")) + new_audio_path = orig_audio_path.with_name(new_name) + orig_audio_path.replace(new_audio_path) + new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) + yield { + "audio_filepath": new_filepath, + "duration": d["duration"], + "text": correct_text, + } + else: + orig_audio_path = Path(d["audio_path"]) + # don't delete if another correction points to an old file + # if d["text"] not in renamed_set: + orig_audio_path.unlink() + # else: + # print(f'skipping deletion of correction:{d["text"]}') + + typer.echo(f"Using data manifest:{data_manifest_path}") + dataset_dir = data_manifest_path.parent + dataset_name = dataset_dir.name + backup_dir = dataset_dir.with_name(dataset_name + ".bkp") + if not backup_dir.exists(): + typer.echo(f"backing up to :{backup_dir}") + shutil.copytree(str(dataset_dir), str(backup_dir)) + # manifest_gen = asr_manifest_reader(data_manifest_path) + corrected_manifest = correct_manifest(ui_dump_path, corrections_path) + new_data_manifest_path = data_manifest_path.with_name("manifest.new") + asr_manifest_writer(new_data_manifest_path, corrected_manifest) + new_data_manifest_path.replace(data_manifest_path) + + +@app.command() +def clear_mongo_corrections(): + delete = typer.confirm("are you sure you want to clear mongo collection it?") + if delete: + col = get_mongo_conn(col="asr_validation") + col.delete_many({"type": "correction"}) + col.delete_many({"type": "current_cursor"}) + typer.echo("deleted mongo collection.") + return + typer.echo("Aborted") + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data/validation/st_rerun.py b/jasper/data/validation/st_rerun.py new file mode 100644 index 0000000..ae80624 --- /dev/null +++ b/jasper/data/validation/st_rerun.py @@ -0,0 +1,38 @@ +import streamlit.ReportThread as ReportThread +from streamlit.ScriptRequestQueue import RerunData +from streamlit.ScriptRunner import RerunException +from streamlit.server.Server import Server + + +def rerun(): + """Rerun a Streamlit app from the top!""" + widget_states = _get_widget_states() + raise RerunException(RerunData(widget_states)) + + +def _get_widget_states(): + # Hack to get the session object from Streamlit. + + ctx = ReportThread.get_report_ctx() + + session = None + + current_server = Server.get_current() + if hasattr(current_server, '_session_infos'): + # Streamlit < 0.56 + session_infos = Server.get_current()._session_infos.values() + else: + session_infos = Server.get_current()._session_info_by_id.values() + + for session_info in session_infos: + if session_info.session.enqueue == ctx.enqueue: + session = session_info.session + + if session is None: + raise RuntimeError( + "Oh noes. Couldn't get your Streamlit Session object" + "Are you doing something fancy with threads?" + ) + # Got the session object! + + return session._widget_states diff --git a/jasper/data/validation/ui.py b/jasper/data/validation/ui.py new file mode 100644 index 0000000..3915aeb --- /dev/null +++ b/jasper/data/validation/ui.py @@ -0,0 +1,158 @@ +from pathlib import Path + +import streamlit as st +import typer +from uuid import uuid4 +from ..utils import ExtendedPath, get_mongo_conn +from .st_rerun import rerun + +app = typer.Typer() + + +if not hasattr(st, "mongo_connected"): + st.mongoclient = get_mongo_conn(col="asr_validation") + mongo_conn = st.mongoclient + st.task_id = str(uuid4()) + + def current_cursor_fn(): + # mongo_conn = st.mongoclient + cursor_obj = mongo_conn.find_one( + {"type": "current_cursor", "task_id": st.task_id} + ) + cursor_val = cursor_obj["cursor"] + return cursor_val + + def update_cursor_fn(val=0): + mongo_conn.find_one_and_update( + {"type": "current_cursor", "task_id": st.task_id}, + {"$set": {"type": "current_cursor", "task_id": st.task_id, "cursor": val}}, + upsert=True, + ) + rerun() + + def get_correction_entry_fn(code): + return mongo_conn.find_one( + {"type": "correction", "code": code}, projection={"_id": False} + ) + + def update_entry_fn(code, value): + mongo_conn.find_one_and_update( + {"type": "correction", "code": code}, + {"$set": {"value": value, "task_id": st.task_id}}, + upsert=True, + ) + + def set_task_fn(mf_path): + task_path = mf_path.parent / Path(f"task-{st.task_id}.lck") + if not task_path.exists(): + print(f"creating task lock at {task_path}") + task_path.touch() + + st.get_current_cursor = current_cursor_fn + st.update_cursor = update_cursor_fn + st.get_correction_entry = get_correction_entry_fn + st.update_entry = update_entry_fn + st.set_task = set_task_fn + st.mongo_connected = True + cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id}) + if not cursor_obj: + update_cursor_fn(0) + + +@st.cache() +def load_ui_data(validation_ui_data_path: Path): + typer.echo(f"Using validation ui data from {validation_ui_data_path}") + return ExtendedPath(validation_ui_data_path).read_json() + + +@app.command() +def main(manifest: Path): + st.set_task(manifest) + ui_config = load_ui_data(manifest) + asr_data = ui_config["data"] + use_domain_asr = ui_config.get("use_domain_asr", True) + annotation_only = ui_config.get("annotation_only", False) + enable_plots = ui_config.get("enable_plots", True) + sample_no = st.get_current_cursor() + if len(asr_data) - 1 < sample_no or sample_no < 0: + print("Invalid samplno resetting to 0") + st.update_cursor(0) + sample = asr_data[sample_no] + title_type = "Speller " if use_domain_asr else "" + task_uid = st.task_id.rsplit("-", 1)[1] + if annotation_only: + st.title(f"ASR Annotation - # {task_uid}") + else: + st.title(f"ASR {title_type}Validation - # {task_uid}") + addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else "" + st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text) + new_sample = st.number_input( + "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) + ) + if new_sample != sample_no + 1: + st.update_cursor(new_sample - 1) + st.sidebar.title(f"Details: [{sample['real_idx']}]") + st.sidebar.markdown(f"Gold Text: **{sample['text']}**") + if not annotation_only: + if use_domain_asr: + st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*") + st.sidebar.title("Results:") + st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**") + if "caller" in sample: + st.sidebar.markdown(f"Caller: **{sample['caller']}**") + if use_domain_asr: + st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**") + st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%") + else: + st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") + if enable_plots: + st.sidebar.image(Path(sample["plot_path"]).read_bytes()) + st.audio(Path(sample["audio_path"]).open("rb")) + # set default to text + corrected = sample["text"] + correction_entry = st.get_correction_entry(sample["utterance_id"]) + selected_idx = 0 + options = ("Correct", "Incorrect", "Inaudible") + # if correction entry is present set the corresponding ui defaults + if correction_entry: + selected_idx = options.index(correction_entry["value"]["status"]) + corrected = correction_entry["value"]["correction"] + selected = st.radio("The Audio is", options, index=selected_idx) + if selected == "Incorrect": + corrected = st.text_input("Actual:", value=corrected) + if selected == "Inaudible": + corrected = "" + if st.button("Submit"): + st.update_entry( + sample["utterance_id"], {"status": selected, "correction": corrected} + ) + st.update_cursor(sample_no + 1) + if correction_entry: + st.markdown( + f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' + ) + text_sample = st.text_input("Go to Text:", value="") + if text_sample != "": + candidates = [ + i + for (i, p) in enumerate(asr_data) + if p["text"] == text_sample or p["spoken"] == text_sample + ] + if len(candidates) > 0: + st.update_cursor(candidates[0]) + real_idx = st.number_input( + "Go to real-index", + value=sample["real_idx"], + min_value=0, + max_value=len(asr_data) - 1, + ) + if real_idx != int(sample["real_idx"]): + idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0] + st.update_cursor(idx) + + +if __name__ == "__main__": + try: + app() + except SystemExit: + pass diff --git a/jasper/evaluate.py b/jasper/evaluate.py new file mode 100644 index 0000000..94d8f43 --- /dev/null +++ b/jasper/evaluate.py @@ -0,0 +1,359 @@ +# Copyright (c) 2019 NVIDIA Corporation +import argparse +import copy +# import math +import os +from pathlib import Path +from functools import partial + +from ruamel.yaml import YAML + +import nemo +import nemo.collections.asr as nemo_asr +import nemo.utils.argparse as nm_argparse +from nemo.collections.asr.helpers import ( + # monitor_asr_train_progress, + process_evaluation_batch, + process_evaluation_epoch, +) + +# from nemo.utils.lr_policies import CosineAnnealing +from training.data_loaders import RpycAudioToTextDataLayer + +logging = nemo.logging + + +def parse_args(): + parser = argparse.ArgumentParser( + parents=[nm_argparse.NemoArgParser()], + description="Jasper", + conflict_handler="resolve", + ) + parser.set_defaults( + checkpoint_dir=None, + optimizer="novograd", + batch_size=64, + eval_batch_size=64, + lr=0.002, + amp_opt_level="O1", + create_tb_writer=True, + model_config="./train/jasper10x5dr.yaml", + work_dir="./train/work", + num_epochs=300, + weight_decay=0.005, + checkpoint_save_freq=100, + eval_freq=100, + load_dir="./train/models/jasper/", + warmup_steps=3, + exp_name="jasper-speller", + ) + + # Overwrite default args + parser.add_argument( + "--max_steps", + type=int, + default=None, + required=False, + help="max number of steps to train", + ) + parser.add_argument( + "--num_epochs", type=int, required=False, help="number of epochs to train" + ) + parser.add_argument( + "--model_config", + type=str, + required=False, + help="model configuration file: model.yaml", + ) + parser.add_argument( + "--encoder_checkpoint", + type=str, + required=True, + help="encoder checkpoint file: JasperEncoder.pt", + ) + parser.add_argument( + "--decoder_checkpoint", + type=str, + required=True, + help="decoder checkpoint file: JasperDecoderForCTC.pt", + ) + parser.add_argument( + "--remote_data", + type=str, + required=False, + default="", + help="remote dataloader endpoint", + ) + parser.add_argument( + "--dataset", + type=str, + required=False, + default="", + help="dataset directory containing train/test manifests", + ) + + # Create new args + parser.add_argument("--exp_name", default="Jasper", type=str) + parser.add_argument("--beta1", default=0.95, type=float) + parser.add_argument("--beta2", default=0.25, type=float) + parser.add_argument("--warmup_steps", default=0, type=int) + parser.add_argument( + "--load_dir", + default=None, + type=str, + help="directory with pre-trained checkpoint", + ) + + args = parser.parse_args() + if args.max_steps is None and args.num_epochs is None: + raise ValueError("Either max_steps or num_epochs should be provided.") + return args + + +def construct_name( + name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step +): + if max_steps is not None: + return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format( + name, lr, batch_size, max_steps, wd, optimizer, iter_per_step + ) + else: + return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format( + name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step + ) + + +def create_all_dags(args, neural_factory): + yaml = YAML(typ="safe") + with open(args.model_config) as f: + jasper_params = yaml.load(f) + vocab = jasper_params["labels"] + sample_rate = jasper_params["sample_rate"] + + # Calculate num_workers for dataloader + total_cpus = os.cpu_count() + cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1) + # perturb_config = jasper_params.get('perturb', None) + train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) + train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"]) + del train_dl_params["train"] + del train_dl_params["eval"] + # del train_dl_params["normalize_transcripts"] + + if args.dataset: + d_path = Path(args.dataset) + if not args.train_dataset: + args.train_dataset = str(d_path / Path("train_manifest.json")) + if not args.eval_datasets: + args.eval_datasets = [str(d_path / Path("test_manifest.json"))] + + data_loader_layer = nemo_asr.AudioToTextDataLayer + + if args.remote_data: + train_dl_params["rpyc_host"] = args.remote_data + data_loader_layer = RpycAudioToTextDataLayer + + # data_layer = data_loader_layer( + # manifest_filepath=args.train_dataset, + # sample_rate=sample_rate, + # labels=vocab, + # batch_size=args.batch_size, + # num_workers=cpu_per_traindl, + # **train_dl_params, + # # normalize_transcripts=False + # ) + # + # N = len(data_layer) + # steps_per_epoch = math.ceil( + # N / (args.batch_size * args.iter_per_step * args.num_gpus) + # ) + # logging.info("Have {0} examples to train on.".format(N)) + # + data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor( + sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"] + ) + + # multiply_batch_config = jasper_params.get("MultiplyBatch", None) + # if multiply_batch_config: + # multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config) + # + # spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None) + # if spectr_augment_config: + # data_spectr_augmentation = nemo_asr.SpectrogramAugmentation( + # **spectr_augment_config + # ) + # + eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) + eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"]) + if args.remote_data: + eval_dl_params["rpyc_host"] = args.remote_data + del eval_dl_params["train"] + del eval_dl_params["eval"] + data_layers_eval = [] + + # if args.eval_datasets: + for eval_datasets in args.eval_datasets: + data_layer_eval = data_loader_layer( + manifest_filepath=eval_datasets, + sample_rate=sample_rate, + labels=vocab, + batch_size=args.eval_batch_size, + num_workers=cpu_per_traindl, + **eval_dl_params, + ) + + data_layers_eval.append(data_layer_eval) + # else: + # logging.warning("There were no val datasets passed") + + jasper_encoder = nemo_asr.JasperEncoder( + feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"], + **jasper_params["JasperEncoder"], + ) + jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0) + + jasper_decoder = nemo_asr.JasperDecoderForCTC( + feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], + num_classes=len(vocab), + ) + jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0) + + ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab)) + + greedy_decoder = nemo_asr.GreedyCTCDecoder() + + # logging.info("================================") + # logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}") + # logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}") + # logging.info( + # f"Total number of parameters in model: " + # f"{jasper_decoder.num_weights + jasper_encoder.num_weights}" + # ) + # logging.info("================================") + # + # # Train DAG + # (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer() + # processed_signal_t, p_length_t = data_preprocessor( + # input_signal=audio_signal_t, length=a_sig_length_t + # ) + # + # if multiply_batch_config: + # ( + # processed_signal_t, + # p_length_t, + # transcript_t, + # transcript_len_t, + # ) = multiply_batch( + # in_x=processed_signal_t, + # in_x_len=p_length_t, + # in_y=transcript_t, + # in_y_len=transcript_len_t, + # ) + # + # if spectr_augment_config: + # processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t) + # + # encoded_t, encoded_len_t = jasper_encoder( + # audio_signal=processed_signal_t, length=p_length_t + # ) + # log_probs_t = jasper_decoder(encoder_output=encoded_t) + # predictions_t = greedy_decoder(log_probs=log_probs_t) + # loss_t = ctc_loss( + # log_probs=log_probs_t, + # targets=transcript_t, + # input_length=encoded_len_t, + # target_length=transcript_len_t, + # ) + # + # # Callbacks needed to print info to console and Tensorboard + # train_callback = nemo.core.SimpleLossLoggerCallback( + # tensors=[loss_t, predictions_t, transcript_t, transcript_len_t], + # print_func=partial(monitor_asr_train_progress, labels=vocab), + # get_tb_values=lambda x: [("loss", x[0])], + # tb_writer=neural_factory.tb_writer, + # ) + # + # chpt_callback = nemo.core.CheckpointCallback( + # folder=neural_factory.checkpoint_dir, + # load_from_folder=args.load_dir, + # step_freq=args.checkpoint_save_freq, + # checkpoints_to_keep=30, + # ) + # + # callbacks = [train_callback, chpt_callback] + callbacks = [] + # assemble eval DAGs + for i, eval_dl in enumerate(data_layers_eval): + (audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl() + processed_signal_e, p_length_e = data_preprocessor( + input_signal=audio_signal_e, length=a_sig_length_e + ) + encoded_e, encoded_len_e = jasper_encoder( + audio_signal=processed_signal_e, length=p_length_e + ) + log_probs_e = jasper_decoder(encoder_output=encoded_e) + predictions_e = greedy_decoder(log_probs=log_probs_e) + loss_e = ctc_loss( + log_probs=log_probs_e, + targets=transcript_e, + input_length=encoded_len_e, + target_length=transcript_len_e, + ) + + # create corresponding eval callback + tagname = os.path.basename(args.eval_datasets[i]).split(".")[0] + eval_callback = nemo.core.EvaluatorCallback( + eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e], + user_iter_callback=partial(process_evaluation_batch, labels=vocab), + user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname), + eval_step=args.eval_freq, + tb_writer=neural_factory.tb_writer, + ) + + callbacks.append(eval_callback) + return callbacks + + +def main(): + args = parse_args() + # name = construct_name( + # args.exp_name, + # args.lr, + # args.batch_size, + # args.max_steps, + # args.num_epochs, + # args.weight_decay, + # args.optimizer, + # args.iter_per_step, + # ) + # log_dir = name + # if args.work_dir: + # log_dir = os.path.join(args.work_dir, name) + + # instantiate Neural Factory with supported backend + neural_factory = nemo.core.NeuralModuleFactory( + placement=nemo.core.DeviceType.GPU, + backend=nemo.core.Backend.PyTorch, + # local_rank=args.local_rank, + # optimization_level=args.amp_opt_level, + # log_dir=log_dir, + # checkpoint_dir=args.checkpoint_dir, + # create_tb_writer=args.create_tb_writer, + # files_to_copy=[args.model_config, __file__], + # cudnn_benchmark=args.cudnn_benchmark, + # tensorboard_dir=args.tensorboard_dir, + ) + args.num_gpus = neural_factory.world_size + + # checkpoint_dir = neural_factory.checkpoint_dir + if args.local_rank is not None: + logging.info("Doing ALL GPU") + + # build dags + callbacks = create_all_dags(args, neural_factory) + # evaluate model + neural_factory.eval(callbacks=callbacks) + + +if __name__ == "__main__": + main() diff --git a/jasper/training/__init__.py b/jasper/training/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/jasper/training/__init__.py @@ -0,0 +1 @@ + diff --git a/jasper/training/cli.py b/jasper/training/cli.py new file mode 100644 index 0000000..7ef9beb --- /dev/null +++ b/jasper/training/cli.py @@ -0,0 +1,366 @@ +# Copyright (c) 2019 NVIDIA Corporation +import argparse +import copy +import math +import os +from pathlib import Path +from functools import partial + +from ruamel.yaml import YAML + +import nemo +import nemo.collections.asr as nemo_asr +import nemo.utils.argparse as nm_argparse +from nemo.collections.asr.helpers import ( + monitor_asr_train_progress, + process_evaluation_batch, + process_evaluation_epoch, +) + +from nemo.utils.lr_policies import CosineAnnealing +from .data_loaders import RpycAudioToTextDataLayer + +logging = nemo.logging + + +def parse_args(): + parser = argparse.ArgumentParser( + parents=[nm_argparse.NemoArgParser()], + description="Jasper", + conflict_handler="resolve", + ) + parser.set_defaults( + checkpoint_dir=None, + optimizer="novograd", + batch_size=64, + eval_batch_size=64, + lr=0.002, + amp_opt_level="O1", + create_tb_writer=True, + model_config="./train/jasper10x5dr.yaml", + work_dir="./train/work", + num_epochs=300, + weight_decay=0.005, + checkpoint_save_freq=100, + eval_freq=100, + load_dir="./train/models/jasper/", + warmup_steps=3, + exp_name="jasper-speller", + ) + + # Overwrite default args + parser.add_argument( + "--max_steps", + type=int, + default=None, + required=False, + help="max number of steps to train", + ) + parser.add_argument( + "--num_epochs", + type=int, + required=False, + help="number of epochs to train", + ) + parser.add_argument( + "--model_config", + type=str, + required=False, + help="model configuration file: model.yaml", + ) + parser.add_argument( + "--remote_data", + type=str, + required=False, + default="", + help="remote dataloader endpoint", + ) + parser.add_argument( + "--dataset", + type=str, + required=False, + default="", + help="dataset directory containing train/test manifests", + ) + + # Create new args + parser.add_argument("--exp_name", default="Jasper", type=str) + parser.add_argument("--beta1", default=0.95, type=float) + parser.add_argument("--beta2", default=0.25, type=float) + parser.add_argument("--warmup_steps", default=0, type=int) + parser.add_argument( + "--load_dir", + default=None, + type=str, + help="directory with pre-trained checkpoint", + ) + + args = parser.parse_args() + if args.max_steps is None and args.num_epochs is None: + raise ValueError("Either max_steps or num_epochs should be provided.") + return args + + +def construct_name( + name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step +): + if max_steps is not None: + return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format( + name, lr, batch_size, max_steps, wd, optimizer, iter_per_step + ) + else: + return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format( + name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step + ) + + +def create_all_dags(args, neural_factory): + yaml = YAML(typ="safe") + with open(args.model_config) as f: + jasper_params = yaml.load(f) + vocab = jasper_params["labels"] + sample_rate = jasper_params["sample_rate"] + + # Calculate num_workers for dataloader + total_cpus = os.cpu_count() + cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1) + # perturb_config = jasper_params.get('perturb', None) + train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) + train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"]) + del train_dl_params["train"] + del train_dl_params["eval"] + # del train_dl_params["normalize_transcripts"] + + if args.dataset: + d_path = Path(args.dataset) + if not args.train_dataset: + args.train_dataset = str(d_path / Path("train_manifest.json")) + if not args.eval_datasets: + args.eval_datasets = [str(d_path / Path("test_manifest.json"))] + + data_loader_layer = nemo_asr.AudioToTextDataLayer + + if args.remote_data: + train_dl_params["rpyc_host"] = args.remote_data + data_loader_layer = RpycAudioToTextDataLayer + + data_layer = data_loader_layer( + manifest_filepath=args.train_dataset, + sample_rate=sample_rate, + labels=vocab, + batch_size=args.batch_size, + num_workers=cpu_per_traindl, + **train_dl_params, + # normalize_transcripts=False + ) + + N = len(data_layer) + steps_per_epoch = math.ceil( + N / (args.batch_size * args.iter_per_step * args.num_gpus) + ) + logging.info("Have {0} examples to train on.".format(N)) + + data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor( + sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"] + ) + + multiply_batch_config = jasper_params.get("MultiplyBatch", None) + if multiply_batch_config: + multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config) + + spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None) + if spectr_augment_config: + data_spectr_augmentation = nemo_asr.SpectrogramAugmentation( + **spectr_augment_config + ) + + eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) + eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"]) + if args.remote_data: + eval_dl_params["rpyc_host"] = args.remote_data + del eval_dl_params["train"] + del eval_dl_params["eval"] + data_layers_eval = [] + + if args.eval_datasets: + for eval_datasets in args.eval_datasets: + data_layer_eval = data_loader_layer( + manifest_filepath=eval_datasets, + sample_rate=sample_rate, + labels=vocab, + batch_size=args.eval_batch_size, + num_workers=cpu_per_traindl, + **eval_dl_params, + ) + + data_layers_eval.append(data_layer_eval) + else: + logging.warning("There were no val datasets passed") + + jasper_encoder = nemo_asr.JasperEncoder( + feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"], + **jasper_params["JasperEncoder"], + ) + + jasper_decoder = nemo_asr.JasperDecoderForCTC( + feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], + num_classes=len(vocab), + ) + + ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab)) + + greedy_decoder = nemo_asr.GreedyCTCDecoder() + + logging.info("================================") + logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}") + logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}") + logging.info( + f"Total number of parameters in model: " + f"{jasper_decoder.num_weights + jasper_encoder.num_weights}" + ) + logging.info("================================") + + # Train DAG + (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer() + processed_signal_t, p_length_t = data_preprocessor( + input_signal=audio_signal_t, length=a_sig_length_t + ) + + if multiply_batch_config: + ( + processed_signal_t, + p_length_t, + transcript_t, + transcript_len_t, + ) = multiply_batch( + in_x=processed_signal_t, + in_x_len=p_length_t, + in_y=transcript_t, + in_y_len=transcript_len_t, + ) + + if spectr_augment_config: + processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t) + + encoded_t, encoded_len_t = jasper_encoder( + audio_signal=processed_signal_t, length=p_length_t + ) + log_probs_t = jasper_decoder(encoder_output=encoded_t) + predictions_t = greedy_decoder(log_probs=log_probs_t) + loss_t = ctc_loss( + log_probs=log_probs_t, + targets=transcript_t, + input_length=encoded_len_t, + target_length=transcript_len_t, + ) + + # Callbacks needed to print info to console and Tensorboard + train_callback = nemo.core.SimpleLossLoggerCallback( + tensors=[loss_t, predictions_t, transcript_t, transcript_len_t], + print_func=partial(monitor_asr_train_progress, labels=vocab), + get_tb_values=lambda x: [("loss", x[0])], + tb_writer=neural_factory.tb_writer, + ) + + chpt_callback = nemo.core.CheckpointCallback( + folder=neural_factory.checkpoint_dir, + load_from_folder=args.load_dir, + step_freq=args.checkpoint_save_freq, + checkpoints_to_keep=30, + ) + + callbacks = [train_callback, chpt_callback] + + # assemble eval DAGs + for i, eval_dl in enumerate(data_layers_eval): + (audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl() + processed_signal_e, p_length_e = data_preprocessor( + input_signal=audio_signal_e, length=a_sig_length_e + ) + encoded_e, encoded_len_e = jasper_encoder( + audio_signal=processed_signal_e, length=p_length_e + ) + log_probs_e = jasper_decoder(encoder_output=encoded_e) + predictions_e = greedy_decoder(log_probs=log_probs_e) + loss_e = ctc_loss( + log_probs=log_probs_e, + targets=transcript_e, + input_length=encoded_len_e, + target_length=transcript_len_e, + ) + + # create corresponding eval callback + tagname = os.path.basename(args.eval_datasets[i]).split(".")[0] + eval_callback = nemo.core.EvaluatorCallback( + eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e], + user_iter_callback=partial(process_evaluation_batch, labels=vocab), + user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname), + eval_step=args.eval_freq, + tb_writer=neural_factory.tb_writer, + ) + + callbacks.append(eval_callback) + return loss_t, callbacks, steps_per_epoch + + +def main(): + args = parse_args() + name = construct_name( + args.exp_name, + args.lr, + args.batch_size, + args.max_steps, + args.num_epochs, + args.weight_decay, + args.optimizer, + args.iter_per_step, + ) + log_dir = name + if args.work_dir: + log_dir = os.path.join(args.work_dir, name) + + # instantiate Neural Factory with supported backend + neural_factory = nemo.core.NeuralModuleFactory( + backend=nemo.core.Backend.PyTorch, + local_rank=args.local_rank, + optimization_level=args.amp_opt_level, + log_dir=log_dir, + checkpoint_dir=args.checkpoint_dir, + create_tb_writer=args.create_tb_writer, + files_to_copy=[args.model_config, __file__], + cudnn_benchmark=args.cudnn_benchmark, + tensorboard_dir=args.tensorboard_dir, + ) + args.num_gpus = neural_factory.world_size + + checkpoint_dir = neural_factory.checkpoint_dir + if args.local_rank is not None: + logging.info("Doing ALL GPU") + + # build dags + train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory) + # train model + neural_factory.train( + tensors_to_optimize=[train_loss], + callbacks=callbacks, + lr_policy=CosineAnnealing( + args.max_steps + if args.max_steps is not None + else args.num_epochs * steps_per_epoch, + warmup_steps=args.warmup_steps, + ), + optimizer=args.optimizer, + optimization_params={ + "num_epochs": args.num_epochs, + "max_steps": args.max_steps, + "lr": args.lr, + "betas": (args.beta1, args.beta2), + "weight_decay": args.weight_decay, + "grad_norm_clip": None, + }, + batches_per_step=args.iter_per_step, + ) + + +if __name__ == "__main__": + main() diff --git a/jasper/training/data_loaders.py b/jasper/training/data_loaders.py new file mode 100644 index 0000000..d181dfa --- /dev/null +++ b/jasper/training/data_loaders.py @@ -0,0 +1,334 @@ +from functools import partial +import tempfile + +# from typing import Any, Dict, List, Optional + +import torch +import nemo + +# import nemo.collections.asr as nemo_asr +from nemo.backends.pytorch import DataLayerNM +from nemo.core import DeviceType + +# from nemo.core.neural_types import * +from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType +from nemo.utils.decorators import add_port_docs + +from nemo.collections.asr.parts.dataset import ( + # AudioDataset, + # AudioLabelDataset, + # KaldiFeatureDataset, + # TranscriptDataset, + parsers, + collections, + seq_collate_fn, +) + +# from functools import lru_cache +import rpyc +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +from .featurizer import RpycWaveformFeaturizer + +# from nemo.collections.asr.parts.features import WaveformFeaturizer + +# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types + + +logging = nemo.logging + + +class CachedAudioDataset(torch.utils.data.Dataset): + """ + Dataset that loads tensors via a json file containing paths to audio + files, transcripts, and durations (in seconds). Each new line is a + different sample. Example below: + + {"audio_filepath": "/path/to/audio.wav", "text_filepath": + "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the + transcription", offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + + Args: + manifest_filepath: Path to manifest json as described above. Can + be comma-separated paths. + labels: String containing all the possible characters to map to + featurizer: Initialized featurizer class that converts paths of + audio to feature tensors + max_duration: If audio exceeds this length, do not include in dataset + min_duration: If audio is less than this length, do not include + in dataset + max_utts: Limit number of utterances + blank_index: blank character index, default = -1 + unk_index: unk_character index, default = -1 + normalize: whether to normalize transcript text (default): True + bos_id: Id of beginning of sequence symbol to append if not None + eos_id: Id of end of sequence symbol to append if not None + load_audio: Boolean flag indicate whether do or not load audio + """ + + def __init__( + self, + manifest_filepath, + labels, + featurizer, + max_duration=None, + min_duration=None, + max_utts=0, + blank_index=-1, + unk_index=-1, + normalize=True, + trim=False, + bos_id=None, + eos_id=None, + load_audio=True, + parser="en", + ): + self.collection = collections.ASRAudioText( + manifests_files=manifest_filepath.split(","), + parser=parsers.make_parser( + labels=labels, + name=parser, + unk_id=unk_index, + blank_id=blank_index, + do_normalize=normalize, + ), + min_duration=min_duration, + max_duration=max_duration, + max_number=max_utts, + ) + self.index_feature_map = {} + + self.featurizer = featurizer + self.trim = trim + self.eos_id = eos_id + self.bos_id = bos_id + self.load_audio = load_audio + print(f"initializing dataset {manifest_filepath}") + + def exec_func(i): + return self[i] + + task_count = len(self.collection) + with ThreadPoolExecutor() as exe: + print("starting all loading tasks") + list( + tqdm( + exe.map(exec_func, range(task_count)), + position=0, + leave=True, + total=task_count, + ) + ) + print(f"initializing complete") + + def __getitem__(self, index): + sample = self.collection[index] + if self.load_audio: + cached_features = self.index_feature_map.get(index) + if cached_features is not None: + features = cached_features + else: + features = self.featurizer.process( + sample.audio_file, + offset=0, + duration=sample.duration, + trim=self.trim, + ) + self.index_feature_map[index] = features + f, fl = features, torch.tensor(features.shape[0]).long() + else: + f, fl = None, None + + t, tl = sample.text_tokens, len(sample.text_tokens) + if self.bos_id is not None: + t = [self.bos_id] + t + tl += 1 + if self.eos_id is not None: + t = t + [self.eos_id] + tl += 1 + + return f, fl, torch.tensor(t).long(), torch.tensor(tl).long() + + def __len__(self): + return len(self.collection) + + +class RpycAudioToTextDataLayer(DataLayerNM): + """Data Layer for general ASR tasks. + + Module which reads ASR labeled data. It accepts comma-separated + JSON manifest files describing the correspondence between wav audio files + and their transcripts. JSON files should be of the following format:: + + {"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \ +transcript_0} + ... + {"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \ +transcript_n} + + Args: + manifest_filepath (str): Dataset parameter. + Path to JSON containing data. + labels (list): Dataset parameter. + List of characters that can be output by the ASR model. + For Jasper, this is the 28 character set {a-z '}. The CTC blank + symbol is automatically added later for models using ctc. + batch_size (int): batch size + sample_rate (int): Target sampling rate for data. Audio files will be + resampled to sample_rate if it is not already. + Defaults to 16000. + int_values (bool): Bool indicating whether the audio file is saved as + int data or float data. + Defaults to False. + eos_id (id): Dataset parameter. + End of string symbol id used for seq2seq models. + Defaults to None. + min_duration (float): Dataset parameter. + All training files which have a duration less than min_duration + are dropped. Note: Duration is read from the manifest JSON. + Defaults to 0.1. + max_duration (float): Dataset parameter. + All training files which have a duration more than max_duration + are dropped. Note: Duration is read from the manifest JSON. + Defaults to None. + normalize_transcripts (bool): Dataset parameter. + Whether to use automatic text cleaning. + It is highly recommended to manually clean text for best results. + Defaults to True. + trim_silence (bool): Whether to use trim silence from beginning and end + of audio signal using librosa.effects.trim(). + Defaults to False. + load_audio (bool): Dataset parameter. + Controls whether the dataloader loads the audio signal and + transcript or just the transcript. + Defaults to True. + drop_last (bool): See PyTorch DataLoader. + Defaults to False. + shuffle (bool): See PyTorch DataLoader. + Defaults to True. + num_workers (int): See PyTorch DataLoader. + Defaults to 0. + perturb_config (dict): Currently disabled. + """ + + @property + @add_port_docs() + def output_ports(self): + """Returns definitions of module output ports. + """ + return { + # 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}), + # 'a_sig_length': NeuralType({0: AxisType(BatchTag)}), + # 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}), + # 'transcript_length': NeuralType({0: AxisType(BatchTag)}), + "audio_signal": NeuralType( + ("B", "T"), + AudioSignal(freq=self._sample_rate) + if self is not None and self._sample_rate is not None + else AudioSignal(), + ), + "a_sig_length": NeuralType(tuple("B"), LengthsType()), + "transcripts": NeuralType(("B", "T"), LabelsType()), + "transcript_length": NeuralType(tuple("B"), LengthsType()), + } + + def __init__( + self, + manifest_filepath, + labels, + batch_size, + sample_rate=16000, + int_values=False, + bos_id=None, + eos_id=None, + pad_id=None, + min_duration=0.1, + max_duration=None, + normalize_transcripts=True, + trim_silence=False, + load_audio=True, + rpyc_host="", + drop_last=False, + shuffle=True, + num_workers=0, + ): + super().__init__() + self._sample_rate = sample_rate + + def rpyc_root_fn(): + return rpyc.connect( + rpyc_host, 8064, config={"sync_request_timeout": 600} + ).root + + rpyc_conn = rpyc_root_fn() + + self._featurizer = RpycWaveformFeaturizer( + sample_rate=self._sample_rate, + int_values=int_values, + augmentor=None, + rpyc_conn=rpyc_conn, + ) + + def read_remote_manifests(): + local_mp = [] + for mrp in manifest_filepath.split(","): + md = rpyc_conn.read_path(mrp) + mf = tempfile.NamedTemporaryFile( + dir="/tmp", prefix="jasper_manifest.", delete=False + ) + mf.write(md) + mf.close() + local_mp.append(mf.name) + return ",".join(local_mp) + + local_manifest_filepath = read_remote_manifests() + dataset_params = { + "manifest_filepath": local_manifest_filepath, + "labels": labels, + "featurizer": self._featurizer, + "max_duration": max_duration, + "min_duration": min_duration, + "normalize": normalize_transcripts, + "trim": trim_silence, + "bos_id": bos_id, + "eos_id": eos_id, + "load_audio": load_audio, + } + + self._dataset = CachedAudioDataset(**dataset_params) + self._batch_size = batch_size + + # Set up data loader + if self._placement == DeviceType.AllGpu: + logging.info("Parallelizing Datalayer.") + sampler = torch.utils.data.distributed.DistributedSampler(self._dataset) + else: + sampler = None + + if batch_size == -1: + batch_size = len(self._dataset) + + pad_id = 0 if pad_id is None else pad_id + self._dataloader = torch.utils.data.DataLoader( + dataset=self._dataset, + batch_size=batch_size, + collate_fn=partial(seq_collate_fn, token_pad_value=pad_id), + drop_last=drop_last, + shuffle=shuffle if sampler is None else False, + sampler=sampler, + num_workers=1, + ) + + def __len__(self): + return len(self._dataset) + + @property + def dataset(self): + return None + + @property + def data_iterator(self): + return self._dataloader diff --git a/jasper/training/featurizer.py b/jasper/training/featurizer.py new file mode 100644 index 0000000..030eb36 --- /dev/null +++ b/jasper/training/featurizer.py @@ -0,0 +1,51 @@ +# import math + +# import librosa +import torch +import pickle +# import torch.nn as nn +# from torch_stft import STFT + +# from nemo import logging +from nemo.collections.asr.parts.perturb import AudioAugmentor +# from nemo.collections.asr.parts.segment import AudioSegment + + +class RpycWaveformFeaturizer(object): + def __init__( + self, sample_rate=16000, int_values=False, augmentor=None, rpyc_conn=None + ): + self.augmentor = augmentor if augmentor is not None else AudioAugmentor() + self.sample_rate = sample_rate + self.int_values = int_values + self.remote_path_samples = rpyc_conn.get_path_samples + + def max_augmentation_length(self, length): + return self.augmentor.max_augmentation_length(length) + + def process(self, file_path, offset=0, duration=0, trim=False): + audio = self.remote_path_samples( + file_path, + target_sr=self.sample_rate, + int_values=self.int_values, + offset=offset, + duration=duration, + trim=trim, + ) + return torch.tensor(pickle.loads(audio), dtype=torch.float) + + def process_segment(self, audio_segment): + self.augmentor.perturb(audio_segment) + return torch.tensor(audio_segment, dtype=torch.float) + + @classmethod + def from_config(cls, input_config, perturbation_configs=None): + if perturbation_configs is not None: + aa = AudioAugmentor.from_config(perturbation_configs) + else: + aa = None + + sample_rate = input_config.get("sample_rate", 16000) + int_values = input_config.get("int_values", False) + + return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa) diff --git a/setup.py b/setup.py index b18df24..cfb4c03 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,52 @@ -from setuptools import setup +from setuptools import setup, find_packages requirements = [ "ruamel.yaml", + "torch==1.4.0", + "torchvision==0.5.0", "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit", ] -extra_requirements = {"server": ["rpyc==4.1.4"]} +extra_requirements = { + "server": ["rpyc~=4.1.4", "tqdm~=4.39.0"], + "data": [ + "google-cloud-texttospeech~=1.0.1", + "tqdm~=4.39.0", + "pydub~=0.24.0", + "scikit_learn~=0.22.1", + "pandas~=1.0.3", + "boto3~=1.12.35", + "ruamel.yaml==0.16.10", + "pymongo==3.10.1", + "librosa==0.7.2", + "matplotlib==3.2.1", + "pandas==1.0.3", + "tabulate==0.8.7", + "natural==0.2.0", + "num2words==0.5.10", + "typer[all]==0.1.1", + "python-slugify==4.0.0", + "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses", + ], + "validation": [ + "rpyc~=4.1.4", + "pymongo==3.10.1", + "typer[all]==0.1.1", + "tqdm~=4.39.0", + "librosa==0.7.2", + "matplotlib==3.2.1", + "pydub~=0.24.0", + "streamlit==0.58.0", + "natural==0.2.0", + "stringcase==1.2.0", + "google-cloud-speech~=1.3.1", + ] + # "train": [ + # "torchaudio==0.5.0", + # "torch-stft==0.1.4", + # ] +} +packages = find_packages() setup( name="jasper-asr", @@ -17,11 +58,24 @@ setup( license="MIT", install_requires=requirements, extras_require=extra_requirements, - packages=["."], + packages=packages, entry_points={ "console_scripts": [ "jasper_transcribe = jasper.transcribe:main", - "jasper_asr_rpyc_server = jasper.server:main", + "jasper_server = jasper.server:main", + "jasper_trainer = jasper.training.cli:main", + "jasper_evaluator = jasper.evaluate:main", + "jasper_data_tts_generate = jasper.data.tts_generator:main", + "jasper_data_conv_generate = jasper.data.conv_generator:main", + "jasper_data_nlu_generate = jasper.data.nlu_generator:main", + "jasper_data_test_generate = jasper.data.test_generator:main", + "jasper_data_call_recycle = jasper.data.call_recycler:main", + "jasper_data_asr_recycle = jasper.data.asr_recycler:main", + "jasper_data_rev_recycle = jasper.data.rev_recycler:main", + "jasper_data_server = jasper.data.server:main", + "jasper_data_validation = jasper.data.validation.process:main", + "jasper_data_preprocess = jasper.data.process:main", + "jasper_data_slu_evaluate = jasper.data.slu_evaluator:main", ] }, zip_safe=False, diff --git a/validation_ui.py b/validation_ui.py new file mode 100644 index 0000000..b45692e --- /dev/null +++ b/validation_ui.py @@ -0,0 +1,3 @@ +import runpy + +runpy.run_module("jasper.data.validation.ui", run_name="__main__", alter_sys=True)