From 61048f855e54e019dbacb2e784a6f7670707d0f6 Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Mon, 27 Apr 2020 10:53:14 +0530 Subject: [PATCH] implement call audio data recycler for asr --- .gitignore | 37 ++++ jasper/data_utils/call_recycler.py | 333 +++++++++++++++++++++++++++++ jasper/data_utils/generator.py | 39 ++-- jasper/data_utils/process.py | 53 +++-- jasper/data_utils/utils.py | 23 ++ jasper/train.py | 4 +- setup.py | 16 ++ 7 files changed, 465 insertions(+), 40 deletions(-) create mode 100644 jasper/data_utils/call_recycler.py diff --git a/.gitignore b/.gitignore index aab7ea0..f5adf10 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +data/ +.env* +*.yaml + # Created by https://www.gitignore.io/api/python # Edit at https://www.gitignore.io/?templates=python @@ -108,3 +112,36 @@ dmypy.json .pyre/ # End of https://www.gitignore.io/api/python + +# Created by https://www.gitignore.io/api/macos +# Edit at https://www.gitignore.io/?templates=macos + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# End of https://www.gitignore.io/api/macos diff --git a/jasper/data_utils/call_recycler.py b/jasper/data_utils/call_recycler.py new file mode 100644 index 0000000..8678787 --- /dev/null +++ b/jasper/data_utils/call_recycler.py @@ -0,0 +1,333 @@ +# import argparse + +# import logging +import typer +from pathlib import Path + +app = typer.Typer() +# leader_app = typer.Typer() +# app.add_typer(leader_app, name="leaderboard") +# plot_app = typer.Typer() +# app.add_typer(plot_app, name="plot") + + +@app.command() +def analyze( + leaderboard: bool = False, + plot_calls: bool = False, + extract_data: bool = False, + call_logs_file: Path = Path("./call_logs.yaml"), + output_dir: Path = Path("./data"), +): + call_logs_file = Path("./call_logs.yaml") + output_dir = Path("./data") + + from urllib.parse import urlsplit + from functools import reduce + from pymongo import MongoClient + import boto3 + + from io import BytesIO + import json + from ruamel.yaml import YAML + import re + from google.protobuf.timestamp_pb2 import Timestamp + from datetime import timedelta + + # from concurrent.futures import ThreadPoolExecutor + from dateutil.relativedelta import relativedelta + import librosa + import librosa.display + from lenses import lens + from pprint import pprint + import pandas as pd + import matplotlib.pyplot as plt + import matplotlib + from tqdm import tqdm + from .utils import asr_data_writer + from pydub import AudioSegment + + matplotlib.rcParams["agg.path.chunksize"] = 10000 + + matplotlib.use("agg") + + # logging.basicConfig( + # level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + # ) + # logger = logging.getLogger(__name__) + yaml = YAML() + s3 = boto3.client("s3") + mongo_collection = MongoClient("mongodb://localhost:27017/").test.calls + call_media_dir: Path = output_dir / Path("call_wavs") + call_media_dir.mkdir(exist_ok=True, parents=True) + call_meta_dir: Path = output_dir / Path("call_metas") + call_meta_dir.mkdir(exist_ok=True, parents=True) + call_plot_dir: Path = output_dir / Path("plots") + call_plot_dir.mkdir(exist_ok=True, parents=True) + call_asr_data: Path = output_dir / Path("asr_data") + call_asr_data.mkdir(exist_ok=True, parents=True) + + call_logs = yaml.load(call_logs_file.read_text()) + + def get_call_meta(call_obj): + s3_event_url_p = urlsplit(call_obj["DataURI"]) + saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name) + if not saved_meta_path.exists(): + print(f"downloading : {saved_meta_path}") + s3.download_file( + s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path) + ) + call_metas = json.load(saved_meta_path.open()) + return call_metas + + def gen_ev_fev_timedelta(fev): + fev_p = Timestamp() + fev_p.FromJsonString(fev["CreatedTS"]) + fev_dt = fev_p.ToDatetime() + td_0 = timedelta() + + def get_timedelta(ev): + ev_p = Timestamp() + ev_p.FromJsonString(value=ev["CreatedTS"]) + ev_dt = ev_p.ToDatetime() + delta = ev_dt - fev_dt + return delta if delta > td_0 else td_0 + + return get_timedelta + + def process_call(call_obj): + call_meta = get_call_meta(call_obj) + call_events = call_meta["Events"] + + def is_writer_event(ev): + return ev["Author"] == "AUDIO_WRITER" + + writer_events = list(filter(is_writer_event, call_events)) + s3_wav_url = re.search(r"saved to: (.*)", writer_events[0]["Msg"]).groups(0)[0] + s3_wav_url_p = urlsplit(s3_wav_url) + + def is_first_audio_ev(state, ev): + if state[0]: + return state + else: + return (ev["Author"] == "GATEWAY" and ev["Type"] == "AUDIO", ev) + + (_, first_audio_ev) = reduce(is_first_audio_ev, call_events, (False, {})) + + get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev) + + def is_utter_event(ev): + return ( + (ev["Author"] == "CONV" or ev["Author"] == "ASR") + and (ev["Type"] != "DEBUG") + and ev["Type"] != "ASR_RESULT" + ) + + uevs = list(filter(is_utter_event, call_events)) + ev_count = len(uevs) + utter_events = uevs[: ev_count - ev_count % 3] + saved_wav_path = call_media_dir / Path(Path(s3_wav_url_p.path).name) + if not saved_wav_path.exists(): + print(f"downloading : {saved_wav_path}") + s3.download_file( + s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path) + ) + + # %config InlineBackend.figure_format = "retina" + def chunk_n(evs, n): + return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)] + + def get_data_points(utter_events): + data_points = [] + for evs in chunk_n(utter_events, 3): + assert evs[0]["Type"] == "CONV_RESULT" + assert evs[1]["Type"] == "STARTED_SPEAKING" + assert evs[2]["Type"] == "STOPPED_SPEAKING" + start_time = get_ev_fev_timedelta(evs[1]).total_seconds() - 1.5 + end_time = get_ev_fev_timedelta(evs[2]).total_seconds() + code = evs[0]["Msg"] + data_points.append( + {"start_time": start_time, "end_time": end_time, "code": code} + ) + return data_points + + def plot_events(y, sr, utter_events, file_path): + plt.figure(figsize=(16, 12)) + librosa.display.waveplot(y=y, sr=sr) + # plt.tight_layout() + for evs in chunk_n(utter_events, 3): + assert evs[0]["Type"] == "CONV_RESULT" + assert evs[1]["Type"] == "STARTED_SPEAKING" + assert evs[2]["Type"] == "STOPPED_SPEAKING" + for ev in evs: + # print(ev["Type"]) + ev_type = ev["Type"] + pos = get_ev_fev_timedelta(ev).total_seconds() + if ev_type == "STARTED_SPEAKING": + pos = pos - 1.5 + plt.axvline(pos) # , label="pyplot vertical line") + plt.text( + pos, + 0.2, + f"event:{ev_type}:{ev['Msg']}", + rotation=90, + horizontalalignment="left" + if ev_type != "STOPPED_SPEAKING" + else "right", + verticalalignment="center", + ) + plt.title("Monophonic") + plt.savefig(file_path, format="png") + + data_points = get_data_points(utter_events) + + return { + "wav_path": saved_wav_path, + "num_samples": len(utter_events) // 3, + "meta": call_obj, + "data_points": data_points, + } + + def retrieve_callmeta(uri): + cid = Path(urlsplit(uri).path).stem + meta = mongo_collection.find_one({"SystemID": cid}) + duration = meta["EndTS"] - meta["StartTS"] + process_meta = process_call(meta) + return {"url": uri, "meta": meta, "duration": duration, "process": process_meta} + + # @plot_app.command() + def plot_calls_data(): + def plot_data_points(y, sr, data_points, file_path): + plt.figure(figsize=(16, 12)) + librosa.display.waveplot(y=y, sr=sr) + for dp in data_points: + start, end, code = dp["start_time"], dp["end_time"], dp["code"] + plt.axvspan(start, end, color="green", alpha=0.2) + text_pos = (start + end) / 2 + plt.text( + text_pos, + 0.25, + f"{code}", + rotation=90, + horizontalalignment="center", + verticalalignment="center", + ) + plt.title("Datapoints") + plt.savefig(file_path, format="png") + return file_path + + def plot_call(call_obj): + saved_wav_path, data_points, sys_id = ( + call_obj["process"]["wav_path"], + call_obj["process"]["data_points"], + call_obj["meta"]["SystemID"], + ) + file_path = call_plot_dir / Path(sys_id).with_suffix(".png") + if not file_path.exists(): + print(f"plotting: {file_path}") + (y, sr) = librosa.load(saved_wav_path) + plot_data_points(y, sr, data_points, str(file_path)) + return file_path + + # plot_call(retrieve_callmeta("http://saasdev.agaralabs.com/calls/JOR9V47L03AGUEL")) + call_lens = lens["users"].Each()["calls"].Each() + call_stats = call_lens.modify(retrieve_callmeta)(call_logs) + # call_plot_data = call_lens.collect()(call_stats) + call_plots = call_lens.modify(plot_call)(call_stats) + # with ThreadPoolExecutor(max_workers=20) as exe: + # print('starting all plot tasks') + # responses = [exe.submit(plot_call, w) for w in call_plot_data] + # print('submitted all plot tasks') + # call_plots = [r.result() for r in responses] + pprint(call_plots) + + def extract_data_points(): + def gen_data_values(saved_wav_path, data_points): + call_seg = ( + AudioSegment.from_wav(saved_wav_path) + .set_channels(1) + .set_sample_width(2) + .set_frame_rate(24000) + ) + for dp_id, dp in enumerate(data_points): + start, end, code = dp["start_time"], dp["end_time"], dp["code"] + code_seg = call_seg[start * 1000 : end * 1000] + code_fb = BytesIO() + code_seg.export(code_fb, format="wav") + code_wav = code_fb.getvalue() + # import pdb; pdb.set_trace() + yield code, code_seg.duration_seconds, code_wav + + call_lens = lens["users"].Each()["calls"].Each() + call_stats = call_lens.modify(retrieve_callmeta)(call_logs) + call_objs = call_lens.collect()(call_stats) + + def data_source(): + for call_obj in tqdm(call_objs): + saved_wav_path, data_points, sys_id = ( + call_obj["process"]["wav_path"], + call_obj["process"]["data_points"], + call_obj["meta"]["SystemID"], + ) + for dp in gen_data_values(saved_wav_path, data_points): + yield dp + + asr_data_writer(call_asr_data, "call_alphanum", data_source()) + + # @leader_app.command() + def show_leaderboard(): + def compute_user_stats(call_stat): + n_samples = ( + lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat) + ) + n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat) + rel_dur = relativedelta( + seconds=int(n_duration.total_seconds()), + microseconds=n_duration.microseconds, + ) + return { + "num_samples": n_samples, + "duration": n_duration.total_seconds(), + "samples_rate": n_samples / n_duration.total_seconds(), + "duration_str": f"{rel_dur.minutes} mins {rel_dur.seconds} secs", + "name": call_stat["name"], + } + + call_lens = lens["users"].Each()["calls"].Each() + call_stats = call_lens.modify(retrieve_callmeta)(call_logs) + user_stats = lens["users"].Each().modify(compute_user_stats)(call_stats) + leader_df = ( + pd.DataFrame(user_stats["users"]) + .sort_values(by=["duration"], ascending=False) + .reset_index(drop=True) + ) + leader_df["rank"] = leader_df.index + 1 + leader_board = leader_df.rename( + columns={ + "rank": "Rank", + "num_samples": "Codes", + "name": "Name", + "samples_rate": "SpeechRate", + "duration_str": "Duration", + } + )[["Rank", "Name", "Codes", "Duration"]] + print( + """Today's ASR Speller Dataset Leaderboard: +----------------------------------------""" + ) + print(leader_board.to_string(index=False)) + + if leaderboard: + show_leaderboard() + if plot_calls: + plot_calls_data() + if extract_data: + extract_data_points() + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data_utils/generator.py b/jasper/data_utils/generator.py index c49d460..ce26ec1 100644 --- a/jasper/data_utils/generator.py +++ b/jasper/data_utils/generator.py @@ -4,7 +4,7 @@ import argparse import logging from pathlib import Path -from .utils import random_pnr_generator, manifest_str +from .utils import random_pnr_generator, asr_data_writer from .tts.googletts import GoogleTTS from tqdm import tqdm import random @@ -15,27 +15,21 @@ logging.basicConfig( logger = logging.getLogger(__name__) -def generate_asr_data(output_dir, count): +def pnr_tts_streamer(count): google_voices = GoogleTTS.voice_list() gtts = GoogleTTS() - wav_dir = output_dir / Path("pnr_data") - wav_dir.mkdir(parents=True, exist_ok=True) - asr_manifest = output_dir / Path("pnr_data").with_suffix(".json") - with asr_manifest.open("w") as mf: - for pnr_code in tqdm(random_pnr_generator(count)): - tts_code = ( - f'{pnr_code}' - ) - param = random.choice(google_voices) - param["sample_rate"] = 24000 - param["num_channels"] = 1 - wav_data = gtts.text_to_speech(text=tts_code, params=param) - audio_dur = len(wav_data[44:]) / (2 * 24000) - pnr_af = wav_dir / Path(pnr_code).with_suffix(".wav") - pnr_af.write_bytes(wav_data) - rel_pnr_path = pnr_af.relative_to(output_dir) - manifest = manifest_str(str(rel_pnr_path), audio_dur, pnr_code) - mf.write(manifest) + for pnr_code in tqdm(random_pnr_generator(count)): + tts_code = f'{pnr_code}' + param = random.choice(google_voices) + param["sample_rate"] = 24000 + param["num_channels"] = 1 + wav_data = gtts.text_to_speech(text=tts_code, params=param) + audio_dur = len(wav_data[44:]) / (2 * 24000) + yield pnr_code, audio_dur, wav_data + + +def generate_asr_data_fromtts(output_dir, dataset_name, count): + asr_data_writer(output_dir, dataset_name, pnr_tts_streamer(count)) def arg_parser(): @@ -52,13 +46,16 @@ def arg_parser(): parser.add_argument( "--count", type=int, default=3, help="number of datapoints to generate" ) + parser.add_argument( + "--dataset_name", type=str, default="pnr_data", help="name of the dataset" + ) return parser def main(): parser = arg_parser() args = parser.parse_args() - generate_asr_data(**vars(args)) + generate_asr_data_fromtts(**vars(args)) if __name__ == "__main__": diff --git a/jasper/data_utils/process.py b/jasper/data_utils/process.py index 44e4237..7e38523 100644 --- a/jasper/data_utils/process.py +++ b/jasper/data_utils/process.py @@ -1,9 +1,13 @@ import json from pathlib import Path from sklearn.model_selection import train_test_split -from num2words import num2words +from .utils import alnum_to_asr_tokens +import typer + +app = typer.Typer() +@app.command() def separate_space_convert_digit_setpath(): with Path("/home/malar/work/asr-data-utils/asr_data/pnr_data.json").open("r") as pf: pnr_jsonl = pf.readlines() @@ -12,9 +16,7 @@ def separate_space_convert_digit_setpath(): new_pnr_data = [] for i in pnr_data: - letters = " ".join(list(i["text"])) - num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters] - i["text"] = ("".join(num_tokens)).lower() + i["text"] = alnum_to_asr_tokens(i["text"]) i["audio_filepath"] = i["audio_filepath"].replace( "pnr_data/", "/dataset/asr_data/pnr_data/wav/" ) @@ -27,24 +29,39 @@ def separate_space_convert_digit_setpath(): pf.write(new_pnr_data) -separate_space_convert_digit_setpath() - - -def split_data(): - with Path("/dataset/asr_data/pnr_data/pnr_data.json").open("r") as pf: +@app.command() +def split_data(manifest_path: Path = Path("/dataset/asr_data/pnr_data/pnr_data.json")): + with manifest_path.open("r") as pf: pnr_jsonl = pf.readlines() train_pnr, test_pnr = train_test_split(pnr_jsonl, test_size=0.1) - with Path("/dataset/asr_data/pnr_data/train_manifest.json").open("w") as pf: + with (manifest_path.parent / Path("train_manifest.json")).open("w") as pf: pnr_data = "".join(train_pnr) pf.write(pnr_data) - with Path("/dataset/asr_data/pnr_data/test_manifest.json").open("w") as pf: + with (manifest_path.parent / Path("test_manifest.json")).open("w") as pf: pnr_data = "".join(test_pnr) pf.write(pnr_data) -split_data() +@app.command() +def fix_path( + dataset_path: Path = Path("/dataset/asr_data/call_alphanum"), +): + manifest_path = dataset_path / Path('manifest.json') + with manifest_path.open("r") as pf: + pnr_jsonl = pf.readlines() + pnr_data = [json.loads(i) for i in pnr_jsonl] + new_pnr_data = [] + for i in pnr_data: + i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"])) + new_pnr_data.append(i) + new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data] + real_manifest_path = dataset_path / Path('real_manifest.json') + with real_manifest_path.open("w") as pf: + new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n" + pf.write(new_pnr_data) +@app.command() def augment_an4(): an4_train = Path("/dataset/asr_data/an4/train_manifest.json").read_bytes() an4_test = Path("/dataset/asr_data/an4/test_manifest.json").read_bytes() @@ -57,10 +74,11 @@ def augment_an4(): pf.write(an4_test + pnr_test) -augment_an4() +# augment_an4() -def validate_data(data_file): +@app.command() +def validate_data(data_file: Path = Path("/dataset/asr_data/call_alphanum/train_manifest.json")): with Path(data_file).open("r") as pf: pnr_jsonl = pf.readlines() for (i, s) in enumerate(pnr_jsonl): @@ -70,10 +88,13 @@ def validate_data(data_file): print(f"failed on {i}") -validate_data("/dataset/asr_data/an4_pnr/test_manifest.json") -validate_data("/dataset/asr_data/an4_pnr/train_manifest.json") +def main(): + app() +if __name__ == "__main__": + main() + # def convert_digits(data_file="/dataset/asr_data/an4_pnr/test_manifest.json"): # with Path(data_file).open("r") as pf: # pnr_jsonl = pf.readlines() diff --git a/jasper/data_utils/utils.py b/jasper/data_utils/utils.py index bee1e97..aca31e8 100644 --- a/jasper/data_utils/utils.py +++ b/jasper/data_utils/utils.py @@ -2,6 +2,8 @@ import numpy as np import wave import io import json +from pathlib import Path +from num2words import num2words def manifest_str(path, dur, text): @@ -38,6 +40,27 @@ def random_pnr_generator(count=10000): return codes +def alnum_to_asr_tokens(text): + letters = " ".join(list(text)) + num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters] + return ("".join(num_tokens)).lower() + + +def asr_data_writer(output_dir, dataset_name, asr_data_source): + dataset_dir = output_dir / Path(dataset_name) + (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) + asr_manifest = dataset_dir / Path("manifest.json") + with asr_manifest.open("w") as mf: + for pnr_code, audio_dur, wav_data in asr_data_source: + pnr_af = dataset_dir / Path("wav") / Path(pnr_code).with_suffix(".wav") + pnr_af.write_bytes(wav_data) + rel_pnr_path = pnr_af.relative_to(dataset_dir) + manifest = manifest_str( + str(rel_pnr_path), audio_dur, alnum_to_asr_tokens(pnr_code) + ) + mf.write(manifest) + + def main(): for c in random_pnr_generator(): print(c) diff --git a/jasper/train.py b/jasper/train.py index 9861aff..def978f 100644 --- a/jasper/train.py +++ b/jasper/train.py @@ -82,8 +82,7 @@ def parse_args(): ) args = parser.parse_args() - - if args.max_steps is not None and args.num_epochs is not None: + if args.max_steps is None and args.num_epochs is None: raise ValueError("Either max_steps or num_epochs should be provided.") return args @@ -311,7 +310,6 @@ def main(): # build dags train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory) - # train model neural_factory.train( tensors_to_optimize=[train_loss], diff --git a/setup.py b/setup.py index 3948e09..015e125 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,8 @@ from setuptools import setup, find_packages requirements = [ "ruamel.yaml", + "torch==1.4.0", + "torchvision==0.5.0", "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit", ] @@ -14,7 +16,19 @@ extra_requirements = { "scikit_learn~=0.22.1", "pandas~=1.0.3", "boto3~=1.12.35", + "ruamel.yaml==0.16.10", + "pymongo==3.10.1", + "librosa==0.7.2", + "matplotlib==3.2.1", + "pandas==1.0.3", + "tabulate==0.8.7", + "typer[all]==0.1.1", + "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses", ], + # "train": [ + # "torchaudio==0.5.0", + # "torch-stft==0.1.4", + # ] } packages = find_packages() @@ -35,6 +49,8 @@ setup( "jasper_asr_rpyc_server = jasper.server:main", "jasper_asr_trainer = jasper.train:main", "jasper_asr_data_generate = jasper.data_utils.generator:main", + "jasper_asr_data_recycle = jasper.data_utils.call_recycler:main", + "jasper_asr_data_preprocess = jasper.data_utils.process:main", ] }, zip_safe=False,