implement call audio data recycler for asr

Malar Kannan 2020-04-27 10:53:14 +05:30
parent 2c15b00da3
commit 61048f855e
7 changed files with 465 additions and 40 deletions

37
.gitignore vendored
View File

@ -1,3 +1,7 @@
data/
.env*
*.yaml
# Created by https://www.gitignore.io/api/python # Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python # Edit at https://www.gitignore.io/?templates=python
@ -108,3 +112,36 @@ dmypy.json
.pyre/ .pyre/
# End of https://www.gitignore.io/api/python # End of https://www.gitignore.io/api/python
# Created by https://www.gitignore.io/api/macos
# Edit at https://www.gitignore.io/?templates=macos
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
# End of https://www.gitignore.io/api/macos

View File

@ -0,0 +1,333 @@
# import argparse
# import logging
import typer
from pathlib import Path
app = typer.Typer()
# leader_app = typer.Typer()
# app.add_typer(leader_app, name="leaderboard")
# plot_app = typer.Typer()
# app.add_typer(plot_app, name="plot")
@app.command()
def analyze(
leaderboard: bool = False,
plot_calls: bool = False,
extract_data: bool = False,
call_logs_file: Path = Path("./call_logs.yaml"),
output_dir: Path = Path("./data"),
):
call_logs_file = Path("./call_logs.yaml")
output_dir = Path("./data")
from urllib.parse import urlsplit
from functools import reduce
from pymongo import MongoClient
import boto3
from io import BytesIO
import json
from ruamel.yaml import YAML
import re
from google.protobuf.timestamp_pb2 import Timestamp
from datetime import timedelta
# from concurrent.futures import ThreadPoolExecutor
from dateutil.relativedelta import relativedelta
import librosa
import librosa.display
from lenses import lens
from pprint import pprint
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from tqdm import tqdm
from .utils import asr_data_writer
from pydub import AudioSegment
matplotlib.rcParams["agg.path.chunksize"] = 10000
matplotlib.use("agg")
# logging.basicConfig(
# level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# )
# logger = logging.getLogger(__name__)
yaml = YAML()
s3 = boto3.client("s3")
mongo_collection = MongoClient("mongodb://localhost:27017/").test.calls
call_media_dir: Path = output_dir / Path("call_wavs")
call_media_dir.mkdir(exist_ok=True, parents=True)
call_meta_dir: Path = output_dir / Path("call_metas")
call_meta_dir.mkdir(exist_ok=True, parents=True)
call_plot_dir: Path = output_dir / Path("plots")
call_plot_dir.mkdir(exist_ok=True, parents=True)
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
call_logs = yaml.load(call_logs_file.read_text())
def get_call_meta(call_obj):
s3_event_url_p = urlsplit(call_obj["DataURI"])
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
if not saved_meta_path.exists():
print(f"downloading : {saved_meta_path}")
s3.download_file(
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
)
call_metas = json.load(saved_meta_path.open())
return call_metas
def gen_ev_fev_timedelta(fev):
fev_p = Timestamp()
fev_p.FromJsonString(fev["CreatedTS"])
fev_dt = fev_p.ToDatetime()
td_0 = timedelta()
def get_timedelta(ev):
ev_p = Timestamp()
ev_p.FromJsonString(value=ev["CreatedTS"])
ev_dt = ev_p.ToDatetime()
delta = ev_dt - fev_dt
return delta if delta > td_0 else td_0
return get_timedelta
def process_call(call_obj):
call_meta = get_call_meta(call_obj)
call_events = call_meta["Events"]
def is_writer_event(ev):
return ev["Author"] == "AUDIO_WRITER"
writer_events = list(filter(is_writer_event, call_events))
s3_wav_url = re.search(r"saved to: (.*)", writer_events[0]["Msg"]).groups(0)[0]
s3_wav_url_p = urlsplit(s3_wav_url)
def is_first_audio_ev(state, ev):
if state[0]:
return state
else:
return (ev["Author"] == "GATEWAY" and ev["Type"] == "AUDIO", ev)
(_, first_audio_ev) = reduce(is_first_audio_ev, call_events, (False, {}))
get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
def is_utter_event(ev):
return (
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
and (ev["Type"] != "DEBUG")
and ev["Type"] != "ASR_RESULT"
)
uevs = list(filter(is_utter_event, call_events))
ev_count = len(uevs)
utter_events = uevs[: ev_count - ev_count % 3]
saved_wav_path = call_media_dir / Path(Path(s3_wav_url_p.path).name)
if not saved_wav_path.exists():
print(f"downloading : {saved_wav_path}")
s3.download_file(
s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
)
# %config InlineBackend.figure_format = "retina"
def chunk_n(evs, n):
return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
def get_data_points(utter_events):
data_points = []
for evs in chunk_n(utter_events, 3):
assert evs[0]["Type"] == "CONV_RESULT"
assert evs[1]["Type"] == "STARTED_SPEAKING"
assert evs[2]["Type"] == "STOPPED_SPEAKING"
start_time = get_ev_fev_timedelta(evs[1]).total_seconds() - 1.5
end_time = get_ev_fev_timedelta(evs[2]).total_seconds()
code = evs[0]["Msg"]
data_points.append(
{"start_time": start_time, "end_time": end_time, "code": code}
)
return data_points
def plot_events(y, sr, utter_events, file_path):
plt.figure(figsize=(16, 12))
librosa.display.waveplot(y=y, sr=sr)
# plt.tight_layout()
for evs in chunk_n(utter_events, 3):
assert evs[0]["Type"] == "CONV_RESULT"
assert evs[1]["Type"] == "STARTED_SPEAKING"
assert evs[2]["Type"] == "STOPPED_SPEAKING"
for ev in evs:
# print(ev["Type"])
ev_type = ev["Type"]
pos = get_ev_fev_timedelta(ev).total_seconds()
if ev_type == "STARTED_SPEAKING":
pos = pos - 1.5
plt.axvline(pos) # , label="pyplot vertical line")
plt.text(
pos,
0.2,
f"event:{ev_type}:{ev['Msg']}",
rotation=90,
horizontalalignment="left"
if ev_type != "STOPPED_SPEAKING"
else "right",
verticalalignment="center",
)
plt.title("Monophonic")
plt.savefig(file_path, format="png")
data_points = get_data_points(utter_events)
return {
"wav_path": saved_wav_path,
"num_samples": len(utter_events) // 3,
"meta": call_obj,
"data_points": data_points,
}
def retrieve_callmeta(uri):
cid = Path(urlsplit(uri).path).stem
meta = mongo_collection.find_one({"SystemID": cid})
duration = meta["EndTS"] - meta["StartTS"]
process_meta = process_call(meta)
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
# @plot_app.command()
def plot_calls_data():
def plot_data_points(y, sr, data_points, file_path):
plt.figure(figsize=(16, 12))
librosa.display.waveplot(y=y, sr=sr)
for dp in data_points:
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
plt.axvspan(start, end, color="green", alpha=0.2)
text_pos = (start + end) / 2
plt.text(
text_pos,
0.25,
f"{code}",
rotation=90,
horizontalalignment="center",
verticalalignment="center",
)
plt.title("Datapoints")
plt.savefig(file_path, format="png")
return file_path
def plot_call(call_obj):
saved_wav_path, data_points, sys_id = (
call_obj["process"]["wav_path"],
call_obj["process"]["data_points"],
call_obj["meta"]["SystemID"],
)
file_path = call_plot_dir / Path(sys_id).with_suffix(".png")
if not file_path.exists():
print(f"plotting: {file_path}")
(y, sr) = librosa.load(saved_wav_path)
plot_data_points(y, sr, data_points, str(file_path))
return file_path
# plot_call(retrieve_callmeta("http://saasdev.agaralabs.com/calls/JOR9V47L03AGUEL"))
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_callmeta)(call_logs)
# call_plot_data = call_lens.collect()(call_stats)
call_plots = call_lens.modify(plot_call)(call_stats)
# with ThreadPoolExecutor(max_workers=20) as exe:
# print('starting all plot tasks')
# responses = [exe.submit(plot_call, w) for w in call_plot_data]
# print('submitted all plot tasks')
# call_plots = [r.result() for r in responses]
pprint(call_plots)
def extract_data_points():
def gen_data_values(saved_wav_path, data_points):
call_seg = (
AudioSegment.from_wav(saved_wav_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
for dp_id, dp in enumerate(data_points):
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
code_seg = call_seg[start * 1000 : end * 1000]
code_fb = BytesIO()
code_seg.export(code_fb, format="wav")
code_wav = code_fb.getvalue()
# import pdb; pdb.set_trace()
yield code, code_seg.duration_seconds, code_wav
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_callmeta)(call_logs)
call_objs = call_lens.collect()(call_stats)
def data_source():
for call_obj in tqdm(call_objs):
saved_wav_path, data_points, sys_id = (
call_obj["process"]["wav_path"],
call_obj["process"]["data_points"],
call_obj["meta"]["SystemID"],
)
for dp in gen_data_values(saved_wav_path, data_points):
yield dp
asr_data_writer(call_asr_data, "call_alphanum", data_source())
# @leader_app.command()
def show_leaderboard():
def compute_user_stats(call_stat):
n_samples = (
lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
)
n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
rel_dur = relativedelta(
seconds=int(n_duration.total_seconds()),
microseconds=n_duration.microseconds,
)
return {
"num_samples": n_samples,
"duration": n_duration.total_seconds(),
"samples_rate": n_samples / n_duration.total_seconds(),
"duration_str": f"{rel_dur.minutes} mins {rel_dur.seconds} secs",
"name": call_stat["name"],
}
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_callmeta)(call_logs)
user_stats = lens["users"].Each().modify(compute_user_stats)(call_stats)
leader_df = (
pd.DataFrame(user_stats["users"])
.sort_values(by=["duration"], ascending=False)
.reset_index(drop=True)
)
leader_df["rank"] = leader_df.index + 1
leader_board = leader_df.rename(
columns={
"rank": "Rank",
"num_samples": "Codes",
"name": "Name",
"samples_rate": "SpeechRate",
"duration_str": "Duration",
}
)[["Rank", "Name", "Codes", "Duration"]]
print(
"""Today's ASR Speller Dataset Leaderboard:
----------------------------------------"""
)
print(leader_board.to_string(index=False))
if leaderboard:
show_leaderboard()
if plot_calls:
plot_calls_data()
if extract_data:
extract_data_points()
def main():
app()
if __name__ == "__main__":
main()

View File

@ -4,7 +4,7 @@
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from .utils import random_pnr_generator, manifest_str from .utils import random_pnr_generator, asr_data_writer
from .tts.googletts import GoogleTTS from .tts.googletts import GoogleTTS
from tqdm import tqdm from tqdm import tqdm
import random import random
@ -15,27 +15,21 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def generate_asr_data(output_dir, count): def pnr_tts_streamer(count):
google_voices = GoogleTTS.voice_list() google_voices = GoogleTTS.voice_list()
gtts = GoogleTTS() gtts = GoogleTTS()
wav_dir = output_dir / Path("pnr_data")
wav_dir.mkdir(parents=True, exist_ok=True)
asr_manifest = output_dir / Path("pnr_data").with_suffix(".json")
with asr_manifest.open("w") as mf:
for pnr_code in tqdm(random_pnr_generator(count)): for pnr_code in tqdm(random_pnr_generator(count)):
tts_code = ( tts_code = f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
)
param = random.choice(google_voices) param = random.choice(google_voices)
param["sample_rate"] = 24000 param["sample_rate"] = 24000
param["num_channels"] = 1 param["num_channels"] = 1
wav_data = gtts.text_to_speech(text=tts_code, params=param) wav_data = gtts.text_to_speech(text=tts_code, params=param)
audio_dur = len(wav_data[44:]) / (2 * 24000) audio_dur = len(wav_data[44:]) / (2 * 24000)
pnr_af = wav_dir / Path(pnr_code).with_suffix(".wav") yield pnr_code, audio_dur, wav_data
pnr_af.write_bytes(wav_data)
rel_pnr_path = pnr_af.relative_to(output_dir)
manifest = manifest_str(str(rel_pnr_path), audio_dur, pnr_code) def generate_asr_data_fromtts(output_dir, dataset_name, count):
mf.write(manifest) asr_data_writer(output_dir, dataset_name, pnr_tts_streamer(count))
def arg_parser(): def arg_parser():
@ -52,13 +46,16 @@ def arg_parser():
parser.add_argument( parser.add_argument(
"--count", type=int, default=3, help="number of datapoints to generate" "--count", type=int, default=3, help="number of datapoints to generate"
) )
parser.add_argument(
"--dataset_name", type=str, default="pnr_data", help="name of the dataset"
)
return parser return parser
def main(): def main():
parser = arg_parser() parser = arg_parser()
args = parser.parse_args() args = parser.parse_args()
generate_asr_data(**vars(args)) generate_asr_data_fromtts(**vars(args))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,9 +1,13 @@
import json import json
from pathlib import Path from pathlib import Path
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from num2words import num2words from .utils import alnum_to_asr_tokens
import typer
app = typer.Typer()
@app.command()
def separate_space_convert_digit_setpath(): def separate_space_convert_digit_setpath():
with Path("/home/malar/work/asr-data-utils/asr_data/pnr_data.json").open("r") as pf: with Path("/home/malar/work/asr-data-utils/asr_data/pnr_data.json").open("r") as pf:
pnr_jsonl = pf.readlines() pnr_jsonl = pf.readlines()
@ -12,9 +16,7 @@ def separate_space_convert_digit_setpath():
new_pnr_data = [] new_pnr_data = []
for i in pnr_data: for i in pnr_data:
letters = " ".join(list(i["text"])) i["text"] = alnum_to_asr_tokens(i["text"])
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
i["text"] = ("".join(num_tokens)).lower()
i["audio_filepath"] = i["audio_filepath"].replace( i["audio_filepath"] = i["audio_filepath"].replace(
"pnr_data/", "/dataset/asr_data/pnr_data/wav/" "pnr_data/", "/dataset/asr_data/pnr_data/wav/"
) )
@ -27,24 +29,39 @@ def separate_space_convert_digit_setpath():
pf.write(new_pnr_data) pf.write(new_pnr_data)
separate_space_convert_digit_setpath() @app.command()
def split_data(manifest_path: Path = Path("/dataset/asr_data/pnr_data/pnr_data.json")):
with manifest_path.open("r") as pf:
def split_data():
with Path("/dataset/asr_data/pnr_data/pnr_data.json").open("r") as pf:
pnr_jsonl = pf.readlines() pnr_jsonl = pf.readlines()
train_pnr, test_pnr = train_test_split(pnr_jsonl, test_size=0.1) train_pnr, test_pnr = train_test_split(pnr_jsonl, test_size=0.1)
with Path("/dataset/asr_data/pnr_data/train_manifest.json").open("w") as pf: with (manifest_path.parent / Path("train_manifest.json")).open("w") as pf:
pnr_data = "".join(train_pnr) pnr_data = "".join(train_pnr)
pf.write(pnr_data) pf.write(pnr_data)
with Path("/dataset/asr_data/pnr_data/test_manifest.json").open("w") as pf: with (manifest_path.parent / Path("test_manifest.json")).open("w") as pf:
pnr_data = "".join(test_pnr) pnr_data = "".join(test_pnr)
pf.write(pnr_data) pf.write(pnr_data)
split_data() @app.command()
def fix_path(
dataset_path: Path = Path("/dataset/asr_data/call_alphanum"),
):
manifest_path = dataset_path / Path('manifest.json')
with manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [json.loads(i) for i in pnr_jsonl]
new_pnr_data = []
for i in pnr_data:
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
new_pnr_data.append(i)
new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data]
real_manifest_path = dataset_path / Path('real_manifest.json')
with real_manifest_path.open("w") as pf:
new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n"
pf.write(new_pnr_data)
@app.command()
def augment_an4(): def augment_an4():
an4_train = Path("/dataset/asr_data/an4/train_manifest.json").read_bytes() an4_train = Path("/dataset/asr_data/an4/train_manifest.json").read_bytes()
an4_test = Path("/dataset/asr_data/an4/test_manifest.json").read_bytes() an4_test = Path("/dataset/asr_data/an4/test_manifest.json").read_bytes()
@ -57,10 +74,11 @@ def augment_an4():
pf.write(an4_test + pnr_test) pf.write(an4_test + pnr_test)
augment_an4() # augment_an4()
def validate_data(data_file): @app.command()
def validate_data(data_file: Path = Path("/dataset/asr_data/call_alphanum/train_manifest.json")):
with Path(data_file).open("r") as pf: with Path(data_file).open("r") as pf:
pnr_jsonl = pf.readlines() pnr_jsonl = pf.readlines()
for (i, s) in enumerate(pnr_jsonl): for (i, s) in enumerate(pnr_jsonl):
@ -70,10 +88,13 @@ def validate_data(data_file):
print(f"failed on {i}") print(f"failed on {i}")
validate_data("/dataset/asr_data/an4_pnr/test_manifest.json") def main():
validate_data("/dataset/asr_data/an4_pnr/train_manifest.json") app()
if __name__ == "__main__":
main()
# def convert_digits(data_file="/dataset/asr_data/an4_pnr/test_manifest.json"): # def convert_digits(data_file="/dataset/asr_data/an4_pnr/test_manifest.json"):
# with Path(data_file).open("r") as pf: # with Path(data_file).open("r") as pf:
# pnr_jsonl = pf.readlines() # pnr_jsonl = pf.readlines()

