diff --git a/MANIFEST.in b/MANIFEST.in index f630a67..11a5eb0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -graft plume/utils/gentle_preview +graft src diff --git a/README.md b/README.md index 691e69f..19a1d6f 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ The installation should work on Python 3.6 or newer. Untested on Python 2.7 ### Library > Jasper ```python -from plume.models.jasper.asr import JasperASR +from plume.models.jasper_nemo.asr import JasperASR asr_model = JasperASR("/path/to/model_config_yaml","/path/to/encoder_checkpoint","/path/to/decoder_checkpoint") # Loads the models TEXT = asr_model.transcribe(wav_data) # Returns the text spoken in the wav ``` diff --git a/plume/utils/audio.py b/plume/utils/audio.py deleted file mode 100644 index 35b9f8f..0000000 --- a/plume/utils/audio.py +++ /dev/null @@ -1,28 +0,0 @@ -from scipy.signal import lfilter, butter -from scipy.io.wavfile import read, write -from numpy import array, int16 -import sys - - -def butter_params(low_freq, high_freq, fs, order=5): - nyq = 0.5 * fs - low = low_freq / nyq - high = high_freq / nyq - b, a = butter(order, [low, high], btype="band") - return b, a - - -def butter_bandpass_filter(data, low_freq, high_freq, fs, order=5): - b, a = butter_params(low_freq, high_freq, fs, order=order) - y = lfilter(b, a, data) - return y - - -if __name__ == "__main__": - fs, audio = read(sys.argv[1]) - import pdb; pdb.set_trace() - low_freq = 300.0 - high_freq = 4000.0 - filtered_signal = butter_bandpass_filter(audio, low_freq, high_freq, fs, order=6) - fname = sys.argv[1].split(".wav")[0] + "_moded.wav" - write(fname, fs, array(filtered_signal, dtype=int16)) diff --git a/plume/utils/vad.py b/plume/utils/vad.py deleted file mode 100644 index 5832914..0000000 --- a/plume/utils/vad.py +++ /dev/null @@ -1,205 +0,0 @@ -import logging -import asyncio -import argparse -from pathlib import Path - -import webrtcvad -import pydub -from pydub.playback import play -from pydub.utils import make_chunks - - -DEFAULT_CHUNK_DUR = 20 - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - - -def is_frame_voice(vad, seg, chunk_dur): - return ( - True - if ( - seg.duration_seconds == chunk_dur / 1000 - and vad.is_speech(seg.raw_data, seg.frame_rate) - ) - else False - ) - - -class VADFilterAudio(object): - """docstring for VADFilterAudio.""" - - def __init__(self, chunk_dur=DEFAULT_CHUNK_DUR): - super(VADFilterAudio, self).__init__() - self.chunk_dur = chunk_dur - self.vad = webrtcvad.Vad() - - def filter_segment(self, wav_seg): - chunks = make_chunks(wav_seg, self.chunk_dur) - speech_buffer = b"" - - for i, c in enumerate(chunks[:-1]): - voice_frame = is_frame_voice(self.vad, c, self.chunk_dur) - if voice_frame: - speech_buffer += c.raw_data - filtered_seg = pydub.AudioSegment( - data=speech_buffer, - frame_rate=wav_seg.frame_rate, - channels=wav_seg.channels, - sample_width=wav_seg.sample_width, - ) - return filtered_seg - - -class VADUtterance(object): - """docstring for VADUtterance.""" - - def __init__( - self, - max_silence=500, - min_utterance=280, - max_utterance=20000, - chunk_dur=DEFAULT_CHUNK_DUR, - start_cycles=3, - ): - super(VADUtterance, self).__init__() - self.vad = webrtcvad.Vad() - self.chunk_dur = chunk_dur - # duration in millisecs - self.max_sil = max_silence - self.min_utt = min_utterance - self.max_utt = max_utterance - self.speech_start = start_cycles * chunk_dur - - def __repr__(self): - return f"VAD(max_silence={self.max_sil},min_utterance:{self.min_utt},max_utterance:{self.max_utt})" - - async def stream_utterance(self, audio_stream): - silence_buffer = pydub.AudioSegment.empty() - voice_buffer = pydub.AudioSegment.empty() - silence_threshold = False - async for c in audio_stream: - voice_frame = is_frame_voice(self.vad, c, self.chunk_dur) - logger.debug(f"is audio stream voice? {voice_frame}") - if voice_frame: - silence_threshold = False - voice_buffer += c - silence_buffer = pydub.AudioSegment.empty() - else: - silence_buffer += c - voc_dur = voice_buffer.duration_seconds * 1000 - sil_dur = silence_buffer.duration_seconds * 1000 - - if voc_dur >= self.max_utt: - logger.info( - f"detected voice overflow: voice duration {voice_buffer.duration_seconds}" - ) - yield voice_buffer - voice_buffer = pydub.AudioSegment.empty() - - if sil_dur >= self.max_sil: - if voc_dur >= self.min_utt: - logger.info( - f"detected silence: voice duration {voice_buffer.duration_seconds}" - ) - yield voice_buffer - voice_buffer = pydub.AudioSegment.empty() - # ignore/clear voice if silence reached threshold or indent the statement - if not silence_threshold: - silence_threshold = True - - if voice_buffer: - yield voice_buffer - - async def stream_events(self, audio_stream): - """ - yields 0, voice_buffer for SpeechBuffer - yields 1, None for StartedSpeaking - yields 2, None for StoppedSpeaking - yields 4, audio_stream - """ - silence_buffer = pydub.AudioSegment.empty() - voice_buffer = pydub.AudioSegment.empty() - silence_threshold, started_speaking = False, False - async for c in audio_stream: - # yield (4, c) - voice_frame = is_frame_voice(self.vad, c, self.chunk_dur) - logger.debug(f"is audio stream voice? {voice_frame}") - if voice_frame: - silence_threshold = False - voice_buffer += c - silence_buffer = pydub.AudioSegment.empty() - else: - silence_buffer += c - voc_dur = voice_buffer.duration_seconds * 1000 - sil_dur = silence_buffer.duration_seconds * 1000 - - if voc_dur >= self.speech_start and not started_speaking: - started_speaking = True - yield (1, None) - - if voc_dur >= self.max_utt: - logger.info( - f"detected voice overflow: voice duration {voice_buffer.duration_seconds}" - ) - yield (0, voice_buffer) - voice_buffer = pydub.AudioSegment.empty() - started_speaking = False - - if sil_dur >= self.max_sil: - if voc_dur >= self.min_utt: - logger.info( - f"detected silence: voice duration {voice_buffer.duration_seconds}" - ) - yield (0, voice_buffer) - voice_buffer = pydub.AudioSegment.empty() - started_speaking = False - # ignore/clear voice if silence reached threshold or indent the statement - if not silence_threshold: - silence_threshold = True - yield (2, None) - - if voice_buffer: - yield (0, voice_buffer) - - @classmethod - async def stream_utterance_file(cls, audio_file): - async def stream_gen(): - audio_seg = pydub.AudioSegment.from_file(audio_file).set_frame_rate(32000) - chunks = make_chunks(audio_seg, DEFAULT_CHUNK_DUR) - for c in chunks: - yield c - - va_ut = cls() - buffer_src = va_ut.stream_utterance(stream_gen()) - async for buf in buffer_src: - play(buf) - await asyncio.sleep(1) - - -class VADStreamGen(object): - """docstring for VADStreamGen.""" - - def __init__(self, arg): - super(VADStreamGen, self).__init__() - self.arg = arg - - -def main(): - prog = Path(__file__).stem - parser = argparse.ArgumentParser(prog=prog, description="transcribes audio file") - parser.add_argument( - "--audio_file", - type=argparse.FileType("rb"), - help="audio file to transcribe", - default="./test_utter2.wav", - ) - args = parser.parse_args() - loop = asyncio.get_event_loop() - loop.run_until_complete(VADUtterance.stream_utterance_file(args.audio_file)) - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a8f43fe --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 79 diff --git a/setup.py b/setup.py index 8ea5a79..92925e8 100644 --- a/setup.py +++ b/setup.py @@ -3,12 +3,10 @@ from setuptools import setup, find_namespace_packages # pip install "nvidia-pyindex~=1.0.5" requirements = [ - "torch~=1.6.0", - "torchvision~=0.7.0", - "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit", - "fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq", + # "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit", + # "fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq", # "google-cloud-texttospeech~=1.0.1", - "tqdm~=4.54.0", + "tqdm~=4.49.0", # "pydub~=0.24.0", # "scikit_learn~=0.22.1", # "pandas~=1.0.3", @@ -47,8 +45,30 @@ extra_requirements = { "num2words~=0.5.10", "python-slugify~=4.0.0", "rpyc~=4.1.4", + "webrtcvad~=2.0.10", + # "datasets" # "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses", ], + "models": [ + # "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit", + "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@v1.0.0#egg=nemo_toolkit", + "fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq", + "transformers~=4.5.0", + "torch~=1.7.0", + "torchvision~=0.8.2", + "torchaudio~=0.7.2", + ], + "eval": [ + "jiwer~=2.2.0", + "pydub~=0.24.0", + "tritonclient[grpc]~=2.9.0", + "pyspellchecker~=0.6.2", + "num2words~=0.5.10", + ], + "infer": [ + "pyspellchecker~=0.6.2", + "num2words~=0.5.10", + ], "validation": [ "pymongo~=3.10.1", "matplotlib~=3.2.1", @@ -61,15 +81,20 @@ extra_requirements = { "ui": [ "rangehttpserver~=1.2.0", ], + "crypto": ["cryptography~=3.4.7"], "train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"], } - -extra_requirements["all"] = list({d for l in extra_requirements.values() for d in l}) -packages = find_namespace_packages() +extra_requirements["deploy"] = ( + extra_requirements["models"] + extra_requirements["infer"] +) +extra_requirements["all"] = list( + {d for r in extra_requirements.values() for d in r} +) +packages = find_namespace_packages("src") setup( name="plume-asr", - version="0.2.0", + version="0.2.1", description="Multi model ASR base package", url="http://github.com/malarinv/plume-asr", author="Malar Kannan", @@ -78,6 +103,7 @@ setup( install_requires=requirements, extras_require=extra_requirements, packages=packages, + package_dir={"": "src"}, entry_points={"console_scripts": ["plume = plume.cli:main"]}, zip_safe=False, ) diff --git a/plume/cli/__init__.py b/src/plume/cli/__init__.py similarity index 100% rename from plume/cli/__init__.py rename to src/plume/cli/__init__.py diff --git a/src/plume/cli/__main__.py b/src/plume/cli/__main__.py new file mode 100644 index 0000000..29c2e0a --- /dev/null +++ b/src/plume/cli/__main__.py @@ -0,0 +1,5 @@ +from . import main + + +if __name__ == "__main__": + main() diff --git a/plume/cli/data/__init__.py b/src/plume/cli/data/__init__.py similarity index 60% rename from plume/cli/data/__init__.py rename to src/plume/cli/data/__init__.py index e7bcb7b..bcf1042 100644 --- a/plume/cli/data/__init__.py +++ b/src/plume/cli/data/__init__.py @@ -1,6 +1,14 @@ import json from pathlib import Path +from random import shuffle +from typing import List +from itertools import chain + # from sklearn.model_selection import train_test_split +from tqdm import tqdm +import shutil +import typer + from plume.utils import ( asr_manifest_reader, asr_manifest_writer, @@ -9,18 +17,19 @@ from plume.utils import ( generate_filter_map, get_mongo_conn, tscript_uuid_fname, - lazy_callable + lazy_callable, + lazy_module, + wav_cryptor, + text_cryptor, + parallel_apply, ) -from typing import List -from itertools import chain -import shutil -import typer -import soundfile from ...models.wav2vec2.data import app as wav2vec2_app from .generate import app as generate_app -train_test_split = lazy_callable('sklearn.model_selection.train_test_split') +soundfile = lazy_module("soundfile") +pydub = lazy_module("pydub") +train_test_split = lazy_callable("sklearn.model_selection.train_test_split") app = typer.Typer() app.add_typer(generate_app, name="generate") @@ -62,7 +71,7 @@ def fix_path(dataset_path: Path, force: bool = False): @app.command() -def augment(src_dataset_paths: List[Path], dest_dataset_path: Path): +def merge(src_dataset_paths: List[Path], dest_dataset_path: Path): reader_list = [] abs_manifest_path = Path("abs_manifest.json") for dataset_path in src_dataset_paths: @@ -74,14 +83,89 @@ def augment(src_dataset_paths: List[Path], dest_dataset_path: Path): @app.command() -def split(dataset_path: Path, test_size: float = 0.03): +def training_split(dataset_path: Path, test_size: float = 0.03): manifest_path = dataset_path / Path("abs_manifest.json") if not manifest_path.exists(): fix_path(dataset_path) asr_data = list(asr_manifest_reader(manifest_path)) train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size) - asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr) - asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr) + asr_manifest_writer( + manifest_path.with_name("train_manifest.json"), train_pnr + ) + asr_manifest_writer( + manifest_path.with_name("test_manifest.json"), test_pnr + ) + + +@app.command() +def parts_split_by_size( + dataset_path: Path, + test_size: float = 0.03, + split_prefix_names: List[str] = ["train", "test"], +): + manifest_path = dataset_path / Path("abs_manifest.json") + if not manifest_path.exists(): + fix_path(dataset_path) + asr_data = list(asr_manifest_reader(manifest_path)) + train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size) + dest_paths = [ + (dataset_path.parent / (dataset_path.name + "_" + spn), sd) + for (spn, sd) in zip(split_prefix_names, [train_pnr, test_pnr]) + ] + for dest_path, manifest_data in dest_paths: + wav_dir = dest_path / Path("wavs") + wav_dir.mkdir(exist_ok=True, parents=True) + abs_manifest_path = ExtendedPath(dest_path / Path("abs_manifest.json")) + for md in manifest_data: + orig_path = Path(md["audio_filepath"]) + new_path = wav_dir / Path(orig_path.name) + shutil.copy(orig_path, new_path) + md["audio_filepath"] = str(new_path) + md.pop("audio_path") + abs_manifest_path.write_jsonl(manifest_data) + fix_path(dest_path) + + +@app.command() +def parts_split_by_dur( + dataset_path: Path, + dur_sec: int = 7200, + suffix_name: List[str] = ["train", "test"], +): + manifest_path = dataset_path / Path("abs_manifest.json") + if not manifest_path.exists(): + fix_path(dataset_path) + asr_data = list(asr_manifest_reader(manifest_path)) + + def dur_split(dataset, dur_seconds): + shuffle(dataset) + counter_dur = 0 + train_set, test_set = [], [] + for d in dataset: + if counter_dur <= dur_seconds: + test_set.append(d) + else: + train_set.append(d) + counter_dur += d["duration"] + return train_set, test_set + + train_pnr, test_pnr = dur_split(asr_data, dur_sec) + dest_paths = [ + (dataset_path.parent / (dataset_path.name + "_" + spn), sd) + for (spn, sd) in zip(suffix_name, [train_pnr, test_pnr]) + ] + for dest_path, manifest_data in dest_paths: + wav_dir = dest_path / Path("wavs") + wav_dir.mkdir(exist_ok=True, parents=True) + abs_manifest_path = ExtendedPath(dest_path / Path("abs_manifest.json")) + for md in manifest_data: + orig_path = Path(md["audio_filepath"]) + new_path = wav_dir / Path(orig_path.name) + shutil.copy(orig_path, new_path) + md["audio_filepath"] = str(new_path.absolute()) + md.pop("audio_path") + abs_manifest_path.write_jsonl(manifest_data) + fix_path(dest_path.absolute()) @app.command() @@ -111,7 +195,11 @@ def validate(dataset_path: Path): @app.command() -def filter(src_dataset_path: Path, dest_dataset_path: Path, kind: str = "skip_dur"): +def filter( + src_dataset_path: Path, + dest_dataset_path: Path, + kind: str = "", +): dest_manifest = dest_dataset_path / Path("manifest.json") data_file = src_dataset_path / Path("manifest.json") dest_wav_dir = dest_dataset_path / Path("wavs") @@ -149,13 +237,21 @@ def info(dataset_path: Path): real_duration += soundfile.info(wav_path).duration # frame_count = soundfile.info(audio_fname).frames - print(f"max audio duration : {duration_str(max_duration)}") - print(f"total audio duration : {duration_str(mf_wav_duration)}") - print(f"total real audio duration : {duration_str(real_duration)}") print( - f"total content duration : {duration_str(mf_wav_duration-empty_duration)}" + f"max audio duration : {duration_str(max_duration, show_hours=True)}" + ) + print( + f"total audio duration : {duration_str(mf_wav_duration, show_hours=True)}" + ) + print( + f"total real audio duration : {duration_str(real_duration, show_hours=True)}" + ) + print( + f"total content duration : {duration_str(mf_wav_duration-empty_duration, show_hours=True)}" + ) + print( + f"total empty duration : {duration_str(empty_duration, show_hours=True)}" ) - print(f"total empty duration : {duration_str(empty_duration)}") print( f"total empty samples : {empty_count}/{total_count} ({empty_count*100/total_count:.2f}%)" ) @@ -167,7 +263,81 @@ def audio_duration(dataset_path: Path): for audio_rel_fname in dataset_path.absolute().glob("**/*.wav"): audio_fname = str(audio_rel_fname) wav_duration += soundfile.info(audio_fname).duration - typer.echo(f"duration of wav files @ {dataset_path}: {duration_str(wav_duration)}") + typer.echo( + f"duration of wav files @ {dataset_path}: {duration_str(wav_duration)}" + ) + + +@app.command() +def encrypt( + src_dataset_path: Path, + dest_dataset_path: Path, + encryption_key: str = typer.Option(..., prompt=True, hide_input=True), + verbose: bool = False, +): + dest_manifest = dest_dataset_path / Path("manifest.json") + src_manifest = src_dataset_path / Path("manifest.json") + dest_wav_dir = dest_dataset_path / Path("wavs") + dest_wav_dir.mkdir(exist_ok=True, parents=True) + wav_crypt = wav_cryptor(encryption_key) + text_crypt = text_cryptor(encryption_key) + # warmup + _ = pydub.AudioSegment.from_file + + def encrypt_item(s): + crypt_text = text_crypt.encrypt_text(s["text"]) + src_wav_path = src_dataset_path / s["audio_filepath"] + dst_wav_path = dest_dataset_path / s["audio_filepath"] + wav_crypt.encrypt_wav_path_to(src_wav_path, dst_wav_path) + s["text"] = crypt_text.decode("utf-8") + return s + + def encryted_gen(): + data = list(ExtendedPath(src_manifest).read_jsonl()) + iter_data = tqdm(data) if verbose else data + encrypted_iter_data = parallel_apply( + encrypt_item, iter_data, verbose=verbose, workers=64 + ) + for s in encrypted_iter_data: + yield s + + asr_manifest_writer(dest_manifest, encryted_gen(), verbose=verbose) + + +@app.command() +def decrypt( + src_dataset_path: Path, + dest_dataset_path: Path, + encryption_key: str = typer.Option(..., prompt=True, hide_input=True), + verbose: bool = True, +): + dest_manifest = dest_dataset_path / Path("manifest.json") + src_manifest = src_dataset_path / Path("manifest.json") + dest_wav_dir = dest_dataset_path / Path("wavs") + dest_wav_dir.mkdir(exist_ok=True, parents=True) + wav_crypt = wav_cryptor(encryption_key) + text_crypt = text_cryptor(encryption_key) + # warmup + _ = pydub.AudioSegment.from_file + + def decrypt_item(s): + crypt_text = text_crypt.decrypt_text(s["text"].encode("utf-8")) + src_wav_path = src_dataset_path / s["audio_filepath"] + dst_wav_path = dest_dataset_path / s["audio_filepath"] + wav_crypt.decrypt_wav_path_to(src_wav_path, dst_wav_path) + s["text"] = crypt_text + return s + + def decryted_gen(): + data = list(ExtendedPath(src_manifest).read_jsonl()) + iter_data = tqdm(data) if verbose else data + decrypted_iter_data = parallel_apply( + decrypt_item, iter_data, verbose=verbose, workers=64 + ) + for s in decrypted_iter_data: + yield s + + asr_manifest_writer(dest_manifest, decryted_gen(), verbose=verbose) @app.command() @@ -204,13 +374,19 @@ def task_split( processed_data_path = data_dir / dump_file processed_data = ExtendedPath(processed_data_path).read_json() - df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True) + df = ( + pd.DataFrame(processed_data["data"]) + .sample(frac=1) + .reset_index(drop=True) + ) for t_idx, task_f in enumerate(np.array_split(df, task_count)): task_f = task_f.reset_index(drop=True) task_f["real_idx"] = task_f.index task_data = task_f.to_dict("records") if sort: - task_data = sorted(task_data, key=lambda x: x["asr_wer"], reverse=True) + task_data = sorted( + task_data, key=lambda x: x["asr_wer"], reverse=True + ) processed_data["data"] = task_data task_path = data_dir / Path(task_file + f"-{t_idx}.json") ExtendedPath(task_path).write_json(processed_data) @@ -223,7 +399,9 @@ def get_corrections(task_uid): for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid or c == task_uid ][0] - corrections = list(col.find({"type": "correction"}, projection={"_id": False})) + corrections = list( + col.find({"type": "correction"}, projection={"_id": False}) + ) cursor_obj = col.find( {"type": "correction", "task_id": task_id}, projection={"_id": False} ) @@ -241,8 +419,8 @@ def dump_task_corrections(data_dir: Path, task_uid: str): @app.command() def dump_all_corrections(data_dir: Path): - for task_lcks in data_dir.glob('task-*.lck'): - task_uid = task_lcks.stem.replace('task-', '') + for task_lcks in data_dir.glob("task-*.lck"): + task_uid = task_lcks.stem.replace("task-", "") dump_task_corrections(data_dir, task_uid) @@ -292,7 +470,9 @@ def update_corrections( correct_text = correction_map[d["utterance_id"]] if skip_incorrect: ap = d["audio_path"] - print(f"skipping incorrect {ap} corrected to {correct_text}") + print( + f"skipping incorrect {ap} corrected to {correct_text}" + ) orig_audio_path.unlink() else: new_fname = tscript_uuid_fname(correct_text) @@ -304,7 +484,9 @@ def update_corrections( new_name = str(Path(new_fname).with_suffix(".wav")) new_audio_path = orig_audio_path.with_name(new_name) orig_audio_path.replace(new_audio_path) - new_filepath = str(Path(d["audio_path"]).with_name(new_name)) + new_filepath = str( + Path(d["audio_path"]).with_name(new_name) + ) d["corrected_from"] = d["text"] d["text"] = correct_text d["audio_path"] = new_filepath @@ -325,7 +507,9 @@ def update_corrections( shutil.copytree(str(dataset_dir), str(backup_dir)) renames = {} corrected_ui_dump = list(correct_ui_dump(data_dir, renames)) - ExtendedPath(data_dir / ui_dump_file).write_json({"data": corrected_ui_dump}) + ExtendedPath(data_dir / ui_dump_file).write_json( + {"data": corrected_ui_dump} + ) corrected_manifest = ( { "audio_filepath": d["audio_path"], diff --git a/plume/cli/data/generate.py b/src/plume/cli/data/generate.py similarity index 100% rename from plume/cli/data/generate.py rename to src/plume/cli/data/generate.py diff --git a/plume/cli/eval.py b/src/plume/cli/eval.py similarity index 57% rename from plume/cli/eval.py rename to src/plume/cli/eval.py index 53a2aef..17d1017 100644 --- a/plume/cli/eval.py +++ b/src/plume/cli/eval.py @@ -1,8 +1,10 @@ import typer from ..models.wav2vec2.eval import app as wav2vec2_app +from ..models.wav2vec2_transformers.eval import app as wav2vec2_transformers_app app = typer.Typer() app.add_typer(wav2vec2_app, name="wav2vec2") +app.add_typer(wav2vec2_transformers_app, name="wav2vec2_transformers") @app.callback() diff --git a/plume/cli/serve.py b/src/plume/cli/serve.py similarity index 53% rename from plume/cli/serve.py rename to src/plume/cli/serve.py index 7b7e29d..6e974aa 100644 --- a/plume/cli/serve.py +++ b/src/plume/cli/serve.py @@ -1,9 +1,11 @@ import typer from ..models.wav2vec2.serve import app as wav2vec2_app -from ..models.jasper.serve import app as jasper_app +from ..models.wav2vec2_transformers.serve import app as wav2vec2_transformers_app +from ..models.jasper_nemo.serve import app as jasper_app app = typer.Typer() app.add_typer(wav2vec2_app, name="wav2vec2") +app.add_typer(wav2vec2_transformers_app, name="wav2vec2_transformers") app.add_typer(jasper_app, name="jasper") diff --git a/plume/cli/train.py b/src/plume/cli/train.py similarity index 100% rename from plume/cli/train.py rename to src/plume/cli/train.py diff --git a/plume/models/__init__.py b/src/plume/models/__init__.py similarity index 100% rename from plume/models/__init__.py rename to src/plume/models/__init__.py diff --git a/plume/models/jasper/__init__.py b/src/plume/models/jasper_nemo/__init__.py similarity index 100% rename from plume/models/jasper/__init__.py rename to src/plume/models/jasper_nemo/__init__.py diff --git a/plume/models/jasper/asr.py b/src/plume/models/jasper_nemo/asr.py similarity index 90% rename from plume/models/jasper/asr.py rename to src/plume/models/jasper_nemo/asr.py index e52695d..8cabd56 100644 --- a/plume/models/jasper/asr.py +++ b/src/plume/models/jasper_nemo/asr.py @@ -16,7 +16,11 @@ class JasperASR(object): """docstring for JasperASR.""" def __init__( - self, model_yaml, encoder_checkpoint, decoder_checkpoint, language_model=None + self, + model_yaml, + encoder_checkpoint, + decoder_checkpoint, + language_model=None, ): super(JasperASR, self).__init__() # Read model YAML @@ -24,16 +28,17 @@ class JasperASR(object): with open(model_yaml) as f: jasper_model_definition = yaml.load(f) self.neural_factory = nemo.core.NeuralModuleFactory( - placement=nemo.core.DeviceType.GPU, backend=nemo.core.Backend.PyTorch + placement=nemo.core.DeviceType.GPU, + backend=nemo.core.Backend.PyTorch, ) self.labels = jasper_model_definition["labels"] self.data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor() self.jasper_encoder = nemo_asr.JasperEncoder( jasper=jasper_model_definition["JasperEncoder"]["jasper"], activation=jasper_model_definition["JasperEncoder"]["activation"], - feat_in=jasper_model_definition["AudioToMelSpectrogramPreprocessor"][ - "features" - ], + feat_in=jasper_model_definition[ + "AudioToMelSpectrogramPreprocessor" + ]["features"], ) self.jasper_encoder.restore_from(encoder_checkpoint, local_rank=0) self.jasper_decoder = nemo_asr.JasperDecoderForCTC( @@ -65,7 +70,11 @@ class JasperASR(object): wf.setframerate(24000) wf.writeframesraw(audio_data) wf.close() - manifest = {"audio_filepath": audio_file_path, "duration": 60, "text": "todo"} + manifest = { + "audio_filepath": audio_file_path, + "duration": 60, + "text": "todo", + } manifest_file = tempfile.NamedTemporaryFile( dir=WORK_DIR, prefix="jasper_manifest.", delete=False, mode="w" ) diff --git a/plume/models/jasper/data.py b/src/plume/models/jasper_nemo/data.py similarity index 100% rename from plume/models/jasper/data.py rename to src/plume/models/jasper_nemo/data.py diff --git a/plume/models/jasper/data_loaders.py b/src/plume/models/jasper_nemo/data_loaders.py similarity index 98% rename from plume/models/jasper/data_loaders.py rename to src/plume/models/jasper_nemo/data_loaders.py index d181dfa..4a7ab97 100644 --- a/plume/models/jasper/data_loaders.py +++ b/src/plume/models/jasper_nemo/data_loaders.py @@ -11,7 +11,12 @@ from nemo.backends.pytorch import DataLayerNM from nemo.core import DeviceType # from nemo.core.neural_types import * -from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType +from nemo.core.neural_types import ( + NeuralType, + AudioSignal, + LengthsType, + LabelsType, +) from nemo.utils.decorators import add_port_docs from nemo.collections.asr.parts.dataset import ( @@ -217,8 +222,7 @@ transcript_n} @property @add_port_docs() def output_ports(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { # 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}), # 'a_sig_length': NeuralType({0: AxisType(BatchTag)}), @@ -304,7 +308,9 @@ transcript_n} # Set up data loader if self._placement == DeviceType.AllGpu: logging.info("Parallelizing Datalayer.") - sampler = torch.utils.data.distributed.DistributedSampler(self._dataset) + sampler = torch.utils.data.distributed.DistributedSampler( + self._dataset + ) else: sampler = None diff --git a/plume/models/jasper/eval.py b/src/plume/models/jasper_nemo/eval.py similarity index 94% rename from plume/models/jasper/eval.py rename to src/plume/models/jasper_nemo/eval.py index 12f558f..ab18a8e 100644 --- a/plume/models/jasper/eval.py +++ b/src/plume/models/jasper_nemo/eval.py @@ -1,6 +1,7 @@ # Copyright (c) 2019 NVIDIA Corporation import argparse import copy + # import math import os from pathlib import Path @@ -57,7 +58,10 @@ def parse_args(): help="max number of steps to train", ) parser.add_argument( - "--num_epochs", type=int, required=False, help="number of epochs to train" + "--num_epochs", + type=int, + required=False, + help="number of epochs to train", ) parser.add_argument( "--model_config", @@ -170,7 +174,8 @@ def create_all_dags(args, neural_factory): # logging.info("Have {0} examples to train on.".format(N)) # data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor( - sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"] + sample_rate=sample_rate, + **jasper_params["AudioToMelSpectrogramPreprocessor"], ) # multiply_batch_config = jasper_params.get("MultiplyBatch", None) @@ -284,7 +289,12 @@ def create_all_dags(args, neural_factory): callbacks = [] # assemble eval DAGs for i, eval_dl in enumerate(data_layers_eval): - (audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl() + ( + audio_signal_e, + a_sig_length_e, + transcript_e, + transcript_len_e, + ) = eval_dl() processed_signal_e, p_length_e = data_preprocessor( input_signal=audio_signal_e, length=a_sig_length_e ) @@ -303,9 +313,16 @@ def create_all_dags(args, neural_factory): # create corresponding eval callback tagname = os.path.basename(args.eval_datasets[i]).split(".")[0] eval_callback = nemo.core.EvaluatorCallback( - eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e], + eval_tensors=[ + loss_e, + predictions_e, + transcript_e, + transcript_len_e, + ], user_iter_callback=partial(process_evaluation_batch, labels=vocab), - user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname), + user_epochs_done_callback=partial( + process_evaluation_epoch, tag=tagname + ), eval_step=args.eval_freq, tb_writer=neural_factory.tb_writer, ) diff --git a/plume/models/jasper/featurizer.py b/src/plume/models/jasper_nemo/featurizer.py similarity index 81% rename from plume/models/jasper/featurizer.py rename to src/plume/models/jasper_nemo/featurizer.py index 030eb36..7ada867 100644 --- a/plume/models/jasper/featurizer.py +++ b/src/plume/models/jasper_nemo/featurizer.py @@ -3,19 +3,27 @@ # import librosa import torch import pickle + # import torch.nn as nn # from torch_stft import STFT # from nemo import logging from nemo.collections.asr.parts.perturb import AudioAugmentor + # from nemo.collections.asr.parts.segment import AudioSegment class RpycWaveformFeaturizer(object): def __init__( - self, sample_rate=16000, int_values=False, augmentor=None, rpyc_conn=None + self, + sample_rate=16000, + int_values=False, + augmentor=None, + rpyc_conn=None, ): - self.augmentor = augmentor if augmentor is not None else AudioAugmentor() + self.augmentor = ( + augmentor if augmentor is not None else AudioAugmentor() + ) self.sample_rate = sample_rate self.int_values = int_values self.remote_path_samples = rpyc_conn.get_path_samples @@ -48,4 +56,6 @@ class RpycWaveformFeaturizer(object): sample_rate = input_config.get("sample_rate", 16000) int_values = input_config.get("int_values", False) - return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa) + return cls( + sample_rate=sample_rate, int_values=int_values, augmentor=aa + ) diff --git a/plume/models/jasper/serve.py b/src/plume/models/jasper_nemo/serve.py similarity index 88% rename from plume/models/jasper/serve.py rename to src/plume/models/jasper_nemo/serve.py index 892c64e..686c90f 100644 --- a/plume/models/jasper/serve.py +++ b/src/plume/models/jasper_nemo/serve.py @@ -9,7 +9,7 @@ import typer from ...utils.serve import ASRService from plume.utils import lazy_callable -JasperASR = lazy_callable('plume.models.jasper.asr.JasperASR') +JasperASR = lazy_callable("plume.models.jasper_nemo.asr.JasperASR") app = typer.Typer() @@ -37,7 +37,9 @@ def rpyc( @app.command() -def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))): +def rpyc_dir( + model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")) +): encoder_path = model_dir / Path("decoder.pt") decoder_path = model_dir / Path("encoder.pt") model_yaml_path = model_dir / Path("model.yaml") diff --git a/plume/models/jasper/serve_data.py b/src/plume/models/jasper_nemo/serve_data.py similarity index 93% rename from plume/models/jasper/serve_data.py rename to src/plume/models/jasper_nemo/serve_data.py index 856c381..5729de3 100644 --- a/plume/models/jasper/serve_data.py +++ b/src/plume/models/jasper_nemo/serve_data.py @@ -40,7 +40,9 @@ class ASRDataService(rpyc.Service): @app.command() def run_server(port: int = 0): - listen_port = port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064")) + listen_port = ( + port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064")) + ) service = ASRDataService() t = ThreadedServer( service, port=listen_port, protocol_config={"allow_all_attrs": True} diff --git a/plume/models/jasper/train.py b/src/plume/models/jasper_nemo/train.py similarity index 91% rename from plume/models/jasper/train.py rename to src/plume/models/jasper_nemo/train.py index 7ef9beb..618bae7 100644 --- a/plume/models/jasper/train.py +++ b/src/plume/models/jasper_nemo/train.py @@ -161,7 +161,8 @@ def create_all_dags(args, neural_factory): logging.info("Have {0} examples to train on.".format(N)) data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor( - sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"] + sample_rate=sample_rate, + **jasper_params["AudioToMelSpectrogramPreprocessor"], ) multiply_batch_config = jasper_params.get("MultiplyBatch", None) @@ -212,8 +213,12 @@ def create_all_dags(args, neural_factory): greedy_decoder = nemo_asr.GreedyCTCDecoder() logging.info("================================") - logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}") - logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}") + logging.info( + f"Number of parameters in encoder: {jasper_encoder.num_weights}" + ) + logging.info( + f"Number of parameters in decoder: {jasper_decoder.num_weights}" + ) logging.info( f"Total number of parameters in model: " f"{jasper_decoder.num_weights + jasper_encoder.num_weights}" @@ -221,7 +226,12 @@ def create_all_dags(args, neural_factory): logging.info("================================") # Train DAG - (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer() + ( + audio_signal_t, + a_sig_length_t, + transcript_t, + transcript_len_t, + ) = data_layer() processed_signal_t, p_length_t = data_preprocessor( input_signal=audio_signal_t, length=a_sig_length_t ) @@ -240,7 +250,9 @@ def create_all_dags(args, neural_factory): ) if spectr_augment_config: - processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t) + processed_signal_t = data_spectr_augmentation( + input_spec=processed_signal_t + ) encoded_t, encoded_len_t = jasper_encoder( audio_signal=processed_signal_t, length=p_length_t @@ -273,7 +285,12 @@ def create_all_dags(args, neural_factory): # assemble eval DAGs for i, eval_dl in enumerate(data_layers_eval): - (audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl() + ( + audio_signal_e, + a_sig_length_e, + transcript_e, + transcript_len_e, + ) = eval_dl() processed_signal_e, p_length_e = data_preprocessor( input_signal=audio_signal_e, length=a_sig_length_e ) @@ -292,9 +309,16 @@ def create_all_dags(args, neural_factory): # create corresponding eval callback tagname = os.path.basename(args.eval_datasets[i]).split(".")[0] eval_callback = nemo.core.EvaluatorCallback( - eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e], + eval_tensors=[ + loss_e, + predictions_e, + transcript_e, + transcript_len_e, + ], user_iter_callback=partial(process_evaluation_batch, labels=vocab), - user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname), + user_epochs_done_callback=partial( + process_evaluation_epoch, tag=tagname + ), eval_step=args.eval_freq, tb_writer=neural_factory.tb_writer, ) @@ -338,7 +362,9 @@ def main(): logging.info("Doing ALL GPU") # build dags - train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory) + train_loss, callbacks, steps_per_epoch = create_all_dags( + args, neural_factory + ) # train model neural_factory.train( tensors_to_optimize=[train_loss], diff --git a/plume/models/matchboxnet/__init__.py b/src/plume/models/marblenet_nemo/__init__.py similarity index 100% rename from plume/models/matchboxnet/__init__.py rename to src/plume/models/marblenet_nemo/__init__.py diff --git a/src/plume/models/marblenet_nemo/asr.py b/src/plume/models/marblenet_nemo/asr.py new file mode 100644 index 0000000..8cabd56 --- /dev/null +++ b/src/plume/models/marblenet_nemo/asr.py @@ -0,0 +1,132 @@ +import os +import tempfile +from ruamel.yaml import YAML +import json +import nemo +import nemo.collections.asr as nemo_asr +import wave +from nemo.collections.asr.helpers import post_process_predictions + +logging = nemo.logging + +WORK_DIR = "/tmp" + + +class JasperASR(object): + """docstring for JasperASR.""" + + def __init__( + self, + model_yaml, + encoder_checkpoint, + decoder_checkpoint, + language_model=None, + ): + super(JasperASR, self).__init__() + # Read model YAML + yaml = YAML(typ="safe") + with open(model_yaml) as f: + jasper_model_definition = yaml.load(f) + self.neural_factory = nemo.core.NeuralModuleFactory( + placement=nemo.core.DeviceType.GPU, + backend=nemo.core.Backend.PyTorch, + ) + self.labels = jasper_model_definition["labels"] + self.data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor() + self.jasper_encoder = nemo_asr.JasperEncoder( + jasper=jasper_model_definition["JasperEncoder"]["jasper"], + activation=jasper_model_definition["JasperEncoder"]["activation"], + feat_in=jasper_model_definition[ + "AudioToMelSpectrogramPreprocessor" + ]["features"], + ) + self.jasper_encoder.restore_from(encoder_checkpoint, local_rank=0) + self.jasper_decoder = nemo_asr.JasperDecoderForCTC( + feat_in=1024, num_classes=len(self.labels) + ) + self.jasper_decoder.restore_from(decoder_checkpoint, local_rank=0) + self.greedy_decoder = nemo_asr.GreedyCTCDecoder() + self.beam_search_with_lm = None + if language_model: + self.beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM( + vocab=self.labels, + beam_width=64, + alpha=2.0, + beta=1.0, + lm_path=language_model, + num_cpus=max(os.cpu_count(), 1), + ) + + def transcribe(self, audio_data, greedy=True): + audio_file = tempfile.NamedTemporaryFile( + dir=WORK_DIR, prefix="jasper_audio.", delete=False + ) + # audio_file.write(audio_data) + audio_file.close() + audio_file_path = audio_file.name + wf = wave.open(audio_file_path, "w") + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframesraw(audio_data) + wf.close() + manifest = { + "audio_filepath": audio_file_path, + "duration": 60, + "text": "todo", + } + manifest_file = tempfile.NamedTemporaryFile( + dir=WORK_DIR, prefix="jasper_manifest.", delete=False, mode="w" + ) + manifest_file.write(json.dumps(manifest)) + manifest_file.close() + manifest_file_path = manifest_file.name + data_layer = nemo_asr.AudioToTextDataLayer( + shuffle=False, + manifest_filepath=manifest_file_path, + labels=self.labels, + batch_size=1, + ) + + # Define inference DAG + audio_signal, audio_signal_len, _, _ = data_layer() + processed_signal, processed_signal_len = self.data_preprocessor( + input_signal=audio_signal, length=audio_signal_len + ) + encoded, encoded_len = self.jasper_encoder( + audio_signal=processed_signal, length=processed_signal_len + ) + log_probs = self.jasper_decoder(encoder_output=encoded) + predictions = self.greedy_decoder(log_probs=log_probs) + + if greedy: + eval_tensors = [predictions] + else: + if self.beam_search_with_lm: + logging.info("Running with beam search") + beam_predictions = self.beam_search_with_lm( + log_probs=log_probs, log_probs_length=encoded_len + ) + eval_tensors = [beam_predictions] + else: + logging.info( + "language_model not specified. falling back to greedy decoding." + ) + eval_tensors = [predictions] + + tensors = self.neural_factory.infer(tensors=eval_tensors) + prediction = post_process_predictions(tensors[0], self.labels) + prediction_text = ". ".join(prediction) + os.unlink(manifest_file.name) + os.unlink(audio_file.name) + return prediction_text + + def transcribe_file(self, audio_file, *args, **kwargs): + tscript_file_path = audio_file.with_suffix(".txt") + audio_file_path = str(audio_file) + with wave.open(audio_file_path, "r") as af: + frame_count = af.getnframes() + audio_data = af.readframes(frame_count) + transcription = self.transcribe(audio_data, *args, **kwargs) + with open(tscript_file_path, "w") as tf: + tf.write(transcription) diff --git a/src/plume/models/marblenet_nemo/data.py b/src/plume/models/marblenet_nemo/data.py new file mode 100644 index 0000000..1d3babf --- /dev/null +++ b/src/plume/models/marblenet_nemo/data.py @@ -0,0 +1,24 @@ +from pathlib import Path +import typer + +app = typer.Typer() + + +@app.command() +def set_root(dataset_path: Path, root_path: Path): + pass + # for dataset_kind in ["train", "valid"]: + # data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv") + # with data_file.open("r") as df: + # lines = df.readlines() + # with data_file.open("w") as df: + # lines[0] = str(root_path) + "\n" + # df.writelines(lines) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/src/plume/models/marblenet_nemo/data_loaders.py b/src/plume/models/marblenet_nemo/data_loaders.py new file mode 100644 index 0000000..4a7ab97 --- /dev/null +++ b/src/plume/models/marblenet_nemo/data_loaders.py @@ -0,0 +1,340 @@ +from functools import partial +import tempfile + +# from typing import Any, Dict, List, Optional + +import torch +import nemo + +# import nemo.collections.asr as nemo_asr +from nemo.backends.pytorch import DataLayerNM +from nemo.core import DeviceType + +# from nemo.core.neural_types import * +from nemo.core.neural_types import ( + NeuralType, + AudioSignal, + LengthsType, + LabelsType, +) +from nemo.utils.decorators import add_port_docs + +from nemo.collections.asr.parts.dataset import ( + # AudioDataset, + # AudioLabelDataset, + # KaldiFeatureDataset, + # TranscriptDataset, + parsers, + collections, + seq_collate_fn, +) + +# from functools import lru_cache +import rpyc +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +from .featurizer import RpycWaveformFeaturizer + +# from nemo.collections.asr.parts.features import WaveformFeaturizer + +# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types + + +logging = nemo.logging + + +class CachedAudioDataset(torch.utils.data.Dataset): + """ + Dataset that loads tensors via a json file containing paths to audio + files, transcripts, and durations (in seconds). Each new line is a + different sample. Example below: + + {"audio_filepath": "/path/to/audio.wav", "text_filepath": + "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the + transcription", offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + + Args: + manifest_filepath: Path to manifest json as described above. Can + be comma-separated paths. + labels: String containing all the possible characters to map to + featurizer: Initialized featurizer class that converts paths of + audio to feature tensors + max_duration: If audio exceeds this length, do not include in dataset + min_duration: If audio is less than this length, do not include + in dataset + max_utts: Limit number of utterances + blank_index: blank character index, default = -1 + unk_index: unk_character index, default = -1 + normalize: whether to normalize transcript text (default): True + bos_id: Id of beginning of sequence symbol to append if not None + eos_id: Id of end of sequence symbol to append if not None + load_audio: Boolean flag indicate whether do or not load audio + """ + + def __init__( + self, + manifest_filepath, + labels, + featurizer, + max_duration=None, + min_duration=None, + max_utts=0, + blank_index=-1, + unk_index=-1, + normalize=True, + trim=False, + bos_id=None, + eos_id=None, + load_audio=True, + parser="en", + ): + self.collection = collections.ASRAudioText( + manifests_files=manifest_filepath.split(","), + parser=parsers.make_parser( + labels=labels, + name=parser, + unk_id=unk_index, + blank_id=blank_index, + do_normalize=normalize, + ), + min_duration=min_duration, + max_duration=max_duration, + max_number=max_utts, + ) + self.index_feature_map = {} + + self.featurizer = featurizer + self.trim = trim + self.eos_id = eos_id + self.bos_id = bos_id + self.load_audio = load_audio + print(f"initializing dataset {manifest_filepath}") + + def exec_func(i): + return self[i] + + task_count = len(self.collection) + with ThreadPoolExecutor() as exe: + print("starting all loading tasks") + list( + tqdm( + exe.map(exec_func, range(task_count)), + position=0, + leave=True, + total=task_count, + ) + ) + print(f"initializing complete") + + def __getitem__(self, index): + sample = self.collection[index] + if self.load_audio: + cached_features = self.index_feature_map.get(index) + if cached_features is not None: + features = cached_features + else: + features = self.featurizer.process( + sample.audio_file, + offset=0, + duration=sample.duration, + trim=self.trim, + ) + self.index_feature_map[index] = features + f, fl = features, torch.tensor(features.shape[0]).long() + else: + f, fl = None, None + + t, tl = sample.text_tokens, len(sample.text_tokens) + if self.bos_id is not None: + t = [self.bos_id] + t + tl += 1 + if self.eos_id is not None: + t = t + [self.eos_id] + tl += 1 + + return f, fl, torch.tensor(t).long(), torch.tensor(tl).long() + + def __len__(self): + return len(self.collection) + + +class RpycAudioToTextDataLayer(DataLayerNM): + """Data Layer for general ASR tasks. + + Module which reads ASR labeled data. It accepts comma-separated + JSON manifest files describing the correspondence between wav audio files + and their transcripts. JSON files should be of the following format:: + + {"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \ +transcript_0} + ... + {"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \ +transcript_n} + + Args: + manifest_filepath (str): Dataset parameter. + Path to JSON containing data. + labels (list): Dataset parameter. + List of characters that can be output by the ASR model. + For Jasper, this is the 28 character set {a-z '}. The CTC blank + symbol is automatically added later for models using ctc. + batch_size (int): batch size + sample_rate (int): Target sampling rate for data. Audio files will be + resampled to sample_rate if it is not already. + Defaults to 16000. + int_values (bool): Bool indicating whether the audio file is saved as + int data or float data. + Defaults to False. + eos_id (id): Dataset parameter. + End of string symbol id used for seq2seq models. + Defaults to None. + min_duration (float): Dataset parameter. + All training files which have a duration less than min_duration + are dropped. Note: Duration is read from the manifest JSON. + Defaults to 0.1. + max_duration (float): Dataset parameter. + All training files which have a duration more than max_duration + are dropped. Note: Duration is read from the manifest JSON. + Defaults to None. + normalize_transcripts (bool): Dataset parameter. + Whether to use automatic text cleaning. + It is highly recommended to manually clean text for best results. + Defaults to True. + trim_silence (bool): Whether to use trim silence from beginning and end + of audio signal using librosa.effects.trim(). + Defaults to False. + load_audio (bool): Dataset parameter. + Controls whether the dataloader loads the audio signal and + transcript or just the transcript. + Defaults to True. + drop_last (bool): See PyTorch DataLoader. + Defaults to False. + shuffle (bool): See PyTorch DataLoader. + Defaults to True. + num_workers (int): See PyTorch DataLoader. + Defaults to 0. + perturb_config (dict): Currently disabled. + """ + + @property + @add_port_docs() + def output_ports(self): + """Returns definitions of module output ports.""" + return { + # 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}), + # 'a_sig_length': NeuralType({0: AxisType(BatchTag)}), + # 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}), + # 'transcript_length': NeuralType({0: AxisType(BatchTag)}), + "audio_signal": NeuralType( + ("B", "T"), + AudioSignal(freq=self._sample_rate) + if self is not None and self._sample_rate is not None + else AudioSignal(), + ), + "a_sig_length": NeuralType(tuple("B"), LengthsType()), + "transcripts": NeuralType(("B", "T"), LabelsType()), + "transcript_length": NeuralType(tuple("B"), LengthsType()), + } + + def __init__( + self, + manifest_filepath, + labels, + batch_size, + sample_rate=16000, + int_values=False, + bos_id=None, + eos_id=None, + pad_id=None, + min_duration=0.1, + max_duration=None, + normalize_transcripts=True, + trim_silence=False, + load_audio=True, + rpyc_host="", + drop_last=False, + shuffle=True, + num_workers=0, + ): + super().__init__() + self._sample_rate = sample_rate + + def rpyc_root_fn(): + return rpyc.connect( + rpyc_host, 8064, config={"sync_request_timeout": 600} + ).root + + rpyc_conn = rpyc_root_fn() + + self._featurizer = RpycWaveformFeaturizer( + sample_rate=self._sample_rate, + int_values=int_values, + augmentor=None, + rpyc_conn=rpyc_conn, + ) + + def read_remote_manifests(): + local_mp = [] + for mrp in manifest_filepath.split(","): + md = rpyc_conn.read_path(mrp) + mf = tempfile.NamedTemporaryFile( + dir="/tmp", prefix="jasper_manifest.", delete=False + ) + mf.write(md) + mf.close() + local_mp.append(mf.name) + return ",".join(local_mp) + + local_manifest_filepath = read_remote_manifests() + dataset_params = { + "manifest_filepath": local_manifest_filepath, + "labels": labels, + "featurizer": self._featurizer, + "max_duration": max_duration, + "min_duration": min_duration, + "normalize": normalize_transcripts, + "trim": trim_silence, + "bos_id": bos_id, + "eos_id": eos_id, + "load_audio": load_audio, + } + + self._dataset = CachedAudioDataset(**dataset_params) + self._batch_size = batch_size + + # Set up data loader + if self._placement == DeviceType.AllGpu: + logging.info("Parallelizing Datalayer.") + sampler = torch.utils.data.distributed.DistributedSampler( + self._dataset + ) + else: + sampler = None + + if batch_size == -1: + batch_size = len(self._dataset) + + pad_id = 0 if pad_id is None else pad_id + self._dataloader = torch.utils.data.DataLoader( + dataset=self._dataset, + batch_size=batch_size, + collate_fn=partial(seq_collate_fn, token_pad_value=pad_id), + drop_last=drop_last, + shuffle=shuffle if sampler is None else False, + sampler=sampler, + num_workers=1, + ) + + def __len__(self): + return len(self._dataset) + + @property + def dataset(self): + return None + + @property + def data_iterator(self): + return self._dataloader diff --git a/src/plume/models/marblenet_nemo/eval.py b/src/plume/models/marblenet_nemo/eval.py new file mode 100644 index 0000000..ab18a8e --- /dev/null +++ b/src/plume/models/marblenet_nemo/eval.py @@ -0,0 +1,376 @@ +# Copyright (c) 2019 NVIDIA Corporation +import argparse +import copy + +# import math +import os +from pathlib import Path +from functools import partial + +from ruamel.yaml import YAML + +import nemo +import nemo.collections.asr as nemo_asr +import nemo.utils.argparse as nm_argparse +from nemo.collections.asr.helpers import ( + # monitor_asr_train_progress, + process_evaluation_batch, + process_evaluation_epoch, +) + +# from nemo.utils.lr_policies import CosineAnnealing +from training.data_loaders import RpycAudioToTextDataLayer + +logging = nemo.logging + + +def parse_args(): + parser = argparse.ArgumentParser( + parents=[nm_argparse.NemoArgParser()], + description="Jasper", + conflict_handler="resolve", + ) + parser.set_defaults( + checkpoint_dir=None, + optimizer="novograd", + batch_size=64, + eval_batch_size=64, + lr=0.002, + amp_opt_level="O1", + create_tb_writer=True, + model_config="./train/jasper10x5dr.yaml", + work_dir="./train/work", + num_epochs=300, + weight_decay=0.005, + checkpoint_save_freq=100, + eval_freq=100, + load_dir="./train/models/jasper/", + warmup_steps=3, + exp_name="jasper", + ) + + # Overwrite default args + parser.add_argument( + "--max_steps", + type=int, + default=None, + required=False, + help="max number of steps to train", + ) + parser.add_argument( + "--num_epochs", + type=int, + required=False, + help="number of epochs to train", + ) + parser.add_argument( + "--model_config", + type=str, + required=False, + help="model configuration file: model.yaml", + ) + parser.add_argument( + "--encoder_checkpoint", + type=str, + required=True, + help="encoder checkpoint file: JasperEncoder.pt", + ) + parser.add_argument( + "--decoder_checkpoint", + type=str, + required=True, + help="decoder checkpoint file: JasperDecoderForCTC.pt", + ) + parser.add_argument( + "--remote_data", + type=str, + required=False, + default="", + help="remote dataloader endpoint", + ) + parser.add_argument( + "--dataset", + type=str, + required=False, + default="", + help="dataset directory containing train/test manifests", + ) + + # Create new args + parser.add_argument("--exp_name", default="Jasper", type=str) + parser.add_argument("--beta1", default=0.95, type=float) + parser.add_argument("--beta2", default=0.25, type=float) + parser.add_argument("--warmup_steps", default=0, type=int) + parser.add_argument( + "--load_dir", + default=None, + type=str, + help="directory with pre-trained checkpoint", + ) + + args = parser.parse_args() + if args.max_steps is None and args.num_epochs is None: + raise ValueError("Either max_steps or num_epochs should be provided.") + return args + + +def construct_name( + name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step +): + if max_steps is not None: + return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format( + name, lr, batch_size, max_steps, wd, optimizer, iter_per_step + ) + else: + return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format( + name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step + ) + + +def create_all_dags(args, neural_factory): + yaml = YAML(typ="safe") + with open(args.model_config) as f: + jasper_params = yaml.load(f) + vocab = jasper_params["labels"] + sample_rate = jasper_params["sample_rate"] + + # Calculate num_workers for dataloader + total_cpus = os.cpu_count() + cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1) + # perturb_config = jasper_params.get('perturb', None) + train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) + train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"]) + del train_dl_params["train"] + del train_dl_params["eval"] + # del train_dl_params["normalize_transcripts"] + + if args.dataset: + d_path = Path(args.dataset) + if not args.train_dataset: + args.train_dataset = str(d_path / Path("train_manifest.json")) + if not args.eval_datasets: + args.eval_datasets = [str(d_path / Path("test_manifest.json"))] + + data_loader_layer = nemo_asr.AudioToTextDataLayer + + if args.remote_data: + train_dl_params["rpyc_host"] = args.remote_data + data_loader_layer = RpycAudioToTextDataLayer + + # data_layer = data_loader_layer( + # manifest_filepath=args.train_dataset, + # sample_rate=sample_rate, + # labels=vocab, + # batch_size=args.batch_size, + # num_workers=cpu_per_traindl, + # **train_dl_params, + # # normalize_transcripts=False + # ) + # + # N = len(data_layer) + # steps_per_epoch = math.ceil( + # N / (args.batch_size * args.iter_per_step * args.num_gpus) + # ) + # logging.info("Have {0} examples to train on.".format(N)) + # + data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor( + sample_rate=sample_rate, + **jasper_params["AudioToMelSpectrogramPreprocessor"], + ) + + # multiply_batch_config = jasper_params.get("MultiplyBatch", None) + # if multiply_batch_config: + # multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config) + # + # spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None) + # if spectr_augment_config: + # data_spectr_augmentation = nemo_asr.SpectrogramAugmentation( + # **spectr_augment_config + # ) + # + eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) + eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"]) + if args.remote_data: + eval_dl_params["rpyc_host"] = args.remote_data + del eval_dl_params["train"] + del eval_dl_params["eval"] + data_layers_eval = [] + + # if args.eval_datasets: + for eval_datasets in args.eval_datasets: + data_layer_eval = data_loader_layer( + manifest_filepath=eval_datasets, + sample_rate=sample_rate, + labels=vocab, + batch_size=args.eval_batch_size, + num_workers=cpu_per_traindl, + **eval_dl_params, + ) + + data_layers_eval.append(data_layer_eval) + # else: + # logging.warning("There were no val datasets passed") + + jasper_encoder = nemo_asr.JasperEncoder( + feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"], + **jasper_params["JasperEncoder"], + ) + jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0) + + jasper_decoder = nemo_asr.JasperDecoderForCTC( + feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], + num_classes=len(vocab), + ) + jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0) + + ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab)) + + greedy_decoder = nemo_asr.GreedyCTCDecoder() + + # logging.info("================================") + # logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}") + # logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}") + # logging.info( + # f"Total number of parameters in model: " + # f"{jasper_decoder.num_weights + jasper_encoder.num_weights}" + # ) + # logging.info("================================") + # + # # Train DAG + # (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer() + # processed_signal_t, p_length_t = data_preprocessor( + # input_signal=audio_signal_t, length=a_sig_length_t + # ) + # + # if multiply_batch_config: + # ( + # processed_signal_t, + # p_length_t, + # transcript_t, + # transcript_len_t, + # ) = multiply_batch( + # in_x=processed_signal_t, + # in_x_len=p_length_t, + # in_y=transcript_t, + # in_y_len=transcript_len_t, + # ) + # + # if spectr_augment_config: + # processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t) + # + # encoded_t, encoded_len_t = jasper_encoder( + # audio_signal=processed_signal_t, length=p_length_t + # ) + # log_probs_t = jasper_decoder(encoder_output=encoded_t) + # predictions_t = greedy_decoder(log_probs=log_probs_t) + # loss_t = ctc_loss( + # log_probs=log_probs_t, + # targets=transcript_t, + # input_length=encoded_len_t, + # target_length=transcript_len_t, + # ) + # + # # Callbacks needed to print info to console and Tensorboard + # train_callback = nemo.core.SimpleLossLoggerCallback( + # tensors=[loss_t, predictions_t, transcript_t, transcript_len_t], + # print_func=partial(monitor_asr_train_progress, labels=vocab), + # get_tb_values=lambda x: [("loss", x[0])], + # tb_writer=neural_factory.tb_writer, + # ) + # + # chpt_callback = nemo.core.CheckpointCallback( + # folder=neural_factory.checkpoint_dir, + # load_from_folder=args.load_dir, + # step_freq=args.checkpoint_save_freq, + # checkpoints_to_keep=30, + # ) + # + # callbacks = [train_callback, chpt_callback] + callbacks = [] + # assemble eval DAGs + for i, eval_dl in enumerate(data_layers_eval): + ( + audio_signal_e, + a_sig_length_e, + transcript_e, + transcript_len_e, + ) = eval_dl() + processed_signal_e, p_length_e = data_preprocessor( + input_signal=audio_signal_e, length=a_sig_length_e + ) + encoded_e, encoded_len_e = jasper_encoder( + audio_signal=processed_signal_e, length=p_length_e + ) + log_probs_e = jasper_decoder(encoder_output=encoded_e) + predictions_e = greedy_decoder(log_probs=log_probs_e) + loss_e = ctc_loss( + log_probs=log_probs_e, + targets=transcript_e, + input_length=encoded_len_e, + target_length=transcript_len_e, + ) + + # create corresponding eval callback + tagname = os.path.basename(args.eval_datasets[i]).split(".")[0] + eval_callback = nemo.core.EvaluatorCallback( + eval_tensors=[ + loss_e, + predictions_e, + transcript_e, + transcript_len_e, + ], + user_iter_callback=partial(process_evaluation_batch, labels=vocab), + user_epochs_done_callback=partial( + process_evaluation_epoch, tag=tagname + ), + eval_step=args.eval_freq, + tb_writer=neural_factory.tb_writer, + ) + + callbacks.append(eval_callback) + return callbacks + + +def main(): + args = parse_args() + # name = construct_name( + # args.exp_name, + # args.lr, + # args.batch_size, + # args.max_steps, + # args.num_epochs, + # args.weight_decay, + # args.optimizer, + # args.iter_per_step, + # ) + # log_dir = name + # if args.work_dir: + # log_dir = os.path.join(args.work_dir, name) + + # instantiate Neural Factory with supported backend + neural_factory = nemo.core.NeuralModuleFactory( + placement=nemo.core.DeviceType.GPU, + backend=nemo.core.Backend.PyTorch, + # local_rank=args.local_rank, + # optimization_level=args.amp_opt_level, + # log_dir=log_dir, + # checkpoint_dir=args.checkpoint_dir, + # create_tb_writer=args.create_tb_writer, + # files_to_copy=[args.model_config, __file__], + # cudnn_benchmark=args.cudnn_benchmark, + # tensorboard_dir=args.tensorboard_dir, + ) + args.num_gpus = neural_factory.world_size + + # checkpoint_dir = neural_factory.checkpoint_dir + if args.local_rank is not None: + logging.info("Doing ALL GPU") + + # build dags + callbacks = create_all_dags(args, neural_factory) + # evaluate model + neural_factory.eval(callbacks=callbacks) + + +if __name__ == "__main__": + main() diff --git a/src/plume/models/marblenet_nemo/featurizer.py b/src/plume/models/marblenet_nemo/featurizer.py new file mode 100644 index 0000000..7ada867 --- /dev/null +++ b/src/plume/models/marblenet_nemo/featurizer.py @@ -0,0 +1,61 @@ +# import math + +# import librosa +import torch +import pickle + +# import torch.nn as nn +# from torch_stft import STFT + +# from nemo import logging +from nemo.collections.asr.parts.perturb import AudioAugmentor + +# from nemo.collections.asr.parts.segment import AudioSegment + + +class RpycWaveformFeaturizer(object): + def __init__( + self, + sample_rate=16000, + int_values=False, + augmentor=None, + rpyc_conn=None, + ): + self.augmentor = ( + augmentor if augmentor is not None else AudioAugmentor() + ) + self.sample_rate = sample_rate + self.int_values = int_values + self.remote_path_samples = rpyc_conn.get_path_samples + + def max_augmentation_length(self, length): + return self.augmentor.max_augmentation_length(length) + + def process(self, file_path, offset=0, duration=0, trim=False): + audio = self.remote_path_samples( + file_path, + target_sr=self.sample_rate, + int_values=self.int_values, + offset=offset, + duration=duration, + trim=trim, + ) + return torch.tensor(pickle.loads(audio), dtype=torch.float) + + def process_segment(self, audio_segment): + self.augmentor.perturb(audio_segment) + return torch.tensor(audio_segment, dtype=torch.float) + + @classmethod + def from_config(cls, input_config, perturbation_configs=None): + if perturbation_configs is not None: + aa = AudioAugmentor.from_config(perturbation_configs) + else: + aa = None + + sample_rate = input_config.get("sample_rate", 16000) + int_values = input_config.get("int_values", False) + + return cls( + sample_rate=sample_rate, int_values=int_values, augmentor=aa + ) diff --git a/src/plume/models/marblenet_nemo/serve.py b/src/plume/models/marblenet_nemo/serve.py new file mode 100644 index 0000000..686c90f --- /dev/null +++ b/src/plume/models/marblenet_nemo/serve.py @@ -0,0 +1,54 @@ +import os +import logging +from pathlib import Path + +from rpyc.utils.server import ThreadedServer +import typer + +# from .asr import JasperASR +from ...utils.serve import ASRService +from plume.utils import lazy_callable + +JasperASR = lazy_callable("plume.models.jasper_nemo.asr.JasperASR") + +app = typer.Typer() + + +@app.command() +def rpyc( + encoder_path: Path = "/path/to/encoder.pt", + decoder_path: Path = "/path/to/decoder.pt", + model_yaml_path: Path = "/path/to/model.yaml", + port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")), +): + for p in [encoder_path, decoder_path, model_yaml_path]: + if not p.exists(): + logging.info(f"{p} doesn't exists") + return + asr = JasperASR(str(model_yaml_path), str(encoder_path), str(decoder_path)) + service = ASRService(asr) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logging.info("starting asr server...") + t = ThreadedServer(service, port=port) + t.start() + + +@app.command() +def rpyc_dir( + model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")) +): + encoder_path = model_dir / Path("decoder.pt") + decoder_path = model_dir / Path("encoder.pt") + model_yaml_path = model_dir / Path("model.yaml") + rpyc(encoder_path, decoder_path, model_yaml_path, port) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/src/plume/models/marblenet_nemo/serve_data.py b/src/plume/models/marblenet_nemo/serve_data.py new file mode 100644 index 0000000..5729de3 --- /dev/null +++ b/src/plume/models/marblenet_nemo/serve_data.py @@ -0,0 +1,59 @@ +import os +from pathlib import Path + +import typer +import rpyc +from rpyc.utils.server import ThreadedServer +import nemo +import pickle + +# import nemo.collections.asr as nemo_asr +from nemo.collections.asr.parts.segment import AudioSegment + +app = typer.Typer() + +nemo.core.NeuralModuleFactory( + backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU +) + + +class ASRDataService(rpyc.Service): + def exposed_get_path_samples( + self, file_path, target_sr, int_values, offset, duration, trim + ): + print(f"loading.. {file_path}") + audio = AudioSegment.from_file( + file_path, + target_sr=target_sr, + int_values=int_values, + offset=offset, + duration=duration, + trim=trim, + ) + # print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}") + return pickle.dumps(audio.samples) + + def exposed_read_path(self, file_path): + # print(f"reading path.. {file_path}") + return Path(file_path).read_bytes() + + +@app.command() +def run_server(port: int = 0): + listen_port = ( + port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064")) + ) + service = ASRDataService() + t = ThreadedServer( + service, port=listen_port, protocol_config={"allow_all_attrs": True} + ) + typer.echo(f"starting asr server on {listen_port}...") + t.start() + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/src/plume/models/marblenet_nemo/train.py b/src/plume/models/marblenet_nemo/train.py new file mode 100644 index 0000000..618bae7 --- /dev/null +++ b/src/plume/models/marblenet_nemo/train.py @@ -0,0 +1,392 @@ +# Copyright (c) 2019 NVIDIA Corporation +import argparse +import copy +import math +import os +from pathlib import Path +from functools import partial + +from ruamel.yaml import YAML + +import nemo +import nemo.collections.asr as nemo_asr +import nemo.utils.argparse as nm_argparse +from nemo.collections.asr.helpers import ( + monitor_asr_train_progress, + process_evaluation_batch, + process_evaluation_epoch, +) + +from nemo.utils.lr_policies import CosineAnnealing +from .data_loaders import RpycAudioToTextDataLayer + +logging = nemo.logging + + +def parse_args(): + parser = argparse.ArgumentParser( + parents=[nm_argparse.NemoArgParser()], + description="Jasper", + conflict_handler="resolve", + ) + parser.set_defaults( + checkpoint_dir=None, + optimizer="novograd", + batch_size=64, + eval_batch_size=64, + lr=0.002, + amp_opt_level="O1", + create_tb_writer=True, + model_config="./train/jasper10x5dr.yaml", + work_dir="./train/work", + num_epochs=300, + weight_decay=0.005, + checkpoint_save_freq=100, + eval_freq=100, + load_dir="./train/models/jasper/", + warmup_steps=3, + exp_name="jasper-speller", + ) + + # Overwrite default args + parser.add_argument( + "--max_steps", + type=int, + default=None, + required=False, + help="max number of steps to train", + ) + parser.add_argument( + "--num_epochs", + type=int, + required=False, + help="number of epochs to train", + ) + parser.add_argument( + "--model_config", + type=str, + required=False, + help="model configuration file: model.yaml", + ) + parser.add_argument( + "--remote_data", + type=str, + required=False, + default="", + help="remote dataloader endpoint", + ) + parser.add_argument( + "--dataset", + type=str, + required=False, + default="", + help="dataset directory containing train/test manifests", + ) + + # Create new args + parser.add_argument("--exp_name", default="Jasper", type=str) + parser.add_argument("--beta1", default=0.95, type=float) + parser.add_argument("--beta2", default=0.25, type=float) + parser.add_argument("--warmup_steps", default=0, type=int) + parser.add_argument( + "--load_dir", + default=None, + type=str, + help="directory with pre-trained checkpoint", + ) + + args = parser.parse_args() + if args.max_steps is None and args.num_epochs is None: + raise ValueError("Either max_steps or num_epochs should be provided.") + return args + + +def construct_name( + name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step +): + if max_steps is not None: + return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format( + name, lr, batch_size, max_steps, wd, optimizer, iter_per_step + ) + else: + return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format( + name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step + ) + + +def create_all_dags(args, neural_factory): + yaml = YAML(typ="safe") + with open(args.model_config) as f: + jasper_params = yaml.load(f) + vocab = jasper_params["labels"] + sample_rate = jasper_params["sample_rate"] + + # Calculate num_workers for dataloader + total_cpus = os.cpu_count() + cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1) + # perturb_config = jasper_params.get('perturb', None) + train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) + train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"]) + del train_dl_params["train"] + del train_dl_params["eval"] + # del train_dl_params["normalize_transcripts"] + + if args.dataset: + d_path = Path(args.dataset) + if not args.train_dataset: + args.train_dataset = str(d_path / Path("train_manifest.json")) + if not args.eval_datasets: + args.eval_datasets = [str(d_path / Path("test_manifest.json"))] + + data_loader_layer = nemo_asr.AudioToTextDataLayer + + if args.remote_data: + train_dl_params["rpyc_host"] = args.remote_data + data_loader_layer = RpycAudioToTextDataLayer + + data_layer = data_loader_layer( + manifest_filepath=args.train_dataset, + sample_rate=sample_rate, + labels=vocab, + batch_size=args.batch_size, + num_workers=cpu_per_traindl, + **train_dl_params, + # normalize_transcripts=False + ) + + N = len(data_layer) + steps_per_epoch = math.ceil( + N / (args.batch_size * args.iter_per_step * args.num_gpus) + ) + logging.info("Have {0} examples to train on.".format(N)) + + data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor( + sample_rate=sample_rate, + **jasper_params["AudioToMelSpectrogramPreprocessor"], + ) + + multiply_batch_config = jasper_params.get("MultiplyBatch", None) + if multiply_batch_config: + multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config) + + spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None) + if spectr_augment_config: + data_spectr_augmentation = nemo_asr.SpectrogramAugmentation( + **spectr_augment_config + ) + + eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) + eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"]) + if args.remote_data: + eval_dl_params["rpyc_host"] = args.remote_data + del eval_dl_params["train"] + del eval_dl_params["eval"] + data_layers_eval = [] + + if args.eval_datasets: + for eval_datasets in args.eval_datasets: + data_layer_eval = data_loader_layer( + manifest_filepath=eval_datasets, + sample_rate=sample_rate, + labels=vocab, + batch_size=args.eval_batch_size, + num_workers=cpu_per_traindl, + **eval_dl_params, + ) + + data_layers_eval.append(data_layer_eval) + else: + logging.warning("There were no val datasets passed") + + jasper_encoder = nemo_asr.JasperEncoder( + feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"], + **jasper_params["JasperEncoder"], + ) + + jasper_decoder = nemo_asr.JasperDecoderForCTC( + feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], + num_classes=len(vocab), + ) + + ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab)) + + greedy_decoder = nemo_asr.GreedyCTCDecoder() + + logging.info("================================") + logging.info( + f"Number of parameters in encoder: {jasper_encoder.num_weights}" + ) + logging.info( + f"Number of parameters in decoder: {jasper_decoder.num_weights}" + ) + logging.info( + f"Total number of parameters in model: " + f"{jasper_decoder.num_weights + jasper_encoder.num_weights}" + ) + logging.info("================================") + + # Train DAG + ( + audio_signal_t, + a_sig_length_t, + transcript_t, + transcript_len_t, + ) = data_layer() + processed_signal_t, p_length_t = data_preprocessor( + input_signal=audio_signal_t, length=a_sig_length_t + ) + + if multiply_batch_config: + ( + processed_signal_t, + p_length_t, + transcript_t, + transcript_len_t, + ) = multiply_batch( + in_x=processed_signal_t, + in_x_len=p_length_t, + in_y=transcript_t, + in_y_len=transcript_len_t, + ) + + if spectr_augment_config: + processed_signal_t = data_spectr_augmentation( + input_spec=processed_signal_t + ) + + encoded_t, encoded_len_t = jasper_encoder( + audio_signal=processed_signal_t, length=p_length_t + ) + log_probs_t = jasper_decoder(encoder_output=encoded_t) + predictions_t = greedy_decoder(log_probs=log_probs_t) + loss_t = ctc_loss( + log_probs=log_probs_t, + targets=transcript_t, + input_length=encoded_len_t, + target_length=transcript_len_t, + ) + + # Callbacks needed to print info to console and Tensorboard + train_callback = nemo.core.SimpleLossLoggerCallback( + tensors=[loss_t, predictions_t, transcript_t, transcript_len_t], + print_func=partial(monitor_asr_train_progress, labels=vocab), + get_tb_values=lambda x: [("loss", x[0])], + tb_writer=neural_factory.tb_writer, + ) + + chpt_callback = nemo.core.CheckpointCallback( + folder=neural_factory.checkpoint_dir, + load_from_folder=args.load_dir, + step_freq=args.checkpoint_save_freq, + checkpoints_to_keep=30, + ) + + callbacks = [train_callback, chpt_callback] + + # assemble eval DAGs + for i, eval_dl in enumerate(data_layers_eval): + ( + audio_signal_e, + a_sig_length_e, + transcript_e, + transcript_len_e, + ) = eval_dl() + processed_signal_e, p_length_e = data_preprocessor( + input_signal=audio_signal_e, length=a_sig_length_e + ) + encoded_e, encoded_len_e = jasper_encoder( + audio_signal=processed_signal_e, length=p_length_e + ) + log_probs_e = jasper_decoder(encoder_output=encoded_e) + predictions_e = greedy_decoder(log_probs=log_probs_e) + loss_e = ctc_loss( + log_probs=log_probs_e, + targets=transcript_e, + input_length=encoded_len_e, + target_length=transcript_len_e, + ) + + # create corresponding eval callback + tagname = os.path.basename(args.eval_datasets[i]).split(".")[0] + eval_callback = nemo.core.EvaluatorCallback( + eval_tensors=[ + loss_e, + predictions_e, + transcript_e, + transcript_len_e, + ], + user_iter_callback=partial(process_evaluation_batch, labels=vocab), + user_epochs_done_callback=partial( + process_evaluation_epoch, tag=tagname + ), + eval_step=args.eval_freq, + tb_writer=neural_factory.tb_writer, + ) + + callbacks.append(eval_callback) + return loss_t, callbacks, steps_per_epoch + + +def main(): + args = parse_args() + name = construct_name( + args.exp_name, + args.lr, + args.batch_size, + args.max_steps, + args.num_epochs, + args.weight_decay, + args.optimizer, + args.iter_per_step, + ) + log_dir = name + if args.work_dir: + log_dir = os.path.join(args.work_dir, name) + + # instantiate Neural Factory with supported backend + neural_factory = nemo.core.NeuralModuleFactory( + backend=nemo.core.Backend.PyTorch, + local_rank=args.local_rank, + optimization_level=args.amp_opt_level, + log_dir=log_dir, + checkpoint_dir=args.checkpoint_dir, + create_tb_writer=args.create_tb_writer, + files_to_copy=[args.model_config, __file__], + cudnn_benchmark=args.cudnn_benchmark, + tensorboard_dir=args.tensorboard_dir, + ) + args.num_gpus = neural_factory.world_size + + checkpoint_dir = neural_factory.checkpoint_dir + if args.local_rank is not None: + logging.info("Doing ALL GPU") + + # build dags + train_loss, callbacks, steps_per_epoch = create_all_dags( + args, neural_factory + ) + # train model + neural_factory.train( + tensors_to_optimize=[train_loss], + callbacks=callbacks, + lr_policy=CosineAnnealing( + args.max_steps + if args.max_steps is not None + else args.num_epochs * steps_per_epoch, + warmup_steps=args.warmup_steps, + ), + optimizer=args.optimizer, + optimization_params={ + "num_epochs": args.num_epochs, + "max_steps": args.max_steps, + "lr": args.lr, + "betas": (args.beta1, args.beta2), + "weight_decay": args.weight_decay, + "grad_norm_clip": None, + }, + batches_per_step=args.iter_per_step, + ) + + +if __name__ == "__main__": + main() diff --git a/src/plume/models/marblenet_nemo/trial.py b/src/plume/models/marblenet_nemo/trial.py new file mode 100644 index 0000000..2a21ddb --- /dev/null +++ b/src/plume/models/marblenet_nemo/trial.py @@ -0,0 +1,22 @@ +import numpy as np +import os +import time +import copy + +from omegaconf import OmegaConf +import matplotlib.pyplot as plt +import IPython.display as ipd +# import pyaudio as pa +import librosa +import nemo +import nemo.collections.asr as nemo_asr + +# sample rate, Hz +SAMPLE_RATE = 16000 + +vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained( + "vad_marblenet" +) +# Preserve a copy of the full config +cfg = copy.deepcopy(vad_model._cfg) +# print(OmegaConf.to_yaml(cfg)) diff --git a/plume/models/wav2vec2/__init__.py b/src/plume/models/wav2vec2/__init__.py similarity index 100% rename from plume/models/wav2vec2/__init__.py rename to src/plume/models/wav2vec2/__init__.py diff --git a/plume/models/wav2vec2/asr.py b/src/plume/models/wav2vec2/asr.py similarity index 100% rename from plume/models/wav2vec2/asr.py rename to src/plume/models/wav2vec2/asr.py diff --git a/src/plume/models/wav2vec2/data.py b/src/plume/models/wav2vec2/data.py new file mode 100644 index 0000000..f6a3fe9 --- /dev/null +++ b/src/plume/models/wav2vec2/data.py @@ -0,0 +1,234 @@ +from pathlib import Path +from collections import Counter +import shutil +import io + +# from time import time + +# import pydub +import typer +from tqdm import tqdm + +from plume.utils import ( + ExtendedPath, + replace_redundant_spaces_with, + lazy_module, + random_segs, + parallel_apply, + batch, + run_shell, +) + +from plume.utils.vad import VADUtterance + +soundfile = lazy_module("soundfile") +pydub = lazy_module("pydub") +app = typer.Typer() + + +@app.command() +def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True): + dict_ltr = dest_dataset_path / Path("dict.ltr.txt") + (dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True) + tok_counter = Counter() + shutil.copy( + src_dataset_path / Path("test_manifest.json"), + src_dataset_path / Path("valid_manifest.json"), + ) + if unlink: + src_wavs = src_dataset_path / Path("wavs") + for wav_path in tqdm(list(src_wavs.glob("**/*.wav"))): + audio_seg = ( + pydub.AudioSegment.from_wav(wav_path) + .set_frame_rate(16000) + .set_channels(1) + ) + dest_path = dest_dataset_path / Path("wavs") / Path(wav_path.name) + audio_seg.export(dest_path, format="wav") + + for dataset_kind in ["train", "valid"]: + abs_manifest_path = ExtendedPath( + src_dataset_path / Path(f"{dataset_kind}_manifest.json") + ) + manifest_data = list(abs_manifest_path.read_jsonl()) + o_tsv, o_ltr = f"{dataset_kind}.tsv", f"{dataset_kind}.ltr" + out_tsv = dest_dataset_path / Path(o_tsv) + out_ltr = dest_dataset_path / Path(o_ltr) + with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f: + if unlink: + tsv_f.write(f"{dest_dataset_path}\n") + else: + tsv_f.write(f"{src_dataset_path}\n") + for md in manifest_data: + audio_fname = md["audio_filepath"] + pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper() + # pipe_toks = "|".join(re.sub(" ", "", md["text"])) + # pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|") + tok_counter.update(pipe_toks) + letter_toks = " ".join(pipe_toks) + " |\n" + frame_count = soundfile.info(audio_fname).frames + rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute()) + ltr_f.write(letter_toks) + tsv_f.write(f"{rel_path}\t{frame_count}\n") + with dict_ltr.open("w") as d_f: + for k, v in tok_counter.most_common(): + d_f.write(f"{k} {v}\n") + (src_dataset_path / Path("valid_manifest.json")).unlink() + + +@app.command() +def set_root(dataset_path: Path, root_path: Path): + for dataset_kind in ["train", "valid"]: + data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv") + with data_file.open("r") as df: + lines = df.readlines() + with data_file.open("w") as df: + lines[0] = str(root_path) + "\n" + df.writelines(lines) + + +@app.command() +def convert_audio(log_dir: Path, out_dir: Path): + out_dir.mkdir(exist_ok=True, parents=True) + all_wavs = list((log_dir).glob("**/*.wav")) + name_wav_map = {i.name: i.absolute() for i in all_wavs} + exists_wavs = list((out_dir).glob("**/*.wav")) + rem_wavs = list( + set((i.name for i in all_wavs)) - set((i.name for i in exists_wavs)) + ) + rem_wavs_real = [name_wav_map[i] for i in rem_wavs] + + def resample_audio(i): + dest_wav = out_dir / i.name + if dest_wav.exists(): + return + run_shell(f"ffmpeg -i {i.absolute()} -ac 1 -ar 16000 {dest_wav}", verbose=False) + + parallel_apply(resample_audio, rem_wavs_real, workers=256) + + +@app.command() +def prepare_pretraining( + log_dir: Path, + dataset_path: Path, + format: str = "wav", + method: str = "random", + max_silence: int = 3000, + min_duration: int = 10000, + max_duration: int = 30000, + fixed_duration: int = 30000, + batch_size: int = 100, +): + audio_dir = dataset_path / "audio" + audio_dir.mkdir(exist_ok=True, parents=True) + cache_dir = dataset_path / "cache" + cache_dir.mkdir(exist_ok=True, parents=True) + all_wavs = list((log_dir).glob("**/*.wav")) + if method not in ["vad", "random", "fixed"]: + typer.echo("should be one of random|fixed") + raise typer.Exit() + + def write_seg_arg(arg): + seg, dest_wav = arg + ob = io.BytesIO() + seg.export(ob, format=format) + dest_wav.write_bytes(ob.getvalue()) + ob.close() + + with (dataset_path / "failed.log").open("w") as fl: + vad_utt = VADUtterance( + max_silence=max_silence, + min_utterance=min_duration, + max_utterance=max_duration, + ) + + def vad_process_wav(wav_path): + if (cache_dir / wav_path.stem).exists(): + return [] + try: + aud_seg = pydub.AudioSegment.from_file(wav_path) + except pydub.exceptions.CouldntDecodeError: + fl.write(wav_path.name + "\n") + return [] + full_seg = aud_seg + # segs = random_segs(len(full_seg), min_duration, max_duration) + segs = vad_utt.stream_segments(full_seg) + audio_chunk_paths = [] + if len(full_seg) > min_duration: + for (i, chunk_seg) in enumerate(segs): + dest_wav = audio_dir / (wav_path.stem + f"_{i}.{format}") + if dest_wav.exists(): + continue + audio_chunk_paths.append((chunk_seg, dest_wav)) + (cache_dir / wav_path.stem).touch() + return audio_chunk_paths + + def random_process_wav(wav_path): + if (cache_dir / wav_path.stem).exists(): + return [] + try: + aud_seg = pydub.AudioSegment.from_file(wav_path) + except pydub.exceptions.CouldntDecodeError: + fl.write(wav_path.name + "\n") + return [] + full_seg = aud_seg + segs = random_segs(len(full_seg), min_duration, max_duration) + audio_chunk_paths = [] + if len(full_seg) > min_duration: + for (i, (start, end)) in enumerate(segs): + dest_wav = audio_dir / (wav_path.stem + f"_{i}.{format}") + if dest_wav.exists(): + continue + chunk_seg = aud_seg[start:end] + audio_chunk_paths.append((chunk_seg, dest_wav)) + (cache_dir / wav_path.stem).touch() + return audio_chunk_paths + + def fixed_process_wav(wav_path): + if (cache_dir / wav_path.stem).exists(): + return [] + try: + aud_seg = pydub.AudioSegment.from_file(wav_path) + except pydub.exceptions.CouldntDecodeError: + fl.write(wav_path.name + "\n") + return [] + full_seg = aud_seg + audio_chunk_paths = [] + if len(full_seg) > min_duration: + for (i, chunk_seg) in enumerate(full_seg[::fixed_duration]): + dest_wav = audio_dir / (wav_path.stem + f"_{i}.{format}") + if dest_wav.exists() or len(chunk_seg) < min_duration: + continue + audio_chunk_paths.append((chunk_seg, dest_wav)) + (cache_dir / wav_path.stem).touch() + return audio_chunk_paths + + # warmup + pydub.AudioSegment.from_file(all_wavs[0]) + # parallel_apply(process_wav, all_wavs, pool='process') + # parallel_apply(process_wav, all_wavs) + seg_f = ( + vad_process_wav + if method == "vad" + else (random_process_wav if method == "random" else fixed_process_wav) + ) + for wp_batch in tqdm(batch(all_wavs, n=batch_size)): + acp_batch = parallel_apply(seg_f, wp_batch) + # acp_batch = list(map(seg_f, tqdm(wp_batch))) + flat_acp_batch = [sd for acp in acp_batch for sd in acp] + parallel_apply(write_seg_arg, flat_acp_batch) + # for acp in acp_batch: + # for (seg, des) in acp: + # seg.export(des) + # for seg_des in tqdm(flat_acp_batch): + # write_seg_arg(seg_des) + del flat_acp_batch + del acp_batch + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/plume/models/wav2vec2/eval.py b/src/plume/models/wav2vec2/eval.py similarity index 100% rename from plume/models/wav2vec2/eval.py rename to src/plume/models/wav2vec2/eval.py diff --git a/plume/models/wav2vec2/serve.py b/src/plume/models/wav2vec2/serve.py similarity index 100% rename from plume/models/wav2vec2/serve.py rename to src/plume/models/wav2vec2/serve.py diff --git a/plume/models/wav2vec2/train.py b/src/plume/models/wav2vec2/train.py similarity index 100% rename from plume/models/wav2vec2/train.py rename to src/plume/models/wav2vec2/train.py diff --git a/src/plume/models/wav2vec2_transformers/__init__.py b/src/plume/models/wav2vec2_transformers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/plume/models/wav2vec2_transformers/asr.py b/src/plume/models/wav2vec2_transformers/asr.py new file mode 100644 index 0000000..ef7f14f --- /dev/null +++ b/src/plume/models/wav2vec2_transformers/asr.py @@ -0,0 +1,39 @@ +from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC + +# import soundfile as sf +from io import BytesIO +import torch + +from plume.utils import lazy_module + +sf = lazy_module("soundfile") + + +class Wav2Vec2TransformersASR(object): + """docstring for Wav2Vec2TransformersASR.""" + + def __init__(self, ctc_path, w2v_path, target_dict_path): + super(Wav2Vec2TransformersASR, self).__init__() + self.tokenizer = Wav2Vec2Tokenizer.from_pretrained( + "facebook/wav2vec2-large-960h-lv60-self" + ) + self.model = Wav2Vec2ForCTC.from_pretrained( + "facebook/wav2vec2-large-960h-lv60-self" + ) + + def transcribe(self, audio_data): + aud_f = BytesIO(audio_data) + # net_input = {} + speech_data, _ = sf.read(aud_f) + input_values = self.tokenizer( + speech_data, return_tensors="pt", padding="longest" + ).input_values # Batch size 1 + + # retrieve logits + logits = self.model(input_values).logits + + # take argmax and decode + predicted_ids = torch.argmax(logits, dim=-1) + + transcription = self.tokenizer.batch_decode(predicted_ids)[0] + return transcription diff --git a/plume/models/wav2vec2/data.py b/src/plume/models/wav2vec2_transformers/data.py similarity index 98% rename from plume/models/wav2vec2/data.py rename to src/plume/models/wav2vec2_transformers/data.py index b79bbef..42e3e33 100644 --- a/plume/models/wav2vec2/data.py +++ b/src/plume/models/wav2vec2_transformers/data.py @@ -2,7 +2,6 @@ from pathlib import Path from collections import Counter import shutil -import soundfile # import pydub import typer from tqdm import tqdm @@ -12,8 +11,8 @@ from plume.utils import ( replace_redundant_spaces_with, lazy_module ) +soundfile = lazy_module('soundfile') pydub = lazy_module('pydub') - app = typer.Typer() diff --git a/src/plume/models/wav2vec2_transformers/eval.py b/src/plume/models/wav2vec2_transformers/eval.py new file mode 100644 index 0000000..4b99501 --- /dev/null +++ b/src/plume/models/wav2vec2_transformers/eval.py @@ -0,0 +1,52 @@ +from pathlib import Path +import typer +from tqdm import tqdm +# import pandas as pd + +from plume.utils import ( + asr_manifest_reader, + discard_except_digits, + replace_digit_symbol, + lazy_module + # run_shell, +) +from ...utils.transcribe import triton_transcribe_grpc_gen, transcribe_rpyc_gen + +pd = lazy_module('pandas') +app = typer.Typer() + + +@app.command() +def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False): + from pydub import AudioSegment + + host = "localhost" + port = 8044 + if rpyc: + transcriber, audio_prep = transcribe_rpyc_gen(host, port) + else: + transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole') + result_path = manifest_file.parent / result_file + manifest_list = list(asr_manifest_reader(manifest_file)) + + def compute_frame(d): + audio_file = d["audio_path"] + orig_text = d["text"] + orig_num = discard_except_digits(replace_digit_symbol(orig_text)) + aud_seg = AudioSegment.from_file(audio_file) + t_audio = audio_prep(aud_seg) + asr_text = transcriber(t_audio) + asr_num = discard_except_digits(replace_digit_symbol(asr_text)) + return { + "audio_file": audio_file, + "asr_text": asr_text, + "asr_num": asr_num, + "orig_text": orig_text, + "orig_num": orig_num, + "asr_match": orig_num == asr_num, + } + + # df_data = parallel_apply(compute_frame, manifest_list) + df_data = map(compute_frame, tqdm(manifest_list)) + df = pd.DataFrame(df_data) + df.to_csv(result_path) diff --git a/src/plume/models/wav2vec2_transformers/serve.py b/src/plume/models/wav2vec2_transformers/serve.py new file mode 100644 index 0000000..5ea8334 --- /dev/null +++ b/src/plume/models/wav2vec2_transformers/serve.py @@ -0,0 +1,52 @@ +import os +import logging +from pathlib import Path + +# from rpyc.utils.server import ThreadedServer +import typer + +from ...utils.serve import ASRService +from plume.utils import lazy_callable +# from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR +# from .asr import Wav2Vec2ASR + +ThreadedServer = lazy_callable("rpyc.utils.server.ThreadedServer") +Wav2Vec2TransformersASR = lazy_callable( + "plume.models.wav2vec2_transformers.asr.Wav2Vec2TransformersASR" +) + +app = typer.Typer() + + +@app.command() +def rpyc( + w2v_path: Path = "/path/to/base.pt", + ctc_path: Path = "/path/to/ctc.pt", + target_dict_path: Path = "/path/to/dict.ltr.txt", + port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")), +): + w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path) + service = ASRService(w2vasr) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logging.info("starting asr server...") + t = ThreadedServer(service, port=port) + t.start() + + +@app.command() +def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))): + ctc_path = model_dir / Path("ctc.pt") + w2v_path = model_dir / Path("base.pt") + target_dict_path = model_dir / Path("dict.ltr.txt") + rpyc(w2v_path, ctc_path, target_dict_path, port) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/src/plume/models/wav2vec2_transformers/test.py b/src/plume/models/wav2vec2_transformers/test.py new file mode 100644 index 0000000..f86ac75 --- /dev/null +++ b/src/plume/models/wav2vec2_transformers/test.py @@ -0,0 +1,41 @@ +from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC +from datasets import load_dataset +import soundfile as sf +import torch + +# load model and tokenizer +tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") +model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") + + +# define function to read in sound file +def map_to_array(batch): + speech, _ = sf.read(batch["file"]) + batch["speech"] = speech + return batch + + +# load dummy dataset and read soundfiles +def main(): + ds = load_dataset( + "patrickvonplaten/librispeech_asr_dummy", "clean", split="validation" + ) + ds = ds.map(map_to_array) + + # tokenize + input_values = tokenizer( + ds["speech"][:2], return_tensors="pt", padding="longest" + ).input_values # Batch size 1 + + # retrieve logits + logits = model(input_values).logits + + # take argmax and decode + predicted_ids = torch.argmax(logits, dim=-1) + + transcription = tokenizer.batch_decode(predicted_ids) + print(transcription) + + +if __name__ == "__main__": + main() diff --git a/src/plume/models/wav2vec2_transformers/train.py b/src/plume/models/wav2vec2_transformers/train.py new file mode 100644 index 0000000..ffbaeca --- /dev/null +++ b/src/plume/models/wav2vec2_transformers/train.py @@ -0,0 +1,34 @@ +import typer +# from fairseq_cli.train import cli_main +import sys +from pathlib import Path +import shlex +from plume.utils import lazy_callable + +cli_main = lazy_callable('fairseq_cli.train.cli_main') + +app = typer.Typer() + + +@app.command() +def local(dataset_path: Path): + args = f'''--distributed-world-size 1 {dataset_path} \ +--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \ +valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \ +--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \ +--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \ +--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \ +--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \ +--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \ +--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \ +--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \ +--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize +''' + new_args = ['train.py'] + new_args.extend(shlex.split(args)) + sys.argv = new_args + cli_main() + + +if __name__ == "__main__": + cli_main() diff --git a/plume/ui/__init__.py b/src/plume/ui/__init__.py similarity index 66% rename from plume/ui/__init__.py rename to src/plume/ui/__init__.py index 67a5c35..d8ab1e4 100644 --- a/plume/ui/__init__.py +++ b/src/plume/ui/__init__.py @@ -46,9 +46,40 @@ def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = @app.command() -def preview(manifest_path: Path): +def preview(manifest_path: Path, port: int = 8081): annotation_lit_path = Path(__file__).parent / Path("preview.py") - sys.argv = ["streamlit", "run", str(annotation_lit_path), "--", str(manifest_path)] + sys.argv = [ + "streamlit", + "run", + "--server.port", + str(port), + str(annotation_lit_path), + "--", + str(manifest_path), + ] + sys.exit(stcli.main()) + + +@app.command() +def encrypted_preview(manifest_path: Path, key: str, port: int = 8081): + lit_path = Path(__file__).parent / Path("encrypted_preview.py") + sys.argv = [ + "streamlit", + "run", + "--server.port", + str(port), + str(lit_path), + "--", + str(manifest_path), + str(key), + ] + sys.exit(stcli.main()) + + +@app.command() +def audio(audio_dir: Path): + lit_path = Path(__file__).parent / Path("audio.py") + sys.argv = ["streamlit", "run", str(lit_path), "--", str(audio_dir)] sys.exit(stcli.main()) diff --git a/plume/ui/annotation.py b/src/plume/ui/annotation.py similarity index 95% rename from plume/ui/annotation.py rename to src/plume/ui/annotation.py index bcb883f..04c2a17 100644 --- a/plume/ui/annotation.py +++ b/src/plume/ui/annotation.py @@ -13,9 +13,9 @@ setup_mongo_asr_validation_state(st) @st.cache() def load_ui_data(data_dir: Path, dump_fname: Path): - validation_ui_data_path = data_dir / dump_fname - typer.echo(f"Using validation ui data from {validation_ui_data_path}") - return ExtendedPath(validation_ui_data_path).read_json() + annotation_ui_data_path = data_dir / dump_fname + typer.echo(f"Using annotation ui data from {annotation_ui_data_path}") + return ExtendedPath(annotation_ui_data_path).read_json() def show_key(sample, key, trail=""): diff --git a/src/plume/ui/audio.py b/src/plume/ui/audio.py new file mode 100644 index 0000000..b517a68 --- /dev/null +++ b/src/plume/ui/audio.py @@ -0,0 +1,21 @@ +from pathlib import Path + +import streamlit as st +import typer + +app = typer.Typer() + + +@app.command() +def main(wav_dir: Path): + wav_file = list(wav_dir.glob('**/*.wav'))[0] + st.title("Audio Preview") + print(wav_file.exists()) + st.audio(str(wav_dir / wav_file)) + + +if __name__ == "__main__": + try: + app() + except SystemExit: + pass diff --git a/src/plume/ui/encrypted_preview.py b/src/plume/ui/encrypted_preview.py new file mode 100644 index 0000000..41ba7a8 --- /dev/null +++ b/src/plume/ui/encrypted_preview.py @@ -0,0 +1,46 @@ +from pathlib import Path + +import streamlit as st +import typer +from plume.utils import ExtendedPath, wav_cryptor, text_cryptor +from plume.utils.ui_persist import setup_file_state + +app = typer.Typer() + +setup_file_state(st) + + +@st.cache() +def load_ui_data(validation_ui_data_path: Path): + typer.echo(f"Using validation ui data from {validation_ui_data_path}") + return list(ExtendedPath(validation_ui_data_path).read_jsonl()) + + +@app.command() +def main(manifest: Path, key: str): + wc = wav_cryptor(key) + tc = text_cryptor(key) + asr_data = load_ui_data(manifest) + sample_no = st.get_current_cursor() + if len(asr_data) - 1 < sample_no or sample_no < 0: + print("Invalid samplno resetting to 0") + st.update_cursor(0) + sample = asr_data[sample_no] + st.title("ASR Manifest Preview") + gt_text = tc.decrypt_text(sample["text"].encode("utf-8")) + st.markdown(f"{sample_no+1} of {len(asr_data)} : **{gt_text}**") + new_sample = st.number_input( + "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) + ) + if new_sample != sample_no + 1: + st.update_cursor(new_sample - 1) + st.sidebar.markdown(f"Gold Text: **{gt_text}**") + wav = wc.decrypt_wav_path((manifest.parent / Path(sample["audio_filepath"]))) + st.audio(wav) + + +if __name__ == "__main__": + try: + app() + except SystemExit: + pass diff --git a/plume/ui/preview.py b/src/plume/ui/preview.py similarity index 100% rename from plume/ui/preview.py rename to src/plume/ui/preview.py diff --git a/plume/utils/.gitignore b/src/plume/utils/.gitignore similarity index 100% rename from plume/utils/.gitignore rename to src/plume/utils/.gitignore diff --git a/plume/utils/__init__.py b/src/plume/utils/__init__.py similarity index 55% rename from plume/utils/__init__.py rename to src/plume/utils/__init__.py index 46f3094..6c69da8 100644 --- a/plume/utils/__init__.py +++ b/src/plume/utils/__init__.py @@ -4,20 +4,20 @@ import re import json import wave import logging +import subprocess +import shutil +import random from pathlib import Path from functools import partial from uuid import uuid4 -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor -import subprocess -import shutil from urllib.parse import urlsplit # from .lazy_loader import LazyLoader -from .lazy_import import lazy_callable, lazy_module # from ruamel.yaml import YAML # import boto3 import typer +from tqdm import tqdm # import pymongo # from slugify import slugify @@ -27,14 +27,29 @@ import typer # import librosa.display as audio_display # from natural.date import compress # from num2words import num2words -from tqdm import tqdm -from datetime import timedelta +import datetime +import six # from .transcribe import triton_transcribe_grpc_gen # from .eval import app as eval_app +from .manifest import asr_manifest_writer, manifest_str +from .lazy_import import lazy_callable, lazy_module +from .parallel import parallel_apply +from .extended_path import ExtendedPath from .tts import app as tts_app from .transcribe import app as transcribe_app from .align import app as align_app +from .encrypt import app as encrypt_app, wav_cryptor, text_cryptor # noqa +from .regentity import ( # noqa + num_replacer, + alnum_replacer, + num_keeper, + alnum_keeper, + default_num_rules, + default_num_only_rules, + default_alnum_rules, + entity_replacer_keeper, +) boto3 = lazy_module("boto3") pymongo = lazy_module("pymongo") @@ -45,9 +60,9 @@ librosa = lazy_module("librosa") YAML = lazy_callable("ruamel.yaml.YAML") num2words = lazy_callable("num2words.num2words") slugify = lazy_callable("slugify.slugify") -compress = lazy_callable("natural.date.compress") app = typer.Typer() +app.add_typer(encrypt_app) app.add_typer(tts_app, name="tts") app.add_typer(align_app, name="align") app.add_typer(transcribe_app, name="transcribe") @@ -60,31 +75,164 @@ def utils(): """ -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) +log_fmt_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +logging.basicConfig(level=logging.INFO, format=log_fmt_str) logger = logging.getLogger(__name__) -def manifest_str(path, dur, text): - return ( - json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text}) - + "\n" - ) +# Precalculated timestamps +TIME_MINUTE = 60 +TIME_HOUR = 3600 +TIME_DAY = 86400 +TIME_WEEK = 604800 -def duration_str(seconds): - return compress(timedelta(seconds=seconds), pad=" ") +def compress(t, show_hours=False, sign=False, pad=""): + """ + Convert the input to compressed format, works with a + :class:`datetime.timedelta` object or a number that represents the number + of seconds you want to compress. If you supply a timestamp or a + :class:`datetime.datetime` object, it will give the delta relative to the + current time. + You can enable showing a sign in front of the compressed format with the + ``sign`` parameter, the default is not to show signs. + Optionally, you can chose to pad the output. If you wish your values to be + separated by spaces, set ``pad`` to ``' '``. + :param t: seconds or :class:`datetime.timedelta` object + :param sign: default ``False`` + :param pad: default ``''`` + >>> print(compress(0)) + 0s + >>> print(compress(1)) + 1s + >>> print(compress(12)) + 12s + >>> print(compress(123)) + 2m3s + >>> print(compress(1234)) + 20m34s + >>> print(compress(12345)) + 3h25m45s + >>> print(compress(123456)) + 1d10h17m36s + ============== + src: https://github.com/tehmaze/natural/blob/master/natural/date.py + """ + + if isinstance(t, datetime.timedelta): + seconds = t.seconds + (t.days * 86400) + elif isinstance(t, six.integer_types + (float,)): + return compress(datetime.timedelta(seconds=t), sign, pad) + else: + raise Exception("Invalid time format") + + parts = [] + if sign: + parts.append("-" if t.days < 0 else "+") + + if not show_hours: + weeks, seconds = divmod(seconds, TIME_WEEK) + days, seconds = divmod(seconds, TIME_DAY) + hours, seconds = divmod(seconds, TIME_HOUR) + minutes, seconds = divmod(seconds, TIME_MINUTE) + + if not show_hours: + if weeks: + parts.append(("%dw") % (weeks,)) + if days: + parts.append(("%dd") % (days,)) + if hours: + parts.append(("%dh") % (hours,)) + if minutes: + parts.append(("%dm") % (minutes,)) + if seconds or len(parts) == 0: + parts.append(("%ds") % (seconds,)) + + return pad.join(parts) -def replace_digit_symbol(w2v_out): - num_int_map = {num2words(i): str(i) for i in range(10)} +def duration_str(seconds, show_hours=False): + t = datetime.timedelta(seconds=seconds) + return compress(t, show_hours=show_hours, pad=" ") + + +def replace_digit_symbol(w2v_out, num_range=10): + def rep_i(i): + return (num2words(i).replace("-", " "), str(i)) + + num_int_map = [rep_i(i) for i in reversed(range(num_range))] out = w2v_out.lower() - for (k, v) in num_int_map.items(): + for (k, v) in num_int_map: out = re.sub(k, v, out) return out +def num_keeper_orig(num_range=10, extra_rules=[]): + num_int_map_ty = [ + ( + r"\b" + num2words(i) + r"\b", + " " + str(i) + " ", + ) + for i in reversed(range(num_range)) + ] + re_rules = [ + (re.compile(k, re.IGNORECASE), v) + for (k, v) in [ + # (r"[ ;,.]", " "), + (r"\bdouble(?: |-)(\w+)\b", "\\1 \\1"), + (r"\btriple(?: |-)(\w+)\b", "\\1 \\1 \\1"), + (r"hundred", "00"), + (r"\boh\b", " 0 "), + (r"\bo\b", " 0 "), + ] + + num_int_map_ty + ] + [(re.compile(k), v) for (k, v) in extra_rules] + + def merge_intervals(intervals): + # https://codereview.stackexchange.com/a/69249 + sorted_by_lower_bound = sorted(intervals, key=lambda tup: tup[0]) + merged = [] + + for higher in sorted_by_lower_bound: + if not merged: + merged.append(higher) + else: + lower = merged[-1] + # test for intersection between lower and higher: + # we know via sorting that lower[0] <= higher[0] + if higher[0] <= lower[1]: + upper_bound = max(lower[1], higher[1]) + merged[-1] = ( + lower[0], + upper_bound, + ) # replace by merged interval + else: + merged.append(higher) + return merged + + # merging interval tree for optimal # https://www.geeksforgeeks.org/interval-tree/ + + def keep_numeric_literals(w2v_out): + # out = w2v_out.lower() + out = re.sub(r"[ ;,.]", " ", w2v_out).strip() + # out = " " + out.strip() + " " + # out = re.sub(r"double (\w+)", "\\1 \\1", out) + # out = re.sub(r"triple (\w+)", "\\1 \\1 \\1", out) + num_spans = [] + for (k, v) in re_rules: # [94:]: + matches = k.finditer(out) + for m in matches: + # num_spans.append((k, m.span())) + num_spans.append(m.span()) + # out = re.sub(k, v, out) + merged = merge_intervals(num_spans) + num_ents = len(merged) + keep_out = " ".join((out[s[0] : s[1]] for s in merged)) + return keep_out, num_ents + + return keep_numeric_literals + + def discard_except_digits(inp): return re.sub("[^0-9]", "", inp) @@ -103,6 +251,26 @@ def space_out(text): return letters +def random_segs(total, min_val, max_val): + out_list = [] + rand_total = prev_start = 0 + while True: + if total < rand_total + min_val or total < rand_total: + break + sample = random.randint(min_val, max_val) + if total - rand_total < max_val: + break + if total - rand_total < max_val + min_val: + sample = random.randint(min_val, max_val - min_val) + prev_start = rand_total + if 0 < rand_total + sample - total < max_val: + break + rand_total += sample + out_list.append((prev_start, rand_total)) + out_list.append((rand_total, total)) + return out_list + + def wav_bytes(audio_bytes, frame_rate=24000): wf_b = io.BytesIO() with wave.open(wf_b, mode="w") as wf: @@ -117,17 +285,20 @@ def tscript_uuid_fname(transcript): return str(uuid4()) + "_" + slugify(transcript, max_length=8) -def run_shell(cmd_str, work_dir="."): +def run_shell(cmd_str, work_dir=".", verbose=True): cwd_path = Path(work_dir).absolute() - p = subprocess.Popen( - cmd_str, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - shell=True, - cwd=cwd_path, - ) - for line in p.stdout: - print(line.replace(b"\n", b"").decode("utf-8")) + if verbose: + with subprocess.Popen( + cmd_str, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + shell=True, + cwd=cwd_path, + ) as p: + for line in p.stdout: + print(line.replace(b"\n", b"").decode("utf-8")) + else: + subprocess.run(cmd_str, shell=True, cwd=cwd_path, capture_output=True) def upload_s3(dataset_path, s3_path): @@ -154,7 +325,8 @@ def s3_downloader(): if not download_path.exists(): if verbose: print(f"downloading {s3_uri} to {download_path}") - s3.download_file(s3_uri_p.netloc, s3_uri_p.path[1:], str(download_path)) + dp_s = str(download_path) + s3.download_file(s3_uri_p.netloc, s3_uri_p.path[1:], dp_s) return download_s3 @@ -167,7 +339,8 @@ def asr_data_writer(dataset_dir, asr_data_source, verbose=False): print(f"writing manifest to {asr_manifest}") for transcript, audio_dur, wav_data in asr_data_source: fname = tscript_uuid_fname(transcript) - audio_file = dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav") + wav_fname = Path(fname).with_suffix(".wav") + audio_file = dataset_dir / Path("wavs") / wav_fname audio_file.write_bytes(wav_data) rel_data_path = audio_file.relative_to(dataset_dir) manifest = manifest_str(str(rel_data_path), audio_dur, transcript) @@ -211,7 +384,13 @@ def ui_data_generator(dataset_dir, asr_data_source, verbose=False): num_datapoints = 0 data_funcs = [] - for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source: + for ( + transcript, + audio_dur, + wav_data, + caller_name, + aud_seg, + ) in asr_data_source: fname = str(uuid4()) + "_" + slugify(transcript, max_length=8) audio_file = ( dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav") @@ -269,17 +448,6 @@ def asr_manifest_reader(data_manifest_path: Path): yield p -def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source, verbose=False): - with asr_manifest_path.open("w") as mf: - if verbose: - print(f"writing asr manifest to {asr_manifest_path}") - for mani_dict in manifest_str_source: - manifest = manifest_str( - mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"] - ) - mf.write(manifest) - - def asr_test_writer(out_file_path: Path, source): def dd_str(dd, idx): path = dd["audio_filepath"] @@ -306,52 +474,6 @@ def batch(iterable, n=1): return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)] -class ExtendedPath(type(Path())): - """docstring for ExtendedPath.""" - - def read_json(self, verbose=False): - if verbose: - print(f"reading json from {self}") - with self.open("r") as jf: - return json.load(jf) - - def read_yaml(self, verbose=False): - yaml = YAML(typ="safe", pure=True) - if verbose: - print(f"reading yaml from {self}") - with self.open("r") as yf: - return yaml.load(yf) - - def read_jsonl(self, verbose=False): - if verbose: - print(f"reading jsonl from {self}") - with self.open("r") as jf: - for ln in jf.readlines(): - yield json.loads(ln) - - def write_json(self, data, verbose=False): - if verbose: - print(f"writing json to {self}") - self.parent.mkdir(parents=True, exist_ok=True) - with self.open("w") as jf: - json.dump(data, jf, indent=2) - - def write_yaml(self, data, verbose=False): - yaml = YAML() - if verbose: - print(f"writing yaml to {self}") - with self.open("w") as yf: - yaml.dump(data, yf) - - def write_jsonl(self, data, verbose=False): - if verbose: - print(f"writing jsonl to {self}") - self.parent.mkdir(parents=True, exist_ok=True) - with self.open("w") as jf: - for d in data: - jf.write(json.dumps(d) + "\n") - - def get_mongo_coll(uri): ud = pymongo.uri_parser.parse_uri(uri) conn = pymongo.MongoClient(uri) @@ -383,37 +505,23 @@ def plot_seg(wav_plot_path, audio_path): fig.savefig(wav_plot_f, format="png", dpi=50) -def parallel_apply(fn, iterable, workers=8, pool="thread"): - if pool == "thread": - with ThreadPoolExecutor(max_workers=workers) as exe: - print(f"parallelly applying {fn}") - return [ - res - for res in tqdm( - exe.map(fn, iterable), position=0, leave=True, total=len(iterable) - ) - ] - elif pool == "process": - with ProcessPoolExecutor(max_workers=workers) as exe: - print(f"parallelly applying {fn}") - return [ - res - for res in tqdm( - exe.map(fn, iterable), position=0, leave=True, total=len(iterable) - ) - ] - else: - raise Exception(f"unsupported pool type - {pool}") - - def generate_filter_map(src_dataset_path, dest_dataset_path, data_file): min_nums = 3 max_duration = 1 * 60 * 60 skip_duration = 1 * 60 * 60 + max_sample_dur = 20 + min_sample_dur = 2 + verbose = True + + src_data_enum = ( + tqdm(list(ExtendedPath(data_file).read_jsonl())) + if verbose + else ExtendedPath(data_file).read_jsonl() + ) def filtered_max_dur(): wav_duration = 0 - for s in ExtendedPath(data_file).read_jsonl(): + for s in src_data_enum: nums = re.sub(" ", "", s["text"]) if len(nums) >= min_nums: wav_duration += s["duration"] @@ -428,36 +536,54 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file): def filtered_skip_dur(): wav_duration = 0 - for s in ExtendedPath(data_file).read_jsonl(): + for s in src_data_enum: nums = re.sub(" ", "", s["text"]) if len(nums) >= min_nums: wav_duration += s["duration"] if wav_duration <= skip_duration: continue elif len(nums) >= min_nums: - yield s shutil.copy( src_dataset_path / Path(s["audio_filepath"]), dest_dataset_path / Path(s["audio_filepath"]), ) + yield s typer.echo(f"skipped {duration_str(skip_duration)} of audio") def filtered_blanks(): - blank_count = 0 - for s in ExtendedPath(data_file).read_jsonl(): + blank_count = total_count = 0 + for s in src_data_enum: + total_count += 1 nums = re.sub(" ", "", s["text"]) if nums != "": - blank_count += 1 shutil.copy( src_dataset_path / Path(s["audio_filepath"]), dest_dataset_path / Path(s["audio_filepath"]), ) yield s - typer.echo(f"filtered {blank_count} blank samples") + else: + blank_count += 1 + typer.echo(f"filtered {blank_count} of {total_count} blank samples") + + def filtered_max_sample_dur(): + max_dur_count = 0 + for s in src_data_enum: + wav_duration = s["duration"] + if wav_duration <= max_sample_dur: + shutil.copy( + src_dataset_path / Path(s["audio_filepath"]), + dest_dataset_path / Path(s["audio_filepath"]), + ) + yield s + else: + max_dur_count += 1 + typer.echo( + f"filtered {max_dur_count} samples longer thans {max_sample_dur}s" + ) def filtered_transform_digits(): count = 0 - for s in ExtendedPath(data_file).read_jsonl(): + for s in src_data_enum: count += 1 digit_text = replace_digit_symbol(s["text"]) only_digits = discard_except_digits(digit_text) @@ -472,11 +598,13 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file): def filtered_extract_chars(): count = 0 - for s in ExtendedPath(data_file).read_jsonl(): + for s in src_data_enum: count += 1 no_digits = digits_to_chars(s["text"]).upper() only_chars = re.sub("[^A-Z'\b]", " ", no_digits) - filter_text = replace_redundant_spaces_with(only_chars, " ").strip() + filter_text = replace_redundant_spaces_with( + only_chars, " " + ).strip() shutil.copy( src_dataset_path / Path(s["audio_filepath"]), dest_dataset_path / Path(s["audio_filepath"]), @@ -487,16 +615,54 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file): def filtered_resample(): count = 0 - for s in ExtendedPath(data_file).read_jsonl(): + for s in src_data_enum: count += 1 src_aud = pydub.AudioSegment.from_file( src_dataset_path / Path(s["audio_filepath"]) ) - dst_aud = src_aud.set_channels(1).set_sample_width(1).set_frame_rate(24000) - dst_aud.export(dest_dataset_path / Path(s["audio_filepath"]), format="wav") + dst_aud = ( + src_aud.set_channels(1) + .set_sample_width(1) + .set_frame_rate(24000) + ) + dst_aud.export( + dest_dataset_path / Path(s["audio_filepath"]), format="wav" + ) yield s typer.echo(f"transformed {count} samples") + def filtered_msec_to_sec(): + count = 0 + for s in src_data_enum: + count += 1 + s["duration"] = s["duration"] / 1000 + shutil.copy( + src_dataset_path / Path(s["audio_filepath"]), + dest_dataset_path / Path(s["audio_filepath"]), + ) + yield s + typer.echo(f"transformed {count} samples") + + def filtered_blank_hr_max_dur(): + max_duration = 3 * 60 * 60 + wav_duration = 0 + for s in src_data_enum: + # nums = re.sub(" ", "", s["text"]) + s["text"] = "gAAAAABgq2FR6ajbhMsDmWRQBzX6gIzyAG5sMwFihGeV7E_6eVJqqF78yzmtTJPsJAOJEEXhJ9Z45MrYNgE1sq7VUdsBVGh2cw==" + if ( + s["duration"] >= min_sample_dur + and s["duration"] <= max_sample_dur + ): + wav_duration += s["duration"] + shutil.copy( + src_dataset_path / Path(s["audio_filepath"]), + dest_dataset_path / Path(s["audio_filepath"]), + ) + yield s + if wav_duration > max_duration: + break + typer.echo(f"filtered only {duration_str(wav_duration)} of audio") + filter_kind_map = { "max_dur_1hr_min3num": filtered_max_dur, "skip_dur_1hr_min3num": filtered_skip_dur, @@ -504,5 +670,8 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file): "transform_digits": filtered_transform_digits, "extract_chars": filtered_extract_chars, "resample_ulaw24kmono": filtered_resample, + "max_sample_dur": filtered_max_sample_dur, + "msec_to_sec": filtered_msec_to_sec, + "blank_3hr_max_dur": filtered_blank_hr_max_dur, } return filter_kind_map diff --git a/plume/utils/align.py b/src/plume/utils/align.py similarity index 86% rename from plume/utils/align.py rename to src/plume/utils/align.py index 5e937ce..3e17406 100644 --- a/plume/utils/align.py +++ b/src/plume/utils/align.py @@ -1,6 +1,5 @@ from pathlib import Path # from IPython import display -import requests import io import shutil @@ -11,6 +10,7 @@ from .tts import GoogleTTS display = lazy_module('IPython.display') pydub = lazy_module('pydub') +requests = lazy_module('requests') app = typer.Typer() @@ -72,12 +72,12 @@ def gentle_preview( pkg_gentle_dir = Path(__file__).parent / 'gentle_preview' shutil.copytree(str(pkg_gentle_dir), str(gent_preview_dir)) - # ab = audio_path.read_bytes() - # tt = transcript_path.read_text() - # audio, alignment = gentle_aligner(service_uri, ab, tt) - # audio.export(gent_preview_dir / Path("a.wav"), format="wav") - # alignment["status"] = "OK" - # ExtendedPath(gent_preview_dir / Path("status.json")).write_json(alignment) + ab = audio_path.read_bytes() + tt = transcript_path.read_text() + audio, alignment = gentle_aligner(service_uri, ab, tt) + audio.export(gent_preview_dir / Path("a.wav"), format="wav") + alignment["status"] = "OK" + ExtendedPath(gent_preview_dir / Path("status.json")).write_json(alignment) def main(): diff --git a/src/plume/utils/audio.py b/src/plume/utils/audio.py new file mode 100644 index 0000000..416b4a2 --- /dev/null +++ b/src/plume/utils/audio.py @@ -0,0 +1,53 @@ +import sys +from io import BytesIO + +from .lazy_import import lazy_module, lazy_callable + +np = lazy_module("numpy") +pydub = lazy_module("pydub") +lfilter = lazy_callable("scipy.signal.lfilter") +butter = lazy_callable("scipy.signal.butter") +read = lazy_callable("scipy.io.wavfile.read") +write = lazy_callable("scipy.io.wavfile.write") +# from scipy.signal import lfilter, butter +# from scipy.io.wavfile import read, write +# import numpy as np + + +def audio_seg_to_wav_bytes(aud_seg): + b = BytesIO() + aud_seg.export(b, format="wav") + return b.getvalue() + + +def audio_wav_bytes_to_seg(wav_bytes): + b = BytesIO(wav_bytes) + return pydub.AudioSegment.from_file(b) + + +def butter_params(low_freq, high_freq, fs, order=5): + nyq = 0.5 * fs + low = low_freq / nyq + high = high_freq / nyq + b, a = butter(order, [low, high], btype="band") + return b, a + + +def butter_bandpass_filter(data, low_freq, high_freq, fs, order=5): + b, a = butter_params(low_freq, high_freq, fs, order=order) + y = lfilter(b, a, data) + return y + + +if __name__ == "__main__": + fs, audio = read(sys.argv[1]) + import pdb + + pdb.set_trace() + low_freq = 300.0 + high_freq = 4000.0 + filtered_signal = butter_bandpass_filter( + audio, low_freq, high_freq, fs, order=6 + ) + fname = sys.argv[1].split(".wav")[0] + "_moded.wav" + write(fname, fs, np.array(filtered_signal, dtype=np.int16)) diff --git a/src/plume/utils/encrypt.py b/src/plume/utils/encrypt.py new file mode 100644 index 0000000..cd84fc3 --- /dev/null +++ b/src/plume/utils/encrypt.py @@ -0,0 +1,188 @@ +from collections import namedtuple +from io import BytesIO +from pathlib import Path + +# from cryptography.fernet import Fernet + +import typer +from tqdm import tqdm + +from . import asr_manifest_writer +from .extended_path import ExtendedPath +from .audio import audio_seg_to_wav_bytes, audio_wav_bytes_to_seg +from .parallel import parallel_apply +from .lazy_import import lazy_module + +cryptography = lazy_module("cryptography") +# cryptography.fernet = lazy_module("cryptography.fernet") +pydub = lazy_module("pydub") + +app = typer.Typer() + + +@app.callback() +def encrypt(): + """ + encrypt sub commands + """ + + +def wav_cryptor(key=""): + WavCryptor = namedtuple( + "WavCryptor", + ( + "keygen", + "encrypt_wav_path_to", + "decrypt_wav_path_to", + "decrypt_wav_path", + ), + ) + _enc_key = key + _crypto_f = cryptography.fernet.Fernet(_enc_key) + + def encrypt_wav_bytes(f, dec_wav_bytes): + b = BytesIO(dec_wav_bytes) + audio_seg = pydub.AudioSegment.from_file(b) + # audio_seg.raw_data + enc_wav_bytes = f.encrypt(audio_seg.raw_data) + encrypted_seg = pydub.AudioSegment( + enc_wav_bytes, + frame_rate=audio_seg.frame_rate, + channels=audio_seg.channels, + sample_width=audio_seg.sample_width, + ) + return audio_seg_to_wav_bytes(encrypted_seg) + + def decrypt_wav_bytes(f, enc_wav_bytes): + b = BytesIO(enc_wav_bytes) + audio_seg = pydub.AudioSegment.from_file(b) + dec_wav_bytes = f.decrypt(audio_seg.raw_data) + decrypted_seg = pydub.AudioSegment( + dec_wav_bytes, + frame_rate=audio_seg.frame_rate, + channels=audio_seg.channels, + sample_width=audio_seg.sample_width, + ) + return audio_seg_to_wav_bytes(decrypted_seg) + + def encrypt_wav_path_to(dec_audio_path: Path, enc_audio_path: Path): + dec_wav_bytes = dec_audio_path.read_bytes() + enc_audio_path.write_bytes(encrypt_wav_bytes(_crypto_f, dec_wav_bytes)) + + def decrypt_wav_path_to(enc_audio_path: Path, dec_audio_path: Path): + enc_wav_bytes = enc_audio_path.read_bytes() + dec_audio_path.write_bytes(decrypt_wav_bytes(_crypto_f, enc_wav_bytes)) + + def decrypt_wav_path(enc_audio_path: Path): + enc_wav_bytes = enc_audio_path.read_bytes() + return decrypt_wav_bytes(_crypto_f, enc_wav_bytes) + + return WavCryptor( + cryptography.fernet.Fernet.generate_key, + encrypt_wav_path_to, + decrypt_wav_path_to, + decrypt_wav_path, + ) + + +def text_cryptor(key=""): + TextCryptor = namedtuple( + "TextCryptor", + ("keygen", "encrypt_text", "decrypt_text"), + ) + _enc_key = key + _crypto_f = cryptography.fernet.Fernet(_enc_key) + + def encrypt_text(text: str): + return _crypto_f.encrypt(text.encode("utf-8")) + + def decrypt_text(text: str): + return _crypto_f.decrypt(text).decode("utf-8") + + return TextCryptor( + cryptography.fernet.Fernet.generate_key, encrypt_text, decrypt_text + ) + + +def encrypted_asr_manifest_reader( + data_manifest_path: Path, encryption_key: str, verbose=True, parallel=True +): + print(f"reading encrypted manifest from {data_manifest_path}") + asr_data = list(ExtendedPath(data_manifest_path).read_jsonl()) + enc_key_bytes = encryption_key.encode("utf-8") + wc = wav_cryptor(enc_key_bytes) + tc = text_cryptor(enc_key_bytes) + + def decrypt_fn(p): + d = { + "audio_seg": audio_wav_bytes_to_seg( + wc.decrypt_wav_path( + data_manifest_path.parent / Path(p["audio_filepath"]) + ) + ), + "text": tc.decrypt_text(p["text"].encode("utf-8")), + } + return d + + if parallel: + for d in parallel_apply(decrypt_fn, asr_data, verbose=verbose): + yield d + else: + for p in tqdm.tqdm(asr_data) if verbose else asr_data: + yield decrypt_fn(d) + + +def decrypt_asr_dataset( + src_dataset_dir: Path, + dest_dataset_dir: Path, + encryption_key: str, + verbose=True, + parallel=True, +): + data_manifest_path = src_dataset_dir / "manifest.json" + (dest_dataset_dir / "wavs").mkdir(exist_ok=True, parents=True) + dest_manifest_path = dest_dataset_dir / "manifest.json" + print(f"reading encrypted manifest from {data_manifest_path}") + asr_data = list(ExtendedPath(data_manifest_path).read_jsonl()) + enc_key_bytes = encryption_key.encode("utf-8") + wc = wav_cryptor(enc_key_bytes) + tc = text_cryptor(enc_key_bytes) + + def decrypt_fn(p): + dest_path = dest_dataset_dir / Path(p["audio_filepath"]) + wc.decrypt_wav_path_to( + src_dataset_dir / Path(p["audio_filepath"]), dest_path + ) + d = { + "audio_filepath": dest_path, + "duration": p["duration"], + "text": tc.decrypt_text(p["text"].encode("utf-8")), + } + return d + + def datagen(): + if parallel: + for d in parallel_apply(decrypt_fn, asr_data, verbose=verbose): + yield d + else: + for p in tqdm.tqdm(asr_data) if verbose else asr_data: + yield decrypt_fn(d) + + asr_manifest_writer(dest_manifest_path, datagen) + + +@app.command() +def keygen(): + gen_key = cryptography.fernet.Fernet.generate_key() + typer.echo(f"KEY: {gen_key}") + + +@app.command() +def encrypt_text( + text_to_encrypt: str, + encryption_key: str = typer.Option(..., prompt=True, hide_input=True), +): + enc_key_bytes = encryption_key.encode("utf-8") + tc = text_cryptor(enc_key_bytes) + cryptext = tc.encrypt_text(text_to_encrypt) + typer.echo(cryptext) diff --git a/src/plume/utils/extended_path.py b/src/plume/utils/extended_path.py new file mode 100644 index 0000000..53be355 --- /dev/null +++ b/src/plume/utils/extended_path.py @@ -0,0 +1,56 @@ +from pathlib import Path +import json + +from .lazy_import import lazy_module + +yaml = lazy_module("ruamel.yaml") +pydub = lazy_module("pydub") + + +class ExtendedPath(type(Path())): + """docstring for ExtendedPath.""" + + def read_json(self, verbose=False): + if verbose: + print(f"reading json from {self}") + with self.open("r") as jf: + return json.load(jf) + + def read_yaml(self, verbose=False): + yaml_o = yaml.YAML(typ="safe", pure=True) + if verbose: + print(f"reading yaml from {self}") + with self.open("r") as yf: + return yaml_o.load(yf) + + def read_jsonl(self, verbose=False): + if verbose: + print(f"reading jsonl from {self}") + with self.open("r") as jf: + for ln in jf.readlines(): + yield json.loads(ln) + + def read_audio_segment(self): + return pydub.AudioSegment.from_file(self) + + def write_json(self, data, verbose=False): + if verbose: + print(f"writing json to {self}") + self.parent.mkdir(parents=True, exist_ok=True) + with self.open("w") as jf: + json.dump(data, jf, indent=2) + + def write_yaml(self, data, verbose=False): + yaml_o = yaml.YAML() + if verbose: + print(f"writing yaml to {self}") + with self.open("w") as yf: + yaml_o.dump(data, yf) + + def write_jsonl(self, data, verbose=False): + if verbose: + print(f"writing jsonl to {self}") + self.parent.mkdir(parents=True, exist_ok=True) + with self.open("w") as jf: + for d in data: + jf.write(json.dumps(d) + "\n") diff --git a/plume/utils/gentle_preview/README.md b/src/plume/utils/gentle_preview/README.md similarity index 100% rename from plume/utils/gentle_preview/README.md rename to src/plume/utils/gentle_preview/README.md diff --git a/plume/utils/gentle_preview/align.html b/src/plume/utils/gentle_preview/align.html similarity index 100% rename from plume/utils/gentle_preview/align.html rename to src/plume/utils/gentle_preview/align.html diff --git a/plume/utils/gentle_preview/index.html b/src/plume/utils/gentle_preview/index.html similarity index 100% rename from plume/utils/gentle_preview/index.html rename to src/plume/utils/gentle_preview/index.html diff --git a/plume/utils/gentle_preview/preloader.gif b/src/plume/utils/gentle_preview/preloader.gif similarity index 100% rename from plume/utils/gentle_preview/preloader.gif rename to src/plume/utils/gentle_preview/preloader.gif diff --git a/plume/utils/lazy_import.py b/src/plume/utils/lazy_import.py similarity index 98% rename from plume/utils/lazy_import.py rename to src/plume/utils/lazy_import.py index 615f596..6b56dc3 100644 --- a/plume/utils/lazy_import.py +++ b/src/plume/utils/lazy_import.py @@ -82,10 +82,10 @@ except ImportError: # Adding a __spec__ doesn't really help. I'll leave the code here in case # future python implementations start relying on it. -# try: -# from importlib.machinery import ModuleSpec -# except ImportError: -# ModuleSpec = None +try: + from importlib.machinery import ModuleSpec +except ImportError: + ModuleSpec = None import six from six import raise_from @@ -206,8 +206,7 @@ class LazyModule(ModuleType): class LazyCallable(object): - """Class for lazily-loaded callables that triggers module loading on access - """ + """Class for lazily-loaded callables that triggers module loading on access""" def __init__(self, *args): if len(args) != 2: @@ -399,9 +398,8 @@ def _lazy_module(modname, error_strings, lazy_mod_class): # Actual module instantiation mod = sys.modules[modname] = _LazyModule(modname) # No need for __spec__. Maybe in the future. - # if ModuleSpec: - # ModuleType.__setattr__(mod, '__spec__', - # ModuleSpec(modname, None)) + if ModuleSpec: + ModuleType.__setattr__(mod, "__spec__", ModuleSpec(modname, None)) if fullsubmodname: submod = sys.modules[fullsubmodname] ModuleType.__setattr__(mod, submodname, submod) @@ -531,8 +529,7 @@ def _lazy_callable(modname, cname, error_strings, lazy_mod_class, lazy_call_clas def _load_module(module): - """Ensures that a module, and its parents, are properly loaded - """ + """Ensures that a module, and its parents, are properly loaded""" modclass = type(module) # We only take care of our own LazyModule instances if not issubclass(modclass, LazyModule): @@ -623,8 +620,7 @@ _DELETION_DICT = ("_lazy_import_submodules",) def _setdef(argdict, name, defaultvalue): - """Like dict.setdefault but sets the default value also if None is present. - """ + """Like dict.setdefault but sets the default value also if None is present.""" if not name in argdict or argdict[name] is None: argdict[name] = defaultvalue return argdict[name] @@ -645,8 +641,7 @@ def _set_default_errornames(modname, error_strings, call=False): def _caller_name(depth=2, default=""): - """Returns the name of the calling namespace. - """ + """Returns the name of the calling namespace.""" # the presence of sys._getframe might be implementation-dependent. # It isn't that serious if we can't get the caller's name. try: @@ -700,8 +695,7 @@ def _clean_lazy_submod_refs(module): def _reset_lazymodule(module, cls_attrs): - """Resets a module's lazy state from cached data. - """ + """Resets a module's lazy state from cached data.""" modclass = type(module) del modclass.__getattribute__ del modclass.__setattr__ diff --git a/plume/utils/lazy_loader.py b/src/plume/utils/lazy_loader.py similarity index 100% rename from plume/utils/lazy_loader.py rename to src/plume/utils/lazy_loader.py diff --git a/src/plume/utils/manifest.py b/src/plume/utils/manifest.py new file mode 100644 index 0000000..ee7a475 --- /dev/null +++ b/src/plume/utils/manifest.py @@ -0,0 +1,68 @@ +from pathlib import Path + +# from tqdm import tqdm +import json + +# from .extended_path import ExtendedPath +# from .parallel import parallel_apply +# from .encrypt import wav_cryptor, text_cryptor + + +def manifest_str(path, dur, text): + k = {"audio_filepath": path, "duration": round(dur, 1), "text": text} + return json.dumps(k) + "\n" + + +def asr_manifest_writer( + asr_manifest_path: Path, manifest_str_source, verbose=False +): + with asr_manifest_path.open("w") as mf: + if verbose: + print(f"writing asr manifest to {asr_manifest_path}") + for mani_dict in manifest_str_source: + manifest = manifest_str( + mani_dict["audio_filepath"], + mani_dict["duration"], + mani_dict["text"], + ) + mf.write(manifest) + + +# +# def decrypt( +# src_dataset_dir: Path, +# dest_dataset_dir: Path, +# encryption_key: str, +# verbose=True, +# parallel=True, +# ): +# data_manifest_path = src_dataset_dir / "manifest.json" +# (dest_dataset_dir / "wavs").mkdir(exist_ok=True, parents=True) +# dest_manifest_path = dest_dataset_dir / "manifest.json" +# print(f"reading encrypted manifest from {data_manifest_path}") +# asr_data = list(ExtendedPath(data_manifest_path).read_jsonl()) +# enc_key_bytes = encryption_key.encode("utf-8") +# wc = wav_cryptor(enc_key_bytes) +# tc = text_cryptor(enc_key_bytes) +# +# def decrypt_fn(p): +# dest_path = dest_dataset_dir / Path(p["audio_filepath"]) +# wc.decrypt_wav_path_to( +# src_dataset_dir / Path(p["audio_filepath"]), dest_path +# ) +# d = { +# "audio_filepath": dest_path, +# "duration": p["duration"], +# "text": tc.decrypt_text(p["text"].encode("utf-8")), +# } +# return d +# +# def datagen(): +# if parallel: +# for d in parallel_apply(decrypt_fn, asr_data, verbose=verbose): +# yield d +# else: +# for p in tqdm.tqdm(asr_data) if verbose else asr_data: +# yield decrypt_fn(d) +# +# asr_manifest_writer(dest_manifest_path, datagen) diff --git a/src/plume/utils/parallel.py b/src/plume/utils/parallel.py new file mode 100644 index 0000000..d125de5 --- /dev/null +++ b/src/plume/utils/parallel.py @@ -0,0 +1,41 @@ +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor +from tqdm import tqdm + + +def parallel_apply(fn, iterable, workers=8, pool="thread", verbose=True): + # warm-up + fn(iterable[0]) + if pool == "thread": + with ThreadPoolExecutor(max_workers=workers) as exe: + if verbose: + print(f"parallelly applying {fn}") + return [ + res + for res in tqdm( + exe.map(fn, iterable), + position=0, + leave=True, + total=len(iterable), + ) + ] + else: + return [res for res in exe.map(fn, iterable)] + elif pool == "process": + with ProcessPoolExecutor(max_workers=workers) as exe: + if verbose: + print(f"parallelly applying {fn}") + with tqdm(total=len(iterable)) as progress: + futures = [] + for i in iterable: + future = exe.submit(fn, i) + future.add_done_callback(lambda p: progress.update()) + futures.append(future) + results = [] + for future in futures: + result = future.result() + results.append(result) + return result + else: + return [res for res in exe.map(fn, iterable)] + else: + raise Exception(f"unsupported pool type - {pool}") diff --git a/src/plume/utils/regentity.py b/src/plume/utils/regentity.py new file mode 100644 index 0000000..029cee6 --- /dev/null +++ b/src/plume/utils/regentity.py @@ -0,0 +1,383 @@ +import re + +from .lazy_import import lazy_callable, lazy_module + +num2words = lazy_callable("num2words.num2words") +spellchecker = lazy_module("spellchecker") +# from num2words import num2words + + +def entity_replacer_keeper( + pre_rules=[], entity_rules=[], post_rules=[], verbose=False +): + # def replacer_keeper_gen(): + pre_rules_c = [(re.compile(k), v) for (k, v) in pre_rules] + entity_rules_c = [ + (re.compile(k, re.IGNORECASE), v) for (k, v) in entity_rules + ] + post_rules_c = [(re.compile(k), v) for (k, v) in post_rules] + + re_rules = pre_rules_c + entity_rules_c + post_rules_c + + def replacer(w2v_out): + out = w2v_out + for (k, v) in re_rules: + orig = out + out = k.sub(v, out) + if verbose: + print(f"rule |{k}|: sub:|{v}| |{orig}|=> |{out}|") + return out + + def merge_intervals(intervals): + # https://codereview.stackexchange.com/a/69249 + sorted_by_lower_bound = sorted(intervals, key=lambda tup: tup[0]) + merged = [] + + for higher in sorted_by_lower_bound: + if not merged: + merged.append(higher) + else: + lower = merged[-1] + # test for intersection between lower and higher: + # we know via sorting that lower[0] <= higher[0] + if higher[0] <= lower[1]: + upper_bound = max(lower[1], higher[1]) + merged[-1] = ( + lower[0], + upper_bound, + ) # replace by merged interval + else: + merged.append(higher) + return merged + + # optimal merging interval tree + # https://www.geeksforgeeks.org/interval-tree/ + + def keep_literals(w2v_out): + # out = re.sub(r"[ ;,.]", " ", w2v_out).strip() + out = w2v_out + for (k, v) in pre_rules_c: + out = k.sub(v, out) + num_spans = [] + if verbose: + print(f"num_rules: {len(entity_rules_c)}") + for (k, v) in entity_rules_c: # [94:]: + matches = k.finditer(out) + for m in matches: + # num_spans.append(m.span()) + # look at space seprated internal entities + (start, end) = m.span() + for s in re.finditer(r"\S+", out[start:end]): + (start_e, end_e) = s.span() + num_spans.append((start_e + start, end_e + start)) + if verbose: + t = out[start_e + start : end_e + start] + print(f"rule |{k}|: sub:|{v}| => |{t}|") + + merged = merge_intervals(num_spans) + num_ents = len(merged) + keep_out = " ".join((out[s[0] : s[1]] for s in merged)) + for (k, v) in post_rules_c: + keep_out = k.sub(v, keep_out) + return keep_out, num_ents + + return replacer, keep_literals + + +def default_num_only_rules(num_range): + entity_rules = ( + [ + ( + r"\b" + num2words(i) + r"\b", + str(i), + ) + for i in reversed(range(num_range)) + ] + + [ + ( + r"\b" + str(i) + r"\b", + str(i), + ) + for i in reversed(range(10)) + ] + + [ + (r"\bhundred\b", "00"), + ] + ) + return entity_rules + + +def default_num_rules(num_range): + entity_rules = default_num_only_rules(num_range) + [ + (r"\boh\b", "0"), + (r"\bo\b", "0"), + (r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"), + (r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"), + ] + return entity_rules + + +def infer_num_rules_vocab(num_range): + vocab = [num2words(i) for i in reversed(range(num_range))] + [ + "hundred", + "double", + "triple", + ] + entity_rules = [ + ( + num2words(i), + str(i), + ) + for i in reversed(range(num_range)) + ] + [ + (r"\bhundred\b", "00"), + (r"\boh\b", "0"), + (r"\bo\b", "0"), + (r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"), + (r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"), + ] + return entity_rules, vocab + + +def do_tri_verbose_list(): + return [ + num2words(i) for i in list(range(11, 19)) + list(range(20, 100, 10)) + ] + ["hundred"] + + +def default_alnum_rules(num_range, oh_is_zero, i_oh_limit): + oh_is_zero_rules = [ + (r"\boh\b", "0"), + (r"\bo\b", "0"), + ] + + num_list = [num2words(i) for i in reversed(range(num_range))] + al_num_regex = r"|".join(num_list) + r"|[0-9a-z]" + o_i_vars = r"(\[?(?:Oh|O|I)\]?)" + i_oh_limit_rules = [ + (r"\b([a-hj-np-z])\b", "\\1"), + ( + r"\b((?:" + + al_num_regex + + r"|^)\b\s*)(I|O)(\s*\b)(?=" + + al_num_regex + + r"\s+|$)\b", + "\\1[\\2]\\3", + ), + # ( + # r"\b" + o_i_vars + r"(\s+)" + o_i_vars + r"\b", + # "[\\1]\\2[\\3]", + # ), + ( + r"(\s+|^)" + o_i_vars + r"(\s+)\[?" + o_i_vars + r"\]?(\s+|$)", + "\\1[\\2]\\3[\\4]\\5", + ), + ( + r"(\s+|^)\[?" + o_i_vars + r"\]?(\s+)" + o_i_vars + r"(\s+|$)", + "\\1[\\2]\\3[\\4]\\5", + ), + ] + entity_rules = ( + default_num_only_rules(num_range) + + (oh_is_zero_rules if oh_is_zero else [(r"\boh\b", "o")]) + + [ + (r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"), + (r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"), + # (r"\b([a-zA-Z])\b", "\\1"), + ] + + (i_oh_limit_rules if i_oh_limit else [(r"\b([a-zA-Z])\b", "\\1")]) + ) + return entity_rules + + +def num_replacer(num_range=100, condense=True): + entity_rules = default_num_rules(num_range) + post_rules = [(r"[^0-9]", "")] if condense else [] + replacer, keeper = entity_replacer_keeper( + entity_rules=entity_rules, post_rules=post_rules + ) + return replacer + + +def num_keeper(num_range=100): + entity_rules = default_num_rules(num_range) + pre_rules = [(r"[ ;,.]", " ")] + post_rules = [] + replacer, keeper = entity_replacer_keeper( + pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules + ) + return keeper + + +def alnum_replacer( + num_range=100, oh_is_zero=False, i_oh_limit=True, condense=True +): + entity_rules = default_alnum_rules( + num_range, oh_is_zero, i_oh_limit=i_oh_limit + ) + # entity_rules = default_num_rules(num_range) + pre_rules = [ + (r"[ ;,.]", " "), + (r"[']", ""), + # ( + # r"((?:(?<=\w{2,2})|^)\s*)(?:\bI\b|\bi\b|\bOh\b|\boh\b)(\s*(?:\w{2,}|$))", + # "", + # ), + ] + + def upper_case(match_obj): + char_elem = match_obj.group(0) + return char_elem.upper() + + post_rules = ( + ( + ( + [ + (r"(\s|^)(?:o|O|I|i)(\s|$)", "\\1\\2"), + (r"\[(\w)\]", "\\1"), + ] + if i_oh_limit + else [] + ) + + [ + # (r"\b[a-zA-Z]+\'[a-zA-Z]+\b", ""), + (r"\b[a-zA-Z]{2,}\b", ""), + (r"[^a-zA-Z0-9]", ""), + (r"([a-z].*)", upper_case), + ] + ) + if condense + else [] + ) + replacer, keeper = entity_replacer_keeper( + pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules + ) + return replacer + + +def alnum_keeper(num_range=100, oh_is_zero=False): + entity_rules = default_alnum_rules(num_range, oh_is_zero, i_oh_limit=True) + + # def strip_space(match_obj): + # # char_elem = match_obj.group(1) + # return match_obj.group(1).strip() + match_obj.group(2).strip() + + pre_rules = [ + (r"[ ;,.]", " "), + (r"[']", ""), + # ( + # r"((?:(?<=\w{2,2})|^)\s*)(?:\bI\b|\bi\b|\bOh\b|\boh\b)(\s*(?:\w{2,}|$))", + # strip_space, + # ), + ] + + post_rules = [ + # ( + # r"((?:(?<=\w{2,2})|^)\s*)(?:\bI\b|\bi\b|\bOh\b|\boh\b)(\s*(?:\w{2,}|$))", + # strip_space, + # ) + ] + replacer, keeper = entity_replacer_keeper( + pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules + ) + return keeper + + +def num_keeper_orig(num_range=10, extra_rules=[]): + num_int_map_ty = [ + ( + r"\b" + num2words(i) + r"\b", + " " + str(i) + " ", + ) + for i in reversed(range(num_range)) + ] + re_rules = [ + (re.compile(k, re.IGNORECASE), v) + for (k, v) in [ + # (r"[ ;,.]", " "), + (r"\bdouble(?: |-)(\w+)\b", "\\1 \\1"), + (r"\btriple(?: |-)(\w+)\b", "\\1 \\1 \\1"), + (r"hundred", "00"), + (r"\boh\b", " 0 "), + (r"\bo\b", " 0 "), + ] + + num_int_map_ty + ] + [(re.compile(k), v) for (k, v) in extra_rules] + + def merge_intervals(intervals): + # https://codereview.stackexchange.com/a/69249 + sorted_by_lower_bound = sorted(intervals, key=lambda tup: tup[0]) + merged = [] + + for higher in sorted_by_lower_bound: + if not merged: + merged.append(higher) + else: + lower = merged[-1] + # test for intersection between lower and higher: + # we know via sorting that lower[0] <= higher[0] + if higher[0] <= lower[1]: + upper_bound = max(lower[1], higher[1]) + merged[-1] = ( + lower[0], + upper_bound, + ) # replace by merged interval + else: + merged.append(higher) + return merged + + # merging interval tree for optimal # https://www.geeksforgeeks.org/interval-tree/ + + def keep_numeric_literals(w2v_out): + # out = w2v_out.lower() + out = re.sub(r"[ ;,.]", " ", w2v_out).strip() + # out = " " + out.strip() + " " + # out = re.sub(r"double (\w+)", "\\1 \\1", out) + # out = re.sub(r"triple (\w+)", "\\1 \\1 \\1", out) + num_spans = [] + for (k, v) in re_rules: # [94:]: + matches = k.finditer(out) + for m in matches: + # num_spans.append((k, m.span())) + num_spans.append(m.span()) + # out = re.sub(k, v, out) + merged = merge_intervals(num_spans) + num_ents = len(merged) + keep_out = " ".join((out[s[0] : s[1]] for s in merged)) + return keep_out, num_ents + + return keep_numeric_literals + + +def infer_num_replacer(num_range=100, condense=True): + entity_rules, vocab = infer_num_rules_vocab(num_range) + corrector = vocab_corrector_gen(vocab) + post_rules = [(r"[^0-9]", "")] if condense else [] + replacer, keeper = entity_replacer_keeper( + entity_rules=entity_rules, post_rules=post_rules + ) + + def final_replacer(x): + return replacer(corrector(x)) + + return final_replacer + + +def vocab_corrector_gen(vocab): + spell = spellchecker.SpellChecker(distance=1) + words_to_remove = set(spell.word_frequency.words()) - set(vocab) + spell.word_frequency.remove_words(words_to_remove) + + def corrector(inp): + return " ".join( + [spell.correction(tok) for tok in spell.split_words(inp)] + ) + + return corrector + + +if __name__ == "__main__": + repl = infer_num_replacer() + import pdb + + pdb.set_trace() diff --git a/plume/utils/serve.py b/src/plume/utils/serve.py similarity index 86% rename from plume/utils/serve.py rename to src/plume/utils/serve.py index 103d68a..02375a7 100644 --- a/plume/utils/serve.py +++ b/src/plume/utils/serve.py @@ -1,7 +1,7 @@ from plume.utils import lazy_module import typer -rpyc = lazy_module('rpyc') +rpyc = lazy_module("rpyc") app = typer.Typer() @@ -20,7 +20,9 @@ class ASRService(rpyc.Service): # (to finalize the service, if needed) pass - def exposed_transcribe(self, utterance: bytes): # this is an exposed method + def exposed_transcribe( + self, utterance: bytes + ): # this is an exposed method speech_audio = self.asr.transcribe(utterance) return speech_audio diff --git a/plume/utils/st_rerun.py b/src/plume/utils/st_rerun.py similarity index 100% rename from plume/utils/st_rerun.py rename to src/plume/utils/st_rerun.py diff --git a/plume/utils/transcribe.py b/src/plume/utils/transcribe.py similarity index 81% rename from plume/utils/transcribe.py rename to src/plume/utils/transcribe.py index 330177d..8964b6d 100644 --- a/plume/utils/transcribe.py +++ b/src/plume/utils/transcribe.py @@ -5,15 +5,16 @@ from pathlib import Path from functools import lru_cache import typer + # import rpyc # from tqdm import tqdm # from pydub.silence import split_on_silence -from plume.utils import lazy_module, lazy_callable +from .lazy_import import lazy_module -rpyc = lazy_module('rpyc') -pydub = lazy_module('pydub') -split_on_silence = lazy_callable('pydub.silence.split_on_silence') +rpyc = lazy_module("rpyc") +pydub = lazy_module("pydub") +np = lazy_module("numpy") app = typer.Typer() @@ -23,7 +24,7 @@ logging.basicConfig( logger = logging.getLogger(__name__) -ASR_RPYC_HOST = os.environ.get("JASR_RPYC_HOST", "localhost") +ASR_RPYC_HOST = os.environ.get("ASR_RPYC_HOST", "localhost") ASR_RPYC_PORT = int(os.environ.get("ASR_RPYC_PORT", "8044")) TRITON_ASR_MODEL = os.environ.get("TRITON_ASR_MODEL", "slu_wav2vec2") @@ -37,13 +38,16 @@ def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT): logger.info(f"connecting to asr server at {asr_host}:{asr_port}") try: asr = rpyc.connect(asr_host, asr_port).root - logger.info(f"connected to asr server successfully") + logger.info("connected to asr server successfully") except ConnectionRefusedError: raise Exception("env-var JASPER_ASR_RPYC_HOST invalid") def audio_prep(aud_seg): asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000) - return asr_seg + af = BytesIO() + asr_seg.export(af, format="wav") + input_audio_bytes = af.getvalue() + return input_audio_bytes return asr.transcribe, audio_prep @@ -58,9 +62,8 @@ def triton_transcribe_grpc_gen( # overlap=False, sep=" ", ): - from tritonclient.utils import np_to_triton_dtype + from tritonclient.utils import np_to_triton_dtype, InferenceServerException import tritonclient.grpc as grpcclient - import numpy as np sup_meth = ["chunked", "silence", "whole"] if method not in sup_meth: @@ -83,13 +86,18 @@ def triton_transcribe_grpc_gen( ] inputs[0].set_data_from_numpy(input_audio_data) outputs = [grpcclient.InferRequestedOutput("OUTPUT_TEXT")] - response = client.infer(asr_model, inputs, request_id=str(1), outputs=outputs) - transcript = response.as_numpy("OUTPUT_TEXT")[0] + try: + response = client.infer( + asr_model, inputs, request_id=str(1), outputs=outputs + ) + transcript = response.as_numpy("OUTPUT_TEXT")[0] + except InferenceServerException: + transcript = b"[server error]" return transcript.decode("utf-8") def chunked_transcriber(aud_seg): if method == "silence": - sil_chunks = split_on_silence( + sil_chunks = pydub.silence.split_on_silence( aud_seg, min_silence_len=sil_msec, silence_thresh=-50, @@ -122,9 +130,14 @@ def triton_transcribe_grpc_gen( @app.command() -def file(audio_file: Path, write_file: bool = False, chunked=True): +def file( + audio_file: Path, write_file: bool = False, chunked: bool = True, rpyc: bool = False, model='slu_wav2vec2' +): aseg = pydub.AudioSegment.from_file(audio_file) - transcriber, prep = triton_transcribe_grpc_gen() + if rpyc: + transcriber, prep = transcribe_rpyc_gen() + else: + transcriber, prep = triton_transcribe_grpc_gen(asr_model=model) transcription = transcriber(prep(aseg)) typer.echo(transcription) diff --git a/plume/utils/tts.py b/src/plume/utils/tts.py similarity index 100% rename from plume/utils/tts.py rename to src/plume/utils/tts.py diff --git a/plume/utils/ui_persist.py b/src/plume/utils/ui_persist.py similarity index 100% rename from plume/utils/ui_persist.py rename to src/plume/utils/ui_persist.py diff --git a/src/plume/utils/vad.py b/src/plume/utils/vad.py new file mode 100644 index 0000000..f93f018 --- /dev/null +++ b/src/plume/utils/vad.py @@ -0,0 +1,134 @@ +import logging +from .lazy_import import lazy_module + +webrtcvad = lazy_module("webrtcvad") +pydub = lazy_module("pydub") + +DEFAULT_CHUNK_DUR = 30 +logger = logging.getLogger(__name__) + + +def is_frame_voice(vad, seg, chunk_dur): + return ( + True + if ( + seg.duration_seconds == chunk_dur / 1000 + and vad.is_speech(seg.raw_data, seg.frame_rate) + ) + else False + ) + + +class VADUtterance(object): + """docstring for VADUtterance.""" + + def __init__( + self, + max_silence=500, + min_utterance=280, + max_utterance=20000, + chunk_dur=DEFAULT_CHUNK_DUR, + start_cycles=3, + aggression=1, + ): + super(VADUtterance, self).__init__() + self.vad = webrtcvad.Vad(aggression) + self.chunk_dur = chunk_dur + # duration in millisecs + self.max_sil = max_silence + self.min_utt = min_utterance + self.max_utt = max_utterance + self.speech_start = start_cycles * chunk_dur + + def __repr__(self): + return f"VAD(max_silence={self.max_sil},min_utterance:{self.min_utt},max_utterance:{self.max_utt})" + + def stream_segments(self, audio_seg): + stream_seg = audio_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000) + silence_buffer = pydub.AudioSegment.empty() + voice_buffer = pydub.AudioSegment.empty() + silence_threshold = False + for c in stream_seg[:: self.chunk_dur]: + voice_frame = is_frame_voice(self.vad, c, self.chunk_dur) + # logger.info(f"is audio stream voice? {voice_frame}") + if voice_frame: + silence_threshold = False + voice_buffer += c + silence_buffer = pydub.AudioSegment.empty() + else: + silence_buffer += c + voc_dur = len(voice_buffer) + sil_dur = len(silence_buffer) + + if voc_dur >= self.max_utt: + # logger.info( + # f"detected voice overflow: voice duration {voice_buffer.duration_seconds}" + # ) + yield voice_buffer + voice_buffer = pydub.AudioSegment.empty() + + if sil_dur >= self.max_sil: + if voc_dur >= self.min_utt: + # logger.info( + # f"detected silence: voice duration {voice_buffer.duration_seconds}" + # ) + yield voice_buffer + voice_buffer = pydub.AudioSegment.empty() + # ignore/clear voice if silence reached threshold or indent the statement + if not silence_threshold: + silence_threshold = True + + # if voice_buffer: + # yield voice_buffer + + if self.min_utt < len(voice_buffer) < self.max_utt: + yield voice_buffer + + # def stream_utterance(self, audio_stream): + # silence_buffer = pydub.AudioSegment.empty() + # voice_buffer = pydub.AudioSegment.empty() + # silence_threshold = False + # for avf in audio_stream: + # audio_bytes = avf.to_ndarray().tobytes() + # c = ( + # pydub.AudioSegment( + # data=audio_bytes, + # frame_rate=avf.sample_rate, + # channels=len(avf.layout.channels), + # sample_width=avf.format.bytes, + # ) + # .set_channels(1) + # .set_sample_width(2) + # .set_frame_rate(16000) + # ) + # voice_frame = is_frame_voice(self.vad, c, self.chunk_dur) + # # logger.info(f"is audio stream voice? {voice_frame}") + # if voice_frame: + # silence_threshold = False + # voice_buffer += c + # silence_buffer = pydub.AudioSegment.empty() + # else: + # silence_buffer += c + # voc_dur = voice_buffer.duration_seconds * 1000 + # sil_dur = silence_buffer.duration_seconds * 1000 + # + # if voc_dur >= self.max_utt: + # # logger.info( + # # f"detected voice overflow: voice duration {voice_buffer.duration_seconds}" + # # ) + # yield voice_buffer + # voice_buffer = pydub.AudioSegment.empty() + # + # if sil_dur >= self.max_sil: + # if voc_dur >= self.min_utt: + # # logger.info( + # # f"detected silence: voice duration {voice_buffer.duration_seconds}" + # # ) + # yield voice_buffer + # voice_buffer = pydub.AudioSegment.empty() + # # ignore/clear voice if silence reached threshold or indent the statement + # if not silence_threshold: + # silence_threshold = True + # + # if voice_buffer: + # yield voice_buffer diff --git a/tests/plume/test_entity_replacer_standalone.py b/tests/plume/test_entity_replacer_standalone.py new file mode 100644 index 0000000..7ba3caa --- /dev/null +++ b/tests/plume/test_entity_replacer_standalone.py @@ -0,0 +1,317 @@ +import re + + +def entity_replacer_keeper(pre_rules=[], entity_rules=[], post_rules=[]): + # def replacer_keeper_gen(): + pre_rules_c = [(re.compile(k), v) for (k, v) in pre_rules] + entity_rules_c = [(re.compile(k, re.IGNORECASE), v) for (k, v) in entity_rules] + post_rules_c = [(re.compile(k), v) for (k, v) in post_rules] + + re_rules = pre_rules_c + entity_rules_c + post_rules_c + + def replacer(w2v_out): + out = w2v_out + for (k, v) in re_rules: + out = k.sub(v, out) + return out + + def merge_intervals(intervals): + # https://codereview.stackexchange.com/a/69249 + sorted_by_lower_bound = sorted(intervals, key=lambda tup: tup[0]) + merged = [] + + for higher in sorted_by_lower_bound: + if not merged: + merged.append(higher) + else: + lower = merged[-1] + # test for intersection between lower and higher: + # we know via sorting that lower[0] <= higher[0] + if higher[0] <= lower[1]: + upper_bound = max(lower[1], higher[1]) + merged[-1] = ( + lower[0], + upper_bound, + ) # replace by merged interval + else: + merged.append(higher) + return merged + + # merging interval tree for optimal # https://www.geeksforgeeks.org/interval-tree/ + + def keep_literals(w2v_out): + # out = re.sub(r"[ ;,.]", " ", w2v_out).strip() + out = w2v_out + for (k, v) in pre_rules_c: + out = k.sub(v, out) + num_spans = [] + for (k, v) in entity_rules_c: # [94:]: + matches = k.finditer(out) + for m in matches: + # num_spans.append((k, m.span())) + num_spans.append(m.span()) + # out = re.sub(k, v, out) + merged = merge_intervals(num_spans) + num_ents = len(merged) + keep_out = " ".join((out[s[0] : s[1]] for s in merged)) + for (k, v) in post_rules_c: + keep_out = k.sub(v, keep_out) + return keep_out, num_ents + + return replacer, keep_literals + + +def default_num_only_rules(num_range): + entity_rules = ( + [ + ("\\bninety-nine\\b", "99"), + ("\\bninety-eight\\b", "98"), + ("\\bninety-seven\\b", "97"), + ("\\bninety-six\\b", "96"), + ("\\bninety-five\\b", "95"), + ("\\bninety-four\\b", "94"), + ("\\bninety-three\\b", "93"), + ("\\bninety-two\\b", "92"), + ("\\bninety-one\\b", "91"), + ("\\bninety\\b", "90"), + ("\\beighty-nine\\b", "89"), + ("\\beighty-eight\\b", "88"), + ("\\beighty-seven\\b", "87"), + ("\\beighty-six\\b", "86"), + ("\\beighty-five\\b", "85"), + ("\\beighty-four\\b", "84"), + ("\\beighty-three\\b", "83"), + ("\\beighty-two\\b", "82"), + ("\\beighty-one\\b", "81"), + ("\\beighty\\b", "80"), + ("\\bseventy-nine\\b", "79"), + ("\\bseventy-eight\\b", "78"), + ("\\bseventy-seven\\b", "77"), + ("\\bseventy-six\\b", "76"), + ("\\bseventy-five\\b", "75"), + ("\\bseventy-four\\b", "74"), + ("\\bseventy-three\\b", "73"), + ("\\bseventy-two\\b", "72"), + ("\\bseventy-one\\b", "71"), + ("\\bseventy\\b", "70"), + ("\\bsixty-nine\\b", "69"), + ("\\bsixty-eight\\b", "68"), + ("\\bsixty-seven\\b", "67"), + ("\\bsixty-six\\b", "66"), + ("\\bsixty-five\\b", "65"), + ("\\bsixty-four\\b", "64"), + ("\\bsixty-three\\b", "63"), + ("\\bsixty-two\\b", "62"), + ("\\bsixty-one\\b", "61"), + ("\\bsixty\\b", "60"), + ("\\bfifty-nine\\b", "59"), + ("\\bfifty-eight\\b", "58"), + ("\\bfifty-seven\\b", "57"), + ("\\bfifty-six\\b", "56"), + ("\\bfifty-five\\b", "55"), + ("\\bfifty-four\\b", "54"), + ("\\bfifty-three\\b", "53"), + ("\\bfifty-two\\b", "52"), + ("\\bfifty-one\\b", "51"), + ("\\bfifty\\b", "50"), + ("\\bforty-nine\\b", "49"), + ("\\bforty-eight\\b", "48"), + ("\\bforty-seven\\b", "47"), + ("\\bforty-six\\b", "46"), + ("\\bforty-five\\b", "45"), + ("\\bforty-four\\b", "44"), + ("\\bforty-three\\b", "43"), + ("\\bforty-two\\b", "42"), + ("\\bforty-one\\b", "41"), + ("\\bforty\\b", "40"), + ("\\bthirty-nine\\b", "39"), + ("\\bthirty-eight\\b", "38"), + ("\\bthirty-seven\\b", "37"), + ("\\bthirty-six\\b", "36"), + ("\\bthirty-five\\b", "35"), + ("\\bthirty-four\\b", "34"), + ("\\bthirty-three\\b", "33"), + ("\\bthirty-two\\b", "32"), + ("\\bthirty-one\\b", "31"), + ("\\bthirty\\b", "30"), + ("\\btwenty-nine\\b", "29"), + ("\\btwenty-eight\\b", "28"), + ("\\btwenty-seven\\b", "27"), + ("\\btwenty-six\\b", "26"), + ("\\btwenty-five\\b", "25"), + ("\\btwenty-four\\b", "24"), + ("\\btwenty-three\\b", "23"), + ("\\btwenty-two\\b", "22"), + ("\\btwenty-one\\b", "21"), + ("\\btwenty\\b", "20"), + ("\\bnineteen\\b", "19"), + ("\\beighteen\\b", "18"), + ("\\bseventeen\\b", "17"), + ("\\bsixteen\\b", "16"), + ("\\bfifteen\\b", "15"), + ("\\bfourteen\\b", "14"), + ("\\bthirteen\\b", "13"), + ("\\btwelve\\b", "12"), + ("\\beleven\\b", "11"), + ("\\bten\\b", "10"), + ("\\bnine\\b", "9"), + ("\\beight\\b", "8"), + ("\\bseven\\b", "7"), + ("\\bsix\\b", "6"), + ("\\bfive\\b", "5"), + ("\\bfour\\b", "4"), + ("\\bthree\\b", "3"), + ("\\btwo\\b", "2"), + ("\\bone\\b", "1"), + ("\\bzero\\b", "0"), + ] + + [ + ( + r"\b" + str(i) + r"\b", + str(i), + ) + for i in reversed(range(10)) + ] + + [ + (r"\bhundred\b", "00"), + ] + ) + return entity_rules + + +def default_num_rules(num_range): + entity_rules = default_num_only_rules(num_range) + [ + (r"\boh\b", " 0 "), + (r"\bo\b", " 0 "), + (r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"), + (r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"), + ] + return entity_rules + + +def default_alnum_rules(num_range, oh_is_zero): + oh_is_zero_rules = [ + (r"\boh\b", "0"), + (r"\bo\b", "0"), + ] + entity_rules = ( + default_num_only_rules(num_range) + + (oh_is_zero_rules if oh_is_zero else [(r"\boh\b", "o")]) + + [ + (r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"), + (r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"), + (r"\b([a-zA-Z])\b", "\\1"), + ] + ) + return entity_rules + + +def num_replacer(num_range=100, condense=True): + entity_rules = default_num_rules(num_range) + post_rules = [(r"[^0-9]", "")] if condense else [] + # post_rules = [] + replacer, keeper = entity_replacer_keeper( + entity_rules=entity_rules, post_rules=post_rules + ) + return replacer + + +def num_keeper(num_range=100): + entity_rules = default_num_rules(num_range) + pre_rules = [(r"[ ;,.]", " ")] + post_rules = [] + replacer, keeper = entity_replacer_keeper( + pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules + ) + return keeper + + +def alnum_replacer(num_range=100, oh_is_zero=False, condense=True): + entity_rules = default_alnum_rules(num_range, oh_is_zero) + # entity_rules = default_num_rules(num_range) + pre_rules = [(r"[ ;,.]", " "), (r"[']", "")] + + def upper_case(match_obj): + char_elem = match_obj.group(0) + return char_elem.upper() + + post_rules = ( + [ + # (r"\b[a-zA-Z]+\'[a-zA-Z]+\b", ""), + (r"\b[a-zA-Z]{2,}\b", ""), + (r"[^a-zA-Z0-9]", ""), + (r"([a-z].*)", upper_case), + ] + if condense + else [] + ) + replacer, keeper = entity_replacer_keeper( + pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules + ) + return replacer + + +def alnum_keeper(num_range=100, oh_is_zero=False): + entity_rules = default_alnum_rules(num_range, oh_is_zero) + pre_rules = [(r"[ ;,.]", " "), (r"[']", "")] + post_rules = [] + replacer, keeper = entity_replacer_keeper( + pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules + ) + return keeper + + +def test_num(): + num_extractor = num_replacer() + keeper = num_keeper() + num_only_replacer = num_replacer(condense=False) + assert num_extractor("thirty-two") == "32" + assert num_extractor("not thirty-two fifty-nine") == "3259" + assert num_extractor(" triPle 5 fifty 3") == "555503" + assert num_only_replacer(" triPle 5 fifty 3") == " 5 5 5 50 3" + assert num_extractor("douBle 2 130") == "22130" + assert num_extractor("It is a One fifty eIght 5 fifty ") == "1508550" + assert ( + num_only_replacer(" It is a One fifty eIght 5 fifty ") + == " It is a 1 50 8 5 50 " + ) + assert num_extractor("One fifty-eight 5 oh o fifty") == "15850050" + assert keeper( + "my phone number is One hundred fifty-eight not 5 oh o fifty more" + ) == ("One hundred fifty-eight 5 oh o fifty", 7) + + +def test_alnum(): + extractor_oh = alnum_replacer(oh_is_zero=True) + extractor = alnum_replacer() + keeper = alnum_keeper() + only_replacer = alnum_replacer(condense=False) + assert extractor("I'm thirty-two") == "32" + assert extractor("a thirty-two") == "A32" + assert extractor("not a b thirty-two fifty-nine") == "AB3259" + assert extractor(" triPle 5 fifty 3") == "555503" + assert only_replacer(" triPle 5 fifty 3") == " 5 5 5 50 3" + assert extractor("douBle 2 130") == "22130" + assert extractor("It is a One b fifty eIght A Z 5 fifty ") == "A1B508AZ550" + assert ( + only_replacer(" It's a ; One b fifty eIght A Z 5 fifty ") + == " Its a 1 b 50 8 A Z 5 50 " + ) + assert ( + only_replacer(" I'm is a One b fifty eIght A Z 5 fifty ") + == " Im is a 1 b 50 8 A Z 5 50 " + ) + assert extractor("One Z fifty-eight 5 oh o b fifty") == "1Z585OOB50" + assert extractor_oh("One Z fifty-eight 5 oh o b fifty") == "1Z58500B50" + assert keeper( + "I'll phone number One hundred n fifty-eight not 5 oh o fifty A B more" + ) == ("One hundred n fifty-eight 5 oh o fifty A B", 10) + assert keeper("I'm One hundred n fifty-eight not 5 oh o fifty A B more") == ( + "One hundred n fifty-eight 5 oh o fifty A B", + 10, + ) + + assert keeper("I am One hundred n fifty-eight not 5 oh o fifty A B more") == ( + "I One hundred n fifty-eight 5 oh o fifty A B", + 11, + ) diff --git a/tests/plume/test_utils.py b/tests/plume/test_utils.py new file mode 100644 index 0000000..0af41ee --- /dev/null +++ b/tests/plume/test_utils.py @@ -0,0 +1,105 @@ +from plume.utils import ( + num_replacer, + num_keeper, + alnum_replacer, + alnum_keeper, + random_segs, +) +import numpy +import random as rand +import pytest + + +def test_num_replacer_keeper(): + num_extractor = num_replacer() + num_only_replacer = num_replacer(condense=False) + assert num_extractor("thirty-two") == "32" + assert num_extractor("not thirty-two fifty-nine") == "3259" + assert num_extractor(" triPle 5 fifty 3") == "555503" + assert num_only_replacer(" triPle 5 fifty 3") == " 5 5 5 50 3" + assert num_extractor("douBle 2 130") == "22130" + assert num_extractor("It is a One fifty eIght 5 fifty ") == "1508550" + assert ( + num_only_replacer(" It is a One fifty eIght 5 fifty ") + == " It is a 1 50 8 5 50 " + ) + assert num_extractor("One fifty-eight 5 oh o fifty") == "15850050" + keeper = num_keeper() + assert keeper( + "my phone number is One hundred fifty-eight not 5 oh o fifty more" + ) == ("One hundred fifty-eight 5 oh o fifty", 7) + + +def test_alnum_replacer(): + extractor_oh = alnum_replacer(oh_is_zero=True) + extractor = alnum_replacer() + only_replacer = alnum_replacer(condense=False) + assert extractor("5 oh i c 3") == "5OIC3" + assert extractor("I am, oh it is 3. I will") == "3" + assert extractor("I oh o 3") == "IOO3" + assert extractor("I will 3 I") == "3I" + assert extractor("I'm thirty-two") == "32" + assert extractor("I am thirty-two") == "32" + assert extractor("I j thirty-two") == "IJ32" + assert extractor("a thirty-two") == "A32" + assert extractor("not a b thirty-two fifty-nine") == "AB3259" + assert extractor(" triPle 5 fifty 3") == "555503" + assert only_replacer(" triPle 5 fifty 3") == " 5 5 5 50 3" + assert extractor("douBle 2 130") == "22130" + assert extractor("It is a One b fifty eIght A Z 5 fifty ") == "A1B508AZ550" + assert ( + only_replacer(" It's a ; One b fifty eIght A Z 5 fifty ") + == " Its a 1 b 50 8 A Z 5 50 " + ) + assert ( + only_replacer(" I'm is a One b fifty eIght A Z 5 fifty ") + == " Im is a 1 b 50 8 A Z 5 50 " + ) + assert extractor("One Z fifty-eight 5 oh o b fifty") == "1Z585OOB50" + assert extractor_oh("One Z fifty-eight 5 oh o b fifty") == "1Z58500B50" + assert ( + extractor("I One hundred n fifty-eight not 5 oh o fifty A B more") + == "I100N585OO50AB" + ) + + +def test_alnum_keeper(): + keeper = alnum_keeper() + assert keeper("I One hundred n fifty-eight not 5 oh o fifty A B more") == ( + "I One hundred n fifty-eight 5 oh o fifty A B", + 11, + ) + assert keeper( + "I'll phone number One hundred n fifty-eight not 5 oh o fifty A B more" + ) == ("One hundred n fifty-eight 5 oh o fifty A B", 10) + assert keeper( + "I'm One hundred n fifty-eight not 5 oh o fifty A B more" + ) == ( + "One hundred n fifty-eight 5 oh o fifty A B", + 10, + ) + + assert keeper( + "I am One hundred n fifty-eight not 5 oh o fifty A B more" + ) == ( + "One hundred n fifty-eight 5 oh o fifty A B", + 10, + ) + + +@pytest.fixture +def random(): + rand.seed(0) + numpy.random.seed(0) + + +def test_random_segs(random): + segs = random_segs(100000, 1000, 3000) + + def segs_comply(segs, min, max): + for (start, end) in segs: + if end - start < min or end - start > max: + return False + return True + + assert segs_comply(segs, 1000, 3000) == True diff --git a/tests/plume/utils/test_regentity.py b/tests/plume/utils/test_regentity.py new file mode 100644 index 0000000..e17a0dd --- /dev/null +++ b/tests/plume/utils/test_regentity.py @@ -0,0 +1,17 @@ +from plume.utils.regentity import infer_num_replacer + + +def test_infer_num(): + repl = infer_num_replacer() + + assert ( + repl( + "SIX NINE TRIPL EIGHT SIX SIX DOULE NINE THREE ZERO TWO SEVENT-ONE" + ) + == "69888669930271" + ) + + assert ( + repl("SIX NINE FSIX EIGHT IGSIX SIX NINE NINE THRE ZERO TWO SEVEN ONE") + == "6968669930271" + ) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..036a9ea --- /dev/null +++ b/tox.ini @@ -0,0 +1,13 @@ +# tox (https://tox.readthedocs.io/) is a tool for running tests +# in multiple virtualenvs. This configuration file will run the +# test suite on all supported python versions. To use it, "pip install tox" +# and then run "tox" from this directory. + +[tox] +envlist = py38 + +[testenv] +deps = + pytest +commands = + pytest