From ed6117559a26b3c2dc9f08542b8da47a7b087b90 Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Tue, 23 Feb 2021 19:43:33 +0530 Subject: [PATCH] massive refactor/rename to plume --- Notes.md | 14 + README.md | 22 +- jasper/__init__.py | 1 - jasper/client.py | 21 - jasper/data/__init__.py | 1 - jasper/data/process.py | 77 -- jasper/data/rastrik_recycler.py | 93 --- jasper/data/utils.py | 241 ------ jasper/data/validation/__init__.py | 1 - jasper/data/validation/process.py | 398 ---------- jasper/server.py | 57 -- jasper/training/__init__.py | 1 - jasper/transcribe.py | 22 - jasper/utils.py | 40 - plume/cli/__init__.py | 23 + plume/cli/data/__init__.py | 339 ++++++++ plume/cli/data/generate.py | 12 + plume/cli/eval.py | 5 + plume/cli/serve.py | 7 + plume/cli/train.py | 5 + plume/models/__init__.py | 1 + plume/models/jasper/__init__.py | 0 {jasper => plume/models/jasper}/asr.py | 0 plume/models/jasper/data.py | 24 + .../models/jasper}/data_loaders.py | 0 .../models/jasper/eval.py | 2 +- .../models/jasper}/featurizer.py | 0 plume/models/jasper/serve.py | 52 ++ .../models/jasper/serve_data.py | 0 .../cli.py => plume/models/jasper/train.py | 0 plume/models/matchboxnet/__init__.py | 0 plume/models/wav2vec2/__init__.py | 0 plume/models/wav2vec2/asr.py | 204 +++++ plume/models/wav2vec2/data.py | 86 ++ plume/models/wav2vec2/eval.py | 49 ++ plume/models/wav2vec2/serve.py | 53 ++ plume/models/wav2vec2/train.py | 34 + plume/ui/__init__.py | 64 ++ .../ui.py => plume/ui/annotation.py | 48 +- plume/ui/preview.py | 58 ++ .../data/validation => plume/ui}/st_rerun.py | 19 +- plume/utils/__init__.py | 486 ++++++++++++ plume/utils/align.py | 117 +++ plume/utils/audio.py | 28 + plume/utils/lazy_import.py | 737 ++++++++++++++++++ plume/utils/lazy_loader.py | 46 ++ plume/utils/serve.py | 31 + plume/utils/transcribe.py | 184 +++++ plume/utils/tts.py | 92 +++ setup.py | 103 ++- validation_ui.py | 3 - 51 files changed, 2864 insertions(+), 1037 deletions(-) delete mode 100644 jasper/__init__.py delete mode 100644 jasper/client.py delete mode 100644 jasper/data/__init__.py delete mode 100644 jasper/data/process.py delete mode 100644 jasper/data/rastrik_recycler.py delete mode 100644 jasper/data/utils.py delete mode 100644 jasper/data/validation/__init__.py delete mode 100644 jasper/data/validation/process.py delete mode 100644 jasper/server.py delete mode 100644 jasper/training/__init__.py delete mode 100644 jasper/transcribe.py delete mode 100644 jasper/utils.py create mode 100644 plume/cli/__init__.py create mode 100644 plume/cli/data/__init__.py create mode 100644 plume/cli/data/generate.py create mode 100644 plume/cli/eval.py create mode 100644 plume/cli/serve.py create mode 100644 plume/cli/train.py create mode 100644 plume/models/__init__.py create mode 100644 plume/models/jasper/__init__.py rename {jasper => plume/models/jasper}/asr.py (100%) create mode 100644 plume/models/jasper/data.py rename {jasper/training => plume/models/jasper}/data_loaders.py (100%) rename jasper/evaluate.py => plume/models/jasper/eval.py (99%) rename {jasper/training => plume/models/jasper}/featurizer.py (100%) create mode 100644 plume/models/jasper/serve.py rename jasper/data/server.py => plume/models/jasper/serve_data.py (100%) rename jasper/training/cli.py => plume/models/jasper/train.py (100%) create mode 100644 plume/models/matchboxnet/__init__.py create mode 100644 plume/models/wav2vec2/__init__.py create mode 100644 plume/models/wav2vec2/asr.py create mode 100644 plume/models/wav2vec2/data.py create mode 100644 plume/models/wav2vec2/eval.py create mode 100644 plume/models/wav2vec2/serve.py create mode 100644 plume/models/wav2vec2/train.py create mode 100644 plume/ui/__init__.py rename jasper/data/validation/ui.py => plume/ui/annotation.py (78%) create mode 100644 plume/ui/preview.py rename {jasper/data/validation => plume/ui}/st_rerun.py (63%) create mode 100644 plume/utils/__init__.py create mode 100644 plume/utils/align.py create mode 100644 plume/utils/audio.py create mode 100644 plume/utils/lazy_import.py create mode 100644 plume/utils/lazy_loader.py create mode 100644 plume/utils/serve.py create mode 100644 plume/utils/transcribe.py create mode 100644 plume/utils/tts.py delete mode 100644 validation_ui.py diff --git a/Notes.md b/Notes.md index 4195305..090056e 100644 --- a/Notes.md +++ b/Notes.md @@ -3,3 +3,17 @@ ``` diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort) ``` + +> Prepare Augmented Data +``` +plume data filter /dataset/png_entities/png_numbers_2020_07/ /dataset/png_entities/png_numbers_2020_07_skip1hour/ + +plume data augment /dataset/agara_slu/call_alphanum_ag_sg_v1_abs/ /dataset/png_entities/png_numbers_2020_07_1hour_noblank/ /dataset/png_entities/png_numbers_2020_07_skip1hour/ /dataset/png_entities/aug_pngskip1hour-agsgalnum-1hournoblank/ + +plume data filter --kind transform_digits /dataset/agara_slu/png1hour-agsgalnum-1hournoblank/ /dataset/agara_slu/png1hour-agsgalnum-1hournoblank_prep/ +``` + + +``` +KENLM_INC=/usr/local/include/kenlm/ pip install -e ../deps/wav2letter/bindings/python/ +``` diff --git a/README.md b/README.md index d95808d..c21365a 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# Jasper ASR +# Plume ASR [![image](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black) -> Generates text from speech audio +> Generates text from audio containing speech --- # Table of Contents @@ -20,7 +20,7 @@ # Features * ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) ) - +* ASR using Wav2Vec2 (from [fairseq](https://github.com/pytorch/fairseq) ) # Installation To install the packages and its dependencies run. @@ -29,14 +29,26 @@ python setup.py install ``` or with pip ```bash -pip install .[server] +pip install .[all] ``` The installation should work on Python 3.6 or newer. Untested on Python 2.7 # Usage +### Library +> Jasper ```python -from jasper.asr import JasperASR +from plume.models.jasper.asr import JasperASR asr_model = JasperASR("/path/to/model_config_yaml","/path/to/encoder_checkpoint","/path/to/decoder_checkpoint") # Loads the models TEXT = asr_model.transcribe(wav_data) # Returns the text spoken in the wav ``` +> Wav2Vec2 +```python +from plume.models.wav2vec2.asr import Wav2Vec2ASR +asr_model = Wav2Vec2ASR("/path/to/ctc_checkpoint","/path/to/w2v_checkpoint","/path/to/target_dictionary") # Loads the models +TEXT = asr_model.transcribe(wav_data) # Returns the text spoken in the wav +``` +### Command Line +``` +$ plume +``` diff --git a/jasper/__init__.py b/jasper/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/jasper/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/jasper/client.py b/jasper/client.py deleted file mode 100644 index 6c474a5..0000000 --- a/jasper/client.py +++ /dev/null @@ -1,21 +0,0 @@ -import os -import logging -import rpyc -from functools import lru_cache - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - - -ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost") -ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045")) - - -@lru_cache() -def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT): - logger.info(f"connecting to asr server at {asr_host}:{asr_port}") - asr = rpyc.connect(asr_host, asr_port).root - logger.info(f"connected to asr server successfully") - return asr.transcribe diff --git a/jasper/data/__init__.py b/jasper/data/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/jasper/data/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/jasper/data/process.py b/jasper/data/process.py deleted file mode 100644 index 472c5bf..0000000 --- a/jasper/data/process.py +++ /dev/null @@ -1,77 +0,0 @@ -import json -from pathlib import Path -from sklearn.model_selection import train_test_split -from .utils import asr_manifest_reader, asr_manifest_writer -from typing import List -from itertools import chain -import typer - -app = typer.Typer() - - -@app.command() -def fixate_data(dataset_path: Path): - manifest_path = dataset_path / Path("manifest.json") - real_manifest_path = dataset_path / Path("abs_manifest.json") - - def fix_path(): - for i in asr_manifest_reader(manifest_path): - i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"])) - yield i - - asr_manifest_writer(real_manifest_path, fix_path()) - - -@app.command() -def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path): - reader_list = [] - abs_manifest_path = Path("abs_manifest.json") - for dataset_path in src_dataset_paths: - manifest_path = dataset_path / abs_manifest_path - reader_list.append(asr_manifest_reader(manifest_path)) - dest_dataset_path.mkdir(parents=True, exist_ok=True) - dest_manifest_path = dest_dataset_path / abs_manifest_path - asr_manifest_writer(dest_manifest_path, chain(*reader_list)) - - -@app.command() -def split_data(dataset_path: Path, test_size: float = 0.1): - manifest_path = dataset_path / Path("abs_manifest.json") - asr_data = list(asr_manifest_reader(manifest_path)) - train_data, test_data = train_test_split(asr_data, test_size=test_size) - asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_data) - asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_data) - - -@app.command() -def validate_data(dataset_path: Path): - from natural.date import compress - from datetime import timedelta - - for mf_type in ["train_manifest.json", "test_manifest.json"]: - data_file = dataset_path / Path(mf_type) - print(f"validating {data_file}.") - with Path(data_file).open("r") as pf: - data_jsonl = pf.readlines() - duration = 0 - for (i, s) in enumerate(data_jsonl): - try: - d = json.loads(s) - duration += d["duration"] - audio_file = data_file.parent / Path(d["audio_filepath"]) - if not audio_file.exists(): - raise OSError(f"File {audio_file} not found") - except BaseException as e: - print(f'failed on {i} with "{e}"') - duration_str = compress(timedelta(seconds=duration), pad=" ") - print( - f"no errors found. seems like a valid {mf_type}. contains {duration_str}sec of audio" - ) - - -def main(): - app() - - -if __name__ == "__main__": - main() diff --git a/jasper/data/rastrik_recycler.py b/jasper/data/rastrik_recycler.py deleted file mode 100644 index 6093b18..0000000 --- a/jasper/data/rastrik_recycler.py +++ /dev/null @@ -1,93 +0,0 @@ -from rastrik.proto.callrecord_pb2 import CallRecord -import gzip -from pydub import AudioSegment -from .utils import ui_dump_manifest_writer, strip_silence - -import typer -from itertools import chain -from io import BytesIO -from pathlib import Path - -app = typer.Typer() - - -@app.command() -def extract_manifest( - call_log_dir: Path = Path("./data/call_audio"), - output_dir: Path = Path("./data"), - dataset_name: str = "grassroot_pizzahut_v1", - caller_name: str = "grassroot", - verbose: bool = False, -): - call_asr_data: Path = output_dir / Path("asr_data") - call_asr_data.mkdir(exist_ok=True, parents=True) - - def wav_pb2_generator(log_dir): - for wav_path in log_dir.glob("**/*.wav"): - if verbose: - typer.echo(f"loading events for file {wav_path}") - call_wav = AudioSegment.from_file_using_temporary_files(wav_path) - meta_path = wav_path.with_suffix(".pb2.gz") - yield call_wav, wav_path, meta_path - - def read_event(call_wav, log_file): - call_wav_0, call_wav_1 = call_wav.split_to_mono() - with gzip.open(log_file, "rb") as log_h: - record_data = log_h.read() - cr = CallRecord() - cr.ParseFromString(record_data) - - first_audio_event_timestamp = next( - ( - i - for i in cr.events - if i.WhichOneof("event_type") == "call_event" - and i.call_event.WhichOneof("event_type") == "call_audio" - ) - ).timestamp.ToDatetime() - - speech_events = [ - i - for i in cr.events - if i.WhichOneof("event_type") == "speech_event" - and i.speech_event.WhichOneof("event_type") == "asr_final" - ] - previous_event_timestamp = ( - first_audio_event_timestamp - first_audio_event_timestamp - ) - for index, each_speech_events in enumerate(speech_events): - asr_final = each_speech_events.speech_event.asr_final - speech_timestamp = each_speech_events.timestamp.ToDatetime() - actual_timestamp = speech_timestamp - first_audio_event_timestamp - start_time = previous_event_timestamp.total_seconds() * 1000 - end_time = actual_timestamp.total_seconds() * 1000 - audio_segment = strip_silence(call_wav_1[start_time:end_time]) - - code_fb = BytesIO() - audio_segment.export(code_fb, format="wav") - wav_data = code_fb.getvalue() - previous_event_timestamp = actual_timestamp - duration = (end_time - start_time) / 1000 - yield asr_final, duration, wav_data, "grassroot", audio_segment - - def generate_call_asr_data(): - full_data = [] - total_duration = 0 - for wav, wav_path, pb2_path in wav_pb2_generator(call_log_dir): - asr_data = read_event(wav, pb2_path) - total_duration += wav.duration_seconds - full_data.append(asr_data) - n_calls = len(full_data) - typer.echo(f"loaded {n_calls} calls of duration {total_duration}s") - n_dps = ui_dump_manifest_writer(call_asr_data, dataset_name, chain(*full_data)) - typer.echo(f"written {n_dps} data points") - - generate_call_asr_data() - - -def main(): - app() - - -if __name__ == "__main__": - main() diff --git a/jasper/data/utils.py b/jasper/data/utils.py deleted file mode 100644 index 300a2da..0000000 --- a/jasper/data/utils.py +++ /dev/null @@ -1,241 +0,0 @@ -import io -import os -import json -import wave -from pathlib import Path -from functools import partial -from uuid import uuid4 -from concurrent.futures import ThreadPoolExecutor - -import pymongo -from slugify import slugify -from jasper.client import transcribe_gen -from nemo.collections.asr.metrics import word_error_rate -import matplotlib.pyplot as plt -import librosa -import librosa.display -from tqdm import tqdm - - -def manifest_str(path, dur, text): - return ( - json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text}) - + "\n" - ) - - -def wav_bytes(audio_bytes, frame_rate=24000): - wf_b = io.BytesIO() - with wave.open(wf_b, mode="w") as wf: - wf.setnchannels(1) - wf.setframerate(frame_rate) - wf.setsampwidth(2) - wf.writeframesraw(audio_bytes) - return wf_b.getvalue() - - -def tscript_uuid_fname(transcript): - return str(uuid4()) + "_" + slugify(transcript, max_length=8) - - -def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False): - dataset_dir = output_dir / Path(dataset_name) - (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) - asr_manifest = dataset_dir / Path("manifest.json") - num_datapoints = 0 - with asr_manifest.open("w") as mf: - print(f"writing manifest to {asr_manifest}") - for transcript, audio_dur, wav_data in asr_data_source: - fname = tscript_uuid_fname(transcript) - audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") - audio_file.write_bytes(wav_data) - rel_data_path = audio_file.relative_to(dataset_dir) - manifest = manifest_str(str(rel_data_path), audio_dur, transcript) - mf.write(manifest) - if verbose: - print(f"writing '{transcript}' of duration {audio_dur}") - num_datapoints += 1 - return num_datapoints - - -def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False): - dataset_dir = output_dir / Path(dataset_name) - (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) - (dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True) - - def data_fn( - transcript, - audio_dur, - wav_data, - caller_name, - aud_seg, - fname, - audio_path, - num_datapoints, - rel_data_path, - ): - pretrained_result = transcriber_pretrained(aud_seg.raw_data) - pretrained_wer = word_error_rate([transcript], [pretrained_result]) - png_path = Path(fname).with_suffix(".png") - wav_plot_path = dataset_dir / Path("wav_plots") / png_path - if not wav_plot_path.exists(): - plot_seg(wav_plot_path, audio_path) - return { - "audio_filepath": str(rel_data_path), - "duration": round(audio_dur, 1), - "text": transcript, - "real_idx": num_datapoints, - "audio_path": audio_path, - "spoken": transcript, - "caller": caller_name, - "utterance_id": fname, - "pretrained_asr": pretrained_result, - "pretrained_wer": pretrained_wer, - "plot_path": str(wav_plot_path), - } - - num_datapoints = 0 - data_funcs = [] - transcriber_pretrained = transcribe_gen(asr_port=8044) - for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source: - fname = str(uuid4()) + "_" + slugify(transcript, max_length=8) - audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") - audio_file.write_bytes(wav_data) - audio_path = str(audio_file) - rel_data_path = audio_file.relative_to(dataset_dir) - data_funcs.append( - partial( - data_fn, - transcript, - audio_dur, - wav_data, - caller_name, - aud_seg, - fname, - audio_path, - num_datapoints, - rel_data_path, - ) - ) - num_datapoints += 1 - ui_data = parallel_apply(lambda x: x(), data_funcs) - return ui_data, num_datapoints - - -def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False): - dataset_dir = output_dir / Path(dataset_name) - dump_data, num_datapoints = ui_data_generator( - output_dir, dataset_name, asr_data_source, verbose=verbose - ) - - asr_manifest = dataset_dir / Path("manifest.json") - with asr_manifest.open("w") as mf: - print(f"writing manifest to {asr_manifest}") - for d in dump_data: - rel_data_path = d["audio_filepath"] - audio_dur = d["duration"] - transcript = d["text"] - manifest = manifest_str(str(rel_data_path), audio_dur, transcript) - mf.write(manifest) - - ui_dump_file = dataset_dir / Path("ui_dump.json") - ExtendedPath(ui_dump_file).write_json({"data": dump_data}) - return num_datapoints - - -def asr_manifest_reader(data_manifest_path: Path): - print(f"reading manifest from {data_manifest_path}") - with data_manifest_path.open("r") as pf: - data_jsonl = pf.readlines() - data_data = [json.loads(v) for v in data_jsonl] - for p in data_data: - p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"]) - p["text"] = p["text"].strip() - yield p - - -def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source): - with asr_manifest_path.open("w") as mf: - print(f"opening {asr_manifest_path} for writing manifest") - for mani_dict in manifest_str_source: - manifest = manifest_str( - mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"] - ) - mf.write(manifest) - - -def asr_test_writer(out_file_path: Path, source): - def dd_str(dd, idx): - path = dd["audio_filepath"] - # dur = dd["duration"] - # return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n" - return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n" - - res_file = out_file_path.with_suffix(".result.json") - with out_file_path.open("w") as of: - print(f"opening {out_file_path} for writing test") - results = [] - idx = 0 - for ui_dd in source: - results.append(ui_dd) - out_str = dd_str(ui_dd, idx) - of.write(out_str) - idx += 1 - of.write("DO_HANGUP\n") - ExtendedPath(res_file).write_json(results) - - -def batch(iterable, n=1): - ls = len(iterable) - return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)] - - -class ExtendedPath(type(Path())): - """docstring for ExtendedPath.""" - - def read_json(self): - print(f"reading json from {self}") - with self.open("r") as jf: - return json.load(jf) - - def write_json(self, data): - print(f"writing json to {self}") - self.parent.mkdir(parents=True, exist_ok=True) - with self.open("w") as jf: - return json.dump(data, jf, indent=2) - - -def get_mongo_conn(host="", port=27017, db="test", col="calls"): - mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost") - mongo_uri = f"mongodb://{mongo_host}:{port}/" - return pymongo.MongoClient(mongo_uri)[db][col] - - -def strip_silence(sound): - from pydub.silence import detect_leading_silence - - start_trim = detect_leading_silence(sound) - end_trim = detect_leading_silence(sound.reverse()) - duration = len(sound) - return sound[start_trim : duration - end_trim] - - -def plot_seg(wav_plot_path, audio_path): - fig = plt.Figure() - ax = fig.add_subplot() - (y, sr) = librosa.load(audio_path) - librosa.display.waveplot(y=y, sr=sr, ax=ax) - with wav_plot_path.open("wb") as wav_plot_f: - fig.set_tight_layout(True) - fig.savefig(wav_plot_f, format="png", dpi=50) - - -def parallel_apply(fn, iterable, workers=8): - with ThreadPoolExecutor(max_workers=workers) as exe: - print(f"parallelly applying {fn}") - return [ - res - for res in tqdm( - exe.map(fn, iterable), position=0, leave=True, total=len(iterable) - ) - ] diff --git a/jasper/data/validation/__init__.py b/jasper/data/validation/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/jasper/data/validation/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/jasper/data/validation/process.py b/jasper/data/validation/process.py deleted file mode 100644 index 619113b..0000000 --- a/jasper/data/validation/process.py +++ /dev/null @@ -1,398 +0,0 @@ -import json -import shutil -from pathlib import Path - -import typer - -from ..utils import ( - ExtendedPath, - asr_manifest_reader, - asr_manifest_writer, - tscript_uuid_fname, - get_mongo_conn, - plot_seg, -) - -app = typer.Typer() - - -def preprocess_datapoint(idx, rel_root, sample): - from pydub import AudioSegment - from nemo.collections.asr.metrics import word_error_rate - from jasper.client import transcribe_gen - - try: - res = dict(sample) - res["real_idx"] = idx - audio_path = rel_root / Path(sample["audio_filepath"]) - res["audio_path"] = str(audio_path) - res["utterance_id"] = audio_path.stem - transcriber_pretrained = transcribe_gen(asr_port=8044) - - aud_seg = ( - AudioSegment.from_file_using_temporary_files(audio_path) - .set_channels(1) - .set_sample_width(2) - .set_frame_rate(24000) - ) - res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) - res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]]) - wav_plot_path = ( - rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png") - ) - if not wav_plot_path.exists(): - plot_seg(wav_plot_path, audio_path) - res["plot_path"] = str(wav_plot_path) - return res - except BaseException as e: - print(f'failed on {idx}: {sample["audio_filepath"]} with {e}') - - -@app.command() -def dump_ui( - data_name: str = typer.Option("dataname", show_default=True), - dataset_dir: Path = Path("./data/asr_data"), - dump_dir: Path = Path("./data/valiation_data"), - dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True), -): - from io import BytesIO - from pydub import AudioSegment - from ..utils import ui_data_generator - - data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json") - plot_dir = data_manifest_path.parent / Path("wav_plots") - plot_dir.mkdir(parents=True, exist_ok=True) - typer.echo(f"Using data manifest:{data_manifest_path}") - - def asr_data_source_gen(): - with data_manifest_path.open("r") as pf: - data_jsonl = pf.readlines() - for v in data_jsonl: - sample = json.loads(v) - rel_root = data_manifest_path.parent - res = dict(sample) - audio_path = rel_root / Path(sample["audio_filepath"]) - audio_segment = ( - AudioSegment.from_file_using_temporary_files(audio_path) - .set_channels(1) - .set_sample_width(2) - .set_frame_rate(24000) - ) - wav_plot_path = ( - rel_root - / Path("wav_plots") - / Path(audio_path.name).with_suffix(".png") - ) - if not wav_plot_path.exists(): - plot_seg(wav_plot_path, audio_path) - res["plot_path"] = str(wav_plot_path) - code_fb = BytesIO() - audio_segment.export(code_fb, format="wav") - wav_data = code_fb.getvalue() - duration = audio_segment.duration_seconds - asr_final = res["text"] - yield asr_final, duration, wav_data, "caller", audio_segment - - dump_data, num_datapoints = ui_data_generator( - dataset_dir, data_name, asr_data_source_gen() - ) - ui_dump_file = dataset_dir / Path("ui_dump.json") - ExtendedPath(ui_dump_file).write_json({"data": dump_data}) - - -@app.command() -def sample_ui( - data_name: str = typer.Option("dataname", show_default=True), - dump_dir: Path = Path("./data/asr_data"), - dump_file: Path = Path("ui_dump.json"), - sample_count: int = typer.Option(80, show_default=True), - sample_file: Path = Path("sample_dump.json"), -): - import pandas as pd - - processed_data_path = dump_dir / Path(data_name) / dump_file - sample_path = dump_dir / Path(data_name) / sample_file - processed_data = ExtendedPath(processed_data_path).read_json() - df = pd.DataFrame(processed_data["data"]) - samples_per_caller = sample_count // len(df["caller"].unique()) - caller_samples = pd.concat( - [g.sample(samples_per_caller) for (c, g) in df.groupby("caller")] - ) - caller_samples = caller_samples.reset_index(drop=True) - caller_samples["real_idx"] = caller_samples.index - sample_data = caller_samples.to_dict("records") - processed_data["data"] = sample_data - typer.echo(f"sampling {sample_count} datapoints") - ExtendedPath(sample_path).write_json(processed_data) - - -@app.command() -def task_ui( - data_name: str = typer.Option("dataname", show_default=True), - dump_dir: Path = Path("./data/asr_data"), - dump_file: Path = Path("ui_dump.json"), - task_count: int = typer.Option(4, show_default=True), - task_file: str = "task_dump", -): - import pandas as pd - import numpy as np - - processed_data_path = dump_dir / Path(data_name) / dump_file - processed_data = ExtendedPath(processed_data_path).read_json() - df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True) - for t_idx, task_f in enumerate(np.array_split(df, task_count)): - task_f = task_f.reset_index(drop=True) - task_f["real_idx"] = task_f.index - task_data = task_f.to_dict("records") - processed_data["data"] = task_data - task_path = dump_dir / Path(data_name) / Path(task_file + f"-{t_idx}.json") - ExtendedPath(task_path).write_json(processed_data) - - -@app.command() -def dump_corrections( - task_uid: str, - data_name: str = typer.Option("dataname", show_default=True), - dump_dir: Path = Path("./data/asr_data"), - dump_fname: Path = Path("corrections.json"), -): - dump_path = dump_dir / Path(data_name) / dump_fname - col = get_mongo_conn(col="asr_validation") - task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0] - corrections = list(col.find({"type": "correction"}, projection={"_id": False})) - cursor_obj = col.find( - {"type": "correction", "task_id": task_id}, projection={"_id": False} - ) - corrections = [c for c in cursor_obj] - ExtendedPath(dump_path).write_json(corrections) - - -@app.command() -def caller_quality( - task_uid: str, - data_name: str = typer.Option("dataname", show_default=True), - dump_dir: Path = Path("./data/asr_data"), - dump_fname: Path = Path("ui_dump.json"), - correction_fname: Path = Path("corrections.json"), -): - import copy - import pandas as pd - - dump_path = dump_dir / Path(data_name) / dump_fname - correction_path = dump_dir / Path(data_name) / correction_fname - dump_data = ExtendedPath(dump_path).read_json() - - dump_map = {d["utterance_id"]: d for d in dump_data["data"]} - correction_data = ExtendedPath(correction_path).read_json() - - def correction_dp(c): - dp = copy.deepcopy(dump_map[c["code"]]) - dp["valid"] = c["value"]["status"] == "Correct" - return dp - - corrected_dump = [ - correction_dp(c) - for c in correction_data - if c["task_id"].rsplit("-", 1)[1] == task_uid - ] - df = pd.DataFrame(corrected_dump) - print(f"Total samples: {len(df)}") - for (c, g) in df.groupby("caller"): - total = len(g) - valid = len(g[g["valid"] == True]) - valid_rate = valid * 100 / total - print(f"Caller: {c} Valid%:{valid_rate:.2f} of {total} samples") - - -@app.command() -def fill_unannotated( - data_name: str = typer.Option("dataname", show_default=True), - dump_dir: Path = Path("./data/valiation_data"), - dump_file: Path = Path("ui_dump.json"), - corrections_file: Path = Path("corrections.json"), -): - processed_data_path = dump_dir / Path(data_name) / dump_file - corrections_path = dump_dir / Path(data_name) / corrections_file - processed_data = json.load(processed_data_path.open()) - corrections = json.load(corrections_path.open()) - annotated_codes = {c["code"] for c in corrections} - all_codes = {c["gold_chars"] for c in processed_data} - unann_codes = all_codes - annotated_codes - mongo_conn = get_mongo_conn(col="asr_validation") - for c in unann_codes: - mongo_conn.find_one_and_update( - {"type": "correction", "code": c}, - {"$set": {"value": {"status": "Inaudible", "correction": ""}}}, - upsert=True, - ) - - -@app.command() -def split_extract( - data_name: str = typer.Option("dataname", show_default=True), - # dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True), - # dump_dir: Path = Path("./data/valiation_data"), - dump_dir: Path = Path("./data/asr_data"), - dump_file: Path = Path("ui_dump.json"), - manifest_file: Path = Path("manifest.json"), - corrections_file: str = typer.Option("corrections.json", show_default=True), - conv_data_path: Path = typer.Option( - Path("./data/conv_data.json"), show_default=True - ), - extraction_type: str = "all", -): - import shutil - - data_manifest_path = dump_dir / Path(data_name) / manifest_file - conv_data = ExtendedPath(conv_data_path).read_json() - - def extract_data_of_type(extraction_key): - extraction_vals = conv_data[extraction_key] - dest_data_name = data_name + "_" + extraction_key.lower() - - manifest_gen = asr_manifest_reader(data_manifest_path) - dest_data_dir = dump_dir / Path(dest_data_name) - dest_data_dir.mkdir(exist_ok=True, parents=True) - (dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True) - dest_manifest_path = dest_data_dir / manifest_file - dest_ui_path = dest_data_dir / dump_file - - def extract_manifest(mg): - for m in mg: - if m["text"] in extraction_vals: - shutil.copy( - m["audio_path"], dest_data_dir / Path(m["audio_filepath"]) - ) - yield m - - asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen)) - - ui_data_path = dump_dir / Path(data_name) / dump_file - orig_ui_data = ExtendedPath(ui_data_path).read_json() - ui_data = orig_ui_data["data"] - file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data} - extracted_ui_data = list( - filter(lambda u: u["text"] in extraction_vals, ui_data) - ) - final_data = [] - for i, d in enumerate(extracted_ui_data): - d["real_idx"] = i - final_data.append(d) - orig_ui_data["data"] = final_data - ExtendedPath(dest_ui_path).write_json(orig_ui_data) - - if corrections_file: - dest_correction_path = dest_data_dir / corrections_file - corrections_path = dump_dir / Path(data_name) / corrections_file - corrections = json.load(corrections_path.open()) - extracted_corrections = list( - filter( - lambda c: c["code"] in file_ui_map - and file_ui_map[c["code"]]["text"] in extraction_vals, - corrections, - ) - ) - ExtendedPath(dest_correction_path).write_json(extracted_corrections) - - if extraction_type.value == "all": - for ext_key in conv_data.keys(): - extract_data_of_type(ext_key) - else: - extract_data_of_type(extraction_type.value) - - -@app.command() -def update_corrections( - data_name: str = typer.Option("dataname", show_default=True), - dump_dir: Path = Path("./data/asr_data"), - manifest_file: Path = Path("manifest.json"), - corrections_file: Path = Path("corrections.json"), - ui_dump_file: Path = Path("ui_dump.json"), - skip_incorrect: bool = typer.Option(True, show_default=True), -): - data_manifest_path = dump_dir / Path(data_name) / manifest_file - corrections_path = dump_dir / Path(data_name) / corrections_file - ui_dump_path = dump_dir / Path(data_name) / ui_dump_file - - def correct_manifest(ui_dump_path, corrections_path): - corrections = ExtendedPath(corrections_path).read_json() - ui_data = ExtendedPath(ui_dump_path).read_json()["data"] - correct_set = { - c["code"] for c in corrections if c["value"]["status"] == "Correct" - } - # incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"} - correction_map = { - c["code"]: c["value"]["correction"] - for c in corrections - if c["value"]["status"] == "Incorrect" - } - # for d in manifest_data_gen: - # if d["chars"] in incorrect_set: - # d["audio_path"].unlink() - # renamed_set = set() - for d in ui_data: - if d["utterance_id"] in correct_set: - yield { - "audio_filepath": d["audio_filepath"], - "duration": d["duration"], - "text": d["text"], - } - elif d["utterance_id"] in correction_map: - correct_text = correction_map[d["utterance_id"]] - if skip_incorrect: - print( - f'skipping incorrect {d["audio_path"]} corrected to {correct_text}' - ) - else: - orig_audio_path = Path(d["audio_path"]) - new_name = str( - Path(tscript_uuid_fname(correct_text)).with_suffix(".wav") - ) - new_audio_path = orig_audio_path.with_name(new_name) - orig_audio_path.replace(new_audio_path) - new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) - yield { - "audio_filepath": new_filepath, - "duration": d["duration"], - "text": correct_text, - } - else: - orig_audio_path = Path(d["audio_path"]) - # don't delete if another correction points to an old file - # if d["text"] not in renamed_set: - orig_audio_path.unlink() - # else: - # print(f'skipping deletion of correction:{d["text"]}') - - typer.echo(f"Using data manifest:{data_manifest_path}") - dataset_dir = data_manifest_path.parent - dataset_name = dataset_dir.name - backup_dir = dataset_dir.with_name(dataset_name + ".bkp") - if not backup_dir.exists(): - typer.echo(f"backing up to :{backup_dir}") - shutil.copytree(str(dataset_dir), str(backup_dir)) - # manifest_gen = asr_manifest_reader(data_manifest_path) - corrected_manifest = correct_manifest(ui_dump_path, corrections_path) - new_data_manifest_path = data_manifest_path.with_name("manifest.new") - asr_manifest_writer(new_data_manifest_path, corrected_manifest) - new_data_manifest_path.replace(data_manifest_path) - - -@app.command() -def clear_mongo_corrections(): - delete = typer.confirm("are you sure you want to clear mongo collection it?") - if delete: - col = get_mongo_conn(col="asr_validation") - col.delete_many({"type": "correction"}) - col.delete_many({"type": "current_cursor"}) - typer.echo("deleted mongo collection.") - return - typer.echo("Aborted") - - -def main(): - app() - - -if __name__ == "__main__": - main() diff --git a/jasper/server.py b/jasper/server.py deleted file mode 100644 index 163a022..0000000 --- a/jasper/server.py +++ /dev/null @@ -1,57 +0,0 @@ -import os -import logging - -import rpyc -from rpyc.utils.server import ThreadedServer - -from .asr import JasperASR -from .utils import arg_parser - - -class ASRService(rpyc.Service): - def __init__(self, asr_recognizer): - self.asr = asr_recognizer - - def on_connect(self, conn): - # code that runs when a connection is created - # (to init the service, if needed) - pass - - def on_disconnect(self, conn): - # code that runs after the connection has already closed - # (to finalize the service, if needed) - pass - - def exposed_transcribe(self, utterance: bytes): # this is an exposed method - speech_audio = self.asr.transcribe(utterance) - return speech_audio - - def exposed_transcribe_cb( - self, utterance: bytes, respond - ): # this is an exposed method - speech_audio = self.asr.transcribe(utterance) - respond(speech_audio) - - -def main(): - parser = arg_parser('jasper_transcribe') - parser.description = 'jasper asr rpyc server' - parser.add_argument( - "--port", type=int, default=int(os.environ.get("ASR_RPYC_PORT", "8044")), help="port to listen on" - ) - args = parser.parse_args() - args_dict = vars(args) - port = args_dict.pop("port") - jasper_asr = JasperASR(**args_dict) - service = ASRService(jasper_asr) - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - logging.info("starting asr server...") - t = ThreadedServer(service, port=port) - t.start() - - -if __name__ == "__main__": - main() diff --git a/jasper/training/__init__.py b/jasper/training/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/jasper/training/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/jasper/transcribe.py b/jasper/transcribe.py deleted file mode 100644 index d58fe9d..0000000 --- a/jasper/transcribe.py +++ /dev/null @@ -1,22 +0,0 @@ -from pathlib import Path -from .asr import JasperASR -from .utils import arg_parser - - -def main(): - parser = arg_parser('jasper_transcribe') - parser.description = 'transcribe audio file to text' - parser.add_argument( - "audio_file", - type=Path, - help="audio file(16khz 1channel int16 wav) to transcribe", - ) - parser.add_argument( - "--greedy", type=bool, default=False, help="enables greedy decoding" - ) - args = parser.parse_args() - args_dict = vars(args) - audio_file = args_dict.pop("audio_file") - greedy = args_dict.pop("greedy") - jasper_asr = JasperASR(**args_dict) - jasper_asr.transcribe_file(audio_file, greedy) diff --git a/jasper/utils.py b/jasper/utils.py deleted file mode 100644 index 5f5ed3e..0000000 --- a/jasper/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -import os -import argparse -from pathlib import Path - -MODEL_YAML = os.environ.get("JASPER_MODEL_CONFIG", "/models/jasper/jasper10x5dr.yaml") -CHECKPOINT_ENCODER = os.environ.get( - "JASPER_ENCODER_CHECKPOINT", "/models/jasper/JasperEncoder-STEP-265520.pt" -) -CHECKPOINT_DECODER = os.environ.get( - "JASPER_DECODER_CHECKPOINT", "/models/jasper/JasperDecoderForCTC-STEP-265520.pt" -) -KEN_LM = os.environ.get("JASPER_KEN_LM", "/models/jasper/kenlm.pt") - - -def arg_parser(prog): - parser = argparse.ArgumentParser( - prog=prog, description=f"convert speech to text" - ) - parser.add_argument( - "--model_yaml", - type=Path, - default=Path(MODEL_YAML), - help="model config yaml file", - ) - parser.add_argument( - "--encoder_checkpoint", - type=Path, - default=Path(CHECKPOINT_ENCODER), - help="encoder checkpoint weights file", - ) - parser.add_argument( - "--decoder_checkpoint", - type=Path, - default=Path(CHECKPOINT_DECODER), - help="decoder checkpoint weights file", - ) - parser.add_argument( - "--language_model", type=Path, default=None, help="kenlm language model file" - ) - return parser diff --git a/plume/cli/__init__.py b/plume/cli/__init__.py new file mode 100644 index 0000000..2200e2e --- /dev/null +++ b/plume/cli/__init__.py @@ -0,0 +1,23 @@ +import typer +from ..utils import app as utils_app +from .data import app as data_app +from ..ui import app as ui_app +from .train import app as train_app +from .eval import app as eval_app +from .serve import app as serve_app + +app = typer.Typer() +app.add_typer(data_app, name="data") +app.add_typer(ui_app, name="ui") +app.add_typer(train_app, name="train") +app.add_typer(eval_app, name="eval") +app.add_typer(serve_app, name="serve") +app.add_typer(utils_app, name='utils') + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/plume/cli/data/__init__.py b/plume/cli/data/__init__.py new file mode 100644 index 0000000..9a90926 --- /dev/null +++ b/plume/cli/data/__init__.py @@ -0,0 +1,339 @@ +import json +from pathlib import Path +# from sklearn.model_selection import train_test_split +from plume.utils import ( + asr_manifest_reader, + asr_manifest_writer, + ExtendedPath, + duration_str, + generate_filter_map, + get_mongo_conn, + tscript_uuid_fname, + lazy_callable +) +from typing import List +from itertools import chain +import shutil +import typer +import soundfile + +from ...models.wav2vec2.data import app as wav2vec2_app +from .generate import app as generate_app + +train_test_split = lazy_callable('sklearn.model_selection.train_test_split') + +app = typer.Typer() +app.add_typer(generate_app, name="generate") +app.add_typer(wav2vec2_app, name="wav2vec2") + + +@app.command() +def fix_path(dataset_path: Path, force: bool = False): + manifest_path = dataset_path / Path("manifest.json") + real_manifest_path = dataset_path / Path("abs_manifest.json") + + def fix_real_path(): + for i in asr_manifest_reader(manifest_path): + i["audio_filepath"] = str( + (dataset_path / Path(i["audio_filepath"])).absolute() + ) + yield i + + def fix_rel_path(): + for i in asr_manifest_reader(real_manifest_path): + i["audio_filepath"] = str( + Path(i["audio_filepath"]).relative_to(dataset_path) + ) + yield i + + if not manifest_path.exists() and not real_manifest_path.exists(): + typer.echo("Invalid dataset directory") + if not real_manifest_path.exists() or force: + asr_manifest_writer(real_manifest_path, fix_real_path()) + if not manifest_path.exists(): + asr_manifest_writer(manifest_path, fix_rel_path()) + + +@app.command() +def augment(src_dataset_paths: List[Path], dest_dataset_path: Path): + reader_list = [] + abs_manifest_path = Path("abs_manifest.json") + for dataset_path in src_dataset_paths: + manifest_path = dataset_path / abs_manifest_path + reader_list.append(asr_manifest_reader(manifest_path)) + dest_dataset_path.mkdir(parents=True, exist_ok=True) + dest_manifest_path = dest_dataset_path / abs_manifest_path + asr_manifest_writer(dest_manifest_path, chain(*reader_list)) + + +@app.command() +def split(dataset_path: Path, test_size: float = 0.03): + manifest_path = dataset_path / Path("abs_manifest.json") + if not manifest_path.exists(): + fix_path(dataset_path) + asr_data = list(asr_manifest_reader(manifest_path)) + train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size) + asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr) + asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr) + + +@app.command() +def validate(dataset_path: Path): + from natural.date import compress + from datetime import timedelta + + for mf_type in ["train_manifest.json", "test_manifest.json"]: + data_file = dataset_path / Path(mf_type) + print(f"validating {data_file}.") + with Path(data_file).open("r") as pf: + pnr_jsonl = pf.readlines() + duration = 0 + for (i, s) in enumerate(pnr_jsonl): + try: + d = json.loads(s) + duration += d["duration"] + audio_file = data_file.parent / Path(d["audio_filepath"]) + if not audio_file.exists(): + raise OSError(f"File {audio_file} not found") + except BaseException as e: + print(f'failed on {i} with "{e}"') + duration_str = compress(timedelta(seconds=duration), pad=" ") + print( + f"no errors found. seems like a valid {mf_type}. contains {duration_str} of audio" + ) + + +@app.command() +def filter(src_dataset_path: Path, dest_dataset_path: Path, kind: str = "skip_dur"): + dest_manifest = dest_dataset_path / Path("manifest.json") + data_file = src_dataset_path / Path("manifest.json") + dest_wav_dir = dest_dataset_path / Path("wavs") + dest_wav_dir.mkdir(exist_ok=True, parents=True) + filter_kind_map = generate_filter_map( + src_dataset_path, dest_dataset_path, data_file + ) + + selected_filter = filter_kind_map.get(kind, None) + if selected_filter: + asr_manifest_writer(dest_manifest, selected_filter()) + else: + typer.echo(f"filter kind - {kind} not implemented") + typer.echo(f"select one of {', '.join(filter_kind_map.keys())}") + + +@app.command() +def info(dataset_path: Path): + for k in ["", "abs_", "train_", "test_"]: + mf_wav_duration = ( + real_duration + ) = max_duration = empty_duration = empty_count = total_count = 0 + data_file = dataset_path / Path(f"{k}manifest.json") + if data_file.exists(): + print(f"stats on {data_file}") + for s in ExtendedPath(data_file).read_jsonl(): + total_count += 1 + mf_wav_duration += s["duration"] + if s["text"] == "": + empty_count += 1 + empty_duration += s["duration"] + wav_path = str(dataset_path / Path(s["audio_filepath"])) + if max_duration < soundfile.info(wav_path).duration: + max_duration = soundfile.info(wav_path).duration + real_duration += soundfile.info(wav_path).duration + + # frame_count = soundfile.info(audio_fname).frames + print(f"max audio duration : {duration_str(max_duration)}") + print(f"total audio duration : {duration_str(mf_wav_duration)}") + print(f"total real audio duration : {duration_str(real_duration)}") + print( + f"total content duration : {duration_str(mf_wav_duration-empty_duration)}" + ) + print(f"total empty duration : {duration_str(empty_duration)}") + print( + f"total empty samples : {empty_count}/{total_count} ({empty_count*100/total_count:.2f}%)" + ) + + +@app.command() +def audio_duration(dataset_path: Path): + wav_duration = 0 + for audio_rel_fname in dataset_path.absolute().glob("**/*.wav"): + audio_fname = str(audio_rel_fname) + wav_duration += soundfile.info(audio_fname).duration + typer.echo(f"duration of wav files @ {dataset_path}: {duration_str(wav_duration)}") + + +@app.command() +def migrate(src_path: Path, dest_path: Path): + shutil.copytree(str(src_path), str(dest_path)) + wav_dir = dest_path / Path("wavs") + wav_dir.mkdir(exist_ok=True, parents=True) + abs_manifest_path = ExtendedPath(dest_path / Path("abs_manifest.json")) + backup_abs_manifest_path = abs_manifest_path.with_suffix(".json.orig") + shutil.copy(abs_manifest_path, backup_abs_manifest_path) + manifest_data = list(abs_manifest_path.read_jsonl()) + for md in manifest_data: + orig_path = Path(md["audio_filepath"]) + new_path = wav_dir / Path(orig_path.name) + shutil.copy(orig_path, new_path) + md["audio_filepath"] = str(new_path) + abs_manifest_path.write_jsonl(manifest_data) + fix_path(dest_path) + + +@app.command() +def task_split( + data_dir: Path, + dump_file: Path = Path("ui_dump.json"), + task_count: int = typer.Option(2, show_default=True), + task_file: str = "task_dump", + sort: bool = True, +): + """ + split ui_dump.json to `task_count` tasks + """ + import pandas as pd + import numpy as np + + processed_data_path = data_dir / dump_file + processed_data = ExtendedPath(processed_data_path).read_json() + df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True) + for t_idx, task_f in enumerate(np.array_split(df, task_count)): + task_f = task_f.reset_index(drop=True) + task_f["real_idx"] = task_f.index + task_data = task_f.to_dict("records") + if sort: + task_data = sorted(task_data, key=lambda x: x["asr_wer"], reverse=True) + processed_data["data"] = task_data + task_path = data_dir / Path(task_file + f"-{t_idx}.json") + ExtendedPath(task_path).write_json(processed_data) + + +def get_corrections(task_uid): + col = get_mongo_conn(col="asr_validation") + task_id = [ + c + for c in col.distinct("task_id") + if c.rsplit("-", 1)[1] == task_uid or c == task_uid + ][0] + corrections = list(col.find({"type": "correction"}, projection={"_id": False})) + cursor_obj = col.find( + {"type": "correction", "task_id": task_id}, projection={"_id": False} + ) + corrections = [c for c in cursor_obj] + return corrections + + +@app.command() +def dump_task_corrections(data_dir: Path, task_uid: str): + dump_fname: Path = Path(f"corrections-{task_uid}.json") + dump_path = data_dir / dump_fname + corrections = get_corrections(task_uid) + ExtendedPath(dump_path).write_json(corrections) + + +@app.command() +def dump_all_corrections(data_dir: Path): + for task_lcks in data_dir.glob('task-*.lck'): + task_uid = task_lcks.stem.replace('task-', '') + dump_task_corrections(data_dir, task_uid) + + +@app.command() +def update_corrections( + data_dir: Path, + skip_incorrect: bool = typer.Option( + False, show_default=True, help="treats incorrect as invalid" + ), + skip_inaudible: bool = typer.Option( + False, show_default=True, help="include invalid as blank target" + ), +): + """ + applies the corrections-*.json + backup the original dataset + """ + manifest_file: Path = Path("manifest.json") + renames_file: Path = Path("rename_map.json") + ui_dump_file: Path = Path("ui_dump.json") + data_manifest_path = data_dir / manifest_file + renames_path = data_dir / renames_file + + def correct_ui_dump(data_dir, rename_result): + ui_dump_path = data_dir / ui_dump_file + # corrections_path = data_dir / Path("corrections.json") + corrections = [ + t + for p in data_dir.glob("corrections-*.json") + for t in ExtendedPath(p).read_json() + ] + ui_data = ExtendedPath(ui_dump_path).read_json()["data"] + correct_set = { + c["code"] for c in corrections if c["value"]["status"] == "Correct" + } + correction_map = { + c["code"]: c["value"]["correction"] + for c in corrections + if c["value"]["status"] == "Incorrect" + } + for d in ui_data: + orig_audio_path = (data_dir / Path(d["audio_path"])).absolute() + if d["utterance_id"] in correct_set: + d["corrected_from"] = d["text"] + yield d + elif d["utterance_id"] in correction_map: + correct_text = correction_map[d["utterance_id"]] + if skip_incorrect: + ap = d["audio_path"] + print(f"skipping incorrect {ap} corrected to {correct_text}") + orig_audio_path.unlink() + else: + new_fname = tscript_uuid_fname(correct_text) + rename_result[new_fname] = { + "orig_text": d["text"], + "correct_text": correct_text, + "orig_id": d["utterance_id"], + } + new_name = str(Path(new_fname).with_suffix(".wav")) + new_audio_path = orig_audio_path.with_name(new_name) + orig_audio_path.replace(new_audio_path) + new_filepath = str(Path(d["audio_path"]).with_name(new_name)) + d["corrected_from"] = d["text"] + d["text"] = correct_text + d["audio_path"] = new_filepath + yield d + else: + if skip_inaudible: + orig_audio_path.unlink() + else: + d["corrected_from"] = d["text"] + d["text"] = "" + yield d + + dataset_dir = data_manifest_path.parent + dataset_name = dataset_dir.name + backup_dir = dataset_dir.with_name(dataset_name + ".bkp") + if not backup_dir.exists(): + typer.echo(f"backing up to {backup_dir}") + shutil.copytree(str(dataset_dir), str(backup_dir)) + renames = {} + corrected_ui_dump = list(correct_ui_dump(data_dir, renames)) + ExtendedPath(data_dir / ui_dump_file).write_json({"data": corrected_ui_dump}) + corrected_manifest = ( + { + "audio_filepath": d["audio_path"], + "duration": d["duration"], + "text": d["text"], + } + for d in corrected_ui_dump + ) + asr_manifest_writer(data_manifest_path, corrected_manifest) + ExtendedPath(renames_path).write_json(renames) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/plume/cli/data/generate.py b/plume/cli/data/generate.py new file mode 100644 index 0000000..b1ee129 --- /dev/null +++ b/plume/cli/data/generate.py @@ -0,0 +1,12 @@ +from pathlib import Path + +import typer +from ...utils.tts import GoogleTTS + +app = typer.Typer() + + +@app.command() +def tts_dataset(dest_path: Path): + tts = GoogleTTS() + pass diff --git a/plume/cli/eval.py b/plume/cli/eval.py new file mode 100644 index 0000000..5686d77 --- /dev/null +++ b/plume/cli/eval.py @@ -0,0 +1,5 @@ +import typer +from ..models.wav2vec2.eval import app as wav2vec2_app + +app = typer.Typer() +app.add_typer(wav2vec2_app, name="wav2vec2") diff --git a/plume/cli/serve.py b/plume/cli/serve.py new file mode 100644 index 0000000..8397682 --- /dev/null +++ b/plume/cli/serve.py @@ -0,0 +1,7 @@ +import typer +from ..models.wav2vec2.serve import app as wav2vec2_app +from ..models.jasper.serve import app as jasper_app + +app = typer.Typer() +app.add_typer(wav2vec2_app, name="wav2vec2") +app.add_typer(jasper_app, name="jasper") diff --git a/plume/cli/train.py b/plume/cli/train.py new file mode 100644 index 0000000..c067984 --- /dev/null +++ b/plume/cli/train.py @@ -0,0 +1,5 @@ +import typer +from ..models.wav2vec2.train import app as train_app + +app = typer.Typer() +app.add_typer(train_app, name="wav2vec2") diff --git a/plume/models/__init__.py b/plume/models/__init__.py new file mode 100644 index 0000000..3ff17b6 --- /dev/null +++ b/plume/models/__init__.py @@ -0,0 +1 @@ +# from . import jasper, wav2vec2, matchboxnet diff --git a/plume/models/jasper/__init__.py b/plume/models/jasper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jasper/asr.py b/plume/models/jasper/asr.py similarity index 100% rename from jasper/asr.py rename to plume/models/jasper/asr.py diff --git a/plume/models/jasper/data.py b/plume/models/jasper/data.py new file mode 100644 index 0000000..1d3babf --- /dev/null +++ b/plume/models/jasper/data.py @@ -0,0 +1,24 @@ +from pathlib import Path +import typer + +app = typer.Typer() + + +@app.command() +def set_root(dataset_path: Path, root_path: Path): + pass + # for dataset_kind in ["train", "valid"]: + # data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv") + # with data_file.open("r") as df: + # lines = df.readlines() + # with data_file.open("w") as df: + # lines[0] = str(root_path) + "\n" + # df.writelines(lines) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/training/data_loaders.py b/plume/models/jasper/data_loaders.py similarity index 100% rename from jasper/training/data_loaders.py rename to plume/models/jasper/data_loaders.py diff --git a/jasper/evaluate.py b/plume/models/jasper/eval.py similarity index 99% rename from jasper/evaluate.py rename to plume/models/jasper/eval.py index 94d8f43..12f558f 100644 --- a/jasper/evaluate.py +++ b/plume/models/jasper/eval.py @@ -45,7 +45,7 @@ def parse_args(): eval_freq=100, load_dir="./train/models/jasper/", warmup_steps=3, - exp_name="jasper-speller", + exp_name="jasper", ) # Overwrite default args diff --git a/jasper/training/featurizer.py b/plume/models/jasper/featurizer.py similarity index 100% rename from jasper/training/featurizer.py rename to plume/models/jasper/featurizer.py diff --git a/plume/models/jasper/serve.py b/plume/models/jasper/serve.py new file mode 100644 index 0000000..892c64e --- /dev/null +++ b/plume/models/jasper/serve.py @@ -0,0 +1,52 @@ +import os +import logging +from pathlib import Path + +from rpyc.utils.server import ThreadedServer +import typer + +# from .asr import JasperASR +from ...utils.serve import ASRService +from plume.utils import lazy_callable + +JasperASR = lazy_callable('plume.models.jasper.asr.JasperASR') + +app = typer.Typer() + + +@app.command() +def rpyc( + encoder_path: Path = "/path/to/encoder.pt", + decoder_path: Path = "/path/to/decoder.pt", + model_yaml_path: Path = "/path/to/model.yaml", + port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")), +): + for p in [encoder_path, decoder_path, model_yaml_path]: + if not p.exists(): + logging.info(f"{p} doesn't exists") + return + asr = JasperASR(str(model_yaml_path), str(encoder_path), str(decoder_path)) + service = ASRService(asr) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logging.info("starting asr server...") + t = ThreadedServer(service, port=port) + t.start() + + +@app.command() +def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))): + encoder_path = model_dir / Path("decoder.pt") + decoder_path = model_dir / Path("encoder.pt") + model_yaml_path = model_dir / Path("model.yaml") + rpyc(encoder_path, decoder_path, model_yaml_path, port) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data/server.py b/plume/models/jasper/serve_data.py similarity index 100% rename from jasper/data/server.py rename to plume/models/jasper/serve_data.py diff --git a/jasper/training/cli.py b/plume/models/jasper/train.py similarity index 100% rename from jasper/training/cli.py rename to plume/models/jasper/train.py diff --git a/plume/models/matchboxnet/__init__.py b/plume/models/matchboxnet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plume/models/wav2vec2/__init__.py b/plume/models/wav2vec2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plume/models/wav2vec2/asr.py b/plume/models/wav2vec2/asr.py new file mode 100644 index 0000000..e4f1581 --- /dev/null +++ b/plume/models/wav2vec2/asr.py @@ -0,0 +1,204 @@ +from io import BytesIO +import warnings +import itertools as it + +import torch +import soundfile as sf +import torch.nn.functional as F + +try: + from fairseq import utils + from fairseq.models import BaseFairseqModel + from fairseq.data import Dictionary + from fairseq.models.wav2vec.wav2vec2_asr import base_architecture, Wav2VecEncoder +except ModuleNotFoundError: + warnings.warn("Install fairseq") +try: + from wav2letter.decoder import CriterionType + from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes +except ModuleNotFoundError: + warnings.warn("Install wav2letter") + + +class Wav2VecCtc(BaseFairseqModel): + def __init__(self, w2v_encoder, args): + super().__init__() + self.w2v_encoder = w2v_encoder + self.args = args + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, args, target_dict): + """Build a new model instance.""" + base_architecture(args) + w2v_encoder = Wav2VecEncoder(args, target_dict) + return cls(w2v_encoder, args) + + def get_normalized_probs(self, net_output, log_probs): + """Get normalized probabilities (or log probs) from a net's output.""" + logits = net_output["encoder_out"] + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def forward(self, **kwargs): + x = self.w2v_encoder(**kwargs) + return x + + +class W2lDecoder(object): + def __init__(self, tgt_dict): + self.tgt_dict = tgt_dict + self.vocab_size = len(tgt_dict) + self.nbest = 1 + + self.criterion_type = CriterionType.CTC + self.blank = ( + tgt_dict.index("") + if "" in tgt_dict.indices + else tgt_dict.bos() + ) + self.asg_transitions = None + + def generate(self, model, sample, **unused): + """Generate a batch of inferences.""" + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" + } + emissions = self.get_emissions(model, encoder_input) + return self.decode(emissions) + + def get_emissions(self, model, encoder_input): + """Run encoder and normalize emissions""" + # encoder_out = models[0].encoder(**encoder_input) + encoder_out = model(**encoder_input) + if self.criterion_type == CriterionType.CTC: + emissions = model.get_normalized_probs(encoder_out, log_probs=True) + + return emissions.transpose(0, 1).float().cpu().contiguous() + + def get_tokens(self, idxs): + """Normalize tokens by handling CTC blank, ASG replabels, etc.""" + idxs = (g[0] for g in it.groupby(idxs)) + idxs = filter(lambda x: x != self.blank, idxs) + + return torch.LongTensor(list(idxs)) + + +class W2lViterbiDecoder(W2lDecoder): + def __init__(self, tgt_dict): + super().__init__(tgt_dict) + + def decode(self, emissions): + B, T, N = emissions.size() + hypos = list() + + if self.asg_transitions is None: + transitions = torch.FloatTensor(N, N).zero_() + else: + transitions = torch.FloatTensor(self.asg_transitions).view(N, N) + + viterbi_path = torch.IntTensor(B, T) + workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N)) + CpuViterbiPath.compute( + B, + T, + N, + get_data_ptr_as_bytes(emissions), + get_data_ptr_as_bytes(transitions), + get_data_ptr_as_bytes(viterbi_path), + get_data_ptr_as_bytes(workspace), + ) + return [ + [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] + for b in range(B) + ] + + +def post_process(sentence: str, symbol: str): + if symbol == "sentencepiece": + sentence = sentence.replace(" ", "").replace("\u2581", " ").strip() + elif symbol == "wordpiece": + sentence = sentence.replace(" ", "").replace("_", " ").strip() + elif symbol == "letter": + sentence = sentence.replace(" ", "").replace("|", " ").strip() + elif symbol == "_EOW": + sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() + elif symbol is not None and symbol != "none": + sentence = (sentence + " ").replace(symbol, "").rstrip() + return sentence + + +def get_feature(filepath): + def postprocess(feats, sample_rate): + if feats.dim == 2: + feats = feats.mean(-1) + + assert feats.dim() == 1, feats.dim() + + with torch.no_grad(): + feats = F.layer_norm(feats, feats.shape) + return feats + + wav, sample_rate = sf.read(filepath) + feats = torch.from_numpy(wav).float() + if torch.cuda.is_available(): + feats = feats.cuda() + feats = postprocess(feats, sample_rate) + return feats + + +def load_model(ctc_model_path, w2v_model_path, target_dict): + w2v = torch.load(ctc_model_path) + w2v["args"].w2v_path = w2v_model_path + model = Wav2VecCtc.build_model(w2v["args"], target_dict) + model.load_state_dict(w2v["model"], strict=True) + if torch.cuda.is_available(): + model = model.cuda() + return model + + +class Wav2Vec2ASR(object): + """docstring for Wav2Vec2ASR.""" + + def __init__(self, ctc_path, w2v_path, target_dict_path): + super(Wav2Vec2ASR, self).__init__() + self.target_dict = Dictionary.load(target_dict_path) + + self.model = load_model(ctc_path, w2v_path, self.target_dict) + self.model.eval() + + self.generator = W2lViterbiDecoder(self.target_dict) + + def transcribe(self, audio_data, greedy=True): + aud_f = BytesIO(audio_data) + # aud_seg = pydub.AudioSegment.from_file(aud_f) + # feat_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000) + # feat_f = io.BytesIO() + # feat_seg.export(feat_f, format='wav') + # feat_f.seek(0) + net_input = {} + feature = get_feature(aud_f) + net_input["source"] = feature.unsqueeze(0) + + padding_mask = ( + torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0) + ) + if torch.cuda.is_available(): + padding_mask = padding_mask.cuda() + + net_input["padding_mask"] = padding_mask + sample = {} + sample["net_input"] = net_input + + with torch.no_grad(): + hypo = self.generator.generate(self.model, sample, prefix_tokens=None) + hyp_pieces = self.target_dict.string(hypo[0][0]["tokens"].int().cpu()) + result = post_process(hyp_pieces, "letter") + return result diff --git a/plume/models/wav2vec2/data.py b/plume/models/wav2vec2/data.py new file mode 100644 index 0000000..b79bbef --- /dev/null +++ b/plume/models/wav2vec2/data.py @@ -0,0 +1,86 @@ +from pathlib import Path +from collections import Counter +import shutil + +import soundfile +# import pydub +import typer +from tqdm import tqdm + +from plume.utils import ( + ExtendedPath, + replace_redundant_spaces_with, + lazy_module +) +pydub = lazy_module('pydub') + +app = typer.Typer() + + +@app.command() +def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True): + dict_ltr = dest_dataset_path / Path("dict.ltr.txt") + (dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True) + tok_counter = Counter() + shutil.copy( + src_dataset_path / Path("test_manifest.json"), + src_dataset_path / Path("valid_manifest.json"), + ) + if unlink: + src_wavs = src_dataset_path / Path("wavs") + for wav_path in tqdm(list(src_wavs.glob("**/*.wav"))): + audio_seg = ( + pydub.AudioSegment.from_wav(wav_path) + .set_frame_rate(16000) + .set_channels(1) + ) + dest_path = dest_dataset_path / Path("wavs") / Path(wav_path.name) + audio_seg.export(dest_path, format="wav") + + for dataset_kind in ["train", "valid"]: + abs_manifest_path = ExtendedPath( + src_dataset_path / Path(f"{dataset_kind}_manifest.json") + ) + manifest_data = list(abs_manifest_path.read_jsonl()) + o_tsv, o_ltr = f"{dataset_kind}.tsv", f"{dataset_kind}.ltr" + out_tsv = dest_dataset_path / Path(o_tsv) + out_ltr = dest_dataset_path / Path(o_ltr) + with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f: + if unlink: + tsv_f.write(f"{dest_dataset_path}\n") + else: + tsv_f.write(f"{src_dataset_path}\n") + for md in manifest_data: + audio_fname = md["audio_filepath"] + pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper() + # pipe_toks = "|".join(re.sub(" ", "", md["text"])) + # pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|") + tok_counter.update(pipe_toks) + letter_toks = " ".join(pipe_toks) + " |\n" + frame_count = soundfile.info(audio_fname).frames + rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute()) + ltr_f.write(letter_toks) + tsv_f.write(f"{rel_path}\t{frame_count}\n") + with dict_ltr.open("w") as d_f: + for k, v in tok_counter.most_common(): + d_f.write(f"{k} {v}\n") + (src_dataset_path / Path("valid_manifest.json")).unlink() + + +@app.command() +def set_root(dataset_path: Path, root_path: Path): + for dataset_kind in ["train", "valid"]: + data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv") + with data_file.open("r") as df: + lines = df.readlines() + with data_file.open("w") as df: + lines[0] = str(root_path) + "\n" + df.writelines(lines) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/plume/models/wav2vec2/eval.py b/plume/models/wav2vec2/eval.py new file mode 100644 index 0000000..fbbb3f3 --- /dev/null +++ b/plume/models/wav2vec2/eval.py @@ -0,0 +1,49 @@ +from pathlib import Path +import typer +from tqdm import tqdm +# import pandas as pd + +from plume.utils import ( + asr_manifest_reader, + discard_except_digits, + replace_digit_symbol, + lazy_module + # run_shell, +) +from ...utils.transcribe import triton_transcribe_grpc_gen + +pd = lazy_module('pandas') +app = typer.Typer() + + +@app.command() +def manifest(manifest_file: Path, result_file: Path = "results.csv"): + from pydub import AudioSegment + + host = "localhost" + port = 8044 + transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole') + result_path = manifest_file.parent / result_file + manifest_list = list(asr_manifest_reader(manifest_file)) + + def compute_frame(d): + audio_file = d["audio_path"] + orig_text = d["text"] + orig_num = discard_except_digits(replace_digit_symbol(orig_text)) + aud_seg = AudioSegment.from_file(audio_file) + t_audio = audio_prep(aud_seg) + asr_text = transcriber(t_audio) + asr_num = discard_except_digits(replace_digit_symbol(asr_text)) + return { + "audio_file": audio_file, + "asr_text": asr_text, + "asr_num": asr_num, + "orig_text": orig_text, + "orig_num": orig_num, + "asr_match": orig_num == asr_num, + } + + # df_data = parallel_apply(compute_frame, manifest_list) + df_data = map(compute_frame, tqdm(manifest_list)) + df = pd.DataFrame(df_data) + df.to_csv(result_path) diff --git a/plume/models/wav2vec2/serve.py b/plume/models/wav2vec2/serve.py new file mode 100644 index 0000000..b549904 --- /dev/null +++ b/plume/models/wav2vec2/serve.py @@ -0,0 +1,53 @@ +import os +import logging +from pathlib import Path + +# from rpyc.utils.server import ThreadedServer +import typer + +from ...utils.serve import ASRService +from plume.utils import lazy_callable +# from .asr import Wav2Vec2ASR + +ThreadedServer = lazy_callable('rpyc.utils.server.ThreadedServer') +Wav2Vec2ASR = lazy_callable('plume.models.wav2vec2.asr.Wav2Vec2ASR') + +app = typer.Typer() + + +@app.command() +def rpyc( + w2v_path: Path = "/path/to/base.pt", + ctc_path: Path = "/path/to/ctc.pt", + target_dict_path: Path = "/path/to/dict.ltr.txt", + port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")), +): + for p in [w2v_path, ctc_path, target_dict_path]: + if not p.exists(): + logging.info(f"{p} doesn't exists") + return + w2vasr = Wav2Vec2ASR(str(ctc_path), str(w2v_path), str(target_dict_path)) + service = ASRService(w2vasr) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logging.info("starting asr server...") + t = ThreadedServer(service, port=port) + t.start() + + +@app.command() +def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))): + ctc_path = model_dir / Path("ctc.pt") + w2v_path = model_dir / Path("base.pt") + target_dict_path = model_dir / Path("dict.ltr.txt") + rpyc(w2v_path, ctc_path, target_dict_path, port) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/plume/models/wav2vec2/train.py b/plume/models/wav2vec2/train.py new file mode 100644 index 0000000..ffbaeca --- /dev/null +++ b/plume/models/wav2vec2/train.py @@ -0,0 +1,34 @@ +import typer +# from fairseq_cli.train import cli_main +import sys +from pathlib import Path +import shlex +from plume.utils import lazy_callable + +cli_main = lazy_callable('fairseq_cli.train.cli_main') + +app = typer.Typer() + + +@app.command() +def local(dataset_path: Path): + args = f'''--distributed-world-size 1 {dataset_path} \ +--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \ +valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \ +--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \ +--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \ +--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \ +--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \ +--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \ +--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \ +--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \ +--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize +''' + new_args = ['train.py'] + new_args.extend(shlex.split(args)) + sys.argv = new_args + cli_main() + + +if __name__ == "__main__": + cli_main() diff --git a/plume/ui/__init__.py b/plume/ui/__init__.py new file mode 100644 index 0000000..3aa516d --- /dev/null +++ b/plume/ui/__init__.py @@ -0,0 +1,64 @@ +import typer +import sys +from pathlib import Path + +from plume.utils import lazy_module +# from streamlit import cli as stcli + +stcli = lazy_module('streamlit.cli') +app = typer.Typer() + + +@app.command() +def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""): + annotation_lit_path = Path(__file__).parent / Path("annotation.py") + if task_id: + sys.argv = [ + "streamlit", + "run", + str(annotation_lit_path), + "--", + str(data_dir), + "--task-id", + task_id, + "--dump-fname", + dump_fname, + ] + else: + sys.argv = [ + "streamlit", + "run", + str(annotation_lit_path), + "--", + str(data_dir), + "--dump-fname", + dump_fname, + ] + sys.exit(stcli.main()) + + +@app.command() +def preview(manifest_path: Path): + annotation_lit_path = Path(__file__).parent / Path("preview.py") + sys.argv = [ + "streamlit", + "run", + str(annotation_lit_path), + "--", + str(manifest_path) + ] + sys.exit(stcli.main()) + + +@app.command() +def collection(data_dir: Path, task_id: str = ""): + # TODO: Implement web ui for data collection + pass + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data/validation/ui.py b/plume/ui/annotation.py similarity index 78% rename from jasper/data/validation/ui.py rename to plume/ui/annotation.py index 8d6a72c..1c45c54 100644 --- a/jasper/data/validation/ui.py +++ b/plume/ui/annotation.py @@ -1,10 +1,12 @@ +# import sys from pathlib import Path +from uuid import uuid4 import streamlit as st import typer -from uuid import uuid4 -from ..utils import ExtendedPath, get_mongo_conn -from .st_rerun import rerun + +from plume.utils import ExtendedPath, get_mongo_conn +from plume.preview.st_rerun import rerun app = typer.Typer() @@ -42,10 +44,10 @@ if not hasattr(st, "mongo_connected"): upsert=True, ) - def set_task_fn(mf_path, task_id): + def set_task_fn(data_path, task_id): if task_id: st.task_id = task_id - task_path = mf_path.parent / Path(f"task-{st.task_id}.lck") + task_path = data_path / Path(f"task-{st.task_id}.lck") if not task_path.exists(): print(f"creating task lock at {task_path}") task_path.touch() @@ -62,17 +64,28 @@ if not hasattr(st, "mongo_connected"): @st.cache() -def load_ui_data(validation_ui_data_path: Path): +def load_ui_data(data_dir: Path, dump_fname: Path): + validation_ui_data_path = data_dir / dump_fname typer.echo(f"Using validation ui data from {validation_ui_data_path}") return ExtendedPath(validation_ui_data_path).read_json() +def show_key(sample, key, trail=""): + if key in sample: + title = key.replace("_", " ").title() + if type(sample[key]) == float: + st.sidebar.markdown(f"{title}: {sample[key]:.2f}{trail}") + else: + st.sidebar.markdown(f"{title}: {sample[key]}") + + @app.command() -def main(manifest: Path, task_id: str = ""): - st.set_task(manifest, task_id) - ui_config = load_ui_data(manifest) +def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""): + st.set_task(data_dir, task_id) + ui_config = load_ui_data(data_dir, dump_fname) asr_data = ui_config["data"] annotation_only = ui_config.get("annotation_only", False) + asr_result_key = ui_config.get("asr_result_key", "pretrained_asr") sample_no = st.get_current_cursor() if len(asr_data) - 1 < sample_no or sample_no < 0: print("Invalid samplno resetting to 0") @@ -91,15 +104,16 @@ def main(manifest: Path, task_id: str = ""): st.update_cursor(new_sample - 1) st.sidebar.title(f"Details: [{sample['real_idx']}]") st.sidebar.markdown(f"Gold Text: **{sample['text']}**") + # if "caller" in sample: + # st.sidebar.markdown(f"Caller: **{sample['caller']}**") + show_key(sample, "caller") if not annotation_only: - st.sidebar.title("Results:") - st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**") - if "caller" in sample: - st.sidebar.markdown(f"Caller: **{sample['caller']}**") - else: - st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") - st.sidebar.image(Path(sample["plot_path"]).read_bytes()) - st.audio(Path(sample["audio_path"]).open("rb")) + show_key(sample, asr_result_key) + show_key(sample, "asr_wer", trail="%") + show_key(sample, "correct_candidate") + + st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes()) + st.audio((data_dir / Path(sample["audio_path"])).open("rb")) # set default to text corrected = sample["text"] correction_entry = st.get_correction_entry(sample["utterance_id"]) diff --git a/plume/ui/preview.py b/plume/ui/preview.py new file mode 100644 index 0000000..60f8dd6 --- /dev/null +++ b/plume/ui/preview.py @@ -0,0 +1,58 @@ +from pathlib import Path + +import streamlit as st +import typer +from plume.utils import ExtendedPath +from plume.preview.st_rerun import rerun + +app = typer.Typer() + +if not hasattr(st, "state_lock"): + # st.task_id = str(uuid4()) + task_path = ExtendedPath("preview.lck") + + def current_cursor_fn(): + return task_path.read_json()["current_cursor"] + + def update_cursor_fn(val=0): + task_path.write_json({"current_cursor": val}) + rerun() + + st.get_current_cursor = current_cursor_fn + st.update_cursor = update_cursor_fn + st.state_lock = True + # cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id}) + # if not cursor_obj: + update_cursor_fn(0) + + +@st.cache() +def load_ui_data(validation_ui_data_path: Path): + typer.echo(f"Using validation ui data from {validation_ui_data_path}") + return list(ExtendedPath(validation_ui_data_path).read_jsonl()) + + +@app.command() +def main(manifest: Path): + asr_data = load_ui_data(manifest) + sample_no = st.get_current_cursor() + if len(asr_data) - 1 < sample_no or sample_no < 0: + print("Invalid samplno resetting to 0") + st.update_cursor(0) + sample = asr_data[sample_no] + st.title(f"ASR Manifest Preview") + st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**") + new_sample = st.number_input( + "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) + ) + if new_sample != sample_no + 1: + st.update_cursor(new_sample - 1) + st.sidebar.markdown(f"Gold Text: **{sample['text']}**") + st.audio((manifest.parent / Path(sample["audio_filepath"])).open("rb")) + + +if __name__ == "__main__": + try: + app() + except SystemExit: + pass diff --git a/jasper/data/validation/st_rerun.py b/plume/ui/st_rerun.py similarity index 63% rename from jasper/data/validation/st_rerun.py rename to plume/ui/st_rerun.py index ae80624..6243d70 100644 --- a/jasper/data/validation/st_rerun.py +++ b/plume/ui/st_rerun.py @@ -1,7 +1,15 @@ -import streamlit.ReportThread as ReportThread -from streamlit.ScriptRequestQueue import RerunData -from streamlit.ScriptRunner import RerunException -from streamlit.server.Server import Server +try: + # Before Streamlit 0.65 + from streamlit.ReportThread import get_report_ctx + from streamlit.server.Server import Server + from streamlit.ScriptRequestQueue import RerunData + from streamlit.ScriptRunner import RerunException +except ModuleNotFoundError: + # After Streamlit 0.65 + from streamlit.report_thread import get_report_ctx + from streamlit.server.server import Server + from streamlit.script_request_queue import RerunData + from streamlit.script_runner import RerunException def rerun(): @@ -13,7 +21,7 @@ def rerun(): def _get_widget_states(): # Hack to get the session object from Streamlit. - ctx = ReportThread.get_report_ctx() + ctx = get_report_ctx() session = None @@ -34,5 +42,4 @@ def _get_widget_states(): "Are you doing something fancy with threads?" ) # Got the session object! - return session._widget_states diff --git a/plume/utils/__init__.py b/plume/utils/__init__.py new file mode 100644 index 0000000..2b43219 --- /dev/null +++ b/plume/utils/__init__.py @@ -0,0 +1,486 @@ +import io +import os +import re +import json +import wave +import logging +from pathlib import Path +from functools import partial +from uuid import uuid4 +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor +import subprocess +import shutil +from urllib.parse import urlsplit +# from .lazy_loader import LazyLoader +from .lazy_import import lazy_callable, lazy_module + +# from ruamel.yaml import YAML +# import boto3 +import typer +# import pymongo +# from slugify import slugify +# import pydub +# import matplotlib.pyplot as plt +# import librosa +# import librosa.display as audio_display +# from natural.date import compress +# from num2words import num2words +from tqdm import tqdm +from datetime import timedelta + +# from .transcribe import triton_transcribe_grpc_gen +# from .eval import app as eval_app +from .tts import app as tts_app +from .transcribe import app as transcribe_app +from .align import app as align_app + +boto3 = lazy_module('boto3') +pymongo = lazy_module('pymongo') +pydub = lazy_module('pydub') +audio_display = lazy_module('librosa.display') +plt = lazy_module('matplotlib.pyplot') +librosa = lazy_module('librosa') +YAML = lazy_callable('ruamel.yaml.YAML') +num2words = lazy_callable('num2words.num2words') +slugify = lazy_callable('slugify.slugify') +compress = lazy_callable('natural.date.compress') + +app = typer.Typer() +app.add_typer(tts_app, name="tts") +app.add_typer(align_app, name="align") +app.add_typer(transcribe_app, name="transcribe") + + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def manifest_str(path, dur, text): + return ( + json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text}) + + "\n" + ) + + +def duration_str(seconds): + return compress(timedelta(seconds=seconds), pad=" ") + + +def replace_digit_symbol(w2v_out): + num_int_map = {num2words(i): str(i) for i in range(10)} + out = w2v_out.lower() + for (k, v) in num_int_map.items(): + out = re.sub(k, v, out) + return out + + +def discard_except_digits(inp): + return re.sub("[^0-9]", "", inp) + + +def digits_to_chars(text): + num_tokens = [num2words(c) + " " if "0" <= c <= "9" else c for c in text] + return ("".join(num_tokens)).lower() + + +def replace_redundant_spaces_with(text, sub): + return re.sub(" +", sub, text) + + +def space_out(text): + letters = " ".join(list(text)) + return letters + + +def wav_bytes(audio_bytes, frame_rate=24000): + wf_b = io.BytesIO() + with wave.open(wf_b, mode="w") as wf: + wf.setnchannels(1) + wf.setframerate(frame_rate) + wf.setsampwidth(2) + wf.writeframesraw(audio_bytes) + return wf_b.getvalue() + + +def tscript_uuid_fname(transcript): + return str(uuid4()) + "_" + slugify(transcript, max_length=8) + + +def run_shell(cmd_str, work_dir="."): + cwd_path = Path(work_dir).absolute() + p = subprocess.Popen( + cmd_str, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + shell=True, + cwd=cwd_path, + ) + for line in p.stdout: + print(line.replace(b"\n", b"").decode("utf-8")) + + +def upload_s3(dataset_path, s3_path): + run_shell(f"aws s3 sync {dataset_path} {s3_path}") + + +def get_download_path(s3_uri, output_path): + s3_uri_p = urlsplit(s3_uri) + download_path = output_path / Path(s3_uri_p.path[1:]) + download_path.parent.mkdir(exist_ok=True, parents=True) + return download_path + + +def s3_downloader(): + s3 = boto3.client("s3") + + def download_s3(s3_uri, download_path): + s3_uri_p = urlsplit(s3_uri) + download_path.parent.mkdir(exist_ok=True, parents=True) + if not download_path.exists(): + print(f"downloading {s3_uri} to {download_path}") + s3.download_file(s3_uri_p.netloc, s3_uri_p.path[1:], str(download_path)) + + return download_s3 + + +def asr_data_writer(dataset_dir, asr_data_source, verbose=False): + (dataset_dir / Path("wavs")).mkdir(parents=True, exist_ok=True) + asr_manifest = dataset_dir / Path("manifest.json") + num_datapoints = 0 + with asr_manifest.open("w") as mf: + print(f"writing manifest to {asr_manifest}") + for transcript, audio_dur, wav_data in asr_data_source: + fname = tscript_uuid_fname(transcript) + audio_file = dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav") + audio_file.write_bytes(wav_data) + rel_data_path = audio_file.relative_to(dataset_dir) + manifest = manifest_str(str(rel_data_path), audio_dur, transcript) + mf.write(manifest) + if verbose: + print(f"writing '{transcript}' of duration {audio_dur}") + num_datapoints += 1 + return num_datapoints + + +def ui_data_generator(dataset_dir, asr_data_source, verbose=False): + (dataset_dir / Path("wavs")).mkdir(parents=True, exist_ok=True) + (dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True) + + def data_fn( + transcript, + audio_dur, + wav_data, + caller_name, + aud_seg, + fname, + audio_file, + num_datapoints, + rel_data_path, + ): + png_path = Path(fname).with_suffix(".png") + rel_plot_path = Path("wav_plots") / png_path + wav_plot_path = dataset_dir / rel_plot_path + if not wav_plot_path.exists(): + plot_seg(wav_plot_path.absolute(), audio_file) + return { + "audio_path": str(rel_data_path), + "duration": round(audio_dur, 1), + "text": transcript, + "real_idx": num_datapoints, + "caller": caller_name, + "utterance_id": fname, + "plot_path": str(rel_plot_path), + } + + num_datapoints = 0 + data_funcs = [] + for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source: + fname = str(uuid4()) + "_" + slugify(transcript, max_length=8) + audio_file = ( + dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav") + ).absolute() + audio_file.write_bytes(wav_data) + # audio_path = str(audio_file) + rel_data_path = audio_file.relative_to(dataset_dir.absolute()) + data_funcs.append( + partial( + data_fn, + transcript, + audio_dur, + wav_data, + caller_name, + aud_seg, + fname, + audio_file, + num_datapoints, + rel_data_path, + ) + ) + num_datapoints += 1 + ui_data = parallel_apply(lambda x: x(), data_funcs) + return ui_data, num_datapoints + + +def ui_dump_manifest_writer(dataset_dir, asr_data_source, verbose=False): + dump_data, num_datapoints = ui_data_generator( + dataset_dir, asr_data_source, verbose=verbose + ) + + asr_manifest = dataset_dir / Path("manifest.json") + with asr_manifest.open("w") as mf: + print(f"writing manifest to {asr_manifest}") + for d in dump_data: + rel_data_path = d["audio_path"] + audio_dur = d["duration"] + transcript = d["text"] + manifest = manifest_str(str(rel_data_path), audio_dur, transcript) + mf.write(manifest) + + ui_dump_file = dataset_dir / Path("ui_dump.json") + ExtendedPath(ui_dump_file).write_json({"data": dump_data}) + return num_datapoints + + +def asr_manifest_reader(data_manifest_path: Path): + print(f"reading manifest from {data_manifest_path}") + with data_manifest_path.open("r") as pf: + data_jsonl = pf.readlines() + data_data = [json.loads(v) for v in data_jsonl] + for p in data_data: + p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"]) + p["text"] = p["text"].strip() + yield p + + +def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source): + with asr_manifest_path.open("w") as mf: + print(f"opening {asr_manifest_path} for writing manifest") + for mani_dict in manifest_str_source: + manifest = manifest_str( + mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"] + ) + mf.write(manifest) + + +def asr_test_writer(out_file_path: Path, source): + def dd_str(dd, idx): + path = dd["audio_filepath"] + # dur = dd["duration"] + # return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n" + return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n" + + res_file = out_file_path.with_suffix(".result.json") + with out_file_path.open("w") as of: + print(f"opening {out_file_path} for writing test") + results = [] + idx = 0 + for ui_dd in source: + results.append(ui_dd) + out_str = dd_str(ui_dd, idx) + of.write(out_str) + idx += 1 + of.write("DO_HANGUP\n") + ExtendedPath(res_file).write_json(results) + + +def batch(iterable, n=1): + ls = len(iterable) + return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)] + + +class ExtendedPath(type(Path())): + """docstring for ExtendedPath.""" + + def read_json(self): + print(f"reading json from {self}") + with self.open("r") as jf: + return json.load(jf) + + def read_yaml(self): + yaml = YAML(typ="safe", pure=True) + print(f"reading yaml from {self}") + with self.open("r") as yf: + return yaml.load(yf) + + def read_jsonl(self): + print(f"reading jsonl from {self}") + with self.open("r") as jf: + for l in jf.readlines(): + yield json.loads(l) + + def write_json(self, data): + print(f"writing json to {self}") + self.parent.mkdir(parents=True, exist_ok=True) + with self.open("w") as jf: + json.dump(data, jf, indent=2) + + def write_yaml(self, data): + yaml = YAML() + print(f"writing yaml to {self}") + with self.open("w") as yf: + yaml.dump(data, yf) + + def write_jsonl(self, data): + print(f"writing jsonl to {self}") + self.parent.mkdir(parents=True, exist_ok=True) + with self.open("w") as jf: + for d in data: + jf.write(json.dumps(d) + "\n") + + +def get_mongo_coll(uri): + ud = pymongo.uri_parser.parse_uri(uri) + conn = pymongo.MongoClient(uri) + return conn[ud["database"]][ud["collection"]] + + +def get_mongo_conn(host="", port=27017, db="db", col="collection"): + mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost") + mongo_uri = f"mongodb://{mongo_host}:{port}/" + return pymongo.MongoClient(mongo_uri)[db][col] + + +def strip_silence(sound): + from pydub.silence import detect_leading_silence + + start_trim = detect_leading_silence(sound) + end_trim = detect_leading_silence(sound.reverse()) + duration = len(sound) + return sound[start_trim : duration - end_trim] + + +def plot_seg(wav_plot_path, audio_path): + fig = plt.Figure() + ax = fig.add_subplot() + (y, sr) = librosa.load(str(audio_path)) + audio_display.waveplot(y=y, sr=sr, ax=ax) + with wav_plot_path.open("wb") as wav_plot_f: + fig.set_tight_layout(True) + fig.savefig(wav_plot_f, format="png", dpi=50) + + +def parallel_apply(fn, iterable, workers=8, pool="thread"): + if pool == "thread": + with ThreadPoolExecutor(max_workers=workers) as exe: + print(f"parallelly applying {fn}") + return [ + res + for res in tqdm( + exe.map(fn, iterable), position=0, leave=True, total=len(iterable) + ) + ] + elif pool == "process": + with ProcessPoolExecutor(max_workers=workers) as exe: + print(f"parallelly applying {fn}") + return [ + res + for res in tqdm( + exe.map(fn, iterable), position=0, leave=True, total=len(iterable) + ) + ] + else: + raise Exception(f"unsupported pool type - {pool}") + + +def generate_filter_map(src_dataset_path, dest_dataset_path, data_file): + min_nums = 3 + max_duration = 1 * 60 * 60 + skip_duration = 1 * 60 * 60 + + def filtered_max_dur(): + wav_duration = 0 + for s in ExtendedPath(data_file).read_jsonl(): + nums = re.sub(" ", "", s["text"]) + if len(nums) >= min_nums: + wav_duration += s["duration"] + shutil.copy( + src_dataset_path / Path(s["audio_filepath"]), + dest_dataset_path / Path(s["audio_filepath"]), + ) + yield s + if wav_duration > max_duration: + break + typer.echo(f"filtered only {duration_str(wav_duration)} of audio") + + def filtered_skip_dur(): + wav_duration = 0 + for s in ExtendedPath(data_file).read_jsonl(): + nums = re.sub(" ", "", s["text"]) + if len(nums) >= min_nums: + wav_duration += s["duration"] + if wav_duration <= skip_duration: + continue + elif len(nums) >= min_nums: + yield s + shutil.copy( + src_dataset_path / Path(s["audio_filepath"]), + dest_dataset_path / Path(s["audio_filepath"]), + ) + typer.echo(f"skipped {duration_str(skip_duration)} of audio") + + def filtered_blanks(): + blank_count = 0 + for s in ExtendedPath(data_file).read_jsonl(): + nums = re.sub(" ", "", s["text"]) + if nums != "": + blank_count += 1 + shutil.copy( + src_dataset_path / Path(s["audio_filepath"]), + dest_dataset_path / Path(s["audio_filepath"]), + ) + yield s + typer.echo(f"filtered {blank_count} blank samples") + + def filtered_transform_digits(): + count = 0 + for s in ExtendedPath(data_file).read_jsonl(): + count += 1 + digit_text = replace_digit_symbol(s["text"]) + only_digits = discard_except_digits(digit_text) + char_text = digits_to_chars(only_digits) + shutil.copy( + src_dataset_path / Path(s["audio_filepath"]), + dest_dataset_path / Path(s["audio_filepath"]), + ) + s["text"] = char_text + yield s + typer.echo(f"transformed {count} samples") + + def filtered_extract_chars(): + count = 0 + for s in ExtendedPath(data_file).read_jsonl(): + count += 1 + no_digits = digits_to_chars(s["text"]).upper() + only_chars = re.sub("[^A-Z'\b]", " ", no_digits) + filter_text = replace_redundant_spaces_with(only_chars, " ").strip() + shutil.copy( + src_dataset_path / Path(s["audio_filepath"]), + dest_dataset_path / Path(s["audio_filepath"]), + ) + s["text"] = filter_text + yield s + typer.echo(f"transformed {count} samples") + + def filtered_resample(): + count = 0 + for s in ExtendedPath(data_file).read_jsonl(): + count += 1 + src_aud = pydub.AudioSegment.from_file( + src_dataset_path / Path(s["audio_filepath"]) + ) + dst_aud = src_aud.set_channels(1).set_sample_width(1).set_frame_rate(24000) + dst_aud.export(dest_dataset_path / Path(s["audio_filepath"]), format="wav") + yield s + typer.echo(f"transformed {count} samples") + + filter_kind_map = { + "max_dur_1hr_min3num": filtered_max_dur, + "skip_dur_1hr_min3num": filtered_skip_dur, + "blanks": filtered_blanks, + "transform_digits": filtered_transform_digits, + "extract_chars": filtered_extract_chars, + "resample_ulaw24kmono": filtered_resample, + } + return filter_kind_map diff --git a/plume/utils/align.py b/plume/utils/align.py new file mode 100644 index 0000000..5c9c74c --- /dev/null +++ b/plume/utils/align.py @@ -0,0 +1,117 @@ +from pathlib import Path +from .tts import GoogleTTS +# from IPython import display +import requests +import io +import typer + +from plume.utils import lazy_module + +display = lazy_module('IPython.display') +pydub = lazy_module('pydub') + +app = typer.Typer() + +# Start gentle with following command +# docker run --rm -d --name gentle_service -p 8765:8765/tcp lowerquality/gentle + + +def gentle_aligner(service_uri, wav_data, utter_text): + # service_uri= "http://52.41.161.36:8765/transcriptions" + wav_f = io.BytesIO(wav_data) + wav_seg = pydub.AudioSegment.from_file(wav_f) + + mp3_f = io.BytesIO() + wav_seg.export(mp3_f, format="mp3") + mp3_f.seek(0) + params = (("async", "false"),) + files = { + "audio": ("audio.mp3", mp3_f), + "transcript": ("words.txt", io.BytesIO(utter_text.encode("utf-8"))), + } + + response = requests.post(service_uri, params=params, files=files) + print(f"Time duration of audio {wav_seg.duration_seconds}") + print(f"Time taken to align: {response.elapsed}s") + return wav_seg, response.json() + + +def gentle_align_iter(service_uri, wav_data, utter_text): + wav_seg, response = gentle_aligner(service_uri, wav_data, utter_text) + for span in response: + word_seg = wav_seg[int(span["start"] * 1000) : int(span["end"] * 1000)] + word = span["word"] + yield (word, word_seg) + + +def tts_jupyter(): + google_voices = GoogleTTS.voice_list() + gtts = GoogleTTS() + # google_voices[4] + us_voice = [v for v in google_voices if v["language"] == "en-US"][0] + utter_text = ( + "I would like to align the audio segments based on word level timestamps" + ) + wav_data = gtts.text_to_speech(text=utter_text, params=us_voice) + for word, seg in gentle_align_iter(wav_data, utter_text): + print(word) + display.display(seg) + + +@app.command() +def cut(audio_path: Path, transcript_path: Path, out_dir: Path = "/tmp"): + from . import ExtendedPath + import datetime + import re + + aud_seg = pydub.AudioSegment.from_file(audio_path) + aud_seg[: 15 * 60 * 1000].export(out_dir / Path("audio.mp3"), format="mp3") + tscript_json = ExtendedPath(transcript_path).read_json() + + def time_to_msecs(time_str): + return ( + datetime.datetime.strptime(time_str, "%H:%M:%S,%f") + - datetime.datetime(1900, 1, 1) + ).total_seconds() * 1000 + + tscript_words = [] + broken = False + for m in tscript_json["monologues"]: + # tscript_words.append("|") + for e in m["elements"]: + if e["type"] == "text": + text = e["value"] + text = re.sub(r"\[.*\]", "", text) + text = re.sub(r"\(.*\)", "", text) + tscript_words.append(text) + if "timestamp" in e and time_to_msecs(e["timestamp"]) >= 15 * 60 * 1000: + broken = True + break + if broken: + break + (out_dir / Path("words.txt")).write_text("".join(tscript_words)) + + +@app.command() +def gentle_preview( + audio_path: Path, + transcript_path: Path, + service_uri="http://101.53.142.218:8765/transcriptions", + gent_preview_dir="../gentle_preview", +): + from . import ExtendedPath + + ab = audio_path.read_bytes() + tt = transcript_path.read_text() + audio, alignment = gentle_aligner(service_uri, ab, tt) + audio.export(gent_preview_dir / Path("a.wav"), format="wav") + alignment["status"] = "OK" + ExtendedPath(gent_preview_dir / Path("status.json")).write_json(alignment) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/plume/utils/audio.py b/plume/utils/audio.py new file mode 100644 index 0000000..35b9f8f --- /dev/null +++ b/plume/utils/audio.py @@ -0,0 +1,28 @@ +from scipy.signal import lfilter, butter +from scipy.io.wavfile import read, write +from numpy import array, int16 +import sys + + +def butter_params(low_freq, high_freq, fs, order=5): + nyq = 0.5 * fs + low = low_freq / nyq + high = high_freq / nyq + b, a = butter(order, [low, high], btype="band") + return b, a + + +def butter_bandpass_filter(data, low_freq, high_freq, fs, order=5): + b, a = butter_params(low_freq, high_freq, fs, order=order) + y = lfilter(b, a, data) + return y + + +if __name__ == "__main__": + fs, audio = read(sys.argv[1]) + import pdb; pdb.set_trace() + low_freq = 300.0 + high_freq = 4000.0 + filtered_signal = butter_bandpass_filter(audio, low_freq, high_freq, fs, order=6) + fname = sys.argv[1].split(".wav")[0] + "_moded.wav" + write(fname, fs, array(filtered_signal, dtype=int16)) diff --git a/plume/utils/lazy_import.py b/plume/utils/lazy_import.py new file mode 100644 index 0000000..615f596 --- /dev/null +++ b/plume/utils/lazy_import.py @@ -0,0 +1,737 @@ +# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*- +# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4 +# +# lazy_import --- https://github.com/mnmelo/lazy_import +# Copyright (C) 2017-2018 Manuel Nuno Melo +# +# This file is part of lazy_import. +# +# lazy_import is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# lazy_import is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with lazy_import. If not, see . +# +# lazy_import was based on code from the importing module from the PEAK +# package (see ). The PEAK +# package is released under the following license, reproduced here: +# +# Copyright (C) 1996-2004 by Phillip J. Eby and Tyler C. Sarna. +# All rights reserved. This software may be used under the same terms +# as Zope or Python. THERE ARE ABSOLUTELY NO WARRANTIES OF ANY KIND. +# Code quality varies between modules, from "beta" to "experimental +# pre-alpha". :) +# +# Code pertaining to lazy loading from PEAK importing was included in +# lazy_import, modified in a number of ways. These are detailed in the +# CHANGELOG file of lazy_import. Changes mainly involved Python 3 +# compatibility, extension to allow customizable behavior, and added +# functionality (lazy importing of callable objects). +# + +""" +Lazy module loading +=================== +Functions and classes for lazy module loading that also delay import errors. +Heavily borrowed from the `importing`_ module. +.. _`importing`: http://peak.telecommunity.com/DevCenter/Importing +Files and directories +--------------------- +.. autofunction:: module +.. autofunction:: callable +""" + +__all__ = [ + "lazy_module", + "lazy_callable", + "lazy_function", + "lazy_class", + "LazyModule", + "LazyCallable", + "module_basename", + "_MSG", + "_MSG_CALLABLE", +] + +from types import ModuleType +import sys + +try: + from importlib._bootstrap import _ImportLockContext +except ImportError: + # Python 2 doesn't have the context manager. Roll it ourselves (copied from + # Python 3's importlib/_bootstrap.py) + import imp + + class _ImportLockContext: + """Context manager for the import lock.""" + + def __enter__(self): + imp.acquire_lock() + + def __exit__(self, exc_type, exc_value, exc_traceback): + imp.release_lock() + + +# Adding a __spec__ doesn't really help. I'll leave the code here in case +# future python implementations start relying on it. +# try: +# from importlib.machinery import ModuleSpec +# except ImportError: +# ModuleSpec = None + +import six +from six import raise_from +from six.moves import reload_module + +# It is sometime useful to have access to the version number of a library. +# This is usually done through the __version__ special attribute. +# To make sure the version number is consistent between setup.py and the +# library, we read the version number from the file called VERSION that stays +# in the module directory. +import os + +# VERSION_FILE = os.path.join(os.path.dirname(__file__), "VERSION") +# with open(VERSION_FILE) as infile: +# __version__ = infile.read().strip() + +# Logging +import logging + +# adding a TRACE level for stack debugging +_LAZY_TRACE = 1 +logging.addLevelName(1, "LAZY_TRACE") +logging.basicConfig(level=logging.WARNING) +# Logs a formatted stack (takes no message or args/kwargs) +def _lazy_trace(self): + if self.isEnabledFor(_LAZY_TRACE): + import traceback + + self._log(_LAZY_TRACE, " ### STACK TRACE ###", ()) + for line in traceback.format_stack(sys._getframe(2)): + for subline in line.split("\n"): + self._log(_LAZY_TRACE, subline.rstrip(), ()) + + +logging.Logger.lazy_trace = _lazy_trace +logger = logging.getLogger(__name__) + +################################ +# Module/function registration # +################################ + +#### Lazy classes #### + + +class LazyModule(ModuleType): + """Class for lazily-loaded modules that triggers proper loading on access. + Instantiation should be made from a subclass of :class:`LazyModule`, with + one subclass per instantiated module. Regular attribute set/access can then + be recovered by setting the subclass's :meth:`__getattribute__` and + :meth:`__setattribute__` to those of :class:`types.ModuleType`. + """ + + # peak.util.imports sets __slots__ to (), but it seems pointless because + # the base ModuleType doesn't itself set __slots__. + def __getattribute__(self, attr): + logger.debug( + "Getting attr {} of LazyModule instance of {}".format( + attr, super(LazyModule, self).__getattribute__("__name__") + ) + ) + logger.lazy_trace() + # IPython tries to be too clever and constantly inspects, asking for + # modules' attrs, which causes premature module loading and unesthetic + # internal errors if the lazily-loaded module doesn't exist. + if ( + run_from_ipython() + and (attr.startswith(("__", "_ipython")) or attr == "_repr_mimebundle_") + and module_basename(_caller_name()) in ("inspect", "IPython") + ): + logger.debug( + "Ignoring request for {}, deemed from IPython's " + "inspection.".format( + super(LazyModule, self).__getattribute__("__name__"), attr + ) + ) + raise AttributeError + if not attr in ("__name__", "__class__", "__spec__"): + # __name__ and __class__ yield their values from the LazyModule; + # __spec__ causes an AttributeError. Maybe in the future it will be + # necessary to return an actual ModuleSpec object, but it works as + # it is without that now. + + # If it's an already-loaded submodule, we return it without + # triggering a full loading + try: + return sys.modules[self.__name__ + "." + attr] + except KeyError: + pass + # Check if it's one of the lazy callables + try: + _callable = type(self)._lazy_import_callables[attr] + logger.debug("Returning lazy-callable '{}'.".format(attr)) + return _callable + except (AttributeError, KeyError) as err: + logger.debug( + "Proceeding to load module {}, " + "from requested value {}".format( + super(LazyModule, self).__getattribute__("__name__"), attr + ) + ) + _load_module(self) + logger.debug( + "Returning value '{}'.".format( + super(LazyModule, self).__getattribute__(attr) + ) + ) + return super(LazyModule, self).__getattribute__(attr) + + def __setattr__(self, attr, value): + logger.debug( + "Setting attr {} to value {}, in LazyModule instance " + "of {}".format( + attr, value, super(LazyModule, self).__getattribute__("__name__") + ) + ) + _load_module(self) + return super(LazyModule, self).__setattr__(attr, value) + + +class LazyCallable(object): + """Class for lazily-loaded callables that triggers module loading on access + """ + + def __init__(self, *args): + if len(args) != 2: + # Maybe the user tried to base a class off this lazy callable? + try: + logger.debug( + "Got wrong number of args when init'ing " + "LazyCallable. args is '{}'".format(args) + ) + base = args[1][0] + if isinstance(base, LazyCallable) and len(args) == 3: + raise NotImplementedError( + "It seems you are trying to use " + "a lazy callable as a class " + "base. This is not supported." + ) + except (IndexError, TypeError): + raise_from( + TypeError( + "LazyCallable takes exactly 2 arguments: " + "a module/lazy module object and the name of " + "a callable to be lazily loaded." + ), + None, + ) + self.module, self.cname = args + self.modclass = type(self.module) + self.callable = None + # Need to save these, since the module-loading gets rid of them + self.error_msgs = self.modclass._lazy_import_error_msgs + self.error_strings = self.modclass._lazy_import_error_strings + + def __call__(self, *args, **kwargs): + # No need to go through all the reloading more than once. + if self.callable: + return self.callable(*args, **kwargs) + try: + del self.modclass._lazy_import_callables[self.cname] + except (AttributeError, KeyError): + pass + try: + self.callable = getattr(self.module, self.cname) + except AttributeError: + msg = self.error_msgs["msg_callable"] + raise_from( + AttributeError(msg.format(callable=self.cname, **self.error_strings)), + None, + ) + except ImportError as err: + # Import failed. We reset the dict and re-raise the ImportError. + try: + self.modclass._lazy_import_callables[self.cname] = self + except AttributeError: + self.modclass._lazy_import_callables = {self.cname: self} + raise_from(err, None) + else: + return self.callable(*args, **kwargs) + + +### Functions ### + + +def lazy_module(modname, error_strings=None, lazy_mod_class=LazyModule, level="leaf"): + """Function allowing lazy importing of a module into the namespace. + A lazy module object is created, registered in `sys.modules`, and + returned. This is a hollow module; actual loading, and `ImportErrors` if + not found, are delayed until an attempt is made to access attributes of the + lazy module. + A handy application is to use :func:`lazy_module` early in your own code + (say, in `__init__.py`) to register all modulenames you want to be lazy. + Because of registration in `sys.modules` later invocations of + `import modulename` will also return the lazy object. This means that after + initial registration the rest of your code can use regular pyhon import + statements and retain the lazyness of the modules. + Parameters + ---------- + modname : str + The module to import. + error_strings : dict, optional + A dictionary of strings to use when module-loading fails. Key 'msg' + sets the message to use (defaults to :attr:`lazy_import._MSG`). The + message is formatted using the remaining dictionary keys. The default + message informs the user of which module is missing (key 'module'), + what code loaded the module as lazy (key 'caller'), and which package + should be installed to solve the dependency (key 'install_name'). + None of the keys is mandatory and all are given smart names by default. + lazy_mod_class: type, optional + Which class to use when instantiating the lazy module, to allow + deep customization. The default is :class:`LazyModule` and custom + alternatives **must** be a subclass thereof. + level : str, optional + Which submodule reference to return. Either a reference to the 'leaf' + module (the default) or to the 'base' module. This is useful if you'll + be using the module functionality in the same place you're calling + :func:`lazy_module` from, since then you don't need to run `import` + again. Setting *level* does not affect which names/modules get + registered in `sys.modules`. + For *level* set to 'base' and *modulename* 'aaa.bbb.ccc':: + aaa = lazy_import.lazy_module("aaa.bbb.ccc", level='base') + # 'aaa' becomes defined in the current namespace, with + # (sub)attributes 'aaa.bbb' and 'aaa.bbb.ccc'. + # It's the lazy equivalent to: + import aaa.bbb.ccc + For *level* set to 'leaf':: + ccc = lazy_import.lazy_module("aaa.bbb.ccc", level='leaf') + # Only 'ccc' becomes set in the current namespace. + # Lazy equivalent to: + from aaa.bbb import ccc + Returns + ------- + module + The module specified by *modname*, or its base, depending on *level*. + The module isn't immediately imported. Instead, an instance of + *lazy_mod_class* is returned. Upon access to any of its attributes, the + module is finally loaded. + Examples + -------- + >>> import lazy_import, sys + >>> np = lazy_import.lazy_module("numpy") + >>> np + Lazily-loaded module numpy + >>> np is sys.modules['numpy'] + True + >>> np.pi # This causes the full loading of the module ... + 3.141592653589793 + >>> np # ... and the module is changed in place. + + >>> import lazy_import, sys + >>> # The following succeeds even when asking for a module that's not available + >>> missing = lazy_import.lazy_module("missing_module") + >>> missing + Lazily-loaded module missing_module + >>> missing is sys.modules['missing_module'] + True + >>> missing.some_attr # This causes the full loading of the module, which now fails. + ImportError: __main__ attempted to use a functionality that requires module missing_module, but it couldn't be loaded. Please install missing_module and retry. + See Also + -------- + :func:`lazy_callable` + :class:`LazyModule` + """ + if error_strings is None: + error_strings = {} + _set_default_errornames(modname, error_strings) + + mod = _lazy_module(modname, error_strings, lazy_mod_class) + if level == "base": + return sys.modules[module_basename(modname)] + elif level == "leaf": + return mod + else: + raise ValueError("Parameter 'level' must be one of ('base', 'leaf')") + + +def _lazy_module(modname, error_strings, lazy_mod_class): + with _ImportLockContext(): + fullmodname = modname + fullsubmodname = None + # ensure parent module/package is in sys.modules + # and parent.modname=module, as soon as the parent is imported + while modname: + try: + mod = sys.modules[modname] + # We reached a (base) module that's already loaded. Let's stop + # the cycle. Can't use 'break' because we still want to go + # through the fullsubmodname check below. + modname = "" + except KeyError: + err_s = error_strings.copy() + err_s.setdefault("module", modname) + + class _LazyModule(lazy_mod_class): + _lazy_import_error_msgs = {"msg": err_s.pop("msg")} + try: + _lazy_import_error_msgs["msg_callable"] = err_s.pop( + "msg_callable" + ) + except KeyError: + pass + _lazy_import_error_strings = err_s + _lazy_import_callables = {} + _lazy_import_submodules = {} + + def __repr__(self): + return "Lazily-loaded module {}".format(self.__name__) + + # A bit of cosmetic, to make AttributeErrors read more natural + _LazyModule.__name__ = "module" + # Actual module instantiation + mod = sys.modules[modname] = _LazyModule(modname) + # No need for __spec__. Maybe in the future. + # if ModuleSpec: + # ModuleType.__setattr__(mod, '__spec__', + # ModuleSpec(modname, None)) + if fullsubmodname: + submod = sys.modules[fullsubmodname] + ModuleType.__setattr__(mod, submodname, submod) + _LazyModule._lazy_import_submodules[submodname] = submod + fullsubmodname = modname + modname, _, submodname = modname.rpartition(".") + return sys.modules[fullmodname] + + +def lazy_callable(modname, *names, **kwargs): + """Performs lazy importing of one or more callables. + :func:`lazy_callable` creates functions that are thin wrappers that pass + any and all arguments straight to the target module's callables. These can + be functions or classes. The full loading of that module is only actually + triggered when the returned lazy function itself is called. This lazy + import of the target module uses the same mechanism as + :func:`lazy_module`. + + If, however, the target module has already been fully imported prior + to invocation of :func:`lazy_callable`, then the target callables + themselves are returned and no lazy imports are made. + :func:`lazy_function` and :func:`lazy_function` are aliases of + :func:`lazy_callable`. + Parameters + ---------- + modname : str + The base module from where to import the callable(s) in *names*, + or a full 'module_name.callable_name' string. + names : str (optional) + The callable name(s) to import from the module specified by *modname*. + If left empty, *modname* is assumed to also include the callable name + to import. + error_strings : dict, optional + A dictionary of strings to use when reporting loading errors (either a + missing module, or a missing callable name in the loaded module). + *error_string* follows the same usage as described under + :func:`lazy_module`, with the exceptions that 1) a further key, + 'msg_callable', can be supplied to be used as the error when a module + is successfully loaded but the target callable can't be found therein + (defaulting to :attr:`lazy_import._MSG_CALLABLE`); 2) a key 'callable' + is always added with the callable name being loaded. + lazy_mod_class : type, optional + See definition under :func:`lazy_module`. + lazy_call_class : type, optional + Analogously to *lazy_mod_class*, allows setting a custom class to + handle lazy callables, other than the default :class:`LazyCallable`. + Returns + ------- + wrapper function or tuple of wrapper functions + If *names* is passed, returns a tuple of wrapper functions, one for + each element in *names*. + If only *modname* is passed it is assumed to be a full + 'module_name.callable_name' string, in which case the wrapper for the + imported callable is returned directly, and not in a tuple. + + Notes + ----- + Unlike :func:`lazy_module`, which returns a lazy module that eventually + mutates into the fully-functional version, :func:`lazy_callable` only + returns thin wrappers that never change. This means that the returned + wrapper object never truly becomes the one under the module's namespace, + even after successful loading of the module in *modname*. This is fine for + most practical use cases, but may break code that relies on the usage of + the returned objects oter than calling them. One such example is the lazy + import of a class: it's fine to use the returned wrapper to instantiate an + object, but it can't be used, for instance, to subclass from. + Examples + -------- + >>> import lazy_import, sys + >>> fn = lazy_import.lazy_callable("numpy.arange") + >>> sys.modules['numpy'] + Lazily-loaded module numpy + >>> fn(10) + array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + >>> sys.modules['numpy'] + + >>> import lazy_import, sys + >>> cl = lazy_import.lazy_callable("numpy.ndarray") # a class + >>> obj = cl([1, 2]) # This works OK (and also triggers the loading of numpy) + >>> class MySubclass(cl): # This fails because cls is just a wrapper, + >>> pass # not an actual class. + See Also + -------- + :func:`lazy_module` + :class:`LazyCallable` + :class:`LazyModule` + """ + if not names: + modname, _, name = modname.rpartition(".") + lazy_mod_class = _setdef(kwargs, "lazy_mod_class", LazyModule) + lazy_call_class = _setdef(kwargs, "lazy_call_class", LazyCallable) + error_strings = _setdef(kwargs, "error_strings", {}) + _set_default_errornames(modname, error_strings, call=True) + + if not names: + # We allow passing a single string as 'modname.callable_name', + # in which case the wrapper is returned directly and not as a list. + return _lazy_callable( + modname, name, error_strings.copy(), lazy_mod_class, lazy_call_class + ) + return tuple( + _lazy_callable( + modname, cname, error_strings.copy(), lazy_mod_class, lazy_call_class + ) + for cname in names + ) + + +lazy_function = lazy_class = lazy_callable + + +def _lazy_callable(modname, cname, error_strings, lazy_mod_class, lazy_call_class): + # We could do most of this in the LazyCallable __init__, but here we can + # pre-check whether to actually be lazy or not. + module = _lazy_module(modname, error_strings, lazy_mod_class) + modclass = type(module) + if issubclass(modclass, LazyModule) and hasattr(modclass, "_lazy_import_callables"): + modclass._lazy_import_callables.setdefault( + cname, lazy_call_class(module, cname) + ) + return getattr(module, cname) + + +####################### +# Real module loading # +####################### + + +def _load_module(module): + """Ensures that a module, and its parents, are properly loaded + """ + modclass = type(module) + # We only take care of our own LazyModule instances + if not issubclass(modclass, LazyModule): + raise TypeError("Passed module is not a LazyModule instance.") + with _ImportLockContext(): + parent, _, modname = module.__name__.rpartition(".") + logger.debug("loading module {}".format(modname)) + # We first identify whether this is a loadable LazyModule, then we + # strip as much of lazy_import behavior as possible (keeping it cached, + # in case loading fails and we need to reset the lazy state). + if not hasattr(modclass, "_lazy_import_error_msgs"): + # Alreay loaded (no _lazy_import_error_msgs attr). Not reloading. + return + # First, ensure the parent is loaded (using recursion; *very* unlikely + # we'll ever hit a stack limit in this case). + modclass._LOADING = True + try: + if parent: + logger.debug("first loading parent module {}".format(parent)) + setattr(sys.modules[parent], modname, module) + if not hasattr(modclass, "_LOADING"): + logger.debug("Module {} already loaded by the parent".format(modname)) + # We've been loaded by the parent. Let's bail. + return + cached_data = _clean_lazymodule(module) + try: + # Get Python to do the real import! + reload_module(module) + except: + # Loading failed. We reset our lazy state. + logger.debug("Failed to load module {}. Resetting...".format(modname)) + _reset_lazymodule(module, cached_data) + raise + else: + # Successful load + logger.debug("Successfully loaded module {}".format(modname)) + delattr(modclass, "_LOADING") + _reset_lazy_submod_refs(module) + + except (AttributeError, ImportError) as err: + logger.debug( + "Failed to load {}.\n{}: {}".format( + modname, err.__class__.__name__, err + ) + ) + logger.lazy_trace() + # Under Python 3 reloading our dummy LazyModule instances causes an + # AttributeError if the module can't be found. Would be preferrable + # if we could always rely on an ImportError. As it is we vet the + # AttributeError as thoroughly as possible. + if (six.PY3 and isinstance(err, AttributeError)) and not err.args[ + 0 + ] == "'NoneType' object has no attribute 'name'": + # Not the AttributeError we were looking for. + raise + msg = modclass._lazy_import_error_msgs["msg"] + raise_from( + ImportError(msg.format(**modclass._lazy_import_error_strings)), None + ) + + +############################## +# Helper functions/constants # +############################## + +_MSG = ( + "{caller} attempted to use a functionality that requires module " + "{module}, but it couldn't be loaded. Please install {install_name} " + "and retry." +) + +_MSG_CALLABLE = ( + "{caller} attempted to use a functionality that requires " + "{callable}, of module {module}, but it couldn't be found in that " + "module. Please install a version of {install_name} that has " + "{module}.{callable} and retry." +) + +_CLS_ATTRS = ( + "_lazy_import_error_strings", + "_lazy_import_error_msgs", + "_lazy_import_callables", + "_lazy_import_submodules", + "__repr__", +) + +_DELETION_DICT = ("_lazy_import_submodules",) + + +def _setdef(argdict, name, defaultvalue): + """Like dict.setdefault but sets the default value also if None is present. + """ + if not name in argdict or argdict[name] is None: + argdict[name] = defaultvalue + return argdict[name] + + +def module_basename(modname): + return modname.partition(".")[0] + + +def _set_default_errornames(modname, error_strings, call=False): + # We don't set the modulename default here because it will change for + # parents of lazily imported submodules. + error_strings.setdefault("caller", _caller_name(3, default="Python")) + error_strings.setdefault("install_name", module_basename(modname)) + error_strings.setdefault("msg", _MSG) + if call: + error_strings.setdefault("msg_callable", _MSG_CALLABLE) + + +def _caller_name(depth=2, default=""): + """Returns the name of the calling namespace. + """ + # the presence of sys._getframe might be implementation-dependent. + # It isn't that serious if we can't get the caller's name. + try: + return sys._getframe(depth).f_globals["__name__"] + except AttributeError: + return default + + +def _clean_lazymodule(module): + """Removes all lazy behavior from a module's class, for loading. + Also removes all module attributes listed under the module's class deletion + dictionaries. Deletion dictionaries are class attributes with names + specified in `_DELETION_DICT`. + Parameters + ---------- + module: LazyModule + Returns + ------- + dict + A dictionary of deleted class attributes, that can be used to reset the + lazy state using :func:`_reset_lazymodule`. + """ + modclass = type(module) + _clean_lazy_submod_refs(module) + + modclass.__getattribute__ = ModuleType.__getattribute__ + modclass.__setattr__ = ModuleType.__setattr__ + cls_attrs = {} + for cls_attr in _CLS_ATTRS: + try: + cls_attrs[cls_attr] = getattr(modclass, cls_attr) + delattr(modclass, cls_attr) + except AttributeError: + pass + return cls_attrs + + +def _clean_lazy_submod_refs(module): + modclass = type(module) + for deldict in _DELETION_DICT: + try: + delnames = getattr(modclass, deldict) + except AttributeError: + continue + for delname in delnames: + try: + super(LazyModule, module).__delattr__(delname) + except AttributeError: + # Maybe raise a warning? + pass + + +def _reset_lazymodule(module, cls_attrs): + """Resets a module's lazy state from cached data. + """ + modclass = type(module) + del modclass.__getattribute__ + del modclass.__setattr__ + try: + del modclass._LOADING + except AttributeError: + pass + for cls_attr in _CLS_ATTRS: + try: + setattr(modclass, cls_attr, cls_attrs[cls_attr]) + except KeyError: + pass + _reset_lazy_submod_refs(module) + + +def _reset_lazy_submod_refs(module): + modclass = type(module) + for deldict in _DELETION_DICT: + try: + resetnames = getattr(modclass, deldict) + except AttributeError: + continue + for name, submod in resetnames.items(): + super(LazyModule, module).__setattr__(name, submod) + + +def run_from_ipython(): + # Taken from https://stackoverflow.com/questions/5376837 + try: + __IPYTHON__ + return True + except NameError: + return False diff --git a/plume/utils/lazy_loader.py b/plume/utils/lazy_loader.py new file mode 100644 index 0000000..ed475ac --- /dev/null +++ b/plume/utils/lazy_loader.py @@ -0,0 +1,46 @@ +# Code copied from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/lazy_loader.py +"""A LazyLoader class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import types + + +class LazyLoader(types.ModuleType): + """Lazily import a module, mainly to avoid pulling in large dependencies. + + `contrib`, and `ffmpeg` are examples of modules that are large and not always + needed, and this allows them to only be loaded when they are used. + """ + + # The lint error here is incorrect. + def __init__( + self, local_name, parent_module_globals, name + ): # pylint: disable=super-on-old-class + self._local_name = local_name + self._parent_module_globals = parent_module_globals + + super(LazyLoader, self).__init__(name) + + def _load(self): + # Import the target module and insert it into the parent's namespace + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + + # Update this object's dict so that if someone keeps a reference to the + # LazyLoader, lookups are efficient (__getattr__ is only called on lookups + # that fail). + self.__dict__.update(module.__dict__) + + return module + + def __getattr__(self, item): + module = self._load() + return getattr(module, item) + + def __dir__(self): + module = self._load() + return dir(module) diff --git a/plume/utils/serve.py b/plume/utils/serve.py new file mode 100644 index 0000000..103d68a --- /dev/null +++ b/plume/utils/serve.py @@ -0,0 +1,31 @@ +from plume.utils import lazy_module +import typer + +rpyc = lazy_module('rpyc') + +app = typer.Typer() + + +class ASRService(rpyc.Service): + def __init__(self, asr_recognizer): + self.asr = asr_recognizer + + def on_connect(self, conn): + # code that runs when a connection is created + # (to init the service, if needed) + pass + + def on_disconnect(self, conn): + # code that runs after the connection has already closed + # (to finalize the service, if needed) + pass + + def exposed_transcribe(self, utterance: bytes): # this is an exposed method + speech_audio = self.asr.transcribe(utterance) + return speech_audio + + def exposed_transcribe_cb( + self, utterance: bytes, respond + ): # this is an exposed method + speech_audio = self.asr.transcribe(utterance) + respond(speech_audio) diff --git a/plume/utils/transcribe.py b/plume/utils/transcribe.py new file mode 100644 index 0000000..f1f74c1 --- /dev/null +++ b/plume/utils/transcribe.py @@ -0,0 +1,184 @@ +import os +import logging +from io import BytesIO +from pathlib import Path +from functools import lru_cache + +import typer +# import rpyc + +# from tqdm import tqdm +# from pydub import AudioSegment +# from pydub.silence import split_on_silence +from plume.utils import lazy_module, lazy_callable + +rpyc = lazy_module('rpyc') +AudioSegment = lazy_callable('pydub.AudioSegment') +split_on_silence = lazy_callable('pydub.silence.split_on_silence') + +app = typer.Typer() + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +ASR_RPYC_HOST = os.environ.get("JASR_RPYC_HOST", "localhost") +ASR_RPYC_PORT = int(os.environ.get("ASR_RPYC_PORT", "8044")) + +TRITON_ASR_MODEL = os.environ.get("TRITON_ASR_MODEL", "slu_wav2vec2") + +TRITON_GRPC_ASR_HOST = os.environ.get("TRITON_GRPC_ASR_HOST", "localhost") +TRITON_GRPC_ASR_PORT = int(os.environ.get("TRITON_GRPC_ASR_PORT", "8001")) + + +@lru_cache() +def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT): + logger.info(f"connecting to asr server at {asr_host}:{asr_port}") + try: + asr = rpyc.connect(asr_host, asr_port).root + logger.info(f"connected to asr server successfully") + except ConnectionRefusedError: + raise Exception("env-var JASPER_ASR_RPYC_HOST invalid") + + def audio_prep(aud_seg): + asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000) + return asr_seg + + return asr.transcribe, audio_prep + + +def triton_transcribe_grpc_gen( + asr_host=TRITON_GRPC_ASR_HOST, + asr_port=TRITON_GRPC_ASR_PORT, + asr_model=TRITON_ASR_MODEL, + method="chunked", + chunk_msec=5000, + sil_msec=500, + # overlap=False, + sep=" ", +): + from tritonclient.utils import np_to_triton_dtype + import tritonclient.grpc as grpcclient + import numpy as np + + sup_meth = ["chunked", "silence", "whole"] + if method not in sup_meth: + meths = "|".join(sup_meth) + raise Exception(f"unsupported method {method}. pick one of {meths}") + + client = grpcclient.InferenceServerClient(f"{asr_host}:{asr_port}") + + def transcriber(aud_seg): + af = BytesIO() + aud_seg.export(af, format="wav") + input_audio_bytes = af.getvalue() + input_audio_data = np.array([input_audio_bytes]) + inputs = [ + grpcclient.InferInput( + "INPUT_AUDIO", + input_audio_data.shape, + np_to_triton_dtype(input_audio_data.dtype), + ) + ] + inputs[0].set_data_from_numpy(input_audio_data) + outputs = [grpcclient.InferRequestedOutput("OUTPUT_TEXT")] + response = client.infer(asr_model, inputs, request_id=str(1), outputs=outputs) + transcript = response.as_numpy("OUTPUT_TEXT")[0] + return transcript.decode("utf-8") + + def chunked_transcriber(aud_seg): + if method == "silence": + sil_chunks = split_on_silence( + aud_seg, + min_silence_len=sil_msec, + silence_thresh=-50, + keep_silence=500, + ) + chunks = [sc for c in sil_chunks for sc in c[::chunk_msec]] + else: + chunks = aud_seg[::chunk_msec] + # if overlap: + # chunks = [ + # aud_seg[start, end] + # for start, end in range(0, int(aud_seg.duration_seconds * 1000, 1000)) + # ] + # pass + transcript_list = [] + sil_pad = AudioSegment.silent(duration=sil_msec) + for seg in chunks: + t_seg = sil_pad + seg + sil_pad + c_transcript = transcriber(t_seg) + transcript_list.append(c_transcript) + transcript = sep.join(transcript_list) + return transcript + + def audio_prep(aud_seg): + asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000) + return asr_seg + + whole_transcriber = transcriber if method == "whole" else chunked_transcriber + return whole_transcriber, audio_prep + + +@app.command() +def file(audio_file: Path, write_file: bool = False, chunked=True): + from pydub import AudioSegment + + aseg = AudioSegment.from_file(audio_file) + transcriber, prep = triton_transcribe_grpc_gen() + transcription = transcriber(prep(aseg)) + + typer.echo(transcription) + if write_file: + tscript_file_path = audio_file.with_suffix(".txt") + with open(tscript_file_path, "w") as tf: + tf.write(transcription) + + +@app.command() +def benchmark(audio_file: Path): + from pydub import AudioSegment + + transcriber, audio_prep = transcribe_rpyc_gen() + file_seg = AudioSegment.from_file(audio_file) + aud_seg = audio_prep(file_seg) + + def timeinfo(): + from timeit import Timer + + timer = Timer(lambda: transcriber(aud_seg)) + number = 100 + repeat = 10 + time_taken = timer.repeat(repeat, number=number) + best = min(time_taken) * 1000 / number + print(f"{number} loops, best of {repeat}: {best:.3f} msec per loop") + + timeinfo() + import time + + time.sleep(5) + + transcriber, audio_prep = triton_transcribe_grpc_gen() + aud_seg = audio_prep(file_seg) + + def timeinfo(): + from timeit import Timer + + timer = Timer(lambda: transcriber(aud_seg)) + number = 100 + repeat = 10 + time_taken = timer.repeat(repeat, number=number) + best = min(time_taken) * 1000 / number + print(f"{number} loops, best of {repeat}: {best:.3f} msec per loop") + + timeinfo() + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/plume/utils/tts.py b/plume/utils/tts.py new file mode 100644 index 0000000..c99fa97 --- /dev/null +++ b/plume/utils/tts.py @@ -0,0 +1,92 @@ +from logging import getLogger +from plume.utils import lazy_module + + +from pathlib import Path + +import typer + +# from google.cloud import texttospeech +texttospeech = lazy_module('google.cloud.texttospeech') + +LOGGER = getLogger("googletts") + +app = typer.Typer() + + +class GoogleTTS(object): + def __init__(self): + self.client = texttospeech.TextToSpeechClient() + + def text_to_speech(self, text: str, params: dict) -> bytes: + tts_input = texttospeech.types.SynthesisInput(text=text) + voice = texttospeech.types.VoiceSelectionParams( + language_code=params["language"], name=params["name"] + ) + audio_config = texttospeech.types.AudioConfig( + audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16, + sample_rate_hertz=params["sample_rate"], + ) + response = self.client.synthesize_speech(tts_input, voice, audio_config) + audio_content = response.audio_content + return audio_content + + def ssml_to_speech(self, text: str, params: dict) -> bytes: + tts_input = texttospeech.types.SynthesisInput(ssml=text) + voice = texttospeech.types.VoiceSelectionParams( + language_code=params["language"], name=params["name"] + ) + audio_config = texttospeech.types.AudioConfig( + audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16, + sample_rate_hertz=params["sample_rate"], + ) + response = self.client.synthesize_speech(tts_input, voice, audio_config) + audio_content = response.audio_content + return audio_content + + @classmethod + def voice_list(cls): + """Lists the available voices.""" + + client = cls().client + + # Performs the list voices request + voices = client.list_voices() + results = [] + for voice in voices.voices: + supported_eng_langs = [ + lang for lang in voice.language_codes if lang[:2] == "en" + ] + if len(supported_eng_langs) > 0: + lang = ",".join(supported_eng_langs) + else: + continue + + ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender) + results.append( + { + "name": voice.name, + "language": lang, + "gender": ssml_gender.name, + "engine": "wavenet" if "Wav" in voice.name else "standard", + "sample_rate": voice.natural_sample_rate_hertz, + } + ) + return results + + +@app.command() +def generate_audio_file(text, dest_path: Path = "./tts_audio.wav", voice="en-US-Wavenet-D"): + tts = GoogleTTS() + selected_voice = [v for v in tts.voice_list() if v["name"] == voice][0] + wav_data = tts.text_to_speech(text, selected_voice) + with dest_path.open("wb") as wf: + wf.write(wav_data) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index eb23848..b5d756f 100644 --- a/setup.py +++ b/setup.py @@ -1,81 +1,80 @@ from setuptools import setup, find_packages +# pip install "nvidia-pyindex~=1.0.5" + requirements = [ - "ruamel.yaml", - "torch==1.4.0", - "torchvision==0.5.0", + "torch~=1.6.0", + "torchvision~=0.7.0", "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit", + "fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq", + # "google-cloud-texttospeech~=1.0.1", + "tqdm~=4.54.0", + # "pydub~=0.24.0", + # "scikit_learn~=0.22.1", + # "pandas~=1.0.3", + # "boto3~=1.12.35", + # "ruamel.yaml~=0.16.10", + # "pymongo==3.10.1", + # "matplotlib==3.2.1", + # "tabulate==0.8.7", + # "natural==0.2.0", + # "num2words==0.5.10", + "typer[all]~=0.3.2", + # "python-slugify==4.0.0", + # "websockets==8.1", + # "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses", + "rpyc~=4.1.4", + # "streamlit~=0.61.0", + # "librosa~=0.7.2", + # "tritonclient[http]~=2.6.0", + "numba~=0.48.0", ] extra_requirements = { - "server": ["rpyc~=4.1.4", "tqdm~=4.39.0"], "data": [ - "google-cloud-texttospeech~=1.0.1", - "tqdm~=4.39.0", "pydub~=0.24.0", + "google-cloud-texttospeech~=1.0.1", "scikit_learn~=0.22.1", "pandas~=1.0.3", "boto3~=1.12.35", - "ruamel.yaml==0.16.10", - "pymongo==3.10.1", - "librosa==0.7.2", - "numba==0.48", - "matplotlib==3.2.1", - "pandas==1.0.3", - "tabulate==0.8.7", - "natural==0.2.0", - "num2words==0.5.10", - "typer[all]==0.3.1", - "python-slugify==4.0.0", + "ruamel.yaml~=0.16.10", + "pymongo~=3.10.1", + "librosa~=0.7.2", + "matplotlib~=3.2.1", + "pandas~=1.0.3", + "tabulate~=0.8.7", + "natural~=0.2.0", + "num2words~=0.5.10", + "python-slugify~=4.0.0", "rpyc~=4.1.4", - "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses", + # "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses", ], "validation": [ - "rpyc~=4.1.4", - "pymongo==3.10.1", - "typer[all]==0.1.1", - "tqdm~=4.39.0", - "librosa==0.7.2", - "matplotlib==3.2.1", + "pymongo~=3.10.1", + "matplotlib~=3.2.1", "pydub~=0.24.0", - "streamlit==0.58.0", - "natural==0.2.0", - "stringcase==1.2.0", + "streamlit~=0.58.0", + "natural~=0.2.0", + "stringcase~=1.2.0", "google-cloud-speech~=1.3.1", - ] - # "train": [ - # "torchaudio==0.5.0", - # "torch-stft==0.1.4", - # ] + ], + "train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"], } + +extra_requirements["all"] = list({d for l in extra_requirements.values() for d in l}) packages = find_packages() setup( - name="jasper-asr", - version="0.1", - description="Tool to get gcp alignments of tts-data", - url="http://github.com/malarinv/jasper-asr", + name="plume-asr", + version="0.11", + description="Multi model ASR base package", + url="http://github.com/malarinv/plume-asr", author="Malar Kannan", author_email="malarkannan.invention@gmail.com", license="MIT", install_requires=requirements, extras_require=extra_requirements, packages=packages, - entry_points={ - "console_scripts": [ - "jasper_transcribe = jasper.transcribe:main", - "jasper_server = jasper.server:main", - "jasper_trainer = jasper.training.cli:main", - "jasper_evaluator = jasper.evaluate:main", - "jasper_data_tts_generate = jasper.data.tts_generator:main", - "jasper_data_conv_generate = jasper.data.conv_generator:main", - "jasper_data_nlu_generate = jasper.data.nlu_generator:main", - "jasper_data_rastrik_recycle = jasper.data.rastrik_recycler:main", - "jasper_data_server = jasper.data.server:main", - "jasper_data_validation = jasper.data.validation.process:main", - "jasper_data_preprocess = jasper.data.process:main", - "jasper_data_slu_evaluate = jasper.data.slu_evaluator:main", - ] - }, + entry_points={"console_scripts": ["plume = plume.cli:main"]}, zip_safe=False, ) diff --git a/validation_ui.py b/validation_ui.py deleted file mode 100644 index b45692e..0000000 --- a/validation_ui.py +++ /dev/null @@ -1,3 +0,0 @@ -import runpy - -runpy.run_module("jasper.data.validation.ui", run_name="__main__", alter_sys=True)