1
0
mirror of https://github.com/malarinv/plume-asr.git synced 2026-03-08 04:12:35 +00:00

massive refactor/rename to plume

This commit is contained in:
2021-02-23 19:43:33 +05:30
parent e8f58a5043
commit ed6117559a
51 changed files with 2864 additions and 1037 deletions

View File

View File

@@ -0,0 +1,204 @@
from io import BytesIO
import warnings
import itertools as it
import torch
import soundfile as sf
import torch.nn.functional as F
try:
from fairseq import utils
from fairseq.models import BaseFairseqModel
from fairseq.data import Dictionary
from fairseq.models.wav2vec.wav2vec2_asr import base_architecture, Wav2VecEncoder
except ModuleNotFoundError:
warnings.warn("Install fairseq")
try:
from wav2letter.decoder import CriterionType
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
except ModuleNotFoundError:
warnings.warn("Install wav2letter")
class Wav2VecCtc(BaseFairseqModel):
def __init__(self, w2v_encoder, args):
super().__init__()
self.w2v_encoder = w2v_encoder
self.args = args
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, args, target_dict):
"""Build a new model instance."""
base_architecture(args)
w2v_encoder = Wav2VecEncoder(args, target_dict)
return cls(w2v_encoder, args)
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output["encoder_out"]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x
class W2lDecoder(object):
def __init__(self, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = 1
self.criterion_type = CriterionType.CTC
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
self.asg_transitions = None
def generate(self, model, sample, **unused):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(model, encoder_input)
return self.decode(emissions)
def get_emissions(self, model, encoder_input):
"""Run encoder and normalize emissions"""
# encoder_out = models[0].encoder(**encoder_input)
encoder_out = model(**encoder_input)
if self.criterion_type == CriterionType.CTC:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = list()
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
for b in range(B)
]
def post_process(sentence: str, symbol: str):
if symbol == "sentencepiece":
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
elif symbol == "wordpiece":
sentence = sentence.replace(" ", "").replace("_", " ").strip()
elif symbol == "letter":
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol is not None and symbol != "none":
sentence = (sentence + " ").replace(symbol, "").rstrip()
return sentence
def get_feature(filepath):
def postprocess(feats, sample_rate):
if feats.dim == 2:
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
wav, sample_rate = sf.read(filepath)
feats = torch.from_numpy(wav).float()
if torch.cuda.is_available():
feats = feats.cuda()
feats = postprocess(feats, sample_rate)
return feats
def load_model(ctc_model_path, w2v_model_path, target_dict):
w2v = torch.load(ctc_model_path)
w2v["args"].w2v_path = w2v_model_path
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
model.load_state_dict(w2v["model"], strict=True)
if torch.cuda.is_available():
model = model.cuda()
return model
class Wav2Vec2ASR(object):
"""docstring for Wav2Vec2ASR."""
def __init__(self, ctc_path, w2v_path, target_dict_path):
super(Wav2Vec2ASR, self).__init__()
self.target_dict = Dictionary.load(target_dict_path)
self.model = load_model(ctc_path, w2v_path, self.target_dict)
self.model.eval()
self.generator = W2lViterbiDecoder(self.target_dict)
def transcribe(self, audio_data, greedy=True):
aud_f = BytesIO(audio_data)
# aud_seg = pydub.AudioSegment.from_file(aud_f)
# feat_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
# feat_f = io.BytesIO()
# feat_seg.export(feat_f, format='wav')
# feat_f.seek(0)
net_input = {}
feature = get_feature(aud_f)
net_input["source"] = feature.unsqueeze(0)
padding_mask = (
torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)
)
if torch.cuda.is_available():
padding_mask = padding_mask.cuda()
net_input["padding_mask"] = padding_mask
sample = {}
sample["net_input"] = net_input
with torch.no_grad():
hypo = self.generator.generate(self.model, sample, prefix_tokens=None)
hyp_pieces = self.target_dict.string(hypo[0][0]["tokens"].int().cpu())
result = post_process(hyp_pieces, "letter")
return result

View File

