plume-asr/jasper/training/data_loaders.py

335 lines
11 KiB
Python
Raw Normal View History

1. integrated data generator using google tts 2. added training script fix module packaging issue implement call audio data recycler for asr 1. added streamlit based validation ui with mongodb datastore integration 2. fix asr wrong sample rate inference 3. update requirements 1. refactored streamlit code 2. fixed issues in data manifest handling refresh to next entry on submit and comment out mongo clearing code for safety :P add validation ui and post processing to correct using validation data 1. added a tool to extract asr data from gcp transcripts logs 2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file 3. clean-up leaderboard duration logic 4. added a wip dataloader service 5. made the asr_data_writer util more generic with verbose flags and unique filename 6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node 7. refactored the validation post processing to dump a ui config for validation 8. included utility functions to correct, fill update and clear annotations from mongodb data 9. refactored the ui logic to be more generic for any asr data 10. updated setup.py dependencies to support the above features unlink temporary files after transcribing 1. clean-up unused data process code 2. fix invalid sample no from mongo 3. data loader service return remote netref 1. added training utils with custom data loaders with remote rpyc dataservice support 2. fix validation correction dump path 3. cache dataset for precaching before training to memory 4. update dependencies 1. implement dataset augmentation and validation in process 2. added option to skip 'incorrect' annotations in validation data 3. added confirmation on clearing mongo collection 4. added an option to navigate to a given text in the validation ui 5. added a dataset and remote option to trainer to load dataset from directory and remote rpyc service 1. added utility command to export call logs 2. mongo conn accepts port refactored module structure 1. enabled silece stripping in chunks when recycling audio from asr logs 2. limit asr recycling to 1 min of start audio to get reliable alignments and ignoring agent channel 3. added rev recycler for generating asr dataset from rev transcripts and audio 4. update pydub dependency for silence stripping fn and removing threadpool hardcoded worker count 1. added support for mono/dual channel rev transcripts 2. handle errors when extracting datapoints from rev meta data 3. added suport for annotation only task when dumping ui data cleanup rev recycle added option to disable plots during validation fix skipping null audio and add more verbose logs respect verbose flag don't load audio for annotation only ui and keep spoken as text for normal asr validation 1. refactored wav chunk processing method 2. renamed streamlit to validation_ui show duration on validation of dataset parallelize data loading from remote skipping invalid data points 1. removed the transcriber_pretrained/speller from utils 2. introduced get_mongo_coll to get the collection object directly from mongo uri 3. removed processing of correction entries to remove space/upper casing refactor validation process arguments and logging 1. added a data extraction type argument 2. cleanup/refactor 1. using dataname args for update/fill annotations 2. rename to dump_ui added support for name/dates/cities call data extraction and more logs handling non-pnr cases without parens in text data 1. added conv data generator 2. more utils 1. added start delay arg in call recycler 2. implement ui_dump/manifest writer in call_recycler itself 3. refactored call data point plotter 4. added sample-ui task-ui on the validation process 5. implemented call-quality stats using corrections from mongo 6. support deleting cursors on mongo 7. implement multiple task support on validation ui based on task_id mongo field fix 11st to 11th in ordinal stripping silence on call chunk 1. added option to strip silent chunks 2. computing caller quality based on task-id of corrections 1. fix update-correction to use ui_dump instead of manifest 2. update training params no of checkpoints on chpk frequency 1. split extract all data types in one shot with --extraction-type all flag 2. add notes about diffing split extracted and original data 3. add a nlu conv generator to generate conv data based on nlu utterances and entities 4. add task uid support for dumping corrections 5. abstracted generate date fn 1. added a test generator and slu evaluator 2. ui dump now include gcp results 3. showing default option for more args validation process commands added evaluation command clean-up
2020-04-08 11:56:27 +00:00
from functools import partial
import tempfile
# from typing import Any, Dict, List, Optional
import torch
import nemo
# import nemo.collections.asr as nemo_asr
from nemo.backends.pytorch import DataLayerNM
from nemo.core import DeviceType
# from nemo.core.neural_types import *
from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType
from nemo.utils.decorators import add_port_docs
from nemo.collections.asr.parts.dataset import (
# AudioDataset,
# AudioLabelDataset,
# KaldiFeatureDataset,
# TranscriptDataset,
parsers,
collections,
seq_collate_fn,
)
# from functools import lru_cache
import rpyc
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from .featurizer import RpycWaveformFeaturizer
# from nemo.collections.asr.parts.features import WaveformFeaturizer
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
logging = nemo.logging
class CachedAudioDataset(torch.utils.data.Dataset):
"""
Dataset that loads tensors via a json file containing paths to audio
files, transcripts, and durations (in seconds). Each new line is a
different sample. Example below:
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
"/path/to/audio.txt", "duration": 23.147}
...
{"audio_filepath": "/path/to/audio.wav", "text": "the
transcription", offset": 301.75, "duration": 0.82, "utt":
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
Args:
manifest_filepath: Path to manifest json as described above. Can
be comma-separated paths.
labels: String containing all the possible characters to map to
featurizer: Initialized featurizer class that converts paths of
audio to feature tensors
max_duration: If audio exceeds this length, do not include in dataset
min_duration: If audio is less than this length, do not include
in dataset
max_utts: Limit number of utterances
blank_index: blank character index, default = -1
unk_index: unk_character index, default = -1
normalize: whether to normalize transcript text (default): True
bos_id: Id of beginning of sequence symbol to append if not None
eos_id: Id of end of sequence symbol to append if not None
load_audio: Boolean flag indicate whether do or not load audio
"""
def __init__(
self,
manifest_filepath,
labels,
featurizer,
max_duration=None,
min_duration=None,
max_utts=0,
blank_index=-1,
unk_index=-1,
normalize=True,
trim=False,
bos_id=None,
eos_id=None,
load_audio=True,
parser="en",
):
self.collection = collections.ASRAudioText(
manifests_files=manifest_filepath.split(","),
parser=parsers.make_parser(
labels=labels,
name=parser,
unk_id=unk_index,
blank_id=blank_index,
do_normalize=normalize,
),
min_duration=min_duration,
max_duration=max_duration,
max_number=max_utts,
)
self.index_feature_map = {}
self.featurizer = featurizer
self.trim = trim
self.eos_id = eos_id
self.bos_id = bos_id
self.load_audio = load_audio
print(f"initializing dataset {manifest_filepath}")
def exec_func(i):
return self[i]
task_count = len(self.collection)
with ThreadPoolExecutor() as exe:
print("starting all loading tasks")
list(
tqdm(
exe.map(exec_func, range(task_count)),
position=0,
leave=True,
total=task_count,
)
)
print(f"initializing complete")
def __getitem__(self, index):
sample = self.collection[index]
if self.load_audio:
cached_features = self.index_feature_map.get(index)
if cached_features is not None:
features = cached_features
else:
features = self.featurizer.process(
sample.audio_file,
offset=0,
duration=sample.duration,
trim=self.trim,
)
self.index_feature_map[index] = features
f, fl = features, torch.tensor(features.shape[0]).long()
else:
f, fl = None, None
t, tl = sample.text_tokens, len(sample.text_tokens)
if self.bos_id is not None:
t = [self.bos_id] + t
tl += 1
if self.eos_id is not None:
t = t + [self.eos_id]
tl += 1
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
def __len__(self):
return len(self.collection)
class RpycAudioToTextDataLayer(DataLayerNM):
"""Data Layer for general ASR tasks.
Module which reads ASR labeled data. It accepts comma-separated
JSON manifest files describing the correspondence between wav audio files
and their transcripts. JSON files should be of the following format::
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
transcript_0}
...
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
transcript_n}
Args:
manifest_filepath (str): Dataset parameter.
Path to JSON containing data.
labels (list): Dataset parameter.
List of characters that can be output by the ASR model.
For Jasper, this is the 28 character set {a-z '}. The CTC blank
symbol is automatically added later for models using ctc.
batch_size (int): batch size
sample_rate (int): Target sampling rate for data. Audio files will be
resampled to sample_rate if it is not already.
Defaults to 16000.
int_values (bool): Bool indicating whether the audio file is saved as
int data or float data.
Defaults to False.
eos_id (id): Dataset parameter.
End of string symbol id used for seq2seq models.
Defaults to None.
min_duration (float): Dataset parameter.
All training files which have a duration less than min_duration
are dropped. Note: Duration is read from the manifest JSON.
Defaults to 0.1.
max_duration (float): Dataset parameter.
All training files which have a duration more than max_duration
are dropped. Note: Duration is read from the manifest JSON.
Defaults to None.
normalize_transcripts (bool): Dataset parameter.
Whether to use automatic text cleaning.
It is highly recommended to manually clean text for best results.
Defaults to True.
trim_silence (bool): Whether to use trim silence from beginning and end
of audio signal using librosa.effects.trim().
Defaults to False.
load_audio (bool): Dataset parameter.
Controls whether the dataloader loads the audio signal and
transcript or just the transcript.
Defaults to True.
drop_last (bool): See PyTorch DataLoader.
Defaults to False.
shuffle (bool): See PyTorch DataLoader.
Defaults to True.
num_workers (int): See PyTorch DataLoader.
Defaults to 0.
perturb_config (dict): Currently disabled.
"""
@property
@add_port_docs()
def output_ports(self):
"""Returns definitions of module output ports.
"""
return {
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
"audio_signal": NeuralType(
("B", "T"),
AudioSignal(freq=self._sample_rate)
if self is not None and self._sample_rate is not None
else AudioSignal(),
),
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
"transcripts": NeuralType(("B", "T"), LabelsType()),
"transcript_length": NeuralType(tuple("B"), LengthsType()),
}
def __init__(
self,
manifest_filepath,
labels,
batch_size,
sample_rate=16000,
int_values=False,
bos_id=None,
eos_id=None,
pad_id=None,
min_duration=0.1,
max_duration=None,
normalize_transcripts=True,
trim_silence=False,
load_audio=True,
rpyc_host="",
drop_last=False,
shuffle=True,
num_workers=0,
):
super().__init__()
self._sample_rate = sample_rate
def rpyc_root_fn():
return rpyc.connect(
rpyc_host, 8064, config={"sync_request_timeout": 600}
).root
rpyc_conn = rpyc_root_fn()
self._featurizer = RpycWaveformFeaturizer(
sample_rate=self._sample_rate,
int_values=int_values,
augmentor=None,
rpyc_conn=rpyc_conn,
)
def read_remote_manifests():
local_mp = []
for mrp in manifest_filepath.split(","):
md = rpyc_conn.read_path(mrp)
mf = tempfile.NamedTemporaryFile(
dir="/tmp", prefix="jasper_manifest.", delete=False
)
mf.write(md)
mf.close()
local_mp.append(mf.name)
return ",".join(local_mp)
local_manifest_filepath = read_remote_manifests()
dataset_params = {
"manifest_filepath": local_manifest_filepath,
"labels": labels,
"featurizer": self._featurizer,
"max_duration": max_duration,
"min_duration": min_duration,
"normalize": normalize_transcripts,
"trim": trim_silence,
"bos_id": bos_id,
"eos_id": eos_id,
"load_audio": load_audio,
}
self._dataset = CachedAudioDataset(**dataset_params)
self._batch_size = batch_size
# Set up data loader
if self._placement == DeviceType.AllGpu:
logging.info("Parallelizing Datalayer.")
sampler = torch.utils.data.distributed.DistributedSampler(self._dataset)
else:
sampler = None
if batch_size == -1:
batch_size = len(self._dataset)
pad_id = 0 if pad_id is None else pad_id
self._dataloader = torch.utils.data.DataLoader(
dataset=self._dataset,
batch_size=batch_size,
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
drop_last=drop_last,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=1,
)
def __len__(self):
return len(self._dataset)
@property
def dataset(self):
return None
@property
def data_iterator(self):
return self._dataloader