tegra wav2vec2 transformers
parent
db51553320
commit
846f029cf1
4
setup.py
4
setup.py
|
|
@ -26,7 +26,7 @@ requirements = [
|
||||||
# "streamlit~=0.61.0",
|
# "streamlit~=0.61.0",
|
||||||
# "librosa~=0.7.2",
|
# "librosa~=0.7.2",
|
||||||
# "tritonclient[http]~=2.6.0",
|
# "tritonclient[http]~=2.6.0",
|
||||||
"numba~=0.48.0",
|
# "numba~=0.48.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
extra_requirements = {
|
extra_requirements = {
|
||||||
|
|
@ -66,7 +66,7 @@ extra_requirements = {
|
||||||
"pyspellchecker~=0.6.2",
|
"pyspellchecker~=0.6.2",
|
||||||
"num2words~=0.5.10",
|
"num2words~=0.5.10",
|
||||||
"pydub~=0.24.0",
|
"pydub~=0.24.0",
|
||||||
"pyaudio~=0.2.11"
|
"pyaudio~=0.2.11",
|
||||||
],
|
],
|
||||||
"infer_min": [
|
"infer_min": [
|
||||||
"pyspellchecker~=0.6.2",
|
"pyspellchecker~=0.6.2",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
|
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
||||||
|
|
||||||
# import soundfile as sf
|
# import soundfile as sf
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
@ -12,20 +12,16 @@ sf = lazy_module("soundfile")
|
||||||
class Wav2Vec2TransformersASR(object):
|
class Wav2Vec2TransformersASR(object):
|
||||||
"""docstring for Wav2Vec2TransformersASR."""
|
"""docstring for Wav2Vec2TransformersASR."""
|
||||||
|
|
||||||
def __init__(self, ctc_path, w2v_path, target_dict_path):
|
def __init__(self, model_dir):
|
||||||
super(Wav2Vec2TransformersASR, self).__init__()
|
super(Wav2Vec2TransformersASR, self).__init__()
|
||||||
self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
|
self.processor = Wav2Vec2Processor.from_pretrained(model_dir)
|
||||||
"facebook/wav2vec2-large-960h-lv60-self"
|
self.model = Wav2Vec2ForCTC.from_pretrained(model_dir)
|
||||||
)
|
|
||||||
self.model = Wav2Vec2ForCTC.from_pretrained(
|
|
||||||
"facebook/wav2vec2-large-960h-lv60-self"
|
|
||||||
)
|
|
||||||
|
|
||||||
def transcribe(self, audio_data):
|
def transcribe(self, audio_data):
|
||||||
aud_f = BytesIO(audio_data)
|
aud_f = BytesIO(audio_data)
|
||||||
# net_input = {}
|
# net_input = {}
|
||||||
speech_data, _ = sf.read(aud_f)
|
speech_data, _ = sf.read(aud_f)
|
||||||
input_values = self.tokenizer(
|
input_values = self.processor(
|
||||||
speech_data, return_tensors="pt", padding="longest"
|
speech_data, return_tensors="pt", padding="longest"
|
||||||
).input_values # Batch size 1
|
).input_values # Batch size 1
|
||||||
|
|
||||||
|
|
@ -35,5 +31,5 @@ class Wav2Vec2TransformersASR(object):
|
||||||
# take argmax and decode
|
# take argmax and decode
|
||||||
predicted_ids = torch.argmax(logits, dim=-1)
|
predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
|
|
||||||
transcription = self.tokenizer.batch_decode(predicted_ids)[0]
|
transcription = self.processor.batch_decode(predicted_ids)[0]
|
||||||
return transcription
|
return transcription
|
||||||
|
|
|
||||||
|
|
@ -9,15 +9,18 @@ from tqdm import tqdm
|
||||||
from plume.utils import (
|
from plume.utils import (
|
||||||
ExtendedPath,
|
ExtendedPath,
|
||||||
replace_redundant_spaces_with,
|
replace_redundant_spaces_with,
|
||||||
lazy_module
|
lazy_module,
|
||||||
)
|
)
|
||||||
soundfile = lazy_module('soundfile')
|
|
||||||
pydub = lazy_module('pydub')
|
soundfile = lazy_module("soundfile")
|
||||||
|
pydub = lazy_module("pydub")
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True):
|
def export_jasper(
|
||||||
|
src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True
|
||||||
|
):
|
||||||
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
|
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
|
||||||
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
|
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
|
||||||
tok_counter = Counter()
|
tok_counter = Counter()
|
||||||
|
|
@ -51,13 +54,16 @@ def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool
|
||||||
tsv_f.write(f"{src_dataset_path}\n")
|
tsv_f.write(f"{src_dataset_path}\n")
|
||||||
for md in manifest_data:
|
for md in manifest_data:
|
||||||
audio_fname = md["audio_filepath"]
|
audio_fname = md["audio_filepath"]
|
||||||
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()
|
pipe_toks = replace_redundant_spaces_with(
|
||||||
|
md["text"], "|"
|
||||||
|
).upper()
|
||||||
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
|
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
|
||||||
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
|
|
||||||
tok_counter.update(pipe_toks)
|
tok_counter.update(pipe_toks)
|
||||||
letter_toks = " ".join(pipe_toks) + " |\n"
|
letter_toks = " ".join(pipe_toks) + " |\n"
|
||||||
frame_count = soundfile.info(audio_fname).frames
|
frame_count = soundfile.info(audio_fname).frames
|
||||||
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute())
|
rel_path = Path(audio_fname).relative_to(
|
||||||
|
src_dataset_path.absolute()
|
||||||
|
)
|
||||||
ltr_f.write(letter_toks)
|
ltr_f.write(letter_toks)
|
||||||
tsv_f.write(f"{rel_path}\t{frame_count}\n")
|
tsv_f.write(f"{rel_path}\t{frame_count}\n")
|
||||||
with dict_ltr.open("w") as d_f:
|
with dict_ltr.open("w") as d_f:
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,26 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import typer
|
import typer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# import pandas as pd
|
# import pandas as pd
|
||||||
|
|
||||||
from plume.utils import (
|
from plume.utils import (
|
||||||
asr_manifest_reader,
|
asr_manifest_reader,
|
||||||
discard_except_digits,
|
discard_except_digits,
|
||||||
replace_digit_symbol,
|
replace_digit_symbol,
|
||||||
lazy_module
|
lazy_module,
|
||||||
# run_shell,
|
# run_shell,
|
||||||
)
|
)
|
||||||
from ...utils.transcribe import triton_transcribe_grpc_gen, transcribe_rpyc_gen
|
from ...utils.transcribe import triton_transcribe_grpc_gen, transcribe_rpyc_gen
|
||||||
|
|
||||||
pd = lazy_module('pandas')
|
pd = lazy_module("pandas")
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False):
|
def manifest(
|
||||||
|
manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False
|
||||||
|
):
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
|
||||||
host = "localhost"
|
host = "localhost"
|
||||||
|
|
@ -25,7 +28,9 @@ def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool
|
||||||
if rpyc:
|
if rpyc:
|
||||||
transcriber, audio_prep = transcribe_rpyc_gen(host, port)
|
transcriber, audio_prep = transcribe_rpyc_gen(host, port)
|
||||||
else:
|
else:
|
||||||
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole')
|
transcriber, audio_prep = triton_transcribe_grpc_gen(
|
||||||
|
host, port, method="whole"
|
||||||
|
)
|
||||||
result_path = manifest_file.parent / result_file
|
result_path = manifest_file.parent / result_file
|
||||||
manifest_list = list(asr_manifest_reader(manifest_file))
|
manifest_list = list(asr_manifest_reader(manifest_file))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import typer
|
||||||
|
|
||||||
from ...utils.serve import ASRService
|
from ...utils.serve import ASRService
|
||||||
from plume.utils import lazy_callable
|
from plume.utils import lazy_callable
|
||||||
|
|
||||||
# from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR
|
# from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR
|
||||||
# from .asr import Wav2Vec2ASR
|
# from .asr import Wav2Vec2ASR
|
||||||
|
|
||||||
|
|
@ -18,14 +19,29 @@ Wav2Vec2TransformersASR = lazy_callable(
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
# @app.command()
|
||||||
|
# def rpyc(
|
||||||
|
# w2v_path: Path = "/path/to/base.pt",
|
||||||
|
# ctc_path: Path = "/path/to/ctc.pt",
|
||||||
|
# target_dict_path: Path = "/path/to/dict.ltr.txt",
|
||||||
|
# port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||||
|
# ):
|
||||||
|
# w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path)
|
||||||
|
# service = ASRService(w2vasr)
|
||||||
|
# logging.basicConfig(
|
||||||
|
# level=logging.INFO,
|
||||||
|
# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
# )
|
||||||
|
# logging.info("starting asr server...")
|
||||||
|
# t = ThreadedServer(service, port=port)
|
||||||
|
# t.start()
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def rpyc(
|
def rpyc_dir(
|
||||||
w2v_path: Path = "/path/to/base.pt",
|
model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||||
ctc_path: Path = "/path/to/ctc.pt",
|
|
||||||
target_dict_path: Path = "/path/to/dict.ltr.txt",
|
|
||||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
|
||||||
):
|
):
|
||||||
w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path)
|
w2vasr = Wav2Vec2TransformersASR(model_dir)
|
||||||
service = ASRService(w2vasr)
|
service = ASRService(w2vasr)
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
|
|
@ -36,14 +52,6 @@ def rpyc(
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
|
|
||||||
ctc_path = model_dir / Path("ctc.pt")
|
|
||||||
w2v_path = model_dir / Path("base.pt")
|
|
||||||
target_dict_path = model_dir / Path("dict.ltr.txt")
|
|
||||||
rpyc(w2v_path, ctc_path, target_dict_path, port)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
app()
|
app()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,12 @@ import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# load model and tokenizer
|
# load model and tokenizer
|
||||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained(
|
||||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
"facebook/wav2vec2-large-960h-lv60-self"
|
||||||
|
)
|
||||||
|
model = Wav2Vec2ForCTC.from_pretrained(
|
||||||
|
"facebook/wav2vec2-large-960h-lv60-self"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# define function to read in sound file
|
# define function to read in sound file
|
||||||
|
|
|
||||||
|
|
@ -1,33 +1,19 @@
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
# from fairseq_cli.train import cli_main
|
# from fairseq_cli.train import cli_main
|
||||||
import sys
|
# import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shlex
|
# import shlex
|
||||||
from plume.utils import lazy_callable
|
from plume.utils import lazy_callable
|
||||||
|
|
||||||
cli_main = lazy_callable('fairseq_cli.train.cli_main')
|
cli_main = lazy_callable("fairseq_cli.train.cli_main")
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def local(dataset_path: Path):
|
def local(dataset_path: Path):
|
||||||
args = f'''--distributed-world-size 1 {dataset_path} \
|
pass
|
||||||
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
|
|
||||||
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
|
|
||||||
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
|
|
||||||
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
|
|
||||||
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
|
|
||||||
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
|
|
||||||
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
|
|
||||||
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
|
|
||||||
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
|
|
||||||
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
|
|
||||||
'''
|
|
||||||
new_args = ['train.py']
|
|
||||||
new_args.extend(shlex.split(args))
|
|
||||||
sys.argv = new_args
|
|
||||||
cli_main()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue