mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-03-08 02:22:34 +00:00
Compare commits
7 Commits
000853b600
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| e30dd724f5 | |||
|
|
02df1b5282 | ||
| e8f58a5043 | |||
| 42647196fe | |||
| e77943b2f2 | |||
|
|
14d31a51c3 | ||
| e24a8cf9d0 |
4
.flake8
Normal file
4
.flake8
Normal file
@@ -0,0 +1,4 @@
|
||||
[flake8]
|
||||
exclude = docs
|
||||
ignore = E203, W503
|
||||
max-line-length = 119
|
||||
5
Notes.md
Normal file
5
Notes.md
Normal file
@@ -0,0 +1,5 @@
|
||||
|
||||
> Diff after splitting based on type
|
||||
```
|
||||
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
|
||||
```
|
||||
@@ -7,10 +7,16 @@
|
||||
|
||||
# Table of Contents
|
||||
|
||||
* [Prerequisites](#prerequisites)
|
||||
* [Features](#features)
|
||||
* [Installation](#installation)
|
||||
* [Usage](#usage)
|
||||
|
||||
# Prerequisites
|
||||
```bash
|
||||
# apt install libsndfile-dev ffmpeg
|
||||
```
|
||||
|
||||
# Features
|
||||
|
||||
* ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) )
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
import typer
|
||||
from itertools import chain
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def extract_data(
|
||||
call_audio_dir: Path = Path("/dataset/png_prod/call_audio"),
|
||||
call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"),
|
||||
output_dir: Path = Path("./data"),
|
||||
dataset_name: str = "png_gcp_2jan",
|
||||
verbose: bool = False,
|
||||
):
|
||||
from pydub import AudioSegment
|
||||
from .utils import ExtendedPath, asr_data_writer, strip_silence
|
||||
from lenses import lens
|
||||
|
||||
call_asr_data: Path = output_dir / Path("asr_data")
|
||||
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def wav_event_generator(call_audio_dir):
|
||||
for wav_path in call_audio_dir.glob("**/*.wav"):
|
||||
if verbose:
|
||||
typer.echo(f"loading events for file {wav_path}")
|
||||
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
|
||||
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
|
||||
meta_path = call_meta_dir / rel_meta_path
|
||||
events = ExtendedPath(meta_path).read_json()
|
||||
yield call_wav, wav_path, events
|
||||
|
||||
def contains_asr(x):
|
||||
return "AsrResult" in x
|
||||
|
||||
def channel(n):
|
||||
def filter_func(ev):
|
||||
return (
|
||||
ev["AsrResult"]["Channel"] == n
|
||||
if "Channel" in ev["AsrResult"]
|
||||
else n == 0
|
||||
)
|
||||
|
||||
return filter_func
|
||||
|
||||
def compute_endtime(call_wav, state):
|
||||
for (i, st) in enumerate(state):
|
||||
start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0)
|
||||
transcript = st["AsrResult"]["Alternatives"][0]["Transcript"]
|
||||
if i + 1 < len(state):
|
||||
end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
|
||||
else:
|
||||
end_time = call_wav.duration_seconds
|
||||
full_code_seg = call_wav[start_time * 1000 : end_time * 1000]
|
||||
code_seg = strip_silence(full_code_seg)
|
||||
code_fb = BytesIO()
|
||||
code_seg.export(code_fb, format="wav")
|
||||
code_wav = code_fb.getvalue()
|
||||
# only starting 1 min audio has reliable alignment ignore rest
|
||||
if start_time > 60:
|
||||
if verbose:
|
||||
print(f'start time over 60 seconds of audio skipping.')
|
||||
break
|
||||
# only if some reasonable audio data is present yield it
|
||||
if code_seg.duration_seconds < 0.5:
|
||||
if verbose:
|
||||
print(f'transcript chunk "{transcript}" contains no audio skipping.')
|
||||
continue
|
||||
yield transcript, code_seg.duration_seconds, code_wav
|
||||
|
||||
def asr_data_generator(call_wav, call_wav_fname, events):
|
||||
call_wav_0, call_wav_1 = call_wav.split_to_mono()
|
||||
asr_events = lens["Events"].Each()["Event"].Filter(contains_asr)
|
||||
call_evs_0 = asr_events.Filter(channel(0)).collect()(events)
|
||||
# Ignoring agent channel events
|
||||
# call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
|
||||
if verbose:
|
||||
typer.echo(f"processing data points on {call_wav_fname}")
|
||||
call_data_0 = compute_endtime(call_wav_0, call_evs_0)
|
||||
# Ignoring agent channel
|
||||
# call_data_1 = compute_endtime(call_wav_1, call_evs_1)
|
||||
return call_data_0 # chain(call_data_0, call_data_1)
|
||||
|
||||
def generate_call_asr_data():
|
||||
full_asr_data = []
|
||||
total_duration = 0
|
||||
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
|
||||
asr_data = asr_data_generator(wav, wav_path, ev)
|
||||
total_duration += wav.duration_seconds
|
||||
full_asr_data.append(asr_data)
|
||||
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
|
||||
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
|
||||
typer.echo(f"written {n_dps} data points")
|
||||
|
||||
generate_call_asr_data()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,521 +0,0 @@
|
||||
import typer
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_all_logs(
|
||||
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
|
||||
):
|
||||
from .utils import get_mongo_conn
|
||||
from collections import defaultdict
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
yaml = YAML()
|
||||
mongo_coll = get_mongo_conn()
|
||||
caller_calls = defaultdict(lambda: [])
|
||||
for call in mongo_coll.find():
|
||||
sysid = call["SystemID"]
|
||||
call_uri = f"http://{domain}/calls/{sysid}"
|
||||
caller = call["Caller"]
|
||||
caller_calls[caller].append(call_uri)
|
||||
caller_list = []
|
||||
for caller in caller_calls:
|
||||
caller_list.append({"name": caller, "calls": caller_calls[caller]})
|
||||
output_yaml = {"users": caller_list}
|
||||
typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
|
||||
with call_logs_file.open("w") as yf:
|
||||
yaml.dump(output_yaml, yf)
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_calls_between(
|
||||
start_cid: str,
|
||||
end_cid: str,
|
||||
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
|
||||
mongo_port: int = 27017,
|
||||
):
|
||||
from collections import defaultdict
|
||||
from ruamel.yaml import YAML
|
||||
from .utils import get_mongo_conn
|
||||
|
||||
yaml = YAML()
|
||||
mongo_coll = get_mongo_conn(port=mongo_port)
|
||||
start_meta = mongo_coll.find_one({"SystemID": start_cid})
|
||||
end_meta = mongo_coll.find_one({"SystemID": end_cid})
|
||||
|
||||
caller_calls = defaultdict(lambda: [])
|
||||
call_query = mongo_coll.find(
|
||||
{
|
||||
"StartTS": {"$gte": start_meta["StartTS"]},
|
||||
"EndTS": {"$lte": end_meta["EndTS"]},
|
||||
}
|
||||
)
|
||||
for call in call_query:
|
||||
sysid = call["SystemID"]
|
||||
call_uri = f"http://{domain}/calls/{sysid}"
|
||||
caller = call["Caller"]
|
||||
caller_calls[caller].append(call_uri)
|
||||
caller_list = []
|
||||
for caller in caller_calls:
|
||||
caller_list.append({"name": caller, "calls": caller_calls[caller]})
|
||||
output_yaml = {"users": caller_list}
|
||||
typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
|
||||
with call_logs_file.open("w") as yf:
|
||||
yaml.dump(output_yaml, yf)
|
||||
|
||||
|
||||
@app.command()
|
||||
def copy_metas(
|
||||
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||
output_dir: Path = Path("./data"),
|
||||
meta_dir: Path = Path("/tmp/call_metas"),
|
||||
):
|
||||
from lenses import lens
|
||||
from ruamel.yaml import YAML
|
||||
from urllib.parse import urlsplit
|
||||
from shutil import copy2
|
||||
|
||||
yaml = YAML()
|
||||
call_logs = yaml.load(call_logs_file.read_text())
|
||||
|
||||
call_meta_dir: Path = output_dir / Path("call_metas")
|
||||
call_meta_dir.mkdir(exist_ok=True, parents=True)
|
||||
meta_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def get_cid(uri):
|
||||
return Path(urlsplit(uri).path).stem
|
||||
|
||||
def copy_meta(uri):
|
||||
cid = get_cid(uri)
|
||||
saved_meta_path = call_meta_dir / Path(f"{cid}.json")
|
||||
dest_meta_path = meta_dir / Path(f"{cid}.json")
|
||||
if not saved_meta_path.exists():
|
||||
print(f"{saved_meta_path} not found")
|
||||
copy2(saved_meta_path, dest_meta_path)
|
||||
|
||||
def download_meta_audio():
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_lens.modify(copy_meta)(call_logs)
|
||||
|
||||
download_meta_audio()
|
||||
|
||||
|
||||
class ExtractionType(str, Enum):
|
||||
flow = "flow"
|
||||
data = "data"
|
||||
|
||||
|
||||
@app.command()
|
||||
def analyze(
|
||||
leaderboard: bool = False,
|
||||
plot_calls: bool = False,
|
||||
extract_data: bool = False,
|
||||
extraction_type: ExtractionType = typer.Option(
|
||||
ExtractionType.data, show_default=True
|
||||
),
|
||||
start_delay: float = 1.5,
|
||||
download_only: bool = False,
|
||||
strip_silent_chunks: bool = True,
|
||||
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||
output_dir: Path = Path("./data"),
|
||||
data_name: str = None,
|
||||
mongo_uri: str = typer.Option(
|
||||
"mongodb://localhost:27017/test.calls", show_default=True
|
||||
),
|
||||
):
|
||||
|
||||
from urllib.parse import urlsplit
|
||||
from functools import reduce
|
||||
import boto3
|
||||
from io import BytesIO
|
||||
import json
|
||||
from ruamel.yaml import YAML
|
||||
import re
|
||||
from google.protobuf.timestamp_pb2 import Timestamp
|
||||
from datetime import timedelta
|
||||
import librosa
|
||||
import librosa.display
|
||||
from lenses import lens
|
||||
from pprint import pprint
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
from tqdm import tqdm
|
||||
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll
|
||||
from pydub import AudioSegment
|
||||
from natural.date import compress
|
||||
|
||||
matplotlib.rcParams["agg.path.chunksize"] = 10000
|
||||
|
||||
matplotlib.use("agg")
|
||||
|
||||
yaml = YAML()
|
||||
s3 = boto3.client("s3")
|
||||
mongo_collection = get_mongo_coll(mongo_uri)
|
||||
call_media_dir: Path = output_dir / Path("call_wavs")
|
||||
call_media_dir.mkdir(exist_ok=True, parents=True)
|
||||
call_meta_dir: Path = output_dir / Path("call_metas")
|
||||
call_meta_dir.mkdir(exist_ok=True, parents=True)
|
||||
call_plot_dir: Path = output_dir / Path("plots")
|
||||
call_plot_dir.mkdir(exist_ok=True, parents=True)
|
||||
call_asr_data: Path = output_dir / Path("asr_data")
|
||||
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||
dataset_name = call_logs_file.stem if not data_name else data_name
|
||||
|
||||
call_logs = yaml.load(call_logs_file.read_text())
|
||||
|
||||
def get_call_meta(call_obj):
|
||||
meta_s3_uri = call_obj["DataURI"]
|
||||
s3_event_url_p = urlsplit(meta_s3_uri)
|
||||
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
|
||||
if not saved_meta_path.exists():
|
||||
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
|
||||
s3.download_file(
|
||||
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
|
||||
)
|
||||
call_metas = json.load(saved_meta_path.open())
|
||||
return call_metas
|
||||
|
||||
def gen_ev_fev_timedelta(fev):
|
||||
fev_p = Timestamp()
|
||||
fev_p.FromJsonString(fev["CreatedTS"])
|
||||
fev_dt = fev_p.ToDatetime()
|
||||
td_0 = timedelta()
|
||||
|
||||
def get_timedelta(ev):
|
||||
ev_p = Timestamp()
|
||||
ev_p.FromJsonString(value=ev["CreatedTS"])
|
||||
ev_dt = ev_p.ToDatetime()
|
||||
delta = ev_dt - fev_dt
|
||||
return delta if delta > td_0 else td_0
|
||||
|
||||
return get_timedelta
|
||||
|
||||
def chunk_n(evs, n):
|
||||
return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
|
||||
|
||||
if extraction_type == ExtractionType.data:
|
||||
|
||||
def is_utter_event(ev):
|
||||
return (
|
||||
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
|
||||
and (ev["Type"] != "DEBUG")
|
||||
and ev["Type"] != "ASR_RESULT"
|
||||
)
|
||||
|
||||
def get_data_points(utter_events, td_fn):
|
||||
data_points = []
|
||||
for evs in chunk_n(utter_events, 3):
|
||||
try:
|
||||
assert evs[0]["Type"] == "CONV_RESULT"
|
||||
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
||||
assert evs[2]["Type"] == "STOPPED_SPEAKING"
|
||||
start_time = td_fn(evs[1]).total_seconds() - start_delay
|
||||
end_time = td_fn(evs[2]).total_seconds()
|
||||
spoken = evs[0]["Msg"]
|
||||
data_points.append(
|
||||
{"start_time": start_time, "end_time": end_time, "code": spoken}
|
||||
)
|
||||
except AssertionError:
|
||||
# skipping invalid data_points
|
||||
pass
|
||||
return data_points
|
||||
|
||||
def text_extractor(spoken):
|
||||
return (
|
||||
re.search(r"'(.*)'", spoken).groups(0)[0]
|
||||
if len(spoken) > 6 and re.search(r"'(.*)'", spoken)
|
||||
else spoken
|
||||
)
|
||||
|
||||
elif extraction_type == ExtractionType.flow:
|
||||
|
||||
def is_final_asr_event_or_spoken(ev):
|
||||
pld = json.loads(ev["Payload"])
|
||||
return (
|
||||
pld["AsrResult"]["Results"][0]["IsFinal"]
|
||||
if ev["Type"] == "ASR_RESULT"
|
||||
else True
|
||||
)
|
||||
|
||||
def is_utter_event(ev):
|
||||
return (
|
||||
ev["Author"] == "CONV"
|
||||
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
|
||||
) and (ev["Type"] != "DEBUG")
|
||||
|
||||
def get_data_points(utter_events, td_fn):
|
||||
data_points = []
|
||||
for evs in chunk_n(utter_events, 4):
|
||||
try:
|
||||
assert len(evs) == 4
|
||||
assert evs[0]["Type"] == "CONV_RESULT"
|
||||
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
||||
assert evs[2]["Type"] == "ASR_RESULT"
|
||||
assert evs[3]["Type"] == "STOPPED_SPEAKING"
|
||||
start_time = td_fn(evs[1]).total_seconds() - start_delay
|
||||
end_time = td_fn(evs[2]).total_seconds()
|
||||
conv_msg = evs[0]["Msg"]
|
||||
if "full name" in conv_msg.lower():
|
||||
pld = json.loads(evs[2]["Payload"])
|
||||
spoken = pld["AsrResult"]["Results"][0]["Alternatives"][0][
|
||||
"Transcript"
|
||||
]
|
||||
data_points.append(
|
||||
{
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"code": spoken,
|
||||
}
|
||||
)
|
||||
except AssertionError:
|
||||
# skipping invalid data_points
|
||||
pass
|
||||
return data_points
|
||||
|
||||
def text_extractor(spoken):
|
||||
return spoken
|
||||
|
||||
def process_call(call_obj):
|
||||
call_meta = get_call_meta(call_obj)
|
||||
call_events = call_meta["Events"]
|
||||
|
||||
def is_writer_uri_event(ev):
|
||||
return ev["Author"] == "AUDIO_WRITER" and "s3://" in ev["Msg"]
|
||||
|
||||
writer_events = list(filter(is_writer_uri_event, call_events))
|
||||
s3_wav_url = re.search(r"(s3://.*)", writer_events[0]["Msg"]).groups(0)[0]
|
||||
s3_wav_url_p = urlsplit(s3_wav_url)
|
||||
|
||||
def is_first_audio_ev(state, ev):
|
||||
if state[0]:
|
||||
return state
|
||||
else:
|
||||
return (ev["Author"] == "GATEWAY" and ev["Type"] == "AUDIO", ev)
|
||||
|
||||
(_, first_audio_ev) = reduce(is_first_audio_ev, call_events, (False, {}))
|
||||
|
||||
get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
|
||||
|
||||
uevs = list(filter(is_utter_event, call_events))
|
||||
ev_count = len(uevs)
|
||||
utter_events = uevs[: ev_count - ev_count % 3]
|
||||
saved_wav_path = call_media_dir / Path(Path(s3_wav_url_p.path).name)
|
||||
if not saved_wav_path.exists():
|
||||
print(f"downloading : {saved_wav_path} from {s3_wav_url}")
|
||||
s3.download_file(
|
||||
s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
|
||||
)
|
||||
|
||||
return {
|
||||
"wav_path": saved_wav_path,
|
||||
"num_samples": len(utter_events) // 3,
|
||||
"meta": call_obj,
|
||||
"first_event_fn": get_ev_fev_timedelta,
|
||||
"utter_events": utter_events,
|
||||
}
|
||||
|
||||
def get_cid(uri):
|
||||
return Path(urlsplit(uri).path).stem
|
||||
|
||||
def ensure_call(uri):
|
||||
cid = get_cid(uri)
|
||||
meta = mongo_collection.find_one({"SystemID": cid})
|
||||
process_meta = process_call(meta)
|
||||
return process_meta
|
||||
|
||||
def retrieve_processed_callmeta(uri):
|
||||
cid = get_cid(uri)
|
||||
meta = mongo_collection.find_one({"SystemID": cid})
|
||||
duration = meta["EndTS"] - meta["StartTS"]
|
||||
process_meta = process_call(meta)
|
||||
data_points = get_data_points(
|
||||
process_meta["utter_events"], process_meta["first_event_fn"]
|
||||
)
|
||||
process_meta["data_points"] = data_points
|
||||
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
|
||||
|
||||
def retrieve_callmeta(call_uri):
|
||||
uri = call_uri["call_uri"]
|
||||
name = call_uri["name"]
|
||||
cid = get_cid(uri)
|
||||
meta = mongo_collection.find_one({"SystemID": cid})
|
||||
duration = meta["EndTS"] - meta["StartTS"]
|
||||
process_meta = process_call(meta)
|
||||
data_points = get_data_points(
|
||||
process_meta["utter_events"], process_meta["first_event_fn"]
|
||||
)
|
||||
process_meta["data_points"] = data_points
|
||||
return {
|
||||
"url": uri,
|
||||
"name": name,
|
||||
"meta": meta,
|
||||
"duration": duration,
|
||||
"process": process_meta,
|
||||
}
|
||||
|
||||
def download_meta_audio():
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_lens.modify(ensure_call)(call_logs)
|
||||
|
||||
def plot_calls_data():
|
||||
def plot_data_points(y, sr, data_points, file_path):
|
||||
plt.figure(figsize=(16, 12))
|
||||
librosa.display.waveplot(y=y, sr=sr)
|
||||
for dp in data_points:
|
||||
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
|
||||
plt.axvspan(start, end, color="green", alpha=0.2)
|
||||
text_pos = (start + end) / 2
|
||||
plt.text(
|
||||
text_pos,
|
||||
0.25,
|
||||
f"{code}",
|
||||
rotation=90,
|
||||
horizontalalignment="center",
|
||||
verticalalignment="center",
|
||||
)
|
||||
plt.title("Datapoints")
|
||||
plt.savefig(file_path, format="png")
|
||||
return file_path
|
||||
|
||||
def plot_call(call_obj):
|
||||
saved_wav_path, data_points, sys_id = (
|
||||
call_obj["process"]["wav_path"],
|
||||
call_obj["process"]["data_points"],
|
||||
call_obj["meta"]["SystemID"],
|
||||
)
|
||||
file_path = call_plot_dir / Path(sys_id).with_suffix(".png")
|
||||
if not file_path.exists():
|
||||
print(f"plotting: {file_path}")
|
||||
(y, sr) = librosa.load(saved_wav_path)
|
||||
plot_data_points(y, sr, data_points, str(file_path))
|
||||
return file_path
|
||||
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
|
||||
# call_plot_data = call_lens.collect()(call_stats)
|
||||
call_plots = call_lens.modify(plot_call)(call_stats)
|
||||
# with ThreadPoolExecutor(max_workers=20) as exe:
|
||||
# print('starting all plot tasks')
|
||||
# responses = [exe.submit(plot_call, w) for w in call_plot_data]
|
||||
# print('submitted all plot tasks')
|
||||
# call_plots = [r.result() for r in responses]
|
||||
pprint(call_plots)
|
||||
|
||||
def extract_data_points():
|
||||
if strip_silent_chunks:
|
||||
|
||||
def audio_process(seg):
|
||||
return strip_silence(seg)
|
||||
|
||||
else:
|
||||
|
||||
def audio_process(seg):
|
||||
return seg
|
||||
|
||||
def gen_data_values(saved_wav_path, data_points, caller_name):
|
||||
call_seg = (
|
||||
AudioSegment.from_wav(saved_wav_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
for dp_id, dp in enumerate(data_points):
|
||||
start, end, spoken = dp["start_time"], dp["end_time"], dp["code"]
|
||||
spoken_seg = audio_process(call_seg[start * 1000 : end * 1000])
|
||||
spoken_fb = BytesIO()
|
||||
spoken_seg.export(spoken_fb, format="wav")
|
||||
spoken_wav = spoken_fb.getvalue()
|
||||
# search for actual pnr code and handle plain codes as well
|
||||
extracted_code = text_extractor(spoken)
|
||||
if strip_silent_chunks and spoken_seg.duration_seconds < 0.5:
|
||||
print(f'transcript chunk "{spoken}" contains no audio skipping.')
|
||||
continue
|
||||
yield extracted_code, spoken_seg.duration_seconds, spoken_wav, caller_name, spoken_seg
|
||||
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
|
||||
def assign_user_call(uc):
|
||||
return (
|
||||
lens["calls"]
|
||||
.Each()
|
||||
.modify(lambda c: {"call_uri": c, "name": uc["name"]})(uc)
|
||||
)
|
||||
|
||||
user_call_logs = lens["users"].Each().modify(assign_user_call)(call_logs)
|
||||
call_stats = call_lens.modify(retrieve_callmeta)(user_call_logs)
|
||||
call_objs = call_lens.collect()(call_stats)
|
||||
|
||||
def data_source():
|
||||
for call_obj in tqdm(call_objs):
|
||||
saved_wav_path, data_points, name = (
|
||||
call_obj["process"]["wav_path"],
|
||||
call_obj["process"]["data_points"],
|
||||
call_obj["name"],
|
||||
)
|
||||
for dp in gen_data_values(saved_wav_path, data_points, name):
|
||||
yield dp
|
||||
|
||||
ui_dump_manifest_writer(call_asr_data, dataset_name, data_source())
|
||||
|
||||
def show_leaderboard():
|
||||
def compute_user_stats(call_stat):
|
||||
n_samples = (
|
||||
lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
|
||||
)
|
||||
n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
|
||||
return {
|
||||
"num_samples": n_samples,
|
||||
"duration": n_duration.total_seconds(),
|
||||
"samples_rate": n_samples / n_duration.total_seconds(),
|
||||
"duration_str": compress(n_duration, pad=" "),
|
||||
"name": call_stat["name"],
|
||||
}
|
||||
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
|
||||
user_stats = lens["users"].Each().modify(compute_user_stats)(call_stats)
|
||||
leader_df = (
|
||||
pd.DataFrame(user_stats["users"])
|
||||
.sort_values(by=["duration"], ascending=False)
|
||||
.reset_index(drop=True)
|
||||
)
|
||||
leader_df["rank"] = leader_df.index + 1
|
||||
leader_board = leader_df.rename(
|
||||
columns={
|
||||
"rank": "Rank",
|
||||
"num_samples": "Count",
|
||||
"name": "Name",
|
||||
"samples_rate": "SpeechRate",
|
||||
"duration_str": "Duration",
|
||||
}
|
||||
)[["Rank", "Name", "Count", "Duration"]]
|
||||
print(
|
||||
"""ASR Dataset Leaderboard :
|
||||
---------------------------------"""
|
||||
)
|
||||
print(leader_board.to_string(index=False))
|
||||
|
||||
if download_only:
|
||||
download_meta_audio()
|
||||
return
|
||||
if leaderboard:
|
||||
show_leaderboard()
|
||||
if plot_calls:
|
||||
plot_calls_data()
|
||||
if extract_data:
|
||||
extract_data_points()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,68 +0,0 @@
|
||||
import typer
|
||||
from pathlib import Path
|
||||
from random import randrange
|
||||
from itertools import product
|
||||
from math import floor
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_conv_json(
|
||||
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
|
||||
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
||||
):
|
||||
from .utils import ExtendedPath
|
||||
|
||||
conv_data = ExtendedPath(conv_src).read_json()
|
||||
|
||||
days = [i for i in range(1, 32)]
|
||||
months = [
|
||||
"January",
|
||||
"February",
|
||||
"March",
|
||||
"April",
|
||||
"May",
|
||||
"June",
|
||||
"July",
|
||||
"August",
|
||||
"September",
|
||||
"October",
|
||||
"November",
|
||||
"December",
|
||||
]
|
||||
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
|
||||
|
||||
def ordinal(n):
|
||||
return "%d%s" % (
|
||||
n,
|
||||
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
|
||||
)
|
||||
|
||||
def canon_vars(d, m):
|
||||
return [
|
||||
ordinal(d) + " " + m,
|
||||
m + " " + ordinal(d),
|
||||
ordinal(d) + " of " + m,
|
||||
m + " the " + ordinal(d),
|
||||
str(d) + " " + m,
|
||||
m + " " + str(d),
|
||||
]
|
||||
|
||||
day_months = [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
|
||||
|
||||
conv_data["dates"] = day_months
|
||||
|
||||
def dates_data_gen():
|
||||
i = randrange(len(day_months))
|
||||
return day_months[i]
|
||||
|
||||
ExtendedPath(conv_dest).write_json(conv_data)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -23,7 +23,7 @@ def fixate_data(dataset_path: Path):
|
||||
|
||||
|
||||
@app.command()
|
||||
def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||
reader_list = []
|
||||
abs_manifest_path = Path("abs_manifest.json")
|
||||
for dataset_path in src_dataset_paths:
|
||||
@@ -38,9 +38,9 @@ def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||
def split_data(dataset_path: Path, test_size: float = 0.1):
|
||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
asr_data = list(asr_manifest_reader(manifest_path))
|
||||
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
|
||||
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
|
||||
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
|
||||
train_data, test_data = train_test_split(asr_data, test_size=test_size)
|
||||
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_data)
|
||||
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -52,9 +52,9 @@ def validate_data(dataset_path: Path):
|
||||
data_file = dataset_path / Path(mf_type)
|
||||
print(f"validating {data_file}.")
|
||||
with Path(data_file).open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
data_jsonl = pf.readlines()
|
||||
duration = 0
|
||||
for (i, s) in enumerate(pnr_jsonl):
|
||||
for (i, s) in enumerate(data_jsonl):
|
||||
try:
|
||||
d = json.loads(s)
|
||||
duration += d["duration"]
|
||||
|
||||
93
jasper/data/rastrik_recycler.py
Normal file
93
jasper/data/rastrik_recycler.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from rastrik.proto.callrecord_pb2 import CallRecord
|
||||
import gzip
|
||||
from pydub import AudioSegment
|
||||
from .utils import ui_dump_manifest_writer, strip_silence
|
||||
|
||||
import typer
|
||||
from itertools import chain
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def extract_manifest(
|
||||
call_log_dir: Path = Path("./data/call_audio"),
|
||||
output_dir: Path = Path("./data"),
|
||||
dataset_name: str = "grassroot_pizzahut_v1",
|
||||
caller_name: str = "grassroot",
|
||||
verbose: bool = False,
|
||||
):
|
||||
call_asr_data: Path = output_dir / Path("asr_data")
|
||||
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def wav_pb2_generator(log_dir):
|
||||
for wav_path in log_dir.glob("**/*.wav"):
|
||||
if verbose:
|
||||
typer.echo(f"loading events for file {wav_path}")
|
||||
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
|
||||
meta_path = wav_path.with_suffix(".pb2.gz")
|
||||
yield call_wav, wav_path, meta_path
|
||||
|
||||
def read_event(call_wav, log_file):
|
||||
call_wav_0, call_wav_1 = call_wav.split_to_mono()
|
||||
with gzip.open(log_file, "rb") as log_h:
|
||||
record_data = log_h.read()
|
||||
cr = CallRecord()
|
||||
cr.ParseFromString(record_data)
|
||||
|
||||
first_audio_event_timestamp = next(
|
||||
(
|
||||
i
|
||||
for i in cr.events
|
||||
if i.WhichOneof("event_type") == "call_event"
|
||||
and i.call_event.WhichOneof("event_type") == "call_audio"
|
||||
)
|
||||
).timestamp.ToDatetime()
|
||||
|
||||
speech_events = [
|
||||
i
|
||||
for i in cr.events
|
||||
if i.WhichOneof("event_type") == "speech_event"
|
||||
and i.speech_event.WhichOneof("event_type") == "asr_final"
|
||||
]
|
||||
previous_event_timestamp = (
|
||||
first_audio_event_timestamp - first_audio_event_timestamp
|
||||
)
|
||||
for index, each_speech_events in enumerate(speech_events):
|
||||
asr_final = each_speech_events.speech_event.asr_final
|
||||
speech_timestamp = each_speech_events.timestamp.ToDatetime()
|
||||
actual_timestamp = speech_timestamp - first_audio_event_timestamp
|
||||
start_time = previous_event_timestamp.total_seconds() * 1000
|
||||
end_time = actual_timestamp.total_seconds() * 1000
|
||||
audio_segment = strip_silence(call_wav_1[start_time:end_time])
|
||||
|
||||
code_fb = BytesIO()
|
||||
audio_segment.export(code_fb, format="wav")
|
||||
wav_data = code_fb.getvalue()
|
||||
previous_event_timestamp = actual_timestamp
|
||||
duration = (end_time - start_time) / 1000
|
||||
yield asr_final, duration, wav_data, "grassroot", audio_segment
|
||||
|
||||
def generate_call_asr_data():
|
||||
full_data = []
|
||||
total_duration = 0
|
||||
for wav, wav_path, pb2_path in wav_pb2_generator(call_log_dir):
|
||||
asr_data = read_event(wav, pb2_path)
|
||||
total_duration += wav.duration_seconds
|
||||
full_data.append(asr_data)
|
||||
n_calls = len(full_data)
|
||||
typer.echo(f"loaded {n_calls} calls of duration {total_duration}s")
|
||||
n_dps = ui_dump_manifest_writer(call_asr_data, dataset_name, chain(*full_data))
|
||||
typer.echo(f"written {n_dps} data points")
|
||||
|
||||
generate_call_asr_data()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,175 +0,0 @@
|
||||
import typer
|
||||
from itertools import chain
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def extract_data(
|
||||
call_audio_dir: Path = typer.Option(Path("/dataset/rev/wavs"), show_default=True),
|
||||
call_meta_dir: Path = typer.Option(Path("/dataset/rev/jsons"), show_default=True),
|
||||
output_dir: Path = typer.Option(Path("./data"), show_default=True),
|
||||
dataset_name: str = typer.Option("rev_transribed", show_default=True),
|
||||
verbose: bool = False,
|
||||
):
|
||||
from pydub import AudioSegment
|
||||
from .utils import ExtendedPath, asr_data_writer, strip_silence
|
||||
from lenses import lens
|
||||
import datetime
|
||||
|
||||
call_asr_data: Path = output_dir / Path("asr_data")
|
||||
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def wav_event_generator(call_audio_dir):
|
||||
for wav_path in call_audio_dir.glob("**/*.wav"):
|
||||
if verbose:
|
||||
typer.echo(f"loading events for file {wav_path}")
|
||||
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
|
||||
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
|
||||
meta_path = call_meta_dir / rel_meta_path
|
||||
if meta_path.exists():
|
||||
events = ExtendedPath(meta_path).read_json()
|
||||
yield call_wav, wav_path, events
|
||||
else:
|
||||
if verbose:
|
||||
typer.echo(f"missing json corresponding to {wav_path}")
|
||||
|
||||
def contains_asr(x):
|
||||
return "AsrResult" in x
|
||||
|
||||
def channel(n):
|
||||
def filter_func(ev):
|
||||
return (
|
||||
ev["AsrResult"]["Channel"] == n
|
||||
if "Channel" in ev["AsrResult"]
|
||||
else n == 0
|
||||
)
|
||||
|
||||
return filter_func
|
||||
|
||||
def time_to_msecs(time_str):
|
||||
return (
|
||||
datetime.datetime.strptime(time_str, "%H:%M:%S,%f")
|
||||
- datetime.datetime(1900, 1, 1)
|
||||
).total_seconds() * 1000
|
||||
|
||||
def process_utterance_chunk(wav_seg, start_time, end_time, monologue):
|
||||
# offset by 1sec left side to include vad? discarded audio
|
||||
full_tscript_wav_seg = wav_seg[
|
||||
time_to_msecs(start_time) - 1000 : time_to_msecs(end_time) # + 1000
|
||||
]
|
||||
tscript_wav_seg = strip_silence(full_tscript_wav_seg)
|
||||
tscript_wav_fb = BytesIO()
|
||||
tscript_wav_seg.export(tscript_wav_fb, format="wav")
|
||||
tscript_wav = tscript_wav_fb.getvalue()
|
||||
text = "".join(lens["elements"].Each()["value"].collect()(monologue))
|
||||
text_clean = re.sub(r"\[.*\]", "", text)
|
||||
return tscript_wav, tscript_wav_seg.duration_seconds, text_clean
|
||||
|
||||
def dual_asr_data_generator(wav_seg, wav_path, meta):
|
||||
left_audio, right_audio = wav_seg.split_to_mono()
|
||||
channel_map = {"Agent": right_audio, "Client": left_audio}
|
||||
monologues = lens["monologues"].Each().collect()(meta)
|
||||
for monologue in monologues:
|
||||
# print(monologue["speaker_name"])
|
||||
speaker_channel = channel_map.get(monologue["speaker_name"])
|
||||
if not speaker_channel:
|
||||
if verbose:
|
||||
print(
|
||||
f'unknown speaker tag {monologue["speaker_name"]} in wav:{wav_path} skipping.'
|
||||
)
|
||||
continue
|
||||
try:
|
||||
start_time = (
|
||||
lens["elements"]
|
||||
.Each()
|
||||
.Filter(lambda x: "timestamp" in x)["timestamp"]
|
||||
.collect()(monologue)[0]
|
||||
)
|
||||
end_time = (
|
||||
lens["elements"]
|
||||
.Each()
|
||||
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
|
||||
.collect()(monologue)[-1]
|
||||
)
|
||||
except IndexError:
|
||||
if verbose:
|
||||
print(
|
||||
f"error when loading timestamp events in wav:{wav_path} skipping."
|
||||
)
|
||||
continue
|
||||
tscript_wav, seg_dur, text_clean = process_utterance_chunk(
|
||||
speaker_channel, start_time, end_time, monologue
|
||||
)
|
||||
if seg_dur < 0.5:
|
||||
if verbose:
|
||||
print(
|
||||
f'transcript chunk "{text_clean}" contains no audio in {wav_path} skipping.'
|
||||
)
|
||||
continue
|
||||
yield text_clean, seg_dur, tscript_wav
|
||||
|
||||
def mono_asr_data_generator(wav_seg, wav_path, meta):
|
||||
monologues = lens["monologues"].Each().collect()(meta)
|
||||
for monologue in monologues:
|
||||
try:
|
||||
start_time = (
|
||||
lens["elements"]
|
||||
.Each()
|
||||
.Filter(lambda x: "timestamp" in x)["timestamp"]
|
||||
.collect()(monologue)[0]
|
||||
)
|
||||
end_time = (
|
||||
lens["elements"]
|
||||
.Each()
|
||||
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
|
||||
.collect()(monologue)[-1]
|
||||
)
|
||||
except IndexError:
|
||||
if verbose:
|
||||
print(
|
||||
f"error when loading timestamp events in wav:{wav_path} skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
tscript_wav, seg_dur, text_clean = process_utterance_chunk(
|
||||
wav_seg, start_time, end_time, monologue
|
||||
)
|
||||
if seg_dur < 0.5:
|
||||
if verbose:
|
||||
print(
|
||||
f'transcript chunk "{text_clean}" contains no audio in {wav_path} skipping.'
|
||||
)
|
||||
continue
|
||||
yield text_clean, seg_dur, tscript_wav
|
||||
|
||||
def generate_rev_asr_data():
|
||||
full_asr_data = []
|
||||
total_duration = 0
|
||||
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
|
||||
if wav.channels > 2:
|
||||
print(f"skipping many channel audio {wav_path}")
|
||||
asr_data_generator = (
|
||||
mono_asr_data_generator
|
||||
if wav.channels == 1
|
||||
else dual_asr_data_generator
|
||||
)
|
||||
asr_data = asr_data_generator(wav, wav_path, ev)
|
||||
total_duration += wav.duration_seconds
|
||||
full_asr_data.append(asr_data)
|
||||
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
|
||||
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
|
||||
typer.echo(f"written {n_dps} data points")
|
||||
|
||||
generate_rev_asr_data()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,52 +0,0 @@
|
||||
from logging import getLogger
|
||||
from google.cloud import texttospeech
|
||||
|
||||
LOGGER = getLogger("googletts")
|
||||
|
||||
|
||||
class GoogleTTS(object):
|
||||
def __init__(self):
|
||||
self.client = texttospeech.TextToSpeechClient()
|
||||
|
||||
def text_to_speech(self, text: str, params: dict) -> bytes:
|
||||
tts_input = texttospeech.types.SynthesisInput(ssml=text)
|
||||
voice = texttospeech.types.VoiceSelectionParams(
|
||||
language_code=params["language"], name=params["name"]
|
||||
)
|
||||
audio_config = texttospeech.types.AudioConfig(
|
||||
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=params["sample_rate"],
|
||||
)
|
||||
response = self.client.synthesize_speech(tts_input, voice, audio_config)
|
||||
audio_content = response.audio_content
|
||||
return audio_content
|
||||
|
||||
@classmethod
|
||||
def voice_list(cls):
|
||||
"""Lists the available voices."""
|
||||
|
||||
client = cls().client
|
||||
|
||||
# Performs the list voices request
|
||||
voices = client.list_voices()
|
||||
results = []
|
||||
for voice in voices.voices:
|
||||
supported_eng_langs = [
|
||||
lang for lang in voice.language_codes if lang[:2] == "en"
|
||||
]
|
||||
if len(supported_eng_langs) > 0:
|
||||
lang = ",".join(supported_eng_langs)
|
||||
else:
|
||||
continue
|
||||
|
||||
ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender)
|
||||
results.append(
|
||||
{
|
||||
"name": voice.name,
|
||||
"language": lang,
|
||||
"gender": ssml_gender.name,
|
||||
"engine": "wavenet" if "Wav" in voice.name else "standard",
|
||||
"sample_rate": voice.natural_sample_rate_hertz,
|
||||
}
|
||||
)
|
||||
return results
|
||||
@@ -1,26 +0,0 @@
|
||||
"""
|
||||
TTSClient Abstract Class
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class TTSClient(ABC):
|
||||
"""
|
||||
Base class for TTS
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def text_to_speech(self, text: str, num_channels: int, sample_rate: int,
|
||||
audio_encoding) -> bytes:
|
||||
"""
|
||||
convert text to bytes
|
||||
|
||||
Arguments:
|
||||
text {[type]} -- text to convert
|
||||
channel {[type]} -- output audio bytes channel setting
|
||||
width {[type]} -- width of audio bytes
|
||||
rate {[type]} -- rare for audio bytes
|
||||
|
||||
Returns:
|
||||
[type] -- [description]
|
||||
"""
|
||||
@@ -1,62 +0,0 @@
|
||||
# import io
|
||||
# import sys
|
||||
# import json
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from .utils import random_pnr_generator, asr_data_writer
|
||||
from .tts.googletts import GoogleTTS
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pnr_tts_streamer(count):
|
||||
google_voices = GoogleTTS.voice_list()
|
||||
gtts = GoogleTTS()
|
||||
for pnr_code in tqdm(random_pnr_generator(count)):
|
||||
tts_code = f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
|
||||
param = random.choice(google_voices)
|
||||
param["sample_rate"] = 24000
|
||||
param["num_channels"] = 1
|
||||
wav_data = gtts.text_to_speech(text=tts_code, params=param)
|
||||
audio_dur = len(wav_data[44:]) / (2 * 24000)
|
||||
yield pnr_code, audio_dur, wav_data
|
||||
|
||||
|
||||
def generate_asr_data_fromtts(output_dir, dataset_name, count):
|
||||
asr_data_writer(output_dir, dataset_name, pnr_tts_streamer(count))
|
||||
|
||||
|
||||
def arg_parser():
|
||||
prog = Path(__file__).stem
|
||||
parser = argparse.ArgumentParser(
|
||||
prog=prog, description=f"generates asr training data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=Path,
|
||||
default=Path("./train/asr_data"),
|
||||
help="directory to output asr data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--count", type=int, default=3, help="number of datapoints to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name", type=str, default="pnr_data", help="name of the dataset"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = arg_parser()
|
||||
args = parser.parse_args()
|
||||
generate_asr_data_fromtts(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,92 +0,0 @@
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def compute_pnr_name_city():
|
||||
data = pd.read_csv("./customer_utterance_processing/customer_provide_answer.csv")
|
||||
|
||||
def unique_pnr_count():
|
||||
pnr_data = data[data["Input.Answer"] == "ZZZZZZ"]
|
||||
unique_pnr_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
|
||||
if "ZZZZZZ" in t
|
||||
}
|
||||
return len(unique_pnr_set)
|
||||
|
||||
def unique_name_count():
|
||||
pnr_data = data[data["Input.Answer"] == "John Doe"]
|
||||
unique_pnr_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
|
||||
if "John Doe" in t
|
||||
}
|
||||
return len(unique_pnr_set)
|
||||
|
||||
def unique_city_count():
|
||||
pnr_data = data[data["Input.Answer"] == "Heathrow Airport"]
|
||||
unique_pnr_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
|
||||
if "Heathrow Airport" in t
|
||||
}
|
||||
return len(unique_pnr_set)
|
||||
|
||||
def unique_entity_count(entity_template_tags):
|
||||
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
|
||||
entity_data = data
|
||||
unique_entity_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in entity_data[f"Answer.utterance-{n}"].tolist()
|
||||
if any(et in t for et in entity_template_tags)
|
||||
}
|
||||
return len(unique_entity_set)
|
||||
|
||||
print('PNR', unique_pnr_count())
|
||||
print('Name', unique_name_count())
|
||||
print('City', unique_city_count())
|
||||
print('Payment', unique_entity_count(['KPay', 'ZPay', 'Credit Card']))
|
||||
|
||||
|
||||
def compute_date():
|
||||
entity_template_tags = ['27 january', 'December 18']
|
||||
data = pd.read_csv("./customer_utterance_processing/customer_provide_departure.csv")
|
||||
# data.sample(10)
|
||||
|
||||
def unique_entity_count(entity_template_tags):
|
||||
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
|
||||
entity_data = data
|
||||
unique_entity_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in entity_data[f"Answer.utterance-{n}"].tolist()
|
||||
if any(et in t for et in entity_template_tags)
|
||||
}
|
||||
return len(unique_entity_set)
|
||||
|
||||
print('Date', unique_entity_count(entity_template_tags))
|
||||
|
||||
|
||||
def compute_option():
|
||||
entity_template_tag = 'third'
|
||||
data = pd.read_csv("./customer_utterance_processing/customer_provide_flight_selection.csv")
|
||||
|
||||
def unique_entity_count():
|
||||
entity_data = data[data['Input.Prompt'] == entity_template_tag]
|
||||
unique_entity_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in entity_data[f"Answer.utterance-{n}"].tolist()
|
||||
if entity_template_tag in t
|
||||
}
|
||||
return len(unique_entity_set)
|
||||
|
||||
print('Option', unique_entity_count())
|
||||
|
||||
|
||||
compute_pnr_name_city()
|
||||
compute_date()
|
||||
compute_option()
|
||||
@@ -1,22 +1,20 @@
|
||||
import numpy as np
|
||||
import wave
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from uuid import uuid4
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pymongo
|
||||
from slugify import slugify
|
||||
from uuid import uuid4
|
||||
from num2words import num2words
|
||||
from jasper.client import transcribe_gen
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
def manifest_str(path, dur, text):
|
||||
@@ -36,27 +34,8 @@ def wav_bytes(audio_bytes, frame_rate=24000):
|
||||
return wf_b.getvalue()
|
||||
|
||||
|
||||
def random_pnr_generator(count=10000):
|
||||
LENGTH = 3
|
||||
|
||||
# alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
alphabet = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
numeric = list("0123456789")
|
||||
np_alphabet = np.array(alphabet, dtype="|S1")
|
||||
np_numeric = np.array(numeric, dtype="|S1")
|
||||
np_alpha_codes = np.random.choice(np_alphabet, [count, LENGTH])
|
||||
np_num_codes = np.random.choice(np_numeric, [count, LENGTH])
|
||||
np_code_seed = np.concatenate((np_alpha_codes, np_num_codes), axis=1).T
|
||||
np.random.shuffle(np_code_seed)
|
||||
np_codes = np_code_seed.T
|
||||
codes = [(b"".join(np_codes[i])).decode("utf-8") for i in range(len(np_codes))]
|
||||
return codes
|
||||
|
||||
|
||||
def alnum_to_asr_tokens(text):
|
||||
letters = " ".join(list(text))
|
||||
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
|
||||
return ("".join(num_tokens)).lower()
|
||||
def tscript_uuid_fname(transcript):
|
||||
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
|
||||
|
||||
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
@@ -67,11 +46,11 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
with asr_manifest.open("w") as mf:
|
||||
print(f"writing manifest to {asr_manifest}")
|
||||
for transcript, audio_dur, wav_data in asr_data_source:
|
||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
fname = tscript_uuid_fname(transcript)
|
||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
audio_file.write_bytes(wav_data)
|
||||
rel_pnr_path = audio_file.relative_to(dataset_dir)
|
||||
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
|
||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||
mf.write(manifest)
|
||||
if verbose:
|
||||
print(f"writing '{transcript}' of duration {audio_dur}")
|
||||
@@ -79,102 +58,99 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
return num_datapoints
|
||||
|
||||
|
||||
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
dataset_dir = output_dir / Path(dataset_name)
|
||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
|
||||
def data_fn(
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
fname,
|
||||
audio_path,
|
||||
num_datapoints,
|
||||
rel_data_path,
|
||||
):
|
||||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
||||
png_path = Path(fname).with_suffix(".png")
|
||||
wav_plot_path = dataset_dir / Path("wav_plots") / png_path
|
||||
if not wav_plot_path.exists():
|
||||
plot_seg(wav_plot_path, audio_path)
|
||||
return {
|
||||
"audio_filepath": str(rel_data_path),
|
||||
"duration": round(audio_dur, 1),
|
||||
"text": transcript,
|
||||
"real_idx": num_datapoints,
|
||||
"audio_path": audio_path,
|
||||
"spoken": transcript,
|
||||
"caller": caller_name,
|
||||
"utterance_id": fname,
|
||||
"pretrained_asr": pretrained_result,
|
||||
"pretrained_wer": pretrained_wer,
|
||||
"plot_path": str(wav_plot_path),
|
||||
}
|
||||
|
||||
num_datapoints = 0
|
||||
ui_dump = {
|
||||
"use_domain_asr": False,
|
||||
"annotation_only": False,
|
||||
"enable_plots": True,
|
||||
"data": [],
|
||||
}
|
||||
data_funcs = []
|
||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||
with asr_manifest.open("w") as mf:
|
||||
print(f"writing manifest to {asr_manifest}")
|
||||
|
||||
def data_fn(
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
fname,
|
||||
audio_path,
|
||||
num_datapoints,
|
||||
rel_pnr_path,
|
||||
):
|
||||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
||||
wav_plot_path = (
|
||||
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
plot_seg(wav_plot_path, audio_path)
|
||||
return {
|
||||
"audio_filepath": str(rel_pnr_path),
|
||||
"duration": round(audio_dur, 1),
|
||||
"text": transcript,
|
||||
"real_idx": num_datapoints,
|
||||
"audio_path": audio_path,
|
||||
"spoken": transcript,
|
||||
"caller": caller_name,
|
||||
"utterance_id": fname,
|
||||
"pretrained_asr": pretrained_result,
|
||||
"pretrained_wer": pretrained_wer,
|
||||
"plot_path": str(wav_plot_path),
|
||||
}
|
||||
|
||||
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
audio_file.write_bytes(wav_data)
|
||||
audio_path = str(audio_file)
|
||||
rel_pnr_path = audio_file.relative_to(dataset_dir)
|
||||
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
|
||||
mf.write(manifest)
|
||||
data_funcs.append(
|
||||
partial(
|
||||
data_fn,
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
fname,
|
||||
audio_path,
|
||||
num_datapoints,
|
||||
rel_pnr_path,
|
||||
)
|
||||
)
|
||||
num_datapoints += 1
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all plot/transcription tasks")
|
||||
dump_data = list(
|
||||
tqdm(
|
||||
exe.map(lambda x: x(), data_funcs),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=len(data_funcs),
|
||||
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
audio_file.write_bytes(wav_data)
|
||||
audio_path = str(audio_file)
|
||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
||||
data_funcs.append(
|
||||
partial(
|
||||
data_fn,
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
fname,
|
||||
audio_path,
|
||||
num_datapoints,
|
||||
rel_data_path,
|
||||
)
|
||||
)
|
||||
ui_dump["data"] = dump_data
|
||||
ExtendedPath(ui_dump_file).write_json(ui_dump)
|
||||
num_datapoints += 1
|
||||
ui_data = parallel_apply(lambda x: x(), data_funcs)
|
||||
return ui_data, num_datapoints
|
||||
|
||||
|
||||
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
dataset_dir = output_dir / Path(dataset_name)
|
||||
dump_data, num_datapoints = ui_data_generator(
|
||||
output_dir, dataset_name, asr_data_source, verbose=verbose
|
||||
)
|
||||
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
with asr_manifest.open("w") as mf:
|
||||
print(f"writing manifest to {asr_manifest}")
|
||||
for d in dump_data:
|
||||
rel_data_path = d["audio_filepath"]
|
||||
audio_dur = d["duration"]
|
||||
transcript = d["text"]
|
||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||
mf.write(manifest)
|
||||
|
||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
|
||||
return num_datapoints
|
||||
|
||||
|
||||
def asr_manifest_reader(data_manifest_path: Path):
|
||||
print(f"reading manifest from {data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_data = [json.loads(v) for v in pnr_jsonl]
|
||||
for p in pnr_data:
|
||||
data_jsonl = pf.readlines()
|
||||
data_data = [json.loads(v) for v in data_jsonl]
|
||||
for p in data_data:
|
||||
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
||||
p["chars"] = Path(p["audio_filepath"]).stem
|
||||
p["text"] = p["text"].strip()
|
||||
yield p
|
||||
|
||||
|
||||
@@ -188,6 +164,32 @@ def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
|
||||
mf.write(manifest)
|
||||
|
||||
|
||||
def asr_test_writer(out_file_path: Path, source):
|
||||
def dd_str(dd, idx):
|
||||
path = dd["audio_filepath"]
|
||||
# dur = dd["duration"]
|
||||
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
|
||||
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
|
||||
|
||||
res_file = out_file_path.with_suffix(".result.json")
|
||||
with out_file_path.open("w") as of:
|
||||
print(f"opening {out_file_path} for writing test")
|
||||
results = []
|
||||
idx = 0
|
||||
for ui_dd in source:
|
||||
results.append(ui_dd)
|
||||
out_str = dd_str(ui_dd, idx)
|
||||
of.write(out_str)
|
||||
idx += 1
|
||||
of.write("DO_HANGUP\n")
|
||||
ExtendedPath(res_file).write_json(results)
|
||||
|
||||
|
||||
def batch(iterable, n=1):
|
||||
ls = len(iterable)
|
||||
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
|
||||
|
||||
|
||||
class ExtendedPath(type(Path())):
|
||||
"""docstring for ExtendedPath."""
|
||||
|
||||
@@ -203,12 +205,6 @@ class ExtendedPath(type(Path())):
|
||||
return json.dump(data, jf, indent=2)
|
||||
|
||||
|
||||
def get_mongo_coll(uri="mongodb://localhost:27017/test.calls"):
|
||||
ud = pymongo.uri_parser.parse_uri(uri)
|
||||
conn = pymongo.MongoClient(uri)
|
||||
return conn[ud["database"]][ud["collection"]]
|
||||
|
||||
|
||||
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
|
||||
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
||||
mongo_uri = f"mongodb://{mongo_host}:{port}/"
|
||||
@@ -234,10 +230,12 @@ def plot_seg(wav_plot_path, audio_path):
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
|
||||
|
||||
def main():
|
||||
for c in random_pnr_generator():
|
||||
print(c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
def parallel_apply(fn, iterable, workers=8):
|
||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||
print(f"parallelly applying {fn}")
|
||||
return [
|
||||
res
|
||||
for res in tqdm(
|
||||
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import (
|
||||
alnum_to_asr_tokens,
|
||||
ExtendedPath,
|
||||
asr_manifest_reader,
|
||||
asr_manifest_writer,
|
||||
tscript_uuid_fname,
|
||||
get_mongo_conn,
|
||||
plot_seg,
|
||||
)
|
||||
@@ -18,9 +16,7 @@ from ..utils import (
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def preprocess_datapoint(
|
||||
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
|
||||
):
|
||||
def preprocess_datapoint(idx, rel_root, sample):
|
||||
from pydub import AudioSegment
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
from jasper.client import transcribe_gen
|
||||
@@ -30,37 +26,23 @@ def preprocess_datapoint(
|
||||
res["real_idx"] = idx
|
||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
||||
res["audio_path"] = str(audio_path)
|
||||
if use_domain_asr:
|
||||
res["spoken"] = alnum_to_asr_tokens(res["text"])
|
||||
else:
|
||||
res["spoken"] = res["text"]
|
||||
res["utterance_id"] = audio_path.stem
|
||||
if not annotation_only:
|
||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||
|
||||
aud_seg = (
|
||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["pretrained_wer"] = word_error_rate(
|
||||
[res["text"]], [res["pretrained_asr"]]
|
||||
)
|
||||
if use_domain_asr:
|
||||
transcriber_speller = transcribe_gen(asr_port=8045)
|
||||
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["domain_wer"] = word_error_rate(
|
||||
[res["spoken"]], [res["pretrained_asr"]]
|
||||
)
|
||||
if enable_plots:
|
||||
wav_plot_path = (
|
||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
plot_seg(wav_plot_path, audio_path)
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
aud_seg = (
|
||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
|
||||
wav_plot_path = (
|
||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
plot_seg(wav_plot_path, audio_path)
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
return res
|
||||
except BaseException as e:
|
||||
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
||||
@@ -68,70 +50,59 @@ def preprocess_datapoint(
|
||||
|
||||
@app.command()
|
||||
def dump_ui(
|
||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dataset_dir: Path = Path("./data/asr_data"),
|
||||
dump_dir: Path = Path("./data/valiation_data"),
|
||||
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
|
||||
use_domain_asr: bool = False,
|
||||
annotation_only: bool = False,
|
||||
enable_plots: bool = True,
|
||||
):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from pydub import AudioSegment
|
||||
from ..utils import ui_data_generator
|
||||
|
||||
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
|
||||
dump_path: Path = dump_dir / Path(data_name) / dump_fname
|
||||
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_funcs = [
|
||||
partial(
|
||||
preprocess_datapoint,
|
||||
i,
|
||||
data_manifest_path.parent,
|
||||
json.loads(v),
|
||||
use_domain_asr,
|
||||
annotation_only,
|
||||
enable_plots,
|
||||
)
|
||||
for i, v in enumerate(pnr_jsonl)
|
||||
]
|
||||
|
||||
def exec_func(f):
|
||||
return f()
|
||||
def asr_data_source_gen():
|
||||
with data_manifest_path.open("r") as pf:
|
||||
data_jsonl = pf.readlines()
|
||||
for v in data_jsonl:
|
||||
sample = json.loads(v)
|
||||
rel_root = data_manifest_path.parent
|
||||
res = dict(sample)
|
||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
||||
audio_segment = (
|
||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
wav_plot_path = (
|
||||
rel_root
|
||||
/ Path("wav_plots")
|
||||
/ Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
plot_seg(wav_plot_path, audio_path)
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
code_fb = BytesIO()
|
||||
audio_segment.export(code_fb, format="wav")
|
||||
wav_data = code_fb.getvalue()
|
||||
duration = audio_segment.duration_seconds
|
||||
asr_final = res["text"]
|
||||
yield asr_final, duration, wav_data, "caller", audio_segment
|
||||
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all preprocess tasks")
|
||||
pnr_data = filter(
|
||||
None,
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, pnr_funcs),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=len(pnr_funcs),
|
||||
)
|
||||
),
|
||||
)
|
||||
if annotation_only:
|
||||
result = list(pnr_data)
|
||||
else:
|
||||
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
||||
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
||||
ui_config = {
|
||||
"use_domain_asr": use_domain_asr,
|
||||
"annotation_only": annotation_only,
|
||||
"enable_plots": enable_plots,
|
||||
"data": result,
|
||||
}
|
||||
ExtendedPath(dump_path).write_json(ui_config)
|
||||
dump_data, num_datapoints = ui_data_generator(
|
||||
dataset_dir, data_name, asr_data_source_gen()
|
||||
)
|
||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
|
||||
|
||||
|
||||
@app.command()
|
||||
def sample_ui(
|
||||
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
sample_count: int = typer.Option(80, show_default=True),
|
||||
@@ -157,7 +128,7 @@ def sample_ui(
|
||||
|
||||
@app.command()
|
||||
def task_ui(
|
||||
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
task_count: int = typer.Option(4, show_default=True),
|
||||
@@ -180,14 +151,18 @@ def task_ui(
|
||||
|
||||
@app.command()
|
||||
def dump_corrections(
|
||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||
task_uid: str,
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_fname: Path = Path("corrections.json"),
|
||||
):
|
||||
dump_path = dump_dir / Path(data_name) / dump_fname
|
||||
col = get_mongo_conn(col="asr_validation")
|
||||
|
||||
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
||||
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
||||
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
||||
cursor_obj = col.find(
|
||||
{"type": "correction", "task_id": task_id}, projection={"_id": False}
|
||||
)
|
||||
corrections = [c for c in cursor_obj]
|
||||
ExtendedPath(dump_path).write_json(corrections)
|
||||
|
||||
@@ -195,7 +170,7 @@ def dump_corrections(
|
||||
@app.command()
|
||||
def caller_quality(
|
||||
task_uid: str,
|
||||
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_fname: Path = Path("ui_dump.json"),
|
||||
correction_fname: Path = Path("corrections.json"),
|
||||
@@ -231,7 +206,7 @@ def caller_quality(
|
||||
|
||||
@app.command()
|
||||
def fill_unannotated(
|
||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/valiation_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
corrections_file: Path = Path("corrections.json"),
|
||||
@@ -252,97 +227,96 @@ def fill_unannotated(
|
||||
)
|
||||
|
||||
|
||||
class ExtractionType(str, Enum):
|
||||
date = "dates"
|
||||
city = "cities"
|
||||
name = "names"
|
||||
|
||||
|
||||
@app.command()
|
||||
def split_extract(
|
||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
|
||||
dump_dir: Path = Path("./data/valiation_data"),
|
||||
# dump_dir: Path = Path("./data/valiation_data"),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
manifest_dir: Path = Path("./data/asr_data"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: Path = Path("corrections.json"),
|
||||
conv_data_path: Path = Path("./data/conv_data.json"),
|
||||
extraction_type: ExtractionType = ExtractionType.date,
|
||||
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
||||
conv_data_path: Path = typer.Option(
|
||||
Path("./data/conv_data.json"), show_default=True
|
||||
),
|
||||
extraction_type: str = "all",
|
||||
):
|
||||
import shutil
|
||||
|
||||
def get_conv_data(cdp):
|
||||
from itertools import product
|
||||
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||
conv_data = ExtendedPath(conv_data_path).read_json()
|
||||
|
||||
conv_data = json.load(cdp.open())
|
||||
days = [str(i) for i in range(1, 32)]
|
||||
months = conv_data["months"]
|
||||
day_months = {d + " " + m for d, m in product(days, months)}
|
||||
return {
|
||||
"cities": set(conv_data["cities"]),
|
||||
"names": set(conv_data["names"]),
|
||||
"dates": day_months,
|
||||
}
|
||||
def extract_data_of_type(extraction_key):
|
||||
extraction_vals = conv_data[extraction_key]
|
||||
dest_data_name = data_name + "_" + extraction_key.lower()
|
||||
|
||||
dest_data_name = data_name + "_" + extraction_type.value
|
||||
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
|
||||
conv_data = get_conv_data(conv_data_path)
|
||||
extraction_vals = conv_data[extraction_type.value]
|
||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
dest_data_dir = dump_dir / Path(dest_data_name)
|
||||
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
||||
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
||||
dest_manifest_path = dest_data_dir / manifest_file
|
||||
dest_ui_path = dest_data_dir / dump_file
|
||||
|
||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
dest_data_dir = manifest_dir / Path(dest_data_name)
|
||||
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
||||
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
||||
dest_manifest_path = dest_data_dir / manifest_file
|
||||
dest_ui_dir = dump_dir / Path(dest_data_name)
|
||||
dest_ui_dir.mkdir(exist_ok=True, parents=True)
|
||||
dest_ui_path = dest_ui_dir / dump_file
|
||||
dest_correction_path = dest_ui_dir / corrections_file
|
||||
def extract_manifest(mg):
|
||||
for m in mg:
|
||||
if m["text"] in extraction_vals:
|
||||
shutil.copy(
|
||||
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
|
||||
)
|
||||
yield m
|
||||
|
||||
def extract_manifest(mg):
|
||||
for m in mg:
|
||||
if m["text"] in extraction_vals:
|
||||
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
|
||||
yield m
|
||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||
|
||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||
|
||||
ui_data_path = dump_dir / Path(data_name) / dump_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
ui_data = json.load(ui_data_path.open())["data"]
|
||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||
corrections = json.load(corrections_path.open())
|
||||
|
||||
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
||||
ExtendedPath(dest_ui_path).write_json(extracted_ui_data)
|
||||
|
||||
extracted_corrections = list(
|
||||
filter(
|
||||
lambda c: c["code"] in file_ui_map
|
||||
and file_ui_map[c["code"]]["text"] in extraction_vals,
|
||||
corrections,
|
||||
ui_data_path = dump_dir / Path(data_name) / dump_file
|
||||
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
||||
ui_data = orig_ui_data["data"]
|
||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||
extracted_ui_data = list(
|
||||
filter(lambda u: u["text"] in extraction_vals, ui_data)
|
||||
)
|
||||
)
|
||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
||||
final_data = []
|
||||
for i, d in enumerate(extracted_ui_data):
|
||||
d["real_idx"] = i
|
||||
final_data.append(d)
|
||||
orig_ui_data["data"] = final_data
|
||||
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
||||
|
||||
if corrections_file:
|
||||
dest_correction_path = dest_data_dir / corrections_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
corrections = json.load(corrections_path.open())
|
||||
extracted_corrections = list(
|
||||
filter(
|
||||
lambda c: c["code"] in file_ui_map
|
||||
and file_ui_map[c["code"]]["text"] in extraction_vals,
|
||||
corrections,
|
||||
)
|
||||
)
|
||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
||||
|
||||
if extraction_type.value == "all":
|
||||
for ext_key in conv_data.keys():
|
||||
extract_data_of_type(ext_key)
|
||||
else:
|
||||
extract_data_of_type(extraction_type.value)
|
||||
|
||||
|
||||
@app.command()
|
||||
def update_corrections(
|
||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||
dump_dir: Path = Path("./data/valiation_data"),
|
||||
manifest_dir: Path = Path("./data/asr_data"),
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: Path = Path("corrections.json"),
|
||||
# data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
# corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
||||
skip_incorrect: bool = True,
|
||||
ui_dump_file: Path = Path("ui_dump.json"),
|
||||
skip_incorrect: bool = typer.Option(True, show_default=True),
|
||||
):
|
||||
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
|
||||
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
|
||||
|
||||
def correct_manifest(manifest_data_gen, corrections_path):
|
||||
corrections = json.load(corrections_path.open())
|
||||
def correct_manifest(ui_dump_path, corrections_path):
|
||||
corrections = ExtendedPath(corrections_path).read_json()
|
||||
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
|
||||
correct_set = {
|
||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||
}
|
||||
@@ -355,36 +329,40 @@ def update_corrections(
|
||||
# for d in manifest_data_gen:
|
||||
# if d["chars"] in incorrect_set:
|
||||
# d["audio_path"].unlink()
|
||||
renamed_set = set()
|
||||
for d in manifest_data_gen:
|
||||
if d["chars"] in correct_set:
|
||||
# renamed_set = set()
|
||||
for d in ui_data:
|
||||
if d["utterance_id"] in correct_set:
|
||||
yield {
|
||||
"audio_filepath": d["audio_filepath"],
|
||||
"duration": d["duration"],
|
||||
"text": d["text"],
|
||||
}
|
||||
elif d["chars"] in correction_map:
|
||||
correct_text = correction_map[d["chars"]]
|
||||
elif d["utterance_id"] in correction_map:
|
||||
correct_text = correction_map[d["utterance_id"]]
|
||||
if skip_incorrect:
|
||||
print(
|
||||
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
||||
)
|
||||
else:
|
||||
renamed_set.add(correct_text)
|
||||
new_name = str(Path(correct_text).with_suffix(".wav"))
|
||||
d["audio_path"].replace(d["audio_path"].with_name(new_name))
|
||||
orig_audio_path = Path(d["audio_path"])
|
||||
new_name = str(
|
||||
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
|
||||
)
|
||||
new_audio_path = orig_audio_path.with_name(new_name)
|
||||
orig_audio_path.replace(new_audio_path)
|
||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||
yield {
|
||||
"audio_filepath": new_filepath,
|
||||
"duration": d["duration"],
|
||||
"text": alnum_to_asr_tokens(correct_text),
|
||||
"text": correct_text,
|
||||
}
|
||||
else:
|
||||
orig_audio_path = Path(d["audio_path"])
|
||||
# don't delete if another correction points to an old file
|
||||
if d["chars"] not in renamed_set:
|
||||
d["audio_path"].unlink()
|
||||
else:
|
||||
print(f'skipping deletion of correction:{d["chars"]}')
|
||||
# if d["text"] not in renamed_set:
|
||||
orig_audio_path.unlink()
|
||||
# else:
|
||||
# print(f'skipping deletion of correction:{d["text"]}')
|
||||
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
dataset_dir = data_manifest_path.parent
|
||||
@@ -393,8 +371,8 @@ def update_corrections(
|
||||
if not backup_dir.exists():
|
||||
typer.echo(f"backing up to :{backup_dir}")
|
||||
shutil.copytree(str(dataset_dir), str(backup_dir))
|
||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
|
||||
# manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
|
||||
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
||||
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
||||
new_data_manifest_path.replace(data_manifest_path)
|
||||
|
||||
@@ -42,7 +42,9 @@ if not hasattr(st, "mongo_connected"):
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
def set_task_fn(mf_path):
|
||||
def set_task_fn(mf_path, task_id):
|
||||
if task_id:
|
||||
st.task_id = task_id
|
||||
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
|
||||
if not task_path.exists():
|
||||
print(f"creating task lock at {task_path}")
|
||||
@@ -66,26 +68,22 @@ def load_ui_data(validation_ui_data_path: Path):
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(manifest: Path):
|
||||
st.set_task(manifest)
|
||||
def main(manifest: Path, task_id: str = ""):
|
||||
st.set_task(manifest, task_id)
|
||||
ui_config = load_ui_data(manifest)
|
||||
asr_data = ui_config["data"]
|
||||
use_domain_asr = ui_config.get("use_domain_asr", True)
|
||||
annotation_only = ui_config.get("annotation_only", False)
|
||||
enable_plots = ui_config.get("enable_plots", True)
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
title_type = "Speller " if use_domain_asr else ""
|
||||
task_uid = st.task_id.rsplit("-", 1)[1]
|
||||
if annotation_only:
|
||||
st.title(f"ASR Annotation - # {task_uid}")
|
||||
else:
|
||||
st.title(f"ASR {title_type}Validation - # {task_uid}")
|
||||
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
||||
st.title(f"ASR Validation - # {task_uid}")
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
@@ -94,19 +92,13 @@ def main(manifest: Path):
|
||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
if not annotation_only:
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
||||
st.sidebar.title("Results:")
|
||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||
if "caller" in sample:
|
||||
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
||||
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
||||
else:
|
||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||
if enable_plots:
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
||||
# set default to text
|
||||
corrected = sample["text"]
|
||||
@@ -128,16 +120,12 @@ def main(manifest: Path):
|
||||
)
|
||||
st.update_cursor(sample_no + 1)
|
||||
if correction_entry:
|
||||
st.markdown(
|
||||
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
||||
)
|
||||
status = correction_entry["value"]["status"]
|
||||
correction = correction_entry["value"]["correction"]
|
||||
st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
|
||||
text_sample = st.text_input("Go to Text:", value="")
|
||||
if text_sample != "":
|
||||
candidates = [
|
||||
i
|
||||
for (i, p) in enumerate(asr_data)
|
||||
if p["text"] == text_sample or p["spoken"] == text_sample
|
||||
]
|
||||
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
|
||||
if len(candidates) > 0:
|
||||
st.update_cursor(candidates[0])
|
||||
real_idx = st.number_input(
|
||||
|
||||
359
jasper/evaluate.py
Normal file
359
jasper/evaluate.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
# import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
# monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
# from nemo.utils.lr_policies import CosineAnnealing
|
||||
from training.data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper-speller",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs", type=int, required=False, help="number of epochs to train"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="encoder checkpoint file: JasperEncoder.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="decoder checkpoint file: JasperDecoderForCTC.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
# data_layer = data_loader_layer(
|
||||
# manifest_filepath=args.train_dataset,
|
||||
# sample_rate=sample_rate,
|
||||
# labels=vocab,
|
||||
# batch_size=args.batch_size,
|
||||
# num_workers=cpu_per_traindl,
|
||||
# **train_dl_params,
|
||||
# # normalize_transcripts=False
|
||||
# )
|
||||
#
|
||||
# N = len(data_layer)
|
||||
# steps_per_epoch = math.ceil(
|
||||
# N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
# )
|
||||
# logging.info("Have {0} examples to train on.".format(N))
|
||||
#
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
|
||||
)
|
||||
|
||||
# multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
# if multiply_batch_config:
|
||||
# multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
#
|
||||
# spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
# if spectr_augment_config:
|
||||
# data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
# **spectr_augment_config
|
||||
# )
|
||||
#
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
# if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
# else:
|
||||
# logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
# logging.info("================================")
|
||||
# logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
|
||||
# logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
|
||||
# logging.info(
|
||||
# f"Total number of parameters in model: "
|
||||
# f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
# )
|
||||
# logging.info("================================")
|
||||
#
|
||||
# # Train DAG
|
||||
# (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
|
||||
# processed_signal_t, p_length_t = data_preprocessor(
|
||||
# input_signal=audio_signal_t, length=a_sig_length_t
|
||||
# )
|
||||
#
|
||||
# if multiply_batch_config:
|
||||
# (
|
||||
# processed_signal_t,
|
||||
# p_length_t,
|
||||
# transcript_t,
|
||||
# transcript_len_t,
|
||||
# ) = multiply_batch(
|
||||
# in_x=processed_signal_t,
|
||||
# in_x_len=p_length_t,
|
||||
# in_y=transcript_t,
|
||||
# in_y_len=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# if spectr_augment_config:
|
||||
# processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
|
||||
#
|
||||
# encoded_t, encoded_len_t = jasper_encoder(
|
||||
# audio_signal=processed_signal_t, length=p_length_t
|
||||
# )
|
||||
# log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
# predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
# loss_t = ctc_loss(
|
||||
# log_probs=log_probs_t,
|
||||
# targets=transcript_t,
|
||||
# input_length=encoded_len_t,
|
||||
# target_length=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# # Callbacks needed to print info to console and Tensorboard
|
||||
# train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
# tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
# print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
# get_tb_values=lambda x: [("loss", x[0])],
|
||||
# tb_writer=neural_factory.tb_writer,
|
||||
# )
|
||||
#
|
||||
# chpt_callback = nemo.core.CheckpointCallback(
|
||||
# folder=neural_factory.checkpoint_dir,
|
||||
# load_from_folder=args.load_dir,
|
||||
# step_freq=args.checkpoint_save_freq,
|
||||
# checkpoints_to_keep=30,
|
||||
# )
|
||||
#
|
||||
# callbacks = [train_callback, chpt_callback]
|
||||
callbacks = []
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return callbacks
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
# name = construct_name(
|
||||
# args.exp_name,
|
||||
# args.lr,
|
||||
# args.batch_size,
|
||||
# args.max_steps,
|
||||
# args.num_epochs,
|
||||
# args.weight_decay,
|
||||
# args.optimizer,
|
||||
# args.iter_per_step,
|
||||
# )
|
||||
# log_dir = name
|
||||
# if args.work_dir:
|
||||
# log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
# local_rank=args.local_rank,
|
||||
# optimization_level=args.amp_opt_level,
|
||||
# log_dir=log_dir,
|
||||
# checkpoint_dir=args.checkpoint_dir,
|
||||
# create_tb_writer=args.create_tb_writer,
|
||||
# files_to_copy=[args.model_config, __file__],
|
||||
# cudnn_benchmark=args.cudnn_benchmark,
|
||||
# tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
# checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
callbacks = create_all_dags(args, neural_factory)
|
||||
# evaluate model
|
||||
neural_factory.eval(callbacks=callbacks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -41,7 +41,7 @@ def parse_args():
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=200,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
@@ -266,6 +266,7 @@ def create_all_dags(args, neural_factory):
|
||||
folder=neural_factory.checkpoint_dir,
|
||||
load_from_folder=args.load_dir,
|
||||
step_freq=args.checkpoint_save_freq,
|
||||
checkpoints_to_keep=30,
|
||||
)
|
||||
|
||||
callbacks = [train_callback, chpt_callback]
|
||||
|
||||
14
setup.py
14
setup.py
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
||||
|
||||
requirements = [
|
||||
"ruamel.yaml",
|
||||
"torch==1.4.0",
|
||||
"torch==2.8.0",
|
||||
"torchvision==0.5.0",
|
||||
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
||||
]
|
||||
@@ -19,13 +19,15 @@ extra_requirements = {
|
||||
"ruamel.yaml==0.16.10",
|
||||
"pymongo==3.10.1",
|
||||
"librosa==0.7.2",
|
||||
"numba==0.48",
|
||||
"matplotlib==3.2.1",
|
||||
"pandas==1.0.3",
|
||||
"tabulate==0.8.7",
|
||||
"natural==0.2.0",
|
||||
"num2words==0.5.10",
|
||||
"typer[all]==0.1.1",
|
||||
"typer[all]==0.3.1",
|
||||
"python-slugify==4.0.0",
|
||||
"rpyc~=4.1.4",
|
||||
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
||||
],
|
||||
"validation": [
|
||||
@@ -39,6 +41,7 @@ extra_requirements = {
|
||||
"streamlit==0.58.0",
|
||||
"natural==0.2.0",
|
||||
"stringcase==1.2.0",
|
||||
"google-cloud-speech~=1.3.1",
|
||||
]
|
||||
# "train": [
|
||||
# "torchaudio==0.5.0",
|
||||
@@ -63,14 +66,15 @@ setup(
|
||||
"jasper_transcribe = jasper.transcribe:main",
|
||||
"jasper_server = jasper.server:main",
|
||||
"jasper_trainer = jasper.training.cli:main",
|
||||
"jasper_evaluator = jasper.evaluate:main",
|
||||
"jasper_data_tts_generate = jasper.data.tts_generator:main",
|
||||
"jasper_data_conv_generate = jasper.data.conv_generator:main",
|
||||
"jasper_data_call_recycle = jasper.data.call_recycler:main",
|
||||
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
|
||||
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
|
||||
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
|
||||
"jasper_data_rastrik_recycle = jasper.data.rastrik_recycler:main",
|
||||
"jasper_data_server = jasper.data.server:main",
|
||||
"jasper_data_validation = jasper.data.validation.process:main",
|
||||
"jasper_data_preprocess = jasper.data.process:main",
|
||||
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
|
||||
]
|
||||
},
|
||||
zip_safe=False,
|
||||
|
||||
Reference in New Issue
Block a user