1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-03-08 10:32:35 +00:00

Compare commits

..

7 Commits

Author SHA1 Message Date
e30dd724f5 Merge pull request #3 from malarinv/dependabot/pip/torch-2.8.0
Bump torch from 1.4.0 to 2.8.0
2025-08-30 18:24:04 +05:30
dependabot[bot]
02df1b5282 Bump torch from 1.4.0 to 2.8.0
Bumps [torch](https://github.com/pytorch/pytorch) from 1.4.0 to 2.8.0.
- [Release notes](https://github.com/pytorch/pytorch/releases)
- [Changelog](https://github.com/pytorch/pytorch/blob/main/RELEASE.md)
- [Commits](https://github.com/pytorch/pytorch/compare/v1.4.0...v2.8.0)

---
updated-dependencies:
- dependency-name: torch
  dependency-version: 2.8.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-08-30 12:53:51 +00:00
e8f58a5043 1. refactored ui_dump
2. added flake8
2020-08-09 19:16:35 +05:30
42647196fe 1. fixed dependency issues
2. add task-id option to validation ui to respawn previous task
3. clean-up rastrik-recycler
2020-08-06 22:40:14 +05:30
e77943b2f2 Merge pull request #1 from wrat/master
adding support for asr data generator
2020-08-06 00:11:53 +05:30
wabi_sabi004
14d31a51c3 adding support for asr data generator 2020-08-06 00:08:46 +05:30
e24a8cf9d0 1. integrated data generator using google tts
2. added training script

fix module packaging issue

implement call audio data recycler for asr

1. added streamlit based validation ui with mongodb datastore integration
2. fix asr wrong sample rate inference
3. update requirements

1. refactored streamlit code
2. fixed issues in data manifest handling

refresh to next entry on submit and comment out mongo clearing code for safety :P

add validation ui and post processing to correct using validation data

1. added a tool to extract asr data from gcp transcripts logs
2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file
3. clean-up leaderboard duration logic
4. added a wip dataloader service
5. made the asr_data_writer util more generic with verbose flags and unique filename
6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node
7. refactored the validation post processing to dump a ui config for validation
8. included utility functions to correct, fill update and clear annotations from mongodb data
9. refactored the ui logic to be more generic for any asr data
10. updated setup.py dependencies to support the above features

unlink temporary files after transcribing

1. clean-up unused data process code
2. fix invalid sample no from mongo
3. data loader service return remote netref

1. added training utils with custom data loaders with remote rpyc dataservice support
2. fix validation correction dump path
3. cache dataset for precaching before training to memory
4. update dependencies

1. implement dataset augmentation and validation in process
2. added option to skip 'incorrect' annotations in validation data
3. added confirmation on clearing mongo collection
4. added an option to navigate to a given text in the validation ui
5. added a dataset and remote option to trainer to load dataset from directory and remote rpyc service

1. added utility command to export call logs
2. mongo conn accepts port

refactored module structure

1. enabled silece stripping in chunks when recycling audio from asr logs
2. limit asr recycling to 1 min of start audio to get reliable alignments and ignoring agent channel
3. added rev recycler for generating asr dataset from rev transcripts and audio
4. update pydub dependency for silence stripping fn and removing threadpool hardcoded worker count

1. added support for mono/dual channel rev transcripts
2. handle errors when extracting datapoints from rev meta data
3. added suport for annotation only task when dumping ui data

cleanup rev recycle

added option to disable plots during validation

fix skipping null audio and add more verbose logs

respect verbose flag

don't load audio for annotation only ui and keep spoken as text for normal asr validation

1. refactored wav chunk processing method
2. renamed streamlit to validation_ui

show duration on validation of dataset

parallelize data loading from remote

skipping invalid data points

1. removed the transcriber_pretrained/speller from utils
2. introduced get_mongo_coll to get the collection object directly from mongo uri
3. removed processing of correction entries to remove space/upper casing

refactor validation process arguments and logging

1. added a data extraction type argument
2. cleanup/refactor

1. using dataname args for update/fill annotations
2. rename to dump_ui

added support for name/dates/cities call data extraction and more logs

handling non-pnr cases without parens in text data

1. added conv data generator
2. more utils

1. added start delay arg in call recycler
2. implement ui_dump/manifest  writer in call_recycler itself
3. refactored call data point plotter
4. added sample-ui task-ui  on the validation process
5. implemented call-quality stats using corrections from mongo
6. support deleting cursors on mongo
7. implement multiple task support on validation ui based on task_id mongo field

fix 11st to 11th in ordinal

stripping silence on call chunk

1. added option to strip silent chunks
2. computing caller quality based on task-id of corrections

1. fix update-correction to use ui_dump instead of manifest
2. update training params no of checkpoints on chpk frequency

1. split extract all data types in one shot with --extraction-type all flag
2. add notes about diffing split extracted and original data
3. add a nlu conv generator to generate conv data based on nlu utterances and entities
4. add task uid support for dumping corrections
5. abstracted generate date fn

1. added a test generator and slu evaluator
2. ui dump now include gcp results
3. showing default option for more args validation process commands

added evaluation command

clean-up
2020-07-14 12:09:46 +05:30
21 changed files with 265 additions and 2130 deletions

4
.flake8 Normal file
View File

@@ -0,0 +1,4 @@
[flake8]
exclude = docs
ignore = E203, W503
max-line-length = 119

View File

@@ -7,10 +7,16 @@
# Table of Contents # Table of Contents
* [Prerequisites](#prerequisites)
* [Features](#features) * [Features](#features)
* [Installation](#installation) * [Installation](#installation)
* [Usage](#usage) * [Usage](#usage)
# Prerequisites
```bash
# apt install libsndfile-dev ffmpeg
```
# Features # Features
* ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) ) * ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) )

View File

@@ -2,10 +2,6 @@ import os
import logging import logging
import rpyc import rpyc
from functools import lru_cache from functools import lru_cache
import typer
from pathlib import Path
app = typer.Typer()
logging.basicConfig( logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@@ -23,28 +19,3 @@ def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
asr = rpyc.connect(asr_host, asr_port).root asr = rpyc.connect(asr_host, asr_port).root
logger.info(f"connected to asr server successfully") logger.info(f"connected to asr server successfully")
return asr.transcribe return asr.transcribe
@app.command()
def transcribe_file(audio_file: Path):
from pydub import AudioSegment
transcriber = transcribe_gen()
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_file)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
tscript_file_path = audio_file.with_suffix(".txt")
transcription = transcriber(aud_seg.raw_data)
with open(tscript_file_path, "w") as tf:
tf.write(transcription)
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -1,104 +0,0 @@
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
app = typer.Typer()
@app.command()
def extract_data(
call_audio_dir: Path = Path("/dataset/png_prod/call_audio"),
call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"),
output_dir: Path = Path("./data"),
dataset_name: str = "png_gcp_2jan",
verbose: bool = False,
):
from pydub import AudioSegment
from .utils import ExtendedPath, asr_data_writer, strip_silence
from lenses import lens
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_event_generator(call_audio_dir):
for wav_path in call_audio_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
meta_path = call_meta_dir / rel_meta_path
events = ExtendedPath(meta_path).read_json()
yield call_wav, wav_path, events
def contains_asr(x):
return "AsrResult" in x
def channel(n):
def filter_func(ev):
return (
ev["AsrResult"]["Channel"] == n
if "Channel" in ev["AsrResult"]
else n == 0
)
return filter_func
def compute_endtime(call_wav, state):
for (i, st) in enumerate(state):
start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0)
transcript = st["AsrResult"]["Alternatives"][0]["Transcript"]
if i + 1 < len(state):
end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
else:
end_time = call_wav.duration_seconds
full_code_seg = call_wav[start_time * 1000 : end_time * 1000]
code_seg = strip_silence(full_code_seg)
code_fb = BytesIO()
code_seg.export(code_fb, format="wav")
code_wav = code_fb.getvalue()
# only starting 1 min audio has reliable alignment ignore rest
if start_time > 60:
if verbose:
print(f'start time over 60 seconds of audio skipping.')
break
# only if some reasonable audio data is present yield it
if code_seg.duration_seconds < 0.5:
if verbose:
print(f'transcript chunk "{transcript}" contains no audio skipping.')
continue
yield transcript, code_seg.duration_seconds, code_wav
def asr_data_generator(call_wav, call_wav_fname, events):
call_wav_0, call_wav_1 = call_wav.split_to_mono()
asr_events = lens["Events"].Each()["Event"].Filter(contains_asr)
call_evs_0 = asr_events.Filter(channel(0)).collect()(events)
# Ignoring agent channel events
# call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
if verbose:
typer.echo(f"processing data points on {call_wav_fname}")
call_data_0 = compute_endtime(call_wav_0, call_evs_0)
# Ignoring agent channel
# call_data_1 = compute_endtime(call_wav_1, call_evs_1)
return call_data_0 # chain(call_data_0, call_data_1)
def generate_call_asr_data():
full_asr_data = []
total_duration = 0
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
asr_data = asr_data_generator(wav, wav_path, ev)
total_duration += wav.duration_seconds
full_asr_data.append(asr_data)
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
typer.echo(f"written {n_dps} data points")
generate_call_asr_data()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -1,509 +0,0 @@
import typer
from pathlib import Path
from enum import Enum
app = typer.Typer()
@app.command()
def export_all_logs(
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
):
from .utils import get_mongo_conn
from collections import defaultdict
from ruamel.yaml import YAML
yaml = YAML()
mongo_coll = get_mongo_conn()
caller_calls = defaultdict(lambda: [])
for call in mongo_coll.find():
sysid = call["SystemID"]
call_uri = f"http://{domain}/calls/{sysid}"
caller = call["Caller"]
caller_calls[caller].append(call_uri)
caller_list = []
for caller in caller_calls:
caller_list.append({"name": caller, "calls": caller_calls[caller]})
output_yaml = {"users": caller_list}
typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
with call_logs_file.open("w") as yf:
yaml.dump(output_yaml, yf)
@app.command()
def export_calls_between(
start_cid: str,
end_cid: str,
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
mongo_port: int = 27017,
):
from collections import defaultdict
from ruamel.yaml import YAML
from .utils import get_mongo_conn
yaml = YAML()
mongo_coll = get_mongo_conn(port=mongo_port)
start_meta = mongo_coll.find_one({"SystemID": start_cid})
end_meta = mongo_coll.find_one({"SystemID": end_cid})
caller_calls = defaultdict(lambda: [])
call_query = mongo_coll.find(
{
"StartTS": {"$gte": start_meta["StartTS"]},
"EndTS": {"$lte": end_meta["EndTS"]},
}
)
for call in call_query:
sysid = call["SystemID"]
call_uri = f"http://{domain}/calls/{sysid}"
caller = call["Caller"]
caller_calls[caller].append(call_uri)
caller_list = []
for caller in caller_calls:
caller_list.append({"name": caller, "calls": caller_calls[caller]})
output_yaml = {"users": caller_list}
typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
with call_logs_file.open("w") as yf:
yaml.dump(output_yaml, yf)
@app.command()
def copy_metas(
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
output_dir: Path = Path("./data"),
meta_dir: Path = Path("/tmp/call_metas"),
):
from lenses import lens
from ruamel.yaml import YAML
from urllib.parse import urlsplit
from shutil import copy2
yaml = YAML()
call_logs = yaml.load(call_logs_file.read_text())
call_meta_dir: Path = output_dir / Path("call_metas")
call_meta_dir.mkdir(exist_ok=True, parents=True)
meta_dir.mkdir(exist_ok=True, parents=True)
def get_cid(uri):
return Path(urlsplit(uri).path).stem
def copy_meta(uri):
cid = get_cid(uri)
saved_meta_path = call_meta_dir / Path(f"{cid}.json")
dest_meta_path = meta_dir / Path(f"{cid}.json")
if not saved_meta_path.exists():
print(f"{saved_meta_path} not found")
copy2(saved_meta_path, dest_meta_path)
def download_meta_audio():
call_lens = lens["users"].Each()["calls"].Each()
call_lens.modify(copy_meta)(call_logs)
download_meta_audio()
class ExtractionType(str, Enum):
flow = "flow"
data = "data"
@app.command()
def analyze(
leaderboard: bool = False,
plot_calls: bool = False,
extract_data: bool = False,
extraction_type: ExtractionType = typer.Option(
ExtractionType.data, show_default=True
),
start_delay: float = 1.5,
download_only: bool = False,
strip_silent_chunks: bool = True,
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
output_dir: Path = Path("./data"),
data_name: str = None,
mongo_uri: str = typer.Option(
"mongodb://localhost:27017/test.calls", show_default=True
),
):
from urllib.parse import urlsplit
from functools import reduce
import boto3
from io import BytesIO
import json
from ruamel.yaml import YAML
import re
from google.protobuf.timestamp_pb2 import Timestamp
from datetime import timedelta
import librosa
import librosa.display
from lenses import lens
from pprint import pprint
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from tqdm import tqdm
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll, get_call_logs
from pydub import AudioSegment
from natural.date import compress
matplotlib.rcParams["agg.path.chunksize"] = 10000
matplotlib.use("agg")
yaml = YAML()
s3 = boto3.client("s3")
mongo_collection = get_mongo_coll(mongo_uri)
call_media_dir: Path = output_dir / Path("call_wavs")
call_media_dir.mkdir(exist_ok=True, parents=True)
call_meta_dir: Path = output_dir / Path("call_metas")
call_meta_dir.mkdir(exist_ok=True, parents=True)
call_plot_dir: Path = output_dir / Path("plots")
call_plot_dir.mkdir(exist_ok=True, parents=True)
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
dataset_name = call_logs_file.stem if not data_name else data_name
call_logs = yaml.load(call_logs_file.read_text())
def gen_ev_fev_timedelta(fev):
fev_p = Timestamp()
fev_p.FromJsonString(fev["CreatedTS"])
fev_dt = fev_p.ToDatetime()
td_0 = timedelta()
def get_timedelta(ev):
ev_p = Timestamp()
ev_p.FromJsonString(value=ev["CreatedTS"])
ev_dt = ev_p.ToDatetime()
delta = ev_dt - fev_dt
return delta if delta > td_0 else td_0
return get_timedelta
def chunk_n(evs, n):
return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
if extraction_type == ExtractionType.data:
def is_utter_event(ev):
return (
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
and (ev["Type"] != "DEBUG")
and ev["Type"] != "ASR_RESULT"
)
def get_data_points(utter_events, td_fn):
data_points = []
for evs in chunk_n(utter_events, 3):
try:
assert evs[0]["Type"] == "CONV_RESULT"
assert evs[1]["Type"] == "STARTED_SPEAKING"
assert evs[2]["Type"] == "STOPPED_SPEAKING"
start_time = td_fn(evs[1]).total_seconds() - start_delay
end_time = td_fn(evs[2]).total_seconds()
spoken = evs[0]["Msg"]
data_points.append(
{"start_time": start_time, "end_time": end_time, "code": spoken}
)
except AssertionError:
# skipping invalid data_points
pass
return data_points
def text_extractor(spoken):
return (
re.search(r"'(.*)'", spoken).groups(0)[0]
if len(spoken) > 6 and re.search(r"'(.*)'", spoken)
else spoken
)
elif extraction_type == ExtractionType.flow:
def is_final_asr_event_or_spoken(ev):
pld = json.loads(ev["Payload"])
return (
pld["AsrResult"]["Results"][0]["IsFinal"]
if ev["Type"] == "ASR_RESULT"
else True
)
def is_utter_event(ev):
return (
ev["Author"] == "CONV"
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
) and (ev["Type"] != "DEBUG")
def get_data_points(utter_events, td_fn):
data_points = []
for evs in chunk_n(utter_events, 4):
try:
assert len(evs) == 4
assert evs[0]["Type"] == "CONV_RESULT"
assert evs[1]["Type"] == "STARTED_SPEAKING"
assert evs[2]["Type"] == "ASR_RESULT"
assert evs[3]["Type"] == "STOPPED_SPEAKING"
start_time = td_fn(evs[1]).total_seconds() - start_delay
end_time = td_fn(evs[2]).total_seconds()
conv_msg = evs[0]["Msg"]
if "full name" in conv_msg.lower():
pld = json.loads(evs[2]["Payload"])
spoken = pld["AsrResult"]["Results"][0]["Alternatives"][0][
"Transcript"
]
data_points.append(
{
"start_time": start_time,
"end_time": end_time,
"code": spoken,
}
)
except AssertionError:
# skipping invalid data_points
pass
return data_points
def text_extractor(spoken):
return spoken
def process_call(call_obj):
call_meta = get_call_logs(call_obj, s3, call_meta_dir)
call_events = call_meta["Events"]
def is_writer_uri_event(ev):
return ev["Author"] == "AUDIO_WRITER" and "s3://" in ev["Msg"]
writer_events = list(filter(is_writer_uri_event, call_events))
s3_wav_url = re.search(r"(s3://.*)", writer_events[0]["Msg"]).groups(0)[0]
s3_wav_url_p = urlsplit(s3_wav_url)
def is_first_audio_ev(state, ev):
if state[0]:
return state
else:
return (ev["Author"] == "GATEWAY" and ev["Type"] == "AUDIO", ev)
(_, first_audio_ev) = reduce(is_first_audio_ev, call_events, (False, {}))
get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
uevs = list(filter(is_utter_event, call_events))
ev_count = len(uevs)
utter_events = uevs[: ev_count - ev_count % 3]
saved_wav_path = call_media_dir / Path(Path(s3_wav_url_p.path).name)
if not saved_wav_path.exists():
print(f"downloading : {saved_wav_path} from {s3_wav_url}")
s3.download_file(
s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
)
return {
"wav_path": saved_wav_path,
"num_samples": len(utter_events) // 3,
"meta": call_obj,
"first_event_fn": get_ev_fev_timedelta,
"utter_events": utter_events,
}
def get_cid(uri):
return Path(urlsplit(uri).path).stem
def ensure_call(uri):
cid = get_cid(uri)
meta = mongo_collection.find_one({"SystemID": cid})
process_meta = process_call(meta)
return process_meta
def retrieve_processed_callmeta(uri):
cid = get_cid(uri)
meta = mongo_collection.find_one({"SystemID": cid})
duration = meta["EndTS"] - meta["StartTS"]
process_meta = process_call(meta)
data_points = get_data_points(
process_meta["utter_events"], process_meta["first_event_fn"]
)
process_meta["data_points"] = data_points
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
def retrieve_callmeta(call_uri):
uri = call_uri["call_uri"]
name = call_uri["name"]
cid = get_cid(uri)
meta = mongo_collection.find_one({"SystemID": cid})
duration = meta["EndTS"] - meta["StartTS"]
process_meta = process_call(meta)
data_points = get_data_points(
process_meta["utter_events"], process_meta["first_event_fn"]
)
process_meta["data_points"] = data_points
return {
"url": uri,
"name": name,
"meta": meta,
"duration": duration,
"process": process_meta,
}
def download_meta_audio():
call_lens = lens["users"].Each()["calls"].Each()
call_lens.modify(ensure_call)(call_logs)
def plot_calls_data():
def plot_data_points(y, sr, data_points, file_path):
plt.figure(figsize=(16, 12))
librosa.display.waveplot(y=y, sr=sr)
for dp in data_points:
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
plt.axvspan(start, end, color="green", alpha=0.2)
text_pos = (start + end) / 2
plt.text(
text_pos,
0.25,
f"{code}",
rotation=90,
horizontalalignment="center",
verticalalignment="center",
)
plt.title("Datapoints")
plt.savefig(file_path, format="png")
return file_path
def plot_call(call_obj):
saved_wav_path, data_points, sys_id = (
call_obj["process"]["wav_path"],
call_obj["process"]["data_points"],
call_obj["meta"]["SystemID"],
)
file_path = call_plot_dir / Path(sys_id).with_suffix(".png")
if not file_path.exists():
print(f"plotting: {file_path}")
(y, sr) = librosa.load(saved_wav_path)
plot_data_points(y, sr, data_points, str(file_path))
return file_path
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
# call_plot_data = call_lens.collect()(call_stats)
call_plots = call_lens.modify(plot_call)(call_stats)
# with ThreadPoolExecutor(max_workers=20) as exe:
# print('starting all plot tasks')
# responses = [exe.submit(plot_call, w) for w in call_plot_data]
# print('submitted all plot tasks')
# call_plots = [r.result() for r in responses]
pprint(call_plots)
def extract_data_points():
if strip_silent_chunks:
def audio_process(seg):
return strip_silence(seg)
else:
def audio_process(seg):
return seg
def gen_data_values(saved_wav_path, data_points, caller_name):
call_seg = (
AudioSegment.from_wav(saved_wav_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
for dp_id, dp in enumerate(data_points):
start, end, spoken = dp["start_time"], dp["end_time"], dp["code"]
spoken_seg = audio_process(call_seg[start * 1000 : end * 1000])
spoken_fb = BytesIO()
spoken_seg.export(spoken_fb, format="wav")
spoken_wav = spoken_fb.getvalue()
# search for actual pnr code and handle plain codes as well
extracted_code = text_extractor(spoken)
if strip_silent_chunks and spoken_seg.duration_seconds < 0.5:
print(f'transcript chunk "{spoken}" contains no audio skipping.')
continue
yield extracted_code, spoken_seg.duration_seconds, spoken_wav, caller_name, spoken_seg
call_lens = lens["users"].Each()["calls"].Each()
def assign_user_call(uc):
return (
lens["calls"]
.Each()
.modify(lambda c: {"call_uri": c, "name": uc["name"]})(uc)
)
user_call_logs = lens["users"].Each().modify(assign_user_call)(call_logs)
call_stats = call_lens.modify(retrieve_callmeta)(user_call_logs)
call_objs = call_lens.collect()(call_stats)
def data_source():
for call_obj in tqdm(call_objs):
saved_wav_path, data_points, name = (
call_obj["process"]["wav_path"],
call_obj["process"]["data_points"],
call_obj["name"],
)
for dp in gen_data_values(saved_wav_path, data_points, name):
yield dp
ui_dump_manifest_writer(call_asr_data, dataset_name, data_source())
def show_leaderboard():
def compute_user_stats(call_stat):
n_samples = (
lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
)
n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
return {
"num_samples": n_samples,
"duration": n_duration.total_seconds(),
"samples_rate": n_samples / n_duration.total_seconds(),
"duration_str": compress(n_duration, pad=" "),
"name": call_stat["name"],
}
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
user_stats = lens["users"].Each().modify(compute_user_stats)(call_stats)
leader_df = (
pd.DataFrame(user_stats["users"])
.sort_values(by=["duration"], ascending=False)
.reset_index(drop=True)
)
leader_df["rank"] = leader_df.index + 1
leader_board = leader_df.rename(
columns={
"rank": "Rank",
"num_samples": "Count",
"name": "Name",
"samples_rate": "SpeechRate",
"duration_str": "Duration",
}
)[["Rank", "Name", "Count", "Duration"]]
print(
"""ASR Dataset Leaderboard :
---------------------------------"""
)
print(leader_board.to_string(index=False))
if download_only:
download_meta_audio()
return
if leaderboard:
show_leaderboard()
if plot_calls:
plot_calls_data()
if extract_data:
extract_data_points()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -1,27 +0,0 @@
import typer
from pathlib import Path
from .utils import generate_dates
app = typer.Typer()
@app.command()
def export_conv_json(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
):
from .utils import ExtendedPath
conv_data = ExtendedPath(conv_src).read_json()
conv_data["dates"] = generate_dates()
ExtendedPath(conv_dest).write_json(conv_data)
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -1,98 +0,0 @@
from pathlib import Path
import typer
import pandas as pd
from ruamel.yaml import YAML
from itertools import product
from .utils import generate_dates
app = typer.Typer()
def unique_entity_list(entity_template_tags, entity_data):
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return list(unique_entity_set)
def nlu_entity_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
yaml = YAML()
nlu_data = yaml.load(nlu_data_file.read_text())
for cf in nlu_data["csv_files"]:
data = pd.read_csv(cf["fname"])
for et in cf["entities"]:
entity_name = et["name"]
entity_template_tags = et["tags"]
if "filter" in et:
entity_data = data[data[cf["filter_key"]] == et["filter"]]
else:
entity_data = data
yield entity_name, entity_template_tags, entity_data
def nlu_samples_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
yaml = YAML()
nlu_data = yaml.load(nlu_data_file.read_text())
sm = {s["name"]: s for s in nlu_data["samples_per_entity"]}
return sm
@app.command()
def compute_unique_nlu_stats(
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
):
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
nlu_data_file
):
entity_count = len(unique_entity_list(entity_template_tags, entity_data))
print(f"{entity_name}\t{entity_count}")
def replace_entity(tmpl, value, tags):
result = tmpl
for t in tags:
result = result.replace(t, value)
return result
@app.command()
def export_nlu_conv_json(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
):
from .utils import ExtendedPath
from random import sample
entity_samples = nlu_samples_reader(nlu_data_file)
conv_data = ExtendedPath(conv_src).read_json()
conv_data["Dates"] = generate_dates()
result_dict = {}
data_count = 0
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
nlu_data_file
):
entity_variants = sample(conv_data[entity_name], entity_samples[entity_name]["test_size"])
unique_entites = unique_entity_list(entity_template_tags, entity_data)
# sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
result_dict[entity_name] = []
for val in entity_variants:
sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
for tmpl in sample_entites:
result = replace_entity(tmpl, val, entity_template_tags)
result_dict[entity_name].append(result)
data_count += 1
print(f"Total of {data_count} variants generated")
ExtendedPath(conv_dest).write_json(result_dict)
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -38,9 +38,9 @@ def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
def split_data(dataset_path: Path, test_size: float = 0.1): def split_data(dataset_path: Path, test_size: float = 0.1):
manifest_path = dataset_path / Path("abs_manifest.json") manifest_path = dataset_path / Path("abs_manifest.json")
asr_data = list(asr_manifest_reader(manifest_path)) asr_data = list(asr_manifest_reader(manifest_path))
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size) train_data, test_data = train_test_split(asr_data, test_size=test_size)
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr) asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_data)
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr) asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_data)
@app.command() @app.command()
@@ -52,9 +52,9 @@ def validate_data(dataset_path: Path):
data_file = dataset_path / Path(mf_type) data_file = dataset_path / Path(mf_type)
print(f"validating {data_file}.") print(f"validating {data_file}.")
with Path(data_file).open("r") as pf: with Path(data_file).open("r") as pf:
pnr_jsonl = pf.readlines() data_jsonl = pf.readlines()
duration = 0 duration = 0
for (i, s) in enumerate(pnr_jsonl): for (i, s) in enumerate(data_jsonl):
try: try:
d = json.loads(s) d = json.loads(s)
duration += d["duration"] duration += d["duration"]

