tegra wav2vec2 transformers

tegra
Malar Kannan 2021-09-08 23:26:13 +05:30
parent db51553320
commit 846f029cf1
7 changed files with 63 additions and 58 deletions

View File

@ -26,7 +26,7 @@ requirements = [
# "streamlit~=0.61.0",
# "librosa~=0.7.2",
# "tritonclient[http]~=2.6.0",
"numba~=0.48.0",
# "numba~=0.48.0",
]
extra_requirements = {
@ -66,7 +66,7 @@ extra_requirements = {
"pyspellchecker~=0.6.2",
"num2words~=0.5.10",
"pydub~=0.24.0",
"pyaudio~=0.2.11"
"pyaudio~=0.2.11",
],
"infer_min": [
"pyspellchecker~=0.6.2",

View File

@ -1,4 +1,4 @@
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
# import soundfile as sf
from io import BytesIO
@ -12,20 +12,16 @@ sf = lazy_module("soundfile")
class Wav2Vec2TransformersASR(object):
"""docstring for Wav2Vec2TransformersASR."""
def __init__(self, ctc_path, w2v_path, target_dict_path):
def __init__(self, model_dir):
super(Wav2Vec2TransformersASR, self).__init__()
self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
"facebook/wav2vec2-large-960h-lv60-self"
)
self.model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-960h-lv60-self"
)
self.processor = Wav2Vec2Processor.from_pretrained(model_dir)
self.model = Wav2Vec2ForCTC.from_pretrained(model_dir)
def transcribe(self, audio_data):
aud_f = BytesIO(audio_data)
# net_input = {}
speech_data, _ = sf.read(aud_f)
input_values = self.tokenizer(
input_values = self.processor(
speech_data, return_tensors="pt", padding="longest"
).input_values # Batch size 1
@ -35,5 +31,5 @@ class Wav2Vec2TransformersASR(object):
# take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.tokenizer.batch_decode(predicted_ids)[0]
transcription = self.processor.batch_decode(predicted_ids)[0]
return transcription

View File

@ -9,15 +9,18 @@ from tqdm import tqdm
from plume.utils import (
ExtendedPath,
replace_redundant_spaces_with,
lazy_module
lazy_module,
)
soundfile = lazy_module('soundfile')
pydub = lazy_module('pydub')
soundfile = lazy_module("soundfile")
pydub = lazy_module("pydub")
app = typer.Typer()
@app.command()
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True):
def export_jasper(
src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True
):
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
tok_counter = Counter()
@ -51,13 +54,16 @@ def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool
tsv_f.write(f"{src_dataset_path}\n")
for md in manifest_data:
audio_fname = md["audio_filepath"]
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()
pipe_toks = replace_redundant_spaces_with(
md["text"], "|"
).upper()
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
tok_counter.update(pipe_toks)
letter_toks = " ".join(pipe_toks) + " |\n"
frame_count = soundfile.info(audio_fname).frames
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute())
rel_path = Path(audio_fname).relative_to(
src_dataset_path.absolute()
)
ltr_f.write(letter_toks)
tsv_f.write(f"{rel_path}\t{frame_count}\n")
with dict_ltr.open("w") as d_f:

View File

@ -1,23 +1,26 @@
from pathlib import Path
import typer
from tqdm import tqdm
# import pandas as pd
from plume.utils import (
asr_manifest_reader,
discard_except_digits,
replace_digit_symbol,
lazy_module
lazy_module,
# run_shell,
)
from ...utils.transcribe import triton_transcribe_grpc_gen, transcribe_rpyc_gen
pd = lazy_module('pandas')
pd = lazy_module("pandas")
app = typer.Typer()
@app.command()
def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False):
def manifest(
manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False
):
from pydub import AudioSegment
host = "localhost"
@ -25,7 +28,9 @@ def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool
if rpyc:
transcriber, audio_prep = transcribe_rpyc_gen(host, port)
else:
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole')
transcriber, audio_prep = triton_transcribe_grpc_gen(
host, port, method="whole"
)
result_path = manifest_file.parent / result_file
manifest_list = list(asr_manifest_reader(manifest_file))

View File

@ -7,6 +7,7 @@ import typer
from ...utils.serve import ASRService
from plume.utils import lazy_callable
# from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR
# from .asr import Wav2Vec2ASR
@ -18,14 +19,29 @@ Wav2Vec2TransformersASR = lazy_callable(
app = typer.Typer()
# @app.command()
# def rpyc(
# w2v_path: Path = "/path/to/base.pt",
# ctc_path: Path = "/path/to/ctc.pt",
# target_dict_path: Path = "/path/to/dict.ltr.txt",
# port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
# ):
# w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path)
# service = ASRService(w2vasr)
# logging.basicConfig(
# level=logging.INFO,
# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
# )
# logging.info("starting asr server...")
# t = ThreadedServer(service, port=port)
# t.start()
@app.command()
def rpyc(
w2v_path: Path = "/path/to/base.pt",
ctc_path: Path = "/path/to/ctc.pt",
target_dict_path: Path = "/path/to/dict.ltr.txt",
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
def rpyc_dir(
model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
):
w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path)
w2vasr = Wav2Vec2TransformersASR(model_dir)
service = ASRService(w2vasr)
logging.basicConfig(
level=logging.INFO,
@ -36,14 +52,6 @@ def rpyc(
t.start()
@app.command()
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
ctc_path = model_dir / Path("ctc.pt")
w2v_path = model_dir / Path("base.pt")
target_dict_path = model_dir / Path("dict.ltr.txt")
rpyc(w2v_path, ctc_path, target_dict_path, port)
def main():
app()

View File

@ -4,8 +4,12 @@ import soundfile as sf
import torch
# load model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
tokenizer = Wav2Vec2Tokenizer.from_pretrained(
"facebook/wav2vec2-large-960h-lv60-self"
)
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-960h-lv60-self"
)
# define function to read in sound file

View File

@ -1,33 +1,19 @@
import typer
# from fairseq_cli.train import cli_main
import sys
# import sys
from pathlib import Path
import shlex
# import shlex
from plume.utils import lazy_callable
cli_main = lazy_callable('fairseq_cli.train.cli_main')
cli_main = lazy_callable("fairseq_cli.train.cli_main")
app = typer.Typer()
@app.command()
def local(dataset_path: Path):
args = f'''--distributed-world-size 1 {dataset_path} \
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
'''
new_args = ['train.py']
new_args.extend(shlex.split(args))
sys.argv = new_args
cli_main()
pass
if __name__ == "__main__":