1. integrated data generator using google tts
2. added training script
parent
f7ebd8e90a
commit
d22a99a4f6
|
|
@ -0,0 +1 @@
|
||||||
|
|
||||||
|
|
@ -0,0 +1,65 @@
|
||||||
|
# import io
|
||||||
|
# import sys
|
||||||
|
# import json
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from .utils import random_pnr_generator, manifest_str
|
||||||
|
from .tts.googletts import GoogleTTS
|
||||||
|
from tqdm import tqdm
|
||||||
|
import random
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_asr_data(output_dir, count):
|
||||||
|
google_voices = GoogleTTS.voice_list()
|
||||||
|
gtts = GoogleTTS()
|
||||||
|
wav_dir = output_dir / Path("pnr_data")
|
||||||
|
wav_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
asr_manifest = output_dir / Path("pnr_data").with_suffix(".json")
|
||||||
|
with asr_manifest.open("w") as mf:
|
||||||
|
for pnr_code in tqdm(random_pnr_generator(count)):
|
||||||
|
tts_code = (
|
||||||
|
f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
|
||||||
|
)
|
||||||
|
param = random.choice(google_voices)
|
||||||
|
param["sample_rate"] = 24000
|
||||||
|
param["num_channels"] = 1
|
||||||
|
wav_data = gtts.text_to_speech(text=tts_code, params=param)
|
||||||
|
audio_dur = len(wav_data[44:]) / (2 * 24000)
|
||||||
|
pnr_af = wav_dir / Path(pnr_code).with_suffix(".wav")
|
||||||
|
pnr_af.write_bytes(wav_data)
|
||||||
|
rel_pnr_path = pnr_af.relative_to(output_dir)
|
||||||
|
manifest = manifest_str(str(rel_pnr_path), audio_dur, pnr_code)
|
||||||
|
mf.write(manifest)
|
||||||
|
|
||||||
|
|
||||||
|
def arg_parser():
|
||||||
|
prog = Path(__file__).stem
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog=prog, description=f"generates asr training data"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("./train/asr_data"),
|
||||||
|
help="directory to output asr data",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--count", type=int, default=3, help="number of datapoints to generate"
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = arg_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
generate_asr_data(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
import concurrent.futures
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
URLS = [
|
||||||
|
"http://www.foxnews.com/",
|
||||||
|
"http://www.cnn.com/",
|
||||||
|
"http://europe.wsj.com/",
|
||||||
|
"http://www.bbc.co.uk/",
|
||||||
|
"http://some-made-up-domain.com/",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Retrieve a single page and report the URL and contents
|
||||||
|
def load_url(url, timeout):
|
||||||
|
with urllib.request.urlopen(url, timeout=timeout) as conn:
|
||||||
|
return conn.read()
|
||||||
|
|
||||||
|
|
||||||
|
# We can use a with statement to ensure threads are cleaned up promptly
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
|
# Start the load operations and mark each future with its URL
|
||||||
|
future_to_url = {executor.submit(load_url, url, 60): url for url in URLS}
|
||||||
|
for future in concurrent.futures.as_completed(future_to_url):
|
||||||
|
url = future_to_url[future]
|
||||||
|
try:
|
||||||
|
data = future.result()
|
||||||
|
except Exception as exc:
|
||||||
|
print("%r generated an exception: %s" % (url, exc))
|
||||||
|
else:
|
||||||
|
print("%r page is %d bytes" % (url, len(data)))
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from num2words import num2words
|
||||||
|
|
||||||
|
|
||||||
|
def separate_space_convert_digit_setpath():
|
||||||
|
with Path("/home/malar/work/asr-data-utils/asr_data/pnr_data.json").open("r") as pf:
|
||||||
|
pnr_jsonl = pf.readlines()
|
||||||
|
|
||||||
|
pnr_data = [json.loads(i) for i in pnr_jsonl]
|
||||||
|
|
||||||
|
new_pnr_data = []
|
||||||
|
for i in pnr_data:
|
||||||
|
letters = " ".join(list(i["text"]))
|
||||||
|
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
|
||||||
|
i["text"] = ("".join(num_tokens)).lower()
|
||||||
|
i["audio_filepath"] = i["audio_filepath"].replace(
|
||||||
|
"pnr_data/", "/dataset/asr_data/pnr_data/wav/"
|
||||||
|
)
|
||||||
|
new_pnr_data.append(i)
|
||||||
|
|
||||||
|
new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data]
|
||||||
|
|
||||||
|
with Path("/dataset/asr_data/pnr_data/pnr_data.json").open("w") as pf:
|
||||||
|
new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n"
|
||||||
|
pf.write(new_pnr_data)
|
||||||
|
|
||||||
|
|
||||||
|
separate_space_convert_digit_setpath()
|
||||||
|
|
||||||
|
|
||||||
|
def split_data():
|
||||||
|
with Path("/dataset/asr_data/pnr_data/pnr_data.json").open("r") as pf:
|
||||||
|
pnr_jsonl = pf.readlines()
|
||||||
|
train_pnr, test_pnr = train_test_split(pnr_jsonl, test_size=0.1)
|
||||||
|
with Path("/dataset/asr_data/pnr_data/train_manifest.json").open("w") as pf:
|
||||||
|
pnr_data = "".join(train_pnr)
|
||||||
|
pf.write(pnr_data)
|
||||||
|
with Path("/dataset/asr_data/pnr_data/test_manifest.json").open("w") as pf:
|
||||||
|
pnr_data = "".join(test_pnr)
|
||||||
|
pf.write(pnr_data)
|
||||||
|
|
||||||
|
|
||||||
|
split_data()
|
||||||
|
|
||||||
|
|
||||||
|
def augment_an4():
|
||||||
|
an4_train = Path("/dataset/asr_data/an4/train_manifest.json").read_bytes()
|
||||||
|
an4_test = Path("/dataset/asr_data/an4/test_manifest.json").read_bytes()
|
||||||
|
pnr_train = Path("/dataset/asr_data/pnr_data/train_manifest.json").read_bytes()
|
||||||
|
pnr_test = Path("/dataset/asr_data/pnr_data/test_manifest.json").read_bytes()
|
||||||
|
|
||||||
|
with Path("/dataset/asr_data/an4_pnr/train_manifest.json").open("wb") as pf:
|
||||||
|
pf.write(an4_train + pnr_train)
|
||||||
|
with Path("/dataset/asr_data/an4_pnr/test_manifest.json").open("wb") as pf:
|
||||||
|
pf.write(an4_test + pnr_test)
|
||||||
|
|
||||||
|
|
||||||
|
augment_an4()
|
||||||
|
|
||||||
|
|
||||||
|
def validate_data(data_file):
|
||||||
|
with Path(data_file).open("r") as pf:
|
||||||
|
pnr_jsonl = pf.readlines()
|
||||||
|
for (i, s) in enumerate(pnr_jsonl):
|
||||||
|
try:
|
||||||
|
json.loads(s)
|
||||||
|
except BaseException as e:
|
||||||
|
print(f"failed on {i}")
|
||||||
|
|
||||||
|
|
||||||
|
validate_data("/dataset/asr_data/an4_pnr/test_manifest.json")
|
||||||
|
validate_data("/dataset/asr_data/an4_pnr/train_manifest.json")
|
||||||
|
|
||||||
|
|
||||||
|
# def convert_digits(data_file="/dataset/asr_data/an4_pnr/test_manifest.json"):
|
||||||
|
# with Path(data_file).open("r") as pf:
|
||||||
|
# pnr_jsonl = pf.readlines()
|
||||||
|
#
|
||||||
|
# pnr_data = [json.loads(i) for i in pnr_jsonl]
|
||||||
|
# new_pnr_data = []
|
||||||
|
# for i in pnr_data:
|
||||||
|
# num_tokens = [num2words(c) for c in i["text"] if "0" <= c <= "9"]
|
||||||
|
# i["text"] = "".join(num_tokens)
|
||||||
|
# new_pnr_data.append(i)
|
||||||
|
#
|
||||||
|
# new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data]
|
||||||
|
#
|
||||||
|
# with Path(data_file).open("w") as pf:
|
||||||
|
# new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n"
|
||||||
|
# pf.write(new_pnr_data)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# convert_digits(data_file="/dataset/asr_data/an4_pnr/train_manifest.json")
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from google.cloud import texttospeech
|
||||||
|
|
||||||
|
LOGGER = getLogger("googletts")
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleTTS(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.client = texttospeech.TextToSpeechClient()
|
||||||
|
|
||||||
|
def text_to_speech(self, text: str, params: dict) -> bytes:
|
||||||
|
tts_input = texttospeech.types.SynthesisInput(ssml=text)
|
||||||
|
voice = texttospeech.types.VoiceSelectionParams(
|
||||||
|
language_code=params["language"], name=params["name"]
|
||||||
|
)
|
||||||
|
audio_config = texttospeech.types.AudioConfig(
|
||||||
|
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
|
||||||
|
sample_rate_hertz=params["sample_rate"],
|
||||||
|
)
|
||||||
|
response = self.client.synthesize_speech(tts_input, voice, audio_config)
|
||||||
|
audio_content = response.audio_content
|
||||||
|
return audio_content
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def voice_list(cls):
|
||||||
|
"""Lists the available voices."""
|
||||||
|
|
||||||
|
client = cls().client
|
||||||
|
|
||||||
|
# Performs the list voices request
|
||||||
|
voices = client.list_voices()
|
||||||
|
results = []
|
||||||
|
for voice in voices.voices:
|
||||||
|
supported_eng_langs = [
|
||||||
|
lang for lang in voice.language_codes if lang[:2] == "en"
|
||||||
|
]
|
||||||
|
if len(supported_eng_langs) > 0:
|
||||||
|
lang = ",".join(supported_eng_langs)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": voice.name,
|
||||||
|
"language": lang,
|
||||||
|
"gender": ssml_gender.name,
|
||||||
|
"engine": "wavenet" if "Wav" in voice.name else "standard",
|
||||||
|
"sample_rate": voice.natural_sample_rate_hertz,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""
|
||||||
|
TTSClient Abstract Class
|
||||||
|
"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class TTSClient(ABC):
|
||||||
|
"""
|
||||||
|
Base class for TTS
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def text_to_speech(self, text: str, num_channels: int, sample_rate: int,
|
||||||
|
audio_encoding) -> bytes:
|
||||||
|
"""
|
||||||
|
convert text to bytes
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
text {[type]} -- text to convert
|
||||||
|
channel {[type]} -- output audio bytes channel setting
|
||||||
|
width {[type]} -- width of audio bytes
|
||||||
|
rate {[type]} -- rare for audio bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[type] -- [description]
|
||||||
|
"""
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
import numpy as np
|
||||||
|
import wave
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def manifest_str(path, dur, text):
|
||||||
|
return (
|
||||||
|
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def wav_bytes(audio_bytes, frame_rate=24000):
|
||||||
|
wf_b = io.BytesIO()
|
||||||
|
with wave.open(wf_b, mode="w") as wf:
|
||||||
|
wf.setnchannels(1)
|
||||||
|
wf.setframerate(frame_rate)
|
||||||
|
wf.setsampwidth(2)
|
||||||
|
wf.writeframesraw(audio_bytes)
|
||||||
|
return wf_b.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def random_pnr_generator(count=10000):
|
||||||
|
LENGTH = 3
|
||||||
|
|
||||||
|
# alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||||
|
alphabet = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||||
|
numeric = list("0123456789")
|
||||||
|
np_alphabet = np.array(alphabet, dtype="|S1")
|
||||||
|
np_numeric = np.array(numeric, dtype="|S1")
|
||||||
|
np_alpha_codes = np.random.choice(np_alphabet, [count, LENGTH])
|
||||||
|
np_num_codes = np.random.choice(np_numeric, [count, LENGTH])
|
||||||
|
np_code_seed = np.concatenate((np_alpha_codes, np_num_codes), axis=1).T
|
||||||
|
np.random.shuffle(np_code_seed)
|
||||||
|
np_codes = np_code_seed.T
|
||||||
|
codes = [(b"".join(np_codes[i])).decode("utf-8") for i in range(len(np_codes))]
|
||||||
|
return codes
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
for c in random_pnr_generator():
|
||||||
|
print(c)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,339 @@
|
||||||
|
# Copyright (c) 2019 NVIDIA Corporation
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from ruamel.yaml import YAML
|
||||||
|
|
||||||
|
import nemo
|
||||||
|
import nemo.collections.asr as nemo_asr
|
||||||
|
import nemo.utils.argparse as nm_argparse
|
||||||
|
from nemo.collections.asr.helpers import (
|
||||||
|
monitor_asr_train_progress,
|
||||||
|
process_evaluation_batch,
|
||||||
|
process_evaluation_epoch,
|
||||||
|
)
|
||||||
|
from nemo.utils.lr_policies import CosineAnnealing
|
||||||
|
|
||||||
|
logging = nemo.logging
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
parents=[nm_argparse.NemoArgParser()],
|
||||||
|
description="Jasper",
|
||||||
|
conflict_handler="resolve",
|
||||||
|
)
|
||||||
|
parser.set_defaults(
|
||||||
|
checkpoint_dir=None,
|
||||||
|
optimizer="novograd",
|
||||||
|
batch_size=64,
|
||||||
|
eval_batch_size=64,
|
||||||
|
lr=0.002,
|
||||||
|
amp_opt_level="O1",
|
||||||
|
create_tb_writer=True,
|
||||||
|
model_config="./train/jasper10x5dr.yaml",
|
||||||
|
train_dataset="./train/asr_data/train_manifest.json",
|
||||||
|
eval_datasets="./train/asr_data/test_manifest.json",
|
||||||
|
work_dir="./train/work",
|
||||||
|
num_epochs=50,
|
||||||
|
weight_decay=0.005,
|
||||||
|
checkpoint_save_freq=1000,
|
||||||
|
eval_freq=100,
|
||||||
|
load_dir="./train/models/jasper/",
|
||||||
|
warmup_steps=3,
|
||||||
|
exp_name='jasper-speller'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Overwrite default args
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_steps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help="max number of steps to train",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_epochs",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help="number of epochs to train",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_config",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="model configuration file: model.yaml",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new args
|
||||||
|
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||||
|
parser.add_argument("--beta1", default=0.95, type=float)
|
||||||
|
parser.add_argument("--beta2", default=0.25, type=float)
|
||||||
|
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="directory with pre-trained checkpoint",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.max_steps is not None and args.num_epochs is not None:
|
||||||
|
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def construct_name(
|
||||||
|
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||||
|
):
|
||||||
|
if max_steps is not None:
|
||||||
|
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||||
|
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||||
|
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_all_dags(args, neural_factory):
|
||||||
|
yaml = YAML(typ="safe")
|
||||||
|
with open(args.model_config) as f:
|
||||||
|
jasper_params = yaml.load(f)
|
||||||
|
vocab = jasper_params["labels"]
|
||||||
|
sample_rate = jasper_params["sample_rate"]
|
||||||
|
|
||||||
|
# Calculate num_workers for dataloader
|
||||||
|
total_cpus = os.cpu_count()
|
||||||
|
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||||
|
|
||||||
|
# perturb_config = jasper_params.get('perturb', None)
|
||||||
|
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||||
|
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||||
|
del train_dl_params["train"]
|
||||||
|
del train_dl_params["eval"]
|
||||||
|
# del train_dl_params["normalize_transcripts"]
|
||||||
|
|
||||||
|
data_layer = nemo_asr.AudioToTextDataLayer(
|
||||||
|
manifest_filepath=args.train_dataset,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
labels=vocab,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=cpu_per_traindl,
|
||||||
|
**train_dl_params,
|
||||||
|
# normalize_transcripts=False
|
||||||
|
)
|
||||||
|
|
||||||
|
N = len(data_layer)
|
||||||
|
steps_per_epoch = math.ceil(
|
||||||
|
N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||||
|
)
|
||||||
|
logging.info("Have {0} examples to train on.".format(N))
|
||||||
|
|
||||||
|
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||||
|
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
|
||||||
|
)
|
||||||
|
|
||||||
|
multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||||
|
if multiply_batch_config:
|
||||||
|
multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||||
|
|
||||||
|
spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||||
|
if spectr_augment_config:
|
||||||
|
data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||||
|
**spectr_augment_config
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||||
|
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||||
|
del eval_dl_params["train"]
|
||||||
|
del eval_dl_params["eval"]
|
||||||
|
data_layers_eval = []
|
||||||
|
|
||||||
|
if args.eval_datasets:
|
||||||
|
for eval_datasets in args.eval_datasets:
|
||||||
|
data_layer_eval = nemo_asr.AudioToTextDataLayer(
|
||||||
|
manifest_filepath=eval_datasets,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
labels=vocab,
|
||||||
|
batch_size=args.eval_batch_size,
|
||||||
|
num_workers=cpu_per_traindl,
|
||||||
|
**eval_dl_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_layers_eval.append(data_layer_eval)
|
||||||
|
else:
|
||||||
|
logging.warning("There were no val datasets passed")
|
||||||
|
|
||||||
|
jasper_encoder = nemo_asr.JasperEncoder(
|
||||||
|
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||||
|
**jasper_params["JasperEncoder"],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||||
|
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||||
|
num_classes=len(vocab),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||||
|
|
||||||
|
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||||
|
|
||||||
|
logging.info("================================")
|
||||||
|
logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
|
||||||
|
logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
|
||||||
|
logging.info(
|
||||||
|
f"Total number of parameters in model: "
|
||||||
|
f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||||
|
)
|
||||||
|
logging.info("================================")
|
||||||
|
|
||||||
|
# Train DAG
|
||||||
|
(audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
|
||||||
|
processed_signal_t, p_length_t = data_preprocessor(
|
||||||
|
input_signal=audio_signal_t, length=a_sig_length_t
|
||||||
|
)
|
||||||
|
|
||||||
|
if multiply_batch_config:
|
||||||
|
(
|
||||||
|
processed_signal_t,
|
||||||
|
p_length_t,
|
||||||
|
transcript_t,
|
||||||
|
transcript_len_t,
|
||||||
|
) = multiply_batch(
|
||||||
|
in_x=processed_signal_t,
|
||||||
|
in_x_len=p_length_t,
|
||||||
|
in_y=transcript_t,
|
||||||
|
in_y_len=transcript_len_t,
|
||||||
|
)
|
||||||
|
|
||||||
|
if spectr_augment_config:
|
||||||
|
processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
|
||||||
|
|
||||||
|
encoded_t, encoded_len_t = jasper_encoder(
|
||||||
|
audio_signal=processed_signal_t, length=p_length_t
|
||||||
|
)
|
||||||
|
log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||||
|
predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||||
|
loss_t = ctc_loss(
|
||||||
|
log_probs=log_probs_t,
|
||||||
|
targets=transcript_t,
|
||||||
|
input_length=encoded_len_t,
|
||||||
|
target_length=transcript_len_t,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Callbacks needed to print info to console and Tensorboard
|
||||||
|
train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||||
|
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||||
|
print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||||
|
get_tb_values=lambda x: [("loss", x[0])],
|
||||||
|
tb_writer=neural_factory.tb_writer,
|
||||||
|
)
|
||||||
|
|
||||||
|
chpt_callback = nemo.core.CheckpointCallback(
|
||||||
|
folder=neural_factory.checkpoint_dir,
|
||||||
|
load_from_folder=args.load_dir,
|
||||||
|
step_freq=args.checkpoint_save_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks = [train_callback, chpt_callback]
|
||||||
|
|
||||||
|
# assemble eval DAGs
|
||||||
|
for i, eval_dl in enumerate(data_layers_eval):
|
||||||
|
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
|
||||||
|
processed_signal_e, p_length_e = data_preprocessor(
|
||||||
|
input_signal=audio_signal_e, length=a_sig_length_e
|
||||||
|
)
|
||||||
|
encoded_e, encoded_len_e = jasper_encoder(
|
||||||
|
audio_signal=processed_signal_e, length=p_length_e
|
||||||
|
)
|
||||||
|
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||||
|
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||||
|
loss_e = ctc_loss(
|
||||||
|
log_probs=log_probs_e,
|
||||||
|
targets=transcript_e,
|
||||||
|
input_length=encoded_len_e,
|
||||||
|
target_length=transcript_len_e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# create corresponding eval callback
|
||||||
|
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||||
|
eval_callback = nemo.core.EvaluatorCallback(
|
||||||
|
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
|
||||||
|
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||||
|
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
|
||||||
|
eval_step=args.eval_freq,
|
||||||
|
tb_writer=neural_factory.tb_writer,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks.append(eval_callback)
|
||||||
|
return loss_t, callbacks, steps_per_epoch
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
name = construct_name(
|
||||||
|
args.exp_name,
|
||||||
|
args.lr,
|
||||||
|
args.batch_size,
|
||||||
|
args.max_steps,
|
||||||
|
args.num_epochs,
|
||||||
|
args.weight_decay,
|
||||||
|
args.optimizer,
|
||||||
|
args.iter_per_step,
|
||||||
|
)
|
||||||
|
log_dir = name
|
||||||
|
if args.work_dir:
|
||||||
|
log_dir = os.path.join(args.work_dir, name)
|
||||||
|
|
||||||
|
# instantiate Neural Factory with supported backend
|
||||||
|
neural_factory = nemo.core.NeuralModuleFactory(
|
||||||
|
backend=nemo.core.Backend.PyTorch,
|
||||||
|
local_rank=args.local_rank,
|
||||||
|
optimization_level=args.amp_opt_level,
|
||||||
|
log_dir=log_dir,
|
||||||
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
|
create_tb_writer=args.create_tb_writer,
|
||||||
|
files_to_copy=[args.model_config, __file__],
|
||||||
|
cudnn_benchmark=args.cudnn_benchmark,
|
||||||
|
tensorboard_dir=args.tensorboard_dir,
|
||||||
|
)
|
||||||
|
args.num_gpus = neural_factory.world_size
|
||||||
|
|
||||||
|
checkpoint_dir = neural_factory.checkpoint_dir
|
||||||
|
if args.local_rank is not None:
|
||||||
|
logging.info("Doing ALL GPU")
|
||||||
|
|
||||||
|
# build dags
|
||||||
|
train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory)
|
||||||
|
|
||||||
|
# train model
|
||||||
|
neural_factory.train(
|
||||||
|
tensors_to_optimize=[train_loss],
|
||||||
|
callbacks=callbacks,
|
||||||
|
lr_policy=CosineAnnealing(
|
||||||
|
args.max_steps
|
||||||
|
if args.max_steps is not None
|
||||||
|
else args.num_epochs * steps_per_epoch,
|
||||||
|
warmup_steps=args.warmup_steps,
|
||||||
|
),
|
||||||
|
optimizer=args.optimizer,
|
||||||
|
optimization_params={
|
||||||
|
"num_epochs": args.num_epochs,
|
||||||
|
"max_steps": args.max_steps,
|
||||||
|
"lr": args.lr,
|
||||||
|
"betas": (args.beta1, args.beta2),
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
"grad_norm_clip": None,
|
||||||
|
},
|
||||||
|
batches_per_step=args.iter_per_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
14
setup.py
14
setup.py
|
|
@ -5,7 +5,17 @@ requirements = [
|
||||||
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
||||||
]
|
]
|
||||||
|
|
||||||
extra_requirements = {"server": ["rpyc==4.1.4"]}
|
extra_requirements = {
|
||||||
|
"server": ["rpyc~=4.1.4"],
|
||||||
|
"data": [
|
||||||
|
"google-cloud-texttospeech~=1.0.1",
|
||||||
|
"tqdm~=4.39.0",
|
||||||
|
"pydub~=0.23.1",
|
||||||
|
"scikit_learn~=0.22.1",
|
||||||
|
"pandas~=1.0.3",
|
||||||
|
"boto3~=1.12.35",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="jasper-asr",
|
name="jasper-asr",
|
||||||
|
|
@ -22,6 +32,8 @@ setup(
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
"jasper_transcribe = jasper.transcribe:main",
|
"jasper_transcribe = jasper.transcribe:main",
|
||||||
"jasper_asr_rpyc_server = jasper.server:main",
|
"jasper_asr_rpyc_server = jasper.server:main",
|
||||||
|
"jasper_asr_trainer = jasper.train:main",
|
||||||
|
"jasper_asr_data_generate = jasper.data_utils.generator:main",
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue