refactor validation process arguments and logging

Malar Kannan 2020-06-05 16:32:08 +05:30
parent bca227a7d7
commit 8db1be0083
2 changed files with 25 additions and 14 deletions

View File

@ -76,10 +76,10 @@ def preprocess_datapoint(
@app.command() @app.command()
def dump_validation_ui_data( def dump_validation_ui_data(
dataset_path: Path = typer.Option( data_name: str = typer.Option("call_alphanum", show_default=True),
Path("./data/asr_data/call_alphanum"), show_default=True dataset_dir: Path = Path("./data/asr_data"),
), dump_dir: Path = Path("./data/valiation_data"),
dump_name: str = typer.Option("ui_dump.json", show_default=True), dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
use_domain_asr: bool = False, use_domain_asr: bool = False,
annotation_only: bool = False, annotation_only: bool = False,
enable_plots: bool = True, enable_plots: bool = True,
@ -87,8 +87,8 @@ def dump_validation_ui_data(
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
data_manifest_path = dataset_path / Path("manifest.json") data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
dump_path: Path = Path(f"./data/valiation_data/{dataset_path.stem}/{dump_name}") dump_path: Path = dump_dir / Path(data_name) / dump_fname
plot_dir = data_manifest_path.parent / Path("wav_plots") plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True) plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}") typer.echo(f"Using data manifest:{data_manifest_path}")
@ -134,15 +134,22 @@ def dump_validation_ui_data(
"annotation_only": annotation_only, "annotation_only": annotation_only,
"enable_plots": enable_plots, "enable_plots": enable_plots,
} }
typer.echo(f"Writing dump to {dump_path}")
ExtendedPath(dump_path).write_json(ui_config) ExtendedPath(dump_path).write_json(ui_config)
@app.command() @app.command()
def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")): def dump_corrections(
col = get_mongo_conn(col='asr_validation') data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
dump_fname: Path = Path("corrections.json"),
):
dump_path = dump_dir / Path(data_name) / dump_fname
col = get_mongo_conn(col="asr_validation")
cursor_obj = col.find({"type": "correction"}, projection={"_id": False}) cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
corrections = [c for c in cursor_obj] corrections = [c for c in cursor_obj]
typer.echo(f"Writing dump to {dump_path}")
ExtendedPath(dump_path).write_json(corrections) ExtendedPath(dump_path).write_json(corrections)
@ -156,7 +163,7 @@ def fill_unannotated(
annotated_codes = {c["code"] for c in corrections} annotated_codes = {c["code"] for c in corrections}
all_codes = {c["gold_chars"] for c in processed_data} all_codes = {c["gold_chars"] for c in processed_data}
unann_codes = all_codes - annotated_codes unann_codes = all_codes - annotated_codes
mongo_conn = get_mongo_conn(col='asr_validation') mongo_conn = get_mongo_conn(col="asr_validation")
for c in unann_codes: for c in unann_codes:
mongo_conn.find_one_and_update( mongo_conn.find_one_and_update(
{"type": "correction", "code": c}, {"type": "correction", "code": c},
@ -234,7 +241,7 @@ def update_corrections(
def clear_mongo_corrections(): def clear_mongo_corrections():
delete = typer.confirm("are you sure you want to clear mongo collection it?") delete = typer.confirm("are you sure you want to clear mongo collection it?")
if delete: if delete:
col = get_mongo_conn(col='asr_validation') col = get_mongo_conn(col="asr_validation")
col.delete_many({"type": "correction"}) col.delete_many({"type": "correction"})
typer.echo("deleted mongo collection.") typer.echo("deleted mongo collection.")
typer.echo("Aborted") typer.echo("Aborted")

View File

@ -9,7 +9,7 @@ app = typer.Typer()
if not hasattr(st, "mongo_connected"): if not hasattr(st, "mongo_connected"):
st.mongoclient = get_mongo_conn(col='asr_validation') st.mongoclient = get_mongo_conn(col="asr_validation")
mongo_conn = st.mongoclient mongo_conn = st.mongoclient
def current_cursor_fn(): def current_cursor_fn():
@ -119,9 +119,13 @@ def main(manifest: Path):
st.markdown( st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
) )
text_sample = st.text_input("Go to Text:", value='') text_sample = st.text_input("Go to Text:", value="")
if text_sample != '': if text_sample != "":
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample] candidates = [
i
for (i, p) in enumerate(asr_data)
if p["text"] == text_sample or p["spoken"] == text_sample
]
if len(candidates) > 0: if len(candidates) > 0:
st.update_cursor(candidates[0]) st.update_cursor(candidates[0])
real_idx = st.number_input( real_idx = st.number_input(