mirror of
https://github.com/malarinv/plume-asr.git
synced 2026-03-08 04:12:35 +00:00
1. refactor package root to src/ layout
2. add framwork suffix for models 3. change black max columns to 79 4. add tests 5. integrate vad, encrypt and refactor manifest, regentity, extended_path, audio, parallel utils 6. added ui utils for encrypted preview 7. wip marblenet model 8. added transformers based wav2vec2 inference 9. update readme and manifest 10. add deploy setup target
This commit is contained in:
1
src/plume/models/__init__.py
Normal file
1
src/plume/models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# from . import jasper, wav2vec2, matchboxnet
|
||||
0
src/plume/models/jasper_nemo/__init__.py
Normal file
0
src/plume/models/jasper_nemo/__init__.py
Normal file
132
src/plume/models/jasper_nemo/asr.py
Normal file
132
src/plume/models/jasper_nemo/asr.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import os
|
||||
import tempfile
|
||||
from ruamel.yaml import YAML
|
||||
import json
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import wave
|
||||
from nemo.collections.asr.helpers import post_process_predictions
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
WORK_DIR = "/tmp"
|
||||
|
||||
|
||||
class JasperASR(object):
|
||||
"""docstring for JasperASR."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_yaml,
|
||||
encoder_checkpoint,
|
||||
decoder_checkpoint,
|
||||
language_model=None,
|
||||
):
|
||||
super(JasperASR, self).__init__()
|
||||
# Read model YAML
|
||||
yaml = YAML(typ="safe")
|
||||
with open(model_yaml) as f:
|
||||
jasper_model_definition = yaml.load(f)
|
||||
self.neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
)
|
||||
self.labels = jasper_model_definition["labels"]
|
||||
self.data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor()
|
||||
self.jasper_encoder = nemo_asr.JasperEncoder(
|
||||
jasper=jasper_model_definition["JasperEncoder"]["jasper"],
|
||||
activation=jasper_model_definition["JasperEncoder"]["activation"],
|
||||
feat_in=jasper_model_definition[
|
||||
"AudioToMelSpectrogramPreprocessor"
|
||||
]["features"],
|
||||
)
|
||||
self.jasper_encoder.restore_from(encoder_checkpoint, local_rank=0)
|
||||
self.jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=1024, num_classes=len(self.labels)
|
||||
)
|
||||
self.jasper_decoder.restore_from(decoder_checkpoint, local_rank=0)
|
||||
self.greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
self.beam_search_with_lm = None
|
||||
if language_model:
|
||||
self.beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
|
||||
vocab=self.labels,
|
||||
beam_width=64,
|
||||
alpha=2.0,
|
||||
beta=1.0,
|
||||
lm_path=language_model,
|
||||
num_cpus=max(os.cpu_count(), 1),
|
||||
)
|
||||
|
||||
def transcribe(self, audio_data, greedy=True):
|
||||
audio_file = tempfile.NamedTemporaryFile(
|
||||
dir=WORK_DIR, prefix="jasper_audio.", delete=False
|
||||
)
|
||||
# audio_file.write(audio_data)
|
||||
audio_file.close()
|
||||
audio_file_path = audio_file.name
|
||||
wf = wave.open(audio_file_path, "w")
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(24000)
|
||||
wf.writeframesraw(audio_data)
|
||||
wf.close()
|
||||
manifest = {
|
||||
"audio_filepath": audio_file_path,
|
||||
"duration": 60,
|
||||
"text": "todo",
|
||||
}
|
||||
manifest_file = tempfile.NamedTemporaryFile(
|
||||
dir=WORK_DIR, prefix="jasper_manifest.", delete=False, mode="w"
|
||||
)
|
||||
manifest_file.write(json.dumps(manifest))
|
||||
manifest_file.close()
|
||||
manifest_file_path = manifest_file.name
|
||||
data_layer = nemo_asr.AudioToTextDataLayer(
|
||||
shuffle=False,
|
||||
manifest_filepath=manifest_file_path,
|
||||
labels=self.labels,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
# Define inference DAG
|
||||
audio_signal, audio_signal_len, _, _ = data_layer()
|
||||
processed_signal, processed_signal_len = self.data_preprocessor(
|
||||
input_signal=audio_signal, length=audio_signal_len
|
||||
)
|
||||
encoded, encoded_len = self.jasper_encoder(
|
||||
audio_signal=processed_signal, length=processed_signal_len
|
||||
)
|
||||
log_probs = self.jasper_decoder(encoder_output=encoded)
|
||||
predictions = self.greedy_decoder(log_probs=log_probs)
|
||||
|
||||
if greedy:
|
||||
eval_tensors = [predictions]
|
||||
else:
|
||||
if self.beam_search_with_lm:
|
||||
logging.info("Running with beam search")
|
||||
beam_predictions = self.beam_search_with_lm(
|
||||
log_probs=log_probs, log_probs_length=encoded_len
|
||||
)
|
||||
eval_tensors = [beam_predictions]
|
||||
else:
|
||||
logging.info(
|
||||
"language_model not specified. falling back to greedy decoding."
|
||||
)
|
||||
eval_tensors = [predictions]
|
||||
|
||||
tensors = self.neural_factory.infer(tensors=eval_tensors)
|
||||
prediction = post_process_predictions(tensors[0], self.labels)
|
||||
prediction_text = ". ".join(prediction)
|
||||
os.unlink(manifest_file.name)
|
||||
os.unlink(audio_file.name)
|
||||
return prediction_text
|
||||
|
||||
def transcribe_file(self, audio_file, *args, **kwargs):
|
||||
tscript_file_path = audio_file.with_suffix(".txt")
|
||||
audio_file_path = str(audio_file)
|
||||
with wave.open(audio_file_path, "r") as af:
|
||||
frame_count = af.getnframes()
|
||||
audio_data = af.readframes(frame_count)
|
||||
transcription = self.transcribe(audio_data, *args, **kwargs)
|
||||
with open(tscript_file_path, "w") as tf:
|
||||
tf.write(transcription)
|
||||
24
src/plume/models/jasper_nemo/data.py
Normal file
24
src/plume/models/jasper_nemo/data.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from pathlib import Path
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def set_root(dataset_path: Path, root_path: Path):
|
||||
pass
|
||||
# for dataset_kind in ["train", "valid"]:
|
||||
# data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||
# with data_file.open("r") as df:
|
||||
# lines = df.readlines()
|
||||
# with data_file.open("w") as df:
|
||||
# lines[0] = str(root_path) + "\n"
|
||||
# df.writelines(lines)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
340
src/plume/models/jasper_nemo/data_loaders.py
Normal file
340
src/plume/models/jasper_nemo/data_loaders.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from functools import partial
|
||||
import tempfile
|
||||
|
||||
# from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import nemo
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.backends.pytorch import DataLayerNM
|
||||
from nemo.core import DeviceType
|
||||
|
||||
# from nemo.core.neural_types import *
|
||||
from nemo.core.neural_types import (
|
||||
NeuralType,
|
||||
AudioSignal,
|
||||
LengthsType,
|
||||
LabelsType,
|
||||
)
|
||||
from nemo.utils.decorators import add_port_docs
|
||||
|
||||
from nemo.collections.asr.parts.dataset import (
|
||||
# AudioDataset,
|
||||
# AudioLabelDataset,
|
||||
# KaldiFeatureDataset,
|
||||
# TranscriptDataset,
|
||||
parsers,
|
||||
collections,
|
||||
seq_collate_fn,
|
||||
)
|
||||
|
||||
# from functools import lru_cache
|
||||
import rpyc
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
from .featurizer import RpycWaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
|
||||
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
class CachedAudioDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset that loads tensors via a json file containing paths to audio
|
||||
files, transcripts, and durations (in seconds). Each new line is a
|
||||
different sample. Example below:
|
||||
|
||||
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
|
||||
"/path/to/audio.txt", "duration": 23.147}
|
||||
...
|
||||
{"audio_filepath": "/path/to/audio.wav", "text": "the
|
||||
transcription", offset": 301.75, "duration": 0.82, "utt":
|
||||
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
||||
|
||||
Args:
|
||||
manifest_filepath: Path to manifest json as described above. Can
|
||||
be comma-separated paths.
|
||||
labels: String containing all the possible characters to map to
|
||||
featurizer: Initialized featurizer class that converts paths of
|
||||
audio to feature tensors
|
||||
max_duration: If audio exceeds this length, do not include in dataset
|
||||
min_duration: If audio is less than this length, do not include
|
||||
in dataset
|
||||
max_utts: Limit number of utterances
|
||||
blank_index: blank character index, default = -1
|
||||
unk_index: unk_character index, default = -1
|
||||
normalize: whether to normalize transcript text (default): True
|
||||
bos_id: Id of beginning of sequence symbol to append if not None
|
||||
eos_id: Id of end of sequence symbol to append if not None
|
||||
load_audio: Boolean flag indicate whether do or not load audio
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
featurizer,
|
||||
max_duration=None,
|
||||
min_duration=None,
|
||||
max_utts=0,
|
||||
blank_index=-1,
|
||||
unk_index=-1,
|
||||
normalize=True,
|
||||
trim=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
load_audio=True,
|
||||
parser="en",
|
||||
):
|
||||
self.collection = collections.ASRAudioText(
|
||||
manifests_files=manifest_filepath.split(","),
|
||||
parser=parsers.make_parser(
|
||||
labels=labels,
|
||||
name=parser,
|
||||
unk_id=unk_index,
|
||||
blank_id=blank_index,
|
||||
do_normalize=normalize,
|
||||
),
|
||||
min_duration=min_duration,
|
||||
max_duration=max_duration,
|
||||
max_number=max_utts,
|
||||
)
|
||||
self.index_feature_map = {}
|
||||
|
||||
self.featurizer = featurizer
|
||||
self.trim = trim
|
||||
self.eos_id = eos_id
|
||||
self.bos_id = bos_id
|
||||
self.load_audio = load_audio
|
||||
print(f"initializing dataset {manifest_filepath}")
|
||||
|
||||
def exec_func(i):
|
||||
return self[i]
|
||||
|
||||
task_count = len(self.collection)
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all loading tasks")
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, range(task_count)),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=task_count,
|
||||
)
|
||||
)
|
||||
print(f"initializing complete")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.collection[index]
|
||||
if self.load_audio:
|
||||
cached_features = self.index_feature_map.get(index)
|
||||
if cached_features is not None:
|
||||
features = cached_features
|
||||
else:
|
||||
features = self.featurizer.process(
|
||||
sample.audio_file,
|
||||
offset=0,
|
||||
duration=sample.duration,
|
||||
trim=self.trim,
|
||||
)
|
||||
self.index_feature_map[index] = features
|
||||
f, fl = features, torch.tensor(features.shape[0]).long()
|
||||
else:
|
||||
f, fl = None, None
|
||||
|
||||
t, tl = sample.text_tokens, len(sample.text_tokens)
|
||||
if self.bos_id is not None:
|
||||
t = [self.bos_id] + t
|
||||
tl += 1
|
||||
if self.eos_id is not None:
|
||||
t = t + [self.eos_id]
|
||||
tl += 1
|
||||
|
||||
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.collection)
|
||||
|
||||
|
||||
class RpycAudioToTextDataLayer(DataLayerNM):
|
||||
"""Data Layer for general ASR tasks.
|
||||
|
||||
Module which reads ASR labeled data. It accepts comma-separated
|
||||
JSON manifest files describing the correspondence between wav audio files
|
||||
and their transcripts. JSON files should be of the following format::
|
||||
|
||||
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
|
||||
transcript_0}
|
||||
...
|
||||
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
|
||||
transcript_n}
|
||||
|
||||
Args:
|
||||
manifest_filepath (str): Dataset parameter.
|
||||
Path to JSON containing data.
|
||||
labels (list): Dataset parameter.
|
||||
List of characters that can be output by the ASR model.
|
||||
For Jasper, this is the 28 character set {a-z '}. The CTC blank
|
||||
symbol is automatically added later for models using ctc.
|
||||
batch_size (int): batch size
|
||||
sample_rate (int): Target sampling rate for data. Audio files will be
|
||||
resampled to sample_rate if it is not already.
|
||||
Defaults to 16000.
|
||||
int_values (bool): Bool indicating whether the audio file is saved as
|
||||
int data or float data.
|
||||
Defaults to False.
|
||||
eos_id (id): Dataset parameter.
|
||||
End of string symbol id used for seq2seq models.
|
||||
Defaults to None.
|
||||
min_duration (float): Dataset parameter.
|
||||
All training files which have a duration less than min_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to 0.1.
|
||||
max_duration (float): Dataset parameter.
|
||||
All training files which have a duration more than max_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to None.
|
||||
normalize_transcripts (bool): Dataset parameter.
|
||||
Whether to use automatic text cleaning.
|
||||
It is highly recommended to manually clean text for best results.
|
||||
Defaults to True.
|
||||
trim_silence (bool): Whether to use trim silence from beginning and end
|
||||
of audio signal using librosa.effects.trim().
|
||||
Defaults to False.
|
||||
load_audio (bool): Dataset parameter.
|
||||
Controls whether the dataloader loads the audio signal and
|
||||
transcript or just the transcript.
|
||||
Defaults to True.
|
||||
drop_last (bool): See PyTorch DataLoader.
|
||||
Defaults to False.
|
||||
shuffle (bool): See PyTorch DataLoader.
|
||||
Defaults to True.
|
||||
num_workers (int): See PyTorch DataLoader.
|
||||
Defaults to 0.
|
||||
perturb_config (dict): Currently disabled.
|
||||
"""
|
||||
|
||||
@property
|
||||
@add_port_docs()
|
||||
def output_ports(self):
|
||||
"""Returns definitions of module output ports."""
|
||||
return {
|
||||
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
"audio_signal": NeuralType(
|
||||
("B", "T"),
|
||||
AudioSignal(freq=self._sample_rate)
|
||||
if self is not None and self._sample_rate is not None
|
||||
else AudioSignal(),
|
||||
),
|
||||
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
|
||||
"transcripts": NeuralType(("B", "T"), LabelsType()),
|
||||
"transcript_length": NeuralType(tuple("B"), LengthsType()),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
batch_size,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
pad_id=None,
|
||||
min_duration=0.1,
|
||||
max_duration=None,
|
||||
normalize_transcripts=True,
|
||||
trim_silence=False,
|
||||
load_audio=True,
|
||||
rpyc_host="",
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
):
|
||||
super().__init__()
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
def rpyc_root_fn():
|
||||
return rpyc.connect(
|
||||
rpyc_host, 8064, config={"sync_request_timeout": 600}
|
||||
).root
|
||||
|
||||
rpyc_conn = rpyc_root_fn()
|
||||
|
||||
self._featurizer = RpycWaveformFeaturizer(
|
||||
sample_rate=self._sample_rate,
|
||||
int_values=int_values,
|
||||
augmentor=None,
|
||||
rpyc_conn=rpyc_conn,
|
||||
)
|
||||
|
||||
def read_remote_manifests():
|
||||
local_mp = []
|
||||
for mrp in manifest_filepath.split(","):
|
||||
md = rpyc_conn.read_path(mrp)
|
||||
mf = tempfile.NamedTemporaryFile(
|
||||
dir="/tmp", prefix="jasper_manifest.", delete=False
|
||||
)
|
||||
mf.write(md)
|
||||
mf.close()
|
||||
local_mp.append(mf.name)
|
||||
return ",".join(local_mp)
|
||||
|
||||
local_manifest_filepath = read_remote_manifests()
|
||||
dataset_params = {
|
||||
"manifest_filepath": local_manifest_filepath,
|
||||
"labels": labels,
|
||||
"featurizer": self._featurizer,
|
||||
"max_duration": max_duration,
|
||||
"min_duration": min_duration,
|
||||
"normalize": normalize_transcripts,
|
||||
"trim": trim_silence,
|
||||
"bos_id": bos_id,
|
||||
"eos_id": eos_id,
|
||||
"load_audio": load_audio,
|
||||
}
|
||||
|
||||
self._dataset = CachedAudioDataset(**dataset_params)
|
||||
self._batch_size = batch_size
|
||||
|
||||
# Set up data loader
|
||||
if self._placement == DeviceType.AllGpu:
|
||||
logging.info("Parallelizing Datalayer.")
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
self._dataset
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
if batch_size == -1:
|
||||
batch_size = len(self._dataset)
|
||||
|
||||
pad_id = 0 if pad_id is None else pad_id
|
||||
self._dataloader = torch.utils.data.DataLoader(
|
||||
dataset=self._dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
|
||||
drop_last=drop_last,
|
||||
shuffle=shuffle if sampler is None else False,
|
||||
sampler=sampler,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def data_iterator(self):
|
||||
return self._dataloader
|
||||
376
src/plume/models/jasper_nemo/eval.py
Normal file
376
src/plume/models/jasper_nemo/eval.py
Normal file
@@ -0,0 +1,376 @@
|
||||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
# import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
# monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
# from nemo.utils.lr_policies import CosineAnnealing
|
||||
from training.data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="encoder checkpoint file: JasperEncoder.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="decoder checkpoint file: JasperDecoderForCTC.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
# data_layer = data_loader_layer(
|
||||
# manifest_filepath=args.train_dataset,
|
||||
# sample_rate=sample_rate,
|
||||
# labels=vocab,
|
||||
# batch_size=args.batch_size,
|
||||
# num_workers=cpu_per_traindl,
|
||||
# **train_dl_params,
|
||||
# # normalize_transcripts=False
|
||||
# )
|
||||
#
|
||||
# N = len(data_layer)
|
||||
# steps_per_epoch = math.ceil(
|
||||
# N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
# )
|
||||
# logging.info("Have {0} examples to train on.".format(N))
|
||||
#
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate,
|
||||
**jasper_params["AudioToMelSpectrogramPreprocessor"],
|
||||
)
|
||||
|
||||
# multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
# if multiply_batch_config:
|
||||
# multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
#
|
||||
# spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
# if spectr_augment_config:
|
||||
# data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
# **spectr_augment_config
|
||||
# )
|
||||
#
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
# if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
# else:
|
||||
# logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
# logging.info("================================")
|
||||
# logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
|
||||
# logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
|
||||
# logging.info(
|
||||
# f"Total number of parameters in model: "
|
||||
# f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
# )
|
||||
# logging.info("================================")
|
||||
#
|
||||
# # Train DAG
|
||||
# (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
|
||||
# processed_signal_t, p_length_t = data_preprocessor(
|
||||
# input_signal=audio_signal_t, length=a_sig_length_t
|
||||
# )
|
||||
#
|
||||
# if multiply_batch_config:
|
||||
# (
|
||||
# processed_signal_t,
|
||||
# p_length_t,
|
||||
# transcript_t,
|
||||
# transcript_len_t,
|
||||
# ) = multiply_batch(
|
||||
# in_x=processed_signal_t,
|
||||
# in_x_len=p_length_t,
|
||||
# in_y=transcript_t,
|
||||
# in_y_len=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# if spectr_augment_config:
|
||||
# processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
|
||||
#
|
||||
# encoded_t, encoded_len_t = jasper_encoder(
|
||||
# audio_signal=processed_signal_t, length=p_length_t
|
||||
# )
|
||||
# log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
# predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
# loss_t = ctc_loss(
|
||||
# log_probs=log_probs_t,
|
||||
# targets=transcript_t,
|
||||
# input_length=encoded_len_t,
|
||||
# target_length=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# # Callbacks needed to print info to console and Tensorboard
|
||||
# train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
# tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
# print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
# get_tb_values=lambda x: [("loss", x[0])],
|
||||
# tb_writer=neural_factory.tb_writer,
|
||||
# )
|
||||
#
|
||||
# chpt_callback = nemo.core.CheckpointCallback(
|
||||
# folder=neural_factory.checkpoint_dir,
|
||||
# load_from_folder=args.load_dir,
|
||||
# step_freq=args.checkpoint_save_freq,
|
||||
# checkpoints_to_keep=30,
|
||||
# )
|
||||
#
|
||||
# callbacks = [train_callback, chpt_callback]
|
||||
callbacks = []
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(
|
||||
audio_signal_e,
|
||||
a_sig_length_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[
|
||||
loss_e,
|
||||
predictions_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(
|
||||
process_evaluation_epoch, tag=tagname
|
||||
),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return callbacks
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
# name = construct_name(
|
||||
# args.exp_name,
|
||||
# args.lr,
|
||||
# args.batch_size,
|
||||
# args.max_steps,
|
||||
# args.num_epochs,
|
||||
# args.weight_decay,
|
||||
# args.optimizer,
|
||||
# args.iter_per_step,
|
||||
# )
|
||||
# log_dir = name
|
||||
# if args.work_dir:
|
||||
# log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
# local_rank=args.local_rank,
|
||||
# optimization_level=args.amp_opt_level,
|
||||
# log_dir=log_dir,
|
||||
# checkpoint_dir=args.checkpoint_dir,
|
||||
# create_tb_writer=args.create_tb_writer,
|
||||
# files_to_copy=[args.model_config, __file__],
|
||||
# cudnn_benchmark=args.cudnn_benchmark,
|
||||
# tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
# checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
callbacks = create_all_dags(args, neural_factory)
|
||||
# evaluate model
|
||||
neural_factory.eval(callbacks=callbacks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
61
src/plume/models/jasper_nemo/featurizer.py
Normal file
61
src/plume/models/jasper_nemo/featurizer.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# import math
|
||||
|
||||
# import librosa
|
||||
import torch
|
||||
import pickle
|
||||
|
||||
# import torch.nn as nn
|
||||
# from torch_stft import STFT
|
||||
|
||||
# from nemo import logging
|
||||
from nemo.collections.asr.parts.perturb import AudioAugmentor
|
||||
|
||||
# from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
|
||||
class RpycWaveformFeaturizer(object):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
augmentor=None,
|
||||
rpyc_conn=None,
|
||||
):
|
||||
self.augmentor = (
|
||||
augmentor if augmentor is not None else AudioAugmentor()
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.int_values = int_values
|
||||
self.remote_path_samples = rpyc_conn.get_path_samples
|
||||
|
||||
def max_augmentation_length(self, length):
|
||||
return self.augmentor.max_augmentation_length(length)
|
||||
|
||||
def process(self, file_path, offset=0, duration=0, trim=False):
|
||||
audio = self.remote_path_samples(
|
||||
file_path,
|
||||
target_sr=self.sample_rate,
|
||||
int_values=self.int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
return torch.tensor(pickle.loads(audio), dtype=torch.float)
|
||||
|
||||
def process_segment(self, audio_segment):
|
||||
self.augmentor.perturb(audio_segment)
|
||||
return torch.tensor(audio_segment, dtype=torch.float)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, input_config, perturbation_configs=None):
|
||||
if perturbation_configs is not None:
|
||||
aa = AudioAugmentor.from_config(perturbation_configs)
|
||||
else:
|
||||
aa = None
|
||||
|
||||
sample_rate = input_config.get("sample_rate", 16000)
|
||||
int_values = input_config.get("int_values", False)
|
||||
|
||||
return cls(
|
||||
sample_rate=sample_rate, int_values=int_values, augmentor=aa
|
||||
)
|
||||
54
src/plume/models/jasper_nemo/serve.py
Normal file
54
src/plume/models/jasper_nemo/serve.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import typer
|
||||
|
||||
# from .asr import JasperASR
|
||||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
JasperASR = lazy_callable("plume.models.jasper_nemo.asr.JasperASR")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc(
|
||||
encoder_path: Path = "/path/to/encoder.pt",
|
||||
decoder_path: Path = "/path/to/decoder.pt",
|
||||
model_yaml_path: Path = "/path/to/model.yaml",
|
||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
):
|
||||
for p in [encoder_path, decoder_path, model_yaml_path]:
|
||||
if not p.exists():
|
||||
logging.info(f"{p} doesn't exists")
|
||||
return
|
||||
asr = JasperASR(str(model_yaml_path), str(encoder_path), str(decoder_path))
|
||||
service = ASRService(asr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logging.info("starting asr server...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(
|
||||
model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||
):
|
||||
encoder_path = model_dir / Path("decoder.pt")
|
||||
decoder_path = model_dir / Path("encoder.pt")
|
||||
model_yaml_path = model_dir / Path("model.yaml")
|
||||
rpyc(encoder_path, decoder_path, model_yaml_path, port)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
59
src/plume/models/jasper_nemo/serve_data.py
Normal file
59
src/plume/models/jasper_nemo/serve_data.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import rpyc
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import nemo
|
||||
import pickle
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
|
||||
)
|
||||
|
||||
|
||||
class ASRDataService(rpyc.Service):
|
||||
def exposed_get_path_samples(
|
||||
self, file_path, target_sr, int_values, offset, duration, trim
|
||||
):
|
||||
print(f"loading.. {file_path}")
|
||||
audio = AudioSegment.from_file(
|
||||
file_path,
|
||||
target_sr=target_sr,
|
||||
int_values=int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
|
||||
return pickle.dumps(audio.samples)
|
||||
|
||||
def exposed_read_path(self, file_path):
|
||||
# print(f"reading path.. {file_path}")
|
||||
return Path(file_path).read_bytes()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run_server(port: int = 0):
|
||||
listen_port = (
|
||||
port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
|
||||
)
|
||||
service = ASRDataService()
|
||||
t = ThreadedServer(
|
||||
service, port=listen_port, protocol_config={"allow_all_attrs": True}
|
||||
)
|
||||
typer.echo(f"starting asr server on {listen_port}...")
|
||||
t.start()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
392
src/plume/models/jasper_nemo/train.py
Normal file
392
src/plume/models/jasper_nemo/train.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
from nemo.utils.lr_policies import CosineAnnealing
|
||||
from .data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper-speller",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
data_layer = data_loader_layer(
|
||||
manifest_filepath=args.train_dataset,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**train_dl_params,
|
||||
# normalize_transcripts=False
|
||||
)
|
||||
|
||||
N = len(data_layer)
|
||||
steps_per_epoch = math.ceil(
|
||||
N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
)
|
||||
logging.info("Have {0} examples to train on.".format(N))
|
||||
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate,
|
||||
**jasper_params["AudioToMelSpectrogramPreprocessor"],
|
||||
)
|
||||
|
||||
multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
if multiply_batch_config:
|
||||
multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
|
||||
spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
if spectr_augment_config:
|
||||
data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
**spectr_augment_config
|
||||
)
|
||||
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
else:
|
||||
logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
logging.info("================================")
|
||||
logging.info(
|
||||
f"Number of parameters in encoder: {jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info(
|
||||
f"Number of parameters in decoder: {jasper_decoder.num_weights}"
|
||||
)
|
||||
logging.info(
|
||||
f"Total number of parameters in model: "
|
||||
f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info("================================")
|
||||
|
||||
# Train DAG
|
||||
(
|
||||
audio_signal_t,
|
||||
a_sig_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = data_layer()
|
||||
processed_signal_t, p_length_t = data_preprocessor(
|
||||
input_signal=audio_signal_t, length=a_sig_length_t
|
||||
)
|
||||
|
||||
if multiply_batch_config:
|
||||
(
|
||||
processed_signal_t,
|
||||
p_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = multiply_batch(
|
||||
in_x=processed_signal_t,
|
||||
in_x_len=p_length_t,
|
||||
in_y=transcript_t,
|
||||
in_y_len=transcript_len_t,
|
||||
)
|
||||
|
||||
if spectr_augment_config:
|
||||
processed_signal_t = data_spectr_augmentation(
|
||||
input_spec=processed_signal_t
|
||||
)
|
||||
|
||||
encoded_t, encoded_len_t = jasper_encoder(
|
||||
audio_signal=processed_signal_t, length=p_length_t
|
||||
)
|
||||
log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
loss_t = ctc_loss(
|
||||
log_probs=log_probs_t,
|
||||
targets=transcript_t,
|
||||
input_length=encoded_len_t,
|
||||
target_length=transcript_len_t,
|
||||
)
|
||||
|
||||
# Callbacks needed to print info to console and Tensorboard
|
||||
train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
get_tb_values=lambda x: [("loss", x[0])],
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
chpt_callback = nemo.core.CheckpointCallback(
|
||||
folder=neural_factory.checkpoint_dir,
|
||||
load_from_folder=args.load_dir,
|
||||
step_freq=args.checkpoint_save_freq,
|
||||
checkpoints_to_keep=30,
|
||||
)
|
||||
|
||||
callbacks = [train_callback, chpt_callback]
|
||||
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(
|
||||
audio_signal_e,
|
||||
a_sig_length_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[
|
||||
loss_e,
|
||||
predictions_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(
|
||||
process_evaluation_epoch, tag=tagname
|
||||
),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return loss_t, callbacks, steps_per_epoch
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
name = construct_name(
|
||||
args.exp_name,
|
||||
args.lr,
|
||||
args.batch_size,
|
||||
args.max_steps,
|
||||
args.num_epochs,
|
||||
args.weight_decay,
|
||||
args.optimizer,
|
||||
args.iter_per_step,
|
||||
)
|
||||
log_dir = name
|
||||
if args.work_dir:
|
||||
log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
local_rank=args.local_rank,
|
||||
optimization_level=args.amp_opt_level,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
create_tb_writer=args.create_tb_writer,
|
||||
files_to_copy=[args.model_config, __file__],
|
||||
cudnn_benchmark=args.cudnn_benchmark,
|
||||
tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
train_loss, callbacks, steps_per_epoch = create_all_dags(
|
||||
args, neural_factory
|
||||
)
|
||||
# train model
|
||||
neural_factory.train(
|
||||
tensors_to_optimize=[train_loss],
|
||||
callbacks=callbacks,
|
||||
lr_policy=CosineAnnealing(
|
||||
args.max_steps
|
||||
if args.max_steps is not None
|
||||
else args.num_epochs * steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
),
|
||||
optimizer=args.optimizer,
|
||||
optimization_params={
|
||||
"num_epochs": args.num_epochs,
|
||||
"max_steps": args.max_steps,
|
||||
"lr": args.lr,
|
||||
"betas": (args.beta1, args.beta2),
|
||||
"weight_decay": args.weight_decay,
|
||||
"grad_norm_clip": None,
|
||||
},
|
||||
batches_per_step=args.iter_per_step,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
src/plume/models/marblenet_nemo/__init__.py
Normal file
0
src/plume/models/marblenet_nemo/__init__.py
Normal file
132
src/plume/models/marblenet_nemo/asr.py
Normal file
132
src/plume/models/marblenet_nemo/asr.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import os
|
||||
import tempfile
|
||||
from ruamel.yaml import YAML
|
||||
import json
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import wave
|
||||
from nemo.collections.asr.helpers import post_process_predictions
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
WORK_DIR = "/tmp"
|
||||
|
||||
|
||||
class JasperASR(object):
|
||||
"""docstring for JasperASR."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_yaml,
|
||||
encoder_checkpoint,
|
||||
decoder_checkpoint,
|
||||
language_model=None,
|
||||
):
|
||||
super(JasperASR, self).__init__()
|
||||
# Read model YAML
|
||||
yaml = YAML(typ="safe")
|
||||
with open(model_yaml) as f:
|
||||
jasper_model_definition = yaml.load(f)
|
||||
self.neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
)
|
||||
self.labels = jasper_model_definition["labels"]
|
||||
self.data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor()
|
||||
self.jasper_encoder = nemo_asr.JasperEncoder(
|
||||
jasper=jasper_model_definition["JasperEncoder"]["jasper"],
|
||||
activation=jasper_model_definition["JasperEncoder"]["activation"],
|
||||
feat_in=jasper_model_definition[
|
||||
"AudioToMelSpectrogramPreprocessor"
|
||||
]["features"],
|
||||
)
|
||||
self.jasper_encoder.restore_from(encoder_checkpoint, local_rank=0)
|
||||
self.jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=1024, num_classes=len(self.labels)
|
||||
)
|
||||
self.jasper_decoder.restore_from(decoder_checkpoint, local_rank=0)
|
||||
self.greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
self.beam_search_with_lm = None
|
||||
if language_model:
|
||||
self.beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
|
||||
vocab=self.labels,
|
||||
beam_width=64,
|
||||
alpha=2.0,
|
||||
beta=1.0,
|
||||
lm_path=language_model,
|
||||
num_cpus=max(os.cpu_count(), 1),
|
||||
)
|
||||
|
||||
def transcribe(self, audio_data, greedy=True):
|
||||
audio_file = tempfile.NamedTemporaryFile(
|
||||
dir=WORK_DIR, prefix="jasper_audio.", delete=False
|
||||
)
|
||||
# audio_file.write(audio_data)
|
||||
audio_file.close()
|
||||
audio_file_path = audio_file.name
|
||||
wf = wave.open(audio_file_path, "w")
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(24000)
|
||||
wf.writeframesraw(audio_data)
|
||||
wf.close()
|
||||
manifest = {
|
||||
"audio_filepath": audio_file_path,
|
||||
"duration": 60,
|
||||
"text": "todo",
|
||||
}
|
||||
manifest_file = tempfile.NamedTemporaryFile(
|
||||
dir=WORK_DIR, prefix="jasper_manifest.", delete=False, mode="w"
|
||||
)
|
||||
manifest_file.write(json.dumps(manifest))
|
||||
manifest_file.close()
|
||||
manifest_file_path = manifest_file.name
|
||||
data_layer = nemo_asr.AudioToTextDataLayer(
|
||||
shuffle=False,
|
||||
manifest_filepath=manifest_file_path,
|
||||
labels=self.labels,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
# Define inference DAG
|
||||
audio_signal, audio_signal_len, _, _ = data_layer()
|
||||
processed_signal, processed_signal_len = self.data_preprocessor(
|
||||
input_signal=audio_signal, length=audio_signal_len
|
||||
)
|
||||
encoded, encoded_len = self.jasper_encoder(
|
||||
audio_signal=processed_signal, length=processed_signal_len
|
||||
)
|
||||
log_probs = self.jasper_decoder(encoder_output=encoded)
|
||||
predictions = self.greedy_decoder(log_probs=log_probs)
|
||||
|
||||
if greedy:
|
||||
eval_tensors = [predictions]
|
||||
else:
|
||||
if self.beam_search_with_lm:
|
||||
logging.info("Running with beam search")
|
||||
beam_predictions = self.beam_search_with_lm(
|
||||
log_probs=log_probs, log_probs_length=encoded_len
|
||||
)
|
||||
eval_tensors = [beam_predictions]
|
||||
else:
|
||||
logging.info(
|
||||
"language_model not specified. falling back to greedy decoding."
|
||||
)
|
||||
eval_tensors = [predictions]
|
||||
|
||||
tensors = self.neural_factory.infer(tensors=eval_tensors)
|
||||
prediction = post_process_predictions(tensors[0], self.labels)
|
||||
prediction_text = ". ".join(prediction)
|
||||
os.unlink(manifest_file.name)
|
||||
os.unlink(audio_file.name)
|
||||
return prediction_text
|
||||
|
||||
def transcribe_file(self, audio_file, *args, **kwargs):
|
||||
tscript_file_path = audio_file.with_suffix(".txt")
|
||||
audio_file_path = str(audio_file)
|
||||
with wave.open(audio_file_path, "r") as af:
|
||||
frame_count = af.getnframes()
|
||||
audio_data = af.readframes(frame_count)
|
||||
transcription = self.transcribe(audio_data, *args, **kwargs)
|
||||
with open(tscript_file_path, "w") as tf:
|
||||
tf.write(transcription)
|
||||
24
src/plume/models/marblenet_nemo/data.py
Normal file
24
src/plume/models/marblenet_nemo/data.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from pathlib import Path
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def set_root(dataset_path: Path, root_path: Path):
|
||||
pass
|
||||
# for dataset_kind in ["train", "valid"]:
|
||||
# data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||
# with data_file.open("r") as df:
|
||||
# lines = df.readlines()
|
||||
# with data_file.open("w") as df:
|
||||
# lines[0] = str(root_path) + "\n"
|
||||
# df.writelines(lines)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
340
src/plume/models/marblenet_nemo/data_loaders.py
Normal file
340
src/plume/models/marblenet_nemo/data_loaders.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from functools import partial
|
||||
import tempfile
|
||||
|
||||
# from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import nemo
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.backends.pytorch import DataLayerNM
|
||||
from nemo.core import DeviceType
|
||||
|
||||
# from nemo.core.neural_types import *
|
||||
from nemo.core.neural_types import (
|
||||
NeuralType,
|
||||
AudioSignal,
|
||||
LengthsType,
|
||||
LabelsType,
|
||||
)
|
||||
from nemo.utils.decorators import add_port_docs
|
||||
|
||||
from nemo.collections.asr.parts.dataset import (
|
||||
# AudioDataset,
|
||||
# AudioLabelDataset,
|
||||
# KaldiFeatureDataset,
|
||||
# TranscriptDataset,
|
||||
parsers,
|
||||
collections,
|
||||
seq_collate_fn,
|
||||
)
|
||||
|
||||
# from functools import lru_cache
|
||||
import rpyc
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
from .featurizer import RpycWaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
|
||||
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
class CachedAudioDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset that loads tensors via a json file containing paths to audio
|
||||
files, transcripts, and durations (in seconds). Each new line is a
|
||||
different sample. Example below:
|
||||
|
||||
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
|
||||
"/path/to/audio.txt", "duration": 23.147}
|
||||
...
|
||||
{"audio_filepath": "/path/to/audio.wav", "text": "the
|
||||
transcription", offset": 301.75, "duration": 0.82, "utt":
|
||||
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
||||
|
||||
Args:
|
||||
manifest_filepath: Path to manifest json as described above. Can
|
||||
be comma-separated paths.
|
||||
labels: String containing all the possible characters to map to
|
||||
featurizer: Initialized featurizer class that converts paths of
|
||||
audio to feature tensors
|
||||
max_duration: If audio exceeds this length, do not include in dataset
|
||||
min_duration: If audio is less than this length, do not include
|
||||
in dataset
|
||||
max_utts: Limit number of utterances
|
||||
blank_index: blank character index, default = -1
|
||||
unk_index: unk_character index, default = -1
|
||||
normalize: whether to normalize transcript text (default): True
|
||||
bos_id: Id of beginning of sequence symbol to append if not None
|
||||
eos_id: Id of end of sequence symbol to append if not None
|
||||
load_audio: Boolean flag indicate whether do or not load audio
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
featurizer,
|
||||
max_duration=None,
|
||||
min_duration=None,
|
||||
max_utts=0,
|
||||
blank_index=-1,
|
||||
unk_index=-1,
|
||||
normalize=True,
|
||||
trim=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
load_audio=True,
|
||||
parser="en",
|
||||
):
|
||||
self.collection = collections.ASRAudioText(
|
||||
manifests_files=manifest_filepath.split(","),
|
||||
parser=parsers.make_parser(
|
||||
labels=labels,
|
||||
name=parser,
|
||||
unk_id=unk_index,
|
||||
blank_id=blank_index,
|
||||
do_normalize=normalize,
|
||||
),
|
||||
min_duration=min_duration,
|
||||
max_duration=max_duration,
|
||||
max_number=max_utts,
|
||||
)
|
||||
self.index_feature_map = {}
|
||||
|
||||
self.featurizer = featurizer
|
||||
self.trim = trim
|
||||
self.eos_id = eos_id
|
||||
self.bos_id = bos_id
|
||||
self.load_audio = load_audio
|
||||
print(f"initializing dataset {manifest_filepath}")
|
||||
|
||||
def exec_func(i):
|
||||
return self[i]
|
||||
|
||||
task_count = len(self.collection)
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all loading tasks")
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, range(task_count)),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=task_count,
|
||||
)
|
||||
)
|
||||
print(f"initializing complete")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.collection[index]
|
||||
if self.load_audio:
|
||||
cached_features = self.index_feature_map.get(index)
|
||||
if cached_features is not None:
|
||||
features = cached_features
|
||||
else:
|
||||
features = self.featurizer.process(
|
||||
sample.audio_file,
|
||||
offset=0,
|
||||
duration=sample.duration,
|
||||
trim=self.trim,
|
||||
)
|
||||
self.index_feature_map[index] = features
|
||||
f, fl = features, torch.tensor(features.shape[0]).long()
|
||||
else:
|
||||
f, fl = None, None
|
||||
|
||||
t, tl = sample.text_tokens, len(sample.text_tokens)
|
||||
if self.bos_id is not None:
|
||||
t = [self.bos_id] + t
|
||||
tl += 1
|
||||
if self.eos_id is not None:
|
||||
t = t + [self.eos_id]
|
||||
tl += 1
|
||||
|
||||
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.collection)
|
||||
|
||||
|
||||
class RpycAudioToTextDataLayer(DataLayerNM):
|
||||
"""Data Layer for general ASR tasks.
|
||||
|
||||
Module which reads ASR labeled data. It accepts comma-separated
|
||||
JSON manifest files describing the correspondence between wav audio files
|
||||
and their transcripts. JSON files should be of the following format::
|
||||
|
||||
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
|
||||
transcript_0}
|
||||
...
|
||||
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
|
||||
transcript_n}
|
||||
|
||||
Args:
|
||||
manifest_filepath (str): Dataset parameter.
|
||||
Path to JSON containing data.
|
||||
labels (list): Dataset parameter.
|
||||
List of characters that can be output by the ASR model.
|
||||
For Jasper, this is the 28 character set {a-z '}. The CTC blank
|
||||
symbol is automatically added later for models using ctc.
|
||||
batch_size (int): batch size
|
||||
sample_rate (int): Target sampling rate for data. Audio files will be
|
||||
resampled to sample_rate if it is not already.
|
||||
Defaults to 16000.
|
||||
int_values (bool): Bool indicating whether the audio file is saved as
|
||||
int data or float data.
|
||||
Defaults to False.
|
||||
eos_id (id): Dataset parameter.
|
||||
End of string symbol id used for seq2seq models.
|
||||
Defaults to None.
|
||||
min_duration (float): Dataset parameter.
|
||||
All training files which have a duration less than min_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to 0.1.
|
||||
max_duration (float): Dataset parameter.
|
||||
All training files which have a duration more than max_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to None.
|
||||
normalize_transcripts (bool): Dataset parameter.
|
||||
Whether to use automatic text cleaning.
|
||||
It is highly recommended to manually clean text for best results.
|
||||
Defaults to True.
|
||||
trim_silence (bool): Whether to use trim silence from beginning and end
|
||||
of audio signal using librosa.effects.trim().
|
||||
Defaults to False.
|
||||
load_audio (bool): Dataset parameter.
|
||||
Controls whether the dataloader loads the audio signal and
|
||||
transcript or just the transcript.
|
||||
Defaults to True.
|
||||
drop_last (bool): See PyTorch DataLoader.
|
||||
Defaults to False.
|
||||
shuffle (bool): See PyTorch DataLoader.
|
||||
Defaults to True.
|
||||
num_workers (int): See PyTorch DataLoader.
|
||||
Defaults to 0.
|
||||
perturb_config (dict): Currently disabled.
|
||||
"""
|
||||
|
||||
@property
|
||||
@add_port_docs()
|
||||
def output_ports(self):
|
||||
"""Returns definitions of module output ports."""
|
||||
return {
|
||||
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
"audio_signal": NeuralType(
|
||||
("B", "T"),
|
||||
AudioSignal(freq=self._sample_rate)
|
||||
if self is not None and self._sample_rate is not None
|
||||
else AudioSignal(),
|
||||
),
|
||||
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
|
||||
"transcripts": NeuralType(("B", "T"), LabelsType()),
|
||||
"transcript_length": NeuralType(tuple("B"), LengthsType()),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
batch_size,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
pad_id=None,
|
||||
min_duration=0.1,
|
||||
max_duration=None,
|
||||
normalize_transcripts=True,
|
||||
trim_silence=False,
|
||||
load_audio=True,
|
||||
rpyc_host="",
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
):
|
||||
super().__init__()
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
def rpyc_root_fn():
|
||||
return rpyc.connect(
|
||||
rpyc_host, 8064, config={"sync_request_timeout": 600}
|
||||
).root
|
||||
|
||||
rpyc_conn = rpyc_root_fn()
|
||||
|
||||
self._featurizer = RpycWaveformFeaturizer(
|
||||
sample_rate=self._sample_rate,
|
||||
int_values=int_values,
|
||||
augmentor=None,
|
||||
rpyc_conn=rpyc_conn,
|
||||
)
|
||||
|
||||
def read_remote_manifests():
|
||||
local_mp = []
|
||||
for mrp in manifest_filepath.split(","):
|
||||
md = rpyc_conn.read_path(mrp)
|
||||
mf = tempfile.NamedTemporaryFile(
|
||||
dir="/tmp", prefix="jasper_manifest.", delete=False
|
||||
)
|
||||
mf.write(md)
|
||||
mf.close()
|
||||
local_mp.append(mf.name)
|
||||
return ",".join(local_mp)
|
||||
|
||||
local_manifest_filepath = read_remote_manifests()
|
||||
dataset_params = {
|
||||
"manifest_filepath": local_manifest_filepath,
|
||||
"labels": labels,
|
||||
"featurizer": self._featurizer,
|
||||
"max_duration": max_duration,
|
||||
"min_duration": min_duration,
|
||||
"normalize": normalize_transcripts,
|
||||
"trim": trim_silence,
|
||||
"bos_id": bos_id,
|
||||
"eos_id": eos_id,
|
||||
"load_audio": load_audio,
|
||||
}
|
||||
|
||||
self._dataset = CachedAudioDataset(**dataset_params)
|
||||
self._batch_size = batch_size
|
||||
|
||||
# Set up data loader
|
||||
if self._placement == DeviceType.AllGpu:
|
||||
logging.info("Parallelizing Datalayer.")
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
self._dataset
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
if batch_size == -1:
|
||||
batch_size = len(self._dataset)
|
||||
|
||||
pad_id = 0 if pad_id is None else pad_id
|
||||
self._dataloader = torch.utils.data.DataLoader(
|
||||
dataset=self._dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
|
||||
drop_last=drop_last,
|
||||
shuffle=shuffle if sampler is None else False,
|
||||
sampler=sampler,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def data_iterator(self):
|
||||
return self._dataloader
|
||||
376
src/plume/models/marblenet_nemo/eval.py
Normal file
376
src/plume/models/marblenet_nemo/eval.py
Normal file
@@ -0,0 +1,376 @@
|
||||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
# import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
# monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
# from nemo.utils.lr_policies import CosineAnnealing
|
||||
from training.data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="encoder checkpoint file: JasperEncoder.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="decoder checkpoint file: JasperDecoderForCTC.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
# data_layer = data_loader_layer(
|
||||
# manifest_filepath=args.train_dataset,
|
||||
# sample_rate=sample_rate,
|
||||
# labels=vocab,
|
||||
# batch_size=args.batch_size,
|
||||
# num_workers=cpu_per_traindl,
|
||||
# **train_dl_params,
|
||||
# # normalize_transcripts=False
|
||||
# )
|
||||
#
|
||||
# N = len(data_layer)
|
||||
# steps_per_epoch = math.ceil(
|
||||
# N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
# )
|
||||
# logging.info("Have {0} examples to train on.".format(N))
|
||||
#
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate,
|
||||
**jasper_params["AudioToMelSpectrogramPreprocessor"],
|
||||
)
|
||||
|
||||
# multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
# if multiply_batch_config:
|
||||
# multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
#
|
||||
# spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
# if spectr_augment_config:
|
||||
# data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
# **spectr_augment_config
|
||||
# )
|
||||
#
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
# if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
# else:
|
||||
# logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
# logging.info("================================")
|
||||
# logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
|
||||
# logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
|
||||
# logging.info(
|
||||
# f"Total number of parameters in model: "
|
||||
# f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
# )
|
||||
# logging.info("================================")
|
||||
#
|
||||
# # Train DAG
|
||||
# (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
|
||||
# processed_signal_t, p_length_t = data_preprocessor(
|
||||
# input_signal=audio_signal_t, length=a_sig_length_t
|
||||
# )
|
||||
#
|
||||
# if multiply_batch_config:
|
||||
# (
|
||||
# processed_signal_t,
|
||||
# p_length_t,
|
||||
# transcript_t,
|
||||
# transcript_len_t,
|
||||
# ) = multiply_batch(
|
||||
# in_x=processed_signal_t,
|
||||
# in_x_len=p_length_t,
|
||||
# in_y=transcript_t,
|
||||
# in_y_len=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# if spectr_augment_config:
|
||||
# processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
|
||||
#
|
||||
# encoded_t, encoded_len_t = jasper_encoder(
|
||||
# audio_signal=processed_signal_t, length=p_length_t
|
||||
# )
|
||||
# log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
# predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
# loss_t = ctc_loss(
|
||||
# log_probs=log_probs_t,
|
||||
# targets=transcript_t,
|
||||
# input_length=encoded_len_t,
|
||||
# target_length=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# # Callbacks needed to print info to console and Tensorboard
|
||||
# train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
# tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
# print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
# get_tb_values=lambda x: [("loss", x[0])],
|
||||
# tb_writer=neural_factory.tb_writer,
|
||||
# )
|
||||
#
|
||||
# chpt_callback = nemo.core.CheckpointCallback(
|
||||
# folder=neural_factory.checkpoint_dir,
|
||||
# load_from_folder=args.load_dir,
|
||||
# step_freq=args.checkpoint_save_freq,
|
||||
# checkpoints_to_keep=30,
|
||||
# )
|
||||
#
|
||||
# callbacks = [train_callback, chpt_callback]
|
||||
callbacks = []
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(
|
||||
audio_signal_e,
|
||||
a_sig_length_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[
|
||||
loss_e,
|
||||
predictions_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(
|
||||
process_evaluation_epoch, tag=tagname
|
||||
),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return callbacks
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
# name = construct_name(
|
||||
# args.exp_name,
|
||||
# args.lr,
|
||||
# args.batch_size,
|
||||
# args.max_steps,
|
||||
# args.num_epochs,
|
||||
# args.weight_decay,
|
||||
# args.optimizer,
|
||||
# args.iter_per_step,
|
||||
# )
|
||||
# log_dir = name
|
||||
# if args.work_dir:
|
||||
# log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
# local_rank=args.local_rank,
|
||||
# optimization_level=args.amp_opt_level,
|
||||
# log_dir=log_dir,
|
||||
# checkpoint_dir=args.checkpoint_dir,
|
||||
# create_tb_writer=args.create_tb_writer,
|
||||
# files_to_copy=[args.model_config, __file__],
|
||||
# cudnn_benchmark=args.cudnn_benchmark,
|
||||
# tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
# checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
callbacks = create_all_dags(args, neural_factory)
|
||||
# evaluate model
|
||||
neural_factory.eval(callbacks=callbacks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
61
src/plume/models/marblenet_nemo/featurizer.py
Normal file
61
src/plume/models/marblenet_nemo/featurizer.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# import math
|
||||
|
||||
# import librosa
|
||||
import torch
|
||||
import pickle
|
||||
|
||||
# import torch.nn as nn
|
||||
# from torch_stft import STFT
|
||||
|
||||
# from nemo import logging
|
||||
from nemo.collections.asr.parts.perturb import AudioAugmentor
|
||||
|
||||
# from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
|
||||
class RpycWaveformFeaturizer(object):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
augmentor=None,
|
||||
rpyc_conn=None,
|
||||
):
|
||||
self.augmentor = (
|
||||
augmentor if augmentor is not None else AudioAugmentor()
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.int_values = int_values
|
||||
self.remote_path_samples = rpyc_conn.get_path_samples
|
||||
|
||||
def max_augmentation_length(self, length):
|
||||
return self.augmentor.max_augmentation_length(length)
|
||||
|
||||
def process(self, file_path, offset=0, duration=0, trim=False):
|
||||
audio = self.remote_path_samples(
|
||||
file_path,
|
||||
target_sr=self.sample_rate,
|
||||
int_values=self.int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
return torch.tensor(pickle.loads(audio), dtype=torch.float)
|
||||
|
||||
def process_segment(self, audio_segment):
|
||||
self.augmentor.perturb(audio_segment)
|
||||
return torch.tensor(audio_segment, dtype=torch.float)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, input_config, perturbation_configs=None):
|
||||
if perturbation_configs is not None:
|
||||
aa = AudioAugmentor.from_config(perturbation_configs)
|
||||
else:
|
||||
aa = None
|
||||
|
||||
sample_rate = input_config.get("sample_rate", 16000)
|
||||
int_values = input_config.get("int_values", False)
|
||||
|
||||
return cls(
|
||||
sample_rate=sample_rate, int_values=int_values, augmentor=aa
|
||||
)
|
||||
54
src/plume/models/marblenet_nemo/serve.py
Normal file
54
src/plume/models/marblenet_nemo/serve.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import typer
|
||||
|
||||
# from .asr import JasperASR
|
||||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
JasperASR = lazy_callable("plume.models.jasper_nemo.asr.JasperASR")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc(
|
||||
encoder_path: Path = "/path/to/encoder.pt",
|
||||
decoder_path: Path = "/path/to/decoder.pt",
|
||||
model_yaml_path: Path = "/path/to/model.yaml",
|
||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
):
|
||||
for p in [encoder_path, decoder_path, model_yaml_path]:
|
||||
if not p.exists():
|
||||
logging.info(f"{p} doesn't exists")
|
||||
return
|
||||
asr = JasperASR(str(model_yaml_path), str(encoder_path), str(decoder_path))
|
||||
service = ASRService(asr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logging.info("starting asr server...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(
|
||||
model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||
):
|
||||
encoder_path = model_dir / Path("decoder.pt")
|
||||
decoder_path = model_dir / Path("encoder.pt")
|
||||
model_yaml_path = model_dir / Path("model.yaml")
|
||||
rpyc(encoder_path, decoder_path, model_yaml_path, port)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
59
src/plume/models/marblenet_nemo/serve_data.py
Normal file
59
src/plume/models/marblenet_nemo/serve_data.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import rpyc
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import nemo
|
||||
import pickle
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
|
||||
)
|
||||
|
||||
|
||||
class ASRDataService(rpyc.Service):
|
||||
def exposed_get_path_samples(
|
||||
self, file_path, target_sr, int_values, offset, duration, trim
|
||||
):
|
||||
print(f"loading.. {file_path}")
|
||||
audio = AudioSegment.from_file(
|
||||
file_path,
|
||||
target_sr=target_sr,
|
||||
int_values=int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
|
||||
return pickle.dumps(audio.samples)
|
||||
|
||||
def exposed_read_path(self, file_path):
|
||||
# print(f"reading path.. {file_path}")
|
||||
return Path(file_path).read_bytes()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run_server(port: int = 0):
|
||||
listen_port = (
|
||||
port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
|
||||
)
|
||||
service = ASRDataService()
|
||||
t = ThreadedServer(
|
||||
service, port=listen_port, protocol_config={"allow_all_attrs": True}
|
||||
)
|
||||
typer.echo(f"starting asr server on {listen_port}...")
|
||||
t.start()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
392
src/plume/models/marblenet_nemo/train.py
Normal file
392
src/plume/models/marblenet_nemo/train.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
from nemo.utils.lr_policies import CosineAnnealing
|
||||
from .data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper-speller",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
data_layer = data_loader_layer(
|
||||
manifest_filepath=args.train_dataset,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**train_dl_params,
|
||||
# normalize_transcripts=False
|
||||
)
|
||||
|
||||
N = len(data_layer)
|
||||
steps_per_epoch = math.ceil(
|
||||
N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
)
|
||||
logging.info("Have {0} examples to train on.".format(N))
|
||||
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate,
|
||||
**jasper_params["AudioToMelSpectrogramPreprocessor"],
|
||||
)
|
||||
|
||||
multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
if multiply_batch_config:
|
||||
multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
|
||||
spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
if spectr_augment_config:
|
||||
data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
**spectr_augment_config
|
||||
)
|
||||
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
else:
|
||||
logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
logging.info("================================")
|
||||
logging.info(
|
||||
f"Number of parameters in encoder: {jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info(
|
||||
f"Number of parameters in decoder: {jasper_decoder.num_weights}"
|
||||
)
|
||||
logging.info(
|
||||
f"Total number of parameters in model: "
|
||||
f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info("================================")
|
||||
|
||||
# Train DAG
|
||||
(
|
||||
audio_signal_t,
|
||||
a_sig_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = data_layer()
|
||||
processed_signal_t, p_length_t = data_preprocessor(
|
||||
input_signal=audio_signal_t, length=a_sig_length_t
|
||||
)
|
||||
|
||||
if multiply_batch_config:
|
||||
(
|
||||
processed_signal_t,
|
||||
p_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = multiply_batch(
|
||||
in_x=processed_signal_t,
|
||||
in_x_len=p_length_t,
|
||||
in_y=transcript_t,
|
||||
in_y_len=transcript_len_t,
|
||||
)
|
||||
|
||||
if spectr_augment_config:
|
||||
processed_signal_t = data_spectr_augmentation(
|
||||
input_spec=processed_signal_t
|
||||
)
|
||||
|
||||
encoded_t, encoded_len_t = jasper_encoder(
|
||||
audio_signal=processed_signal_t, length=p_length_t
|
||||
)
|
||||
log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
loss_t = ctc_loss(
|
||||
log_probs=log_probs_t,
|
||||
targets=transcript_t,
|
||||
input_length=encoded_len_t,
|
||||
target_length=transcript_len_t,
|
||||
)
|
||||
|
||||
# Callbacks needed to print info to console and Tensorboard
|
||||
train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
get_tb_values=lambda x: [("loss", x[0])],
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
chpt_callback = nemo.core.CheckpointCallback(
|
||||
folder=neural_factory.checkpoint_dir,
|
||||
load_from_folder=args.load_dir,
|
||||
step_freq=args.checkpoint_save_freq,
|
||||
checkpoints_to_keep=30,
|
||||
)
|
||||
|
||||
callbacks = [train_callback, chpt_callback]
|
||||
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(
|
||||
audio_signal_e,
|
||||
a_sig_length_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[
|
||||
loss_e,
|
||||
predictions_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(
|
||||
process_evaluation_epoch, tag=tagname
|
||||
),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return loss_t, callbacks, steps_per_epoch
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
name = construct_name(
|
||||
args.exp_name,
|
||||
args.lr,
|
||||
args.batch_size,
|
||||
args.max_steps,
|
||||
args.num_epochs,
|
||||
args.weight_decay,
|
||||
args.optimizer,
|
||||
args.iter_per_step,
|
||||
)
|
||||
log_dir = name
|
||||
if args.work_dir:
|
||||
log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
local_rank=args.local_rank,
|
||||
optimization_level=args.amp_opt_level,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
create_tb_writer=args.create_tb_writer,
|
||||
files_to_copy=[args.model_config, __file__],
|
||||
cudnn_benchmark=args.cudnn_benchmark,
|
||||
tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
train_loss, callbacks, steps_per_epoch = create_all_dags(
|
||||
args, neural_factory
|
||||
)
|
||||
# train model
|
||||
neural_factory.train(
|
||||
tensors_to_optimize=[train_loss],
|
||||
callbacks=callbacks,
|
||||
lr_policy=CosineAnnealing(
|
||||
args.max_steps
|
||||
if args.max_steps is not None
|
||||
else args.num_epochs * steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
),
|
||||
optimizer=args.optimizer,
|
||||
optimization_params={
|
||||
"num_epochs": args.num_epochs,
|
||||
"max_steps": args.max_steps,
|
||||
"lr": args.lr,
|
||||
"betas": (args.beta1, args.beta2),
|
||||
"weight_decay": args.weight_decay,
|
||||
"grad_norm_clip": None,
|
||||
},
|
||||
batches_per_step=args.iter_per_step,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
22
src/plume/models/marblenet_nemo/trial.py
Normal file
22
src/plume/models/marblenet_nemo/trial.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
import copy
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
import matplotlib.pyplot as plt
|
||||
import IPython.display as ipd
|
||||
# import pyaudio as pa
|
||||
import librosa
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
# sample rate, Hz
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained(
|
||||
"vad_marblenet"
|
||||
)
|
||||
# Preserve a copy of the full config
|
||||
cfg = copy.deepcopy(vad_model._cfg)
|
||||
# print(OmegaConf.to_yaml(cfg))
|
||||
0
src/plume/models/wav2vec2/__init__.py
Normal file
0
src/plume/models/wav2vec2/__init__.py
Normal file
204
src/plume/models/wav2vec2/asr.py
Normal file
204
src/plume/models/wav2vec2/asr.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from io import BytesIO
|
||||
import warnings
|
||||
import itertools as it
|
||||
|
||||
import torch
|
||||
import soundfile as sf
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from fairseq import utils
|
||||
from fairseq.models import BaseFairseqModel
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.models.wav2vec.wav2vec2_asr import base_architecture, Wav2VecEncoder
|
||||
except ModuleNotFoundError:
|
||||
warnings.warn("Install fairseq")
|
||||
try:
|
||||
from wav2letter.decoder import CriterionType
|
||||
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
|
||||
except ModuleNotFoundError:
|
||||
warnings.warn("Install wav2letter")
|
||||
|
||||
|
||||
class Wav2VecCtc(BaseFairseqModel):
|
||||
def __init__(self, w2v_encoder, args):
|
||||
super().__init__()
|
||||
self.w2v_encoder = w2v_encoder
|
||||
self.args = args
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
return state_dict
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, target_dict):
|
||||
"""Build a new model instance."""
|
||||
base_architecture(args)
|
||||
w2v_encoder = Wav2VecEncoder(args, target_dict)
|
||||
return cls(w2v_encoder, args)
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs):
|
||||
"""Get normalized probabilities (or log probs) from a net's output."""
|
||||
logits = net_output["encoder_out"]
|
||||
if log_probs:
|
||||
return utils.log_softmax(logits.float(), dim=-1)
|
||||
else:
|
||||
return utils.softmax(logits.float(), dim=-1)
|
||||
|
||||
def forward(self, **kwargs):
|
||||
x = self.w2v_encoder(**kwargs)
|
||||
return x
|
||||
|
||||
|
||||
class W2lDecoder(object):
|
||||
def __init__(self, tgt_dict):
|
||||
self.tgt_dict = tgt_dict
|
||||
self.vocab_size = len(tgt_dict)
|
||||
self.nbest = 1
|
||||
|
||||
self.criterion_type = CriterionType.CTC
|
||||
self.blank = (
|
||||
tgt_dict.index("<ctc_blank>")
|
||||
if "<ctc_blank>" in tgt_dict.indices
|
||||
else tgt_dict.bos()
|
||||
)
|
||||
self.asg_transitions = None
|
||||
|
||||
def generate(self, model, sample, **unused):
|
||||
"""Generate a batch of inferences."""
|
||||
# model.forward normally channels prev_output_tokens into the decoder
|
||||
# separately, but SequenceGenerator directly calls model.encoder
|
||||
encoder_input = {
|
||||
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
|
||||
}
|
||||
emissions = self.get_emissions(model, encoder_input)
|
||||
return self.decode(emissions)
|
||||
|
||||
def get_emissions(self, model, encoder_input):
|
||||
"""Run encoder and normalize emissions"""
|
||||
# encoder_out = models[0].encoder(**encoder_input)
|
||||
encoder_out = model(**encoder_input)
|
||||
if self.criterion_type == CriterionType.CTC:
|
||||
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
|
||||
|
||||
return emissions.transpose(0, 1).float().cpu().contiguous()
|
||||
|
||||
def get_tokens(self, idxs):
|
||||
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
|
||||
idxs = (g[0] for g in it.groupby(idxs))
|
||||
idxs = filter(lambda x: x != self.blank, idxs)
|
||||
|
||||
return torch.LongTensor(list(idxs))
|
||||
|
||||
|
||||
class W2lViterbiDecoder(W2lDecoder):
|
||||
def __init__(self, tgt_dict):
|
||||
super().__init__(tgt_dict)
|
||||
|
||||
def decode(self, emissions):
|
||||
B, T, N = emissions.size()
|
||||
hypos = list()
|
||||
|
||||
if self.asg_transitions is None:
|
||||
transitions = torch.FloatTensor(N, N).zero_()
|
||||
else:
|
||||
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
|
||||
|
||||
viterbi_path = torch.IntTensor(B, T)
|
||||
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
|
||||
CpuViterbiPath.compute(
|
||||
B,
|
||||
T,
|
||||
N,
|
||||
get_data_ptr_as_bytes(emissions),
|
||||
get_data_ptr_as_bytes(transitions),
|
||||
get_data_ptr_as_bytes(viterbi_path),
|
||||
get_data_ptr_as_bytes(workspace),
|
||||
)
|
||||
return [
|
||||
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
|
||||
for b in range(B)
|
||||
]
|
||||
|
||||
|
||||
def post_process(sentence: str, symbol: str):
|
||||
if symbol == "sentencepiece":
|
||||
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
|
||||
elif symbol == "wordpiece":
|
||||
sentence = sentence.replace(" ", "").replace("_", " ").strip()
|
||||
elif symbol == "letter":
|
||||
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
||||
elif symbol == "_EOW":
|
||||
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
||||
elif symbol is not None and symbol != "none":
|
||||
sentence = (sentence + " ").replace(symbol, "").rstrip()
|
||||
return sentence
|
||||
|
||||
|
||||
def get_feature(filepath):
|
||||
def postprocess(feats, sample_rate):
|
||||
if feats.dim == 2:
|
||||
feats = feats.mean(-1)
|
||||
|
||||
assert feats.dim() == 1, feats.dim()
|
||||
|
||||
with torch.no_grad():
|
||||
feats = F.layer_norm(feats, feats.shape)
|
||||
return feats
|
||||
|
||||
wav, sample_rate = sf.read(filepath)
|
||||
feats = torch.from_numpy(wav).float()
|
||||
if torch.cuda.is_available():
|
||||
feats = feats.cuda()
|
||||
feats = postprocess(feats, sample_rate)
|
||||
return feats
|
||||
|
||||
|
||||
def load_model(ctc_model_path, w2v_model_path, target_dict):
|
||||
w2v = torch.load(ctc_model_path)
|
||||
w2v["args"].w2v_path = w2v_model_path
|
||||
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
|
||||
model.load_state_dict(w2v["model"], strict=True)
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
class Wav2Vec2ASR(object):
|
||||
"""docstring for Wav2Vec2ASR."""
|
||||
|
||||
def __init__(self, ctc_path, w2v_path, target_dict_path):
|
||||
super(Wav2Vec2ASR, self).__init__()
|
||||
self.target_dict = Dictionary.load(target_dict_path)
|
||||
|
||||
self.model = load_model(ctc_path, w2v_path, self.target_dict)
|
||||
self.model.eval()
|
||||
|
||||
self.generator = W2lViterbiDecoder(self.target_dict)
|
||||
|
||||
def transcribe(self, audio_data, greedy=True):
|
||||
aud_f = BytesIO(audio_data)
|
||||
# aud_seg = pydub.AudioSegment.from_file(aud_f)
|
||||
# feat_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
# feat_f = io.BytesIO()
|
||||
# feat_seg.export(feat_f, format='wav')
|
||||
# feat_f.seek(0)
|
||||
net_input = {}
|
||||
feature = get_feature(aud_f)
|
||||
net_input["source"] = feature.unsqueeze(0)
|
||||
|
||||
padding_mask = (
|
||||
torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
padding_mask = padding_mask.cuda()
|
||||
|
||||
net_input["padding_mask"] = padding_mask
|
||||
sample = {}
|
||||
sample["net_input"] = net_input
|
||||
|
||||
with torch.no_grad():
|
||||
hypo = self.generator.generate(self.model, sample, prefix_tokens=None)
|
||||
hyp_pieces = self.target_dict.string(hypo[0][0]["tokens"].int().cpu())
|
||||
result = post_process(hyp_pieces, "letter")
|
||||
return result
|
||||
234
src/plume/models/wav2vec2/data.py
Normal file
234
src/plume/models/wav2vec2/data.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
import shutil
|
||||
import io
|
||||
|
||||
# from time import time
|
||||
|
||||
# import pydub
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from plume.utils import (
|
||||
ExtendedPath,
|
||||
replace_redundant_spaces_with,
|
||||
lazy_module,
|
||||
random_segs,
|
||||
parallel_apply,
|
||||
batch,
|
||||
run_shell,
|
||||
)
|
||||
|
||||
from plume.utils.vad import VADUtterance
|
||||
|
||||
soundfile = lazy_module("soundfile")
|
||||
pydub = lazy_module("pydub")
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True):
|
||||
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
|
||||
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
|
||||
tok_counter = Counter()
|
||||
shutil.copy(
|
||||
src_dataset_path / Path("test_manifest.json"),
|
||||
src_dataset_path / Path("valid_manifest.json"),
|
||||
)
|
||||
if unlink:
|
||||
src_wavs = src_dataset_path / Path("wavs")
|
||||
for wav_path in tqdm(list(src_wavs.glob("**/*.wav"))):
|
||||
audio_seg = (
|
||||
pydub.AudioSegment.from_wav(wav_path)
|
||||
.set_frame_rate(16000)
|
||||
.set_channels(1)
|
||||
)
|
||||
dest_path = dest_dataset_path / Path("wavs") / Path(wav_path.name)
|
||||
audio_seg.export(dest_path, format="wav")
|
||||
|
||||
for dataset_kind in ["train", "valid"]:
|
||||
abs_manifest_path = ExtendedPath(
|
||||
src_dataset_path / Path(f"{dataset_kind}_manifest.json")
|
||||
)
|
||||
manifest_data = list(abs_manifest_path.read_jsonl())
|
||||
o_tsv, o_ltr = f"{dataset_kind}.tsv", f"{dataset_kind}.ltr"
|
||||
out_tsv = dest_dataset_path / Path(o_tsv)
|
||||
out_ltr = dest_dataset_path / Path(o_ltr)
|
||||
with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f:
|
||||
if unlink:
|
||||
tsv_f.write(f"{dest_dataset_path}\n")
|
||||
else:
|
||||
tsv_f.write(f"{src_dataset_path}\n")
|
||||
for md in manifest_data:
|
||||
audio_fname = md["audio_filepath"]
|
||||
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()
|
||||
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
|
||||
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
|
||||
tok_counter.update(pipe_toks)
|
||||
letter_toks = " ".join(pipe_toks) + " |\n"
|
||||
frame_count = soundfile.info(audio_fname).frames
|
||||
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute())
|
||||
ltr_f.write(letter_toks)
|
||||
tsv_f.write(f"{rel_path}\t{frame_count}\n")
|
||||
with dict_ltr.open("w") as d_f:
|
||||
for k, v in tok_counter.most_common():
|
||||
d_f.write(f"{k} {v}\n")
|
||||
(src_dataset_path / Path("valid_manifest.json")).unlink()
|
||||
|
||||
|
||||
@app.command()
|
||||
def set_root(dataset_path: Path, root_path: Path):
|
||||
for dataset_kind in ["train", "valid"]:
|
||||
data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||
with data_file.open("r") as df:
|
||||
lines = df.readlines()
|
||||
with data_file.open("w") as df:
|
||||
lines[0] = str(root_path) + "\n"
|
||||
df.writelines(lines)
|
||||
|
||||
|
||||
@app.command()
|
||||
def convert_audio(log_dir: Path, out_dir: Path):
|
||||
out_dir.mkdir(exist_ok=True, parents=True)
|
||||
all_wavs = list((log_dir).glob("**/*.wav"))
|
||||
name_wav_map = {i.name: i.absolute() for i in all_wavs}
|
||||
exists_wavs = list((out_dir).glob("**/*.wav"))
|
||||
rem_wavs = list(
|
||||
set((i.name for i in all_wavs)) - set((i.name for i in exists_wavs))
|
||||
)
|
||||
rem_wavs_real = [name_wav_map[i] for i in rem_wavs]
|
||||
|
||||
def resample_audio(i):
|
||||
dest_wav = out_dir / i.name
|
||||
if dest_wav.exists():
|
||||
return
|
||||
run_shell(f"ffmpeg -i {i.absolute()} -ac 1 -ar 16000 {dest_wav}", verbose=False)
|
||||
|
||||
parallel_apply(resample_audio, rem_wavs_real, workers=256)
|
||||
|
||||
|
||||
@app.command()
|
||||
def prepare_pretraining(
|
||||
log_dir: Path,
|
||||
dataset_path: Path,
|
||||
format: str = "wav",
|
||||
method: str = "random",
|
||||
max_silence: int = 3000,
|
||||
min_duration: int = 10000,
|
||||
max_duration: int = 30000,
|
||||
fixed_duration: int = 30000,
|
||||
batch_size: int = 100,
|
||||
):
|
||||
audio_dir = dataset_path / "audio"
|
||||
audio_dir.mkdir(exist_ok=True, parents=True)
|
||||
cache_dir = dataset_path / "cache"
|
||||
cache_dir.mkdir(exist_ok=True, parents=True)
|
||||
all_wavs = list((log_dir).glob("**/*.wav"))
|
||||
if method not in ["vad", "random", "fixed"]:
|
||||
typer.echo("should be one of random|fixed")
|
||||
raise typer.Exit()
|
||||
|
||||
def write_seg_arg(arg):
|
||||
seg, dest_wav = arg
|
||||
ob = io.BytesIO()
|
||||
seg.export(ob, format=format)
|
||||
dest_wav.write_bytes(ob.getvalue())
|
||||
ob.close()
|
||||
|
||||
with (dataset_path / "failed.log").open("w") as fl:
|
||||
vad_utt = VADUtterance(
|
||||
max_silence=max_silence,
|
||||
min_utterance=min_duration,
|
||||
max_utterance=max_duration,
|
||||
)
|
||||
|
||||
def vad_process_wav(wav_path):
|
||||
if (cache_dir / wav_path.stem).exists():
|
||||
return []
|
||||
try:
|
||||
aud_seg = pydub.AudioSegment.from_file(wav_path)
|
||||
except pydub.exceptions.CouldntDecodeError:
|
||||
fl.write(wav_path.name + "\n")
|
||||
return []
|
||||
full_seg = aud_seg
|
||||
# segs = random_segs(len(full_seg), min_duration, max_duration)
|
||||
segs = vad_utt.stream_segments(full_seg)
|
||||
audio_chunk_paths = []
|
||||
if len(full_seg) > min_duration:
|
||||
for (i, chunk_seg) in enumerate(segs):
|
||||
dest_wav = audio_dir / (wav_path.stem + f"_{i}.{format}")
|
||||
if dest_wav.exists():
|
||||
continue
|
||||
audio_chunk_paths.append((chunk_seg, dest_wav))
|
||||
(cache_dir / wav_path.stem).touch()
|
||||
return audio_chunk_paths
|
||||
|
||||
def random_process_wav(wav_path):
|
||||
if (cache_dir / wav_path.stem).exists():
|
||||
return []
|
||||
try:
|
||||
aud_seg = pydub.AudioSegment.from_file(wav_path)
|
||||
except pydub.exceptions.CouldntDecodeError:
|
||||
fl.write(wav_path.name + "\n")
|
||||
return []
|
||||
full_seg = aud_seg
|
||||
segs = random_segs(len(full_seg), min_duration, max_duration)
|
||||
audio_chunk_paths = []
|
||||
if len(full_seg) > min_duration:
|
||||
for (i, (start, end)) in enumerate(segs):
|
||||
dest_wav = audio_dir / (wav_path.stem + f"_{i}.{format}")
|
||||
if dest_wav.exists():
|
||||
continue
|
||||
chunk_seg = aud_seg[start:end]
|
||||
audio_chunk_paths.append((chunk_seg, dest_wav))
|
||||
(cache_dir / wav_path.stem).touch()
|
||||
return audio_chunk_paths
|
||||
|
||||
def fixed_process_wav(wav_path):
|
||||
if (cache_dir / wav_path.stem).exists():
|
||||
return []
|
||||
try:
|
||||
aud_seg = pydub.AudioSegment.from_file(wav_path)
|
||||
except pydub.exceptions.CouldntDecodeError:
|
||||
fl.write(wav_path.name + "\n")
|
||||
return []
|
||||
full_seg = aud_seg
|
||||
audio_chunk_paths = []
|
||||
if len(full_seg) > min_duration:
|
||||
for (i, chunk_seg) in enumerate(full_seg[::fixed_duration]):
|
||||
dest_wav = audio_dir / (wav_path.stem + f"_{i}.{format}")
|
||||
if dest_wav.exists() or len(chunk_seg) < min_duration:
|
||||
continue
|
||||
audio_chunk_paths.append((chunk_seg, dest_wav))
|
||||
(cache_dir / wav_path.stem).touch()
|
||||
return audio_chunk_paths
|
||||
|
||||
# warmup
|
||||
pydub.AudioSegment.from_file(all_wavs[0])
|
||||
# parallel_apply(process_wav, all_wavs, pool='process')
|
||||
# parallel_apply(process_wav, all_wavs)
|
||||
seg_f = (
|
||||
vad_process_wav
|
||||
if method == "vad"
|
||||
else (random_process_wav if method == "random" else fixed_process_wav)
|
||||
)
|
||||
for wp_batch in tqdm(batch(all_wavs, n=batch_size)):
|
||||
acp_batch = parallel_apply(seg_f, wp_batch)
|
||||
# acp_batch = list(map(seg_f, tqdm(wp_batch)))
|
||||
flat_acp_batch = [sd for acp in acp_batch for sd in acp]
|
||||
parallel_apply(write_seg_arg, flat_acp_batch)
|
||||
# for acp in acp_batch:
|
||||
# for (seg, des) in acp:
|
||||
# seg.export(des)
|
||||
# for seg_des in tqdm(flat_acp_batch):
|
||||
# write_seg_arg(seg_des)
|
||||
del flat_acp_batch
|
||||
del acp_batch
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
49
src/plume/models/wav2vec2/eval.py
Normal file
49
src/plume/models/wav2vec2/eval.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
# import pandas as pd
|
||||
|
||||
from plume.utils import (
|
||||
asr_manifest_reader,
|
||||
discard_except_digits,
|
||||
replace_digit_symbol,
|
||||
lazy_module
|
||||
# run_shell,
|
||||
)
|
||||
from ...utils.transcribe import triton_transcribe_grpc_gen
|
||||
|
||||
pd = lazy_module('pandas')
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def manifest(manifest_file: Path, result_file: Path = "results.csv"):
|
||||
from pydub import AudioSegment
|
||||
|
||||
host = "localhost"
|
||||
port = 8044
|
||||
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole')
|
||||
result_path = manifest_file.parent / result_file
|
||||
manifest_list = list(asr_manifest_reader(manifest_file))
|
||||
|
||||
def compute_frame(d):
|
||||
audio_file = d["audio_path"]
|
||||
orig_text = d["text"]
|
||||
orig_num = discard_except_digits(replace_digit_symbol(orig_text))
|
||||
aud_seg = AudioSegment.from_file(audio_file)
|
||||
t_audio = audio_prep(aud_seg)
|
||||
asr_text = transcriber(t_audio)
|
||||
asr_num = discard_except_digits(replace_digit_symbol(asr_text))
|
||||
return {
|
||||
"audio_file": audio_file,
|
||||
"asr_text": asr_text,
|
||||
"asr_num": asr_num,
|
||||
"orig_text": orig_text,
|
||||
"orig_num": orig_num,
|
||||
"asr_match": orig_num == asr_num,
|
||||
}
|
||||
|
||||
# df_data = parallel_apply(compute_frame, manifest_list)
|
||||
df_data = map(compute_frame, tqdm(manifest_list))
|
||||
df = pd.DataFrame(df_data)
|
||||
df.to_csv(result_path)
|
||||
53
src/plume/models/wav2vec2/serve.py
Normal file
53
src/plume/models/wav2vec2/serve.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# from rpyc.utils.server import ThreadedServer
|
||||
import typer
|
||||
|
||||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
# from .asr import Wav2Vec2ASR
|
||||
|
||||
ThreadedServer = lazy_callable('rpyc.utils.server.ThreadedServer')
|
||||
Wav2Vec2ASR = lazy_callable('plume.models.wav2vec2.asr.Wav2Vec2ASR')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc(
|
||||
w2v_path: Path = "/path/to/base.pt",
|
||||
ctc_path: Path = "/path/to/ctc.pt",
|
||||
target_dict_path: Path = "/path/to/dict.ltr.txt",
|
||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
):
|
||||
for p in [w2v_path, ctc_path, target_dict_path]:
|
||||
if not p.exists():
|
||||
logging.info(f"{p} doesn't exists")
|
||||
return
|
||||
w2vasr = Wav2Vec2ASR(str(ctc_path), str(w2v_path), str(target_dict_path))
|
||||
service = ASRService(w2vasr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logging.info("starting asr server...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
|
||||
ctc_path = model_dir / Path("ctc.pt")
|
||||
w2v_path = model_dir / Path("base.pt")
|
||||
target_dict_path = model_dir / Path("dict.ltr.txt")
|
||||
rpyc(w2v_path, ctc_path, target_dict_path, port)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
34
src/plume/models/wav2vec2/train.py
Normal file
34
src/plume/models/wav2vec2/train.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import typer
|
||||
# from fairseq_cli.train import cli_main
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import shlex
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
cli_main = lazy_callable('fairseq_cli.train.cli_main')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def local(dataset_path: Path):
|
||||
args = f'''--distributed-world-size 1 {dataset_path} \
|
||||
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
|
||||
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
|
||||
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
|
||||
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
|
||||
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
|
||||
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
|
||||
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
|
||||
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
|
||||
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
|
||||
'''
|
||||
new_args = ['train.py']
|
||||
new_args.extend(shlex.split(args))
|
||||
sys.argv = new_args
|
||||
cli_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
0
src/plume/models/wav2vec2_transformers/__init__.py
Normal file
0
src/plume/models/wav2vec2_transformers/__init__.py
Normal file
39
src/plume/models/wav2vec2_transformers/asr.py
Normal file
39
src/plume/models/wav2vec2_transformers/asr.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
|
||||
|
||||
# import soundfile as sf
|
||||
from io import BytesIO
|
||||
import torch
|
||||
|
||||
from plume.utils import lazy_module
|
||||
|
||||
sf = lazy_module("soundfile")
|
||||
|
||||
|
||||
class Wav2Vec2TransformersASR(object):
|
||||
"""docstring for Wav2Vec2TransformersASR."""
|
||||
|
||||
def __init__(self, ctc_path, w2v_path, target_dict_path):
|
||||
super(Wav2Vec2TransformersASR, self).__init__()
|
||||
self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
|
||||
"facebook/wav2vec2-large-960h-lv60-self"
|
||||
)
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(
|
||||
"facebook/wav2vec2-large-960h-lv60-self"
|
||||
)
|
||||
|
||||
def transcribe(self, audio_data):
|
||||
aud_f = BytesIO(audio_data)
|
||||
# net_input = {}
|
||||
speech_data, _ = sf.read(aud_f)
|
||||
input_values = self.tokenizer(
|
||||
speech_data, return_tensors="pt", padding="longest"
|
||||
).input_values # Batch size 1
|
||||
|
||||
# retrieve logits
|
||||
logits = self.model(input_values).logits
|
||||
|
||||
# take argmax and decode
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
transcription = self.tokenizer.batch_decode(predicted_ids)[0]
|
||||
return transcription
|
||||
85
src/plume/models/wav2vec2_transformers/data.py
Normal file
85
src/plume/models/wav2vec2_transformers/data.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
import shutil
|
||||
|
||||
# import pydub
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from plume.utils import (
|
||||
ExtendedPath,
|
||||
replace_redundant_spaces_with,
|
||||
lazy_module
|
||||
)
|
||||
soundfile = lazy_module('soundfile')
|
||||
pydub = lazy_module('pydub')
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True):
|
||||
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
|
||||
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
|
||||
tok_counter = Counter()
|
||||
shutil.copy(
|
||||
src_dataset_path / Path("test_manifest.json"),
|
||||
src_dataset_path / Path("valid_manifest.json"),
|
||||
)
|
||||
if unlink:
|
||||
src_wavs = src_dataset_path / Path("wavs")
|
||||
for wav_path in tqdm(list(src_wavs.glob("**/*.wav"))):
|
||||
audio_seg = (
|
||||
pydub.AudioSegment.from_wav(wav_path)
|
||||
.set_frame_rate(16000)
|
||||
.set_channels(1)
|
||||
)
|
||||
dest_path = dest_dataset_path / Path("wavs") / Path(wav_path.name)
|
||||
audio_seg.export(dest_path, format="wav")
|
||||
|
||||
for dataset_kind in ["train", "valid"]:
|
||||
abs_manifest_path = ExtendedPath(
|
||||
src_dataset_path / Path(f"{dataset_kind}_manifest.json")
|
||||
)
|
||||
manifest_data = list(abs_manifest_path.read_jsonl())
|
||||
o_tsv, o_ltr = f"{dataset_kind}.tsv", f"{dataset_kind}.ltr"
|
||||
out_tsv = dest_dataset_path / Path(o_tsv)
|
||||
out_ltr = dest_dataset_path / Path(o_ltr)
|
||||
with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f:
|
||||
if unlink:
|
||||
tsv_f.write(f"{dest_dataset_path}\n")
|
||||
else:
|
||||
tsv_f.write(f"{src_dataset_path}\n")
|
||||
for md in manifest_data:
|
||||
audio_fname = md["audio_filepath"]
|
||||
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()
|
||||
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
|
||||
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
|
||||
tok_counter.update(pipe_toks)
|
||||
letter_toks = " ".join(pipe_toks) + " |\n"
|
||||
frame_count = soundfile.info(audio_fname).frames
|
||||
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute())
|
||||
ltr_f.write(letter_toks)
|
||||
tsv_f.write(f"{rel_path}\t{frame_count}\n")
|
||||
with dict_ltr.open("w") as d_f:
|
||||
for k, v in tok_counter.most_common():
|
||||
d_f.write(f"{k} {v}\n")
|
||||
(src_dataset_path / Path("valid_manifest.json")).unlink()
|
||||
|
||||
|
||||
@app.command()
|
||||
def set_root(dataset_path: Path, root_path: Path):
|
||||
for dataset_kind in ["train", "valid"]:
|
||||
data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||
with data_file.open("r") as df:
|
||||
lines = df.readlines()
|
||||
with data_file.open("w") as df:
|
||||
lines[0] = str(root_path) + "\n"
|
||||
df.writelines(lines)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
52
src/plume/models/wav2vec2_transformers/eval.py
Normal file
52
src/plume/models/wav2vec2_transformers/eval.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from pathlib import Path
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
# import pandas as pd
|
||||
|
||||
from plume.utils import (
|
||||
asr_manifest_reader,
|
||||
discard_except_digits,
|
||||
replace_digit_symbol,
|
||||
lazy_module
|
||||
# run_shell,
|
||||
)
|
||||
from ...utils.transcribe import triton_transcribe_grpc_gen, transcribe_rpyc_gen
|
||||
|
||||
pd = lazy_module('pandas')
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False):
|
||||
from pydub import AudioSegment
|
||||
|
||||
host = "localhost"
|
||||
port = 8044
|
||||
if rpyc:
|
||||
transcriber, audio_prep = transcribe_rpyc_gen(host, port)
|
||||
else:
|
||||
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole')
|
||||
result_path = manifest_file.parent / result_file
|
||||
manifest_list = list(asr_manifest_reader(manifest_file))
|
||||
|
||||
def compute_frame(d):
|
||||
audio_file = d["audio_path"]
|
||||
orig_text = d["text"]
|
||||
orig_num = discard_except_digits(replace_digit_symbol(orig_text))
|
||||
aud_seg = AudioSegment.from_file(audio_file)
|
||||
t_audio = audio_prep(aud_seg)
|
||||
asr_text = transcriber(t_audio)
|
||||
asr_num = discard_except_digits(replace_digit_symbol(asr_text))
|
||||
return {
|
||||
"audio_file": audio_file,
|
||||
"asr_text": asr_text,
|
||||
"asr_num": asr_num,
|
||||
"orig_text": orig_text,
|
||||
"orig_num": orig_num,
|
||||
"asr_match": orig_num == asr_num,
|
||||
}
|
||||
|
||||
# df_data = parallel_apply(compute_frame, manifest_list)
|
||||
df_data = map(compute_frame, tqdm(manifest_list))
|
||||
df = pd.DataFrame(df_data)
|
||||
df.to_csv(result_path)
|
||||
52
src/plume/models/wav2vec2_transformers/serve.py
Normal file
52
src/plume/models/wav2vec2_transformers/serve.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# from rpyc.utils.server import ThreadedServer
|
||||
import typer
|
||||
|
||||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
# from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR
|
||||
# from .asr import Wav2Vec2ASR
|
||||
|
||||
ThreadedServer = lazy_callable("rpyc.utils.server.ThreadedServer")
|
||||
Wav2Vec2TransformersASR = lazy_callable(
|
||||
"plume.models.wav2vec2_transformers.asr.Wav2Vec2TransformersASR"
|
||||
)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc(
|
||||
w2v_path: Path = "/path/to/base.pt",
|
||||
ctc_path: Path = "/path/to/ctc.pt",
|
||||
target_dict_path: Path = "/path/to/dict.ltr.txt",
|
||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
):
|
||||
w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path)
|
||||
service = ASRService(w2vasr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logging.info("starting asr server...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
|
||||
ctc_path = model_dir / Path("ctc.pt")
|
||||
w2v_path = model_dir / Path("base.pt")
|
||||
target_dict_path = model_dir / Path("dict.ltr.txt")
|
||||
rpyc(w2v_path, ctc_path, target_dict_path, port)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
41
src/plume/models/wav2vec2_transformers/test.py
Normal file
41
src/plume/models/wav2vec2_transformers/test.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
|
||||
from datasets import load_dataset
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
# load model and tokenizer
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
||||
|
||||
|
||||
# define function to read in sound file
|
||||
def map_to_array(batch):
|
||||
speech, _ = sf.read(batch["file"])
|
||||
batch["speech"] = speech
|
||||
return batch
|
||||
|
||||
|
||||
# load dummy dataset and read soundfiles
|
||||
def main():
|
||||
ds = load_dataset(
|
||||
"patrickvonplaten/librispeech_asr_dummy", "clean", split="validation"
|
||||
)
|
||||
ds = ds.map(map_to_array)
|
||||
|
||||
# tokenize
|
||||
input_values = tokenizer(
|
||||
ds["speech"][:2], return_tensors="pt", padding="longest"
|
||||
).input_values # Batch size 1
|
||||
|
||||
# retrieve logits
|
||||
logits = model(input_values).logits
|
||||
|
||||
# take argmax and decode
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
transcription = tokenizer.batch_decode(predicted_ids)
|
||||
print(transcription)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
34
src/plume/models/wav2vec2_transformers/train.py
Normal file
34
src/plume/models/wav2vec2_transformers/train.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import typer
|
||||
# from fairseq_cli.train import cli_main
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import shlex
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
cli_main = lazy_callable('fairseq_cli.train.cli_main')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def local(dataset_path: Path):
|
||||
args = f'''--distributed-world-size 1 {dataset_path} \
|
||||
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
|
||||
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
|
||||
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
|
||||
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
|
||||
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
|
||||
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
|
||||
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
|
||||
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
|
||||
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
|
||||
'''
|
||||
new_args = ['train.py']
|
||||
new_args.extend(shlex.split(args))
|
||||
sys.argv = new_args
|
||||
cli_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
Reference in New Issue
Block a user