1. refactored ui_dump

2. added flake8
tegra
Malar Kannan 2020-08-09 19:16:35 +05:30
parent 42647196fe
commit e8f58a5043
4 changed files with 154 additions and 179 deletions

4
.flake8 Normal file
View File

@ -0,0 +1,4 @@
[flake8]
exclude = docs
ignore = E203, W503
max-line-length = 119

View File

@ -58,83 +58,88 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
return num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
ui_dump_file = dataset_dir / Path("ui_dump.json")
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
def data_fn(
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
png_path = Path(fname).with_suffix(".png")
wav_plot_path = dataset_dir / Path("wav_plots") / png_path
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
return {
"audio_filepath": str(rel_data_path),
"duration": round(audio_dur, 1),
"text": transcript,
"real_idx": num_datapoints,
"audio_path": audio_path,
"spoken": transcript,
"caller": caller_name,
"utterance_id": fname,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
}
num_datapoints = 0
ui_dump = {
"use_domain_asr": False,
"annotation_only": False,
"enable_plots": True,
"data": [],
}
data_funcs = []
transcriber_pretrained = transcribe_gen(asr_port=8044)
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
audio_path = str(audio_file)
rel_data_path = audio_file.relative_to(dataset_dir)
data_funcs.append(
partial(
data_fn,
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
)
)
num_datapoints += 1
ui_data = parallel_apply(lambda x: x(), data_funcs)
return ui_data, num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
dump_data, num_datapoints = ui_data_generator(
output_dir, dataset_name, asr_data_source, verbose=verbose
)
asr_manifest = dataset_dir / Path("manifest.json")
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
def data_fn(
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
return {
"audio_filepath": str(rel_data_path),
"duration": round(audio_dur, 1),
"text": transcript,
"real_idx": num_datapoints,
"audio_path": audio_path,
"spoken": transcript,
"caller": caller_name,
"utterance_id": fname,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
}
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
audio_path = str(audio_file)
rel_data_path = audio_file.relative_to(dataset_dir)
for d in dump_data:
rel_data_path = d["audio_filepath"]
audio_dur = d["duration"]
transcript = d["text"]
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
data_funcs.append(
partial(
data_fn,
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
)
)
num_datapoints += 1
dump_data = parallel_apply(lambda x: x(), data_funcs)
# dump_data = [x() for x in tqdm(data_funcs)]
ui_dump["data"] = dump_data
ExtendedPath(ui_dump_file).write_json(ui_dump)
ui_dump_file = dataset_dir / Path("ui_dump.json")
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
return num_datapoints

View File

@ -1,13 +1,10 @@
import json
import shutil
from pathlib import Path
from enum import Enum
import typer
from tqdm import tqdm
from ..utils import (
alnum_to_asr_tokens,
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
@ -19,9 +16,7 @@ from ..utils import (
app = typer.Typer()
def preprocess_datapoint(
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
):
def preprocess_datapoint(idx, rel_root, sample):
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.client import transcribe_gen
@ -31,37 +26,23 @@ def preprocess_datapoint(
res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
if use_domain_asr:
res["spoken"] = alnum_to_asr_tokens(res["text"])
else:
res["spoken"] = res["text"]
res["utterance_id"] = audio_path.stem
if not annotation_only:
transcriber_pretrained = transcribe_gen(asr_port=8044)
transcriber_pretrained = transcribe_gen(asr_port=8044)
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate(
[res["text"]], [res["pretrained_asr"]]
)
if use_domain_asr:
transcriber_speller = transcribe_gen(asr_port=8045)
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]]
)
if enable_plots:
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
return res
except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
@ -73,61 +54,50 @@ def dump_ui(
dataset_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/valiation_data"),
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
use_domain_asr: bool = False,
annotation_only: bool = False,
enable_plots: bool = True,
):
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from io import BytesIO
from pydub import AudioSegment
from ..utils import ui_data_generator
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
dump_path: Path = dump_dir / Path(data_name) / dump_fname
plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
data_funcs = [
partial(
preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
annotation_only,
enable_plots,
)
for i, v in enumerate(data_jsonl)
]
def exec_func(f):
return f()
def asr_data_source_gen():
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
for v in data_jsonl:
sample = json.loads(v)
rel_root = data_manifest_path.parent
res = dict(sample)
audio_path = rel_root / Path(sample["audio_filepath"])
audio_segment = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
wav_plot_path = (
rel_root
/ Path("wav_plots")
/ Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
code_fb = BytesIO()
audio_segment.export(code_fb, format="wav")
wav_data = code_fb.getvalue()
duration = audio_segment.duration_seconds
asr_final = res["text"]
yield asr_final, duration, wav_data, "caller", audio_segment
with ThreadPoolExecutor() as exe:
print("starting all preprocess tasks")
data_final = filter(
None,
list(
tqdm(
exe.map(exec_func, data_funcs),
position=0,
leave=True,
total=len(data_funcs),
)
),
)
if annotation_only:
result = list(data_final)
else:
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(data_final, key=lambda x: x[wer_key], reverse=True)
ui_config = {
"use_domain_asr": use_domain_asr,
"annotation_only": annotation_only,
"enable_plots": enable_plots,
"data": result,
}
ExtendedPath(dump_path).write_json(ui_config)
dump_data, num_datapoints = ui_data_generator(
dataset_dir, data_name, asr_data_source_gen()
)
ui_dump_file = dataset_dir / Path("ui_dump.json")
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
@app.command()
@ -190,7 +160,9 @@ def dump_corrections(
col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
cursor_obj = col.find(
{"type": "correction", "task_id": task_id}, projection={"_id": False}
)
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@ -264,7 +236,9 @@ def split_extract(
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
conv_data_path: Path = typer.Option(
Path("./data/conv_data.json"), show_default=True
),
extraction_type: str = "all",
):
import shutil
@ -286,7 +260,9 @@ def split_extract(
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
shutil.copy(
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
)
yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
@ -295,12 +271,14 @@ def split_extract(
orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
extracted_ui_data = list(
filter(lambda u: u["text"] in extraction_vals, ui_data)
)
final_data = []
for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i
d["real_idx"] = i
final_data.append(d)
orig_ui_data['data'] = final_data
orig_ui_data["data"] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file:
@ -316,7 +294,7 @@ def split_extract(
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all':
if extraction_type.value == "all":
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
@ -338,7 +316,7 @@ def update_corrections(
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
@ -367,7 +345,9 @@ def update_corrections(
)
else:
orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
new_name = str(
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
)
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))

View File

@ -72,22 +72,18 @@ def main(manifest: Path, task_id: str = ""):
st.set_task(manifest, task_id)
ui_config = load_ui_data(manifest)
asr_data = ui_config["data"]
use_domain_asr = ui_config.get("use_domain_asr", True)
annotation_only = ui_config.get("annotation_only", False)
enable_plots = ui_config.get("enable_plots", True)
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
title_type = "Speller " if use_domain_asr else ""
task_uid = st.task_id.rsplit("-", 1)[1]
if annotation_only:
st.title(f"ASR Annotation - # {task_uid}")
else:
st.title(f"ASR {title_type}Validation - # {task_uid}")
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
st.title(f"ASR Validation - # {task_uid}")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
@ -96,19 +92,13 @@ def main(manifest: Path, task_id: str = ""):
st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
if not annotation_only:
if use_domain_asr:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:")
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if "caller" in sample:
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
if enable_plots:
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text
corrected = sample["text"]
@ -130,16 +120,12 @@ def main(manifest: Path, task_id: str = ""):
)
st.update_cursor(sample_no + 1)
if correction_entry:
st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
)
status = correction_entry["value"]["status"]
correction = correction_entry["value"]["correction"]
st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
text_sample = st.text_input("Go to Text:", value="")
if text_sample != "":
candidates = [
i
for (i, p) in enumerate(asr_data)
if p["text"] == text_sample or p["spoken"] == text_sample
]
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
if len(candidates) > 0:
st.update_cursor(candidates[0])
real_idx = st.number_input(