tegra wav2vec2 transformers

tegra
Malar Kannan 2021-09-08 23:26:13 +05:30
parent db51553320
commit 846f029cf1
7 changed files with 63 additions and 58 deletions

View File

@ -26,7 +26,7 @@ requirements = [
# "streamlit~=0.61.0", # "streamlit~=0.61.0",
# "librosa~=0.7.2", # "librosa~=0.7.2",
# "tritonclient[http]~=2.6.0", # "tritonclient[http]~=2.6.0",
"numba~=0.48.0", # "numba~=0.48.0",
] ]
extra_requirements = { extra_requirements = {
@ -66,7 +66,7 @@ extra_requirements = {
"pyspellchecker~=0.6.2", "pyspellchecker~=0.6.2",
"num2words~=0.5.10", "num2words~=0.5.10",
"pydub~=0.24.0", "pydub~=0.24.0",
"pyaudio~=0.2.11" "pyaudio~=0.2.11",
], ],
"infer_min": [ "infer_min": [
"pyspellchecker~=0.6.2", "pyspellchecker~=0.6.2",

View File

@ -1,4 +1,4 @@
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
# import soundfile as sf # import soundfile as sf
from io import BytesIO from io import BytesIO
@ -12,20 +12,16 @@ sf = lazy_module("soundfile")
class Wav2Vec2TransformersASR(object): class Wav2Vec2TransformersASR(object):
"""docstring for Wav2Vec2TransformersASR.""" """docstring for Wav2Vec2TransformersASR."""
def __init__(self, ctc_path, w2v_path, target_dict_path): def __init__(self, model_dir):
super(Wav2Vec2TransformersASR, self).__init__() super(Wav2Vec2TransformersASR, self).__init__()
self.tokenizer = Wav2Vec2Tokenizer.from_pretrained( self.processor = Wav2Vec2Processor.from_pretrained(model_dir)
"facebook/wav2vec2-large-960h-lv60-self" self.model = Wav2Vec2ForCTC.from_pretrained(model_dir)
)
self.model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-960h-lv60-self"
)
def transcribe(self, audio_data): def transcribe(self, audio_data):
aud_f = BytesIO(audio_data) aud_f = BytesIO(audio_data)
# net_input = {} # net_input = {}
speech_data, _ = sf.read(aud_f) speech_data, _ = sf.read(aud_f)
input_values = self.tokenizer( input_values = self.processor(
speech_data, return_tensors="pt", padding="longest" speech_data, return_tensors="pt", padding="longest"
).input_values # Batch size 1 ).input_values # Batch size 1
@ -35,5 +31,5 @@ class Wav2Vec2TransformersASR(object):
# take argmax and decode # take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1) predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.tokenizer.batch_decode(predicted_ids)[0] transcription = self.processor.batch_decode(predicted_ids)[0]
return transcription return transcription

View File

@ -9,15 +9,18 @@ from tqdm import tqdm
from plume.utils import ( from plume.utils import (
ExtendedPath, ExtendedPath,
replace_redundant_spaces_with, replace_redundant_spaces_with,
lazy_module lazy_module,
) )
soundfile = lazy_module('soundfile')
pydub = lazy_module('pydub') soundfile = lazy_module("soundfile")
pydub = lazy_module("pydub")
app = typer.Typer() app = typer.Typer()
@app.command() @app.command()
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True): def export_jasper(
src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True
):
dict_ltr = dest_dataset_path / Path("dict.ltr.txt") dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True) (dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
tok_counter = Counter() tok_counter = Counter()
@ -51,13 +54,16 @@ def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool
tsv_f.write(f"{src_dataset_path}\n") tsv_f.write(f"{src_dataset_path}\n")
for md in manifest_data: for md in manifest_data:
audio_fname = md["audio_filepath"] audio_fname = md["audio_filepath"]
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper() pipe_toks = replace_redundant_spaces_with(
md["text"], "|"
).upper()
# pipe_toks = "|".join(re.sub(" ", "", md["text"])) # pipe_toks = "|".join(re.sub(" ", "", md["text"]))
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
tok_counter.update(pipe_toks) tok_counter.update(pipe_toks)
letter_toks = " ".join(pipe_toks) + " |\n" letter_toks = " ".join(pipe_toks) + " |\n"
frame_count = soundfile.info(audio_fname).frames frame_count = soundfile.info(audio_fname).frames
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute()) rel_path = Path(audio_fname).relative_to(
src_dataset_path.absolute()
)
ltr_f.write(letter_toks) ltr_f.write(letter_toks)
tsv_f.write(f"{rel_path}\t{frame_count}\n") tsv_f.write(f"{rel_path}\t{frame_count}\n")
with dict_ltr.open("w") as d_f: with dict_ltr.open("w") as d_f:

View File

@ -1,23 +1,26 @@
from pathlib import Path from pathlib import Path
import typer import typer
from tqdm import tqdm from tqdm import tqdm
# import pandas as pd # import pandas as pd
from plume.utils import ( from plume.utils import (
asr_manifest_reader, asr_manifest_reader,
discard_except_digits, discard_except_digits,
replace_digit_symbol, replace_digit_symbol,
lazy_module lazy_module,
# run_shell, # run_shell,
) )
from ...utils.transcribe import triton_transcribe_grpc_gen, transcribe_rpyc_gen from ...utils.transcribe import triton_transcribe_grpc_gen, transcribe_rpyc_gen
pd = lazy_module('pandas') pd = lazy_module("pandas")
app = typer.Typer() app = typer.Typer()
@app.command() @app.command()
def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False): def manifest(
manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False
):
from pydub import AudioSegment from pydub import AudioSegment
host = "localhost" host = "localhost"
@ -25,7 +28,9 @@ def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool
if rpyc: if rpyc:
transcriber, audio_prep = transcribe_rpyc_gen(host, port) transcriber, audio_prep = transcribe_rpyc_gen(host, port)
else: else:
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole') transcriber, audio_prep = triton_transcribe_grpc_gen(
host, port, method="whole"
)
result_path = manifest_file.parent / result_file result_path = manifest_file.parent / result_file
manifest_list = list(asr_manifest_reader(manifest_file)) manifest_list = list(asr_manifest_reader(manifest_file))