@@ -0,0 +1,86 @@
from pathlib import Path
from collections import Counter
import shutil
import soundfile
# import pydub
import typer
from tqdm import tqdm
from plume.utils import (
ExtendedPath,
replace_redundant_spaces_with,
lazy_module
)
pydub = lazy_module('pydub')
app = typer.Typer()
@app.command()
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True):
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
tok_counter = Counter()
shutil.copy(
src_dataset_path / Path("test_manifest.json"),
src_dataset_path / Path("valid_manifest.json"),
)
if unlink:
src_wavs = src_dataset_path / Path("wavs")
for wav_path in tqdm(list(src_wavs.glob("**/*.wav"))):
audio_seg = (
pydub.AudioSegment.from_wav(wav_path)
.set_frame_rate(16000)
.set_channels(1)
)
dest_path = dest_dataset_path / Path("wavs") / Path(wav_path.name)
audio_seg.export(dest_path, format="wav")
for dataset_kind in ["train", "valid"]:
abs_manifest_path = ExtendedPath(
src_dataset_path / Path(f"{dataset_kind}_manifest.json")
)
manifest_data = list(abs_manifest_path.read_jsonl())
o_tsv, o_ltr = f"{dataset_kind}.tsv", f"{dataset_kind}.ltr"
out_tsv = dest_dataset_path / Path(o_tsv)
out_ltr = dest_dataset_path / Path(o_ltr)
with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f:
if unlink:
tsv_f.write(f"{dest_dataset_path}\n")
else:
tsv_f.write(f"{src_dataset_path}\n")
for md in manifest_data:
audio_fname = md["audio_filepath"]
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
tok_counter.update(pipe_toks)
letter_toks = " ".join(pipe_toks) + " |\n"
frame_count = soundfile.info(audio_fname).frames
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute())
ltr_f.write(letter_toks)
tsv_f.write(f"{rel_path}\t{frame_count}\n")
with dict_ltr.open("w") as d_f:
for k, v in tok_counter.most_common():
d_f.write(f"{k} {v}\n")
(src_dataset_path / Path("valid_manifest.json")).unlink()
@app.command()
def set_root(dataset_path: Path, root_path: Path):
for dataset_kind in ["train", "valid"]:
data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
with data_file.open("r") as df:
lines = df.readlines()
with data_file.open("w") as df:
lines[0] = str(root_path) + "\n"
df.writelines(lines)
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,49 @@
from pathlib import Path
import typer
from tqdm import tqdm
# import pandas as pd
from plume.utils import (
asr_manifest_reader,
discard_except_digits,
replace_digit_symbol,
lazy_module
# run_shell,
)
from ...utils.transcribe import triton_transcribe_grpc_gen
pd = lazy_module('pandas')
app = typer.Typer()
@app.command()
def manifest(manifest_file: Path, result_file: Path = "results.csv"):
from pydub import AudioSegment
host = "localhost"
port = 8044
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole')
result_path = manifest_file.parent / result_file
manifest_list = list(asr_manifest_reader(manifest_file))
def compute_frame(d):
audio_file = d["audio_path"]
orig_text = d["text"]
orig_num = discard_except_digits(replace_digit_symbol(orig_text))
aud_seg = AudioSegment.from_file(audio_file)
t_audio = audio_prep(aud_seg)
asr_text = transcriber(t_audio)
asr_num = discard_except_digits(replace_digit_symbol(asr_text))
return {
"audio_file": audio_file,
"asr_text": asr_text,
"asr_num": asr_num,
"orig_text": orig_text,
"orig_num": orig_num,
"asr_match": orig_num == asr_num,
}
# df_data = parallel_apply(compute_frame, manifest_list)
df_data = map(compute_frame, tqdm(manifest_list))
df = pd.DataFrame(df_data)
df.to_csv(result_path)

View File

@@ -0,0 +1,53 @@
import os
import logging
from pathlib import Path
# from rpyc.utils.server import ThreadedServer
import typer
from ...utils.serve import ASRService
from plume.utils import lazy_callable
# from .asr import Wav2Vec2ASR
ThreadedServer = lazy_callable('rpyc.utils.server.ThreadedServer')
Wav2Vec2ASR = lazy_callable('plume.models.wav2vec2.asr.Wav2Vec2ASR')
app = typer.Typer()
@app.command()
def rpyc(
w2v_path: Path = "/path/to/base.pt",
ctc_path: Path = "/path/to/ctc.pt",
target_dict_path: Path = "/path/to/dict.ltr.txt",
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
):
for p in [w2v_path, ctc_path, target_dict_path]:
if not p.exists():
logging.info(f"{p} doesn't exists")
return
w2vasr = Wav2Vec2ASR(str(ctc_path), str(w2v_path), str(target_dict_path))
service = ASRService(w2vasr)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logging.info("starting asr server...")
t = ThreadedServer(service, port=port)
t.start()
@app.command()
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
ctc_path = model_dir / Path("ctc.pt")
w2v_path = model_dir / Path("base.pt")
target_dict_path = model_dir / Path("dict.ltr.txt")
rpyc(w2v_path, ctc_path, target_dict_path, port)
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,34 @@
import typer
# from fairseq_cli.train import cli_main
import sys
from pathlib import Path
import shlex
from plume.utils import lazy_callable
cli_main = lazy_callable('fairseq_cli.train.cli_main')
app = typer.Typer()
@app.command()
def local(dataset_path: Path):
args = f'''--distributed-world-size 1 {dataset_path} \
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
'''
new_args = ['train.py']
new_args.extend(shlex.split(args))
sys.argv = new_args
cli_main()
if __name__ == "__main__":
cli_main()