From 846f029cf1918554787d0a586406a9f389b95a8d Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Wed, 8 Sep 2021 23:26:13 +0530 Subject: [PATCH] tegra wav2vec2 transformers --- setup.py | 4 +-- src/plume/models/wav2vec2_transformers/asr.py | 16 ++++----- .../models/wav2vec2_transformers/data.py | 20 +++++++---- .../models/wav2vec2_transformers/eval.py | 13 ++++--- .../models/wav2vec2_transformers/serve.py | 36 +++++++++++-------- .../models/wav2vec2_transformers/test.py | 8 +++-- .../models/wav2vec2_transformers/train.py | 24 +++---------- 7 files changed, 63 insertions(+), 58 deletions(-) diff --git a/setup.py b/setup.py index 8aa3101..edba725 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ requirements = [ # "streamlit~=0.61.0", # "librosa~=0.7.2", # "tritonclient[http]~=2.6.0", - "numba~=0.48.0", + # "numba~=0.48.0", ] extra_requirements = { @@ -66,7 +66,7 @@ extra_requirements = { "pyspellchecker~=0.6.2", "num2words~=0.5.10", "pydub~=0.24.0", - "pyaudio~=0.2.11" + "pyaudio~=0.2.11", ], "infer_min": [ "pyspellchecker~=0.6.2", diff --git a/src/plume/models/wav2vec2_transformers/asr.py b/src/plume/models/wav2vec2_transformers/asr.py index ef7f14f..c8d9d59 100644 --- a/src/plume/models/wav2vec2_transformers/asr.py +++ b/src/plume/models/wav2vec2_transformers/asr.py @@ -1,4 +1,4 @@ -from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC +from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC # import soundfile as sf from io import BytesIO @@ -12,20 +12,16 @@ sf = lazy_module("soundfile") class Wav2Vec2TransformersASR(object): """docstring for Wav2Vec2TransformersASR.""" - def __init__(self, ctc_path, w2v_path, target_dict_path): + def __init__(self, model_dir): super(Wav2Vec2TransformersASR, self).__init__() - self.tokenizer = Wav2Vec2Tokenizer.from_pretrained( - "facebook/wav2vec2-large-960h-lv60-self" - ) - self.model = Wav2Vec2ForCTC.from_pretrained( - "facebook/wav2vec2-large-960h-lv60-self" - ) + self.processor = Wav2Vec2Processor.from_pretrained(model_dir) + self.model = Wav2Vec2ForCTC.from_pretrained(model_dir) def transcribe(self, audio_data): aud_f = BytesIO(audio_data) # net_input = {} speech_data, _ = sf.read(aud_f) - input_values = self.tokenizer( + input_values = self.processor( speech_data, return_tensors="pt", padding="longest" ).input_values # Batch size 1 @@ -35,5 +31,5 @@ class Wav2Vec2TransformersASR(object): # take argmax and decode predicted_ids = torch.argmax(logits, dim=-1) - transcription = self.tokenizer.batch_decode(predicted_ids)[0] + transcription = self.processor.batch_decode(predicted_ids)[0] return transcription diff --git a/src/plume/models/wav2vec2_transformers/data.py b/src/plume/models/wav2vec2_transformers/data.py index 42e3e33..f4d82be 100644 --- a/src/plume/models/wav2vec2_transformers/data.py +++ b/src/plume/models/wav2vec2_transformers/data.py @@ -9,15 +9,18 @@ from tqdm import tqdm from plume.utils import ( ExtendedPath, replace_redundant_spaces_with, - lazy_module + lazy_module, ) -soundfile = lazy_module('soundfile') -pydub = lazy_module('pydub') + +soundfile = lazy_module("soundfile") +pydub = lazy_module("pydub") app = typer.Typer() @app.command() -def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True): +def export_jasper( + src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True +): dict_ltr = dest_dataset_path / Path("dict.ltr.txt") (dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True) tok_counter = Counter() @@ -51,13 +54,16 @@ def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool tsv_f.write(f"{src_dataset_path}\n") for md in manifest_data: audio_fname = md["audio_filepath"] - pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper() + pipe_toks = replace_redundant_spaces_with( + md["text"], "|" + ).upper() # pipe_toks = "|".join(re.sub(" ", "", md["text"])) - # pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|") tok_counter.update(pipe_toks) letter_toks = " ".join(pipe_toks) + " |\n" frame_count = soundfile.info(audio_fname).frames - rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute()) + rel_path = Path(audio_fname).relative_to( + src_dataset_path.absolute() + ) ltr_f.write(letter_toks) tsv_f.write(f"{rel_path}\t{frame_count}\n") with dict_ltr.open("w") as d_f: diff --git a/src/plume/models/wav2vec2_transformers/eval.py b/src/plume/models/wav2vec2_transformers/eval.py index 4b99501..aebab6c 100644 --- a/src/plume/models/wav2vec2_transformers/eval.py +++ b/src/plume/models/wav2vec2_transformers/eval.py @@ -1,23 +1,26 @@ from pathlib import Path import typer from tqdm import tqdm + # import pandas as pd from plume.utils import ( asr_manifest_reader, discard_except_digits, replace_digit_symbol, - lazy_module + lazy_module, # run_shell, ) from ...utils.transcribe import triton_transcribe_grpc_gen, transcribe_rpyc_gen -pd = lazy_module('pandas') +pd = lazy_module("pandas") app = typer.Typer() @app.command() -def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False): +def manifest( + manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False +): from pydub import AudioSegment host = "localhost" @@ -25,7 +28,9 @@ def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool if rpyc: transcriber, audio_prep = transcribe_rpyc_gen(host, port) else: - transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole') + transcriber, audio_prep = triton_transcribe_grpc_gen( + host, port, method="whole" + ) result_path = manifest_file.parent / result_file manifest_list = list(asr_manifest_reader(manifest_file)) diff --git a/src/plume/models/wav2vec2_transformers/serve.py b/src/plume/models/wav2vec2_transformers/serve.py index 5ea8334..2feed4e 100644 --- a/src/plume/models/wav2vec2_transformers/serve.py +++ b/src/plume/models/wav2vec2_transformers/serve.py @@ -7,6 +7,7 @@ import typer from ...utils.serve import ASRService from plume.utils import lazy_callable + # from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR # from .asr import Wav2Vec2ASR @@ -18,14 +19,29 @@ Wav2Vec2TransformersASR = lazy_callable( app = typer.Typer() +# @app.command() +# def rpyc( +# w2v_path: Path = "/path/to/base.pt", +# ctc_path: Path = "/path/to/ctc.pt", +# target_dict_path: Path = "/path/to/dict.ltr.txt", +# port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")), +# ): +# w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path) +# service = ASRService(w2vasr) +# logging.basicConfig( +# level=logging.INFO, +# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +# ) +# logging.info("starting asr server...") +# t = ThreadedServer(service, port=port) +# t.start() + + @app.command() -def rpyc( - w2v_path: Path = "/path/to/base.pt", - ctc_path: Path = "/path/to/ctc.pt", - target_dict_path: Path = "/path/to/dict.ltr.txt", - port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")), +def rpyc_dir( + model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")) ): - w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path) + w2vasr = Wav2Vec2TransformersASR(model_dir) service = ASRService(w2vasr) logging.basicConfig( level=logging.INFO, @@ -36,14 +52,6 @@ def rpyc( t.start() -@app.command() -def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))): - ctc_path = model_dir / Path("ctc.pt") - w2v_path = model_dir / Path("base.pt") - target_dict_path = model_dir / Path("dict.ltr.txt") - rpyc(w2v_path, ctc_path, target_dict_path, port) - - def main(): app() diff --git a/src/plume/models/wav2vec2_transformers/test.py b/src/plume/models/wav2vec2_transformers/test.py index f86ac75..84c9fa2 100644 --- a/src/plume/models/wav2vec2_transformers/test.py +++ b/src/plume/models/wav2vec2_transformers/test.py @@ -4,8 +4,12 @@ import soundfile as sf import torch # load model and tokenizer -tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") -model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") +tokenizer = Wav2Vec2Tokenizer.from_pretrained( + "facebook/wav2vec2-large-960h-lv60-self" +) +model = Wav2Vec2ForCTC.from_pretrained( + "facebook/wav2vec2-large-960h-lv60-self" +) # define function to read in sound file diff --git a/src/plume/models/wav2vec2_transformers/train.py b/src/plume/models/wav2vec2_transformers/train.py index ffbaeca..709f907 100644 --- a/src/plume/models/wav2vec2_transformers/train.py +++ b/src/plume/models/wav2vec2_transformers/train.py @@ -1,33 +1,19 @@ import typer + # from fairseq_cli.train import cli_main -import sys +# import sys from pathlib import Path -import shlex +# import shlex from plume.utils import lazy_callable -cli_main = lazy_callable('fairseq_cli.train.cli_main') +cli_main = lazy_callable("fairseq_cli.train.cli_main") app = typer.Typer() @app.command() def local(dataset_path: Path): - args = f'''--distributed-world-size 1 {dataset_path} \ ---save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \ -valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \ ---sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \ ---labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \ ---mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \ ---zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \ ---optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \ ---hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \ ---activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \ ---log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize -''' - new_args = ['train.py'] - new_args.extend(shlex.split(args)) - sys.argv = new_args - cli_main() + pass if __name__ == "__main__":