View File

@@ -0,0 +1,93 @@
from rastrik.proto.callrecord_pb2 import CallRecord
import gzip
from pydub import AudioSegment
from .utils import ui_dump_manifest_writer, strip_silence
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
app = typer.Typer()
@app.command()
def extract_manifest(
call_log_dir: Path = Path("./data/call_audio"),
output_dir: Path = Path("./data"),
dataset_name: str = "grassroot_pizzahut_v1",
caller_name: str = "grassroot",
verbose: bool = False,
):
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_pb2_generator(log_dir):
for wav_path in log_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
meta_path = wav_path.with_suffix(".pb2.gz")
yield call_wav, wav_path, meta_path
def read_event(call_wav, log_file):
call_wav_0, call_wav_1 = call_wav.split_to_mono()
with gzip.open(log_file, "rb") as log_h:
record_data = log_h.read()
cr = CallRecord()
cr.ParseFromString(record_data)
first_audio_event_timestamp = next(
(
i
for i in cr.events
if i.WhichOneof("event_type") == "call_event"
and i.call_event.WhichOneof("event_type") == "call_audio"
)
).timestamp.ToDatetime()
speech_events = [
i
for i in cr.events
if i.WhichOneof("event_type") == "speech_event"
and i.speech_event.WhichOneof("event_type") == "asr_final"
]
previous_event_timestamp = (
first_audio_event_timestamp - first_audio_event_timestamp
)
for index, each_speech_events in enumerate(speech_events):
asr_final = each_speech_events.speech_event.asr_final
speech_timestamp = each_speech_events.timestamp.ToDatetime()
actual_timestamp = speech_timestamp - first_audio_event_timestamp
start_time = previous_event_timestamp.total_seconds() * 1000
end_time = actual_timestamp.total_seconds() * 1000
audio_segment = strip_silence(call_wav_1[start_time:end_time])
code_fb = BytesIO()
audio_segment.export(code_fb, format="wav")
wav_data = code_fb.getvalue()
previous_event_timestamp = actual_timestamp
duration = (end_time - start_time) / 1000
yield asr_final, duration, wav_data, "grassroot", audio_segment
def generate_call_asr_data():
full_data = []
total_duration = 0
for wav, wav_path, pb2_path in wav_pb2_generator(call_log_dir):
asr_data = read_event(wav, pb2_path)
total_duration += wav.duration_seconds
full_data.append(asr_data)
n_calls = len(full_data)
typer.echo(f"loaded {n_calls} calls of duration {total_duration}s")
n_dps = ui_dump_manifest_writer(call_asr_data, dataset_name, chain(*full_data))
typer.echo(f"written {n_dps} data points")
generate_call_asr_data()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -1,175 +0,0 @@
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
import re
app = typer.Typer()
@app.command()
def extract_data(
call_audio_dir: Path = typer.Option(Path("/dataset/rev/wavs"), show_default=True),
call_meta_dir: Path = typer.Option(Path("/dataset/rev/jsons"), show_default=True),
output_dir: Path = typer.Option(Path("./data"), show_default=True),
dataset_name: str = typer.Option("rev_transribed", show_default=True),
verbose: bool = False,
):
from pydub import AudioSegment
from .utils import ExtendedPath, asr_data_writer, strip_silence
from lenses import lens
import datetime
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_event_generator(call_audio_dir):
for wav_path in call_audio_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
meta_path = call_meta_dir / rel_meta_path
if meta_path.exists():
events = ExtendedPath(meta_path).read_json()
yield call_wav, wav_path, events
else:
if verbose:
typer.echo(f"missing json corresponding to {wav_path}")
def contains_asr(x):
return "AsrResult" in x
def channel(n):
def filter_func(ev):
return (
ev["AsrResult"]["Channel"] == n
if "Channel" in ev["AsrResult"]
else n == 0
)
return filter_func
def time_to_msecs(time_str):
return (
datetime.datetime.strptime(time_str, "%H:%M:%S,%f")
- datetime.datetime(1900, 1, 1)
).total_seconds() * 1000
def process_utterance_chunk(wav_seg, start_time, end_time, monologue):
# offset by 1sec left side to include vad? discarded audio
full_tscript_wav_seg = wav_seg[
time_to_msecs(start_time) - 1000 : time_to_msecs(end_time) # + 1000
]
tscript_wav_seg = strip_silence(full_tscript_wav_seg)
tscript_wav_fb = BytesIO()
tscript_wav_seg.export(tscript_wav_fb, format="wav")
tscript_wav = tscript_wav_fb.getvalue()
text = "".join(lens["elements"].Each()["value"].collect()(monologue))
text_clean = re.sub(r"\[.*\]", "", text)
return tscript_wav, tscript_wav_seg.duration_seconds, text_clean
def dual_asr_data_generator(wav_seg, wav_path, meta):
left_audio, right_audio = wav_seg.split_to_mono()
channel_map = {"Agent": right_audio, "Client": left_audio}
monologues = lens["monologues"].Each().collect()(meta)
for monologue in monologues:
# print(monologue["speaker_name"])
speaker_channel = channel_map.get(monologue["speaker_name"])
if not speaker_channel:
if verbose:
print(
f'unknown speaker tag {monologue["speaker_name"]} in wav:{wav_path} skipping.'
)
continue
try:
start_time = (
lens["elements"]
.Each()
.Filter(lambda x: "timestamp" in x)["timestamp"]
.collect()(monologue)[0]
)
end_time = (
lens["elements"]
.Each()
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
.collect()(monologue)[-1]
)
except IndexError:
if verbose:
print(
f"error when loading timestamp events in wav:{wav_path} skipping."
)
continue
tscript_wav, seg_dur, text_clean = process_utterance_chunk(
speaker_channel, start_time, end_time, monologue
)
if seg_dur < 0.5:
if verbose:
print(
f'transcript chunk "{text_clean}" contains no audio in {wav_path} skipping.'
)
continue
yield text_clean, seg_dur, tscript_wav
def mono_asr_data_generator(wav_seg, wav_path, meta):
monologues = lens["monologues"].Each().collect()(meta)
for monologue in monologues:
try:
start_time = (
lens["elements"]
.Each()
.Filter(lambda x: "timestamp" in x)["timestamp"]
.collect()(monologue)[0]
)
end_time = (
lens["elements"]
.Each()
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
.collect()(monologue)[-1]
)
except IndexError:
if verbose:
print(
f"error when loading timestamp events in wav:{wav_path} skipping."
)
continue
tscript_wav, seg_dur, text_clean = process_utterance_chunk(
wav_seg, start_time, end_time, monologue
)
if seg_dur < 0.5:
if verbose:
print(
f'transcript chunk "{text_clean}" contains no audio in {wav_path} skipping.'
)
continue
yield text_clean, seg_dur, tscript_wav
def generate_rev_asr_data():
full_asr_data = []
total_duration = 0
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
if wav.channels > 2:
print(f"skipping many channel audio {wav_path}")
asr_data_generator = (
mono_asr_data_generator
if wav.channels == 1
else dual_asr_data_generator
)
asr_data = asr_data_generator(wav, wav_path, ev)
total_duration += wav.duration_seconds
full_asr_data.append(asr_data)
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
typer.echo(f"written {n_dps} data points")
generate_rev_asr_data()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -1,180 +0,0 @@
import typer
from pathlib import Path
import json
# from .utils import generate_dates, asr_test_writer
app = typer.Typer()
def run_test(reg_path, coll, s3, call_meta_dir, city_code, test_path):
from time import sleep
import subprocess
from .utils import ExtendedPath, get_call_logs
coll.delete_many({"CallID": test_path.name})
# test_path = dump_dir / data_name / test_file
# "../saas_reg/regression/run.sh -f data/asr_data/call_upwork_test_cnd_cities/asr_test.reg"
test_output = subprocess.run(
["/bin/bash", "-c", f"{str(reg_path)} --addr [::]:15400 -f {str(test_path)}"]
)
if test_output.returncode != 0:
print("Error running test {test_file}")
return
def get_meta():
call_meta = coll.find_one({"CallID": test_path.name})
if call_meta:
return call_meta
else:
sleep(2)
return get_meta()
call_meta = get_meta()
call_logs = get_call_logs(call_meta, s3, call_meta_dir)
call_events = call_logs["Events"]
test_data_path = test_path.with_suffix(".result.json")
test_data = ExtendedPath(test_data_path).read_json()
def is_final_asr_event_or_spoken(ev):
pld = json.loads(ev["Payload"])
return (
pld["AsrResult"]["Results"][0]["IsFinal"]
if ev["Type"] == "ASR_RESULT"
else False
)
def is_test_event(ev):
return (
ev["Author"] == "NLU"
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
) and (ev["Type"] != "DEBUG")
test_evs = list(filter(is_test_event, call_events))
if len(test_evs) == 2:
try:
asr_payload = test_evs[0]["Payload"]
asr_result = json.loads(asr_payload)["AsrResult"]["Results"][0]
alt_tscripts = [alt["Transcript"] for alt in asr_result["Alternatives"]]
gcp_result = "|".join(alt_tscripts)
entity_asr = asr_result["AsrDynamicResults"][0]["Candidate"]["Transcript"]
nlu_payload = test_evs[1]["Payload"]
nlu_result_payload = json.loads(nlu_payload)["NluResults"]
entity = test_data[0]["entity"]
text = test_data[0]["text"]
audio_filepath = test_data[0]["audio_filepath"]
pretrained_asr = test_data[0]["pretrained_asr"]
nlu_entity = list(json.loads(nlu_result_payload)["Entities"].values())[0]
asr_entity = city_code[entity] if entity in city_code else "UNKNOWN"
entities_match = asr_entity == nlu_entity
result = "Success" if entities_match else "Fail"
return {
"expected_entity": entity,
"text": text,
"audio_filepath": audio_filepath,
"pretrained_asr": pretrained_asr,
"entity_asr": entity_asr,
"google_asr": gcp_result,
"nlu_result": nlu_result_payload,
"asr_entity": asr_entity,
"nlu_entity": nlu_entity,
"result": result,
}
except Exception:
return {
"expected_entity": test_data[0]["entity"],
"text": test_data[0]["text"],
"audio_filepath": test_data[0]["audio_filepath"],
"pretrained_asr": test_data[0]["pretrained_asr"],
"entity_asr": "",
"google_asr": "",
"nlu_result": "",
"asr_entity": "",
"nlu_entity": "",
"result": "Error",
}
else:
return {
"expected_entity": test_data[0]["entity"],
"text": test_data[0]["text"],
"audio_filepath": test_data[0]["audio_filepath"],
"pretrained_asr": test_data[0]["pretrained_asr"],
"entity_asr": "",
"google_asr": "",
"nlu_result": "",
"asr_entity": "",
"nlu_entity": "",
"result": "Empty",
}
@app.command()
def evaluate_slu(
# conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
# extraction_key: str = "Cities",
dump_dir: Path = Path("./data/asr_data"),
call_meta_dir: Path = Path("./data/call_metas"),
test_file_pref: str = "asr_test",
mongo_uri: str = typer.Option(
"mongodb://localhost:27017/test.calls", show_default=True
),
test_results: Path = Path("./data/results.csv"),
airport_codes: Path = Path("./airports_code.csv"),
reg_path: Path = Path("../saas_reg/regression/run.sh"),
test_id: str = "5ef481f27031edf6910e94e0",
):
# import json
from .utils import get_mongo_coll
import pandas as pd
import boto3
from concurrent.futures import ThreadPoolExecutor
from functools import partial
# import subprocess
# from time import sleep
import csv
from tqdm import tqdm
s3 = boto3.client("s3")
df = pd.read_csv(airport_codes)[["iata", "city"]]
city_code = pd.Series(df["iata"].values, index=df["city"]).to_dict()
test_files = list((dump_dir / data_name).glob(test_file_pref + "*.reg"))
coll = get_mongo_coll(mongo_uri)
with test_results.open("w") as csvfile:
fieldnames = [
"expected_entity",
"text",
"audio_filepath",
"pretrained_asr",
"entity_asr",
"google_asr",
"nlu_result",
"asr_entity",
"nlu_entity",
"result",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
with ThreadPoolExecutor(max_workers=8) as exe:
print("starting all loading tasks")
for test_result in tqdm(
exe.map(
partial(run_test, reg_path, coll, s3, call_meta_dir, city_code),
test_files,
),
position=0,
leave=True,
total=len(test_files),
):
writer.writerow(test_result)
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -1,99 +0,0 @@
import typer
from pathlib import Path
from .utils import generate_dates, asr_test_writer
app = typer.Typer()
@app.command()
def export_test_reg(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
extraction_key: str = "Cities",
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
test_file: Path = Path("asr_test.reg"),
):
from .utils import (
ExtendedPath,
asr_manifest_reader,
gcp_transcribe_gen,
parallel_apply,
)
from ..client import transcribe_gen
from pydub import AudioSegment
from queue import PriorityQueue
jasper_map = {
"PNRs": 8045,
"Cities": 8046,
"Names": 8047,
"Dates": 8048,
}
# jasper_map = {"PNRs": 8050, "Cities": 8050, "Names": 8050, "Dates": 8050}
transcriber_gcp = gcp_transcribe_gen()
transcriber_trained = transcribe_gen(asr_port=jasper_map[extraction_key])
transcriber_all_trained = transcribe_gen(asr_port=8050)
transcriber_libri_all_trained = transcribe_gen(asr_port=8051)
def find_ent(dd, conv_data):
ents = PriorityQueue()
for ent in conv_data:
if ent in dd["text"]:
ents.put((-len(ent), ent))
return ents.get_nowait()[1]
def process_data(d):
orig_seg = AudioSegment.from_wav(d["audio_path"])
jas_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(24000)
gcp_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
deepgram_file = Path("/home/shubham/voice_auto/pnrs/wav/") / Path(
d["audio_path"].stem + ".txt"
)
if deepgram_file.exists():
d["deepgram"] = "".join(
[s.replace("CHANNEL 0:", "") for s in deepgram_file.read_text().split("\n")]
)
else:
d["deepgram"] = 'Not Found'
d["audio_path"] = str(d["audio_path"])
d["gcp_transcript"] = transcriber_gcp(gcp_seg.raw_data)
d["jasper_trained"] = transcriber_trained(jas_seg.raw_data)
d["jasper_all"] = transcriber_all_trained(jas_seg.raw_data)
d["jasper_libri"] = transcriber_libri_all_trained(jas_seg.raw_data)
return d
conv_data = ExtendedPath(conv_src).read_json()
conv_data["Dates"] = generate_dates()
dump_data_path = dump_dir / Path(data_name) / dump_file
ui_dump_data = ExtendedPath(dump_data_path).read_json()["data"]
ui_dump_map = {i["utterance_id"]: i for i in ui_dump_data}
manifest_path = dump_dir / Path(data_name) / manifest_file
test_points = list(asr_manifest_reader(manifest_path))
test_data_objs = [{**(ui_dump_map[t["audio_path"].stem]), **t} for t in test_points]
test_data = parallel_apply(process_data, test_data_objs)
# test_data = [process_data(t) for t in test_data_objs]
test_path = dump_dir / Path(data_name) / test_file
def dd_gen(dump_data):
for dd in dump_data:
ent = find_ent(dd, conv_data[extraction_key])
dd["entity"] = ent
if ent:
yield dd
asr_test_writer(test_path, dd_gen(test_data))
# for i, b in enumerate(batch(test_data, 1)):
# test_fname = Path(f"{test_file.stem}_{i}.reg")
# test_path = dump_dir / Path(data_name) / test_fname
# asr_test_writer(test_path, dd_gen(test_data))
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -1,52 +0,0 @@
from logging import getLogger
from google.cloud import texttospeech
LOGGER = getLogger("googletts")
class GoogleTTS(object):
def __init__(self):
self.client = texttospeech.TextToSpeechClient()
def text_to_speech(self, text: str, params: dict) -> bytes:
tts_input = texttospeech.types.SynthesisInput(ssml=text)
voice = texttospeech.types.VoiceSelectionParams(
language_code=params["language"], name=params["name"]
)
audio_config = texttospeech.types.AudioConfig(
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
sample_rate_hertz=params["sample_rate"],
)
response = self.client.synthesize_speech(tts_input, voice, audio_config)
audio_content = response.audio_content
return audio_content
@classmethod
def voice_list(cls):
"""Lists the available voices."""
client = cls().client
# Performs the list voices request
voices = client.list_voices()
results = []
for voice in voices.voices:
supported_eng_langs = [
lang for lang in voice.language_codes if lang[:2] == "en"
]
if len(supported_eng_langs) > 0:
lang = ",".join(supported_eng_langs)
else:
continue
ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender)
results.append(
{
"name": voice.name,
"language": lang,
"gender": ssml_gender.name,
"engine": "wavenet" if "Wav" in voice.name else "standard",
"sample_rate": voice.natural_sample_rate_hertz,
}
)
return results

View File

@@ -1,26 +0,0 @@
"""
TTSClient Abstract Class
"""
from abc import ABC, abstractmethod
class TTSClient(ABC):
"""
Base class for TTS
"""
@abstractmethod
def text_to_speech(self, text: str, num_channels: int, sample_rate: int,
audio_encoding) -> bytes:
"""
convert text to bytes
Arguments:
text {[type]} -- text to convert
channel {[type]} -- output audio bytes channel setting
width {[type]} -- width of audio bytes
rate {[type]} -- rare for audio bytes
Returns:
[type] -- [description]
"""

View File

@@ -1,62 +0,0 @@
# import io
# import sys
# import json
import argparse
import logging
from pathlib import Path
from .utils import random_pnr_generator, asr_data_writer
from .tts.googletts import GoogleTTS
from tqdm import tqdm
import random
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def pnr_tts_streamer(count):
google_voices = GoogleTTS.voice_list()
gtts = GoogleTTS()
for pnr_code in tqdm(random_pnr_generator(count)):
tts_code = f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
param = random.choice(google_voices)
param["sample_rate"] = 24000
param["num_channels"] = 1
wav_data = gtts.text_to_speech(text=tts_code, params=param)
audio_dur = len(wav_data[44:]) / (2 * 24000)
yield pnr_code, audio_dur, wav_data
def generate_asr_data_fromtts(output_dir, dataset_name, count):
asr_data_writer(output_dir, dataset_name, pnr_tts_streamer(count))
def arg_parser():
prog = Path(__file__).stem
parser = argparse.ArgumentParser(
prog=prog, description=f"generates asr training data"
)
parser.add_argument(
"--output_dir",
type=Path,
default=Path("./train/asr_data"),
help="directory to output asr data",
)
parser.add_argument(
"--count", type=int, default=3, help="number of datapoints to generate"
)
parser.add_argument(
"--dataset_name", type=str, default="pnr_data", help="name of the dataset"
)
return parser
def main():
parser = arg_parser()
args = parser.parse_args()
generate_asr_data_fromtts(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -1,22 +1,14 @@
import io import io
import os import os
import re
import json import json
import base64
import wave import wave
from pathlib import Path from pathlib import Path
from itertools import product
from functools import partial from functools import partial
from math import floor
from uuid import uuid4 from uuid import uuid4
from urllib.parse import urlsplit, urlencode
from urllib.request import Request, urlopen
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import numpy as np
import pymongo import pymongo
from slugify import slugify from slugify import slugify
from num2words import num2words
from jasper.client import transcribe_gen from jasper.client import transcribe_gen
from nemo.collections.asr.metrics import word_error_rate from nemo.collections.asr.metrics import word_error_rate
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -42,29 +34,6 @@ def wav_bytes(audio_bytes, frame_rate=24000):
return wf_b.getvalue() return wf_b.getvalue()
def random_pnr_generator(count=10000):
LENGTH = 3
# alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
alphabet = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
numeric = list("0123456789")
np_alphabet = np.array(alphabet, dtype="|S1")
np_numeric = np.array(numeric, dtype="|S1")
np_alpha_codes = np.random.choice(np_alphabet, [count, LENGTH])
np_num_codes = np.random.choice(np_numeric, [count, LENGTH])
np_code_seed = np.concatenate((np_alpha_codes, np_num_codes), axis=1).T
np.random.shuffle(np_code_seed)
np_codes = np_code_seed.T
codes = [(b"".join(np_codes[i])).decode("utf-8") for i in range(len(np_codes))]
return codes
def alnum_to_asr_tokens(text):
letters = " ".join(list(text))
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
return ("".join(num_tokens)).lower()
def tscript_uuid_fname(transcript): def tscript_uuid_fname(transcript):
return str(uuid4()) + "_" + slugify(transcript, max_length=8) return str(uuid4()) + "_" + slugify(transcript, max_length=8)
@@ -80,8 +49,8 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
fname = tscript_uuid_fname(transcript) fname = tscript_uuid_fname(transcript)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data) audio_file.write_bytes(wav_data)
rel_pnr_path = audio_file.relative_to(dataset_dir) rel_data_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript) manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest) mf.write(manifest)
if verbose: if verbose:
print(f"writing '{transcript}' of duration {audio_dur}") print(f"writing '{transcript}' of duration {audio_dur}")
@@ -89,104 +58,97 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
return num_datapoints return num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False): def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name) dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
ui_dump_file = dataset_dir / Path("ui_dump.json")
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True) (dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
ui_dump = {
"use_domain_asr": False,
"annotation_only": False,
"enable_plots": True,
"data": [],
}
data_funcs = []
deepgram_transcriber = deepgram_transcribe_gen() def data_fn(
# t2n = Text2Num() transcript,
transcriber_gcp = gcp_transcribe_gen() audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
png_path = Path(fname).with_suffix(".png")
wav_plot_path = dataset_dir / Path("wav_plots") / png_path
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
return {
"audio_filepath": str(rel_data_path),
"duration": round(audio_dur, 1),
"text": transcript,
"real_idx": num_datapoints,
"audio_path": audio_path,
"spoken": transcript,
"caller": caller_name,
"utterance_id": fname,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
}
num_datapoints = 0
data_funcs = []
transcriber_pretrained = transcribe_gen(asr_port=8044) transcriber_pretrained = transcribe_gen(asr_port=8044)
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
audio_path = str(audio_file)
rel_data_path = audio_file.relative_to(dataset_dir)
data_funcs.append(
partial(
data_fn,
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
)
)
num_datapoints += 1
ui_data = parallel_apply(lambda x: x(), data_funcs)
return ui_data, num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
dump_data, num_datapoints = ui_data_generator(
output_dir, dataset_name, asr_data_source, verbose=verbose
)
asr_manifest = dataset_dir / Path("manifest.json")
with asr_manifest.open("w") as mf: with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}") print(f"writing manifest to {asr_manifest}")
for d in dump_data:
def data_fn( rel_data_path = d["audio_filepath"]
transcript, audio_dur = d["duration"]
audio_dur, transcript = d["text"]
wav_data, manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_pnr_path,
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
gcp_seg = aud_seg.set_frame_rate(16000)
gcp_result = transcriber_gcp(gcp_seg.raw_data)
aud_data = audio_path.read_bytes()
dgram_result = deepgram_transcriber(aud_data)
# gtruth = dp['text']
# dgram_result = t2n.convert(dgram_script)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
return {
"audio_filepath": str(rel_pnr_path),
"duration": round(audio_dur, 1),
"text": transcript,
"real_idx": num_datapoints,
"audio_path": audio_path,
"spoken": transcript,
"caller": caller_name,
"utterance_id": fname,
"gcp_asr": gcp_result,
"deepgram_asr": dgram_result,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
}
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
audio_path = str(audio_file)
rel_pnr_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
mf.write(manifest) mf.write(manifest)
data_funcs.append(
partial( ui_dump_file = dataset_dir / Path("ui_dump.json")
data_fn, ExtendedPath(ui_dump_file).write_json({"data": dump_data})
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_pnr_path,
)
)
num_datapoints += 1
dump_data = parallel_apply(lambda x: x(), data_funcs)
# dump_data = [x() for x in tqdm(data_funcs)]
ui_dump["data"] = dump_data
ExtendedPath(ui_dump_file).write_json(ui_dump)
return num_datapoints return num_datapoints
def asr_manifest_reader(data_manifest_path: Path): def asr_manifest_reader(data_manifest_path: Path):
print(f"reading manifest from {data_manifest_path}") print(f"reading manifest from {data_manifest_path}")
with data_manifest_path.open("r") as pf: with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines() data_jsonl = pf.readlines()
pnr_data = [json.loads(v) for v in pnr_jsonl] data_data = [json.loads(v) for v in data_jsonl]
for p in pnr_data: for p in data_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"]) p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["text"] = p["text"].strip() p["text"] = p["text"].strip()
yield p yield p
@@ -236,12 +198,6 @@ class ExtendedPath(type(Path())):
with self.open("r") as jf: with self.open("r") as jf:
return json.load(jf) return json.load(jf)
def read_jsonl(self):
print(f"reading jsonl from {self}")
with self.open("r") as jf:
for l in jf.readlines():
yield json.loads(l)
def write_json(self, data): def write_json(self, data):
print(f"writing json to {self}") print(f"writing json to {self}")
self.parent.mkdir(parents=True, exist_ok=True) self.parent.mkdir(parents=True, exist_ok=True)
@@ -249,12 +205,6 @@ class ExtendedPath(type(Path())):
return json.dump(data, jf, indent=2) return json.dump(data, jf, indent=2)
def get_mongo_coll(uri="mongodb://localhost:27017/test.calls"):
ud = pymongo.uri_parser.parse_uri(uri)
conn = pymongo.MongoClient(uri)
return conn[ud["database"]][ud["collection"]]
def get_mongo_conn(host="", port=27017, db="test", col="calls"): def get_mongo_conn(host="", port=27017, db="test", col="calls"):
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost") mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
mongo_uri = f"mongodb://{mongo_host}:{port}/" mongo_uri = f"mongodb://{mongo_host}:{port}/"
@@ -280,405 +230,6 @@ def plot_seg(wav_plot_path, audio_path):
fig.savefig(wav_plot_f, format="png", dpi=50) fig.savefig(wav_plot_f, format="png", dpi=50)
def generate_dates():
days = [i for i in range(1, 32)]
months = [
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
]
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
def ordinal(n):
return "%d%s" % (
n,
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
)
def canon_vars(d, m):
return [
ordinal(d) + " " + m,
m + " " + ordinal(d),
ordinal(d) + " of " + m,
m + " the " + ordinal(d),
str(d) + " " + m,
m + " " + str(d),
]
return [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
def get_call_logs(call_obj, s3, call_meta_dir):
meta_s3_uri = call_obj["DataURI"]
s3_event_url_p = urlsplit(meta_s3_uri)
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
if not saved_meta_path.exists():
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
s3.download_file(
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
)
call_metas = json.load(saved_meta_path.open())
return call_metas
def gcp_transcribe_gen():
from google.cloud import speech_v1
from google.cloud.speech_v1 import enums
# import io
client = speech_v1.SpeechClient()
# local_file_path = 'resources/brooklyn_bridge.raw'
# The language of the supplied audio
language_code = "en-US"
model = "phone_call"
# Sample rate in Hertz of the audio data sent
sample_rate_hertz = 16000
# Encoding of audio data sent. This sample sets this explicitly.
# This field is optional for FLAC and WAV audio formats.
encoding = enums.RecognitionConfig.AudioEncoding.LINEAR16
config = {
"language_code": language_code,
"sample_rate_hertz": sample_rate_hertz,
"encoding": encoding,
"model": model,
"enable_automatic_punctuation": True,
"max_alternatives": 10,
"enable_word_time_offsets": True, # used to detect start and end time of utterances
"speech_contexts": [
{
"phrases": [
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"$OOV_CLASS_DIGIT_SEQUENCE",
"$TIME",
"$YEAR",
]
},
{
"phrases": [
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
]
},
{
"phrases": [
"PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"my PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"my PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"It's $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE is my PNR",
]
},
{"phrases": ["my name is"]},
{"phrases": ["Number $ORDINAL", "Numeral $ORDINAL"]},
{
"phrases": [
"John Smith",
"Carina Hu",
"Travis Lim",
"Marvin Tan",
"Samuel Tan",
"Dawn Mathew",
"Dawn",
"Mathew",
]
},
{
"phrases": [
"Beijing",
"Tokyo",
"London",
"19 August",
"7 October",
"11 December",
"17 September",
"19th August",
"7th October",
"11th December",
"17th September",
"ABC123",
"KWXUNP",
"XLU5K1",
"WL2JV6",
"KBS651",
]
},
{
"phrases": [
"first flight",
"second flight",
"third flight",
"first option",
"second option",
"third option",
"first one",
"second one",
"third one",
]
},
],
"metadata": {
"industry_naics_code_of_audio": 481111,
"interaction_type": enums.RecognitionMetadata.InteractionType.PHONE_CALL,
},
}
def sample_recognize(content):
"""
Transcribe a short audio file using synchronous speech recognition
Args:
local_file_path Path to local audio file, e.g. /path/audio.wav
"""
# with io.open(local_file_path, "rb") as f:
# content = f.read()
audio = {"content": content}
response = client.recognize(config, audio)
for result in response.results:
# First alternative is the most probable result
return "/".join([alt.transcript for alt in result.alternatives])
# print(u"Transcript: {}".format(alternative.transcript))
return ""
return sample_recognize
def deepgram_transcribe_gen():
DEEPGRAM_URL = "https://brain.deepgram.com/v2/listen"
MODEL = "agara"
encoding = "linear16"
sample_rate = "8000"
# diarize = "false"
q_params = {
"model": MODEL,
"encoding": encoding,
"sample_rate": sample_rate,
"language": "en-US",
"multichannel": "false",
"punctuate": "true",
}
url = "{}?{}".format(DEEPGRAM_URL, urlencode(q_params))
# print(url)
creds = ("arjun@agaralabs.com", "PoX1Y@x4h%oS")
def deepgram_offline(audio_data):
request = Request(
url,
method="POST",
headers={
"Authorization": "Basic {}".format(
base64.b64encode("{}:{}".format(*creds).encode("utf-8")).decode(
"utf-8"
)
)
},
data=audio_data,
)
with urlopen(request) as response:
msg = json.loads(response.read())
data = msg["results"]["channels"][0]["alternatives"][0]
return data["transcript"]
return deepgram_offline
class Text2Num(object):
"""docstring for Text2Num."""
def __init__(self):
numwords = {}
if not numwords:
units = [
"zero",
"one",
"two",
"three",
"four",
"five",
"six",
"seven",
"eight",
"nine",
"ten",
"eleven",
"twelve",
"thirteen",
"fourteen",
"fifteen",
"sixteen",
"seventeen",
"eighteen",
"nineteen",
]
tens = [
"",
"",
"twenty",
"thirty",
"forty",
"fifty",
"sixty",
"seventy",
"eighty",
"ninety",
]
scales = ["hundred", "thousand", "million", "billion", "trillion"]
numwords["and"] = (1, 0)
for idx, word in enumerate(units):
numwords[word] = (1, idx)
for idx, word in enumerate(tens):
numwords[word] = (1, idx * 10)
for idx, word in enumerate(scales):
numwords[word] = (10 ** (idx * 3 or 2), 0)
self.numwords = numwords
def is_num(self, word):
return word in self.numwords
def parseOrdinal(self, utterance, **kwargs):
lookup_dict = {
"first": 1,
"second": 2,
"third": 3,
"fourth": 4,
"fifth": 5,
"sixth": 6,
"seventh": 7,
"eighth": 8,
"ninth": 9,
"tenth": 10,
"one": 1,
"two": 2,
"three": 3,
"four": 4,
"five": 5,
"six": 6,
"seven": 7,
"eight": 8,
"nine": 9,
"ten": 10,
"1": 1,
"2": 2,
"3": 3,
"4": 4,
"5": 5,
"6": 6,
"7": 7,
"8": 8,
"9": 9,
"10": 10,
"last": -1,
}
pattern = re.compile(
r"(\s|^)(?P<num>(first)|(third)|(fourth)|(fifth)|(sixth)|(seventh)|(eighth)|(ninth)|(tenth)|(two)|(three)|(four)|(five)|(six)|(seven)|(eight)|(nine)|(ten)|(1)|(2)|(3)|(4)|(5)|(6)|(7)|(8)|(9)|(10)|(last))(\s|$)",
re.IGNORECASE,
)
ordinal = ""
if pattern.search(utterance):
ordinal = pattern.search(utterance).groupdict()["num"].strip()
elif re.search(r"(\s|^)(?P<num>(second))(\s|$)", utterance):
ordinal = "second"
elif re.search(r"(\s|^)(?P<num>(one))(\s|$)", utterance):
ordinal = "one"
ordinal = lookup_dict.get(ordinal, "")
return ordinal
def convert(self, sent):
# res = []
# for token in sent.split():
# if token in self.numwords:
# res.append(str(self.text2int(token)))
# else:
# res.append(token)
# return " ".join(res)
return " ".join(
[
str(self.parseOrdinal(x)) if self.parseOrdinal(x) != "" else x
for x in sent.split()
]
)
def text2int(self, textnum):
current = result = 0
for word in textnum.split():
if word not in self.numwords:
raise Exception("Illegal word: " + word)
scale, increment = self.numwords[word]
current = current * scale + increment
if scale > 100:
result += current
current = 0
return result + current
def is_sub_sequence(str1, str2):
m = len(str1)
n = len(str2)
def check_seq(string1, string2, m, n):
# Base Cases
if m == 0:
return True
if n == 0:
return False
# If last characters of two strings are matching
if string1[m - 1] == string2[n - 1]:
return check_seq(string1, string2, m - 1, n - 1)
# If last characters are not matching
return check_seq(string1, string2, m, n - 1)
return check_seq(str1, str2, m, n)
def parallel_apply(fn, iterable, workers=8): def parallel_apply(fn, iterable, workers=8):
with ThreadPoolExecutor(max_workers=workers) as exe: with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}") print(f"parallelly applying {fn}")
@@ -688,12 +239,3 @@ def parallel_apply(fn, iterable, workers=8):
exe.map(fn, iterable), position=0, leave=True, total=len(iterable) exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
) )
] ]
def main():
for c in random_pnr_generator():
print(c)
if __name__ == "__main__":
main()

