plume-asr/jasper/training/data_loaders.py

335 lines
11 KiB
Python

from functools import partial
import tempfile
# from typing import Any, Dict, List, Optional
import torch
import nemo
# import nemo.collections.asr as nemo_asr
from nemo.backends.pytorch import DataLayerNM
from nemo.core import DeviceType
# from nemo.core.neural_types import *
from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType
from nemo.utils.decorators import add_port_docs
from nemo.collections.asr.parts.dataset import (
# AudioDataset,
# AudioLabelDataset,
# KaldiFeatureDataset,
# TranscriptDataset,
parsers,
collections,
seq_collate_fn,
)
# from functools import lru_cache
import rpyc
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from .featurizer import RpycWaveformFeaturizer
# from nemo.collections.asr.parts.features import WaveformFeaturizer
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
logging = nemo.logging
class CachedAudioDataset(torch.utils.data.Dataset):
"""
Dataset that loads tensors via a json file containing paths to audio
files, transcripts, and durations (in seconds). Each new line is a
different sample. Example below:
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
"/path/to/audio.txt", "duration": 23.147}
...
{"audio_filepath": "/path/to/audio.wav", "text": "the
transcription", offset": 301.75, "duration": 0.82, "utt":
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
Args:
manifest_filepath: Path to manifest json as described above. Can
be comma-separated paths.
labels: String containing all the possible characters to map to
featurizer: Initialized featurizer class that converts paths of
audio to feature tensors
max_duration: If audio exceeds this length, do not include in dataset
min_duration: If audio is less than this length, do not include
in dataset
max_utts: Limit number of utterances
blank_index: blank character index, default = -1
unk_index: unk_character index, default = -1
normalize: whether to normalize transcript text (default): True
bos_id: Id of beginning of sequence symbol to append if not None
eos_id: Id of end of sequence symbol to append if not None
load_audio: Boolean flag indicate whether do or not load audio
"""
def __init__(
self,
manifest_filepath,
labels,
featurizer,
max_duration=None,
min_duration=None,
max_utts=0,
blank_index=-1,
unk_index=-1,
normalize=True,
trim=False,
bos_id=None,
eos_id=None,
load_audio=True,
parser="en",
):
self.collection = collections.ASRAudioText(
manifests_files=manifest_filepath.split(","),
parser=parsers.make_parser(
labels=labels,
name=parser,
unk_id=unk_index,
blank_id=blank_index,
do_normalize=normalize,
),
min_duration=min_duration,
max_duration=max_duration,
max_number=max_utts,
)
self.index_feature_map = {}
self.featurizer = featurizer
self.trim = trim
self.eos_id = eos_id
self.bos_id = bos_id
self.load_audio = load_audio
print(f"initializing dataset {manifest_filepath}")
def exec_func(i):
return self[i]
task_count = len(self.collection)
with ThreadPoolExecutor() as exe:
print("starting all loading tasks")
list(
tqdm(
exe.map(exec_func, range(task_count)),
position=0,
leave=True,
total=task_count,
)
)
print(f"initializing complete")
def __getitem__(self, index):
sample = self.collection[index]
if self.load_audio:
cached_features = self.index_feature_map.get(index)
if cached_features is not None:
features = cached_features
else:
features = self.featurizer.process(
sample.audio_file,
offset=0,
duration=sample.duration,
trim=self.trim,
)
self.index_feature_map[index] = features
f, fl = features, torch.tensor(features.shape[0]).long()
else:
f, fl = None, None
t, tl = sample.text_tokens, len(sample.text_tokens)
if self.bos_id is not None:
t = [self.bos_id] + t
tl += 1
if self.eos_id is not None:
t = t + [self.eos_id]
tl += 1
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
def __len__(self):
return len(self.collection)
class RpycAudioToTextDataLayer(DataLayerNM):
"""Data Layer for general ASR tasks.
Module which reads ASR labeled data. It accepts comma-separated
JSON manifest files describing the correspondence between wav audio files
and their transcripts. JSON files should be of the following format::
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
transcript_0}
...
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
transcript_n}
Args:
manifest_filepath (str): Dataset parameter.
Path to JSON containing data.
labels (list): Dataset parameter.
List of characters that can be output by the ASR model.
For Jasper, this is the 28 character set {a-z '}. The CTC blank
symbol is automatically added later for models using ctc.
batch_size (int): batch size
sample_rate (int): Target sampling rate for data. Audio files will be
resampled to sample_rate if it is not already.
Defaults to 16000.
int_values (bool): Bool indicating whether the audio file is saved as
int data or float data.
Defaults to False.
eos_id (id): Dataset parameter.
End of string symbol id used for seq2seq models.
Defaults to None.
min_duration (float): Dataset parameter.
All training files which have a duration less than min_duration
are dropped. Note: Duration is read from the manifest JSON.
Defaults to 0.1.
max_duration (float): Dataset parameter.
All training files which have a duration more than max_duration
are dropped. Note: Duration is read from the manifest JSON.
Defaults to None.
normalize_transcripts (bool): Dataset parameter.
Whether to use automatic text cleaning.
It is highly recommended to manually clean text for best results.
Defaults to True.
trim_silence (bool): Whether to use trim silence from beginning and end
of audio signal using librosa.effects.trim().
Defaults to False.
load_audio (bool): Dataset parameter.
Controls whether the dataloader loads the audio signal and
transcript or just the transcript.
Defaults to True.
drop_last (bool): See PyTorch DataLoader.
Defaults to False.
shuffle (bool): See PyTorch DataLoader.
Defaults to True.
num_workers (int): See PyTorch DataLoader.
Defaults to 0.
perturb_config (dict): Currently disabled.
"""
@property
@add_port_docs()
def output_ports(self):
"""Returns definitions of module output ports.
"""
return {
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
"audio_signal": NeuralType(
("B", "T"),
AudioSignal(freq=self._sample_rate)
if self is not None and self._sample_rate is not None
else AudioSignal(),
),
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
"transcripts": NeuralType(("B", "T"), LabelsType()),
"transcript_length": NeuralType(tuple("B"), LengthsType()),
}
def __init__(
self,
manifest_filepath,
labels,
batch_size,
sample_rate=16000,
int_values=False,
bos_id=None,
eos_id=None,
pad_id=None,
min_duration=0.1,
max_duration=None,
normalize_transcripts=True,
trim_silence=False,
load_audio=True,
rpyc_host="",
drop_last=False,
shuffle=True,
num_workers=0,
):
super().__init__()
self._sample_rate = sample_rate
def rpyc_root_fn():
return rpyc.connect(
rpyc_host, 8064, config={"sync_request_timeout": 600}
).root
rpyc_conn = rpyc_root_fn()
self._featurizer = RpycWaveformFeaturizer(
sample_rate=self._sample_rate,
int_values=int_values,
augmentor=None,
rpyc_conn=rpyc_conn,
)
def read_remote_manifests():
local_mp = []
for mrp in manifest_filepath.split(","):
md = rpyc_conn.read_path(mrp)
mf = tempfile.NamedTemporaryFile(
dir="/tmp", prefix="jasper_manifest.", delete=False
)
mf.write(md)
mf.close()
local_mp.append(mf.name)
return ",".join(local_mp)
local_manifest_filepath = read_remote_manifests()
dataset_params = {
"manifest_filepath": local_manifest_filepath,
"labels": labels,
"featurizer": self._featurizer,
"max_duration": max_duration,
"min_duration": min_duration,
"normalize": normalize_transcripts,
"trim": trim_silence,
"bos_id": bos_id,
"eos_id": eos_id,
"load_audio": load_audio,
}
self._dataset = CachedAudioDataset(**dataset_params)
self._batch_size = batch_size
# Set up data loader
if self._placement == DeviceType.AllGpu:
logging.info("Parallelizing Datalayer.")
sampler = torch.utils.data.distributed.DistributedSampler(self._dataset)
else:
sampler = None
if batch_size == -1:
batch_size = len(self._dataset)
pad_id = 0 if pad_id is None else pad_id
self._dataloader = torch.utils.data.DataLoader(
dataset=self._dataset,
batch_size=batch_size,
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
drop_last=drop_last,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=1,
)
def __len__(self):
return len(self._dataset)
@property
def dataset(self):
return None
@property
def data_iterator(self):
return self._dataloader