1. added a tool to extract asr data from gcp transcripts logs

2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file
3. clean-up leaderboard duration logic
4. added a wip dataloader service
5. made the asr_data_writer util more generic with verbose flags and unique filename
6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node
7. refactored the validation post processing to dump a ui config for validation
8. included utility functions to correct, fill update and clear annotations from mongodb data
9. refactored the ui logic to be more generic for any asr data
10. updated setup.py dependencies to support the above features
Malar Kannan 2020-05-12 23:38:06 +05:30
parent a7da729c0b
commit c06a0814b9
7 changed files with 365 additions and 191 deletions

View File

@ -0,0 +1,93 @@
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
app = typer.Typer()
@app.command()
def extract_data(
call_audio_dir: Path = Path("/dataset/png_prod/call_audio"),
call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"),
output_dir: Path = Path("./data"),
dataset_name: str = "png_gcp_2jan",
verbose: bool = False,
):
from pydub import AudioSegment
from .utils import ExtendedPath, asr_data_writer
from lenses import lens
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_event_generator(call_audio_dir):
for wav_path in call_audio_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
meta_path = call_meta_dir / rel_meta_path
events = ExtendedPath(meta_path).read_json()
yield call_wav, wav_path, events
def contains_asr(x):
return "AsrResult" in x
def channel(n):
def filter_func(ev):
return (
ev["AsrResult"]["Channel"] == n
if "Channel" in ev["AsrResult"]
else n == 0
)
return filter_func
def compute_endtime(call_wav, state):
for (i, st) in enumerate(state):
start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0)
transcript = st["AsrResult"]["Alternatives"][0]["Transcript"]
if i + 1 < len(state):
end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
else:
end_time = call_wav.duration_seconds
code_seg = call_wav[start_time * 1000 : end_time * 1000]
code_fb = BytesIO()
code_seg.export(code_fb, format="wav")
code_wav = code_fb.getvalue()
# only of some audio data is present yield it
if code_seg.duration_seconds >= 0.5:
yield transcript, code_seg.duration_seconds, code_wav
def asr_data_generator(call_wav, call_wav_fname, events):
call_wav_0, call_wav_1 = call_wav.split_to_mono()
asr_events = lens["Events"].Each()["Event"].Filter(contains_asr)
call_evs_0 = asr_events.Filter(channel(0)).collect()(events)
call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
if verbose:
typer.echo(f"processing data points on {call_wav_fname}")
call_data_0 = compute_endtime(call_wav_0, call_evs_0)
call_data_1 = compute_endtime(call_wav_1, call_evs_1)
return chain(call_data_0, call_data_1)
def generate_call_asr_data():
full_asr_data = []
total_duration = 0
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
asr_data = asr_data_generator(wav, wav_path, ev)
total_duration += wav.duration_seconds
full_asr_data.append(asr_data)
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
typer.echo(f"written {n_dps} data points")
generate_call_asr_data()
def main():
app()
if __name__ == "__main__":
main()

View File