View File

@@ -1,13 +1,10 @@
import json import json
import shutil import shutil
from pathlib import Path from pathlib import Path
from enum import Enum
import typer import typer
from tqdm import tqdm
from ..utils import ( from ..utils import (
alnum_to_asr_tokens,
ExtendedPath, ExtendedPath,
asr_manifest_reader, asr_manifest_reader,
asr_manifest_writer, asr_manifest_writer,
@@ -19,9 +16,7 @@ from ..utils import (
app = typer.Typer() app = typer.Typer()
def preprocess_datapoint( def preprocess_datapoint(idx, rel_root, sample):
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
):
from pydub import AudioSegment from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate from nemo.collections.asr.metrics import word_error_rate
from jasper.client import transcribe_gen from jasper.client import transcribe_gen
@@ -31,37 +26,23 @@ def preprocess_datapoint(
res["real_idx"] = idx res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"]) audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path) res["audio_path"] = str(audio_path)
if use_domain_asr:
res["spoken"] = alnum_to_asr_tokens(res["text"])
else:
res["spoken"] = res["text"]
res["utterance_id"] = audio_path.stem res["utterance_id"] = audio_path.stem
if not annotation_only: transcriber_pretrained = transcribe_gen(asr_port=8044)
transcriber_pretrained = transcribe_gen(asr_port=8044)
aud_seg = ( aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path) AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1) .set_channels(1)
.set_sample_width(2) .set_sample_width(2)
.set_frame_rate(24000) .set_frame_rate(24000)
) )
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate( res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
[res["text"]], [res["pretrained_asr"]] wav_plot_path = (
) rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
if use_domain_asr: )
transcriber_speller = transcribe_gen(asr_port=8045) if not wav_plot_path.exists():
res["domain_asr"] = transcriber_speller(aud_seg.raw_data) plot_seg(wav_plot_path, audio_path)
res["domain_wer"] = word_error_rate( res["plot_path"] = str(wav_plot_path)
[res["spoken"]], [res["pretrained_asr"]]
)
if enable_plots:
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
return res return res
except BaseException as e: except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}') print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
@@ -69,70 +50,59 @@ def preprocess_datapoint(
@app.command() @app.command()
def dump_ui( def dump_ui(
data_name: str = typer.Option("call_alphanum", show_default=True), data_name: str = typer.Option("dataname", show_default=True),
dataset_dir: Path = Path("./data/asr_data"), dataset_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/valiation_data"), dump_dir: Path = Path("./data/valiation_data"),
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True), dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
use_domain_asr: bool = False,
annotation_only: bool = False,
enable_plots: bool = True,
): ):
from concurrent.futures import ThreadPoolExecutor from io import BytesIO
from functools import partial from pydub import AudioSegment
from ..utils import ui_data_generator
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json") data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
dump_path: Path = dump_dir / Path(data_name) / dump_fname
plot_dir = data_manifest_path.parent / Path("wav_plots") plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True) plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}") typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_funcs = [
partial(
preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
annotation_only,
enable_plots,
)
for i, v in enumerate(pnr_jsonl)
]
def exec_func(f): def asr_data_source_gen():
return f() with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
for v in data_jsonl:
sample = json.loads(v)
rel_root = data_manifest_path.parent
res = dict(sample)
audio_path = rel_root / Path(sample["audio_filepath"])
audio_segment = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
wav_plot_path = (
rel_root
/ Path("wav_plots")
/ Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
code_fb = BytesIO()
audio_segment.export(code_fb, format="wav")
wav_data = code_fb.getvalue()
duration = audio_segment.duration_seconds
asr_final = res["text"]
yield asr_final, duration, wav_data, "caller", audio_segment
with ThreadPoolExecutor() as exe: dump_data, num_datapoints = ui_data_generator(
print("starting all preprocess tasks") dataset_dir, data_name, asr_data_source_gen()
pnr_data = filter( )
None, ui_dump_file = dataset_dir / Path("ui_dump.json")
list( ExtendedPath(ui_dump_file).write_json({"data": dump_data})
tqdm(
exe.map(exec_func, pnr_funcs),
position=0,
leave=True,
total=len(pnr_funcs),
)
),
)
if annotation_only:
result = list(pnr_data)
else:
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
ui_config = {
"use_domain_asr": use_domain_asr,
"annotation_only": annotation_only,
"enable_plots": enable_plots,
"data": result,
}
ExtendedPath(dump_path).write_json(ui_config)
@app.command() @app.command()
def sample_ui( def sample_ui(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True), data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"), dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"), dump_file: Path = Path("ui_dump.json"),
sample_count: int = typer.Option(80, show_default=True), sample_count: int = typer.Option(80, show_default=True),
@@ -156,50 +126,9 @@ def sample_ui(
ExtendedPath(sample_path).write_json(processed_data) ExtendedPath(sample_path).write_json(processed_data)
@app.command()
def sample_asr_accuracy(
data_name: str = typer.Option(
"png_06_2020_week1_numbers_window_customer", show_default=True
),
dump_dir: Path = Path("./data/asr_data"),
sample_file: Path = Path("sample_dump.json"),
asr_service: str = "deepgram",
):
# import pandas as pd
# from pydub import AudioSegment
from ..utils import is_sub_sequence, Text2Num
# from ..utils import deepgram_transcribe_gen
#
# deepgram_transcriber = deepgram_transcribe_gen()
t2n = Text2Num()
# processed_data_path = dump_dir / Path(data_name) / dump_file
sample_path = dump_dir / Path(data_name) / sample_file
processed_data = ExtendedPath(sample_path).read_json()
# asr_data = []
match_count, total_samples = 0, len(processed_data["data"])
for dp in tqdm(processed_data["data"]):
# aud_data = Path(dp["audio_path"]).read_bytes()
# dgram_result = deepgram_transcriber(aud_data)
# dp["deepgram_asr"] = dgram_result
gcp_num = dp["text"]
dgm_num = t2n.convert(dp["deepgram_asr"].lower())
if is_sub_sequence(gcp_num, dgm_num):
match_count += 1
print(f"MATCH GCP:{gcp_num}\tDGM:{dgm_num}")
else:
print(f"FAIL GCP:{gcp_num}\tDGM:{dgm_num}")
# asr_data.append(dp)
typer.echo(
f"{match_count} from deepgram matches with {total_samples} gcp transcripts."
)
# processed_data["data"] = asr_data
# ExtendedPath(sample_path).write_json(processed_data)
@app.command() @app.command()
def task_ui( def task_ui(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True), data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"), dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"), dump_file: Path = Path("ui_dump.json"),
task_count: int = typer.Option(4, show_default=True), task_count: int = typer.Option(4, show_default=True),
@@ -223,7 +152,7 @@ def task_ui(
@app.command() @app.command()
def dump_corrections( def dump_corrections(
task_uid: str, task_uid: str,
data_name: str = typer.Option("call_alphanum", show_default=True), data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"), dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("corrections.json"), dump_fname: Path = Path("corrections.json"),
): ):
@@ -241,7 +170,7 @@ def dump_corrections(
@app.command() @app.command()
def caller_quality( def caller_quality(
task_uid: str, task_uid: str,
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True), data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"), dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("ui_dump.json"), dump_fname: Path = Path("ui_dump.json"),
correction_fname: Path = Path("corrections.json"), correction_fname: Path = Path("corrections.json"),
@@ -277,7 +206,7 @@ def caller_quality(
@app.command() @app.command()
def fill_unannotated( def fill_unannotated(
data_name: str = typer.Option("call_alphanum", show_default=True), data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/valiation_data"), dump_dir: Path = Path("./data/valiation_data"),
dump_file: Path = Path("ui_dump.json"), dump_file: Path = Path("ui_dump.json"),
corrections_file: Path = Path("corrections.json"), corrections_file: Path = Path("corrections.json"),
@@ -298,16 +227,9 @@ def fill_unannotated(
) )
class ExtractionType(str, Enum):
date = "dates"
city = "cities"
name = "names"
all = "all"
@app.command() @app.command()
def split_extract( def split_extract(
data_name: str = typer.Option("call_alphanum", show_default=True), data_name: str = typer.Option("dataname", show_default=True),
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True), # dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
# dump_dir: Path = Path("./data/valiation_data"), # dump_dir: Path = Path("./data/valiation_data"),
dump_dir: Path = Path("./data/asr_data"), dump_dir: Path = Path("./data/asr_data"),
@@ -317,7 +239,7 @@ def split_extract(
conv_data_path: Path = typer.Option( conv_data_path: Path = typer.Option(
Path("./data/conv_data.json"), show_default=True Path("./data/conv_data.json"), show_default=True
), ),
extraction_type: ExtractionType = ExtractionType.all, extraction_type: str = "all",
): ):
import shutil import shutil
@@ -381,7 +303,7 @@ def split_extract(
@app.command() @app.command()
def update_corrections( def update_corrections(
data_name: str = typer.Option("call_alphanum", show_default=True), data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"), dump_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"), manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"), corrections_file: Path = Path("corrections.json"),

View File

@@ -42,7 +42,9 @@ if not hasattr(st, "mongo_connected"):
upsert=True, upsert=True,
) )
def set_task_fn(mf_path): def set_task_fn(mf_path, task_id):
if task_id:
st.task_id = task_id
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck") task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
if not task_path.exists(): if not task_path.exists():
print(f"creating task lock at {task_path}") print(f"creating task lock at {task_path}")
@@ -66,26 +68,22 @@ def load_ui_data(validation_ui_data_path: Path):
@app.command() @app.command()
def main(manifest: Path): def main(manifest: Path, task_id: str = ""):
st.set_task(manifest) st.set_task(manifest, task_id)
ui_config = load_ui_data(manifest) ui_config = load_ui_data(manifest)
asr_data = ui_config["data"] asr_data = ui_config["data"]
use_domain_asr = ui_config.get("use_domain_asr", True)
annotation_only = ui_config.get("annotation_only", False) annotation_only = ui_config.get("annotation_only", False)
enable_plots = ui_config.get("enable_plots", True)
sample_no = st.get_current_cursor() sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0: if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0") print("Invalid samplno resetting to 0")
st.update_cursor(0) st.update_cursor(0)
sample = asr_data[sample_no] sample = asr_data[sample_no]
title_type = "Speller " if use_domain_asr else ""
task_uid = st.task_id.rsplit("-", 1)[1] task_uid = st.task_id.rsplit("-", 1)[1]
if annotation_only: if annotation_only:
st.title(f"ASR Annotation - # {task_uid}") st.title(f"ASR Annotation - # {task_uid}")
else: else:
st.title(f"ASR {title_type}Validation - # {task_uid}") st.title(f"ASR Validation - # {task_uid}")
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else "" st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
new_sample = st.number_input( new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
) )
@@ -94,19 +92,13 @@ def main(manifest: Path):
st.sidebar.title(f"Details: [{sample['real_idx']}]") st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**") st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
if not annotation_only: if not annotation_only:
if use_domain_asr:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:") st.sidebar.title("Results:")
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**") st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if "caller" in sample: if "caller" in sample:
st.sidebar.markdown(f"Caller: **{sample['caller']}**") st.sidebar.markdown(f"Caller: **{sample['caller']}**")
if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
else: else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
if enable_plots: st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb")) st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text # set default to text
corrected = sample["text"] corrected = sample["text"]
@@ -128,16 +120,12 @@ def main(manifest: Path):
) )
st.update_cursor(sample_no + 1) st.update_cursor(sample_no + 1)
if correction_entry: if correction_entry:
st.markdown( status = correction_entry["value"]["status"]
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' correction = correction_entry["value"]["correction"]
) st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
text_sample = st.text_input("Go to Text:", value="") text_sample = st.text_input("Go to Text:", value="")
if text_sample != "": if text_sample != "":
candidates = [ candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
i
for (i, p) in enumerate(asr_data)
if p["text"] == text_sample or p["spoken"] == text_sample
]
if len(candidates) > 0: if len(candidates) > 0:
st.update_cursor(candidates[0]) st.update_cursor(candidates[0])
real_idx = st.number_input( real_idx = st.number_input(

View File

@@ -1,58 +0,0 @@
from pathlib import Path
import streamlit as st
import typer
from jasper.data.utils import ExtendedPath
from jasper.data.validation.st_rerun import rerun
app = typer.Typer()
if not hasattr(st, "mongo_connected"):
# st.task_id = str(uuid4())
task_path = ExtendedPath("preview.lck")
def current_cursor_fn():
return task_path.read_json()["current_cursor"]
def update_cursor_fn(val=0):
task_path.write_json({"current_cursor": val})
rerun()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn
st.mongo_connected = True
# cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
# if not cursor_obj:
update_cursor_fn(0)
@st.cache()
def load_ui_data(validation_ui_data_path: Path):
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
@app.command()
def main(manifest: Path):
asr_data = load_ui_data(manifest)
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
st.title(f"ASR Manifest Preview")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
st.audio(Path(sample["audio_filepath"]).open("rb"))
if __name__ == "__main__":
try:
app()
except SystemExit:
pass

View File

@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
requirements = [ requirements = [
"ruamel.yaml", "ruamel.yaml",
"torch==1.4.0", "torch==2.8.0",
"torchvision==0.5.0", "torchvision==0.5.0",
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit", "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
] ]
@@ -19,13 +19,15 @@ extra_requirements = {
"ruamel.yaml==0.16.10", "ruamel.yaml==0.16.10",
"pymongo==3.10.1", "pymongo==3.10.1",
"librosa==0.7.2", "librosa==0.7.2",
"numba==0.48",
"matplotlib==3.2.1", "matplotlib==3.2.1",
"pandas==1.0.3", "pandas==1.0.3",
"tabulate==0.8.7", "tabulate==0.8.7",
"natural==0.2.0", "natural==0.2.0",
"num2words==0.5.10", "num2words==0.5.10",
"typer[all]==0.1.1", "typer[all]==0.3.1",
"python-slugify==4.0.0", "python-slugify==4.0.0",
"rpyc~=4.1.4",
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses", "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
], ],
"validation": [ "validation": [
@@ -68,10 +70,7 @@ setup(
"jasper_data_tts_generate = jasper.data.tts_generator:main", "jasper_data_tts_generate = jasper.data.tts_generator:main",
"jasper_data_conv_generate = jasper.data.conv_generator:main", "jasper_data_conv_generate = jasper.data.conv_generator:main",
"jasper_data_nlu_generate = jasper.data.nlu_generator:main", "jasper_data_nlu_generate = jasper.data.nlu_generator:main",
"jasper_data_test_generate = jasper.data.test_generator:main", "jasper_data_rastrik_recycle = jasper.data.rastrik_recycler:main",
"jasper_data_call_recycle = jasper.data.call_recycler:main",
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
"jasper_data_server = jasper.data.server:main", "jasper_data_server = jasper.data.server:main",
"jasper_data_validation = jasper.data.validation.process:main", "jasper_data_validation = jasper.data.validation.process:main",
"jasper_data_preprocess = jasper.data.process:main", "jasper_data_preprocess = jasper.data.process:main",