1. added start delay arg in call recycler
2. implement ui_dump/manifest writer in call_recycler itself 3. refactored call data point plotter 4. added sample-ui task-ui on the validation process 5. implemented call-quality stats using corrections from mongo 6. support deleting cursors on mongo 7. implement multiple task support on validation ui based on task_id mongo field
parent
7dbb04dcbf
commit
8e238c254e
|
|
@ -3,6 +3,7 @@
|
|||
/train/
|
||||
.env*
|
||||
*.yaml
|
||||
*.yml
|
||||
*.json
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -93,8 +93,8 @@ def copy_metas(
|
|||
|
||||
def copy_meta(uri):
|
||||
cid = get_cid(uri)
|
||||
saved_meta_path = call_meta_dir / Path(f'{cid}.json')
|
||||
dest_meta_path = meta_dir / Path(f'{cid}.json')
|
||||
saved_meta_path = call_meta_dir / Path(f"{cid}.json")
|
||||
dest_meta_path = meta_dir / Path(f"{cid}.json")
|
||||
if not saved_meta_path.exists():
|
||||
print(f"{saved_meta_path} not found")
|
||||
copy2(saved_meta_path, dest_meta_path)
|
||||
|
|
@ -106,7 +106,6 @@ def copy_metas(
|
|||
download_meta_audio()
|
||||
|
||||
|
||||
|
||||
class ExtractionType(str, Enum):
|
||||
flow = "flow"
|
||||
data = "data"
|
||||
|
|
@ -120,6 +119,7 @@ def analyze(
|
|||
extraction_type: ExtractionType = typer.Option(
|
||||
ExtractionType.data, show_default=True
|
||||
),
|
||||
start_delay: float = 3,
|
||||
download_only: bool = False,
|
||||
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||
output_dir: Path = Path("./data"),
|
||||
|
|
@ -146,7 +146,7 @@ def analyze(
|
|||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
from tqdm import tqdm
|
||||
from .utils import asr_data_writer, get_mongo_coll
|
||||
from .utils import ui_dump_manifest_writer, get_mongo_coll
|
||||
from pydub import AudioSegment
|
||||
from natural.date import compress
|
||||
|
||||
|
|
@ -215,7 +215,7 @@ def analyze(
|
|||
assert evs[0]["Type"] == "CONV_RESULT"
|
||||
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
||||
assert evs[2]["Type"] == "STOPPED_SPEAKING"
|
||||
start_time = td_fn(evs[1]).total_seconds() - 2
|
||||
start_time = td_fn(evs[1]).total_seconds() - start_delay
|
||||
end_time = td_fn(evs[2]).total_seconds()
|
||||
spoken = evs[0]["Msg"]
|
||||
data_points.append(
|
||||
|
|
@ -227,7 +227,11 @@ def analyze(
|
|||
return data_points
|
||||
|
||||
def text_extractor(spoken):
|
||||
return re.search(r"'(.*)'", spoken).groups(0)[0] if len(spoken) > 6 and re.search(r"'(.*)'", spoken) else spoken
|
||||
return (
|
||||
re.search(r"'(.*)'", spoken).groups(0)[0]
|
||||
if len(spoken) > 6 and re.search(r"'(.*)'", spoken)
|
||||
else spoken
|
||||
)
|
||||
|
||||
elif extraction_type == ExtractionType.flow:
|
||||
|
||||
|
|
@ -254,14 +258,20 @@ def analyze(
|
|||
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
||||
assert evs[2]["Type"] == "ASR_RESULT"
|
||||
assert evs[3]["Type"] == "STOPPED_SPEAKING"
|
||||
start_time = td_fn(evs[1]).total_seconds() - 1.5
|
||||
start_time = td_fn(evs[1]).total_seconds() - start_delay
|
||||
end_time = td_fn(evs[2]).total_seconds()
|
||||
conv_msg = evs[0]["Msg"]
|
||||
if 'full name' in conv_msg.lower():
|
||||
if "full name" in conv_msg.lower():
|
||||
pld = json.loads(evs[2]["Payload"])
|
||||
spoken = pld["AsrResult"]["Results"][0]["Alternatives"][0]['Transcript']
|
||||
spoken = pld["AsrResult"]["Results"][0]["Alternatives"][0][
|
||||
"Transcript"
|
||||
]
|
||||
data_points.append(
|
||||
{"start_time": start_time, "end_time": end_time, "code": spoken}
|
||||
{
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"code": spoken,
|
||||
}
|
||||
)
|
||||
except AssertionError:
|
||||
# skipping invalid data_points
|
||||
|
|
@ -330,6 +340,25 @@ def analyze(
|
|||
process_meta["data_points"] = data_points
|
||||
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
|
||||
|
||||
def retrieve_callmeta(call_uri):
|
||||
uri = call_uri["call_uri"]
|
||||
name = call_uri["name"]
|
||||
cid = get_cid(uri)
|
||||
meta = mongo_collection.find_one({"SystemID": cid})
|
||||
duration = meta["EndTS"] - meta["StartTS"]
|
||||
process_meta = process_call(meta)
|
||||
data_points = get_data_points(
|
||||
process_meta["utter_events"], process_meta["first_event_fn"]
|
||||
)
|
||||
process_meta["data_points"] = data_points
|
||||
return {
|
||||
"url": uri,
|
||||
"name": name,
|
||||
"meta": meta,
|
||||
"duration": duration,
|
||||
"process": process_meta,
|
||||
}
|
||||
|
||||
def download_meta_audio():
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_lens.modify(ensure_call)(call_logs)
|
||||
|
|
@ -379,7 +408,7 @@ def analyze(
|
|||
pprint(call_plots)
|
||||
|
||||
def extract_data_points():
|
||||
def gen_data_values(saved_wav_path, data_points):
|
||||
def gen_data_values(saved_wav_path, data_points, caller_name):
|
||||
call_seg = (
|
||||
AudioSegment.from_wav(saved_wav_path)
|
||||
.set_channels(1)
|
||||
|
|
@ -394,23 +423,32 @@ def analyze(
|
|||
spoken_wav = spoken_fb.getvalue()
|
||||
# search for actual pnr code and handle plain codes as well
|
||||
extracted_code = text_extractor(spoken)
|
||||
yield extracted_code, spoken_seg.duration_seconds, spoken_wav
|
||||
yield extracted_code, spoken_seg.duration_seconds, spoken_wav, caller_name, spoken_seg
|
||||
|
||||
call_lens = lens["users"].Each()["calls"].Each()
|
||||
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
|
||||
|
||||
def assign_user_call(uc):
|
||||
return (
|
||||
lens["calls"]
|
||||
.Each()
|
||||
.modify(lambda c: {"call_uri": c, "name": uc["name"]})(uc)
|
||||
)
|
||||
|
||||
user_call_logs = lens["users"].Each().modify(assign_user_call)(call_logs)
|
||||
call_stats = call_lens.modify(retrieve_callmeta)(user_call_logs)
|
||||
call_objs = call_lens.collect()(call_stats)
|
||||
|
||||
def data_source():
|
||||
for call_obj in tqdm(call_objs):
|
||||
saved_wav_path, data_points, sys_id = (
|
||||
saved_wav_path, data_points, name = (
|
||||
call_obj["process"]["wav_path"],
|
||||
call_obj["process"]["data_points"],
|
||||
call_obj["meta"]["SystemID"],
|
||||
call_obj["name"],
|
||||
)
|
||||
for dp in gen_data_values(saved_wav_path, data_points):
|
||||
for dp in gen_data_values(saved_wav_path, data_points, name):
|
||||
yield dp
|
||||
|
||||
asr_data_writer(call_asr_data, dataset_name, data_source())
|
||||
ui_dump_manifest_writer(call_asr_data, dataset_name, data_source())
|
||||
|
||||
def show_leaderboard():
|
||||
def compute_user_stats(call_stat):
|
||||
|
|
|
|||
|
|
@ -9,6 +9,14 @@ import pymongo
|
|||
from slugify import slugify
|
||||
from uuid import uuid4
|
||||
from num2words import num2words
|
||||
from jasper.client import transcribe_gen
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
def manifest_str(path, dur, text):
|
||||
|
|
@ -57,11 +65,12 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
|||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
num_datapoints = 0
|
||||
with asr_manifest.open("w") as mf:
|
||||
print(f"writing manifest to {asr_manifest}")
|
||||
for transcript, audio_dur, wav_data in asr_data_source:
|
||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
pnr_af = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
pnr_af.write_bytes(wav_data)
|
||||
rel_pnr_path = pnr_af.relative_to(dataset_dir)
|
||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
audio_file.write_bytes(wav_data)
|
||||
rel_pnr_path = audio_file.relative_to(dataset_dir)
|
||||
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
|
||||
mf.write(manifest)
|
||||
if verbose:
|
||||
|
|
@ -70,6 +79,94 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
|||
return num_datapoints
|
||||
|
||||
|
||||
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
dataset_dir = output_dir / Path(dataset_name)
|
||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
num_datapoints = 0
|
||||
ui_dump = {
|
||||
"use_domain_asr": False,
|
||||
"annotation_only": False,
|
||||
"enable_plots": True,
|
||||
"data": [],
|
||||
}
|
||||
data_funcs = []
|
||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||
with asr_manifest.open("w") as mf:
|
||||
print(f"writing manifest to {asr_manifest}")
|
||||
|
||||
def data_fn(
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
fname,
|
||||
audio_path,
|
||||
num_datapoints,
|
||||
rel_pnr_path,
|
||||
):
|
||||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
||||
wav_plot_path = (
|
||||
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
plot_seg(wav_plot_path, audio_path)
|
||||
return {
|
||||
"audio_filepath": str(rel_pnr_path),
|
||||
"duration": round(audio_dur, 1),
|
||||
"text": transcript,
|
||||
"real_idx": num_datapoints,
|
||||
"audio_path": audio_path,
|
||||
"spoken": transcript,
|
||||
"caller": caller_name,
|
||||
"utterance_id": fname,
|
||||
"pretrained_asr": pretrained_result,
|
||||
"pretrained_wer": pretrained_wer,
|
||||
"plot_path": str(wav_plot_path),
|
||||
}
|
||||
|
||||
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
audio_file.write_bytes(wav_data)
|
||||
audio_path = str(audio_file)
|
||||
rel_pnr_path = audio_file.relative_to(dataset_dir)
|
||||
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
|
||||
mf.write(manifest)
|
||||
data_funcs.append(
|
||||
partial(
|
||||
data_fn,
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
fname,
|
||||
audio_path,
|
||||
num_datapoints,
|
||||
rel_pnr_path,
|
||||
)
|
||||
)
|
||||
num_datapoints += 1
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all plot/transcription tasks")
|
||||
dump_data = list(
|
||||
tqdm(
|
||||
exe.map(lambda x: x(), data_funcs),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=len(data_funcs),
|
||||
)
|
||||
)
|
||||
ui_dump["data"] = dump_data
|
||||
ExtendedPath(ui_dump_file).write_json(ui_dump)
|
||||
return num_datapoints
|
||||
|
||||
|
||||
def asr_manifest_reader(data_manifest_path: Path):
|
||||
print(f"reading manifest from {data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
|
|
@ -95,12 +192,12 @@ class ExtendedPath(type(Path())):
|
|||
"""docstring for ExtendedPath."""
|
||||
|
||||
def read_json(self):
|
||||
print(f'reading json from {self}')
|
||||
print(f"reading json from {self}")
|
||||
with self.open("r") as jf:
|
||||
return json.load(jf)
|
||||
|
||||
def write_json(self, data):
|
||||
print(f'writing json to {self}')
|
||||
print(f"writing json to {self}")
|
||||
self.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.open("w") as jf:
|
||||
return json.dump(data, jf, indent=2)
|
||||
|
|
@ -109,7 +206,7 @@ class ExtendedPath(type(Path())):
|
|||
def get_mongo_coll(uri="mongodb://localhost:27017/test.calls"):
|
||||
ud = pymongo.uri_parser.parse_uri(uri)
|
||||
conn = pymongo.MongoClient(uri)
|
||||
return conn[ud['database']][ud['collection']]
|
||||
return conn[ud["database"]][ud["collection"]]
|
||||
|
||||
|
||||
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
|
||||
|
|
@ -127,6 +224,16 @@ def strip_silence(sound):
|
|||
return sound[start_trim : duration - end_trim]
|
||||
|
||||
|
||||
def plot_seg(wav_plot_path, audio_path):
|
||||
fig = plt.Figure()
|
||||
ax = fig.add_subplot()
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
||||
with wav_plot_path.open("wb") as wav_plot_f:
|
||||
fig.set_tight_layout(True)
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
|
||||
|
||||
def main():
|
||||
for c in random_pnr_generator():
|
||||
print(c)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from ..utils import (
|
|||
asr_manifest_reader,
|
||||
asr_manifest_writer,
|
||||
get_mongo_conn,
|
||||
plot_seg,
|
||||
)
|
||||
|
||||
app = typer.Typer()
|
||||
|
|
@ -20,9 +21,6 @@ app = typer.Typer()
|
|||
def preprocess_datapoint(
|
||||
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
from pydub import AudioSegment
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
from jasper.client import transcribe_gen
|
||||
|
|
@ -61,14 +59,7 @@ def preprocess_datapoint(
|
|||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
fig = plt.Figure()
|
||||
ax = fig.add_subplot()
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
||||
with wav_plot_path.open("wb") as wav_plot_f:
|
||||
fig.set_tight_layout(True)
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
# fig.close()
|
||||
plot_seg(wav_plot_path, audio_path)
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
return res
|
||||
except BaseException as e:
|
||||
|
|
@ -131,17 +122,66 @@ def dump_ui(
|
|||
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
||||
ui_config = {
|
||||
"use_domain_asr": use_domain_asr,
|
||||
"data": result,
|
||||
"annotation_only": annotation_only,
|
||||
"enable_plots": enable_plots,
|
||||
"data": result,
|
||||
}
|
||||
ExtendedPath(dump_path).write_json(ui_config)
|
||||
|
||||
|
||||
@app.command()
|
||||
def sample_ui(
|
||||
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
sample_count: int = typer.Option(80, show_default=True),
|
||||
sample_file: Path = Path("sample_dump.json"),
|
||||
):
|
||||
import pandas as pd
|
||||
|
||||
processed_data_path = dump_dir / Path(data_name) / dump_file
|
||||
sample_path = dump_dir / Path(data_name) / sample_file
|
||||
processed_data = ExtendedPath(processed_data_path).read_json()
|
||||
df = pd.DataFrame(processed_data["data"])
|
||||
samples_per_caller = sample_count // len(df["caller"].unique())
|
||||
caller_samples = pd.concat(
|
||||
[g.sample(samples_per_caller) for (c, g) in df.groupby("caller")]
|
||||
)
|
||||
caller_samples = caller_samples.reset_index(drop=True)
|
||||
caller_samples["real_idx"] = caller_samples.index
|
||||
sample_data = caller_samples.to_dict("records")
|
||||
processed_data["data"] = sample_data
|
||||
typer.echo(f"sampling {sample_count} datapoints")
|
||||
ExtendedPath(sample_path).write_json(processed_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def task_ui(
|
||||
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
task_count: int = typer.Option(4, show_default=True),
|
||||
task_file: str = "task_dump",
|
||||
):
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
processed_data_path = dump_dir / Path(data_name) / dump_file
|
||||
processed_data = ExtendedPath(processed_data_path).read_json()
|
||||
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
|
||||
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
|
||||
task_f = task_f.reset_index(drop=True)
|
||||
task_f["real_idx"] = task_f.index
|
||||
task_data = task_f.to_dict("records")
|
||||
processed_data["data"] = task_data
|
||||
task_path = dump_dir / Path(data_name) / Path(task_file + f"-{t_idx}.json")
|
||||
ExtendedPath(task_path).write_json(processed_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_corrections(
|
||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||
dump_dir: Path = Path("./data/valiation_data"),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_fname: Path = Path("corrections.json"),
|
||||
):
|
||||
dump_path = dump_dir / Path(data_name) / dump_fname
|
||||
|
|
@ -152,6 +192,38 @@ def dump_corrections(
|
|||
ExtendedPath(dump_path).write_json(corrections)
|
||||
|
||||
|
||||
@app.command()
|
||||
def caller_quality(
|
||||
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_fname: Path = Path("ui_dump.json"),
|
||||
correction_fname: Path = Path("corrections.json"),
|
||||
):
|
||||
import copy
|
||||
import pandas as pd
|
||||
|
||||
dump_path = dump_dir / Path(data_name) / dump_fname
|
||||
correction_path = dump_dir / Path(data_name) / correction_fname
|
||||
dump_data = ExtendedPath(dump_path).read_json()
|
||||
|
||||
dump_map = {d["utterance_id"]: d for d in dump_data["data"]}
|
||||
correction_data = ExtendedPath(correction_path).read_json()
|
||||
|
||||
def correction_dp(c):
|
||||
dp = copy.deepcopy(dump_map[c["code"]])
|
||||
dp["valid"] = c["value"]["status"] == "Correct"
|
||||
return dp
|
||||
|
||||
corrected_dump = [correction_dp(c) for c in correction_data]
|
||||
df = pd.DataFrame(corrected_dump)
|
||||
print(f"Total samples: {len(df)}")
|
||||
for (c, g) in df.groupby("caller"):
|
||||
total = len(g)
|
||||
valid = len(g[g["valid"] == True])
|
||||
valid_rate = valid * 100 / total
|
||||
print(f"Caller: {c} Valid%:{valid_rate:.2f} of {total} samples")
|
||||
|
||||
|
||||
@app.command()
|
||||
def fill_unannotated(
|
||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||
|
|
@ -329,7 +401,9 @@ def clear_mongo_corrections():
|
|||
if delete:
|
||||
col = get_mongo_conn(col="asr_validation")
|
||||
col.delete_many({"type": "correction"})
|
||||
col.delete_many({"type": "current_cursor"})
|
||||
typer.echo("deleted mongo collection.")
|
||||
return
|
||||
typer.echo("Aborted")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from pathlib import Path
|
|||
|
||||
import streamlit as st
|
||||
import typer
|
||||
from uuid import uuid4
|
||||
from ..utils import ExtendedPath, get_mongo_conn
|
||||
from .st_rerun import rerun
|
||||
|
||||
|
|
@ -11,25 +12,25 @@ app = typer.Typer()
|
|||
if not hasattr(st, "mongo_connected"):
|
||||
st.mongoclient = get_mongo_conn(col="asr_validation")
|
||||
mongo_conn = st.mongoclient
|
||||
st.task_id = str(uuid4())
|
||||
|
||||
def current_cursor_fn():
|
||||
# mongo_conn = st.mongoclient
|
||||
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
|
||||
cursor_obj = mongo_conn.find_one(
|
||||
{"type": "current_cursor", "task_id": st.task_id}
|
||||
)
|
||||
cursor_val = cursor_obj["cursor"]
|
||||
return cursor_val
|
||||
|
||||
def update_cursor_fn(val=0):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "current_cursor"},
|
||||
{"$set": {"type": "current_cursor", "cursor": val}},
|
||||
{"type": "current_cursor", "task_id": st.task_id},
|
||||
{"$set": {"type": "current_cursor", "task_id": st.task_id, "cursor": val}},
|
||||
upsert=True,
|
||||
)
|
||||
rerun()
|
||||
|
||||
def get_correction_entry_fn(code):
|
||||
# mongo_conn = st.mongoclient
|
||||
# cursor_obj = mongo_conn.find_one({"type": "correction", "code": code})
|
||||
# cursor_val = cursor_obj["cursor"]
|
||||
return mongo_conn.find_one(
|
||||
{"type": "correction", "code": code}, projection={"_id": False}
|
||||
)
|
||||
|
|
@ -37,18 +38,25 @@ if not hasattr(st, "mongo_connected"):
|
|||
def update_entry_fn(code, value):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": code},
|
||||
{"$set": {"value": value}},
|
||||
{"$set": {"value": value, "task_id": st.task_id}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
|
||||
if not cursor_obj:
|
||||
update_cursor_fn(0)
|
||||
def set_task_fn(mf_path):
|
||||
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
|
||||
if not task_path.exists():
|
||||
print(f"creating task lock at {task_path}")
|
||||
task_path.touch()
|
||||
|
||||
st.get_current_cursor = current_cursor_fn
|
||||
st.update_cursor = update_cursor_fn
|
||||
st.get_correction_entry = get_correction_entry_fn
|
||||
st.update_entry = update_entry_fn
|
||||
st.set_task = set_task_fn
|
||||
st.mongo_connected = True
|
||||
cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
|
||||
if not cursor_obj:
|
||||
update_cursor_fn(0)
|
||||
|
||||
|
||||
@st.cache()
|
||||
|
|
@ -59,6 +67,7 @@ def load_ui_data(validation_ui_data_path: Path):
|
|||
|
||||
@app.command()
|
||||
def main(manifest: Path):
|
||||
st.set_task(manifest)
|
||||
ui_config = load_ui_data(manifest)
|
||||
asr_data = ui_config["data"]
|
||||
use_domain_asr = ui_config.get("use_domain_asr", True)
|
||||
|
|
@ -70,10 +79,11 @@ def main(manifest: Path):
|
|||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
title_type = "Speller " if use_domain_asr else ""
|
||||
task_uid = st.task_id.rsplit("-", 1)[1]
|
||||
if annotation_only:
|
||||
st.title(f"ASR Annotation")
|
||||
st.title(f"ASR Annotation - # {task_uid}")
|
||||
else:
|
||||
st.title(f"ASR {title_type}Validation")
|
||||
st.title(f"ASR {title_type}Validation - # {task_uid}")
|
||||
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
||||
new_sample = st.number_input(
|
||||
|
|
@ -88,6 +98,8 @@ def main(manifest: Path):
|
|||
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
||||
st.sidebar.title("Results:")
|
||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||
if "caller" in sample:
|
||||
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
||||
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
||||
|
|
|
|||
Loading…
Reference in New Issue