mirror of
https://github.com/malarinv/plume-asr.git
synced 2026-03-08 04:12:35 +00:00
massive refactor/rename to plume
This commit is contained in:
0
plume/models/wav2vec2/__init__.py
Normal file
0
plume/models/wav2vec2/__init__.py
Normal file
204
plume/models/wav2vec2/asr.py
Normal file
204
plume/models/wav2vec2/asr.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from io import BytesIO
|
||||
import warnings
|
||||
import itertools as it
|
||||
|
||||
import torch
|
||||
import soundfile as sf
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from fairseq import utils
|
||||
from fairseq.models import BaseFairseqModel
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.models.wav2vec.wav2vec2_asr import base_architecture, Wav2VecEncoder
|
||||
except ModuleNotFoundError:
|
||||
warnings.warn("Install fairseq")
|
||||
try:
|
||||
from wav2letter.decoder import CriterionType
|
||||
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
|
||||
except ModuleNotFoundError:
|
||||
warnings.warn("Install wav2letter")
|
||||
|
||||
|
||||
class Wav2VecCtc(BaseFairseqModel):
|
||||
def __init__(self, w2v_encoder, args):
|
||||
super().__init__()
|
||||
self.w2v_encoder = w2v_encoder
|
||||
self.args = args
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
return state_dict
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, target_dict):
|
||||
"""Build a new model instance."""
|
||||
base_architecture(args)
|
||||
w2v_encoder = Wav2VecEncoder(args, target_dict)
|
||||
return cls(w2v_encoder, args)
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs):
|
||||
"""Get normalized probabilities (or log probs) from a net's output."""
|
||||
logits = net_output["encoder_out"]
|
||||
if log_probs:
|
||||
return utils.log_softmax(logits.float(), dim=-1)
|
||||
else:
|
||||
return utils.softmax(logits.float(), dim=-1)
|
||||
|
||||
def forward(self, **kwargs):
|
||||
x = self.w2v_encoder(**kwargs)
|
||||
return x
|
||||
|
||||
|
||||
class W2lDecoder(object):
|
||||
def __init__(self, tgt_dict):
|
||||
self.tgt_dict = tgt_dict
|
||||
self.vocab_size = len(tgt_dict)
|
||||
self.nbest = 1
|
||||
|
||||
self.criterion_type = CriterionType.CTC
|
||||
self.blank = (
|
||||
tgt_dict.index("<ctc_blank>")
|
||||
if "<ctc_blank>" in tgt_dict.indices
|
||||
else tgt_dict.bos()
|
||||
)
|
||||
self.asg_transitions = None
|
||||
|
||||
def generate(self, model, sample, **unused):
|
||||
"""Generate a batch of inferences."""
|
||||
# model.forward normally channels prev_output_tokens into the decoder
|
||||
# separately, but SequenceGenerator directly calls model.encoder
|
||||
encoder_input = {
|
||||
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
|
||||
}
|
||||
emissions = self.get_emissions(model, encoder_input)
|
||||
return self.decode(emissions)
|
||||
|
||||
def get_emissions(self, model, encoder_input):
|
||||
"""Run encoder and normalize emissions"""
|
||||
# encoder_out = models[0].encoder(**encoder_input)
|
||||
encoder_out = model(**encoder_input)
|
||||
if self.criterion_type == CriterionType.CTC:
|
||||
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
|
||||
|
||||
return emissions.transpose(0, 1).float().cpu().contiguous()
|
||||
|
||||
def get_tokens(self, idxs):
|
||||
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
|
||||
idxs = (g[0] for g in it.groupby(idxs))
|
||||
idxs = filter(lambda x: x != self.blank, idxs)
|
||||
|
||||
return torch.LongTensor(list(idxs))
|
||||
|
||||
|
||||
class W2lViterbiDecoder(W2lDecoder):
|
||||
def __init__(self, tgt_dict):
|
||||
super().__init__(tgt_dict)
|
||||
|
||||
def decode(self, emissions):
|
||||
B, T, N = emissions.size()
|
||||
hypos = list()
|
||||
|
||||
if self.asg_transitions is None:
|
||||
transitions = torch.FloatTensor(N, N).zero_()
|
||||
else:
|
||||
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
|
||||
|
||||
viterbi_path = torch.IntTensor(B, T)
|
||||
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
|
||||
CpuViterbiPath.compute(
|
||||
B,
|
||||
T,
|
||||
N,
|
||||
get_data_ptr_as_bytes(emissions),
|
||||
get_data_ptr_as_bytes(transitions),
|
||||
get_data_ptr_as_bytes(viterbi_path),
|
||||
get_data_ptr_as_bytes(workspace),
|
||||
)
|
||||
return [
|
||||
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
|
||||
for b in range(B)
|
||||
]
|
||||
|
||||
|
||||
def post_process(sentence: str, symbol: str):
|
||||
if symbol == "sentencepiece":
|
||||
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
|
||||
elif symbol == "wordpiece":
|
||||
sentence = sentence.replace(" ", "").replace("_", " ").strip()
|
||||
elif symbol == "letter":
|
||||
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
||||
elif symbol == "_EOW":
|
||||
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
||||
elif symbol is not None and symbol != "none":
|
||||
sentence = (sentence + " ").replace(symbol, "").rstrip()
|
||||
return sentence
|
||||
|
||||
|
||||
def get_feature(filepath):
|
||||
def postprocess(feats, sample_rate):
|
||||
if feats.dim == 2:
|
||||
feats = feats.mean(-1)
|
||||
|
||||
assert feats.dim() == 1, feats.dim()
|
||||
|
||||
with torch.no_grad():
|
||||
feats = F.layer_norm(feats, feats.shape)
|
||||
return feats
|
||||
|
||||
wav, sample_rate = sf.read(filepath)
|
||||
feats = torch.from_numpy(wav).float()
|
||||
if torch.cuda.is_available():
|
||||
feats = feats.cuda()
|
||||
feats = postprocess(feats, sample_rate)
|
||||
return feats
|
||||
|
||||
|
||||
def load_model(ctc_model_path, w2v_model_path, target_dict):
|
||||
w2v = torch.load(ctc_model_path)
|
||||
w2v["args"].w2v_path = w2v_model_path
|
||||
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
|
||||
model.load_state_dict(w2v["model"], strict=True)
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
class Wav2Vec2ASR(object):
|
||||
"""docstring for Wav2Vec2ASR."""
|
||||
|
||||
def __init__(self, ctc_path, w2v_path, target_dict_path):
|
||||
super(Wav2Vec2ASR, self).__init__()
|
||||
self.target_dict = Dictionary.load(target_dict_path)
|
||||
|
||||
self.model = load_model(ctc_path, w2v_path, self.target_dict)
|
||||
self.model.eval()
|
||||
|
||||
self.generator = W2lViterbiDecoder(self.target_dict)
|
||||
|
||||
def transcribe(self, audio_data, greedy=True):
|
||||
aud_f = BytesIO(audio_data)
|
||||
# aud_seg = pydub.AudioSegment.from_file(aud_f)
|
||||
# feat_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
# feat_f = io.BytesIO()
|
||||
# feat_seg.export(feat_f, format='wav')
|
||||
# feat_f.seek(0)
|
||||
net_input = {}
|
||||
feature = get_feature(aud_f)
|
||||
net_input["source"] = feature.unsqueeze(0)
|
||||
|
||||
padding_mask = (
|
||||
torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
padding_mask = padding_mask.cuda()
|
||||
|
||||
net_input["padding_mask"] = padding_mask
|
||||
sample = {}
|
||||
sample["net_input"] = net_input
|
||||
|
||||
with torch.no_grad():
|
||||
hypo = self.generator.generate(self.model, sample, prefix_tokens=None)
|
||||
hyp_pieces = self.target_dict.string(hypo[0][0]["tokens"].int().cpu())
|
||||
result = post_process(hyp_pieces, "letter")
|
||||
return result
|
||||
86
plume/models/wav2vec2/data.py
Normal file
86
plume/models/wav2vec2/data.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
import shutil
|
||||
|
||||
import soundfile
|
||||
# import pydub
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from plume.utils import (
|
||||
ExtendedPath,
|
||||
replace_redundant_spaces_with,
|
||||
lazy_module
|
||||
)
|
||||
pydub = lazy_module('pydub')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True):
|
||||
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
|
||||
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
|
||||
tok_counter = Counter()
|
||||
shutil.copy(
|
||||
src_dataset_path / Path("test_manifest.json"),
|
||||
src_dataset_path / Path("valid_manifest.json"),
|
||||
)
|
||||
if unlink:
|
||||
src_wavs = src_dataset_path / Path("wavs")
|
||||
for wav_path in tqdm(list(src_wavs.glob("**/*.wav"))):
|
||||
audio_seg = (
|
||||
pydub.AudioSegment.from_wav(wav_path)
|
||||
.set_frame_rate(16000)
|
||||
.set_channels(1)
|
||||
)
|
||||
dest_path = dest_dataset_path / Path("wavs") / Path(wav_path.name)
|
||||
audio_seg.export(dest_path, format="wav")
|
||||
|
||||
for dataset_kind in ["train", "valid"]:
|
||||
abs_manifest_path = ExtendedPath(
|
||||
src_dataset_path / Path(f"{dataset_kind}_manifest.json")
|
||||
)
|
||||
manifest_data = list(abs_manifest_path.read_jsonl())
|
||||
o_tsv, o_ltr = f"{dataset_kind}.tsv", f"{dataset_kind}.ltr"
|
||||
out_tsv = dest_dataset_path / Path(o_tsv)
|
||||
out_ltr = dest_dataset_path / Path(o_ltr)
|
||||
with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f:
|
||||
if unlink:
|
||||
tsv_f.write(f"{dest_dataset_path}\n")
|
||||
else:
|
||||
tsv_f.write(f"{src_dataset_path}\n")
|
||||
for md in manifest_data:
|
||||
audio_fname = md["audio_filepath"]
|
||||
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()
|
||||
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
|
||||
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
|
||||
tok_counter.update(pipe_toks)
|
||||
letter_toks = " ".join(pipe_toks) + " |\n"
|
||||
frame_count = soundfile.info(audio_fname).frames
|
||||
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute())
|
||||
ltr_f.write(letter_toks)
|
||||
tsv_f.write(f"{rel_path}\t{frame_count}\n")
|
||||
with dict_ltr.open("w") as d_f:
|
||||
for k, v in tok_counter.most_common():
|
||||
d_f.write(f"{k} {v}\n")
|
||||
(src_dataset_path / Path("valid_manifest.json")).unlink()
|
||||
|
||||
|
||||
@app.command()
|
||||
def set_root(dataset_path: Path, root_path: Path):
|
||||
for dataset_kind in ["train", "valid"]:
|
||||
data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||
with data_file.open("r") as df:
|
||||
lines = df.readlines()
|
||||
with data_file.open("w") as df:
|
||||
lines[0] = str(root_path) + "\n"
|
||||
df.writelines(lines)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
49
plume/models/wav2vec2/eval.py
Normal file
49
plume/models/wav2vec2/eval.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
# import pandas as pd
|
||||
|
||||
from plume.utils import (
|
||||
asr_manifest_reader,
|
||||
discard_except_digits,
|
||||
replace_digit_symbol,
|
||||
lazy_module
|
||||
# run_shell,
|
||||
)
|
||||
from ...utils.transcribe import triton_transcribe_grpc_gen
|
||||
|
||||
pd = lazy_module('pandas')
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def manifest(manifest_file: Path, result_file: Path = "results.csv"):
|
||||
from pydub import AudioSegment
|
||||
|
||||
host = "localhost"
|
||||
port = 8044
|
||||
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole')
|
||||
result_path = manifest_file.parent / result_file
|
||||
manifest_list = list(asr_manifest_reader(manifest_file))
|
||||
|
||||
def compute_frame(d):
|
||||
audio_file = d["audio_path"]
|
||||
orig_text = d["text"]
|
||||
orig_num = discard_except_digits(replace_digit_symbol(orig_text))
|
||||
aud_seg = AudioSegment.from_file(audio_file)
|
||||
t_audio = audio_prep(aud_seg)
|
||||
asr_text = transcriber(t_audio)
|
||||
asr_num = discard_except_digits(replace_digit_symbol(asr_text))
|
||||
return {
|
||||
"audio_file": audio_file,
|
||||
"asr_text": asr_text,
|
||||
"asr_num": asr_num,
|
||||
"orig_text": orig_text,
|
||||
"orig_num": orig_num,
|
||||
"asr_match": orig_num == asr_num,
|
||||
}
|
||||
|
||||
# df_data = parallel_apply(compute_frame, manifest_list)
|
||||
df_data = map(compute_frame, tqdm(manifest_list))
|
||||
df = pd.DataFrame(df_data)
|
||||
df.to_csv(result_path)
|
||||
53
plume/models/wav2vec2/serve.py
Normal file
53
plume/models/wav2vec2/serve.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# from rpyc.utils.server import ThreadedServer
|
||||
import typer
|
||||
|
||||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
# from .asr import Wav2Vec2ASR
|
||||
|
||||
ThreadedServer = lazy_callable('rpyc.utils.server.ThreadedServer')
|
||||
Wav2Vec2ASR = lazy_callable('plume.models.wav2vec2.asr.Wav2Vec2ASR')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc(
|
||||
w2v_path: Path = "/path/to/base.pt",
|
||||
ctc_path: Path = "/path/to/ctc.pt",
|
||||
target_dict_path: Path = "/path/to/dict.ltr.txt",
|
||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
):
|
||||
for p in [w2v_path, ctc_path, target_dict_path]:
|
||||
if not p.exists():
|
||||
logging.info(f"{p} doesn't exists")
|
||||
return
|
||||
w2vasr = Wav2Vec2ASR(str(ctc_path), str(w2v_path), str(target_dict_path))
|
||||
service = ASRService(w2vasr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logging.info("starting asr server...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
|
||||
ctc_path = model_dir / Path("ctc.pt")
|
||||
w2v_path = model_dir / Path("base.pt")
|
||||
target_dict_path = model_dir / Path("dict.ltr.txt")
|
||||
rpyc(w2v_path, ctc_path, target_dict_path, port)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
34
plume/models/wav2vec2/train.py
Normal file
34
plume/models/wav2vec2/train.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import typer
|
||||
# from fairseq_cli.train import cli_main
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import shlex
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
cli_main = lazy_callable('fairseq_cli.train.cli_main')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def local(dataset_path: Path):
|
||||
args = f'''--distributed-world-size 1 {dataset_path} \
|
||||
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
|
||||
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
|
||||
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
|
||||
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
|
||||
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
|
||||
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
|
||||
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
|
||||
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
|
||||
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
|
||||
'''
|
||||
new_args = ['train.py']
|
||||
new_args.extend(shlex.split(args))
|
||||
sys.argv = new_args
|
||||
cli_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
Reference in New Issue
Block a user