tegra wav2vec2 transformers
parent
db51553320
commit
846f029cf1
4
setup.py
4
setup.py
|
|
@ -26,7 +26,7 @@ requirements = [
|
|||
# "streamlit~=0.61.0",
|
||||
# "librosa~=0.7.2",
|
||||
# "tritonclient[http]~=2.6.0",
|
||||
"numba~=0.48.0",
|
||||
# "numba~=0.48.0",
|
||||
]
|
||||
|
||||
extra_requirements = {
|
||||
|
|
@ -66,7 +66,7 @@ extra_requirements = {
|
|||
"pyspellchecker~=0.6.2",
|
||||
"num2words~=0.5.10",
|
||||
"pydub~=0.24.0",
|
||||
"pyaudio~=0.2.11"
|
||||
"pyaudio~=0.2.11",
|
||||
],
|
||||
"infer_min": [
|
||||
"pyspellchecker~=0.6.2",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
|
||||
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
||||
|
||||
# import soundfile as sf
|
||||
from io import BytesIO
|
||||
|
|
@ -12,20 +12,16 @@ sf = lazy_module("soundfile")
|
|||
class Wav2Vec2TransformersASR(object):
|
||||
"""docstring for Wav2Vec2TransformersASR."""
|
||||
|
||||
def __init__(self, ctc_path, w2v_path, target_dict_path):
|
||||
def __init__(self, model_dir):
|
||||
super(Wav2Vec2TransformersASR, self).__init__()
|
||||
self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
|
||||
"facebook/wav2vec2-large-960h-lv60-self"
|
||||
)
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(
|
||||
"facebook/wav2vec2-large-960h-lv60-self"
|
||||
)
|
||||
self.processor = Wav2Vec2Processor.from_pretrained(model_dir)
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(model_dir)
|
||||
|
||||
def transcribe(self, audio_data):
|
||||
aud_f = BytesIO(audio_data)
|
||||
# net_input = {}
|
||||
speech_data, _ = sf.read(aud_f)
|
||||
input_values = self.tokenizer(
|
||||
input_values = self.processor(
|
||||
speech_data, return_tensors="pt", padding="longest"
|
||||
).input_values # Batch size 1
|
||||
|
||||
|
|
@ -35,5 +31,5 @@ class Wav2Vec2TransformersASR(object):
|
|||
# take argmax and decode
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
transcription = self.tokenizer.batch_decode(predicted_ids)[0]
|
||||
transcription = self.processor.batch_decode(predicted_ids)[0]
|
||||
return transcription
|
||||
|
|
|
|||
|
|
@ -9,15 +9,18 @@ from tqdm import tqdm
|
|||
from plume.utils import (
|
||||
ExtendedPath,
|
||||
replace_redundant_spaces_with,
|
||||
lazy_module
|
||||
lazy_module,
|
||||
)
|
||||
soundfile = lazy_module('soundfile')
|
||||
pydub = lazy_module('pydub')
|
||||
|
||||
soundfile = lazy_module("soundfile")
|
||||
pydub = lazy_module("pydub")
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True):
|
||||
def export_jasper(
|
||||
src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True
|
||||
):
|
||||
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
|
||||
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
|
||||
tok_counter = Counter()
|
||||
|
|
@ -51,13 +54,16 @@ def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool
|
|||
tsv_f.write(f"{src_dataset_path}\n")
|
||||
for md in manifest_data:
|
||||
audio_fname = md["audio_filepath"]
|
||||
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()
|
||||
pipe_toks = replace_redundant_spaces_with(
|
||||
md["text"], "|"
|
||||
).upper()
|
||||
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
|
||||
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
|
||||
tok_counter.update(pipe_toks)
|
||||
letter_toks = " ".join(pipe_toks) + " |\n"
|
||||
frame_count = soundfile.info(audio_fname).frames
|
||||
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute())
|
||||
rel_path = Path(audio_fname).relative_to(
|
||||
src_dataset_path.absolute()
|
||||
)
|
||||
ltr_f.write(letter_toks)
|
||||
tsv_f.write(f"{rel_path}\t{frame_count}\n")
|
||||
with dict_ltr.open("w") as d_f:
|
||||
|
|
|
|||
|
|
@ -1,23 +1,26 @@
|
|||
from pathlib import Path
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
# import pandas as pd
|
||||
|
||||
from plume.utils import (
|
||||
asr_manifest_reader,
|
||||
discard_except_digits,
|
||||
replace_digit_symbol,
|
||||
lazy_module
|
||||
lazy_module,
|
||||
# run_shell,
|
||||
)
|
||||
from ...utils.transcribe import triton_transcribe_grpc_gen, transcribe_rpyc_gen
|
||||
|
||||
pd = lazy_module('pandas')
|
||||
pd = lazy_module("pandas")
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False):
|
||||
def manifest(
|
||||
manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False
|
||||
):
|
||||
from pydub import AudioSegment
|
||||
|
||||
host = "localhost"
|
||||
|
|
@ -25,7 +28,9 @@ def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool
|
|||
if rpyc:
|
||||
transcriber, audio_prep = transcribe_rpyc_gen(host, port)
|
||||
else:
|
||||
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole')
|
||||
transcriber, audio_prep = triton_transcribe_grpc_gen(
|
||||
host, port, method="whole"
|
||||
)
|
||||
result_path = manifest_file.parent / result_file
|
||||
manifest_list = list(asr_manifest_reader(manifest_file))
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import typer
|
|||
|
||||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
# from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR
|
||||
# from .asr import Wav2Vec2ASR
|
||||
|
||||
|
|
@ -18,14 +19,29 @@ Wav2Vec2TransformersASR = lazy_callable(
|
|||
app = typer.Typer()
|
||||
|
||||
|
||||
# @app.command()
|
||||
# def rpyc(
|
||||
# w2v_path: Path = "/path/to/base.pt",
|
||||
# ctc_path: Path = "/path/to/ctc.pt",
|
||||
# target_dict_path: Path = "/path/to/dict.ltr.txt",
|
||||
# port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
# ):
|
||||
# w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path)
|
||||
# service = ASRService(w2vasr)
|
||||
# logging.basicConfig(
|
||||
# level=logging.INFO,
|
||||
# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
# )
|
||||
# logging.info("starting asr server...")
|
||||
# t = ThreadedServer(service, port=port)
|
||||
# t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc(
|
||||
w2v_path: Path = "/path/to/base.pt",
|
||||
ctc_path: Path = "/path/to/ctc.pt",
|
||||
target_dict_path: Path = "/path/to/dict.ltr.txt",
|
||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
def rpyc_dir(
|
||||
model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||
):
|
||||
w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path)
|
||||
w2vasr = Wav2Vec2TransformersASR(model_dir)
|
||||
service = ASRService(w2vasr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -36,14 +52,6 @@ def rpyc(
|
|||
t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
|
||||
ctc_path = model_dir / Path("ctc.pt")
|
||||
w2v_path = model_dir / Path("base.pt")
|
||||
target_dict_path = model_dir / Path("dict.ltr.txt")
|
||||
rpyc(w2v_path, ctc_path, target_dict_path, port)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@ import soundfile as sf
|
|||
import torch
|
||||
|
||||
# load model and tokenizer
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained(
|
||||
"facebook/wav2vec2-large-960h-lv60-self"
|
||||
)
|
||||
model = Wav2Vec2ForCTC.from_pretrained(
|
||||
"facebook/wav2vec2-large-960h-lv60-self"
|
||||
)
|
||||
|
||||
|
||||
# define function to read in sound file
|
||||
|
|
|
|||
|
|
@ -1,33 +1,19 @@
|
|||
import typer
|
||||
|
||||
# from fairseq_cli.train import cli_main
|
||||
import sys
|
||||
# import sys
|
||||
from pathlib import Path
|
||||
import shlex
|
||||
# import shlex
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
cli_main = lazy_callable('fairseq_cli.train.cli_main')
|
||||
cli_main = lazy_callable("fairseq_cli.train.cli_main")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def local(dataset_path: Path):
|
||||
args = f'''--distributed-world-size 1 {dataset_path} \
|
||||
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
|
||||
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
|
||||
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
|
||||
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
|
||||
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
|
||||
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
|
||||
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
|
||||
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
|
||||
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
|
||||
'''
|
||||
new_args = ['train.py']
|
||||
new_args.extend(shlex.split(args))
|
||||
sys.argv = new_args
|
||||
cli_main()
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue