335 lines
11 KiB
Python
335 lines
11 KiB
Python
from functools import partial
|
|
import tempfile
|
|
|
|
# from typing import Any, Dict, List, Optional
|
|
|
|
import torch
|
|
import nemo
|
|
|
|
# import nemo.collections.asr as nemo_asr
|
|
from nemo.backends.pytorch import DataLayerNM
|
|
from nemo.core import DeviceType
|
|
|
|
# from nemo.core.neural_types import *
|
|
from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType
|
|
from nemo.utils.decorators import add_port_docs
|
|
|
|
from nemo.collections.asr.parts.dataset import (
|
|
# AudioDataset,
|
|
# AudioLabelDataset,
|
|
# KaldiFeatureDataset,
|
|
# TranscriptDataset,
|
|
parsers,
|
|
collections,
|
|
seq_collate_fn,
|
|
)
|
|
|
|
# from functools import lru_cache
|
|
import rpyc
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from tqdm import tqdm
|
|
from .featurizer import RpycWaveformFeaturizer
|
|
|
|
# from nemo.collections.asr.parts.features import WaveformFeaturizer
|
|
|
|
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
|
|
|
|
|
|
logging = nemo.logging
|
|
|
|
|
|
class CachedAudioDataset(torch.utils.data.Dataset):
|
|
"""
|
|
Dataset that loads tensors via a json file containing paths to audio
|
|
files, transcripts, and durations (in seconds). Each new line is a
|
|
different sample. Example below:
|
|
|
|
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
|
|
"/path/to/audio.txt", "duration": 23.147}
|
|
...
|
|
{"audio_filepath": "/path/to/audio.wav", "text": "the
|
|
transcription", offset": 301.75, "duration": 0.82, "utt":
|
|
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
|
|
|
Args:
|
|
manifest_filepath: Path to manifest json as described above. Can
|
|
be comma-separated paths.
|
|
labels: String containing all the possible characters to map to
|
|
featurizer: Initialized featurizer class that converts paths of
|
|
audio to feature tensors
|
|
max_duration: If audio exceeds this length, do not include in dataset
|
|
min_duration: If audio is less than this length, do not include
|
|
in dataset
|
|
max_utts: Limit number of utterances
|
|
blank_index: blank character index, default = -1
|
|
unk_index: unk_character index, default = -1
|
|
normalize: whether to normalize transcript text (default): True
|
|
bos_id: Id of beginning of sequence symbol to append if not None
|
|
eos_id: Id of end of sequence symbol to append if not None
|
|
load_audio: Boolean flag indicate whether do or not load audio
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
manifest_filepath,
|
|
labels,
|
|
featurizer,
|
|
max_duration=None,
|
|
min_duration=None,
|
|
max_utts=0,
|
|
blank_index=-1,
|
|
unk_index=-1,
|
|
normalize=True,
|
|
trim=False,
|
|
bos_id=None,
|
|
eos_id=None,
|
|
load_audio=True,
|
|
parser="en",
|
|
):
|
|
self.collection = collections.ASRAudioText(
|
|
manifests_files=manifest_filepath.split(","),
|
|
parser=parsers.make_parser(
|
|
labels=labels,
|
|
name=parser,
|
|
unk_id=unk_index,
|
|
blank_id=blank_index,
|
|
do_normalize=normalize,
|
|
),
|
|
min_duration=min_duration,
|
|
max_duration=max_duration,
|
|
max_number=max_utts,
|
|
)
|
|
self.index_feature_map = {}
|
|
|
|
self.featurizer = featurizer
|
|
self.trim = trim
|
|
self.eos_id = eos_id
|
|
self.bos_id = bos_id
|
|
self.load_audio = load_audio
|
|
print(f"initializing dataset {manifest_filepath}")
|
|
|
|
def exec_func(i):
|
|
return self[i]
|
|
|
|
task_count = len(self.collection)
|
|
with ThreadPoolExecutor() as exe:
|
|
print("starting all loading tasks")
|
|
list(
|
|
tqdm(
|
|
exe.map(exec_func, range(task_count)),
|
|
position=0,
|
|
leave=True,
|
|
total=task_count,
|
|
)
|
|
)
|
|
print(f"initializing complete")
|
|
|
|
def __getitem__(self, index):
|
|
sample = self.collection[index]
|
|
if self.load_audio:
|
|
cached_features = self.index_feature_map.get(index)
|
|
if cached_features is not None:
|
|
features = cached_features
|
|
else:
|
|
features = self.featurizer.process(
|
|
sample.audio_file,
|
|
offset=0,
|
|
duration=sample.duration,
|
|
trim=self.trim,
|
|
)
|
|
self.index_feature_map[index] = features
|
|
f, fl = features, torch.tensor(features.shape[0]).long()
|
|
else:
|
|
f, fl = None, None
|
|
|
|
t, tl = sample.text_tokens, len(sample.text_tokens)
|
|
if self.bos_id is not None:
|
|
t = [self.bos_id] + t
|
|
tl += 1
|
|
if self.eos_id is not None:
|
|
t = t + [self.eos_id]
|
|
tl += 1
|
|
|
|
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
|
|
|
def __len__(self):
|
|
return len(self.collection)
|
|
|
|
|
|
class RpycAudioToTextDataLayer(DataLayerNM):
|
|
"""Data Layer for general ASR tasks.
|
|
|
|
Module which reads ASR labeled data. It accepts comma-separated
|
|
JSON manifest files describing the correspondence between wav audio files
|
|
and their transcripts. JSON files should be of the following format::
|
|
|
|
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
|
|
transcript_0}
|
|
...
|
|
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
|
|
transcript_n}
|
|
|
|
Args:
|
|
manifest_filepath (str): Dataset parameter.
|
|
Path to JSON containing data.
|
|
labels (list): Dataset parameter.
|
|
List of characters that can be output by the ASR model.
|
|
For Jasper, this is the 28 character set {a-z '}. The CTC blank
|
|
symbol is automatically added later for models using ctc.
|
|
batch_size (int): batch size
|
|
sample_rate (int): Target sampling rate for data. Audio files will be
|
|
resampled to sample_rate if it is not already.
|
|
Defaults to 16000.
|
|
int_values (bool): Bool indicating whether the audio file is saved as
|
|
int data or float data.
|
|
Defaults to False.
|
|
eos_id (id): Dataset parameter.
|
|
End of string symbol id used for seq2seq models.
|
|
Defaults to None.
|
|
min_duration (float): Dataset parameter.
|
|
All training files which have a duration less than min_duration
|
|
are dropped. Note: Duration is read from the manifest JSON.
|
|
Defaults to 0.1.
|
|
max_duration (float): Dataset parameter.
|
|
All training files which have a duration more than max_duration
|
|
are dropped. Note: Duration is read from the manifest JSON.
|
|
Defaults to None.
|
|
normalize_transcripts (bool): Dataset parameter.
|
|
Whether to use automatic text cleaning.
|
|
It is highly recommended to manually clean text for best results.
|
|
Defaults to True.
|
|
trim_silence (bool): Whether to use trim silence from beginning and end
|
|
of audio signal using librosa.effects.trim().
|
|
Defaults to False.
|
|
load_audio (bool): Dataset parameter.
|
|
Controls whether the dataloader loads the audio signal and
|
|
transcript or just the transcript.
|
|
Defaults to True.
|
|
drop_last (bool): See PyTorch DataLoader.
|
|
Defaults to False.
|
|
shuffle (bool): See PyTorch DataLoader.
|
|
Defaults to True.
|
|
num_workers (int): See PyTorch DataLoader.
|
|
Defaults to 0.
|
|
perturb_config (dict): Currently disabled.
|
|
"""
|
|
|
|
@property
|
|
@add_port_docs()
|
|
def output_ports(self):
|
|
"""Returns definitions of module output ports.
|
|
"""
|
|
return {
|
|
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
|
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
|
|
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
|
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
|
|
"audio_signal": NeuralType(
|
|
("B", "T"),
|
|
AudioSignal(freq=self._sample_rate)
|
|
if self is not None and self._sample_rate is not None
|
|
else AudioSignal(),
|
|
),
|
|
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
|
|
"transcripts": NeuralType(("B", "T"), LabelsType()),
|
|
"transcript_length": NeuralType(tuple("B"), LengthsType()),
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
manifest_filepath,
|
|
labels,
|
|
batch_size,
|
|
sample_rate=16000,
|
|
int_values=False,
|
|
bos_id=None,
|
|
eos_id=None,
|
|
pad_id=None,
|
|
min_duration=0.1,
|
|
max_duration=None,
|
|
normalize_transcripts=True,
|
|
trim_silence=False,
|
|
load_audio=True,
|
|
rpyc_host="",
|
|
drop_last=False,
|
|
shuffle=True,
|
|
num_workers=0,
|
|
):
|
|
super().__init__()
|
|
self._sample_rate = sample_rate
|
|
|
|
def rpyc_root_fn():
|
|
return rpyc.connect(
|
|
rpyc_host, 8064, config={"sync_request_timeout": 600}
|
|
).root
|
|
|
|
rpyc_conn = rpyc_root_fn()
|
|
|
|
self._featurizer = RpycWaveformFeaturizer(
|
|
sample_rate=self._sample_rate,
|
|
int_values=int_values,
|
|
augmentor=None,
|
|
rpyc_conn=rpyc_conn,
|
|
)
|
|
|
|
def read_remote_manifests():
|
|
local_mp = []
|
|
for mrp in manifest_filepath.split(","):
|
|
md = rpyc_conn.read_path(mrp)
|
|
mf = tempfile.NamedTemporaryFile(
|
|
dir="/tmp", prefix="jasper_manifest.", delete=False
|
|
)
|
|
mf.write(md)
|
|
mf.close()
|
|
local_mp.append(mf.name)
|
|
return ",".join(local_mp)
|
|
|
|
local_manifest_filepath = read_remote_manifests()
|
|
dataset_params = {
|
|
"manifest_filepath": local_manifest_filepath,
|
|
"labels": labels,
|
|
"featurizer": self._featurizer,
|
|
"max_duration": max_duration,
|
|
"min_duration": min_duration,
|
|
"normalize": normalize_transcripts,
|
|
"trim": trim_silence,
|
|
"bos_id": bos_id,
|
|
"eos_id": eos_id,
|
|
"load_audio": load_audio,
|
|
}
|
|
|
|
self._dataset = CachedAudioDataset(**dataset_params)
|
|
self._batch_size = batch_size
|
|
|
|
# Set up data loader
|
|
if self._placement == DeviceType.AllGpu:
|
|
logging.info("Parallelizing Datalayer.")
|
|
sampler = torch.utils.data.distributed.DistributedSampler(self._dataset)
|
|
else:
|
|
sampler = None
|
|
|
|
if batch_size == -1:
|
|
batch_size = len(self._dataset)
|
|
|
|
pad_id = 0 if pad_id is None else pad_id
|
|
self._dataloader = torch.utils.data.DataLoader(
|
|
dataset=self._dataset,
|
|
batch_size=batch_size,
|
|
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
|
|
drop_last=drop_last,
|
|
shuffle=shuffle if sampler is None else False,
|
|
sampler=sampler,
|
|
num_workers=1,
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self._dataset)
|
|
|
|
@property
|
|
def dataset(self):
|
|
return None
|
|
|
|
@property
|
|
def data_iterator(self):
|
|
return self._dataloader
|