implement call audio data recycler for asr
parent
2c15b00da3
commit
61048f855e
|
|
@ -1,3 +1,7 @@
|
|||
data/
|
||||
.env*
|
||||
*.yaml
|
||||
|
||||
|
||||
# Created by https://www.gitignore.io/api/python
|
||||
# Edit at https://www.gitignore.io/?templates=python
|
||||
|
|
@ -108,3 +112,36 @@ dmypy.json
|
|||
.pyre/
|
||||
|
||||
# End of https://www.gitignore.io/api/python
|
||||
|
||||
# Created by https://www.gitignore.io/api/macos
|
||||
# Edit at https://www.gitignore.io/?templates=macos
|
||||
|
||||
### macOS ###
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
# End of https://www.gitignore.io/api/macos
|
||||
|
|
|
|||
|
|
@ -0,0 +1,333 @@
|
|||
# import argparse
|
||||
|
||||
# import logging
|
||||
import typer
|
||||
from pathlib import Path
|
||||
|
||||
app = typer.Typer()
|
||||
# leader_app = typer.Typer()
|
||||
# app.add_typer(leader_app, name="leaderboard")
|
||||
# plot_app = typer.Typer()
|
||||
# app.add_typer(plot_app, name="plot")
|
||||
|
||||
|
||||
@app.command()
|
||||
def analyze(
|
||||
leaderboard: bool = False,
|
||||
plot_calls: bool = False,
|
||||
extract_data: bool = False,
|
||||
call_logs_file: Path = Path("./call_logs.yaml"),
|
||||
output_dir: Path = Path("./data"),
|
||||
):
|
||||
call_logs_file = Path("./call_logs.yaml")
|
||||
output_dir = Path("./data")
|
||||
|
||||
from urllib.parse import urlsplit
|
||||
from functools import reduce
|
||||
from pymongo import MongoClient
|
||||
import boto3
|
||||
|
||||
from io import BytesIO
|
||||
import json
|
||||
from ruamel.yaml import YAML
|
||||
import re
|
||||
from google.protobuf.timestamp_pb2 import Timestamp
|
||||
from datetime import timedelta
|
||||
|
||||
# from concurrent.futures import ThreadPoolExecutor
|
||||
from dateutil.relativedelta import relativedelta
|
||||
import librosa
|
||||
import librosa.display
|
||||
from lenses import lens
|
||||
from pprint import pprint
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
from tqdm import tqdm
|
||||
from .utils import asr_data_writer
|
||||
from pydub import AudioSegment
|
||||
|
||||
matplotlib.rcParams["agg.path.chunksize"] = 10000
|
||||
|
||||
matplotlib.use("agg")
|
||||
|
||||
# logging.basicConfig(
|
||||
# level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
# )
|
||||
# logger = logging.getLogger(__name__)
|
||||
yaml = YAML()
|
||||
s3 = boto3.client("s3")
|
||||
mongo_collection = MongoClient("mongodb://localhost:27017/").test.calls
|
||||
call_media_dir: Path = output_dir / Path("call_wavs")
|
||||
call_media_dir.mkdir(exist_ok=True, parents=True)
|
||||
call_meta_dir: Path = output_dir / Path("call_metas")
|
||||
call_meta_dir.mkdir(exist_ok=True, parents=True)
|
||||
call_plot_dir: Path = output_dir / Path("plots")
|
||||
call_plot_dir.mkdir(exist_ok=True, parents=True)
|
||||
call_asr_data: Path = output_dir / Path("asr_data")
|
||||
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
call_logs = yaml.load(call_logs_file.read_text())
|
||||
|
||||
def get_call_meta(call_obj):
|
||||
s3_event_url_p = urlsplit(call_obj["DataURI"])
|
||||
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
|
||||
if not saved_meta_path.exists():
|
||||
print(f"downloading : {saved_meta_path}")
|
||||
s3.download_file(
|
||||
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
|
||||
)
|
||||
call_metas = json.load(saved_meta_path.open())
|
||||
return call_metas
|
||||
|
||||
def gen_ev_fev_timedelta(fev):
|
||||
fev_p = Timestamp()
|
||||
fev_p.FromJsonString(fev["CreatedTS"])
|
||||
fev_dt = fev_p.ToDatetime()
|
||||
td_0 = timedelta()
|
||||
|
||||
def get_timedelta(ev):
|
||||
ev_p = Timestamp()
|
||||
ev_p.FromJsonString(value=ev["CreatedTS"])
|
||||
ev_dt = ev_p.ToDatetime()
|
||||
delta = ev_dt - fev_dt
|
||||
return delta if delta > td_0 else td_0
|
||||
|
||||
return get_timedelta
|
||||
|
||||
def process_call(call_obj):
|
||||
call_meta = get_call_meta(call_obj)
|
||||
call_events = call_meta["Events"]
|
||||
|
||||
def is_writer_event(ev):
|
||||
return ev["Author"] == "AUDIO_WRITER"
|
||||
|
||||
writer_events = list(filter(is_writer_event, call_events))
|
||||
s3_wav_url = re.search(r"saved to: (.*)", writer_events[0]["Msg"]).groups(0)[0]
|
||||
s3_wav_url_p = urlsplit(s3_wav_url)
|
||||
|
||||
def is_first_audio_ev(state, ev):
|
||||
if state[0]:
|
||||
return state
|
||||
else:
|
||||
return (ev["Author"] == "GATEWAY" and ev["Type"] == "AUDIO", ev)
|
||||
|
||||
(_, first_audio_ev) = reduce(is_first_audio_ev, call_events, (False, {}))
|
||||
|
||||
get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
|
||||
|
||||
def is_utter_event(ev):
|
||||
return (
|
||||
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
|
||||
and (ev["Type"] != "DEBUG")
|
||||
and ev["Type"] != "ASR_RESULT"
|
||||
)
|
||||
|
||||
uevs = list(filter(is_utter_event, call_events))
|
||||
ev_count = len(uevs)
|
||||
utter_events = uevs[: ev_count - ev_count % 3]
|
||||
saved_wav_path = call_media_dir / Path(Path(s3_wav_url_p.path).name)
|
||||
if not saved_wav_path.exists():
|
||||
print(f"downloading : {saved_wav_path}")
|
||||
s3.download_file(
|
||||
s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
|
||||
)
|
||||
|
||||
# %config InlineBackend.figure_format = "retina"
|
||||
def chunk_n(evs, n):
|
||||
return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
|
||||
|
||||
def get_data_points(utter_events):
|
||||
data_points = []
|
||||
for evs in chunk_n(utter_events, 3):
|
||||
assert evs[0]["Type"] == "CONV_RESULT"
|
||||
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
||||
assert evs[2]["Type"] == "STOPPED_SPEAKING"
|
||||
start_time = get_ev_fev_timedelta(evs[1]).total_seconds() - 1.5
|
||||
end_time = get_ev_fev_timedelta(evs[2]).total_seconds()
|
||||
code = evs[0]["Msg"]
|
||||
data_points.append(
|
||||
{"start_time": start_time, "end_time": end_time, "code": code}
|
||||
)
|
||||
return data_points
|
||||
|
||||
def plot_events(y, sr, utter_events, file_path):
|
||||
plt.figure(figsize=(16, 12))
|
||||
librosa.display.waveplot(y=y, sr=sr)
|
||||
# plt.tight_layout()
|
||||
for evs in chunk_n(utter_events, 3):
|
||||
assert evs[0]["Type"] == "CONV_RESULT"
|
||||
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
||||
assert evs[2]["Type"] == "STOPPED_SPEAKING"
|
||||
for ev in evs:
|
||||
# print(ev["Type"])
|
||||
ev_type = ev["Type"]
|
||||
pos = get_ev_fev_timedelta(ev).total_seconds()
|
||||
if ev_type == "STARTED_SPEAKING":
|
||||
pos = pos - 1.5
|
||||
plt.axvline(pos) # , label="pyplot vertical line")
|
||||
plt.text(
|
||||
pos,
|
||||
0.2,
|
||||
f"event:{ev_type}:{ev['Msg']}",
|
||||
rotation=90,
|
||||
horizontalalignment="left"
|
||||
if ev_type != "STOPPED_SPEAKING"
|
||||
else "right",
|
||||
verticalalignment="center",
|
||||
)
|
||||
plt.title("Monophonic")
|
||||
plt.savefig(file_path, format="png")
|
||||
|
||||
data_points = get_data_points(utter_events)
|
||||
|
||||
return {
|
||||
"wav_path": saved_wav_path,
|
||||
"num_samples": len(utter_events) // 3,
|
||||
"meta": call_obj,
|
||||
"data_points": data_points,
|
||||
}
|
||||
|
||||
def retrieve_callmeta(uri):
|
||||
cid = Path(urlsplit(uri).path).stem
|
||||
meta = mongo_collection.find_one({"SystemID": cid})
|
||||
duration = meta["EndTS"] - meta["StartTS"]
|
||||
process_meta = process_call(meta)
|
||||
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
|
||||
|
||||
# @plot_app.command()
|
||||
def plot_calls_data():
|
||||
def plot_data_points(y, sr, data_points, file_path):
|
||||
plt.figure(figsize=(16, 12))
|
||||
librosa.display.waveplot(y=y, sr=sr)
|
||||
for dp in data_points:
|
||||
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
|
||||
plt.axvspan(start, end, color="green", alpha=0.2)
|
||||
text_pos = (start + end) / 2
|
||||
plt.text(
|
||||
text_pos,
|
||||
0.25,
|
||||
f"{code}",
|
||||
rotation=90,
|
||||
horizontalalignment="center",
|
||||
verticalalignment="center",
|
||||
)
|
||||
plt.title("Datapoints")
|
||||
plt.savefig(file_path, format="png")
|
||||
return file_path
|
||||
|
||||
def plot_call(call_obj):
|
||||
saved_wav_path, data_points, sys_id = (
|
||||
call_obj["process"]["wav_path"],
|
||||
call_obj["process"]["data_points"],
|
||||
call_obj["meta"]["SystemID"],
|
||||
)
|
||||
file_path = call_plot_dir / Path(sys_id).with_suffix(".png")
|
||||
if not file_path.exists():
|
||||
print(f"plotting: {file_path}")
|
||||
(y, sr) = librosa.load(saved_wav_path)
|
||||
plot_data_points(y, sr, data_points, str(file_path))
|
||||
return file_path
|
||||
|
||||
# plot_call(retrieve_callmeta("http://saasdev.agaralabs.com/calls/JOR9V47L03AGUEL"))
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_stats = call_lens.modify(retrieve_callmeta)(call_logs)
|
||||
# call_plot_data = call_lens.collect()(call_stats)
|
||||
call_plots = call_lens.modify(plot_call)(call_stats)
|
||||
# with ThreadPoolExecutor(max_workers=20) as exe:
|
||||
# print('starting all plot tasks')
|
||||
# responses = [exe.submit(plot_call, w) for w in call_plot_data]
|
||||
# print('submitted all plot tasks')
|
||||
# call_plots = [r.result() for r in responses]
|
||||
pprint(call_plots)
|
||||
|
||||
def extract_data_points():
|
||||
def gen_data_values(saved_wav_path, data_points):
|
||||
call_seg = (
|
||||
AudioSegment.from_wav(saved_wav_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
for dp_id, dp in enumerate(data_points):
|
||||
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
|
||||
code_seg = call_seg[start * 1000 : end * 1000]
|
||||
code_fb = BytesIO()
|
||||
code_seg.export(code_fb, format="wav")
|
||||
code_wav = code_fb.getvalue()
|
||||
# import pdb; pdb.set_trace()
|
||||
yield code, code_seg.duration_seconds, code_wav
|
||||
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_stats = call_lens.modify(retrieve_callmeta)(call_logs)
|
||||
call_objs = call_lens.collect()(call_stats)
|
||||
|
||||
def data_source():
|
||||
for call_obj in tqdm(call_objs):
|
||||
saved_wav_path, data_points, sys_id = (
|
||||
call_obj["process"]["wav_path"],
|
||||
call_obj["process"]["data_points"],
|
||||
call_obj["meta"]["SystemID"],
|
||||
)
|
||||
for dp in gen_data_values(saved_wav_path, data_points):
|
||||
yield dp
|
||||
|
||||
asr_data_writer(call_asr_data, "call_alphanum", data_source())
|
||||
|
||||
# @leader_app.command()
|
||||
def show_leaderboard():
|
||||
def compute_user_stats(call_stat):
|
||||
n_samples = (
|
||||
lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
|
||||
)
|
||||
n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
|
||||
rel_dur = relativedelta(
|
||||
seconds=int(n_duration.total_seconds()),
|
||||
microseconds=n_duration.microseconds,
|
||||
)
|
||||
return {
|
||||
"num_samples": n_samples,
|
||||
"duration": n_duration.total_seconds(),
|
||||
"samples_rate": n_samples / n_duration.total_seconds(),
|
||||
"duration_str": f"{rel_dur.minutes} mins {rel_dur.seconds} secs",
|
||||
"name": call_stat["name"],
|
||||
}
|
||||
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_stats = call_lens.modify(retrieve_callmeta)(call_logs)
|
||||
user_stats = lens["users"].Each().modify(compute_user_stats)(call_stats)
|
||||
leader_df = (
|
||||
pd.DataFrame(user_stats["users"])
|
||||
.sort_values(by=["duration"], ascending=False)
|
||||
.reset_index(drop=True)
|
||||
)
|
||||
leader_df["rank"] = leader_df.index + 1
|
||||
leader_board = leader_df.rename(
|
||||
columns={
|
||||
"rank": "Rank",
|
||||
"num_samples": "Codes",
|
||||
"name": "Name",
|
||||
"samples_rate": "SpeechRate",
|
||||
"duration_str": "Duration",
|
||||
}
|
||||
)[["Rank", "Name", "Codes", "Duration"]]
|
||||
print(
|
||||
"""Today's ASR Speller Dataset Leaderboard:
|
||||
----------------------------------------"""
|
||||
)
|
||||
print(leader_board.to_string(index=False))
|
||||
|
||||
if leaderboard:
|
||||
show_leaderboard()
|
||||
if plot_calls:
|
||||
plot_calls_data()
|
||||
if extract_data:
|
||||
extract_data_points()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -4,7 +4,7 @@
|
|||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from .utils import random_pnr_generator, manifest_str
|
||||
from .utils import random_pnr_generator, asr_data_writer
|
||||
from .tts.googletts import GoogleTTS
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
|
@ -15,27 +15,21 @@ logging.basicConfig(
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_asr_data(output_dir, count):
|
||||
def pnr_tts_streamer(count):
|
||||
google_voices = GoogleTTS.voice_list()
|
||||
gtts = GoogleTTS()
|
||||
wav_dir = output_dir / Path("pnr_data")
|
||||
wav_dir.mkdir(parents=True, exist_ok=True)
|
||||
asr_manifest = output_dir / Path("pnr_data").with_suffix(".json")
|
||||
with asr_manifest.open("w") as mf:
|
||||
for pnr_code in tqdm(random_pnr_generator(count)):
|
||||
tts_code = (
|
||||
f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
|
||||
)
|
||||
param = random.choice(google_voices)
|
||||
param["sample_rate"] = 24000
|
||||
param["num_channels"] = 1
|
||||
wav_data = gtts.text_to_speech(text=tts_code, params=param)
|
||||
audio_dur = len(wav_data[44:]) / (2 * 24000)
|
||||
pnr_af = wav_dir / Path(pnr_code).with_suffix(".wav")
|
||||
pnr_af.write_bytes(wav_data)
|
||||
rel_pnr_path = pnr_af.relative_to(output_dir)
|
||||
manifest = manifest_str(str(rel_pnr_path), audio_dur, pnr_code)
|
||||
mf.write(manifest)
|
||||
for pnr_code in tqdm(random_pnr_generator(count)):
|
||||
tts_code = f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
|
||||
param = random.choice(google_voices)
|
||||
param["sample_rate"] = 24000
|
||||
param["num_channels"] = 1
|
||||
wav_data = gtts.text_to_speech(text=tts_code, params=param)
|
||||
audio_dur = len(wav_data[44:]) / (2 * 24000)
|
||||
yield pnr_code, audio_dur, wav_data
|
||||
|
||||
|
||||
def generate_asr_data_fromtts(output_dir, dataset_name, count):
|
||||
asr_data_writer(output_dir, dataset_name, pnr_tts_streamer(count))
|
||||
|
||||
|
||||
def arg_parser():
|
||||
|
|
@ -52,13 +46,16 @@ def arg_parser():
|
|||
parser.add_argument(
|
||||
"--count", type=int, default=3, help="number of datapoints to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name", type=str, default="pnr_data", help="name of the dataset"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = arg_parser()
|
||||
args = parser.parse_args()
|
||||
generate_asr_data(**vars(args))
|
||||
generate_asr_data_fromtts(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from sklearn.model_selection import train_test_split
|
||||
from num2words import num2words
|
||||
from .utils import alnum_to_asr_tokens
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def separate_space_convert_digit_setpath():
|
||||
with Path("/home/malar/work/asr-data-utils/asr_data/pnr_data.json").open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
|
|
@ -12,9 +16,7 @@ def separate_space_convert_digit_setpath():
|
|||
|
||||
new_pnr_data = []
|
||||
for i in pnr_data:
|
||||
letters = " ".join(list(i["text"]))
|
||||
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
|
||||
i["text"] = ("".join(num_tokens)).lower()
|
||||
i["text"] = alnum_to_asr_tokens(i["text"])
|
||||
i["audio_filepath"] = i["audio_filepath"].replace(
|
||||
"pnr_data/", "/dataset/asr_data/pnr_data/wav/"
|
||||
)
|
||||
|
|
@ -27,24 +29,39 @@ def separate_space_convert_digit_setpath():
|
|||
pf.write(new_pnr_data)
|
||||
|
||||
|
||||
separate_space_convert_digit_setpath()
|
||||
|
||||
|
||||
def split_data():
|
||||
with Path("/dataset/asr_data/pnr_data/pnr_data.json").open("r") as pf:
|
||||
@app.command()
|
||||
def split_data(manifest_path: Path = Path("/dataset/asr_data/pnr_data/pnr_data.json")):
|
||||
with manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
train_pnr, test_pnr = train_test_split(pnr_jsonl, test_size=0.1)
|
||||
with Path("/dataset/asr_data/pnr_data/train_manifest.json").open("w") as pf:
|
||||
with (manifest_path.parent / Path("train_manifest.json")).open("w") as pf:
|
||||
pnr_data = "".join(train_pnr)
|
||||
pf.write(pnr_data)
|
||||
with Path("/dataset/asr_data/pnr_data/test_manifest.json").open("w") as pf:
|
||||
with (manifest_path.parent / Path("test_manifest.json")).open("w") as pf:
|
||||
pnr_data = "".join(test_pnr)
|
||||
pf.write(pnr_data)
|
||||
|
||||
|
||||
split_data()
|
||||
@app.command()
|
||||
def fix_path(
|
||||
dataset_path: Path = Path("/dataset/asr_data/call_alphanum"),
|
||||
):
|
||||
manifest_path = dataset_path / Path('manifest.json')
|
||||
with manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_data = [json.loads(i) for i in pnr_jsonl]
|
||||
new_pnr_data = []
|
||||
for i in pnr_data:
|
||||
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
|
||||
new_pnr_data.append(i)
|
||||
new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data]
|
||||
real_manifest_path = dataset_path / Path('real_manifest.json')
|
||||
with real_manifest_path.open("w") as pf:
|
||||
new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n"
|
||||
pf.write(new_pnr_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def augment_an4():
|
||||
an4_train = Path("/dataset/asr_data/an4/train_manifest.json").read_bytes()
|
||||
an4_test = Path("/dataset/asr_data/an4/test_manifest.json").read_bytes()
|
||||
|
|
@ -57,10 +74,11 @@ def augment_an4():
|
|||
pf.write(an4_test + pnr_test)
|
||||
|
||||
|
||||
augment_an4()
|
||||
# augment_an4()
|
||||
|
||||
|
||||
def validate_data(data_file):
|
||||
@app.command()
|
||||
def validate_data(data_file: Path = Path("/dataset/asr_data/call_alphanum/train_manifest.json")):
|
||||
with Path(data_file).open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
for (i, s) in enumerate(pnr_jsonl):
|
||||
|
|
@ -70,10 +88,13 @@ def validate_data(data_file):
|
|||
print(f"failed on {i}")
|
||||
|
||||
|
||||
validate_data("/dataset/asr_data/an4_pnr/test_manifest.json")
|
||||
validate_data("/dataset/asr_data/an4_pnr/train_manifest.json")
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# def convert_digits(data_file="/dataset/asr_data/an4_pnr/test_manifest.json"):
|
||||
# with Path(data_file).open("r") as pf:
|
||||
# pnr_jsonl = pf.readlines()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import numpy as np
|
|||
import wave
|
||||
import io
|
||||
import json
|
||||
from pathlib import Path
|
||||
from num2words import num2words
|
||||
|
||||
|
||||
def manifest_str(path, dur, text):
|
||||
|
|
@ -38,6 +40,27 @@ def random_pnr_generator(count=10000):
|
|||
return codes
|
||||
|
||||
|
||||
def alnum_to_asr_tokens(text):
|
||||
letters = " ".join(list(text))
|
||||
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
|
||||
return ("".join(num_tokens)).lower()
|
||||
|
||||
|
||||
def asr_data_writer(output_dir, dataset_name, asr_data_source):
|
||||
dataset_dir = output_dir / Path(dataset_name)
|
||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
with asr_manifest.open("w") as mf:
|
||||
for pnr_code, audio_dur, wav_data in asr_data_source:
|
||||
pnr_af = dataset_dir / Path("wav") / Path(pnr_code).with_suffix(".wav")
|
||||
pnr_af.write_bytes(wav_data)
|
||||
rel_pnr_path = pnr_af.relative_to(dataset_dir)
|
||||
manifest = manifest_str(
|
||||
str(rel_pnr_path), audio_dur, alnum_to_asr_tokens(pnr_code)
|
||||
)
|
||||
mf.write(manifest)
|
||||
|
||||
|
||||
def main():
|
||||
for c in random_pnr_generator():
|
||||
print(c)
|
||||
|
|
|
|||
|
|
@ -82,8 +82,7 @@ def parse_args():
|
|||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.max_steps is not None and args.num_epochs is not None:
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
|
@ -311,7 +310,6 @@ def main():
|
|||
|
||||
# build dags
|
||||
train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory)
|
||||
|
||||
# train model
|
||||
neural_factory.train(
|
||||
tensors_to_optimize=[train_loss],
|
||||
|
|
|
|||
16
setup.py
16
setup.py
|
|
@ -2,6 +2,8 @@ from setuptools import setup, find_packages
|
|||
|
||||
requirements = [
|
||||
"ruamel.yaml",
|
||||
"torch==1.4.0",
|
||||
"torchvision==0.5.0",
|
||||
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
||||
]
|
||||
|
||||
|
|
@ -14,7 +16,19 @@ extra_requirements = {
|
|||
"scikit_learn~=0.22.1",
|
||||
"pandas~=1.0.3",
|
||||
"boto3~=1.12.35",
|
||||
"ruamel.yaml==0.16.10",
|
||||
"pymongo==3.10.1",
|
||||
"librosa==0.7.2",
|
||||
"matplotlib==3.2.1",
|
||||
"pandas==1.0.3",
|
||||
"tabulate==0.8.7",
|
||||
"typer[all]==0.1.1",
|
||||
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
||||
],
|
||||
# "train": [
|
||||
# "torchaudio==0.5.0",
|
||||
# "torch-stft==0.1.4",
|
||||
# ]
|
||||
}
|
||||
packages = find_packages()
|
||||
|
||||
|
|
@ -35,6 +49,8 @@ setup(
|
|||
"jasper_asr_rpyc_server = jasper.server:main",
|
||||
"jasper_asr_trainer = jasper.train:main",
|
||||
"jasper_asr_data_generate = jasper.data_utils.generator:main",
|
||||
"jasper_asr_data_recycle = jasper.data_utils.call_recycler:main",
|
||||
"jasper_asr_data_preprocess = jasper.data_utils.process:main",
|
||||
]
|
||||
},
|
||||
zip_safe=False,
|
||||
|
|
|
|||
Loading…
Reference in New Issue