diff --git a/jasper/data_utils/__init__.py b/jasper/data_utils/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/jasper/data_utils/__init__.py @@ -0,0 +1 @@ + diff --git a/jasper/data_utils/generator.py b/jasper/data_utils/generator.py new file mode 100644 index 0000000..c49d460 --- /dev/null +++ b/jasper/data_utils/generator.py @@ -0,0 +1,65 @@ +# import io +# import sys +# import json +import argparse +import logging +from pathlib import Path +from .utils import random_pnr_generator, manifest_str +from .tts.googletts import GoogleTTS +from tqdm import tqdm +import random + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def generate_asr_data(output_dir, count): + google_voices = GoogleTTS.voice_list() + gtts = GoogleTTS() + wav_dir = output_dir / Path("pnr_data") + wav_dir.mkdir(parents=True, exist_ok=True) + asr_manifest = output_dir / Path("pnr_data").with_suffix(".json") + with asr_manifest.open("w") as mf: + for pnr_code in tqdm(random_pnr_generator(count)): + tts_code = ( + f'{pnr_code}' + ) + param = random.choice(google_voices) + param["sample_rate"] = 24000 + param["num_channels"] = 1 + wav_data = gtts.text_to_speech(text=tts_code, params=param) + audio_dur = len(wav_data[44:]) / (2 * 24000) + pnr_af = wav_dir / Path(pnr_code).with_suffix(".wav") + pnr_af.write_bytes(wav_data) + rel_pnr_path = pnr_af.relative_to(output_dir) + manifest = manifest_str(str(rel_pnr_path), audio_dur, pnr_code) + mf.write(manifest) + + +def arg_parser(): + prog = Path(__file__).stem + parser = argparse.ArgumentParser( + prog=prog, description=f"generates asr training data" + ) + parser.add_argument( + "--output_dir", + type=Path, + default=Path("./train/asr_data"), + help="directory to output asr data", + ) + parser.add_argument( + "--count", type=int, default=3, help="number of datapoints to generate" + ) + return parser + + +def main(): + parser = arg_parser() + args = parser.parse_args() + generate_asr_data(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/jasper/data_utils/parallel.py b/jasper/data_utils/parallel.py new file mode 100644 index 0000000..99141b6 --- /dev/null +++ b/jasper/data_utils/parallel.py @@ -0,0 +1,30 @@ +import concurrent.futures +import urllib.request + +URLS = [ + "http://www.foxnews.com/", + "http://www.cnn.com/", + "http://europe.wsj.com/", + "http://www.bbc.co.uk/", + "http://some-made-up-domain.com/", +] + + +# Retrieve a single page and report the URL and contents +def load_url(url, timeout): + with urllib.request.urlopen(url, timeout=timeout) as conn: + return conn.read() + + +# We can use a with statement to ensure threads are cleaned up promptly +with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + # Start the load operations and mark each future with its URL + future_to_url = {executor.submit(load_url, url, 60): url for url in URLS} + for future in concurrent.futures.as_completed(future_to_url): + url = future_to_url[future] + try: + data = future.result() + except Exception as exc: + print("%r generated an exception: %s" % (url, exc)) + else: + print("%r page is %d bytes" % (url, len(data))) diff --git a/jasper/data_utils/process.py b/jasper/data_utils/process.py new file mode 100644 index 0000000..44e4237 --- /dev/null +++ b/jasper/data_utils/process.py @@ -0,0 +1,95 @@ +import json +from pathlib import Path +from sklearn.model_selection import train_test_split +from num2words import num2words + + +def separate_space_convert_digit_setpath(): + with Path("/home/malar/work/asr-data-utils/asr_data/pnr_data.json").open("r") as pf: + pnr_jsonl = pf.readlines() + + pnr_data = [json.loads(i) for i in pnr_jsonl] + + new_pnr_data = [] + for i in pnr_data: + letters = " ".join(list(i["text"])) + num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters] + i["text"] = ("".join(num_tokens)).lower() + i["audio_filepath"] = i["audio_filepath"].replace( + "pnr_data/", "/dataset/asr_data/pnr_data/wav/" + ) + new_pnr_data.append(i) + + new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data] + + with Path("/dataset/asr_data/pnr_data/pnr_data.json").open("w") as pf: + new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n" + pf.write(new_pnr_data) + + +separate_space_convert_digit_setpath() + + +def split_data(): + with Path("/dataset/asr_data/pnr_data/pnr_data.json").open("r") as pf: + pnr_jsonl = pf.readlines() + train_pnr, test_pnr = train_test_split(pnr_jsonl, test_size=0.1) + with Path("/dataset/asr_data/pnr_data/train_manifest.json").open("w") as pf: + pnr_data = "".join(train_pnr) + pf.write(pnr_data) + with Path("/dataset/asr_data/pnr_data/test_manifest.json").open("w") as pf: + pnr_data = "".join(test_pnr) + pf.write(pnr_data) + + +split_data() + + +def augment_an4(): + an4_train = Path("/dataset/asr_data/an4/train_manifest.json").read_bytes() + an4_test = Path("/dataset/asr_data/an4/test_manifest.json").read_bytes() + pnr_train = Path("/dataset/asr_data/pnr_data/train_manifest.json").read_bytes() + pnr_test = Path("/dataset/asr_data/pnr_data/test_manifest.json").read_bytes() + + with Path("/dataset/asr_data/an4_pnr/train_manifest.json").open("wb") as pf: + pf.write(an4_train + pnr_train) + with Path("/dataset/asr_data/an4_pnr/test_manifest.json").open("wb") as pf: + pf.write(an4_test + pnr_test) + + +augment_an4() + + +def validate_data(data_file): + with Path(data_file).open("r") as pf: + pnr_jsonl = pf.readlines() + for (i, s) in enumerate(pnr_jsonl): + try: + json.loads(s) + except BaseException as e: + print(f"failed on {i}") + + +validate_data("/dataset/asr_data/an4_pnr/test_manifest.json") +validate_data("/dataset/asr_data/an4_pnr/train_manifest.json") + + +# def convert_digits(data_file="/dataset/asr_data/an4_pnr/test_manifest.json"): +# with Path(data_file).open("r") as pf: +# pnr_jsonl = pf.readlines() +# +# pnr_data = [json.loads(i) for i in pnr_jsonl] +# new_pnr_data = [] +# for i in pnr_data: +# num_tokens = [num2words(c) for c in i["text"] if "0" <= c <= "9"] +# i["text"] = "".join(num_tokens) +# new_pnr_data.append(i) +# +# new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data] +# +# with Path(data_file).open("w") as pf: +# new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n" +# pf.write(new_pnr_data) +# +# +# convert_digits(data_file="/dataset/asr_data/an4_pnr/train_manifest.json") diff --git a/jasper/data_utils/tts/__init__.py b/jasper/data_utils/tts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jasper/data_utils/tts/googletts.py b/jasper/data_utils/tts/googletts.py new file mode 100644 index 0000000..690eaac --- /dev/null +++ b/jasper/data_utils/tts/googletts.py @@ -0,0 +1,52 @@ +from logging import getLogger +from google.cloud import texttospeech + +LOGGER = getLogger("googletts") + + +class GoogleTTS(object): + def __init__(self): + self.client = texttospeech.TextToSpeechClient() + + def text_to_speech(self, text: str, params: dict) -> bytes: + tts_input = texttospeech.types.SynthesisInput(ssml=text) + voice = texttospeech.types.VoiceSelectionParams( + language_code=params["language"], name=params["name"] + ) + audio_config = texttospeech.types.AudioConfig( + audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16, + sample_rate_hertz=params["sample_rate"], + ) + response = self.client.synthesize_speech(tts_input, voice, audio_config) + audio_content = response.audio_content + return audio_content + + @classmethod + def voice_list(cls): + """Lists the available voices.""" + + client = cls().client + + # Performs the list voices request + voices = client.list_voices() + results = [] + for voice in voices.voices: + supported_eng_langs = [ + lang for lang in voice.language_codes if lang[:2] == "en" + ] + if len(supported_eng_langs) > 0: + lang = ",".join(supported_eng_langs) + else: + continue + + ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender) + results.append( + { + "name": voice.name, + "language": lang, + "gender": ssml_gender.name, + "engine": "wavenet" if "Wav" in voice.name else "standard", + "sample_rate": voice.natural_sample_rate_hertz, + } + ) + return results diff --git a/jasper/data_utils/tts/ttsclient.py b/jasper/data_utils/tts/ttsclient.py new file mode 100644 index 0000000..d61a5a4 --- /dev/null +++ b/jasper/data_utils/tts/ttsclient.py @@ -0,0 +1,26 @@ +""" +TTSClient Abstract Class +""" +from abc import ABC, abstractmethod + + +class TTSClient(ABC): + """ + Base class for TTS + """ + + @abstractmethod + def text_to_speech(self, text: str, num_channels: int, sample_rate: int, + audio_encoding) -> bytes: + """ + convert text to bytes + + Arguments: + text {[type]} -- text to convert + channel {[type]} -- output audio bytes channel setting + width {[type]} -- width of audio bytes + rate {[type]} -- rare for audio bytes + + Returns: + [type] -- [description] + """ diff --git a/jasper/data_utils/utils.py b/jasper/data_utils/utils.py new file mode 100644 index 0000000..bee1e97 --- /dev/null +++ b/jasper/data_utils/utils.py @@ -0,0 +1,47 @@ +import numpy as np +import wave +import io +import json + + +def manifest_str(path, dur, text): + return ( + json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text}) + + "\n" + ) + + +def wav_bytes(audio_bytes, frame_rate=24000): + wf_b = io.BytesIO() + with wave.open(wf_b, mode="w") as wf: + wf.setnchannels(1) + wf.setframerate(frame_rate) + wf.setsampwidth(2) + wf.writeframesraw(audio_bytes) + return wf_b.getvalue() + + +def random_pnr_generator(count=10000): + LENGTH = 3 + + # alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + alphabet = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + numeric = list("0123456789") + np_alphabet = np.array(alphabet, dtype="|S1") + np_numeric = np.array(numeric, dtype="|S1") + np_alpha_codes = np.random.choice(np_alphabet, [count, LENGTH]) + np_num_codes = np.random.choice(np_numeric, [count, LENGTH]) + np_code_seed = np.concatenate((np_alpha_codes, np_num_codes), axis=1).T + np.random.shuffle(np_code_seed) + np_codes = np_code_seed.T + codes = [(b"".join(np_codes[i])).decode("utf-8") for i in range(len(np_codes))] + return codes + + +def main(): + for c in random_pnr_generator(): + print(c) + + +if __name__ == "__main__": + main() diff --git a/jasper/train.py b/jasper/train.py new file mode 100644 index 0000000..9861aff --- /dev/null +++ b/jasper/train.py @@ -0,0 +1,339 @@ +# Copyright (c) 2019 NVIDIA Corporation +import argparse +import copy +import math +import os +from functools import partial + +from ruamel.yaml import YAML + +import nemo +import nemo.collections.asr as nemo_asr +import nemo.utils.argparse as nm_argparse +from nemo.collections.asr.helpers import ( + monitor_asr_train_progress, + process_evaluation_batch, + process_evaluation_epoch, +) +from nemo.utils.lr_policies import CosineAnnealing + +logging = nemo.logging + + +def parse_args(): + parser = argparse.ArgumentParser( + parents=[nm_argparse.NemoArgParser()], + description="Jasper", + conflict_handler="resolve", + ) + parser.set_defaults( + checkpoint_dir=None, + optimizer="novograd", + batch_size=64, + eval_batch_size=64, + lr=0.002, + amp_opt_level="O1", + create_tb_writer=True, + model_config="./train/jasper10x5dr.yaml", + train_dataset="./train/asr_data/train_manifest.json", + eval_datasets="./train/asr_data/test_manifest.json", + work_dir="./train/work", + num_epochs=50, + weight_decay=0.005, + checkpoint_save_freq=1000, + eval_freq=100, + load_dir="./train/models/jasper/", + warmup_steps=3, + exp_name='jasper-speller' + ) + + # Overwrite default args + parser.add_argument( + "--max_steps", + type=int, + default=None, + required=False, + help="max number of steps to train", + ) + parser.add_argument( + "--num_epochs", + type=int, + default=None, + required=False, + help="number of epochs to train", + ) + parser.add_argument( + "--model_config", + type=str, + required=False, + help="model configuration file: model.yaml", + ) + + # Create new args + parser.add_argument("--exp_name", default="Jasper", type=str) + parser.add_argument("--beta1", default=0.95, type=float) + parser.add_argument("--beta2", default=0.25, type=float) + parser.add_argument("--warmup_steps", default=0, type=int) + parser.add_argument( + "--load_dir", + default=None, + type=str, + help="directory with pre-trained checkpoint", + ) + + args = parser.parse_args() + + if args.max_steps is not None and args.num_epochs is not None: + raise ValueError("Either max_steps or num_epochs should be provided.") + return args + + +def construct_name( + name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step +): + if max_steps is not None: + return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format( + name, lr, batch_size, max_steps, wd, optimizer, iter_per_step + ) + else: + return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format( + name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step + ) + + +def create_all_dags(args, neural_factory): + yaml = YAML(typ="safe") + with open(args.model_config) as f: + jasper_params = yaml.load(f) + vocab = jasper_params["labels"] + sample_rate = jasper_params["sample_rate"] + + # Calculate num_workers for dataloader + total_cpus = os.cpu_count() + cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1) + + # perturb_config = jasper_params.get('perturb', None) + train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) + train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"]) + del train_dl_params["train"] + del train_dl_params["eval"] + # del train_dl_params["normalize_transcripts"] + + data_layer = nemo_asr.AudioToTextDataLayer( + manifest_filepath=args.train_dataset, + sample_rate=sample_rate, + labels=vocab, + batch_size=args.batch_size, + num_workers=cpu_per_traindl, + **train_dl_params, + # normalize_transcripts=False + ) + + N = len(data_layer) + steps_per_epoch = math.ceil( + N / (args.batch_size * args.iter_per_step * args.num_gpus) + ) + logging.info("Have {0} examples to train on.".format(N)) + + data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor( + sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"] + ) + + multiply_batch_config = jasper_params.get("MultiplyBatch", None) + if multiply_batch_config: + multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config) + + spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None) + if spectr_augment_config: + data_spectr_augmentation = nemo_asr.SpectrogramAugmentation( + **spectr_augment_config + ) + + eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) + eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"]) + del eval_dl_params["train"] + del eval_dl_params["eval"] + data_layers_eval = [] + + if args.eval_datasets: + for eval_datasets in args.eval_datasets: + data_layer_eval = nemo_asr.AudioToTextDataLayer( + manifest_filepath=eval_datasets, + sample_rate=sample_rate, + labels=vocab, + batch_size=args.eval_batch_size, + num_workers=cpu_per_traindl, + **eval_dl_params, + ) + + data_layers_eval.append(data_layer_eval) + else: + logging.warning("There were no val datasets passed") + + jasper_encoder = nemo_asr.JasperEncoder( + feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"], + **jasper_params["JasperEncoder"], + ) + + jasper_decoder = nemo_asr.JasperDecoderForCTC( + feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], + num_classes=len(vocab), + ) + + ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab)) + + greedy_decoder = nemo_asr.GreedyCTCDecoder() + + logging.info("================================") + logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}") + logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}") + logging.info( + f"Total number of parameters in model: " + f"{jasper_decoder.num_weights + jasper_encoder.num_weights}" + ) + logging.info("================================") + + # Train DAG + (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer() + processed_signal_t, p_length_t = data_preprocessor( + input_signal=audio_signal_t, length=a_sig_length_t + ) + + if multiply_batch_config: + ( + processed_signal_t, + p_length_t, + transcript_t, + transcript_len_t, + ) = multiply_batch( + in_x=processed_signal_t, + in_x_len=p_length_t, + in_y=transcript_t, + in_y_len=transcript_len_t, + ) + + if spectr_augment_config: + processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t) + + encoded_t, encoded_len_t = jasper_encoder( + audio_signal=processed_signal_t, length=p_length_t + ) + log_probs_t = jasper_decoder(encoder_output=encoded_t) + predictions_t = greedy_decoder(log_probs=log_probs_t) + loss_t = ctc_loss( + log_probs=log_probs_t, + targets=transcript_t, + input_length=encoded_len_t, + target_length=transcript_len_t, + ) + + # Callbacks needed to print info to console and Tensorboard + train_callback = nemo.core.SimpleLossLoggerCallback( + tensors=[loss_t, predictions_t, transcript_t, transcript_len_t], + print_func=partial(monitor_asr_train_progress, labels=vocab), + get_tb_values=lambda x: [("loss", x[0])], + tb_writer=neural_factory.tb_writer, + ) + + chpt_callback = nemo.core.CheckpointCallback( + folder=neural_factory.checkpoint_dir, + load_from_folder=args.load_dir, + step_freq=args.checkpoint_save_freq, + ) + + callbacks = [train_callback, chpt_callback] + + # assemble eval DAGs + for i, eval_dl in enumerate(data_layers_eval): + (audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl() + processed_signal_e, p_length_e = data_preprocessor( + input_signal=audio_signal_e, length=a_sig_length_e + ) + encoded_e, encoded_len_e = jasper_encoder( + audio_signal=processed_signal_e, length=p_length_e + ) + log_probs_e = jasper_decoder(encoder_output=encoded_e) + predictions_e = greedy_decoder(log_probs=log_probs_e) + loss_e = ctc_loss( + log_probs=log_probs_e, + targets=transcript_e, + input_length=encoded_len_e, + target_length=transcript_len_e, + ) + + # create corresponding eval callback + tagname = os.path.basename(args.eval_datasets[i]).split(".")[0] + eval_callback = nemo.core.EvaluatorCallback( + eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e], + user_iter_callback=partial(process_evaluation_batch, labels=vocab), + user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname), + eval_step=args.eval_freq, + tb_writer=neural_factory.tb_writer, + ) + + callbacks.append(eval_callback) + return loss_t, callbacks, steps_per_epoch + + +def main(): + args = parse_args() + name = construct_name( + args.exp_name, + args.lr, + args.batch_size, + args.max_steps, + args.num_epochs, + args.weight_decay, + args.optimizer, + args.iter_per_step, + ) + log_dir = name + if args.work_dir: + log_dir = os.path.join(args.work_dir, name) + + # instantiate Neural Factory with supported backend + neural_factory = nemo.core.NeuralModuleFactory( + backend=nemo.core.Backend.PyTorch, + local_rank=args.local_rank, + optimization_level=args.amp_opt_level, + log_dir=log_dir, + checkpoint_dir=args.checkpoint_dir, + create_tb_writer=args.create_tb_writer, + files_to_copy=[args.model_config, __file__], + cudnn_benchmark=args.cudnn_benchmark, + tensorboard_dir=args.tensorboard_dir, + ) + args.num_gpus = neural_factory.world_size + + checkpoint_dir = neural_factory.checkpoint_dir + if args.local_rank is not None: + logging.info("Doing ALL GPU") + + # build dags + train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory) + + # train model + neural_factory.train( + tensors_to_optimize=[train_loss], + callbacks=callbacks, + lr_policy=CosineAnnealing( + args.max_steps + if args.max_steps is not None + else args.num_epochs * steps_per_epoch, + warmup_steps=args.warmup_steps, + ), + optimizer=args.optimizer, + optimization_params={ + "num_epochs": args.num_epochs, + "max_steps": args.max_steps, + "lr": args.lr, + "betas": (args.beta1, args.beta2), + "weight_decay": args.weight_decay, + "grad_norm_clip": None, + }, + batches_per_step=args.iter_per_step, + ) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index b18df24..be67241 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,17 @@ requirements = [ "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit", ] -extra_requirements = {"server": ["rpyc==4.1.4"]} +extra_requirements = { + "server": ["rpyc~=4.1.4"], + "data": [ + "google-cloud-texttospeech~=1.0.1", + "tqdm~=4.39.0", + "pydub~=0.23.1", + "scikit_learn~=0.22.1", + "pandas~=1.0.3", + "boto3~=1.12.35", + ], +} setup( name="jasper-asr", @@ -22,6 +32,8 @@ setup( "console_scripts": [ "jasper_transcribe = jasper.transcribe:main", "jasper_asr_rpyc_server = jasper.server:main", + "jasper_asr_trainer = jasper.train:main", + "jasper_asr_data_generate = jasper.data_utils.generator:main", ] }, zip_safe=False,