diff --git a/.gitignore b/.gitignore
index aab7ea0..f5adf10 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,7 @@
+data/
+.env*
+*.yaml
+
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python
@@ -108,3 +112,36 @@ dmypy.json
.pyre/
# End of https://www.gitignore.io/api/python
+
+# Created by https://www.gitignore.io/api/macos
+# Edit at https://www.gitignore.io/?templates=macos
+
+### macOS ###
+# General
+.DS_Store
+.AppleDouble
+.LSOverride
+
+# Icon must end with two \r
+Icon
+
+# Thumbnails
+._*
+
+# Files that might appear in the root of a volume
+.DocumentRevisions-V100
+.fseventsd
+.Spotlight-V100
+.TemporaryItems
+.Trashes
+.VolumeIcon.icns
+.com.apple.timemachine.donotpresent
+
+# Directories potentially created on remote AFP share
+.AppleDB
+.AppleDesktop
+Network Trash Folder
+Temporary Items
+.apdisk
+
+# End of https://www.gitignore.io/api/macos
diff --git a/jasper/data_utils/call_recycler.py b/jasper/data_utils/call_recycler.py
new file mode 100644
index 0000000..8678787
--- /dev/null
+++ b/jasper/data_utils/call_recycler.py
@@ -0,0 +1,333 @@
+# import argparse
+
+# import logging
+import typer
+from pathlib import Path
+
+app = typer.Typer()
+# leader_app = typer.Typer()
+# app.add_typer(leader_app, name="leaderboard")
+# plot_app = typer.Typer()
+# app.add_typer(plot_app, name="plot")
+
+
+@app.command()
+def analyze(
+ leaderboard: bool = False,
+ plot_calls: bool = False,
+ extract_data: bool = False,
+ call_logs_file: Path = Path("./call_logs.yaml"),
+ output_dir: Path = Path("./data"),
+):
+ call_logs_file = Path("./call_logs.yaml")
+ output_dir = Path("./data")
+
+ from urllib.parse import urlsplit
+ from functools import reduce
+ from pymongo import MongoClient
+ import boto3
+
+ from io import BytesIO
+ import json
+ from ruamel.yaml import YAML
+ import re
+ from google.protobuf.timestamp_pb2 import Timestamp
+ from datetime import timedelta
+
+ # from concurrent.futures import ThreadPoolExecutor
+ from dateutil.relativedelta import relativedelta
+ import librosa
+ import librosa.display
+ from lenses import lens
+ from pprint import pprint
+ import pandas as pd
+ import matplotlib.pyplot as plt
+ import matplotlib
+ from tqdm import tqdm
+ from .utils import asr_data_writer
+ from pydub import AudioSegment
+
+ matplotlib.rcParams["agg.path.chunksize"] = 10000
+
+ matplotlib.use("agg")
+
+ # logging.basicConfig(
+ # level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+ # )
+ # logger = logging.getLogger(__name__)
+ yaml = YAML()
+ s3 = boto3.client("s3")
+ mongo_collection = MongoClient("mongodb://localhost:27017/").test.calls
+ call_media_dir: Path = output_dir / Path("call_wavs")
+ call_media_dir.mkdir(exist_ok=True, parents=True)
+ call_meta_dir: Path = output_dir / Path("call_metas")
+ call_meta_dir.mkdir(exist_ok=True, parents=True)
+ call_plot_dir: Path = output_dir / Path("plots")
+ call_plot_dir.mkdir(exist_ok=True, parents=True)
+ call_asr_data: Path = output_dir / Path("asr_data")
+ call_asr_data.mkdir(exist_ok=True, parents=True)
+
+ call_logs = yaml.load(call_logs_file.read_text())
+
+ def get_call_meta(call_obj):
+ s3_event_url_p = urlsplit(call_obj["DataURI"])
+ saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
+ if not saved_meta_path.exists():
+ print(f"downloading : {saved_meta_path}")
+ s3.download_file(
+ s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
+ )
+ call_metas = json.load(saved_meta_path.open())
+ return call_metas
+
+ def gen_ev_fev_timedelta(fev):
+ fev_p = Timestamp()
+ fev_p.FromJsonString(fev["CreatedTS"])
+ fev_dt = fev_p.ToDatetime()
+ td_0 = timedelta()
+
+ def get_timedelta(ev):
+ ev_p = Timestamp()
+ ev_p.FromJsonString(value=ev["CreatedTS"])
+ ev_dt = ev_p.ToDatetime()
+ delta = ev_dt - fev_dt
+ return delta if delta > td_0 else td_0
+
+ return get_timedelta
+
+ def process_call(call_obj):
+ call_meta = get_call_meta(call_obj)
+ call_events = call_meta["Events"]
+
+ def is_writer_event(ev):
+ return ev["Author"] == "AUDIO_WRITER"
+
+ writer_events = list(filter(is_writer_event, call_events))
+ s3_wav_url = re.search(r"saved to: (.*)", writer_events[0]["Msg"]).groups(0)[0]
+ s3_wav_url_p = urlsplit(s3_wav_url)
+
+ def is_first_audio_ev(state, ev):
+ if state[0]:
+ return state
+ else:
+ return (ev["Author"] == "GATEWAY" and ev["Type"] == "AUDIO", ev)
+
+ (_, first_audio_ev) = reduce(is_first_audio_ev, call_events, (False, {}))
+
+ get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
+
+ def is_utter_event(ev):
+ return (
+ (ev["Author"] == "CONV" or ev["Author"] == "ASR")
+ and (ev["Type"] != "DEBUG")
+ and ev["Type"] != "ASR_RESULT"
+ )
+
+ uevs = list(filter(is_utter_event, call_events))
+ ev_count = len(uevs)
+ utter_events = uevs[: ev_count - ev_count % 3]
+ saved_wav_path = call_media_dir / Path(Path(s3_wav_url_p.path).name)
+ if not saved_wav_path.exists():
+ print(f"downloading : {saved_wav_path}")
+ s3.download_file(
+ s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
+ )
+
+ # %config InlineBackend.figure_format = "retina"
+ def chunk_n(evs, n):
+ return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
+
+ def get_data_points(utter_events):
+ data_points = []
+ for evs in chunk_n(utter_events, 3):
+ assert evs[0]["Type"] == "CONV_RESULT"
+ assert evs[1]["Type"] == "STARTED_SPEAKING"
+ assert evs[2]["Type"] == "STOPPED_SPEAKING"
+ start_time = get_ev_fev_timedelta(evs[1]).total_seconds() - 1.5
+ end_time = get_ev_fev_timedelta(evs[2]).total_seconds()
+ code = evs[0]["Msg"]
+ data_points.append(
+ {"start_time": start_time, "end_time": end_time, "code": code}
+ )
+ return data_points
+
+ def plot_events(y, sr, utter_events, file_path):
+ plt.figure(figsize=(16, 12))
+ librosa.display.waveplot(y=y, sr=sr)
+ # plt.tight_layout()
+ for evs in chunk_n(utter_events, 3):
+ assert evs[0]["Type"] == "CONV_RESULT"
+ assert evs[1]["Type"] == "STARTED_SPEAKING"
+ assert evs[2]["Type"] == "STOPPED_SPEAKING"
+ for ev in evs:
+ # print(ev["Type"])
+ ev_type = ev["Type"]
+ pos = get_ev_fev_timedelta(ev).total_seconds()
+ if ev_type == "STARTED_SPEAKING":
+ pos = pos - 1.5
+ plt.axvline(pos) # , label="pyplot vertical line")
+ plt.text(
+ pos,
+ 0.2,
+ f"event:{ev_type}:{ev['Msg']}",
+ rotation=90,
+ horizontalalignment="left"
+ if ev_type != "STOPPED_SPEAKING"
+ else "right",
+ verticalalignment="center",
+ )
+ plt.title("Monophonic")
+ plt.savefig(file_path, format="png")
+
+ data_points = get_data_points(utter_events)
+
+ return {
+ "wav_path": saved_wav_path,
+ "num_samples": len(utter_events) // 3,
+ "meta": call_obj,
+ "data_points": data_points,
+ }
+
+ def retrieve_callmeta(uri):
+ cid = Path(urlsplit(uri).path).stem
+ meta = mongo_collection.find_one({"SystemID": cid})
+ duration = meta["EndTS"] - meta["StartTS"]
+ process_meta = process_call(meta)
+ return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
+
+ # @plot_app.command()
+ def plot_calls_data():
+ def plot_data_points(y, sr, data_points, file_path):
+ plt.figure(figsize=(16, 12))
+ librosa.display.waveplot(y=y, sr=sr)
+ for dp in data_points:
+ start, end, code = dp["start_time"], dp["end_time"], dp["code"]
+ plt.axvspan(start, end, color="green", alpha=0.2)
+ text_pos = (start + end) / 2
+ plt.text(
+ text_pos,
+ 0.25,
+ f"{code}",
+ rotation=90,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+ plt.title("Datapoints")
+ plt.savefig(file_path, format="png")
+ return file_path
+
+ def plot_call(call_obj):
+ saved_wav_path, data_points, sys_id = (
+ call_obj["process"]["wav_path"],
+ call_obj["process"]["data_points"],
+ call_obj["meta"]["SystemID"],
+ )
+ file_path = call_plot_dir / Path(sys_id).with_suffix(".png")
+ if not file_path.exists():
+ print(f"plotting: {file_path}")
+ (y, sr) = librosa.load(saved_wav_path)
+ plot_data_points(y, sr, data_points, str(file_path))
+ return file_path
+
+ # plot_call(retrieve_callmeta("http://saasdev.agaralabs.com/calls/JOR9V47L03AGUEL"))
+ call_lens = lens["users"].Each()["calls"].Each()
+ call_stats = call_lens.modify(retrieve_callmeta)(call_logs)
+ # call_plot_data = call_lens.collect()(call_stats)
+ call_plots = call_lens.modify(plot_call)(call_stats)
+ # with ThreadPoolExecutor(max_workers=20) as exe:
+ # print('starting all plot tasks')
+ # responses = [exe.submit(plot_call, w) for w in call_plot_data]
+ # print('submitted all plot tasks')
+ # call_plots = [r.result() for r in responses]
+ pprint(call_plots)
+
+ def extract_data_points():
+ def gen_data_values(saved_wav_path, data_points):
+ call_seg = (
+ AudioSegment.from_wav(saved_wav_path)
+ .set_channels(1)
+ .set_sample_width(2)
+ .set_frame_rate(24000)
+ )
+ for dp_id, dp in enumerate(data_points):
+ start, end, code = dp["start_time"], dp["end_time"], dp["code"]
+ code_seg = call_seg[start * 1000 : end * 1000]
+ code_fb = BytesIO()
+ code_seg.export(code_fb, format="wav")
+ code_wav = code_fb.getvalue()
+ # import pdb; pdb.set_trace()
+ yield code, code_seg.duration_seconds, code_wav
+
+ call_lens = lens["users"].Each()["calls"].Each()
+ call_stats = call_lens.modify(retrieve_callmeta)(call_logs)
+ call_objs = call_lens.collect()(call_stats)
+
+ def data_source():
+ for call_obj in tqdm(call_objs):
+ saved_wav_path, data_points, sys_id = (
+ call_obj["process"]["wav_path"],
+ call_obj["process"]["data_points"],
+ call_obj["meta"]["SystemID"],
+ )
+ for dp in gen_data_values(saved_wav_path, data_points):
+ yield dp
+
+ asr_data_writer(call_asr_data, "call_alphanum", data_source())
+
+ # @leader_app.command()
+ def show_leaderboard():
+ def compute_user_stats(call_stat):
+ n_samples = (
+ lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
+ )
+ n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
+ rel_dur = relativedelta(
+ seconds=int(n_duration.total_seconds()),
+ microseconds=n_duration.microseconds,
+ )
+ return {
+ "num_samples": n_samples,
+ "duration": n_duration.total_seconds(),
+ "samples_rate": n_samples / n_duration.total_seconds(),
+ "duration_str": f"{rel_dur.minutes} mins {rel_dur.seconds} secs",
+ "name": call_stat["name"],
+ }
+
+ call_lens = lens["users"].Each()["calls"].Each()
+ call_stats = call_lens.modify(retrieve_callmeta)(call_logs)
+ user_stats = lens["users"].Each().modify(compute_user_stats)(call_stats)
+ leader_df = (
+ pd.DataFrame(user_stats["users"])
+ .sort_values(by=["duration"], ascending=False)
+ .reset_index(drop=True)
+ )
+ leader_df["rank"] = leader_df.index + 1
+ leader_board = leader_df.rename(
+ columns={
+ "rank": "Rank",
+ "num_samples": "Codes",
+ "name": "Name",
+ "samples_rate": "SpeechRate",
+ "duration_str": "Duration",
+ }
+ )[["Rank", "Name", "Codes", "Duration"]]
+ print(
+ """Today's ASR Speller Dataset Leaderboard:
+----------------------------------------"""
+ )
+ print(leader_board.to_string(index=False))
+
+ if leaderboard:
+ show_leaderboard()
+ if plot_calls:
+ plot_calls_data()
+ if extract_data:
+ extract_data_points()
+
+
+def main():
+ app()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/jasper/data_utils/generator.py b/jasper/data_utils/generator.py
index c49d460..ce26ec1 100644
--- a/jasper/data_utils/generator.py
+++ b/jasper/data_utils/generator.py
@@ -4,7 +4,7 @@
import argparse
import logging
from pathlib import Path
-from .utils import random_pnr_generator, manifest_str
+from .utils import random_pnr_generator, asr_data_writer
from .tts.googletts import GoogleTTS
from tqdm import tqdm
import random
@@ -15,27 +15,21 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
-def generate_asr_data(output_dir, count):
+def pnr_tts_streamer(count):
google_voices = GoogleTTS.voice_list()
gtts = GoogleTTS()
- wav_dir = output_dir / Path("pnr_data")
- wav_dir.mkdir(parents=True, exist_ok=True)
- asr_manifest = output_dir / Path("pnr_data").with_suffix(".json")
- with asr_manifest.open("w") as mf:
- for pnr_code in tqdm(random_pnr_generator(count)):
- tts_code = (
- f'{pnr_code}'
- )
- param = random.choice(google_voices)
- param["sample_rate"] = 24000
- param["num_channels"] = 1
- wav_data = gtts.text_to_speech(text=tts_code, params=param)
- audio_dur = len(wav_data[44:]) / (2 * 24000)
- pnr_af = wav_dir / Path(pnr_code).with_suffix(".wav")
- pnr_af.write_bytes(wav_data)
- rel_pnr_path = pnr_af.relative_to(output_dir)
- manifest = manifest_str(str(rel_pnr_path), audio_dur, pnr_code)
- mf.write(manifest)
+ for pnr_code in tqdm(random_pnr_generator(count)):
+ tts_code = f'{pnr_code}'
+ param = random.choice(google_voices)
+ param["sample_rate"] = 24000
+ param["num_channels"] = 1
+ wav_data = gtts.text_to_speech(text=tts_code, params=param)
+ audio_dur = len(wav_data[44:]) / (2 * 24000)
+ yield pnr_code, audio_dur, wav_data
+
+
+def generate_asr_data_fromtts(output_dir, dataset_name, count):
+ asr_data_writer(output_dir, dataset_name, pnr_tts_streamer(count))
def arg_parser():
@@ -52,13 +46,16 @@ def arg_parser():
parser.add_argument(
"--count", type=int, default=3, help="number of datapoints to generate"
)
+ parser.add_argument(
+ "--dataset_name", type=str, default="pnr_data", help="name of the dataset"
+ )
return parser
def main():
parser = arg_parser()
args = parser.parse_args()
- generate_asr_data(**vars(args))
+ generate_asr_data_fromtts(**vars(args))
if __name__ == "__main__":
diff --git a/jasper/data_utils/process.py b/jasper/data_utils/process.py
index 44e4237..7e38523 100644
--- a/jasper/data_utils/process.py
+++ b/jasper/data_utils/process.py
@@ -1,9 +1,13 @@
import json
from pathlib import Path
from sklearn.model_selection import train_test_split
-from num2words import num2words
+from .utils import alnum_to_asr_tokens
+import typer
+
+app = typer.Typer()
+@app.command()
def separate_space_convert_digit_setpath():
with Path("/home/malar/work/asr-data-utils/asr_data/pnr_data.json").open("r") as pf:
pnr_jsonl = pf.readlines()
@@ -12,9 +16,7 @@ def separate_space_convert_digit_setpath():
new_pnr_data = []
for i in pnr_data:
- letters = " ".join(list(i["text"]))
- num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
- i["text"] = ("".join(num_tokens)).lower()
+ i["text"] = alnum_to_asr_tokens(i["text"])
i["audio_filepath"] = i["audio_filepath"].replace(
"pnr_data/", "/dataset/asr_data/pnr_data/wav/"
)
@@ -27,24 +29,39 @@ def separate_space_convert_digit_setpath():
pf.write(new_pnr_data)
-separate_space_convert_digit_setpath()
-
-
-def split_data():
- with Path("/dataset/asr_data/pnr_data/pnr_data.json").open("r") as pf:
+@app.command()
+def split_data(manifest_path: Path = Path("/dataset/asr_data/pnr_data/pnr_data.json")):
+ with manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
train_pnr, test_pnr = train_test_split(pnr_jsonl, test_size=0.1)
- with Path("/dataset/asr_data/pnr_data/train_manifest.json").open("w") as pf:
+ with (manifest_path.parent / Path("train_manifest.json")).open("w") as pf:
pnr_data = "".join(train_pnr)
pf.write(pnr_data)
- with Path("/dataset/asr_data/pnr_data/test_manifest.json").open("w") as pf:
+ with (manifest_path.parent / Path("test_manifest.json")).open("w") as pf:
pnr_data = "".join(test_pnr)
pf.write(pnr_data)
-split_data()
+@app.command()
+def fix_path(
+ dataset_path: Path = Path("/dataset/asr_data/call_alphanum"),
+):
+ manifest_path = dataset_path / Path('manifest.json')
+ with manifest_path.open("r") as pf:
+ pnr_jsonl = pf.readlines()
+ pnr_data = [json.loads(i) for i in pnr_jsonl]
+ new_pnr_data = []
+ for i in pnr_data:
+ i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
+ new_pnr_data.append(i)
+ new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data]
+ real_manifest_path = dataset_path / Path('real_manifest.json')
+ with real_manifest_path.open("w") as pf:
+ new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n"
+ pf.write(new_pnr_data)
+@app.command()
def augment_an4():
an4_train = Path("/dataset/asr_data/an4/train_manifest.json").read_bytes()
an4_test = Path("/dataset/asr_data/an4/test_manifest.json").read_bytes()
@@ -57,10 +74,11 @@ def augment_an4():
pf.write(an4_test + pnr_test)
-augment_an4()
+# augment_an4()
-def validate_data(data_file):
+@app.command()
+def validate_data(data_file: Path = Path("/dataset/asr_data/call_alphanum/train_manifest.json")):
with Path(data_file).open("r") as pf:
pnr_jsonl = pf.readlines()
for (i, s) in enumerate(pnr_jsonl):
@@ -70,10 +88,13 @@ def validate_data(data_file):
print(f"failed on {i}")
-validate_data("/dataset/asr_data/an4_pnr/test_manifest.json")
-validate_data("/dataset/asr_data/an4_pnr/train_manifest.json")
+def main():
+ app()
+if __name__ == "__main__":
+ main()
+
# def convert_digits(data_file="/dataset/asr_data/an4_pnr/test_manifest.json"):
# with Path(data_file).open("r") as pf:
# pnr_jsonl = pf.readlines()
diff --git a/jasper/data_utils/utils.py b/jasper/data_utils/utils.py
index bee1e97..aca31e8 100644
--- a/jasper/data_utils/utils.py
+++ b/jasper/data_utils/utils.py
@@ -2,6 +2,8 @@ import numpy as np
import wave
import io
import json
+from pathlib import Path
+from num2words import num2words
def manifest_str(path, dur, text):
@@ -38,6 +40,27 @@ def random_pnr_generator(count=10000):
return codes
+def alnum_to_asr_tokens(text):
+ letters = " ".join(list(text))
+ num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
+ return ("".join(num_tokens)).lower()
+
+
+def asr_data_writer(output_dir, dataset_name, asr_data_source):
+ dataset_dir = output_dir / Path(dataset_name)
+ (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
+ asr_manifest = dataset_dir / Path("manifest.json")
+ with asr_manifest.open("w") as mf:
+ for pnr_code, audio_dur, wav_data in asr_data_source:
+ pnr_af = dataset_dir / Path("wav") / Path(pnr_code).with_suffix(".wav")
+ pnr_af.write_bytes(wav_data)
+ rel_pnr_path = pnr_af.relative_to(dataset_dir)
+ manifest = manifest_str(
+ str(rel_pnr_path), audio_dur, alnum_to_asr_tokens(pnr_code)
+ )
+ mf.write(manifest)
+
+
def main():
for c in random_pnr_generator():
print(c)
diff --git a/jasper/train.py b/jasper/train.py
index 9861aff..def978f 100644
--- a/jasper/train.py
+++ b/jasper/train.py
@@ -82,8 +82,7 @@ def parse_args():
)
args = parser.parse_args()
-
- if args.max_steps is not None and args.num_epochs is not None:
+ if args.max_steps is None and args.num_epochs is None:
raise ValueError("Either max_steps or num_epochs should be provided.")
return args
@@ -311,7 +310,6 @@ def main():
# build dags
train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory)
-
# train model
neural_factory.train(
tensors_to_optimize=[train_loss],
diff --git a/setup.py b/setup.py
index 3948e09..015e125 100644
--- a/setup.py
+++ b/setup.py
@@ -2,6 +2,8 @@ from setuptools import setup, find_packages
requirements = [
"ruamel.yaml",
+ "torch==1.4.0",
+ "torchvision==0.5.0",
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
]
@@ -14,7 +16,19 @@ extra_requirements = {
"scikit_learn~=0.22.1",
"pandas~=1.0.3",
"boto3~=1.12.35",
+ "ruamel.yaml==0.16.10",
+ "pymongo==3.10.1",
+ "librosa==0.7.2",
+ "matplotlib==3.2.1",
+ "pandas==1.0.3",
+ "tabulate==0.8.7",
+ "typer[all]==0.1.1",
+ "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
],
+ # "train": [
+ # "torchaudio==0.5.0",
+ # "torch-stft==0.1.4",
+ # ]
}
packages = find_packages()
@@ -35,6 +49,8 @@ setup(
"jasper_asr_rpyc_server = jasper.server:main",
"jasper_asr_trainer = jasper.train:main",
"jasper_asr_data_generate = jasper.data_utils.generator:main",
+ "jasper_asr_data_recycle = jasper.data_utils.call_recycler:main",
+ "jasper_asr_data_preprocess = jasper.data_utils.process:main",
]
},
zip_safe=False,