1. added a tool to extract asr data from gcp transcripts logs
2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file 3. clean-up leaderboard duration logic 4. added a wip dataloader service 5. made the asr_data_writer util more generic with verbose flags and unique filename 6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node 7. refactored the validation post processing to dump a ui config for validation 8. included utility functions to correct, fill update and clear annotations from mongodb data 9. refactored the ui logic to be more generic for any asr data 10. updated setup.py dependencies to support the above features
parent
a7da729c0b
commit
c06a0814b9
|
|
@ -0,0 +1,93 @@
|
|||
import typer
|
||||
from itertools import chain
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def extract_data(
|
||||
call_audio_dir: Path = Path("/dataset/png_prod/call_audio"),
|
||||
call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"),
|
||||
output_dir: Path = Path("./data"),
|
||||
dataset_name: str = "png_gcp_2jan",
|
||||
verbose: bool = False,
|
||||
):
|
||||
from pydub import AudioSegment
|
||||
from .utils import ExtendedPath, asr_data_writer
|
||||
from lenses import lens
|
||||
|
||||
call_asr_data: Path = output_dir / Path("asr_data")
|
||||
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def wav_event_generator(call_audio_dir):
|
||||
for wav_path in call_audio_dir.glob("**/*.wav"):
|
||||
if verbose:
|
||||
typer.echo(f"loading events for file {wav_path}")
|
||||
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
|
||||
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
|
||||
meta_path = call_meta_dir / rel_meta_path
|
||||
events = ExtendedPath(meta_path).read_json()
|
||||
yield call_wav, wav_path, events
|
||||
|
||||
def contains_asr(x):
|
||||
return "AsrResult" in x
|
||||
|
||||
def channel(n):
|
||||
def filter_func(ev):
|
||||
return (
|
||||
ev["AsrResult"]["Channel"] == n
|
||||
if "Channel" in ev["AsrResult"]
|
||||
else n == 0
|
||||
)
|
||||
|
||||
return filter_func
|
||||
|
||||
def compute_endtime(call_wav, state):
|
||||
for (i, st) in enumerate(state):
|
||||
start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0)
|
||||
transcript = st["AsrResult"]["Alternatives"][0]["Transcript"]
|
||||
if i + 1 < len(state):
|
||||
end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
|
||||
else:
|
||||
end_time = call_wav.duration_seconds
|
||||
code_seg = call_wav[start_time * 1000 : end_time * 1000]
|
||||
code_fb = BytesIO()
|
||||
code_seg.export(code_fb, format="wav")
|
||||
code_wav = code_fb.getvalue()
|
||||
# only of some audio data is present yield it
|
||||
if code_seg.duration_seconds >= 0.5:
|
||||
yield transcript, code_seg.duration_seconds, code_wav
|
||||
|
||||
def asr_data_generator(call_wav, call_wav_fname, events):
|
||||
call_wav_0, call_wav_1 = call_wav.split_to_mono()
|
||||
asr_events = lens["Events"].Each()["Event"].Filter(contains_asr)
|
||||
call_evs_0 = asr_events.Filter(channel(0)).collect()(events)
|
||||
call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
|
||||
if verbose:
|
||||
typer.echo(f"processing data points on {call_wav_fname}")
|
||||
call_data_0 = compute_endtime(call_wav_0, call_evs_0)
|
||||
call_data_1 = compute_endtime(call_wav_1, call_evs_1)
|
||||
return chain(call_data_0, call_data_1)
|
||||
|
||||
def generate_call_asr_data():
|
||||
full_asr_data = []
|
||||
total_duration = 0
|
||||
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
|
||||
asr_data = asr_data_generator(wav, wav_path, ev)
|
||||
total_duration += wav.duration_seconds
|
||||
full_asr_data.append(asr_data)
|
||||
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
|
||||
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
|
||||
typer.echo(f"written {n_dps} data points")
|
||||
|
||||
generate_call_asr_data()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -11,6 +11,29 @@ app = typer.Typer()
|
|||
# app.add_typer(plot_app, name="plot")
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
|
||||
from pymongo import MongoClient
|
||||
from collections import defaultdict
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
yaml = YAML()
|
||||
mongo_collection = MongoClient("mongodb://localhost:27017/").test.calls
|
||||
caller_calls = defaultdict(lambda: [])
|
||||
for call in mongo_collection.find():
|
||||
sysid = call["SystemID"]
|
||||
call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}"
|
||||
caller = call["Caller"]
|
||||
caller_calls[caller].append(call_uri)
|
||||
caller_list = []
|
||||
for caller in caller_calls:
|
||||
caller_list.append({"name": caller, "calls": caller_calls[caller]})
|
||||
output_yaml = {"users": caller_list}
|
||||
typer.echo("exporting call logs to yaml file")
|
||||
with call_logs_file.open("w") as yf:
|
||||
yaml.dump(output_yaml, yf)
|
||||
|
||||
|
||||
@app.command()
|
||||
def analyze(
|
||||
leaderboard: bool = False,
|
||||
|
|
@ -19,8 +42,6 @@ def analyze(
|
|||
call_logs_file: Path = Path("./call_logs.yaml"),
|
||||
output_dir: Path = Path("./data"),
|
||||
):
|
||||
call_logs_file = Path("./call_logs.yaml")
|
||||
output_dir = Path("./data")
|
||||
|
||||
from urllib.parse import urlsplit
|
||||
from functools import reduce
|
||||
|
|
@ -35,7 +56,6 @@ def analyze(
|
|||
from datetime import timedelta
|
||||
|
||||
# from concurrent.futures import ThreadPoolExecutor
|
||||
from dateutil.relativedelta import relativedelta
|
||||
import librosa
|
||||
import librosa.display
|
||||
from lenses import lens
|
||||
|
|
@ -46,6 +66,8 @@ def analyze(
|
|||
from tqdm import tqdm
|
||||
from .utils import asr_data_writer
|
||||
from pydub import AudioSegment
|
||||
from natural.date import compress
|
||||
|
||||
# from itertools import product, chain
|
||||
|
||||
matplotlib.rcParams["agg.path.chunksize"] = 10000
|
||||
|
|
@ -256,8 +278,11 @@ def analyze(
|
|||
code_fb = BytesIO()
|
||||
code_seg.export(code_fb, format="wav")
|
||||
code_wav = code_fb.getvalue()
|
||||
# import pdb; pdb.set_trace()
|
||||
yield code, code_seg.duration_seconds, code_wav
|
||||
# search for actual pnr code and handle plain codes as well
|
||||
extracted_code = (
|
||||
re.search(r"'(.*)'", code).groups(0)[0] if len(code) > 6 else code
|
||||
)
|
||||
yield extracted_code, code_seg.duration_seconds, code_wav
|
||||
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_stats = call_lens.modify(retrieve_callmeta)(call_logs)
|
||||
|
|
@ -275,22 +300,17 @@ def analyze(
|
|||
|
||||
asr_data_writer(call_asr_data, "call_alphanum", data_source())
|
||||
|
||||
# @leader_app.command()
|
||||
def show_leaderboard():
|
||||
def compute_user_stats(call_stat):
|
||||
n_samples = (
|
||||
lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
|
||||
)
|
||||
n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
|
||||
rel_dur = relativedelta(
|
||||
seconds=int(n_duration.total_seconds()),
|
||||
microseconds=n_duration.microseconds,
|
||||
)
|
||||
return {
|
||||
"num_samples": n_samples,
|
||||
"duration": n_duration.total_seconds(),
|
||||
"samples_rate": n_samples / n_duration.total_seconds(),
|
||||
"duration_str": f"{rel_dur.minutes} mins {rel_dur.seconds} secs",
|
||||
"duration_str": compress(n_duration, pad=" "),
|
||||
"name": call_stat["name"],
|
||||
}
|
||||
|
||||
|
|
@ -313,8 +333,8 @@ def analyze(
|
|||
}
|
||||
)[["Rank", "Name", "Codes", "Duration"]]
|
||||
print(
|
||||
"""Today's ASR Speller Dataset Leaderboard:
|
||||
----------------------------------------"""
|
||||
"""ASR Speller Dataset Leaderboard :
|
||||
---------------------------------"""
|
||||
)
|
||||
print(leader_board.to_string(index=False))
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
import typer
|
||||
import rpyc
|
||||
import os
|
||||
from pathlib import Path
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
class ASRDataService(rpyc.Service):
|
||||
def get_data_loader(self, data_manifest: Path):
|
||||
return "hello"
|
||||
|
||||
|
||||
@app.command()
|
||||
def run_server(port: int = 0):
|
||||
listen_port = port if port else int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||
service = ASRDataService()
|
||||
t = ThreadedServer(service, port=listen_port)
|
||||
typer.echo(f"starting asr server on {listen_port}...")
|
||||
t.start()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,8 +1,13 @@
|
|||
import numpy as np
|
||||
import wave
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pymongo
|
||||
from slugify import slugify
|
||||
from uuid import uuid4
|
||||
from num2words import num2words
|
||||
|
||||
|
||||
|
|
@ -46,42 +51,65 @@ def alnum_to_asr_tokens(text):
|
|||
return ("".join(num_tokens)).lower()
|
||||
|
||||
|
||||
def asr_data_writer(output_dir, dataset_name, asr_data_source):
|
||||
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
dataset_dir = output_dir / Path(dataset_name)
|
||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
num_datapoints = 0
|
||||
with asr_manifest.open("w") as mf:
|
||||
for pnr_code, audio_dur, wav_data in asr_data_source:
|
||||
pnr_af = dataset_dir / Path("wav") / Path(pnr_code).with_suffix(".wav")
|
||||
for transcript, audio_dur, wav_data in asr_data_source:
|
||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
pnr_af = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
pnr_af.write_bytes(wav_data)
|
||||
rel_pnr_path = pnr_af.relative_to(dataset_dir)
|
||||
manifest = manifest_str(
|
||||
str(rel_pnr_path), audio_dur, alnum_to_asr_tokens(pnr_code)
|
||||
)
|
||||
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
|
||||
mf.write(manifest)
|
||||
if verbose:
|
||||
print(f"writing '{transcript}' of duration {audio_dur}")
|
||||
num_datapoints += 1
|
||||
return num_datapoints
|
||||
|
||||
|
||||
def asr_manifest_reader(data_manifest_path: Path):
|
||||
print(f'reading manifest from {data_manifest_path}')
|
||||
print(f"reading manifest from {data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_data = [json.loads(v) for v in pnr_jsonl]
|
||||
for p in pnr_data:
|
||||
p['audio_path'] = data_manifest_path.parent / Path(p['audio_filepath'])
|
||||
p['chars'] = Path(p['audio_filepath']).stem
|
||||
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
||||
p["chars"] = Path(p["audio_filepath"]).stem
|
||||
yield p
|
||||
|
||||
|
||||
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
|
||||
with asr_manifest_path.open("w") as mf:
|
||||
print(f'opening {asr_manifest_path} for writing manifest')
|
||||
print(f"opening {asr_manifest_path} for writing manifest")
|
||||
for mani_dict in manifest_str_source:
|
||||
manifest = manifest_str(
|
||||
mani_dict['audio_filepath'], mani_dict['duration'], mani_dict['text']
|
||||
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
|
||||
)
|
||||
mf.write(manifest)
|
||||
|
||||
|
||||
class ExtendedPath(type(Path())):
|
||||
"""docstring for ExtendedPath."""
|
||||
|
||||
def read_json(self):
|
||||
with self.open("r") as jf:
|
||||
return json.load(jf)
|
||||
|
||||
def write_json(self, data):
|
||||
self.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.open("w") as jf:
|
||||
return json.dump(data, jf, indent=2)
|
||||
|
||||
|
||||
def get_mongo_conn(host=''):
|
||||
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
||||
mongo_uri = f"mongodb://{mongo_host}:27017/"
|
||||
return pymongo.MongoClient(mongo_uri)
|
||||
|
||||
|
||||
def main():
|
||||
for c in random_pnr_generator():
|
||||
print(c)
|
||||
|
|
|
|||
|
|
@ -1,105 +1,137 @@
|
|||
import pymongo
|
||||
import typer
|
||||
|
||||
# import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# import pandas as pd
|
||||
from pydub import AudioSegment
|
||||
|
||||
# from .jasper_client import transcriber_pretrained, transcriber_speller
|
||||
from jasper.data_utils.validation.jasper_client import (
|
||||
transcriber_pretrained,
|
||||
transcriber_speller,
|
||||
)
|
||||
from jasper.data_utils.utils import alnum_to_asr_tokens
|
||||
|
||||
# import importlib
|
||||
# import jasper.data_utils.utils
|
||||
# importlib.reload(jasper.data_utils.utils)
|
||||
from jasper.data_utils.utils import asr_manifest_reader, asr_manifest_writer
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
|
||||
# from tqdm import tqdm as tqdm_base
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import (
|
||||
alnum_to_asr_tokens,
|
||||
ExtendedPath,
|
||||
asr_manifest_reader,
|
||||
asr_manifest_writer,
|
||||
get_mongo_conn,
|
||||
)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
from pydub import AudioSegment
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
from jasper.data_utils.validation.jasper_client import (
|
||||
transcriber_pretrained,
|
||||
transcriber_speller,
|
||||
)
|
||||
|
||||
try:
|
||||
res = dict(sample)
|
||||
res["real_idx"] = idx
|
||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
||||
res["audio_path"] = str(audio_path)
|
||||
res["spoken"] = alnum_to_asr_tokens(res["text"])
|
||||
res["utterance_id"] = audio_path.stem
|
||||
aud_seg = (
|
||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
|
||||
if use_domain_asr:
|
||||
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["domain_wer"] = word_error_rate(
|
||||
[res["spoken"]], [res["pretrained_asr"]]
|
||||
)
|
||||
wav_plot_path = (
|
||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
fig = plt.Figure()
|
||||
ax = fig.add_subplot()
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
||||
with wav_plot_path.open("wb") as wav_plot_f:
|
||||
fig.set_tight_layout(True)
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
# fig.close()
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
return res
|
||||
except BaseException as e:
|
||||
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_validation_ui_data(
|
||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
dump_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
||||
use_domain_asr: bool = True,
|
||||
):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_funcs = [
|
||||
partial(
|
||||
preprocess_datapoint,
|
||||
i,
|
||||
data_manifest_path.parent,
|
||||
json.loads(v),
|
||||
use_domain_asr,
|
||||
)
|
||||
for i, v in enumerate(pnr_jsonl)
|
||||
]
|
||||
|
||||
def exec_func(f):
|
||||
return f()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=20) as exe:
|
||||
print("starting all plot tasks")
|
||||
pnr_data = filter(
|
||||
None,
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, pnr_funcs),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=len(pnr_funcs),
|
||||
)
|
||||
),
|
||||
)
|
||||
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
||||
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
||||
ui_config = {"use_domain_asr": use_domain_asr, "data": result}
|
||||
ExtendedPath(dump_path).write_json(ui_config)
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_corrections(dump_path: Path = Path("./data/corrections.json")):
|
||||
col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
|
||||
col = get_mongo_conn().test.asr_validation
|
||||
|
||||
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
||||
corrections = [c for c in cursor_obj]
|
||||
dump_f = dump_path.open("w")
|
||||
json.dump(corrections, dump_f, indent=2)
|
||||
dump_f.close()
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, rel, sample):
|
||||
res = dict(sample)
|
||||
res["real_idx"] = idx
|
||||
audio_path = rel / Path(sample["audio_filepath"])
|
||||
res["audio_path"] = str(audio_path)
|
||||
res["gold_chars"] = audio_path.stem
|
||||
res["gold_phone"] = sample["text"]
|
||||
aud_seg = (
|
||||
AudioSegment.from_wav(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
|
||||
return res
|
||||
|
||||
|
||||
def load_dataset(data_manifest_path: Path):
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_data = [
|
||||
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
|
||||
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
|
||||
]
|
||||
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
|
||||
return result
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_processed_data(
|
||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
dump_path: Path = Path("./data/processed_data.json"),
|
||||
):
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_data = [
|
||||
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
|
||||
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
|
||||
]
|
||||
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
|
||||
dump_path = Path("./data/processed_data.json")
|
||||
dump_f = dump_path.open("w")
|
||||
json.dump(result, dump_f, indent=2)
|
||||
dump_f.close()
|
||||
ExtendedPath(dump_path).write_json(corrections)
|
||||
|
||||
|
||||
@app.command()
|
||||
def fill_unannotated(
|
||||
processed_data_path: Path = Path("./data/processed_data.json"),
|
||||
corrections_path: Path = Path("./data/corrections.json"),
|
||||
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
||||
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
||||
):
|
||||
processed_data = json.load(processed_data_path.open())
|
||||
corrections = json.load(corrections_path.open())
|
||||
annotated_codes = {c["code"] for c in corrections}
|
||||
all_codes = {c["gold_chars"] for c in processed_data}
|
||||
unann_codes = all_codes - annotated_codes
|
||||
mongo_conn = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
|
||||
mongo_conn = get_mongo_conn().test.asr_validation
|
||||
for c in unann_codes:
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": c},
|
||||
|
|
@ -111,8 +143,8 @@ def fill_unannotated(
|
|||
@app.command()
|
||||
def update_corrections(
|
||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
processed_data_path: Path = Path("./data/processed_data.json"),
|
||||
corrections_path: Path = Path("./data/corrections.json"),
|
||||
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
||||
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
||||
):
|
||||
def correct_manifest(manifest_data_gen, corrections_path):
|
||||
corrections = json.load(corrections_path.open())
|
||||
|
|
@ -168,6 +200,12 @@ def update_corrections(
|
|||
new_data_manifest_path.replace(data_manifest_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
def clear_mongo_corrections():
|
||||
col = get_mongo_conn().test.asr_validation
|
||||
col.delete_many({"type": "correction"})
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,27 +1,15 @@
|
|||
import json
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
import librosa
|
||||
import librosa.display
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
from pydub import AudioSegment
|
||||
import pymongo
|
||||
import typer
|
||||
from .jasper_client import transcriber_pretrained, transcriber_speller
|
||||
from ..utils import ExtendedPath, get_mongo_conn
|
||||
from .st_rerun import rerun
|
||||
|
||||
app = typer.Typer()
|
||||
st.title("ASR Speller Validation")
|
||||
|
||||
|
||||
if not hasattr(st, "mongo_connected"):
|
||||
st.mongoclient = pymongo.MongoClient(
|
||||
"mongodb://localhost:27017/"
|
||||
).test.asr_validation
|
||||
st.mongoclient = get_mongo_conn().test.asr_validation
|
||||
mongo_conn = st.mongoclient
|
||||
|
||||
def current_cursor_fn():
|
||||
|
|
@ -63,80 +51,49 @@ if not hasattr(st, "mongo_connected"):
|
|||
st.mongo_connected = True
|
||||
|
||||
|
||||
# def clear_mongo_corrections():
|
||||
# col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
|
||||
# col.delete_many({"type": "correction"})
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, rel, sample):
|
||||
res = dict(sample)
|
||||
res["real_idx"] = idx
|
||||
audio_path = rel / Path(sample["audio_filepath"])
|
||||
res["audio_path"] = audio_path
|
||||
res["gold_chars"] = audio_path.stem
|
||||
aud_seg = (
|
||||
AudioSegment.from_wav(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["wer"] = word_error_rate([res["text"]], [res["speller_asr"]])
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
plt.tight_layout()
|
||||
librosa.display.waveplot(y=y, sr=sr)
|
||||
wav_plot_f = BytesIO()
|
||||
plt.savefig(wav_plot_f, format="png", dpi=50)
|
||||
plt.close()
|
||||
wav_plot_f.seek(0)
|
||||
res["plot_png"] = wav_plot_f
|
||||
return res
|
||||
|
||||
|
||||
@st.cache(hash_funcs={"rpyc.core.netref.builtins.method": lambda _: None})
|
||||
def preprocess_dataset(data_manifest_path: Path):
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_data = [
|
||||
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
|
||||
for i, v in enumerate(tqdm(pnr_jsonl))
|
||||
]
|
||||
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
|
||||
return result
|
||||
@st.cache()
|
||||
def load_ui_data(validation_ui_data_path: Path):
|
||||
typer.echo(f"Using validation ui data from :{validation_ui_data_path}")
|
||||
return ExtendedPath(validation_ui_data_path).read_json()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(manifest: Path):
|
||||
pnr_data = preprocess_dataset(manifest)
|
||||
ui_config = load_ui_data(manifest)
|
||||
asr_data = ui_config["data"]
|
||||
use_domain_asr = ui_config["use_domain_asr"]
|
||||
sample_no = st.get_current_cursor()
|
||||
sample = pnr_data[sample_no]
|
||||
st.markdown(
|
||||
f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['text']}*"
|
||||
sample = asr_data[sample_no]
|
||||
title_type = 'Speller ' if use_domain_asr else ''
|
||||
st.title(f"ASR {title_type}Validation")
|
||||
addl_text = (
|
||||
f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
||||
)
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data)
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||
st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**")
|
||||
st.sidebar.markdown(f"Expected Speech: *{sample['text']}*")
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
||||
st.sidebar.title("Results:")
|
||||
st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}")
|
||||
st.sidebar.text(f"Speller:{sample['speller_asr']}")
|
||||
|
||||
st.sidebar.title(f"Speller WER: {sample['wer']:.2f}%")
|
||||
# (y, sr) = librosa.load(sample["audio_path"])
|
||||
# librosa.display.waveplot(y=y, sr=sr)
|
||||
# st.sidebar.pyplot(fig=sample["plot_fig"])
|
||||
st.sidebar.image(sample["plot_png"])
|
||||
st.audio(sample["audio_path"].open("rb"))
|
||||
corrected = sample["gold_chars"]
|
||||
correction_entry = st.get_correction_entry(sample["gold_chars"])
|
||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
||||
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
||||
else:
|
||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
||||
# set default to text
|
||||
corrected = sample["text"]
|
||||
correction_entry = st.get_correction_entry(sample["utterance_id"])
|
||||
selected_idx = 0
|
||||
options = ("Correct", "Incorrect", "Inaudible")
|
||||
# if correction entry is present set the corresponding ui defaults
|
||||
if correction_entry:
|
||||
selected_idx = options.index(correction_entry["value"]["status"])
|
||||
corrected = correction_entry["value"]["correction"]
|
||||
|
|
@ -148,24 +105,26 @@ def main(manifest: Path):
|
|||
if st.button("Submit"):
|
||||
correct_code = corrected.replace(" ", "").upper()
|
||||
st.update_entry(
|
||||
sample["gold_chars"], {"status": selected, "correction": correct_code}
|
||||
sample["utterance_id"], {"status": selected, "correction": correct_code}
|
||||
)
|
||||
st.update_cursor(sample_no + 1)
|
||||
if correction_entry:
|
||||
st.markdown(
|
||||
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
||||
)
|
||||
# real_idx = st.text_input("Go to real-index:", value=sample['real_idx'])
|
||||
# st.markdown(
|
||||
# ",".join(
|
||||
# [
|
||||
# "**" + str(p["real_idx"]) + "**"
|
||||
# if p["real_idx"] == sample["real_idx"]
|
||||
# else str(p["real_idx"])
|
||||
# for p in pnr_data
|
||||
# ]
|
||||
# )
|
||||
# )
|
||||
# if st.button("Previous Untagged"):
|
||||
# pass
|
||||
# if st.button("Next Untagged"):
|
||||
# pass
|
||||
real_idx = st.number_input(
|
||||
"Go to real-index",
|
||||
value=sample["real_idx"],
|
||||
min_value=0,
|
||||
max_value=len(asr_data) - 1,
|
||||
)
|
||||
if real_idx != int(sample["real_idx"]):
|
||||
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
|
||||
st.update_cursor(idx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
7
setup.py
7
setup.py
|
|
@ -22,13 +22,19 @@ extra_requirements = {
|
|||
"matplotlib==3.2.1",
|
||||
"pandas==1.0.3",
|
||||
"tabulate==0.8.7",
|
||||
"natural==0.2.0",
|
||||
"num2words==0.5.10",
|
||||
"typer[all]==0.1.1",
|
||||
"python-slugify==4.0.0",
|
||||
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
||||
],
|
||||
"validation": [
|
||||
"rpyc~=4.1.4",
|
||||
"pymongo==3.10.1",
|
||||
"typer[all]==0.1.1",
|
||||
"tqdm~=4.39.0",
|
||||
"librosa==0.7.2",
|
||||
"matplotlib==3.2.1",
|
||||
"pydub~=0.23.1",
|
||||
"streamlit==0.58.0",
|
||||
"stringcase==1.2.0"
|
||||
|
|
@ -58,6 +64,7 @@ setup(
|
|||
"jasper_asr_trainer = jasper.train:main",
|
||||
"jasper_asr_data_generate = jasper.data_utils.generator:main",
|
||||
"jasper_asr_data_recycle = jasper.data_utils.call_recycler:main",
|
||||
"jasper_asr_data_validation = jasper.data_utils.validation.process:main",
|
||||
"jasper_asr_data_preprocess = jasper.data_utils.process:main",
|
||||
]
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in New Issue