1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-03-09 10:52:35 +00:00

refactored module structure

This commit is contained in:
2020-05-21 16:47:45 +05:30
parent 2d5b720284
commit fca9c1aeb3
23 changed files with 17 additions and 115 deletions

1
jasper/data/__init__.py Normal file
View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,93 @@
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
app = typer.Typer()
@app.command()
def extract_data(
call_audio_dir: Path = Path("/dataset/png_prod/call_audio"),
call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"),
output_dir: Path = Path("./data"),
dataset_name: str = "png_gcp_2jan",
verbose: bool = False,
):
from pydub import AudioSegment
from .utils import ExtendedPath, asr_data_writer
from lenses import lens
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_event_generator(call_audio_dir):
for wav_path in call_audio_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
meta_path = call_meta_dir / rel_meta_path
events = ExtendedPath(meta_path).read_json()
yield call_wav, wav_path, events
def contains_asr(x):
return "AsrResult" in x
def channel(n):
def filter_func(ev):
return (
ev["AsrResult"]["Channel"] == n
if "Channel" in ev["AsrResult"]
else n == 0
)
return filter_func
def compute_endtime(call_wav, state):
for (i, st) in enumerate(state):
start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0)
transcript = st["AsrResult"]["Alternatives"][0]["Transcript"]
if i + 1 < len(state):
end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
else:
end_time = call_wav.duration_seconds
code_seg = call_wav[start_time * 1000 : end_time * 1000]
code_fb = BytesIO()
code_seg.export(code_fb, format="wav")
code_wav = code_fb.getvalue()
# only of some audio data is present yield it
if code_seg.duration_seconds >= 0.5:
yield transcript, code_seg.duration_seconds, code_wav
def asr_data_generator(call_wav, call_wav_fname, events):
call_wav_0, call_wav_1 = call_wav.split_to_mono()
asr_events = lens["Events"].Each()["Event"].Filter(contains_asr)
call_evs_0 = asr_events.Filter(channel(0)).collect()(events)
call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
if verbose:
typer.echo(f"processing data points on {call_wav_fname}")
call_data_0 = compute_endtime(call_wav_0, call_evs_0)
call_data_1 = compute_endtime(call_wav_1, call_evs_1)
return chain(call_data_0, call_data_1)
def generate_call_asr_data():
full_asr_data = []
total_duration = 0
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
asr_data = asr_data_generator(wav, wav_path, ev)
total_duration += wav.duration_seconds
full_asr_data.append(asr_data)
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
typer.echo(f"written {n_dps} data points")
generate_call_asr_data()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,410 @@
# import argparse
# import logging
import typer
from pathlib import Path
app = typer.Typer()
# leader_app = typer.Typer()
# app.add_typer(leader_app, name="leaderboard")
# plot_app = typer.Typer()
# app.add_typer(plot_app, name="plot")
@app.command()
def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
from .utils import get_mongo_conn
from collections import defaultdict
from ruamel.yaml import YAML
yaml = YAML()
mongo_coll = get_mongo_conn().test.calls
caller_calls = defaultdict(lambda: [])
for call in mongo_coll.find():
sysid = call["SystemID"]
call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}"
caller = call["Caller"]
caller_calls[caller].append(call_uri)
caller_list = []
for caller in caller_calls:
caller_list.append({"name": caller, "calls": caller_calls[caller]})
output_yaml = {"users": caller_list}
typer.echo("exporting call logs to yaml file")
with call_logs_file.open("w") as yf:
yaml.dump(output_yaml, yf)
@app.command()
def export_calls_between(
start_cid: str,
end_cid: str,
call_logs_file: Path = Path("./call_sia_logs.yaml"),
mongo_port: int = 27017,
):
from collections import defaultdict
from ruamel.yaml import YAML
from .utils import get_mongo_conn
yaml = YAML()
mongo_coll = get_mongo_conn(port=mongo_port).test.calls
start_meta = mongo_coll.find_one({"SystemID": start_cid})
end_meta = mongo_coll.find_one({"SystemID": end_cid})
caller_calls = defaultdict(lambda: [])
call_query = mongo_coll.find(
{
"StartTS": {"$gte": start_meta["StartTS"]},
"EndTS": {"$lte": end_meta["EndTS"]},
}
)
for call in call_query:
sysid = call["SystemID"]
call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}"
caller = call["Caller"]
caller_calls[caller].append(call_uri)
caller_list = []
for caller in caller_calls:
caller_list.append({"name": caller, "calls": caller_calls[caller]})
output_yaml = {"users": caller_list}
typer.echo("exporting call logs to yaml file")
with call_logs_file.open("w") as yf:
yaml.dump(output_yaml, yf)
@app.command()
def analyze(
leaderboard: bool = False,
plot_calls: bool = False,
extract_data: bool = False,
download_only: bool = False,
call_logs_file: Path = Path("./call_logs.yaml"),
output_dir: Path = Path("./data"),
mongo_port: int = 27017,
):
from urllib.parse import urlsplit
from functools import reduce
import boto3
from io import BytesIO
import json
from ruamel.yaml import YAML
import re
from google.protobuf.timestamp_pb2 import Timestamp
from datetime import timedelta
# from concurrent.futures import ThreadPoolExecutor
import librosa
import librosa.display
from lenses import lens
from pprint import pprint
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from tqdm import tqdm
from .utils import asr_data_writer, get_mongo_conn
from pydub import AudioSegment
from natural.date import compress
# from itertools import product, chain
matplotlib.rcParams["agg.path.chunksize"] = 10000
matplotlib.use("agg")
# logging.basicConfig(
# level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# )
# logger = logging.getLogger(__name__)
yaml = YAML()
s3 = boto3.client("s3")
mongo_collection = get_mongo_conn(port=mongo_port).test.calls
call_media_dir: Path = output_dir / Path("call_wavs")
call_media_dir.mkdir(exist_ok=True, parents=True)
call_meta_dir: Path = output_dir / Path("call_metas")
call_meta_dir.mkdir(exist_ok=True, parents=True)
call_plot_dir: Path = output_dir / Path("plots")
call_plot_dir.mkdir(exist_ok=True, parents=True)
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
call_logs = yaml.load(call_logs_file.read_text())
def get_call_meta(call_obj):
meta_s3_uri = call_obj["DataURI"]
s3_event_url_p = urlsplit(meta_s3_uri)
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
if not saved_meta_path.exists():
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
s3.download_file(
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
)
call_metas = json.load(saved_meta_path.open())
return call_metas
def gen_ev_fev_timedelta(fev):
fev_p = Timestamp()
fev_p.FromJsonString(fev["CreatedTS"])
fev_dt = fev_p.ToDatetime()
td_0 = timedelta()
def get_timedelta(ev):
ev_p = Timestamp()
ev_p.FromJsonString(value=ev["CreatedTS"])
ev_dt = ev_p.ToDatetime()
delta = ev_dt - fev_dt
return delta if delta > td_0 else td_0
return get_timedelta
def chunk_n(evs, n):
return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
def get_data_points(utter_events, td_fn):
data_points = []
for evs in chunk_n(utter_events, 3):
assert evs[0]["Type"] == "CONV_RESULT"
assert evs[1]["Type"] == "STARTED_SPEAKING"
assert evs[2]["Type"] == "STOPPED_SPEAKING"
start_time = td_fn(evs[1]).total_seconds() - 1.5
end_time = td_fn(evs[2]).total_seconds()
code = evs[0]["Msg"]
data_points.append(
{"start_time": start_time, "end_time": end_time, "code": code}
)
return data_points
def process_call(call_obj):
call_meta = get_call_meta(call_obj)
call_events = call_meta["Events"]
def is_writer_uri_event(ev):
return ev["Author"] == "AUDIO_WRITER" and 's3://' in ev["Msg"]
writer_events = list(filter(is_writer_uri_event, call_events))
s3_wav_url = re.search(r"(s3://.*)", writer_events[0]["Msg"]).groups(0)[0]
s3_wav_url_p = urlsplit(s3_wav_url)
def is_first_audio_ev(state, ev):
if state[0]:
return state
else:
return (ev["Author"] == "GATEWAY" and ev["Type"] == "AUDIO", ev)
(_, first_audio_ev) = reduce(is_first_audio_ev, call_events, (False, {}))
get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
def is_utter_event(ev):
return (
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
and (ev["Type"] != "DEBUG")
and ev["Type"] != "ASR_RESULT"
)
uevs = list(filter(is_utter_event, call_events))
ev_count = len(uevs)
utter_events = uevs[: ev_count - ev_count % 3]
saved_wav_path = call_media_dir / Path(Path(s3_wav_url_p.path).name)
if not saved_wav_path.exists():
print(f"downloading : {saved_wav_path} from {s3_wav_url}")
s3.download_file(
s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
)
# %config InlineBackend.figure_format = "retina"
def plot_events(y, sr, utter_events, file_path):
plt.figure(figsize=(16, 12))
librosa.display.waveplot(y=y, sr=sr)
# plt.tight_layout()
for evs in chunk_n(utter_events, 3):
assert evs[0]["Type"] == "CONV_RESULT"
assert evs[1]["Type"] == "STARTED_SPEAKING"
assert evs[2]["Type"] == "STOPPED_SPEAKING"
for ev in evs:
# print(ev["Type"])
ev_type = ev["Type"]
pos = get_ev_fev_timedelta(ev).total_seconds()
if ev_type == "STARTED_SPEAKING":
pos = pos - 1.5
plt.axvline(pos) # , label="pyplot vertical line")
plt.text(
pos,
0.2,
f"event:{ev_type}:{ev['Msg']}",
rotation=90,
horizontalalignment="left"
if ev_type != "STOPPED_SPEAKING"
else "right",
verticalalignment="center",
)
plt.title("Monophonic")
plt.savefig(file_path, format="png")
return {
"wav_path": saved_wav_path,
"num_samples": len(utter_events) // 3,
"meta": call_obj,
"first_event_fn": get_ev_fev_timedelta,
"utter_events": utter_events,
}
def get_cid(uri):
return Path(urlsplit(uri).path).stem
def ensure_call(uri):
cid = get_cid(uri)
meta = mongo_collection.find_one({"SystemID": cid})
process_meta = process_call(meta)
return process_meta
def retrieve_processed_callmeta(uri):
cid = get_cid(uri)
meta = mongo_collection.find_one({"SystemID": cid})
duration = meta["EndTS"] - meta["StartTS"]
process_meta = process_call(meta)
data_points = get_data_points(process_meta['utter_events'], process_meta['first_event_fn'])
process_meta['data_points'] = data_points
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
def download_meta_audio():
call_lens = lens["users"].Each()["calls"].Each()
call_lens.modify(ensure_call)(call_logs)
# @plot_app.command()
def plot_calls_data():
def plot_data_points(y, sr, data_points, file_path):
plt.figure(figsize=(16, 12))
librosa.display.waveplot(y=y, sr=sr)
for dp in data_points:
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
plt.axvspan(start, end, color="green", alpha=0.2)
text_pos = (start + end) / 2
plt.text(
text_pos,
0.25,
f"{code}",
rotation=90,
horizontalalignment="center",
verticalalignment="center",
)
plt.title("Datapoints")
plt.savefig(file_path, format="png")
return file_path
def plot_call(call_obj):
saved_wav_path, data_points, sys_id = (
call_obj["process"]["wav_path"],
call_obj["process"]["data_points"],
call_obj["meta"]["SystemID"],
)
file_path = call_plot_dir / Path(sys_id).with_suffix(".png")
if not file_path.exists():
print(f"plotting: {file_path}")
(y, sr) = librosa.load(saved_wav_path)
plot_data_points(y, sr, data_points, str(file_path))
return file_path
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
# call_plot_data = call_lens.collect()(call_stats)
call_plots = call_lens.modify(plot_call)(call_stats)
# with ThreadPoolExecutor(max_workers=20) as exe:
# print('starting all plot tasks')
# responses = [exe.submit(plot_call, w) for w in call_plot_data]
# print('submitted all plot tasks')
# call_plots = [r.result() for r in responses]
pprint(call_plots)
def extract_data_points():
def gen_data_values(saved_wav_path, data_points):
call_seg = (
AudioSegment.from_wav(saved_wav_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
for dp_id, dp in enumerate(data_points):
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
code_seg = call_seg[start * 1000 : end * 1000]
code_fb = BytesIO()
code_seg.export(code_fb, format="wav")
code_wav = code_fb.getvalue()
# search for actual pnr code and handle plain codes as well
extracted_code = (
re.search(r"'(.*)'", code).groups(0)[0] if len(code) > 6 else code
)
yield extracted_code, code_seg.duration_seconds, code_wav
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
call_objs = call_lens.collect()(call_stats)
def data_source():
for call_obj in tqdm(call_objs):
saved_wav_path, data_points, sys_id = (
call_obj["process"]["wav_path"],
call_obj["process"]["data_points"],
call_obj["meta"]["SystemID"],
)
for dp in gen_data_values(saved_wav_path, data_points):
yield dp
asr_data_writer(call_asr_data, "call_alphanum", data_source())
def show_leaderboard():
def compute_user_stats(call_stat):
n_samples = (
lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
)
n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
return {
"num_samples": n_samples,
"duration": n_duration.total_seconds(),
"samples_rate": n_samples / n_duration.total_seconds(),
"duration_str": compress(n_duration, pad=" "),
"name": call_stat["name"],
}
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
user_stats = lens["users"].Each().modify(compute_user_stats)(call_stats)
leader_df = (
pd.DataFrame(user_stats["users"])
.sort_values(by=["duration"], ascending=False)
.reset_index(drop=True)
)
leader_df["rank"] = leader_df.index + 1
leader_board = leader_df.rename(
columns={
"rank": "Rank",
"num_samples": "Codes",
"name": "Name",
"samples_rate": "SpeechRate",
"duration_str": "Duration",
}
)[["Rank", "Name", "Codes", "Duration"]]
print(
"""ASR Speller Dataset Leaderboard :
---------------------------------"""
)
print(leader_board.to_string(index=False))
if download_only:
download_meta_audio()
return
if leaderboard:
show_leaderboard()
if plot_calls:
plot_calls_data()
if extract_data:
extract_data_points()
def main():
app()
if __name__ == "__main__":
main()

69
jasper/data/process.py Normal file
View File

@@ -0,0 +1,69 @@
import json
from pathlib import Path
from sklearn.model_selection import train_test_split
from .utils import asr_manifest_reader, asr_manifest_writer
from typing import List
from itertools import chain
import typer
app = typer.Typer()
@app.command()
def fixate_data(dataset_path: Path):
manifest_path = dataset_path / Path("manifest.json")
real_manifest_path = dataset_path / Path("abs_manifest.json")
def fix_path():
for i in asr_manifest_reader(manifest_path):
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
yield i
asr_manifest_writer(real_manifest_path, fix_path())
@app.command()
def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
reader_list = []
abs_manifest_path = Path("abs_manifest.json")
for dataset_path in src_dataset_paths:
manifest_path = dataset_path / abs_manifest_path
reader_list.append(asr_manifest_reader(manifest_path))
dest_dataset_path.mkdir(parents=True, exist_ok=True)
dest_manifest_path = dest_dataset_path / abs_manifest_path
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
@app.command()
def split_data(dataset_path: Path, test_size: float = 0.1):
manifest_path = dataset_path / Path("abs_manifest.json")
asr_data = list(asr_manifest_reader(manifest_path))
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
@app.command()
def validate_data(dataset_path: Path):
for mf_type in ["train_manifest.json", "test_manifest.json"]:
data_file = dataset_path / Path(mf_type)
print(f"validating {data_file}.")
with Path(data_file).open("r") as pf:
pnr_jsonl = pf.readlines()
for (i, s) in enumerate(pnr_jsonl):
try:
d = json.loads(s)
audio_file = data_file.parent / Path(d["audio_filepath"])
if not audio_file.exists():
raise OSError(f"File {audio_file} not found")
except BaseException as e:
print(f'failed on {i} with "{e}"')
print(f"no errors found. seems like a valid {mf_type}.")
def main():
app()
if __name__ == "__main__":
main()

57
jasper/data/server.py Normal file
View File

@@ -0,0 +1,57 @@
import os
from pathlib import Path
import typer
import rpyc
from rpyc.utils.server import ThreadedServer
import nemo
import pickle
# import nemo.collections.asr as nemo_asr
from nemo.collections.asr.parts.segment import AudioSegment
app = typer.Typer()
nemo.core.NeuralModuleFactory(
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
)
class ASRDataService(rpyc.Service):
def exposed_get_path_samples(
self, file_path, target_sr, int_values, offset, duration, trim
):
print(f"loading.. {file_path}")
audio = AudioSegment.from_file(
file_path,
target_sr=target_sr,
int_values=int_values,
offset=offset,
duration=duration,
trim=trim,
)
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
return pickle.dumps(audio.samples)
def exposed_read_path(self, file_path):
# print(f"reading path.. {file_path}")
return Path(file_path).read_bytes()
@app.command()
def run_server(port: int = 0):
listen_port = port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
service = ASRDataService()
t = ThreadedServer(
service, port=listen_port, protocol_config={"allow_all_attrs": True}
)
typer.echo(f"starting asr server on {listen_port}...")
t.start()
def main():
app()
if __name__ == "__main__":
main()

View File

View File

@@ -0,0 +1,52 @@
from logging import getLogger
from google.cloud import texttospeech
LOGGER = getLogger("googletts")
class GoogleTTS(object):
def __init__(self):
self.client = texttospeech.TextToSpeechClient()
def text_to_speech(self, text: str, params: dict) -> bytes:
tts_input = texttospeech.types.SynthesisInput(ssml=text)
voice = texttospeech.types.VoiceSelectionParams(
language_code=params["language"], name=params["name"]
)
audio_config = texttospeech.types.AudioConfig(
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
sample_rate_hertz=params["sample_rate"],
)
response = self.client.synthesize_speech(tts_input, voice, audio_config)
audio_content = response.audio_content
return audio_content
@classmethod
def voice_list(cls):
"""Lists the available voices."""
client = cls().client
# Performs the list voices request
voices = client.list_voices()
results = []
for voice in voices.voices:
supported_eng_langs = [
lang for lang in voice.language_codes if lang[:2] == "en"
]
if len(supported_eng_langs) > 0:
lang = ",".join(supported_eng_langs)
else:
continue
ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender)
results.append(
{
"name": voice.name,
"language": lang,
"gender": ssml_gender.name,
"engine": "wavenet" if "Wav" in voice.name else "standard",
"sample_rate": voice.natural_sample_rate_hertz,
}
)
return results

View File

@@ -0,0 +1,26 @@
"""
TTSClient Abstract Class
"""
from abc import ABC, abstractmethod
class TTSClient(ABC):
"""
Base class for TTS
"""
@abstractmethod
def text_to_speech(self, text: str, num_channels: int, sample_rate: int,
audio_encoding) -> bytes:
"""
convert text to bytes
Arguments:
text {[type]} -- text to convert
channel {[type]} -- output audio bytes channel setting
width {[type]} -- width of audio bytes
rate {[type]} -- rare for audio bytes
Returns:
[type] -- [description]
"""

View File

@@ -0,0 +1,62 @@
# import io
# import sys
# import json
import argparse
import logging
from pathlib import Path
from .utils import random_pnr_generator, asr_data_writer
from .tts.googletts import GoogleTTS
from tqdm import tqdm
import random
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def pnr_tts_streamer(count):
google_voices = GoogleTTS.voice_list()
gtts = GoogleTTS()
for pnr_code in tqdm(random_pnr_generator(count)):
tts_code = f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
param = random.choice(google_voices)
param["sample_rate"] = 24000
param["num_channels"] = 1
wav_data = gtts.text_to_speech(text=tts_code, params=param)
audio_dur = len(wav_data[44:]) / (2 * 24000)
yield pnr_code, audio_dur, wav_data
def generate_asr_data_fromtts(output_dir, dataset_name, count):
asr_data_writer(output_dir, dataset_name, pnr_tts_streamer(count))
def arg_parser():
prog = Path(__file__).stem
parser = argparse.ArgumentParser(
prog=prog, description=f"generates asr training data"
)
parser.add_argument(
"--output_dir",
type=Path,
default=Path("./train/asr_data"),
help="directory to output asr data",
)
parser.add_argument(
"--count", type=int, default=3, help="number of datapoints to generate"
)
parser.add_argument(
"--dataset_name", type=str, default="pnr_data", help="name of the dataset"
)
return parser
def main():
parser = arg_parser()
args = parser.parse_args()
generate_asr_data_fromtts(**vars(args))
if __name__ == "__main__":
main()

119
jasper/data/utils.py Normal file
View File

@@ -0,0 +1,119 @@
import numpy as np
import wave
import io
import os
import json
from pathlib import Path
import pymongo
from slugify import slugify
from uuid import uuid4
from num2words import num2words
def manifest_str(path, dur, text):
return (
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
+ "\n"
)
def wav_bytes(audio_bytes, frame_rate=24000):
wf_b = io.BytesIO()
with wave.open(wf_b, mode="w") as wf:
wf.setnchannels(1)
wf.setframerate(frame_rate)
wf.setsampwidth(2)
wf.writeframesraw(audio_bytes)
return wf_b.getvalue()
def random_pnr_generator(count=10000):
LENGTH = 3
# alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
alphabet = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
numeric = list("0123456789")
np_alphabet = np.array(alphabet, dtype="|S1")
np_numeric = np.array(numeric, dtype="|S1")
np_alpha_codes = np.random.choice(np_alphabet, [count, LENGTH])
np_num_codes = np.random.choice(np_numeric, [count, LENGTH])
np_code_seed = np.concatenate((np_alpha_codes, np_num_codes), axis=1).T
np.random.shuffle(np_code_seed)
np_codes = np_code_seed.T
codes = [(b"".join(np_codes[i])).decode("utf-8") for i in range(len(np_codes))]
return codes
def alnum_to_asr_tokens(text):
letters = " ".join(list(text))
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
return ("".join(num_tokens)).lower()
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
with asr_manifest.open("w") as mf:
for transcript, audio_dur, wav_data in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
pnr_af = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
pnr_af.write_bytes(wav_data)
rel_pnr_path = pnr_af.relative_to(dataset_dir)
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
mf.write(manifest)
if verbose:
print(f"writing '{transcript}' of duration {audio_dur}")
num_datapoints += 1
return num_datapoints
def asr_manifest_reader(data_manifest_path: Path):
print(f"reading manifest from {data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [json.loads(v) for v in pnr_jsonl]
for p in pnr_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["chars"] = Path(p["audio_filepath"]).stem
yield p
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
with asr_manifest_path.open("w") as mf:
print(f"opening {asr_manifest_path} for writing manifest")
for mani_dict in manifest_str_source:
manifest = manifest_str(
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
)
mf.write(manifest)
class ExtendedPath(type(Path())):
"""docstring for ExtendedPath."""
def read_json(self):
with self.open("r") as jf:
return json.load(jf)
def write_json(self, data):
self.parent.mkdir(parents=True, exist_ok=True)
with self.open("w") as jf:
return json.dump(data, jf, indent=2)
def get_mongo_conn(host='', port=27017):
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
mongo_uri = f"mongodb://{mongo_host}:{port}/"
return pymongo.MongoClient(mongo_uri)
def main():
for c in random_pnr_generator():
print(c)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,221 @@
import json
import shutil
from pathlib import Path
import typer
from tqdm import tqdm
from ..utils import (
alnum_to_asr_tokens,
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
get_mongo_conn,
)
app = typer.Typer()
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
import matplotlib.pyplot as plt
import librosa
import librosa.display
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.client import (
transcriber_pretrained,
transcriber_speller,
)
try:
res = dict(sample)
res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
res["spoken"] = alnum_to_asr_tokens(res["text"])
res["utterance_id"] = audio_path.stem
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
if use_domain_asr:
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]]
)
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
fig = plt.Figure()
ax = fig.add_subplot()
(y, sr) = librosa.load(audio_path)
librosa.display.waveplot(y=y, sr=sr, ax=ax)
with wav_plot_path.open("wb") as wav_plot_f:
fig.set_tight_layout(True)
fig.savefig(wav_plot_f, format="png", dpi=50)
# fig.close()
res["plot_path"] = str(wav_plot_path)
return res
except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
@app.command()
def dump_validation_ui_data(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
dump_path: Path = Path("./data/valiation_data/ui_dump.json"),
use_domain_asr: bool = True,
):
from concurrent.futures import ThreadPoolExecutor
from functools import partial
plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_funcs = [
partial(
preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
)
for i, v in enumerate(pnr_jsonl)
]
def exec_func(f):
return f()
with ThreadPoolExecutor(max_workers=20) as exe:
print("starting all plot tasks")
pnr_data = filter(
None,
list(
tqdm(
exe.map(exec_func, pnr_funcs),
position=0,
leave=True,
total=len(pnr_funcs),
)
),
)
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
ui_config = {"use_domain_asr": use_domain_asr, "data": result}
ExtendedPath(dump_path).write_json(ui_config)
@app.command()
def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")):
col = get_mongo_conn().test.asr_validation
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@app.command()
def fill_unannotated(
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
):
processed_data = json.load(processed_data_path.open())
corrections = json.load(corrections_path.open())
annotated_codes = {c["code"] for c in corrections}
all_codes = {c["gold_chars"] for c in processed_data}
unann_codes = all_codes - annotated_codes
mongo_conn = get_mongo_conn().test.asr_validation
for c in unann_codes:
mongo_conn.find_one_and_update(
{"type": "correction", "code": c},
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
upsert=True,
)
@app.command()
def update_corrections(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
skip_incorrect: bool = True,
):
def correct_manifest(manifest_data_gen, corrections_path):
corrections = json.load(corrections_path.open())
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
correction_map = {
c["code"]: c["value"]["correction"]
for c in corrections
if c["value"]["status"] == "Incorrect"
}
# for d in manifest_data_gen:
# if d["chars"] in incorrect_set:
# d["audio_path"].unlink()
renamed_set = set()
for d in manifest_data_gen:
if d["chars"] in correct_set:
yield {
"audio_filepath": d["audio_filepath"],
"duration": d["duration"],
"text": d["text"],
}
elif d["chars"] in correction_map:
correct_text = correction_map[d["chars"]]
if skip_incorrect:
print(f'skipping incorrect {d["audio_path"]} corrected to {correct_text}')
else:
renamed_set.add(correct_text)
new_name = str(Path(correct_text).with_suffix(".wav"))
d["audio_path"].replace(d["audio_path"].with_name(new_name))
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
yield {
"audio_filepath": new_filepath,
"duration": d["duration"],
"text": alnum_to_asr_tokens(correct_text),
}
else:
# don't delete if another correction points to an old file
if d["chars"] not in renamed_set:
d["audio_path"].unlink()
else:
print(f'skipping deletion of correction:{d["chars"]}')
typer.echo(f"Using data manifest:{data_manifest_path}")
dataset_dir = data_manifest_path.parent
dataset_name = dataset_dir.name
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
if not backup_dir.exists():
typer.echo(f"backing up to :{backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir))
manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
new_data_manifest_path.replace(data_manifest_path)
@app.command()
def clear_mongo_corrections():
delete = typer.confirm("are you sure you want to clear mongo collection it?")
if delete:
col = get_mongo_conn().test.asr_validation
col.delete_many({"type": "correction"})
typer.echo("deleted mongo collection.")
typer.echo("Aborted")
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,38 @@
import streamlit.ReportThread as ReportThread
from streamlit.ScriptRequestQueue import RerunData
from streamlit.ScriptRunner import RerunException
from streamlit.server.Server import Server
def rerun():
"""Rerun a Streamlit app from the top!"""
widget_states = _get_widget_states()
raise RerunException(RerunData(widget_states))
def _get_widget_states():
# Hack to get the session object from Streamlit.
ctx = ReportThread.get_report_ctx()
session = None
current_server = Server.get_current()
if hasattr(current_server, '_session_infos'):
# Streamlit < 0.56
session_infos = Server.get_current()._session_infos.values()
else:
session_infos = Server.get_current()._session_info_by_id.values()
for session_info in session_infos:
if session_info.session.enqueue == ctx.enqueue:
session = session_info.session
if session is None:
raise RuntimeError(
"Oh noes. Couldn't get your Streamlit Session object"
"Are you doing something fancy with threads?"
)
# Got the session object!
return session._widget_states

View File

@@ -0,0 +1,140 @@
from pathlib import Path
import streamlit as st
import typer
from ..utils import ExtendedPath, get_mongo_conn
from .st_rerun import rerun
app = typer.Typer()
if not hasattr(st, "mongo_connected"):
st.mongoclient = get_mongo_conn().test.asr_validation
mongo_conn = st.mongoclient
def current_cursor_fn():
# mongo_conn = st.mongoclient
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
cursor_val = cursor_obj["cursor"]
return cursor_val
def update_cursor_fn(val=0):
mongo_conn.find_one_and_update(
{"type": "current_cursor"},
{"$set": {"type": "current_cursor", "cursor": val}},
upsert=True,
)
rerun()
def get_correction_entry_fn(code):
# mongo_conn = st.mongoclient
# cursor_obj = mongo_conn.find_one({"type": "correction", "code": code})
# cursor_val = cursor_obj["cursor"]
return mongo_conn.find_one(
{"type": "correction", "code": code}, projection={"_id": False}
)
def update_entry_fn(code, value):
mongo_conn.find_one_and_update(
{"type": "correction", "code": code},
{"$set": {"value": value}},
upsert=True,
)
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
if not cursor_obj:
update_cursor_fn(0)
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn
st.get_correction_entry = get_correction_entry_fn
st.update_entry = update_entry_fn
st.mongo_connected = True
@st.cache()
def load_ui_data(validation_ui_data_path: Path):
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
return ExtendedPath(validation_ui_data_path).read_json()
@app.command()
def main(manifest: Path):
ui_config = load_ui_data(manifest)
asr_data = ui_config["data"]
use_domain_asr = ui_config["use_domain_asr"]
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
title_type = "Speller " if use_domain_asr else ""
st.title(f"ASR {title_type}Validation")
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
if use_domain_asr:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:")
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text
corrected = sample["text"]
correction_entry = st.get_correction_entry(sample["utterance_id"])
selected_idx = 0
options = ("Correct", "Incorrect", "Inaudible")
# if correction entry is present set the corresponding ui defaults
if correction_entry:
selected_idx = options.index(correction_entry["value"]["status"])
corrected = correction_entry["value"]["correction"]
selected = st.radio("The Audio is", options, index=selected_idx)
if selected == "Incorrect":
corrected = st.text_input("Actual:", value=corrected)
if selected == "Inaudible":
corrected = ""
if st.button("Submit"):
correct_code = corrected.replace(" ", "").upper()
st.update_entry(
sample["utterance_id"], {"status": selected, "correction": correct_code}
)
st.update_cursor(sample_no + 1)
if correction_entry:
st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
)
# if st.button("Previous Untagged"):
# pass
# if st.button("Next Untagged"):
# pass
text_sample = st.text_input("Go to Text:", value='')
if text_sample != '':
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample]
if len(candidates) > 0:
st.update_cursor(candidates[0])
real_idx = st.number_input(
"Go to real-index",
value=sample["real_idx"],
min_value=0,
max_value=len(asr_data) - 1,
)
if real_idx != int(sample["real_idx"]):
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
st.update_cursor(idx)
if __name__ == "__main__":
try:
app()
except SystemExit:
pass