@ -11,6 +11,29 @@ app = typer.Typer()
# app.add_typer(plot_app, name="plot")
@app.command()
def export_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
from pymongo import MongoClient
from collections import defaultdict
from ruamel.yaml import YAML
yaml = YAML()
mongo_collection = MongoClient("mongodb://localhost:27017/").test.calls
caller_calls = defaultdict(lambda: [])
for call in mongo_collection.find():
sysid = call["SystemID"]
call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}"
caller = call["Caller"]
caller_calls[caller].append(call_uri)
caller_list = []
for caller in caller_calls:
caller_list.append({"name": caller, "calls": caller_calls[caller]})
output_yaml = {"users": caller_list}
typer.echo("exporting call logs to yaml file")
with call_logs_file.open("w") as yf:
yaml.dump(output_yaml, yf)
@app.command()
def analyze(
leaderboard: bool = False,
@ -19,8 +42,6 @@ def analyze(
call_logs_file: Path = Path("./call_logs.yaml"),
output_dir: Path = Path("./data"),
):
call_logs_file = Path("./call_logs.yaml")
output_dir = Path("./data")
from urllib.parse import urlsplit
from functools import reduce
@ -35,7 +56,6 @@ def analyze(
from datetime import timedelta
# from concurrent.futures import ThreadPoolExecutor
from dateutil.relativedelta import relativedelta
import librosa
import librosa.display
from lenses import lens
@ -46,6 +66,8 @@ def analyze(
from tqdm import tqdm
from .utils import asr_data_writer
from pydub import AudioSegment
from natural.date import compress
# from itertools import product, chain
matplotlib.rcParams["agg.path.chunksize"] = 10000
@ -256,8 +278,11 @@ def analyze(
code_fb = BytesIO()
code_seg.export(code_fb, format="wav")
code_wav = code_fb.getvalue()
# import pdb; pdb.set_trace()
yield code, code_seg.duration_seconds, code_wav
# search for actual pnr code and handle plain codes as well
extracted_code = (
re.search(r"'(.*)'", code).groups(0)[0] if len(code) > 6 else code
)
yield extracted_code, code_seg.duration_seconds, code_wav
call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_callmeta)(call_logs)
@ -275,22 +300,17 @@ def analyze(
asr_data_writer(call_asr_data, "call_alphanum", data_source())
# @leader_app.command()
def show_leaderboard():
def compute_user_stats(call_stat):
n_samples = (
lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
)
n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
rel_dur = relativedelta(
seconds=int(n_duration.total_seconds()),
microseconds=n_duration.microseconds,
)
return {
"num_samples": n_samples,
"duration": n_duration.total_seconds(),
"samples_rate": n_samples / n_duration.total_seconds(),
"duration_str": f"{rel_dur.minutes} mins {rel_dur.seconds} secs",
"duration_str": compress(n_duration, pad=" "),
"name": call_stat["name"],
}
@ -313,8 +333,8 @@ def analyze(
}
)[["Rank", "Name", "Codes", "Duration"]]
print(
"""Today's ASR Speller Dataset Leaderboard:
----------------------------------------"""
"""ASR Speller Dataset Leaderboard :
---------------------------------"""
)
print(leader_board.to_string(index=False))

View File

@ -0,0 +1,29 @@
import typer
import rpyc
import os
from pathlib import Path
from rpyc.utils.server import ThreadedServer
app = typer.Typer()
class ASRDataService(rpyc.Service):
def get_data_loader(self, data_manifest: Path):
return "hello"
@app.command()
def run_server(port: int = 0):
listen_port = port if port else int(os.environ.get("ASR_RPYC_PORT", "8044"))
service = ASRDataService()
t = ThreadedServer(service, port=listen_port)
typer.echo(f"starting asr server on {listen_port}...")
t.start()
def main():
app()
if __name__ == "__main__":
main()

View File

@ -1,8 +1,13 @@
import numpy as np
import wave
import io
import os
import json
from pathlib import Path
import pymongo
from slugify import slugify
from uuid import uuid4
from num2words import num2words
@ -46,42 +51,65 @@ def alnum_to_asr_tokens(text):
return ("".join(num_tokens)).lower()
def asr_data_writer(output_dir, dataset_name, asr_data_source):
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
with asr_manifest.open("w") as mf:
for pnr_code, audio_dur, wav_data in asr_data_source:
pnr_af = dataset_dir / Path("wav") / Path(pnr_code).with_suffix(".wav")
for transcript, audio_dur, wav_data in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
pnr_af = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
pnr_af.write_bytes(wav_data)
rel_pnr_path = pnr_af.relative_to(dataset_dir)
manifest = manifest_str(
str(rel_pnr_path), audio_dur, alnum_to_asr_tokens(pnr_code)
)
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
mf.write(manifest)
if verbose:
print(f"writing '{transcript}' of duration {audio_dur}")
num_datapoints += 1
return num_datapoints
def asr_manifest_reader(data_manifest_path: Path):
print(f'reading manifest from {data_manifest_path}')
print(f"reading manifest from {data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [json.loads(v) for v in pnr_jsonl]
for p in pnr_data:
p['audio_path'] = data_manifest_path.parent / Path(p['audio_filepath'])
p['chars'] = Path(p['audio_filepath']).stem
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["chars"] = Path(p["audio_filepath"]).stem
yield p
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
with asr_manifest_path.open("w") as mf:
print(f'opening {asr_manifest_path} for writing manifest')
print(f"opening {asr_manifest_path} for writing manifest")
for mani_dict in manifest_str_source:
manifest = manifest_str(
mani_dict['audio_filepath'], mani_dict['duration'], mani_dict['text']
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
)
mf.write(manifest)
class ExtendedPath(type(Path())):
"""docstring for ExtendedPath."""
def read_json(self):
with self.open("r") as jf:
return json.load(jf)
def write_json(self, data):
self.parent.mkdir(parents=True, exist_ok=True)
with self.open("w") as jf:
return json.dump(data, jf, indent=2)
def get_mongo_conn(host=''):
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
mongo_uri = f"mongodb://{mongo_host}:27017/"
return pymongo.MongoClient(mongo_uri)
def main():
for c in random_pnr_generator():
print(c)

View File

@ -1,105 +1,137 @@
import pymongo
import typer
# import matplotlib.pyplot as plt
from pathlib import Path
import json
import shutil
from pathlib import Path
# import pandas as pd
from pydub import AudioSegment
# from .jasper_client import transcriber_pretrained, transcriber_speller
from jasper.data_utils.validation.jasper_client import (
transcriber_pretrained,
transcriber_speller,
)
from jasper.data_utils.utils import alnum_to_asr_tokens
# import importlib
# import jasper.data_utils.utils
# importlib.reload(jasper.data_utils.utils)
from jasper.data_utils.utils import asr_manifest_reader, asr_manifest_writer
from nemo.collections.asr.metrics import word_error_rate
# from tqdm import tqdm as tqdm_base
import typer
from tqdm import tqdm
from ..utils import (
alnum_to_asr_tokens,
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
get_mongo_conn,
)
app = typer.Typer()
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
import matplotlib.pyplot as plt
import librosa
import librosa.display
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.data_utils.validation.jasper_client import (
transcriber_pretrained,
transcriber_speller,
)
try:
res = dict(sample)
res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
res["spoken"] = alnum_to_asr_tokens(res["text"])
res["utterance_id"] = audio_path.stem
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
if use_domain_asr:
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]]
)
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
fig = plt.Figure()
ax = fig.add_subplot()
(y, sr) = librosa.load(audio_path)
librosa.display.waveplot(y=y, sr=sr, ax=ax)
with wav_plot_path.open("wb") as wav_plot_f:
fig.set_tight_layout(True)
fig.savefig(wav_plot_f, format="png", dpi=50)
# fig.close()
res["plot_path"] = str(wav_plot_path)
return res
except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
@app.command()
def dump_validation_ui_data(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
dump_path: Path = Path("./data/valiation_data/ui_dump.json"),
use_domain_asr: bool = True,
):
from concurrent.futures import ThreadPoolExecutor
from functools import partial
plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_funcs = [
partial(
preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
)
for i, v in enumerate(pnr_jsonl)
]
def exec_func(f):
return f()
with ThreadPoolExecutor(max_workers=20) as exe:
print("starting all plot tasks")
pnr_data = filter(
None,
list(
tqdm(
exe.map(exec_func, pnr_funcs),
position=0,
leave=True,
total=len(pnr_funcs),
)
),
)
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
ui_config = {"use_domain_asr": use_domain_asr, "data": result}
ExtendedPath(dump_path).write_json(ui_config)
@app.command()
def dump_corrections(dump_path: Path = Path("./data/corrections.json")):
col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
col = get_mongo_conn().test.asr_validation
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
corrections = [c for c in cursor_obj]
dump_f = dump_path.open("w")
json.dump(corrections, dump_f, indent=2)
dump_f.close()
def preprocess_datapoint(idx, rel, sample):
res = dict(sample)
res["real_idx"] = idx
audio_path = rel / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
res["gold_chars"] = audio_path.stem
res["gold_phone"] = sample["text"]
aud_seg = (
AudioSegment.from_wav(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
return res
def load_dataset(data_manifest_path: Path):
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
]
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
return result
@app.command()
def dump_processed_data(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
dump_path: Path = Path("./data/processed_data.json"),
):
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
]
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
dump_path = Path("./data/processed_data.json")
dump_f = dump_path.open("w")
json.dump(result, dump_f, indent=2)
dump_f.close()
ExtendedPath(dump_path).write_json(corrections)
@app.command()
def fill_unannotated(
processed_data_path: Path = Path("./data/processed_data.json"),
corrections_path: Path = Path("./data/corrections.json"),
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
):
processed_data = json.load(processed_data_path.open())
corrections = json.load(corrections_path.open())
annotated_codes = {c["code"] for c in corrections}
all_codes = {c["gold_chars"] for c in processed_data}
unann_codes = all_codes - annotated_codes
mongo_conn = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
mongo_conn = get_mongo_conn().test.asr_validation
for c in unann_codes:
mongo_conn.find_one_and_update(
{"type": "correction", "code": c},
@ -111,8 +143,8 @@ def fill_unannotated(
@app.command()
def update_corrections(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
processed_data_path: Path = Path("./data/processed_data.json"),
corrections_path: Path = Path("./data/corrections.json"),
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
):
def correct_manifest(manifest_data_gen, corrections_path):
corrections = json.load(corrections_path.open())
@ -168,6 +200,12 @@ def update_corrections(
new_data_manifest_path.replace(data_manifest_path)
@app.command()
def clear_mongo_corrections():
col = get_mongo_conn().test.asr_validation
col.delete_many({"type": "correction"})
def main():
app()

View File

@ -1,27 +1,15 @@
import json
from io import BytesIO
from pathlib import Path
import streamlit as st
from nemo.collections.asr.metrics import word_error_rate
import librosa
import librosa.display
import matplotlib.pyplot as plt
from tqdm import tqdm
from pydub import AudioSegment
import pymongo
import typer
from .jasper_client import transcriber_pretrained, transcriber_speller
from ..utils import ExtendedPath, get_mongo_conn
from .st_rerun import rerun
app = typer.Typer()
st.title("ASR Speller Validation")
if not hasattr(st, "mongo_connected"):
st.mongoclient = pymongo.MongoClient(
"mongodb://localhost:27017/"
).test.asr_validation
st.mongoclient = get_mongo_conn().test.asr_validation
mongo_conn = st.mongoclient
def current_cursor_fn():
@ -63,80 +51,49 @@ if not hasattr(st, "mongo_connected"):
st.mongo_connected = True
# def clear_mongo_corrections():
# col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
# col.delete_many({"type": "correction"})
def preprocess_datapoint(idx, rel, sample):
res = dict(sample)
res["real_idx"] = idx
audio_path = rel / Path(sample["audio_filepath"])
res["audio_path"] = audio_path
res["gold_chars"] = audio_path.stem
aud_seg = (
AudioSegment.from_wav(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
res["wer"] = word_error_rate([res["text"]], [res["speller_asr"]])
(y, sr) = librosa.load(audio_path)
plt.tight_layout()
librosa.display.waveplot(y=y, sr=sr)
wav_plot_f = BytesIO()
plt.savefig(wav_plot_f, format="png", dpi=50)
plt.close()
wav_plot_f.seek(0)
res["plot_png"] = wav_plot_f
return res
@st.cache(hash_funcs={"rpyc.core.netref.builtins.method": lambda _: None})
def preprocess_dataset(data_manifest_path: Path):
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
for i, v in enumerate(tqdm(pnr_jsonl))
]
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
return result
@st.cache()
def load_ui_data(validation_ui_data_path: Path):
typer.echo(f"Using validation ui data from :{validation_ui_data_path}")
return ExtendedPath(validation_ui_data_path).read_json()
@app.command()
def main(manifest: Path):
pnr_data = preprocess_dataset(manifest)
ui_config = load_ui_data(manifest)
asr_data = ui_config["data"]
use_domain_asr = ui_config["use_domain_asr"]
sample_no = st.get_current_cursor()
sample = pnr_data[sample_no]
st.markdown(
f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['text']}*"
sample = asr_data[sample_no]
title_type = 'Speller ' if use_domain_asr else ''
st.title(f"ASR {title_type}Validation")
addl_text = (
f"spelled *{sample['spoken']}*" if use_domain_asr else ""
)
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data)
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**")
st.sidebar.markdown(f"Expected Speech: *{sample['text']}*")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
if use_domain_asr:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:")
st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}")
st.sidebar.text(f"Speller:{sample['speller_asr']}")
st.sidebar.title(f"Speller WER: {sample['wer']:.2f}%")
# (y, sr) = librosa.load(sample["audio_path"])
# librosa.display.waveplot(y=y, sr=sr)
# st.sidebar.pyplot(fig=sample["plot_fig"])
st.sidebar.image(sample["plot_png"])
st.audio(sample["audio_path"].open("rb"))
corrected = sample["gold_chars"]
correction_entry = st.get_correction_entry(sample["gold_chars"])
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text
corrected = sample["text"]
correction_entry = st.get_correction_entry(sample["utterance_id"])
selected_idx = 0
options = ("Correct", "Incorrect", "Inaudible")
# if correction entry is present set the corresponding ui defaults
if correction_entry:
selected_idx = options.index(correction_entry["value"]["status"])
corrected = correction_entry["value"]["correction"]
@ -148,24 +105,26 @@ def main(manifest: Path):
if st.button("Submit"):
correct_code = corrected.replace(" ", "").upper()
st.update_entry(
sample["gold_chars"], {"status": selected, "correction": correct_code}
sample["utterance_id"], {"status": selected, "correction": correct_code}
)
st.update_cursor(sample_no + 1)
if correction_entry:
st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
)
# real_idx = st.text_input("Go to real-index:", value=sample['real_idx'])
# st.markdown(
# ",".join(
# [
# "**" + str(p["real_idx"]) + "**"
# if p["real_idx"] == sample["real_idx"]
# else str(p["real_idx"])
# for p in pnr_data
# ]
# )
# )
# if st.button("Previous Untagged"):
# pass
# if st.button("Next Untagged"):
# pass
real_idx = st.number_input(
"Go to real-index",
value=sample["real_idx"],
min_value=0,
max_value=len(asr_data) - 1,
)
if real_idx != int(sample["real_idx"]):
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
st.update_cursor(idx)
if __name__ == "__main__":

View File

@ -22,13 +22,19 @@ extra_requirements = {
"matplotlib==3.2.1",
"pandas==1.0.3",
"tabulate==0.8.7",
"natural==0.2.0",
"num2words==0.5.10",
"typer[all]==0.1.1",
"python-slugify==4.0.0",
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
],
"validation": [
"rpyc~=4.1.4",
"pymongo==3.10.1",
"typer[all]==0.1.1",
"tqdm~=4.39.0",
"librosa==0.7.2",
"matplotlib==3.2.1",
"pydub~=0.23.1",
"streamlit==0.58.0",
"stringcase==1.2.0"
@ -58,6 +64,7 @@ setup(
"jasper_asr_trainer = jasper.train:main",
"jasper_asr_data_generate = jasper.data_utils.generator:main",
"jasper_asr_data_recycle = jasper.data_utils.call_recycler:main",
"jasper_asr_data_validation = jasper.data_utils.validation.process:main",
"jasper_asr_data_preprocess = jasper.data_utils.process:main",
]
},