mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-03-09 10:52:35 +00:00
refactored module structure
This commit is contained in:
1
jasper/data/__init__.py
Normal file
1
jasper/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
93
jasper/data/asr_recycler.py
Normal file
93
jasper/data/asr_recycler.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import typer
|
||||
from itertools import chain
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def extract_data(
|
||||
call_audio_dir: Path = Path("/dataset/png_prod/call_audio"),
|
||||
call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"),
|
||||
output_dir: Path = Path("./data"),
|
||||
dataset_name: str = "png_gcp_2jan",
|
||||
verbose: bool = False,
|
||||
):
|
||||
from pydub import AudioSegment
|
||||
from .utils import ExtendedPath, asr_data_writer
|
||||
from lenses import lens
|
||||
|
||||
call_asr_data: Path = output_dir / Path("asr_data")
|
||||
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def wav_event_generator(call_audio_dir):
|
||||
for wav_path in call_audio_dir.glob("**/*.wav"):
|
||||
if verbose:
|
||||
typer.echo(f"loading events for file {wav_path}")
|
||||
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
|
||||
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
|
||||
meta_path = call_meta_dir / rel_meta_path
|
||||
events = ExtendedPath(meta_path).read_json()
|
||||
yield call_wav, wav_path, events
|
||||
|
||||
def contains_asr(x):
|
||||
return "AsrResult" in x
|
||||
|
||||
def channel(n):
|
||||
def filter_func(ev):
|
||||
return (
|
||||
ev["AsrResult"]["Channel"] == n
|
||||
if "Channel" in ev["AsrResult"]
|
||||
else n == 0
|
||||
)
|
||||
|
||||
return filter_func
|
||||
|
||||
def compute_endtime(call_wav, state):
|
||||
for (i, st) in enumerate(state):
|
||||
start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0)
|
||||
transcript = st["AsrResult"]["Alternatives"][0]["Transcript"]
|
||||
if i + 1 < len(state):
|
||||
end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
|
||||
else:
|
||||
end_time = call_wav.duration_seconds
|
||||
code_seg = call_wav[start_time * 1000 : end_time * 1000]
|
||||
code_fb = BytesIO()
|
||||
code_seg.export(code_fb, format="wav")
|
||||
code_wav = code_fb.getvalue()
|
||||
# only of some audio data is present yield it
|
||||
if code_seg.duration_seconds >= 0.5:
|
||||
yield transcript, code_seg.duration_seconds, code_wav
|
||||
|
||||
def asr_data_generator(call_wav, call_wav_fname, events):
|
||||
call_wav_0, call_wav_1 = call_wav.split_to_mono()
|
||||
asr_events = lens["Events"].Each()["Event"].Filter(contains_asr)
|
||||
call_evs_0 = asr_events.Filter(channel(0)).collect()(events)
|
||||
call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
|
||||
if verbose:
|
||||
typer.echo(f"processing data points on {call_wav_fname}")
|
||||
call_data_0 = compute_endtime(call_wav_0, call_evs_0)
|
||||
call_data_1 = compute_endtime(call_wav_1, call_evs_1)
|
||||
return chain(call_data_0, call_data_1)
|
||||
|
||||
def generate_call_asr_data():
|
||||
full_asr_data = []
|
||||
total_duration = 0
|
||||
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
|
||||
asr_data = asr_data_generator(wav, wav_path, ev)
|
||||
total_duration += wav.duration_seconds
|
||||
full_asr_data.append(asr_data)
|
||||
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
|
||||
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
|
||||
typer.echo(f"written {n_dps} data points")
|
||||
|
||||
generate_call_asr_data()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
410
jasper/data/call_recycler.py
Normal file
410
jasper/data/call_recycler.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# import argparse
|
||||
|
||||
# import logging
|
||||
import typer
|
||||
from pathlib import Path
|
||||
|
||||
app = typer.Typer()
|
||||
# leader_app = typer.Typer()
|
||||
# app.add_typer(leader_app, name="leaderboard")
|
||||
# plot_app = typer.Typer()
|
||||
# app.add_typer(plot_app, name="plot")
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
|
||||
from .utils import get_mongo_conn
|
||||
from collections import defaultdict
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
yaml = YAML()
|
||||
mongo_coll = get_mongo_conn().test.calls
|
||||
caller_calls = defaultdict(lambda: [])
|
||||
for call in mongo_coll.find():
|
||||
sysid = call["SystemID"]
|
||||
call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}"
|
||||
caller = call["Caller"]
|
||||
caller_calls[caller].append(call_uri)
|
||||
caller_list = []
|
||||
for caller in caller_calls:
|
||||
caller_list.append({"name": caller, "calls": caller_calls[caller]})
|
||||
output_yaml = {"users": caller_list}
|
||||
typer.echo("exporting call logs to yaml file")
|
||||
with call_logs_file.open("w") as yf:
|
||||
yaml.dump(output_yaml, yf)
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_calls_between(
|
||||
start_cid: str,
|
||||
end_cid: str,
|
||||
call_logs_file: Path = Path("./call_sia_logs.yaml"),
|
||||
mongo_port: int = 27017,
|
||||
):
|
||||
from collections import defaultdict
|
||||
from ruamel.yaml import YAML
|
||||
from .utils import get_mongo_conn
|
||||
|
||||
yaml = YAML()
|
||||
mongo_coll = get_mongo_conn(port=mongo_port).test.calls
|
||||
start_meta = mongo_coll.find_one({"SystemID": start_cid})
|
||||
end_meta = mongo_coll.find_one({"SystemID": end_cid})
|
||||
|
||||
caller_calls = defaultdict(lambda: [])
|
||||
call_query = mongo_coll.find(
|
||||
{
|
||||
"StartTS": {"$gte": start_meta["StartTS"]},
|
||||
"EndTS": {"$lte": end_meta["EndTS"]},
|
||||
}
|
||||
)
|
||||
for call in call_query:
|
||||
sysid = call["SystemID"]
|
||||
call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}"
|
||||
caller = call["Caller"]
|
||||
caller_calls[caller].append(call_uri)
|
||||
caller_list = []
|
||||
for caller in caller_calls:
|
||||
caller_list.append({"name": caller, "calls": caller_calls[caller]})
|
||||
output_yaml = {"users": caller_list}
|
||||
typer.echo("exporting call logs to yaml file")
|
||||
with call_logs_file.open("w") as yf:
|
||||
yaml.dump(output_yaml, yf)
|
||||
|
||||
|
||||
@app.command()
|
||||
def analyze(
|
||||
leaderboard: bool = False,
|
||||
plot_calls: bool = False,
|
||||
extract_data: bool = False,
|
||||
download_only: bool = False,
|
||||
call_logs_file: Path = Path("./call_logs.yaml"),
|
||||
output_dir: Path = Path("./data"),
|
||||
mongo_port: int = 27017,
|
||||
):
|
||||
|
||||
from urllib.parse import urlsplit
|
||||
from functools import reduce
|
||||
import boto3
|
||||
|
||||
from io import BytesIO
|
||||
import json
|
||||
from ruamel.yaml import YAML
|
||||
import re
|
||||
from google.protobuf.timestamp_pb2 import Timestamp
|
||||
from datetime import timedelta
|
||||
|
||||
# from concurrent.futures import ThreadPoolExecutor
|
||||
import librosa
|
||||
import librosa.display
|
||||
from lenses import lens
|
||||
from pprint import pprint
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
from tqdm import tqdm
|
||||
from .utils import asr_data_writer, get_mongo_conn
|
||||
from pydub import AudioSegment
|
||||
from natural.date import compress
|
||||
|
||||
# from itertools import product, chain
|
||||
|
||||
matplotlib.rcParams["agg.path.chunksize"] = 10000
|
||||
|
||||
matplotlib.use("agg")
|
||||
|
||||
# logging.basicConfig(
|
||||
# level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
# )
|
||||
# logger = logging.getLogger(__name__)
|
||||
yaml = YAML()
|
||||
s3 = boto3.client("s3")
|
||||
mongo_collection = get_mongo_conn(port=mongo_port).test.calls
|
||||
call_media_dir: Path = output_dir / Path("call_wavs")
|
||||
call_media_dir.mkdir(exist_ok=True, parents=True)
|
||||
call_meta_dir: Path = output_dir / Path("call_metas")
|
||||
call_meta_dir.mkdir(exist_ok=True, parents=True)
|
||||
call_plot_dir: Path = output_dir / Path("plots")
|
||||
call_plot_dir.mkdir(exist_ok=True, parents=True)
|
||||
call_asr_data: Path = output_dir / Path("asr_data")
|
||||
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
call_logs = yaml.load(call_logs_file.read_text())
|
||||
|
||||
def get_call_meta(call_obj):
|
||||
meta_s3_uri = call_obj["DataURI"]
|
||||
s3_event_url_p = urlsplit(meta_s3_uri)
|
||||
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
|
||||
if not saved_meta_path.exists():
|
||||
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
|
||||
s3.download_file(
|
||||
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
|
||||
)
|
||||
call_metas = json.load(saved_meta_path.open())
|
||||
return call_metas
|
||||
|
||||
def gen_ev_fev_timedelta(fev):
|
||||
fev_p = Timestamp()
|
||||
fev_p.FromJsonString(fev["CreatedTS"])
|
||||
fev_dt = fev_p.ToDatetime()
|
||||
td_0 = timedelta()
|
||||
|
||||
def get_timedelta(ev):
|
||||
ev_p = Timestamp()
|
||||
ev_p.FromJsonString(value=ev["CreatedTS"])
|
||||
ev_dt = ev_p.ToDatetime()
|
||||
delta = ev_dt - fev_dt
|
||||
return delta if delta > td_0 else td_0
|
||||
|
||||
return get_timedelta
|
||||
|
||||
def chunk_n(evs, n):
|
||||
return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
|
||||
|
||||
def get_data_points(utter_events, td_fn):
|
||||
data_points = []
|
||||
for evs in chunk_n(utter_events, 3):
|
||||
assert evs[0]["Type"] == "CONV_RESULT"
|
||||
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
||||
assert evs[2]["Type"] == "STOPPED_SPEAKING"
|
||||
start_time = td_fn(evs[1]).total_seconds() - 1.5
|
||||
end_time = td_fn(evs[2]).total_seconds()
|
||||
code = evs[0]["Msg"]
|
||||
data_points.append(
|
||||
{"start_time": start_time, "end_time": end_time, "code": code}
|
||||
)
|
||||
return data_points
|
||||
|
||||
def process_call(call_obj):
|
||||
call_meta = get_call_meta(call_obj)
|
||||
call_events = call_meta["Events"]
|
||||
|
||||
def is_writer_uri_event(ev):
|
||||
return ev["Author"] == "AUDIO_WRITER" and 's3://' in ev["Msg"]
|
||||
|
||||
writer_events = list(filter(is_writer_uri_event, call_events))
|
||||
s3_wav_url = re.search(r"(s3://.*)", writer_events[0]["Msg"]).groups(0)[0]
|
||||
s3_wav_url_p = urlsplit(s3_wav_url)
|
||||
|
||||
def is_first_audio_ev(state, ev):
|
||||
if state[0]:
|
||||
return state
|
||||
else:
|
||||
return (ev["Author"] == "GATEWAY" and ev["Type"] == "AUDIO", ev)
|
||||
|
||||
(_, first_audio_ev) = reduce(is_first_audio_ev, call_events, (False, {}))
|
||||
|
||||
get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
|
||||
|
||||
def is_utter_event(ev):
|
||||
return (
|
||||
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
|
||||
and (ev["Type"] != "DEBUG")
|
||||
and ev["Type"] != "ASR_RESULT"
|
||||
)
|
||||
|
||||
uevs = list(filter(is_utter_event, call_events))
|
||||
ev_count = len(uevs)
|
||||
utter_events = uevs[: ev_count - ev_count % 3]
|
||||
saved_wav_path = call_media_dir / Path(Path(s3_wav_url_p.path).name)
|
||||
if not saved_wav_path.exists():
|
||||
print(f"downloading : {saved_wav_path} from {s3_wav_url}")
|
||||
s3.download_file(
|
||||
s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
|
||||
)
|
||||
|
||||
# %config InlineBackend.figure_format = "retina"
|
||||
|
||||
def plot_events(y, sr, utter_events, file_path):
|
||||
plt.figure(figsize=(16, 12))
|
||||
librosa.display.waveplot(y=y, sr=sr)
|
||||
# plt.tight_layout()
|
||||
for evs in chunk_n(utter_events, 3):
|
||||
assert evs[0]["Type"] == "CONV_RESULT"
|
||||
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
||||
assert evs[2]["Type"] == "STOPPED_SPEAKING"
|
||||
for ev in evs:
|
||||
# print(ev["Type"])
|
||||
ev_type = ev["Type"]
|
||||
pos = get_ev_fev_timedelta(ev).total_seconds()
|
||||
if ev_type == "STARTED_SPEAKING":
|
||||
pos = pos - 1.5
|
||||
plt.axvline(pos) # , label="pyplot vertical line")
|
||||
plt.text(
|
||||
pos,
|
||||
0.2,
|
||||
f"event:{ev_type}:{ev['Msg']}",
|
||||
rotation=90,
|
||||
horizontalalignment="left"
|
||||
if ev_type != "STOPPED_SPEAKING"
|
||||
else "right",
|
||||
verticalalignment="center",
|
||||
)
|
||||
plt.title("Monophonic")
|
||||
plt.savefig(file_path, format="png")
|
||||
|
||||
return {
|
||||
"wav_path": saved_wav_path,
|
||||
"num_samples": len(utter_events) // 3,
|
||||
"meta": call_obj,
|
||||
"first_event_fn": get_ev_fev_timedelta,
|
||||
"utter_events": utter_events,
|
||||
}
|
||||
|
||||
def get_cid(uri):
|
||||
return Path(urlsplit(uri).path).stem
|
||||
|
||||
def ensure_call(uri):
|
||||
cid = get_cid(uri)
|
||||
meta = mongo_collection.find_one({"SystemID": cid})
|
||||
process_meta = process_call(meta)
|
||||
return process_meta
|
||||
|
||||
def retrieve_processed_callmeta(uri):
|
||||
cid = get_cid(uri)
|
||||
meta = mongo_collection.find_one({"SystemID": cid})
|
||||
duration = meta["EndTS"] - meta["StartTS"]
|
||||
process_meta = process_call(meta)
|
||||
data_points = get_data_points(process_meta['utter_events'], process_meta['first_event_fn'])
|
||||
process_meta['data_points'] = data_points
|
||||
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
|
||||
|
||||
def download_meta_audio():
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_lens.modify(ensure_call)(call_logs)
|
||||
|
||||
# @plot_app.command()
|
||||
def plot_calls_data():
|
||||
def plot_data_points(y, sr, data_points, file_path):
|
||||
plt.figure(figsize=(16, 12))
|
||||
librosa.display.waveplot(y=y, sr=sr)
|
||||
for dp in data_points:
|
||||
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
|
||||
plt.axvspan(start, end, color="green", alpha=0.2)
|
||||
text_pos = (start + end) / 2
|
||||
plt.text(
|
||||
text_pos,
|
||||
0.25,
|
||||
f"{code}",
|
||||
rotation=90,
|
||||
horizontalalignment="center",
|
||||
verticalalignment="center",
|
||||
)
|
||||
plt.title("Datapoints")
|
||||
plt.savefig(file_path, format="png")
|
||||
return file_path
|
||||
|
||||
def plot_call(call_obj):
|
||||
saved_wav_path, data_points, sys_id = (
|
||||
call_obj["process"]["wav_path"],
|
||||
call_obj["process"]["data_points"],
|
||||
call_obj["meta"]["SystemID"],
|
||||
)
|
||||
file_path = call_plot_dir / Path(sys_id).with_suffix(".png")
|
||||
if not file_path.exists():
|
||||
print(f"plotting: {file_path}")
|
||||
(y, sr) = librosa.load(saved_wav_path)
|
||||
plot_data_points(y, sr, data_points, str(file_path))
|
||||
return file_path
|
||||
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
|
||||
# call_plot_data = call_lens.collect()(call_stats)
|
||||
call_plots = call_lens.modify(plot_call)(call_stats)
|
||||
# with ThreadPoolExecutor(max_workers=20) as exe:
|
||||
# print('starting all plot tasks')
|
||||
# responses = [exe.submit(plot_call, w) for w in call_plot_data]
|
||||
# print('submitted all plot tasks')
|
||||
# call_plots = [r.result() for r in responses]
|
||||
pprint(call_plots)
|
||||
|
||||
def extract_data_points():
|
||||
def gen_data_values(saved_wav_path, data_points):
|
||||
call_seg = (
|
||||
AudioSegment.from_wav(saved_wav_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
for dp_id, dp in enumerate(data_points):
|
||||
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
|
||||
code_seg = call_seg[start * 1000 : end * 1000]
|
||||
code_fb = BytesIO()
|
||||
code_seg.export(code_fb, format="wav")
|
||||
code_wav = code_fb.getvalue()
|
||||
# search for actual pnr code and handle plain codes as well
|
||||
extracted_code = (
|
||||
re.search(r"'(.*)'", code).groups(0)[0] if len(code) > 6 else code
|
||||
)
|
||||
yield extracted_code, code_seg.duration_seconds, code_wav
|
||||
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
|
||||
call_objs = call_lens.collect()(call_stats)
|
||||
|
||||
def data_source():
|
||||
for call_obj in tqdm(call_objs):
|
||||
saved_wav_path, data_points, sys_id = (
|
||||
call_obj["process"]["wav_path"],
|
||||
call_obj["process"]["data_points"],
|
||||
call_obj["meta"]["SystemID"],
|
||||
)
|
||||
for dp in gen_data_values(saved_wav_path, data_points):
|
||||
yield dp
|
||||
|
||||
asr_data_writer(call_asr_data, "call_alphanum", data_source())
|
||||
|
||||
def show_leaderboard():
|
||||
def compute_user_stats(call_stat):
|
||||
n_samples = (
|
||||
lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
|
||||
)
|
||||
n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
|
||||
return {
|
||||
"num_samples": n_samples,
|
||||
"duration": n_duration.total_seconds(),
|
||||
"samples_rate": n_samples / n_duration.total_seconds(),
|
||||
"duration_str": compress(n_duration, pad=" "),
|
||||
"name": call_stat["name"],
|
||||
}
|
||||
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
|
||||
user_stats = lens["users"].Each().modify(compute_user_stats)(call_stats)
|
||||
leader_df = (
|
||||
pd.DataFrame(user_stats["users"])
|
||||
.sort_values(by=["duration"], ascending=False)
|
||||
.reset_index(drop=True)
|
||||
)
|
||||
leader_df["rank"] = leader_df.index + 1
|
||||
leader_board = leader_df.rename(
|
||||
columns={
|
||||
"rank": "Rank",
|
||||
"num_samples": "Codes",
|
||||
"name": "Name",
|
||||
"samples_rate": "SpeechRate",
|
||||
"duration_str": "Duration",
|
||||
}
|
||||
)[["Rank", "Name", "Codes", "Duration"]]
|
||||
print(
|
||||
"""ASR Speller Dataset Leaderboard :
|
||||
---------------------------------"""
|
||||
)
|
||||
print(leader_board.to_string(index=False))
|
||||
|
||||
if download_only:
|
||||
download_meta_audio()
|
||||
return
|
||||
if leaderboard:
|
||||
show_leaderboard()
|
||||
if plot_calls:
|
||||
plot_calls_data()
|
||||
if extract_data:
|
||||
extract_data_points()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
69
jasper/data/process.py
Normal file
69
jasper/data/process.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from sklearn.model_selection import train_test_split
|
||||
from .utils import asr_manifest_reader, asr_manifest_writer
|
||||
from typing import List
|
||||
from itertools import chain
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def fixate_data(dataset_path: Path):
|
||||
manifest_path = dataset_path / Path("manifest.json")
|
||||
real_manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
|
||||
def fix_path():
|
||||
for i in asr_manifest_reader(manifest_path):
|
||||
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
|
||||
yield i
|
||||
|
||||
asr_manifest_writer(real_manifest_path, fix_path())
|
||||
|
||||
|
||||
@app.command()
|
||||
def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||
reader_list = []
|
||||
abs_manifest_path = Path("abs_manifest.json")
|
||||
for dataset_path in src_dataset_paths:
|
||||
manifest_path = dataset_path / abs_manifest_path
|
||||
reader_list.append(asr_manifest_reader(manifest_path))
|
||||
dest_dataset_path.mkdir(parents=True, exist_ok=True)
|
||||
dest_manifest_path = dest_dataset_path / abs_manifest_path
|
||||
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
|
||||
|
||||
|
||||
@app.command()
|
||||
def split_data(dataset_path: Path, test_size: float = 0.1):
|
||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
asr_data = list(asr_manifest_reader(manifest_path))
|
||||
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
|
||||
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
|
||||
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
|
||||
|
||||
|
||||
@app.command()
|
||||
def validate_data(dataset_path: Path):
|
||||
for mf_type in ["train_manifest.json", "test_manifest.json"]:
|
||||
data_file = dataset_path / Path(mf_type)
|
||||
print(f"validating {data_file}.")
|
||||
with Path(data_file).open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
for (i, s) in enumerate(pnr_jsonl):
|
||||
try:
|
||||
d = json.loads(s)
|
||||
audio_file = data_file.parent / Path(d["audio_filepath"])
|
||||
if not audio_file.exists():
|
||||
raise OSError(f"File {audio_file} not found")
|
||||
except BaseException as e:
|
||||
print(f'failed on {i} with "{e}"')
|
||||
print(f"no errors found. seems like a valid {mf_type}.")
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
57
jasper/data/server.py
Normal file
57
jasper/data/server.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import rpyc
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import nemo
|
||||
import pickle
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
|
||||
)
|
||||
|
||||
|
||||
class ASRDataService(rpyc.Service):
|
||||
def exposed_get_path_samples(
|
||||
self, file_path, target_sr, int_values, offset, duration, trim
|
||||
):
|
||||
print(f"loading.. {file_path}")
|
||||
audio = AudioSegment.from_file(
|
||||
file_path,
|
||||
target_sr=target_sr,
|
||||
int_values=int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
|
||||
return pickle.dumps(audio.samples)
|
||||
|
||||
def exposed_read_path(self, file_path):
|
||||
# print(f"reading path.. {file_path}")
|
||||
return Path(file_path).read_bytes()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run_server(port: int = 0):
|
||||
listen_port = port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
|
||||
service = ASRDataService()
|
||||
t = ThreadedServer(
|
||||
service, port=listen_port, protocol_config={"allow_all_attrs": True}
|
||||
)
|
||||
typer.echo(f"starting asr server on {listen_port}...")
|
||||
t.start()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
jasper/data/tts/__init__.py
Normal file
0
jasper/data/tts/__init__.py
Normal file
52
jasper/data/tts/googletts.py
Normal file
52
jasper/data/tts/googletts.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from logging import getLogger
|
||||
from google.cloud import texttospeech
|
||||
|
||||
LOGGER = getLogger("googletts")
|
||||
|
||||
|
||||
class GoogleTTS(object):
|
||||
def __init__(self):
|
||||
self.client = texttospeech.TextToSpeechClient()
|
||||
|
||||
def text_to_speech(self, text: str, params: dict) -> bytes:
|
||||
tts_input = texttospeech.types.SynthesisInput(ssml=text)
|
||||
voice = texttospeech.types.VoiceSelectionParams(
|
||||
language_code=params["language"], name=params["name"]
|
||||
)
|
||||
audio_config = texttospeech.types.AudioConfig(
|
||||
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=params["sample_rate"],
|
||||
)
|
||||
response = self.client.synthesize_speech(tts_input, voice, audio_config)
|
||||
audio_content = response.audio_content
|
||||
return audio_content
|
||||
|
||||
@classmethod
|
||||
def voice_list(cls):
|
||||
"""Lists the available voices."""
|
||||
|
||||
client = cls().client
|
||||
|
||||
# Performs the list voices request
|
||||
voices = client.list_voices()
|
||||
results = []
|
||||
for voice in voices.voices:
|
||||
supported_eng_langs = [
|
||||
lang for lang in voice.language_codes if lang[:2] == "en"
|
||||
]
|
||||
if len(supported_eng_langs) > 0:
|
||||
lang = ",".join(supported_eng_langs)
|
||||
else:
|
||||
continue
|
||||
|
||||
ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender)
|
||||
results.append(
|
||||
{
|
||||
"name": voice.name,
|
||||
"language": lang,
|
||||
"gender": ssml_gender.name,
|
||||
"engine": "wavenet" if "Wav" in voice.name else "standard",
|
||||
"sample_rate": voice.natural_sample_rate_hertz,
|
||||
}
|
||||
)
|
||||
return results
|
||||
26
jasper/data/tts/ttsclient.py
Normal file
26
jasper/data/tts/ttsclient.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
TTSClient Abstract Class
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class TTSClient(ABC):
|
||||
"""
|
||||
Base class for TTS
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def text_to_speech(self, text: str, num_channels: int, sample_rate: int,
|
||||
audio_encoding) -> bytes:
|
||||
"""
|
||||
convert text to bytes
|
||||
|
||||
Arguments:
|
||||
text {[type]} -- text to convert
|
||||
channel {[type]} -- output audio bytes channel setting
|
||||
width {[type]} -- width of audio bytes
|
||||
rate {[type]} -- rare for audio bytes
|
||||
|
||||
Returns:
|
||||
[type] -- [description]
|
||||
"""
|
||||
62
jasper/data/tts_generator.py
Normal file
62
jasper/data/tts_generator.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# import io
|
||||
# import sys
|
||||
# import json
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from .utils import random_pnr_generator, asr_data_writer
|
||||
from .tts.googletts import GoogleTTS
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pnr_tts_streamer(count):
|
||||
google_voices = GoogleTTS.voice_list()
|
||||
gtts = GoogleTTS()
|
||||
for pnr_code in tqdm(random_pnr_generator(count)):
|
||||
tts_code = f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
|
||||
param = random.choice(google_voices)
|
||||
param["sample_rate"] = 24000
|
||||
param["num_channels"] = 1
|
||||
wav_data = gtts.text_to_speech(text=tts_code, params=param)
|
||||
audio_dur = len(wav_data[44:]) / (2 * 24000)
|
||||
yield pnr_code, audio_dur, wav_data
|
||||
|
||||
|
||||
def generate_asr_data_fromtts(output_dir, dataset_name, count):
|
||||
asr_data_writer(output_dir, dataset_name, pnr_tts_streamer(count))
|
||||
|
||||
|
||||
def arg_parser():
|
||||
prog = Path(__file__).stem
|
||||
parser = argparse.ArgumentParser(
|
||||
prog=prog, description=f"generates asr training data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=Path,
|
||||
default=Path("./train/asr_data"),
|
||||
help="directory to output asr data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--count", type=int, default=3, help="number of datapoints to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name", type=str, default="pnr_data", help="name of the dataset"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = arg_parser()
|
||||
args = parser.parse_args()
|
||||
generate_asr_data_fromtts(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
119
jasper/data/utils.py
Normal file
119
jasper/data/utils.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import numpy as np
|
||||
import wave
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pymongo
|
||||
from slugify import slugify
|
||||
from uuid import uuid4
|
||||
from num2words import num2words
|
||||
|
||||
|
||||
def manifest_str(path, dur, text):
|
||||
return (
|
||||
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
|
||||
def wav_bytes(audio_bytes, frame_rate=24000):
|
||||
wf_b = io.BytesIO()
|
||||
with wave.open(wf_b, mode="w") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setframerate(frame_rate)
|
||||
wf.setsampwidth(2)
|
||||
wf.writeframesraw(audio_bytes)
|
||||
return wf_b.getvalue()
|
||||
|
||||
|
||||
def random_pnr_generator(count=10000):
|
||||
LENGTH = 3
|
||||
|
||||
# alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
alphabet = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
numeric = list("0123456789")
|
||||
np_alphabet = np.array(alphabet, dtype="|S1")
|
||||
np_numeric = np.array(numeric, dtype="|S1")
|
||||
np_alpha_codes = np.random.choice(np_alphabet, [count, LENGTH])
|
||||
np_num_codes = np.random.choice(np_numeric, [count, LENGTH])
|
||||
np_code_seed = np.concatenate((np_alpha_codes, np_num_codes), axis=1).T
|
||||
np.random.shuffle(np_code_seed)
|
||||
np_codes = np_code_seed.T
|
||||
codes = [(b"".join(np_codes[i])).decode("utf-8") for i in range(len(np_codes))]
|
||||
return codes
|
||||
|
||||
|
||||
def alnum_to_asr_tokens(text):
|
||||
letters = " ".join(list(text))
|
||||
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
|
||||
return ("".join(num_tokens)).lower()
|
||||
|
||||
|
||||
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
dataset_dir = output_dir / Path(dataset_name)
|
||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
num_datapoints = 0
|
||||
with asr_manifest.open("w") as mf:
|
||||
for transcript, audio_dur, wav_data in asr_data_source:
|
||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
pnr_af = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
pnr_af.write_bytes(wav_data)
|
||||
rel_pnr_path = pnr_af.relative_to(dataset_dir)
|
||||
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
|
||||
mf.write(manifest)
|
||||
if verbose:
|
||||
print(f"writing '{transcript}' of duration {audio_dur}")
|
||||
num_datapoints += 1
|
||||
return num_datapoints
|
||||
|
||||
|
||||
def asr_manifest_reader(data_manifest_path: Path):
|
||||
print(f"reading manifest from {data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_data = [json.loads(v) for v in pnr_jsonl]
|
||||
for p in pnr_data:
|
||||
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
||||
p["chars"] = Path(p["audio_filepath"]).stem
|
||||
yield p
|
||||
|
||||
|
||||
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
|
||||
with asr_manifest_path.open("w") as mf:
|
||||
print(f"opening {asr_manifest_path} for writing manifest")
|
||||
for mani_dict in manifest_str_source:
|
||||
manifest = manifest_str(
|
||||
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
|
||||
)
|
||||
mf.write(manifest)
|
||||
|
||||
|
||||
class ExtendedPath(type(Path())):
|
||||
"""docstring for ExtendedPath."""
|
||||
|
||||
def read_json(self):
|
||||
with self.open("r") as jf:
|
||||
return json.load(jf)
|
||||
|
||||
def write_json(self, data):
|
||||
self.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.open("w") as jf:
|
||||
return json.dump(data, jf, indent=2)
|
||||
|
||||
|
||||
def get_mongo_conn(host='', port=27017):
|
||||
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
||||
mongo_uri = f"mongodb://{mongo_host}:{port}/"
|
||||
return pymongo.MongoClient(mongo_uri)
|
||||
|
||||
|
||||
def main():
|
||||
for c in random_pnr_generator():
|
||||
print(c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
221
jasper/data/validation/process.py
Normal file
221
jasper/data/validation/process.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import (
|
||||
alnum_to_asr_tokens,
|
||||
ExtendedPath,
|
||||
asr_manifest_reader,
|
||||
asr_manifest_writer,
|
||||
get_mongo_conn,
|
||||
)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
from pydub import AudioSegment
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
from jasper.client import (
|
||||
transcriber_pretrained,
|
||||
transcriber_speller,
|
||||
)
|
||||
|
||||
try:
|
||||
res = dict(sample)
|
||||
res["real_idx"] = idx
|
||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
||||
res["audio_path"] = str(audio_path)
|
||||
res["spoken"] = alnum_to_asr_tokens(res["text"])
|
||||
res["utterance_id"] = audio_path.stem
|
||||
aud_seg = (
|
||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
|
||||
if use_domain_asr:
|
||||
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["domain_wer"] = word_error_rate(
|
||||
[res["spoken"]], [res["pretrained_asr"]]
|
||||
)
|
||||
wav_plot_path = (
|
||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
fig = plt.Figure()
|
||||
ax = fig.add_subplot()
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
||||
with wav_plot_path.open("wb") as wav_plot_f:
|
||||
fig.set_tight_layout(True)
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
# fig.close()
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
return res
|
||||
except BaseException as e:
|
||||
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_validation_ui_data(
|
||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
dump_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
||||
use_domain_asr: bool = True,
|
||||
):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_funcs = [
|
||||
partial(
|
||||
preprocess_datapoint,
|
||||
i,
|
||||
data_manifest_path.parent,
|
||||
json.loads(v),
|
||||
use_domain_asr,
|
||||
)
|
||||
for i, v in enumerate(pnr_jsonl)
|
||||
]
|
||||
|
||||
def exec_func(f):
|
||||
return f()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=20) as exe:
|
||||
print("starting all plot tasks")
|
||||
pnr_data = filter(
|
||||
None,
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, pnr_funcs),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=len(pnr_funcs),
|
||||
)
|
||||
),
|
||||
)
|
||||
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
||||
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
||||
ui_config = {"use_domain_asr": use_domain_asr, "data": result}
|
||||
ExtendedPath(dump_path).write_json(ui_config)
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")):
|
||||
col = get_mongo_conn().test.asr_validation
|
||||
|
||||
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
||||
corrections = [c for c in cursor_obj]
|
||||
ExtendedPath(dump_path).write_json(corrections)
|
||||
|
||||
|
||||
@app.command()
|
||||
def fill_unannotated(
|
||||
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
||||
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
||||
):
|
||||
processed_data = json.load(processed_data_path.open())
|
||||
corrections = json.load(corrections_path.open())
|
||||
annotated_codes = {c["code"] for c in corrections}
|
||||
all_codes = {c["gold_chars"] for c in processed_data}
|
||||
unann_codes = all_codes - annotated_codes
|
||||
mongo_conn = get_mongo_conn().test.asr_validation
|
||||
for c in unann_codes:
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": c},
|
||||
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def update_corrections(
|
||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
||||
skip_incorrect: bool = True,
|
||||
):
|
||||
def correct_manifest(manifest_data_gen, corrections_path):
|
||||
corrections = json.load(corrections_path.open())
|
||||
correct_set = {
|
||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||
}
|
||||
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
|
||||
correction_map = {
|
||||
c["code"]: c["value"]["correction"]
|
||||
for c in corrections
|
||||
if c["value"]["status"] == "Incorrect"
|
||||
}
|
||||
# for d in manifest_data_gen:
|
||||
# if d["chars"] in incorrect_set:
|
||||
# d["audio_path"].unlink()
|
||||
renamed_set = set()
|
||||
for d in manifest_data_gen:
|
||||
if d["chars"] in correct_set:
|
||||
yield {
|
||||
"audio_filepath": d["audio_filepath"],
|
||||
"duration": d["duration"],
|
||||
"text": d["text"],
|
||||
}
|
||||
elif d["chars"] in correction_map:
|
||||
correct_text = correction_map[d["chars"]]
|
||||
if skip_incorrect:
|
||||
print(f'skipping incorrect {d["audio_path"]} corrected to {correct_text}')
|
||||
else:
|
||||
renamed_set.add(correct_text)
|
||||
new_name = str(Path(correct_text).with_suffix(".wav"))
|
||||
d["audio_path"].replace(d["audio_path"].with_name(new_name))
|
||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||
yield {
|
||||
"audio_filepath": new_filepath,
|
||||
"duration": d["duration"],
|
||||
"text": alnum_to_asr_tokens(correct_text),
|
||||
}
|
||||
else:
|
||||
# don't delete if another correction points to an old file
|
||||
if d["chars"] not in renamed_set:
|
||||
d["audio_path"].unlink()
|
||||
else:
|
||||
print(f'skipping deletion of correction:{d["chars"]}')
|
||||
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
dataset_dir = data_manifest_path.parent
|
||||
dataset_name = dataset_dir.name
|
||||
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
|
||||
if not backup_dir.exists():
|
||||
typer.echo(f"backing up to :{backup_dir}")
|
||||
shutil.copytree(str(dataset_dir), str(backup_dir))
|
||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
|
||||
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
||||
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
||||
new_data_manifest_path.replace(data_manifest_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
def clear_mongo_corrections():
|
||||
delete = typer.confirm("are you sure you want to clear mongo collection it?")
|
||||
if delete:
|
||||
col = get_mongo_conn().test.asr_validation
|
||||
col.delete_many({"type": "correction"})
|
||||
typer.echo("deleted mongo collection.")
|
||||
typer.echo("Aborted")
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
38
jasper/data/validation/st_rerun.py
Normal file
38
jasper/data/validation/st_rerun.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import streamlit.ReportThread as ReportThread
|
||||
from streamlit.ScriptRequestQueue import RerunData
|
||||
from streamlit.ScriptRunner import RerunException
|
||||
from streamlit.server.Server import Server
|
||||
|
||||
|
||||
def rerun():
|
||||
"""Rerun a Streamlit app from the top!"""
|
||||
widget_states = _get_widget_states()
|
||||
raise RerunException(RerunData(widget_states))
|
||||
|
||||
|
||||
def _get_widget_states():
|
||||
# Hack to get the session object from Streamlit.
|
||||
|
||||
ctx = ReportThread.get_report_ctx()
|
||||
|
||||
session = None
|
||||
|
||||
current_server = Server.get_current()
|
||||
if hasattr(current_server, '_session_infos'):
|
||||
# Streamlit < 0.56
|
||||
session_infos = Server.get_current()._session_infos.values()
|
||||
else:
|
||||
session_infos = Server.get_current()._session_info_by_id.values()
|
||||
|
||||
for session_info in session_infos:
|
||||
if session_info.session.enqueue == ctx.enqueue:
|
||||
session = session_info.session
|
||||
|
||||
if session is None:
|
||||
raise RuntimeError(
|
||||
"Oh noes. Couldn't get your Streamlit Session object"
|
||||
"Are you doing something fancy with threads?"
|
||||
)
|
||||
# Got the session object!
|
||||
|
||||
return session._widget_states
|
||||
140
jasper/data/validation/ui.py
Normal file
140
jasper/data/validation/ui.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
import typer
|
||||
from ..utils import ExtendedPath, get_mongo_conn
|
||||
from .st_rerun import rerun
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
if not hasattr(st, "mongo_connected"):
|
||||
st.mongoclient = get_mongo_conn().test.asr_validation
|
||||
mongo_conn = st.mongoclient
|
||||
|
||||
def current_cursor_fn():
|
||||
# mongo_conn = st.mongoclient
|
||||
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
|
||||
cursor_val = cursor_obj["cursor"]
|
||||
return cursor_val
|
||||
|
||||
def update_cursor_fn(val=0):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "current_cursor"},
|
||||
{"$set": {"type": "current_cursor", "cursor": val}},
|
||||
upsert=True,
|
||||
)
|
||||
rerun()
|
||||
|
||||
def get_correction_entry_fn(code):
|
||||
# mongo_conn = st.mongoclient
|
||||
# cursor_obj = mongo_conn.find_one({"type": "correction", "code": code})
|
||||
# cursor_val = cursor_obj["cursor"]
|
||||
return mongo_conn.find_one(
|
||||
{"type": "correction", "code": code}, projection={"_id": False}
|
||||
)
|
||||
|
||||
def update_entry_fn(code, value):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": code},
|
||||
{"$set": {"value": value}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
|
||||
if not cursor_obj:
|
||||
update_cursor_fn(0)
|
||||
st.get_current_cursor = current_cursor_fn
|
||||
st.update_cursor = update_cursor_fn
|
||||
st.get_correction_entry = get_correction_entry_fn
|
||||
st.update_entry = update_entry_fn
|
||||
st.mongo_connected = True
|
||||
|
||||
|
||||
@st.cache()
|
||||
def load_ui_data(validation_ui_data_path: Path):
|
||||
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||
return ExtendedPath(validation_ui_data_path).read_json()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(manifest: Path):
|
||||
ui_config = load_ui_data(manifest)
|
||||
asr_data = ui_config["data"]
|
||||
use_domain_asr = ui_config["use_domain_asr"]
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
title_type = "Speller " if use_domain_asr else ""
|
||||
st.title(f"ASR {title_type}Validation")
|
||||
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
||||
st.sidebar.title("Results:")
|
||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
||||
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
||||
else:
|
||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
||||
# set default to text
|
||||
corrected = sample["text"]
|
||||
correction_entry = st.get_correction_entry(sample["utterance_id"])
|
||||
selected_idx = 0
|
||||
options = ("Correct", "Incorrect", "Inaudible")
|
||||
# if correction entry is present set the corresponding ui defaults
|
||||
if correction_entry:
|
||||
selected_idx = options.index(correction_entry["value"]["status"])
|
||||
corrected = correction_entry["value"]["correction"]
|
||||
selected = st.radio("The Audio is", options, index=selected_idx)
|
||||
if selected == "Incorrect":
|
||||
corrected = st.text_input("Actual:", value=corrected)
|
||||
if selected == "Inaudible":
|
||||
corrected = ""
|
||||
if st.button("Submit"):
|
||||
correct_code = corrected.replace(" ", "").upper()
|
||||
st.update_entry(
|
||||
sample["utterance_id"], {"status": selected, "correction": correct_code}
|
||||
)
|
||||
st.update_cursor(sample_no + 1)
|
||||
if correction_entry:
|
||||
st.markdown(
|
||||
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
||||
)
|
||||
# if st.button("Previous Untagged"):
|
||||
# pass
|
||||
# if st.button("Next Untagged"):
|
||||
# pass
|
||||
text_sample = st.text_input("Go to Text:", value='')
|
||||
if text_sample != '':
|
||||
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample]
|
||||
if len(candidates) > 0:
|
||||
st.update_cursor(candidates[0])
|
||||
real_idx = st.number_input(
|
||||
"Go to real-index",
|
||||
value=sample["real_idx"],
|
||||
min_value=0,
|
||||
max_value=len(asr_data) - 1,
|
||||
)
|
||||
if real_idx != int(sample["real_idx"]):
|
||||
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
|
||||
st.update_cursor(idx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
Reference in New Issue
Block a user