View File

@ -2,6 +2,8 @@ import numpy as np
import wave import wave
import io import io
import json import json
from pathlib import Path
from num2words import num2words
def manifest_str(path, dur, text): def manifest_str(path, dur, text):
@ -38,6 +40,27 @@ def random_pnr_generator(count=10000):
return codes return codes
def alnum_to_asr_tokens(text):
letters = " ".join(list(text))
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
return ("".join(num_tokens)).lower()
def asr_data_writer(output_dir, dataset_name, asr_data_source):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
with asr_manifest.open("w") as mf:
for pnr_code, audio_dur, wav_data in asr_data_source:
pnr_af = dataset_dir / Path("wav") / Path(pnr_code).with_suffix(".wav")
pnr_af.write_bytes(wav_data)
rel_pnr_path = pnr_af.relative_to(dataset_dir)
manifest = manifest_str(
str(rel_pnr_path), audio_dur, alnum_to_asr_tokens(pnr_code)
)
mf.write(manifest)
def main(): def main():
for c in random_pnr_generator(): for c in random_pnr_generator():
print(c) print(c)

View File

@ -82,8 +82,7 @@ def parse_args():
) )
args = parser.parse_args() args = parser.parse_args()
if args.max_steps is None and args.num_epochs is None:
if args.max_steps is not None and args.num_epochs is not None:
raise ValueError("Either max_steps or num_epochs should be provided.") raise ValueError("Either max_steps or num_epochs should be provided.")
return args return args
@ -311,7 +310,6 @@ def main():
# build dags # build dags
train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory) train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory)
# train model # train model
neural_factory.train( neural_factory.train(
tensors_to_optimize=[train_loss], tensors_to_optimize=[train_loss],

View File

@ -2,6 +2,8 @@ from setuptools import setup, find_packages
requirements = [ requirements = [
"ruamel.yaml", "ruamel.yaml",
"torch==1.4.0",
"torchvision==0.5.0",
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit", "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
] ]
@ -14,7 +16,19 @@ extra_requirements = {
"scikit_learn~=0.22.1", "scikit_learn~=0.22.1",
"pandas~=1.0.3", "pandas~=1.0.3",
"boto3~=1.12.35", "boto3~=1.12.35",
"ruamel.yaml==0.16.10",
"pymongo==3.10.1",
"librosa==0.7.2",
"matplotlib==3.2.1",
"pandas==1.0.3",
"tabulate==0.8.7",
"typer[all]==0.1.1",
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
], ],
# "train": [
# "torchaudio==0.5.0",
# "torch-stft==0.1.4",
# ]
} }
packages = find_packages() packages = find_packages()
@ -35,6 +49,8 @@ setup(
"jasper_asr_rpyc_server = jasper.server:main", "jasper_asr_rpyc_server = jasper.server:main",
"jasper_asr_trainer = jasper.train:main", "jasper_asr_trainer = jasper.train:main",
"jasper_asr_data_generate = jasper.data_utils.generator:main", "jasper_asr_data_generate = jasper.data_utils.generator:main",
"jasper_asr_data_recycle = jasper.data_utils.call_recycler:main",
"jasper_asr_data_preprocess = jasper.data_utils.process:main",
] ]
}, },
zip_safe=False, zip_safe=False,