View File

@ -7,6 +7,7 @@ import typer
from ...utils.serve import ASRService from ...utils.serve import ASRService
from plume.utils import lazy_callable from plume.utils import lazy_callable
# from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR # from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR
# from .asr import Wav2Vec2ASR # from .asr import Wav2Vec2ASR
@ -18,14 +19,29 @@ Wav2Vec2TransformersASR = lazy_callable(
app = typer.Typer() app = typer.Typer()
# @app.command()
# def rpyc(
# w2v_path: Path = "/path/to/base.pt",
# ctc_path: Path = "/path/to/ctc.pt",
# target_dict_path: Path = "/path/to/dict.ltr.txt",
# port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
# ):
# w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path)
# service = ASRService(w2vasr)
# logging.basicConfig(
# level=logging.INFO,
# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
# )
# logging.info("starting asr server...")
# t = ThreadedServer(service, port=port)
# t.start()
@app.command() @app.command()
def rpyc( def rpyc_dir(
w2v_path: Path = "/path/to/base.pt", model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
ctc_path: Path = "/path/to/ctc.pt",
target_dict_path: Path = "/path/to/dict.ltr.txt",
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
): ):
w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path) w2vasr = Wav2Vec2TransformersASR(model_dir)
service = ASRService(w2vasr) service = ASRService(w2vasr)
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@ -36,14 +52,6 @@ def rpyc(
t.start() t.start()
@app.command()
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
ctc_path = model_dir / Path("ctc.pt")
w2v_path = model_dir / Path("base.pt")
target_dict_path = model_dir / Path("dict.ltr.txt")
rpyc(w2v_path, ctc_path, target_dict_path, port)
def main(): def main():
app() app()

View File

@ -4,8 +4,12 @@ import soundfile as sf
import torch import torch
# load model and tokenizer # load model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") tokenizer = Wav2Vec2Tokenizer.from_pretrained(
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") "facebook/wav2vec2-large-960h-lv60-self"
)
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-960h-lv60-self"
)
# define function to read in sound file # define function to read in sound file

View File

@ -1,33 +1,19 @@
import typer import typer
# from fairseq_cli.train import cli_main # from fairseq_cli.train import cli_main
import sys # import sys
from pathlib import Path from pathlib import Path
import shlex # import shlex
from plume.utils import lazy_callable from plume.utils import lazy_callable
cli_main = lazy_callable('fairseq_cli.train.cli_main') cli_main = lazy_callable("fairseq_cli.train.cli_main")
app = typer.Typer() app = typer.Typer()
@app.command() @app.command()
def local(dataset_path: Path): def local(dataset_path: Path):
args = f'''--distributed-world-size 1 {dataset_path} \ pass
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
'''
new_args = ['train.py']
new_args.extend(shlex.split(args))
sys.argv = new_args
cli_main()
if __name__ == "__main__": if __name__ == "__main__":