1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-03-08 02:22:34 +00:00

Compare commits

..

7 Commits

Author SHA1 Message Date
e30dd724f5 Merge pull request #3 from malarinv/dependabot/pip/torch-2.8.0
Bump torch from 1.4.0 to 2.8.0
2025-08-30 18:24:04 +05:30
dependabot[bot]
02df1b5282 Bump torch from 1.4.0 to 2.8.0
Bumps [torch](https://github.com/pytorch/pytorch) from 1.4.0 to 2.8.0.
- [Release notes](https://github.com/pytorch/pytorch/releases)
- [Changelog](https://github.com/pytorch/pytorch/blob/main/RELEASE.md)
- [Commits](https://github.com/pytorch/pytorch/compare/v1.4.0...v2.8.0)

---
updated-dependencies:
- dependency-name: torch
  dependency-version: 2.8.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-08-30 12:53:51 +00:00
e8f58a5043 1. refactored ui_dump
2. added flake8
2020-08-09 19:16:35 +05:30
42647196fe 1. fixed dependency issues
2. add task-id option to validation ui to respawn previous task
3. clean-up rastrik-recycler
2020-08-06 22:40:14 +05:30
e77943b2f2 Merge pull request #1 from wrat/master
adding support for asr data generator
2020-08-06 00:11:53 +05:30
wabi_sabi004
14d31a51c3 adding support for asr data generator 2020-08-06 00:08:46 +05:30
e24a8cf9d0 1. integrated data generator using google tts
2. added training script

fix module packaging issue

implement call audio data recycler for asr

1. added streamlit based validation ui with mongodb datastore integration
2. fix asr wrong sample rate inference
3. update requirements

1. refactored streamlit code
2. fixed issues in data manifest handling

refresh to next entry on submit and comment out mongo clearing code for safety :P

add validation ui and post processing to correct using validation data

1. added a tool to extract asr data from gcp transcripts logs
2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file
3. clean-up leaderboard duration logic
4. added a wip dataloader service
5. made the asr_data_writer util more generic with verbose flags and unique filename
6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node
7. refactored the validation post processing to dump a ui config for validation
8. included utility functions to correct, fill update and clear annotations from mongodb data
9. refactored the ui logic to be more generic for any asr data
10. updated setup.py dependencies to support the above features

unlink temporary files after transcribing

1. clean-up unused data process code
2. fix invalid sample no from mongo
3. data loader service return remote netref

1. added training utils with custom data loaders with remote rpyc dataservice support
2. fix validation correction dump path
3. cache dataset for precaching before training to memory
4. update dependencies

1. implement dataset augmentation and validation in process
2. added option to skip 'incorrect' annotations in validation data
3. added confirmation on clearing mongo collection
4. added an option to navigate to a given text in the validation ui
5. added a dataset and remote option to trainer to load dataset from directory and remote rpyc service

1. added utility command to export call logs
2. mongo conn accepts port

refactored module structure

1. enabled silece stripping in chunks when recycling audio from asr logs
2. limit asr recycling to 1 min of start audio to get reliable alignments and ignoring agent channel
3. added rev recycler for generating asr dataset from rev transcripts and audio
4. update pydub dependency for silence stripping fn and removing threadpool hardcoded worker count

1. added support for mono/dual channel rev transcripts
2. handle errors when extracting datapoints from rev meta data
3. added suport for annotation only task when dumping ui data

cleanup rev recycle

added option to disable plots during validation

fix skipping null audio and add more verbose logs

respect verbose flag

don't load audio for annotation only ui and keep spoken as text for normal asr validation

1. refactored wav chunk processing method
2. renamed streamlit to validation_ui

show duration on validation of dataset

parallelize data loading from remote

skipping invalid data points

1. removed the transcriber_pretrained/speller from utils
2. introduced get_mongo_coll to get the collection object directly from mongo uri
3. removed processing of correction entries to remove space/upper casing

refactor validation process arguments and logging

1. added a data extraction type argument
2. cleanup/refactor

1. using dataname args for update/fill annotations
2. rename to dump_ui

added support for name/dates/cities call data extraction and more logs

handling non-pnr cases without parens in text data

1. added conv data generator
2. more utils

1. added start delay arg in call recycler
2. implement ui_dump/manifest  writer in call_recycler itself
3. refactored call data point plotter
4. added sample-ui task-ui  on the validation process
5. implemented call-quality stats using corrections from mongo
6. support deleting cursors on mongo
7. implement multiple task support on validation ui based on task_id mongo field

fix 11st to 11th in ordinal

stripping silence on call chunk

1. added option to strip silent chunks
2. computing caller quality based on task-id of corrections

1. fix update-correction to use ui_dump instead of manifest
2. update training params no of checkpoints on chpk frequency

1. split extract all data types in one shot with --extraction-type all flag
2. add notes about diffing split extracted and original data
3. add a nlu conv generator to generate conv data based on nlu utterances and entities
4. add task uid support for dumping corrections
5. abstracted generate date fn

1. added a test generator and slu evaluator
2. ui dump now include gcp results
3. showing default option for more args validation process commands

added evaluation command

clean-up
2020-07-14 12:09:46 +05:30
20 changed files with 769 additions and 1433 deletions

4
.flake8 Normal file
View File

@@ -0,0 +1,4 @@
[flake8]
exclude = docs
ignore = E203, W503
max-line-length = 119

5
Notes.md Normal file
View File

@@ -0,0 +1,5 @@
> Diff after splitting based on type
```
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
```

View File

@@ -7,10 +7,16 @@
# Table of Contents
* [Prerequisites](#prerequisites)
* [Features](#features)
* [Installation](#installation)
* [Usage](#usage)
# Prerequisites
```bash
# apt install libsndfile-dev ffmpeg
```
# Features
* ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) )

View File

@@ -1,104 +0,0 @@
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
app = typer.Typer()
@app.command()
def extract_data(
call_audio_dir: Path = Path("/dataset/png_prod/call_audio"),
call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"),
output_dir: Path = Path("./data"),
dataset_name: str = "png_gcp_2jan",
verbose: bool = False,
):
from pydub import AudioSegment
from .utils import ExtendedPath, asr_data_writer, strip_silence
from lenses import lens
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_event_generator(call_audio_dir):
for wav_path in call_audio_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
meta_path = call_meta_dir / rel_meta_path
events = ExtendedPath(meta_path).read_json()
yield call_wav, wav_path, events
def contains_asr(x):
return "AsrResult" in x
def channel(n):
def filter_func(ev):
return (
ev["AsrResult"]["Channel"] == n
if "Channel" in ev["AsrResult"]
else n == 0
)
return filter_func
def compute_endtime(call_wav, state):
for (i, st) in enumerate(state):
start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0)
transcript = st["AsrResult"]["Alternatives"][0]["Transcript"]
if i + 1 < len(state):
end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
else:
end_time = call_wav.duration_seconds
full_code_seg = call_wav[start_time * 1000 : end_time * 1000]
code_seg = strip_silence(full_code_seg)
code_fb = BytesIO()
code_seg.export(code_fb, format="wav")
code_wav = code_fb.getvalue()
# only starting 1 min audio has reliable alignment ignore rest
if start_time > 60:
if verbose:
print(f'start time over 60 seconds of audio skipping.')
break
# only if some reasonable audio data is present yield it
if code_seg.duration_seconds < 0.5:
if verbose:
print(f'transcript chunk "{transcript}" contains no audio skipping.')
continue
yield transcript, code_seg.duration_seconds, code_wav
def asr_data_generator(call_wav, call_wav_fname, events):
call_wav_0, call_wav_1 = call_wav.split_to_mono()
asr_events = lens["Events"].Each()["Event"].Filter(contains_asr)
call_evs_0 = asr_events.Filter(channel(0)).collect()(events)
# Ignoring agent channel events
# call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
if verbose:
typer.echo(f"processing data points on {call_wav_fname}")
call_data_0 = compute_endtime(call_wav_0, call_evs_0)
# Ignoring agent channel
# call_data_1 = compute_endtime(call_wav_1, call_evs_1)
return call_data_0 # chain(call_data_0, call_data_1)
def generate_call_asr_data():
full_asr_data = []
total_duration = 0
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
asr_data = asr_data_generator(wav, wav_path, ev)
total_duration += wav.duration_seconds
full_asr_data.append(asr_data)
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
typer.echo(f"written {n_dps} data points")
generate_call_asr_data()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -1,521 +0,0 @@
import typer
from pathlib import Path
from enum import Enum
app = typer.Typer()
@app.command()
def export_all_logs(
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
):
from .utils import get_mongo_conn
from collections import defaultdict
from ruamel.yaml import YAML
yaml = YAML()
mongo_coll = get_mongo_conn()
caller_calls = defaultdict(lambda: [])
for call in mongo_coll.find():
sysid = call["SystemID"]
call_uri = f"http://{domain}/calls/{sysid}"
caller = call["Caller"]
caller_calls[caller].append(call_uri)
caller_list = []
for caller in caller_calls:
caller_list.append({"name": caller, "calls": caller_calls[caller]})
output_yaml = {"users": caller_list}
typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
with call_logs_file.open("w") as yf:
yaml.dump(output_yaml, yf)
@app.command()
def export_calls_between(
start_cid: str,
end_cid: str,
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
mongo_port: int = 27017,
):
from collections import defaultdict
from ruamel.yaml import YAML
from .utils import get_mongo_conn
yaml = YAML()
mongo_coll = get_mongo_conn(port=mongo_port)
start_meta = mongo_coll.find_one({"SystemID": start_cid})
end_meta = mongo_coll.find_one({"SystemID": end_cid})
caller_calls = defaultdict(lambda: [])
call_query = mongo_coll.find(
{
"StartTS": {"$gte": start_meta["StartTS"]},
"EndTS": {"$lte": end_meta["EndTS"]},
}
)
for call in call_query:
sysid = call["SystemID"]
call_uri = f"http://{domain}/calls/{sysid}"
caller = call["Caller"]
caller_calls[caller].append(call_uri)
caller_list = []
for caller in caller_calls:
caller_list.append({"name": caller, "calls": caller_calls[caller]})
output_yaml = {"users": caller_list}
typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
with call_logs_file.open("w") as yf:
yaml.dump(output_yaml, yf)
@app.command()
def copy_metas(
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
output_dir: Path = Path("./data"),
meta_dir: Path = Path("/tmp/call_metas"),
):
from lenses import lens
from ruamel.yaml import YAML
from urllib.parse import urlsplit
from shutil import copy2
yaml = YAML()
call_logs = yaml.load(call_logs_file.read_text())
call_meta_dir: Path = output_dir / Path("call_metas")
call_meta_dir.mkdir(exist_ok=True, parents=True)
meta_dir.mkdir(exist_ok=True, parents=True)
def get_cid(uri):
return Path(urlsplit(uri).path).stem
def copy_meta(uri):
cid = get_cid(uri)
saved_meta_path = call_meta_dir / Path(f"{cid}.json")
dest_meta_path = meta_dir / Path(f"{cid}.json")
if not saved_meta_path.exists():
print(f"{saved_meta_path} not found")
copy2(saved_meta_path, dest_meta_path)
def download_meta_audio():
call_lens = lens["users"].Each()["calls"].Each()
call_lens.modify(copy_meta)(call_logs)
download_meta_audio()
class ExtractionType(str, Enum):
flow = "flow"
data = "data"
@app.command()
def analyze(
leaderboard: bool = False,
plot_calls: bool = False,
extract_data: bool = False,
extraction_type: ExtractionType = typer.Option(
ExtractionType.data, show_default=True
),
start_delay: float = 1.5,
download_only: bool = False,
strip_silent_chunks: bool = True,
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
output_dir: Path = Path("./data"),
data_name: str = None,
mongo_uri: str = typer.Option(
"mongodb://localhost:27017/test.calls", show_default=True
),
):
from urllib.parse import urlsplit
from functools import reduce
import boto3
from io import BytesIO
import json
from ruamel.yaml import YAML
import re
from google.protobuf.timestamp_pb2 import Timestamp
from datetime import timedelta
import librosa
import librosa.display
from lenses import lens
from pprint import pprint
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from tqdm import tqdm
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll
from pydub import AudioSegment
from natural.date import compress
matplotlib.rcParams["agg.path.chunksize"] = 10000
matplotlib.use("agg")
yaml = YAML()
s3 = boto3.client("s3")
mongo_collection = get_mongo_coll(mongo_uri)
call_media_dir: Path = output_dir / Path("call_wavs")
call_media_dir.mkdir(exist_ok=True, parents=True)
call_meta_dir: Path = output_dir / Path("call_metas")
call_meta_dir.mkdir(exist_ok=True, parents=True)
call_plot_dir: Path = output_dir / Path("plots")
call_plot_dir.mkdir(exist_ok=True, parents=True)
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
dataset_name = call_logs_file.stem if not data_name else data_name
call_logs = yaml.load(call_logs_file.read_text())
def get_call_meta(call_obj):
meta_s3_uri = call_obj["DataURI"]
s3_event_url_p = urlsplit(meta_s3_uri)
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
if not saved_meta_path.exists():
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
s3.download_file(
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
)
call_metas = json.load(saved_meta_path.open())
return call_metas
def gen_ev_fev_timedelta(fev):
fev_p = Timestamp()
fev_p.FromJsonString(fev["CreatedTS"])
fev_dt = fev_p.ToDatetime()
td_0 = timedelta()
def get_timedelta(ev):
ev_p = Timestamp()
ev_p.FromJsonString(value=ev["CreatedTS"])
ev_dt = ev_p.ToDatetime()
delta = ev_dt - fev_dt
return delta if delta > td_0 else td_0
return get_timedelta
def chunk_n(evs, n):
return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
if extraction_type == ExtractionType.data:
def is_utter_event(ev):
return (
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
and (ev["Type"] != "DEBUG")
and ev["Type"] != "ASR_RESULT"
)
def get_data_points(utter_events, td_fn):
data_points = []
for evs in chunk_n(utter_events, 3):
try:
assert evs[0]["Type"] == "CONV_RESULT"
assert evs[1]["Type"] == "STARTED_SPEAKING"
assert evs[2]["Type"] == "STOPPED_SPEAKING"
start_time = td_fn(evs[1]).total_seconds() - start_delay
end_time = td_fn(evs[2]).total_seconds()
spoken = evs[0]["Msg"]
data_points.append(
{"start_time": start_time, "end_time": end_time, "code": spoken}
)
except AssertionError:
# skipping invalid data_points
pass
return data_points
def text_extractor(spoken):
return (
re.search(r"'(.*)'", spoken).groups(0)[0]
if len(spoken) > 6 and re.search(r"'(.*)'", spoken)
else spoken
)
elif extraction_type == ExtractionType.flow:
def is_final_asr_event_or_spoken(ev):
pld = json.loads(ev["Payload"])
return (
pld["AsrResult"]["Results"][0]["IsFinal"]
if ev["Type"] == "ASR_RESULT"
else True
)
def is_utter_event(ev):
return (
ev["Author"] == "CONV"
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
) and (ev["Type"] != "DEBUG")
def get_data_points(utter_events, td_fn):
data_points = []
for evs in chunk_n(utter_events, 4):
try:
assert len(evs) == 4
assert evs[0]["Type"] == "CONV_RESULT"
assert evs[1]["Type"] == "STARTED_SPEAKING"
assert evs[2]["Type"] == "ASR_RESULT"
assert evs[3]["Type"] == "STOPPED_SPEAKING"
start_time = td_fn(evs[1]).total_seconds() - start_delay
end_time = td_fn(evs[2]).total_seconds()
conv_msg = evs[0]["Msg"]
if "full name" in conv_msg.lower():
pld = json.loads(evs[2]["Payload"])
spoken = pld["AsrResult"]["Results"][0]["Alternatives"][0][
"Transcript"
]
data_points.append(
{
"start_time": start_time,
"end_time": end_time,
"code": spoken,
}
)
except AssertionError:
# skipping invalid data_points
pass
return data_points
def text_extractor(spoken):
return spoken
def process_call(call_obj):
call_meta = get_call_meta(call_obj)
call_events = call_meta["Events"]
def is_writer_uri_event(ev):
return ev["Author"] == "AUDIO_WRITER" and "s3://" in ev["Msg"]
writer_events = list(filter(is_writer_uri_event, call_events))
s3_wav_url = re.search(r"(s3://.*)", writer_events[0]["Msg"]).groups(0)[0]
s3_wav_url_p = urlsplit(s3_wav_url)
def is_first_audio_ev(state, ev):
if state[0]:
return state
else:
return (ev["Author"] == "GATEWAY" and ev["Type"] == "AUDIO", ev)
(_, first_audio_ev) = reduce(is_first_audio_ev, call_events, (False, {}))
get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
uevs = list(filter(is_utter_event, call_events))
ev_count = len(uevs)
utter_events = uevs[: ev_count - ev_count % 3]
saved_wav_path = call_media_dir / Path(Path(s3_wav_url_p.path).name)
if not saved_wav_path.exists():
print(f"downloading : {saved_wav_path} from {s3_wav_url}")
s3.download_file(
s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
)
return {
"wav_path": saved_wav_path,
"num_samples": len(utter_events) // 3,
"meta": call_obj,
"first_event_fn": get_ev_fev_timedelta,
"utter_events": utter_events,
}
def get_cid(uri):
return Path(urlsplit(uri).path).stem
def ensure_call(uri):
cid = get_cid(uri)
meta = mongo_collection.find_one({"SystemID": cid})
process_meta = process_call(meta)
return process_meta
def retrieve_processed_callmeta(uri):
cid = get_cid(uri)
meta = mongo_collection.find_one({"SystemID": cid})
duration = meta["EndTS"] - meta["StartTS"]
process_meta = process_call(meta)
data_points = get_data_points(
process_meta["utter_events"], process_meta["first_event_fn"]
)
process_meta["data_points"] = data_points
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
def retrieve_callmeta(call_uri):
uri = call_uri["call_uri"]
name = call_uri["name"]
cid = get_cid(uri)
meta = mongo_collection.find_one({"SystemID": cid})
duration = meta["EndTS"] - meta["StartTS"]
process_meta = process_call(meta)
data_points = get_data_points(
process_meta["utter_events"], process_meta["first_event_fn"]
)
process_meta["data_points"] = data_points
return {
"url": uri,
"name": name,
"meta": meta,
"duration": duration,
"process": process_meta,
}
def download_meta_audio():
call_lens = lens["users"].Each()["calls"].Each()
call_lens.modify(ensure_call)(call_logs)
def plot_calls_data():
def plot_data_points(y, sr, data_points, file_path):
plt.figure(figsize=(16, 12))
librosa.display.waveplot(y=y, sr=sr)
for dp in data_points:
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
plt.axvspan(start, end, color="green", alpha=0.2)
text_pos = (start + end) / 2
plt.text(
text_pos,
0.25,
f"{code}",
rotation=90,
horizontalalignment="center",
verticalalignment="center",
)
plt.title("Datapoints")
plt.savefig(file_path, format="png")
return file_path
def plot_call(call_obj):
saved_wav_path, data_points, sys_id = (
call_obj["process"]["wav_path"],
call_obj["process"]["data_points"],
call_obj["meta"]["SystemID"],
)
file_path = call_plot_dir / Path(sys_id).with_suffix(".png")
if not file_path.exists():
print(f"plotting: {file_path}")
(y, sr) = librosa.load(saved_wav_path)
plot_data_points(y, sr, data_points, str(file_path))
return file_path
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
# call_plot_data = call_lens.collect()(call_stats)
call_plots = call_lens.modify(plot_call)(call_stats)
# with ThreadPoolExecutor(max_workers=20) as exe:
# print('starting all plot tasks')
# responses = [exe.submit(plot_call, w) for w in call_plot_data]
# print('submitted all plot tasks')
# call_plots = [r.result() for r in responses]
pprint(call_plots)
def extract_data_points():
if strip_silent_chunks:
def audio_process(seg):
return strip_silence(seg)
else:
def audio_process(seg):
return seg
def gen_data_values(saved_wav_path, data_points, caller_name):
call_seg = (
AudioSegment.from_wav(saved_wav_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
for dp_id, dp in enumerate(data_points):
start, end, spoken = dp["start_time"], dp["end_time"], dp["code"]
spoken_seg = audio_process(call_seg[start * 1000 : end * 1000])
spoken_fb = BytesIO()
spoken_seg.export(spoken_fb, format="wav")
spoken_wav = spoken_fb.getvalue()
# search for actual pnr code and handle plain codes as well
extracted_code = text_extractor(spoken)
if strip_silent_chunks and spoken_seg.duration_seconds < 0.5:
print(f'transcript chunk "{spoken}" contains no audio skipping.')
continue
yield extracted_code, spoken_seg.duration_seconds, spoken_wav, caller_name, spoken_seg
call_lens = lens["users"].Each()["calls"].Each()
def assign_user_call(uc):
return (
lens["calls"]
.Each()
.modify(lambda c: {"call_uri": c, "name": uc["name"]})(uc)
)
user_call_logs = lens["users"].Each().modify(assign_user_call)(call_logs)
call_stats = call_lens.modify(retrieve_callmeta)(user_call_logs)
call_objs = call_lens.collect()(call_stats)
def data_source():
for call_obj in tqdm(call_objs):
saved_wav_path, data_points, name = (
call_obj["process"]["wav_path"],
call_obj["process"]["data_points"],
call_obj["name"],
)
for dp in gen_data_values(saved_wav_path, data_points, name):
yield dp
ui_dump_manifest_writer(call_asr_data, dataset_name, data_source())
def show_leaderboard():
def compute_user_stats(call_stat):
n_samples = (
lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
)
n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
return {
"num_samples": n_samples,
"duration": n_duration.total_seconds(),
"samples_rate": n_samples / n_duration.total_seconds(),
"duration_str": compress(n_duration, pad=" "),
"name": call_stat["name"],
}
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
user_stats = lens["users"].Each().modify(compute_user_stats)(call_stats)
leader_df = (
pd.DataFrame(user_stats["users"])
.sort_values(by=["duration"], ascending=False)
.reset_index(drop=True)
)
leader_df["rank"] = leader_df.index + 1
leader_board = leader_df.rename(
columns={
"rank": "Rank",
"num_samples": "Count",
"name": "Name",
"samples_rate": "SpeechRate",
"duration_str": "Duration",
}
)[["Rank", "Name", "Count", "Duration"]]
print(
"""ASR Dataset Leaderboard :
---------------------------------"""
)
print(leader_board.to_string(index=False))
if download_only:
download_meta_audio()
return
if leaderboard:
show_leaderboard()
if plot_calls:
plot_calls_data()
if extract_data:
extract_data_points()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -1,68 +0,0 @@
import typer
from pathlib import Path
from random import randrange
from itertools import product
from math import floor
app = typer.Typer()
@app.command()
def export_conv_json(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
):
from .utils import ExtendedPath
conv_data = ExtendedPath(conv_src).read_json()
days = [i for i in range(1, 32)]
months = [
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
]
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
def ordinal(n):
return "%d%s" % (
n,
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
)
def canon_vars(d, m):
return [
ordinal(d) + " " + m,
m + " " + ordinal(d),
ordinal(d) + " of " + m,
m + " the " + ordinal(d),
str(d) + " " + m,
m + " " + str(d),
]
day_months = [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
conv_data["dates"] = day_months
def dates_data_gen():
i = randrange(len(day_months))
return day_months[i]
ExtendedPath(conv_dest).write_json(conv_data)
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -23,7 +23,7 @@ def fixate_data(dataset_path: Path):
@app.command()
def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
reader_list = []
abs_manifest_path = Path("abs_manifest.json")
for dataset_path in src_dataset_paths:
@@ -38,9 +38,9 @@ def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
def split_data(dataset_path: Path, test_size: float = 0.1):
manifest_path = dataset_path / Path("abs_manifest.json")
asr_data = list(asr_manifest_reader(manifest_path))
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
train_data, test_data = train_test_split(asr_data, test_size=test_size)
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_data)
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_data)
@app.command()
@@ -52,9 +52,9 @@ def validate_data(dataset_path: Path):
data_file = dataset_path / Path(mf_type)
print(f"validating {data_file}.")
with Path(data_file).open("r") as pf:
pnr_jsonl = pf.readlines()
data_jsonl = pf.readlines()
duration = 0
for (i, s) in enumerate(pnr_jsonl):
for (i, s) in enumerate(data_jsonl):
try:
d = json.loads(s)
duration += d["duration"]

View File

@@ -0,0 +1,93 @@
from rastrik.proto.callrecord_pb2 import CallRecord
import gzip
from pydub import AudioSegment
from .utils import ui_dump_manifest_writer, strip_silence
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
app = typer.Typer()
@app.command()
def extract_manifest(
call_log_dir: Path = Path("./data/call_audio"),
output_dir: Path = Path("./data"),
dataset_name: str = "grassroot_pizzahut_v1",
caller_name: str = "grassroot",
verbose: bool = False,
):
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_pb2_generator(log_dir):
for wav_path in log_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
meta_path = wav_path.with_suffix(".pb2.gz")
yield call_wav, wav_path, meta_path
def read_event(call_wav, log_file):
call_wav_0, call_wav_1 = call_wav.split_to_mono()
with gzip.open(log_file, "rb") as log_h:
record_data = log_h.read()
cr = CallRecord()
cr.ParseFromString(record_data)
first_audio_event_timestamp = next(
(
i
for i in cr.events
if i.WhichOneof("event_type") == "call_event"
and i.call_event.WhichOneof("event_type") == "call_audio"
)
).timestamp.ToDatetime()
speech_events = [
i
for i in cr.events
if i.WhichOneof("event_type") == "speech_event"
and i.speech_event.WhichOneof("event_type") == "asr_final"
]
previous_event_timestamp = (
first_audio_event_timestamp - first_audio_event_timestamp
)
for index, each_speech_events in enumerate(speech_events):
asr_final = each_speech_events.speech_event.asr_final
speech_timestamp = each_speech_events.timestamp.ToDatetime()
actual_timestamp = speech_timestamp - first_audio_event_timestamp
start_time = previous_event_timestamp.total_seconds() * 1000
end_time = actual_timestamp.total_seconds() * 1000
audio_segment = strip_silence(call_wav_1[start_time:end_time])
code_fb = BytesIO()
audio_segment.export(code_fb, format="wav")
wav_data = code_fb.getvalue()
previous_event_timestamp = actual_timestamp
duration = (end_time - start_time) / 1000
yield asr_final, duration, wav_data, "grassroot", audio_segment
def generate_call_asr_data():
full_data = []
total_duration = 0
for wav, wav_path, pb2_path in wav_pb2_generator(call_log_dir):
asr_data = read_event(wav, pb2_path)
total_duration += wav.duration_seconds
full_data.append(asr_data)
n_calls = len(full_data)
typer.echo(f"loaded {n_calls} calls of duration {total_duration}s")
n_dps = ui_dump_manifest_writer(call_asr_data, dataset_name, chain(*full_data))
typer.echo(f"written {n_dps} data points")
generate_call_asr_data()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -1,175 +0,0 @@
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
import re
app = typer.Typer()
@app.command()
def extract_data(
call_audio_dir: Path = typer.Option(Path("/dataset/rev/wavs"), show_default=True),
call_meta_dir: Path = typer.Option(Path("/dataset/rev/jsons"), show_default=True),
output_dir: Path = typer.Option(Path("./data"), show_default=True),
dataset_name: str = typer.Option("rev_transribed", show_default=True),
verbose: bool = False,
):
from pydub import AudioSegment
from .utils import ExtendedPath, asr_data_writer, strip_silence
from lenses import lens
import datetime
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_event_generator(call_audio_dir):
for wav_path in call_audio_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
meta_path = call_meta_dir / rel_meta_path
if meta_path.exists():
events = ExtendedPath(meta_path).read_json()
yield call_wav, wav_path, events
else:
if verbose:
typer.echo(f"missing json corresponding to {wav_path}")
def contains_asr(x):
return "AsrResult" in x
def channel(n):
def filter_func(ev):
return (
ev["AsrResult"]["Channel"] == n
if "Channel" in ev["AsrResult"]
else n == 0
)
return filter_func
def time_to_msecs(time_str):
return (
datetime.datetime.strptime(time_str, "%H:%M:%S,%f")
- datetime.datetime(1900, 1, 1)
).total_seconds() * 1000
def process_utterance_chunk(wav_seg, start_time, end_time, monologue):
# offset by 1sec left side to include vad? discarded audio
full_tscript_wav_seg = wav_seg[
time_to_msecs(start_time) - 1000 : time_to_msecs(end_time) # + 1000
]
tscript_wav_seg = strip_silence(full_tscript_wav_seg)
tscript_wav_fb = BytesIO()
tscript_wav_seg.export(tscript_wav_fb, format="wav")
tscript_wav = tscript_wav_fb.getvalue()
text = "".join(lens["elements"].Each()["value"].collect()(monologue))
text_clean = re.sub(r"\[.*\]", "", text)
return tscript_wav, tscript_wav_seg.duration_seconds, text_clean
def dual_asr_data_generator(wav_seg, wav_path, meta):
left_audio, right_audio = wav_seg.split_to_mono()
channel_map = {"Agent": right_audio, "Client": left_audio}
monologues = lens["monologues"].Each().collect()(meta)
for monologue in monologues:
# print(monologue["speaker_name"])
speaker_channel = channel_map.get(monologue["speaker_name"])
if not speaker_channel:
if verbose:
print(
f'unknown speaker tag {monologue["speaker_name"]} in wav:{wav_path} skipping.'
)
continue
try:
start_time = (
lens["elements"]
.Each()
.Filter(lambda x: "timestamp" in x)["timestamp"]
.collect()(monologue)[0]
)
end_time = (
lens["elements"]
.Each()
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
.collect()(monologue)[-1]
)
except IndexError:
if verbose:
print(
f"error when loading timestamp events in wav:{wav_path} skipping."
)
continue
tscript_wav, seg_dur, text_clean = process_utterance_chunk(
speaker_channel, start_time, end_time, monologue
)
if seg_dur < 0.5:
if verbose:
print(
f'transcript chunk "{text_clean}" contains no audio in {wav_path} skipping.'
)
continue
yield text_clean, seg_dur, tscript_wav
def mono_asr_data_generator(wav_seg, wav_path, meta):
monologues = lens["monologues"].Each().collect()(meta)
for monologue in monologues:
try:
start_time = (
lens["elements"]
.Each()
.Filter(lambda x: "timestamp" in x)["timestamp"]
.collect()(monologue)[0]
)
end_time = (
lens["elements"]
.Each()
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
.collect()(monologue)[-1]
)
except IndexError:
if verbose:
print(
f"error when loading timestamp events in wav:{wav_path} skipping."
)
continue
tscript_wav, seg_dur, text_clean = process_utterance_chunk(
wav_seg, start_time, end_time, monologue
)
if seg_dur < 0.5:
if verbose:
print(
f'transcript chunk "{text_clean}" contains no audio in {wav_path} skipping.'
)
continue
yield text_clean, seg_dur, tscript_wav
def generate_rev_asr_data():
full_asr_data = []
total_duration = 0
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
if wav.channels > 2:
print(f"skipping many channel audio {wav_path}")
asr_data_generator = (
mono_asr_data_generator
if wav.channels == 1
else dual_asr_data_generator
)
asr_data = asr_data_generator(wav, wav_path, ev)
total_duration += wav.duration_seconds
full_asr_data.append(asr_data)
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
typer.echo(f"written {n_dps} data points")
generate_rev_asr_data()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -1,52 +0,0 @@
from logging import getLogger
from google.cloud import texttospeech
LOGGER = getLogger("googletts")
class GoogleTTS(object):
def __init__(self):
self.client = texttospeech.TextToSpeechClient()
def text_to_speech(self, text: str, params: dict) -> bytes:
tts_input = texttospeech.types.SynthesisInput(ssml=text)
voice = texttospeech.types.VoiceSelectionParams(
language_code=params["language"], name=params["name"]
)
audio_config = texttospeech.types.AudioConfig(
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
sample_rate_hertz=params["sample_rate"],
)
response = self.client.synthesize_speech(tts_input, voice, audio_config)
audio_content = response.audio_content
return audio_content
@classmethod
def voice_list(cls):
"""Lists the available voices."""
client = cls().client
# Performs the list voices request
voices = client.list_voices()
results = []
for voice in voices.voices:
supported_eng_langs = [
lang for lang in voice.language_codes if lang[:2] == "en"
]
if len(supported_eng_langs) > 0:
lang = ",".join(supported_eng_langs)
else:
continue
ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender)
results.append(
{
"name": voice.name,
"language": lang,
"gender": ssml_gender.name,
"engine": "wavenet" if "Wav" in voice.name else "standard",
"sample_rate": voice.natural_sample_rate_hertz,
}
)
return results

View File

@@ -1,26 +0,0 @@
"""
TTSClient Abstract Class
"""
from abc import ABC, abstractmethod
class TTSClient(ABC):
"""
Base class for TTS
"""
@abstractmethod
def text_to_speech(self, text: str, num_channels: int, sample_rate: int,
audio_encoding) -> bytes:
"""
convert text to bytes
Arguments:
text {[type]} -- text to convert
channel {[type]} -- output audio bytes channel setting
width {[type]} -- width of audio bytes
rate {[type]} -- rare for audio bytes
Returns:
[type] -- [description]
"""

View File

@@ -1,62 +0,0 @@
# import io
# import sys
# import json
import argparse
import logging
from pathlib import Path
from .utils import random_pnr_generator, asr_data_writer
from .tts.googletts import GoogleTTS
from tqdm import tqdm
import random
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def pnr_tts_streamer(count):
google_voices = GoogleTTS.voice_list()
gtts = GoogleTTS()
for pnr_code in tqdm(random_pnr_generator(count)):
tts_code = f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
param = random.choice(google_voices)
param["sample_rate"] = 24000
param["num_channels"] = 1
wav_data = gtts.text_to_speech(text=tts_code, params=param)
audio_dur = len(wav_data[44:]) / (2 * 24000)
yield pnr_code, audio_dur, wav_data
def generate_asr_data_fromtts(output_dir, dataset_name, count):
asr_data_writer(output_dir, dataset_name, pnr_tts_streamer(count))
def arg_parser():
prog = Path(__file__).stem
parser = argparse.ArgumentParser(
prog=prog, description=f"generates asr training data"
)
parser.add_argument(
"--output_dir",
type=Path,
default=Path("./train/asr_data"),
help="directory to output asr data",
)
parser.add_argument(
"--count", type=int, default=3, help="number of datapoints to generate"
)
parser.add_argument(
"--dataset_name", type=str, default="pnr_data", help="name of the dataset"
)
return parser
def main():
parser = arg_parser()
args = parser.parse_args()
generate_asr_data_fromtts(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -1,92 +0,0 @@
import pandas as pd
def compute_pnr_name_city():
data = pd.read_csv("./customer_utterance_processing/customer_provide_answer.csv")
def unique_pnr_count():
pnr_data = data[data["Input.Answer"] == "ZZZZZZ"]
unique_pnr_set = {
t
for n in range(1, 5)
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
if "ZZZZZZ" in t
}
return len(unique_pnr_set)
def unique_name_count():
pnr_data = data[data["Input.Answer"] == "John Doe"]
unique_pnr_set = {
t
for n in range(1, 5)
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
if "John Doe" in t
}
return len(unique_pnr_set)
def unique_city_count():
pnr_data = data[data["Input.Answer"] == "Heathrow Airport"]
unique_pnr_set = {
t
for n in range(1, 5)
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
if "Heathrow Airport" in t
}
return len(unique_pnr_set)
def unique_entity_count(entity_template_tags):
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
entity_data = data
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return len(unique_entity_set)
print('PNR', unique_pnr_count())
print('Name', unique_name_count())
print('City', unique_city_count())
print('Payment', unique_entity_count(['KPay', 'ZPay', 'Credit Card']))
def compute_date():
entity_template_tags = ['27 january', 'December 18']
data = pd.read_csv("./customer_utterance_processing/customer_provide_departure.csv")
# data.sample(10)
def unique_entity_count(entity_template_tags):
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
entity_data = data
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return len(unique_entity_set)
print('Date', unique_entity_count(entity_template_tags))
def compute_option():
entity_template_tag = 'third'
data = pd.read_csv("./customer_utterance_processing/customer_provide_flight_selection.csv")
def unique_entity_count():
entity_data = data[data['Input.Prompt'] == entity_template_tag]
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if entity_template_tag in t
}
return len(unique_entity_set)
print('Option', unique_entity_count())
compute_pnr_name_city()
compute_date()
compute_option()

View File

@@ -1,22 +1,20 @@
import numpy as np
import wave
import io
import os
import json
import wave
from pathlib import Path
from functools import partial
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor
import pymongo
from slugify import slugify
from uuid import uuid4
from num2words import num2words
from jasper.client import transcribe_gen
from nemo.collections.asr.metrics import word_error_rate
import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm
from functools import partial
from concurrent.futures import ThreadPoolExecutor
def manifest_str(path, dur, text):
@@ -36,27 +34,8 @@ def wav_bytes(audio_bytes, frame_rate=24000):
return wf_b.getvalue()
def random_pnr_generator(count=10000):
LENGTH = 3
# alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
alphabet = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
numeric = list("0123456789")
np_alphabet = np.array(alphabet, dtype="|S1")
np_numeric = np.array(numeric, dtype="|S1")
np_alpha_codes = np.random.choice(np_alphabet, [count, LENGTH])
np_num_codes = np.random.choice(np_numeric, [count, LENGTH])
np_code_seed = np.concatenate((np_alpha_codes, np_num_codes), axis=1).T
np.random.shuffle(np_code_seed)
np_codes = np_code_seed.T
codes = [(b"".join(np_codes[i])).decode("utf-8") for i in range(len(np_codes))]
return codes
def alnum_to_asr_tokens(text):
letters = " ".join(list(text))
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
return ("".join(num_tokens)).lower()
def tscript_uuid_fname(transcript):
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
@@ -67,11 +46,11 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for transcript, audio_dur, wav_data in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
fname = tscript_uuid_fname(transcript)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
rel_pnr_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
rel_data_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
if verbose:
print(f"writing '{transcript}' of duration {audio_dur}")
@@ -79,23 +58,10 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
return num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
ui_dump_file = dataset_dir / Path("ui_dump.json")
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
ui_dump = {
"use_domain_asr": False,
"annotation_only": False,
"enable_plots": True,
"data": [],
}
data_funcs = []
transcriber_pretrained = transcribe_gen(asr_port=8044)
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
def data_fn(
transcript,
@@ -106,17 +72,16 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
fname,
audio_path,
num_datapoints,
rel_pnr_path,
rel_data_path,
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
)
png_path = Path(fname).with_suffix(".png")
wav_plot_path = dataset_dir / Path("wav_plots") / png_path
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
return {
"audio_filepath": str(rel_pnr_path),
"audio_filepath": str(rel_data_path),
"duration": round(audio_dur, 1),
"text": transcript,
"real_idx": num_datapoints,
@@ -129,14 +94,15 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"plot_path": str(wav_plot_path),
}
num_datapoints = 0
data_funcs = []
transcriber_pretrained = transcribe_gen(asr_port=8044)
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
audio_path = str(audio_file)
rel_pnr_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
mf.write(manifest)
rel_data_path = audio_file.relative_to(dataset_dir)
data_funcs.append(
partial(
data_fn,
@@ -148,33 +114,43 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
fname,
audio_path,
num_datapoints,
rel_pnr_path,
rel_data_path,
)
)
num_datapoints += 1
with ThreadPoolExecutor() as exe:
print("starting all plot/transcription tasks")
dump_data = list(
tqdm(
exe.map(lambda x: x(), data_funcs),
position=0,
leave=True,
total=len(data_funcs),
ui_data = parallel_apply(lambda x: x(), data_funcs)
return ui_data, num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
dump_data, num_datapoints = ui_data_generator(
output_dir, dataset_name, asr_data_source, verbose=verbose
)
)
ui_dump["data"] = dump_data
ExtendedPath(ui_dump_file).write_json(ui_dump)
asr_manifest = dataset_dir / Path("manifest.json")
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for d in dump_data:
rel_data_path = d["audio_filepath"]
audio_dur = d["duration"]
transcript = d["text"]
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
ui_dump_file = dataset_dir / Path("ui_dump.json")
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
return num_datapoints
def asr_manifest_reader(data_manifest_path: Path):
print(f"reading manifest from {data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [json.loads(v) for v in pnr_jsonl]
for p in pnr_data:
data_jsonl = pf.readlines()
data_data = [json.loads(v) for v in data_jsonl]
for p in data_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["chars"] = Path(p["audio_filepath"]).stem
p["text"] = p["text"].strip()
yield p
@@ -188,6 +164,32 @@ def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
mf.write(manifest)
def asr_test_writer(out_file_path: Path, source):
def dd_str(dd, idx):
path = dd["audio_filepath"]
# dur = dd["duration"]
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
res_file = out_file_path.with_suffix(".result.json")
with out_file_path.open("w") as of:
print(f"opening {out_file_path} for writing test")
results = []
idx = 0
for ui_dd in source:
results.append(ui_dd)
out_str = dd_str(ui_dd, idx)
of.write(out_str)
idx += 1
of.write("DO_HANGUP\n")
ExtendedPath(res_file).write_json(results)
def batch(iterable, n=1):
ls = len(iterable)
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
class ExtendedPath(type(Path())):
"""docstring for ExtendedPath."""
@@ -203,12 +205,6 @@ class ExtendedPath(type(Path())):
return json.dump(data, jf, indent=2)
def get_mongo_coll(uri="mongodb://localhost:27017/test.calls"):
ud = pymongo.uri_parser.parse_uri(uri)
conn = pymongo.MongoClient(uri)
return conn[ud["database"]][ud["collection"]]
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
mongo_uri = f"mongodb://{mongo_host}:{port}/"
@@ -234,10 +230,12 @@ def plot_seg(wav_plot_path, audio_path):
fig.savefig(wav_plot_f, format="png", dpi=50)
def main():
for c in random_pnr_generator():
print(c)
if __name__ == "__main__":
main()
def parallel_apply(fn, iterable, workers=8):
with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")
return [
res
for res in tqdm(
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
)
]

View File

@@ -1,16 +1,14 @@
import json
import shutil
from pathlib import Path
from enum import Enum
import typer
from tqdm import tqdm
from ..utils import (
alnum_to_asr_tokens,
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
tscript_uuid_fname,
get_mongo_conn,
plot_seg,
)
@@ -18,9 +16,7 @@ from ..utils import (
app = typer.Typer()
def preprocess_datapoint(
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
):
def preprocess_datapoint(idx, rel_root, sample):
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.client import transcribe_gen
@@ -30,12 +26,7 @@ def preprocess_datapoint(
res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
if use_domain_asr:
res["spoken"] = alnum_to_asr_tokens(res["text"])
else:
res["spoken"] = res["text"]
res["utterance_id"] = audio_path.stem
if not annotation_only:
transcriber_pretrained = transcribe_gen(asr_port=8044)
aud_seg = (
@@ -45,16 +36,7 @@ def preprocess_datapoint(
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate(
[res["text"]], [res["pretrained_asr"]]
)
if use_domain_asr:
transcriber_speller = transcribe_gen(asr_port=8045)
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]]
)
if enable_plots:
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
@@ -68,70 +50,59 @@ def preprocess_datapoint(
@app.command()
def dump_ui(
data_name: str = typer.Option("call_alphanum", show_default=True),
data_name: str = typer.Option("dataname", show_default=True),
dataset_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/valiation_data"),
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
use_domain_asr: bool = False,
annotation_only: bool = False,
enable_plots: bool = True,
):
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from io import BytesIO
from pydub import AudioSegment
from ..utils import ui_data_generator
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
dump_path: Path = dump_dir / Path(data_name) / dump_fname
plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}")
def asr_data_source_gen():
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_funcs = [
partial(
preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
annotation_only,
enable_plots,
data_jsonl = pf.readlines()
for v in data_jsonl:
sample = json.loads(v)
rel_root = data_manifest_path.parent
res = dict(sample)
audio_path = rel_root / Path(sample["audio_filepath"])
audio_segment = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
for i, v in enumerate(pnr_jsonl)
]
wav_plot_path = (
rel_root
/ Path("wav_plots")
/ Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
code_fb = BytesIO()
audio_segment.export(code_fb, format="wav")
wav_data = code_fb.getvalue()
duration = audio_segment.duration_seconds
asr_final = res["text"]
yield asr_final, duration, wav_data, "caller", audio_segment
def exec_func(f):
return f()
with ThreadPoolExecutor() as exe:
print("starting all preprocess tasks")
pnr_data = filter(
None,
list(
tqdm(
exe.map(exec_func, pnr_funcs),
position=0,
leave=True,
total=len(pnr_funcs),
dump_data, num_datapoints = ui_data_generator(
dataset_dir, data_name, asr_data_source_gen()
)
),
)
if annotation_only:
result = list(pnr_data)
else:
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
ui_config = {
"use_domain_asr": use_domain_asr,
"annotation_only": annotation_only,
"enable_plots": enable_plots,
"data": result,
}
ExtendedPath(dump_path).write_json(ui_config)
ui_dump_file = dataset_dir / Path("ui_dump.json")
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
@app.command()
def sample_ui(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
sample_count: int = typer.Option(80, show_default=True),
@@ -157,7 +128,7 @@ def sample_ui(
@app.command()
def task_ui(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
task_count: int = typer.Option(4, show_default=True),
@@ -180,14 +151,18 @@ def task_ui(
@app.command()
def dump_corrections(
data_name: str = typer.Option("call_alphanum", show_default=True),
task_uid: str,
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("corrections.json"),
):
dump_path = dump_dir / Path(data_name) / dump_fname
col = get_mongo_conn(col="asr_validation")
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find(
{"type": "correction", "task_id": task_id}, projection={"_id": False}
)
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@@ -195,7 +170,7 @@ def dump_corrections(
@app.command()
def caller_quality(
task_uid: str,
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("ui_dump.json"),
correction_fname: Path = Path("corrections.json"),
@@ -231,7 +206,7 @@ def caller_quality(
@app.command()
def fill_unannotated(
data_name: str = typer.Option("call_alphanum", show_default=True),
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
dump_file: Path = Path("ui_dump.json"),
corrections_file: Path = Path("corrections.json"),
@@ -252,71 +227,64 @@ def fill_unannotated(
)
class ExtractionType(str, Enum):
date = "dates"
city = "cities"
name = "names"
@app.command()
def split_extract(
data_name: str = typer.Option("call_alphanum", show_default=True),
data_name: str = typer.Option("dataname", show_default=True),
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
# dump_dir: Path = Path("./data/valiation_data"),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
conv_data_path: Path = Path("./data/conv_data.json"),
extraction_type: ExtractionType = ExtractionType.date,
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(
Path("./data/conv_data.json"), show_default=True
),
extraction_type: str = "all",
):
import shutil
def get_conv_data(cdp):
from itertools import product
data_manifest_path = dump_dir / Path(data_name) / manifest_file
conv_data = ExtendedPath(conv_data_path).read_json()
conv_data = json.load(cdp.open())
days = [str(i) for i in range(1, 32)]
months = conv_data["months"]
day_months = {d + " " + m for d, m in product(days, months)}
return {
"cities": set(conv_data["cities"]),
"names": set(conv_data["names"]),
"dates": day_months,
}
dest_data_name = data_name + "_" + extraction_type.value
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
conv_data = get_conv_data(conv_data_path)
extraction_vals = conv_data[extraction_type.value]
def extract_data_of_type(extraction_key):
extraction_vals = conv_data[extraction_key]
dest_data_name = data_name + "_" + extraction_key.lower()
manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = manifest_dir / Path(dest_data_name)
dest_data_dir = dump_dir / Path(dest_data_name)
dest_data_dir.mkdir(exist_ok=True, parents=True)
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file
dest_ui_dir = dump_dir / Path(dest_data_name)
dest_ui_dir.mkdir(exist_ok=True, parents=True)
dest_ui_path = dest_ui_dir / dump_file
dest_correction_path = dest_ui_dir / corrections_file
dest_ui_path = dest_data_dir / dump_file
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
shutil.copy(
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
)
yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
ui_data_path = dump_dir / Path(data_name) / dump_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_data = json.load(ui_data_path.open())["data"]
orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(
filter(lambda u: u["text"] in extraction_vals, ui_data)
)
final_data = []
for i, d in enumerate(extracted_ui_data):
d["real_idx"] = i
final_data.append(d)
orig_ui_data["data"] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file:
dest_correction_path = dest_data_dir / corrections_file
corrections_path = dump_dir / Path(data_name) / corrections_file
corrections = json.load(corrections_path.open())
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
ExtendedPath(dest_ui_path).write_json(extracted_ui_data)
extracted_corrections = list(
filter(
lambda c: c["code"] in file_ui_map
@@ -326,23 +294,29 @@ def split_extract(
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == "all":
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
extract_data_of_type(extraction_type.value)
@app.command()
def update_corrections(
data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
manifest_dir: Path = Path("./data/asr_data"),
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
# data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
# corrections_path: Path = Path("./data/valiation_data/corrections.json"),
skip_incorrect: bool = True,
ui_dump_file: Path = Path("ui_dump.json"),
skip_incorrect: bool = typer.Option(True, show_default=True),
):
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
data_manifest_path = dump_dir / Path(data_name) / manifest_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
def correct_manifest(manifest_data_gen, corrections_path):
corrections = json.load(corrections_path.open())
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
@@ -355,36 +329,40 @@ def update_corrections(
# for d in manifest_data_gen:
# if d["chars"] in incorrect_set:
# d["audio_path"].unlink()
renamed_set = set()
for d in manifest_data_gen:
if d["chars"] in correct_set:
# renamed_set = set()
for d in ui_data:
if d["utterance_id"] in correct_set:
yield {
"audio_filepath": d["audio_filepath"],
"duration": d["duration"],
"text": d["text"],
}
elif d["chars"] in correction_map:
correct_text = correction_map[d["chars"]]
elif d["utterance_id"] in correction_map:
correct_text = correction_map[d["utterance_id"]]
if skip_incorrect:
print(
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
)
else:
renamed_set.add(correct_text)
new_name = str(Path(correct_text).with_suffix(".wav"))
d["audio_path"].replace(d["audio_path"].with_name(new_name))
orig_audio_path = Path(d["audio_path"])
new_name = str(
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
)
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
yield {
"audio_filepath": new_filepath,
"duration": d["duration"],
"text": alnum_to_asr_tokens(correct_text),
"text": correct_text,
}
else:
orig_audio_path = Path(d["audio_path"])
# don't delete if another correction points to an old file
if d["chars"] not in renamed_set:
d["audio_path"].unlink()
else:
print(f'skipping deletion of correction:{d["chars"]}')
# if d["text"] not in renamed_set:
orig_audio_path.unlink()
# else:
# print(f'skipping deletion of correction:{d["text"]}')
typer.echo(f"Using data manifest:{data_manifest_path}")
dataset_dir = data_manifest_path.parent
@@ -393,8 +371,8 @@ def update_corrections(
if not backup_dir.exists():
typer.echo(f"backing up to :{backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir))
manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
# manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
new_data_manifest_path.replace(data_manifest_path)

View File

@@ -42,7 +42,9 @@ if not hasattr(st, "mongo_connected"):
upsert=True,
)
def set_task_fn(mf_path):
def set_task_fn(mf_path, task_id):
if task_id:
st.task_id = task_id
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
if not task_path.exists():
print(f"creating task lock at {task_path}")
@@ -66,26 +68,22 @@ def load_ui_data(validation_ui_data_path: Path):
@app.command()
def main(manifest: Path):
st.set_task(manifest)
def main(manifest: Path, task_id: str = ""):
st.set_task(manifest, task_id)
ui_config = load_ui_data(manifest)
asr_data = ui_config["data"]
use_domain_asr = ui_config.get("use_domain_asr", True)
annotation_only = ui_config.get("annotation_only", False)
enable_plots = ui_config.get("enable_plots", True)
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
title_type = "Speller " if use_domain_asr else ""
task_uid = st.task_id.rsplit("-", 1)[1]
if annotation_only:
st.title(f"ASR Annotation - # {task_uid}")
else:
st.title(f"ASR {title_type}Validation - # {task_uid}")
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
st.title(f"ASR Validation - # {task_uid}")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
@@ -94,18 +92,12 @@ def main(manifest: Path):
st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
if not annotation_only:
if use_domain_asr:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:")
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if "caller" in sample:
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
if enable_plots:
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text
@@ -128,16 +120,12 @@ def main(manifest: Path):
)
st.update_cursor(sample_no + 1)
if correction_entry:
st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
)
status = correction_entry["value"]["status"]
correction = correction_entry["value"]["correction"]
st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
text_sample = st.text_input("Go to Text:", value="")
if text_sample != "":
candidates = [
i
for (i, p) in enumerate(asr_data)
if p["text"] == text_sample or p["spoken"] == text_sample
]
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
if len(candidates) > 0:
st.update_cursor(candidates[0])
real_idx = st.number_input(

359
jasper/evaluate.py Normal file
View File

@@ -0,0 +1,359 @@
# Copyright (c) 2019 NVIDIA Corporation
import argparse
import copy
# import math
import os
from pathlib import Path
from functools import partial
from ruamel.yaml import YAML
import nemo
import nemo.collections.asr as nemo_asr
import nemo.utils.argparse as nm_argparse
from nemo.collections.asr.helpers import (
# monitor_asr_train_progress,
process_evaluation_batch,
process_evaluation_epoch,
)
# from nemo.utils.lr_policies import CosineAnnealing
from training.data_loaders import RpycAudioToTextDataLayer
logging = nemo.logging
def parse_args():
parser = argparse.ArgumentParser(
parents=[nm_argparse.NemoArgParser()],
description="Jasper",
conflict_handler="resolve",
)
parser.set_defaults(
checkpoint_dir=None,
optimizer="novograd",
batch_size=64,
eval_batch_size=64,
lr=0.002,
amp_opt_level="O1",
create_tb_writer=True,
model_config="./train/jasper10x5dr.yaml",
work_dir="./train/work",
num_epochs=300,
weight_decay=0.005,
checkpoint_save_freq=100,
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
exp_name="jasper-speller",
)
# Overwrite default args
parser.add_argument(
"--max_steps",
type=int,
default=None,
required=False,
help="max number of steps to train",
)
parser.add_argument(
"--num_epochs", type=int, required=False, help="number of epochs to train"
)
parser.add_argument(
"--model_config",
type=str,
required=False,
help="model configuration file: model.yaml",
)
parser.add_argument(
"--encoder_checkpoint",
type=str,
required=True,
help="encoder checkpoint file: JasperEncoder.pt",
)
parser.add_argument(
"--decoder_checkpoint",
type=str,
required=True,
help="decoder checkpoint file: JasperDecoderForCTC.pt",
)
parser.add_argument(
"--remote_data",
type=str,
required=False,
default="",
help="remote dataloader endpoint",
)
parser.add_argument(
"--dataset",
type=str,
required=False,
default="",
help="dataset directory containing train/test manifests",
)
# Create new args
parser.add_argument("--exp_name", default="Jasper", type=str)
parser.add_argument("--beta1", default=0.95, type=float)
parser.add_argument("--beta2", default=0.25, type=float)
parser.add_argument("--warmup_steps", default=0, type=int)
parser.add_argument(
"--load_dir",
default=None,
type=str,
help="directory with pre-trained checkpoint",
)
args = parser.parse_args()
if args.max_steps is None and args.num_epochs is None:
raise ValueError("Either max_steps or num_epochs should be provided.")
return args
def construct_name(
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
):
if max_steps is not None:
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
)
else:
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
)
def create_all_dags(args, neural_factory):
yaml = YAML(typ="safe")
with open(args.model_config) as f:
jasper_params = yaml.load(f)
vocab = jasper_params["labels"]
sample_rate = jasper_params["sample_rate"]
# Calculate num_workers for dataloader
total_cpus = os.cpu_count()
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
# perturb_config = jasper_params.get('perturb', None)
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
del train_dl_params["train"]
del train_dl_params["eval"]
# del train_dl_params["normalize_transcripts"]
if args.dataset:
d_path = Path(args.dataset)
if not args.train_dataset:
args.train_dataset = str(d_path / Path("train_manifest.json"))
if not args.eval_datasets:
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
data_loader_layer = nemo_asr.AudioToTextDataLayer
if args.remote_data:
train_dl_params["rpyc_host"] = args.remote_data
data_loader_layer = RpycAudioToTextDataLayer
# data_layer = data_loader_layer(
# manifest_filepath=args.train_dataset,
# sample_rate=sample_rate,
# labels=vocab,
# batch_size=args.batch_size,
# num_workers=cpu_per_traindl,
# **train_dl_params,
# # normalize_transcripts=False
# )
#
# N = len(data_layer)
# steps_per_epoch = math.ceil(
# N / (args.batch_size * args.iter_per_step * args.num_gpus)
# )
# logging.info("Have {0} examples to train on.".format(N))
#
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
)
# multiply_batch_config = jasper_params.get("MultiplyBatch", None)
# if multiply_batch_config:
# multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
#
# spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
# if spectr_augment_config:
# data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
# **spectr_augment_config
# )
#
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
if args.remote_data:
eval_dl_params["rpyc_host"] = args.remote_data
del eval_dl_params["train"]
del eval_dl_params["eval"]
data_layers_eval = []
# if args.eval_datasets:
for eval_datasets in args.eval_datasets:
data_layer_eval = data_loader_layer(
manifest_filepath=eval_datasets,
sample_rate=sample_rate,
labels=vocab,
batch_size=args.eval_batch_size,
num_workers=cpu_per_traindl,
**eval_dl_params,
)
data_layers_eval.append(data_layer_eval)
# else:
# logging.warning("There were no val datasets passed")
jasper_encoder = nemo_asr.JasperEncoder(
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
**jasper_params["JasperEncoder"],
)
jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0)
jasper_decoder = nemo_asr.JasperDecoderForCTC(
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
num_classes=len(vocab),
)
jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0)
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
greedy_decoder = nemo_asr.GreedyCTCDecoder()
# logging.info("================================")
# logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
# logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
# logging.info(
# f"Total number of parameters in model: "
# f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
# )
# logging.info("================================")
#
# # Train DAG
# (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
# processed_signal_t, p_length_t = data_preprocessor(
# input_signal=audio_signal_t, length=a_sig_length_t
# )
#
# if multiply_batch_config:
# (
# processed_signal_t,
# p_length_t,
# transcript_t,
# transcript_len_t,
# ) = multiply_batch(
# in_x=processed_signal_t,
# in_x_len=p_length_t,
# in_y=transcript_t,
# in_y_len=transcript_len_t,
# )
#
# if spectr_augment_config:
# processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
#
# encoded_t, encoded_len_t = jasper_encoder(
# audio_signal=processed_signal_t, length=p_length_t
# )
# log_probs_t = jasper_decoder(encoder_output=encoded_t)
# predictions_t = greedy_decoder(log_probs=log_probs_t)
# loss_t = ctc_loss(
# log_probs=log_probs_t,
# targets=transcript_t,
# input_length=encoded_len_t,
# target_length=transcript_len_t,
# )
#
# # Callbacks needed to print info to console and Tensorboard
# train_callback = nemo.core.SimpleLossLoggerCallback(
# tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
# print_func=partial(monitor_asr_train_progress, labels=vocab),
# get_tb_values=lambda x: [("loss", x[0])],
# tb_writer=neural_factory.tb_writer,
# )
#
# chpt_callback = nemo.core.CheckpointCallback(
# folder=neural_factory.checkpoint_dir,
# load_from_folder=args.load_dir,
# step_freq=args.checkpoint_save_freq,
# checkpoints_to_keep=30,
# )
#
# callbacks = [train_callback, chpt_callback]
callbacks = []
# assemble eval DAGs
for i, eval_dl in enumerate(data_layers_eval):
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
processed_signal_e, p_length_e = data_preprocessor(
input_signal=audio_signal_e, length=a_sig_length_e
)
encoded_e, encoded_len_e = jasper_encoder(
audio_signal=processed_signal_e, length=p_length_e
)
log_probs_e = jasper_decoder(encoder_output=encoded_e)
predictions_e = greedy_decoder(log_probs=log_probs_e)
loss_e = ctc_loss(
log_probs=log_probs_e,
targets=transcript_e,
input_length=encoded_len_e,
target_length=transcript_len_e,
)
# create corresponding eval callback
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
eval_callback = nemo.core.EvaluatorCallback(
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
eval_step=args.eval_freq,
tb_writer=neural_factory.tb_writer,
)
callbacks.append(eval_callback)
return callbacks
def main():
args = parse_args()
# name = construct_name(
# args.exp_name,
# args.lr,
# args.batch_size,
# args.max_steps,
# args.num_epochs,
# args.weight_decay,
# args.optimizer,
# args.iter_per_step,
# )
# log_dir = name
# if args.work_dir:
# log_dir = os.path.join(args.work_dir, name)
# instantiate Neural Factory with supported backend
neural_factory = nemo.core.NeuralModuleFactory(
placement=nemo.core.DeviceType.GPU,
backend=nemo.core.Backend.PyTorch,
# local_rank=args.local_rank,
# optimization_level=args.amp_opt_level,
# log_dir=log_dir,
# checkpoint_dir=args.checkpoint_dir,
# create_tb_writer=args.create_tb_writer,
# files_to_copy=[args.model_config, __file__],
# cudnn_benchmark=args.cudnn_benchmark,
# tensorboard_dir=args.tensorboard_dir,
)
args.num_gpus = neural_factory.world_size
# checkpoint_dir = neural_factory.checkpoint_dir
if args.local_rank is not None:
logging.info("Doing ALL GPU")
# build dags
callbacks = create_all_dags(args, neural_factory)
# evaluate model
neural_factory.eval(callbacks=callbacks)
if __name__ == "__main__":
main()

View File

@@ -41,7 +41,7 @@ def parse_args():
work_dir="./train/work",
num_epochs=300,
weight_decay=0.005,
checkpoint_save_freq=200,
checkpoint_save_freq=100,
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
@@ -266,6 +266,7 @@ def create_all_dags(args, neural_factory):
folder=neural_factory.checkpoint_dir,
load_from_folder=args.load_dir,
step_freq=args.checkpoint_save_freq,
checkpoints_to_keep=30,
)
callbacks = [train_callback, chpt_callback]

View File

@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
requirements = [
"ruamel.yaml",
"torch==1.4.0",
"torch==2.8.0",
"torchvision==0.5.0",
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
]
@@ -19,13 +19,15 @@ extra_requirements = {
"ruamel.yaml==0.16.10",
"pymongo==3.10.1",
"librosa==0.7.2",
"numba==0.48",
"matplotlib==3.2.1",
"pandas==1.0.3",
"tabulate==0.8.7",
"natural==0.2.0",
"num2words==0.5.10",
"typer[all]==0.1.1",
"typer[all]==0.3.1",
"python-slugify==4.0.0",
"rpyc~=4.1.4",
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
],
"validation": [
@@ -39,6 +41,7 @@ extra_requirements = {
"streamlit==0.58.0",
"natural==0.2.0",
"stringcase==1.2.0",
"google-cloud-speech~=1.3.1",
]
# "train": [
# "torchaudio==0.5.0",
@@ -63,14 +66,15 @@ setup(
"jasper_transcribe = jasper.transcribe:main",
"jasper_server = jasper.server:main",
"jasper_trainer = jasper.training.cli:main",
"jasper_evaluator = jasper.evaluate:main",
"jasper_data_tts_generate = jasper.data.tts_generator:main",
"jasper_data_conv_generate = jasper.data.conv_generator:main",
"jasper_data_call_recycle = jasper.data.call_recycler:main",
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
"jasper_data_rastrik_recycle = jasper.data.rastrik_recycler:main",
"jasper_data_server = jasper.data.server:main",
"jasper_data_validation = jasper.data.validation.process:main",
"jasper_data_preprocess = jasper.data.process:main",
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
]
},
zip_safe=False,