1. refactored ui_dump

2. added flake8
tegra
Malar Kannan 2020-08-09 19:16:35 +05:30
parent 42647196fe
commit e8f58a5043
4 changed files with 154 additions and 179 deletions

4
.flake8 Normal file
View File

@ -0,0 +1,4 @@
[flake8]
exclude = docs
ignore = E203, W503
max-line-length = 119

View File

@ -58,83 +58,88 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
return num_datapoints return num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False): def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name) dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
ui_dump_file = dataset_dir / Path("ui_dump.json")
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True) (dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
def data_fn(
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
png_path = Path(fname).with_suffix(".png")
wav_plot_path = dataset_dir / Path("wav_plots") / png_path
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
return {
"audio_filepath": str(rel_data_path),
"duration": round(audio_dur, 1),
"text": transcript,
"real_idx": num_datapoints,
"audio_path": audio_path,
"spoken": transcript,
"caller": caller_name,
"utterance_id": fname,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
}
num_datapoints = 0 num_datapoints = 0
ui_dump = {
"use_domain_asr": False,
"annotation_only": False,
"enable_plots": True,
"data": [],
}
data_funcs = [] data_funcs = []
transcriber_pretrained = transcribe_gen(asr_port=8044) transcriber_pretrained = transcribe_gen(asr_port=8044)
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
audio_path = str(audio_file)
rel_data_path = audio_file.relative_to(dataset_dir)
data_funcs.append(
partial(
data_fn,
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
)
)
num_datapoints += 1
ui_data = parallel_apply(lambda x: x(), data_funcs)
return ui_data, num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
dump_data, num_datapoints = ui_data_generator(
output_dir, dataset_name, asr_data_source, verbose=verbose
)
asr_manifest = dataset_dir / Path("manifest.json")
with asr_manifest.open("w") as mf: with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}") print(f"writing manifest to {asr_manifest}")
for d in dump_data:
def data_fn( rel_data_path = d["audio_filepath"]
transcript, audio_dur = d["duration"]
audio_dur, transcript = d["text"]
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
return {
"audio_filepath": str(rel_data_path),
"duration": round(audio_dur, 1),
"text": transcript,
"real_idx": num_datapoints,
"audio_path": audio_path,
"spoken": transcript,
"caller": caller_name,
"utterance_id": fname,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
}
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
audio_path = str(audio_file)
rel_data_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_data_path), audio_dur, transcript) manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest) mf.write(manifest)
data_funcs.append(
partial( ui_dump_file = dataset_dir / Path("ui_dump.json")
data_fn, ExtendedPath(ui_dump_file).write_json({"data": dump_data})
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
)
)
num_datapoints += 1
dump_data = parallel_apply(lambda x: x(), data_funcs)
# dump_data = [x() for x in tqdm(data_funcs)]
ui_dump["data"] = dump_data
ExtendedPath(ui_dump_file).write_json(ui_dump)
return num_datapoints return num_datapoints

View File

@ -1,13 +1,10 @@
import json import json
import shutil import shutil
from pathlib import Path from pathlib import Path
from enum import Enum
import typer import typer
from tqdm import tqdm
from ..utils import ( from ..utils import (
alnum_to_asr_tokens,
ExtendedPath, ExtendedPath,
asr_manifest_reader, asr_manifest_reader,
asr_manifest_writer, asr_manifest_writer,
@ -19,9 +16,7 @@ from ..utils import (
app = typer.Typer() app = typer.Typer()
def preprocess_datapoint( def preprocess_datapoint(idx, rel_root, sample):
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
):
from pydub import AudioSegment from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate from nemo.collections.asr.metrics import word_error_rate
from jasper.client import transcribe_gen from jasper.client import transcribe_gen
@ -31,37 +26,23 @@ def preprocess_datapoint(
res["real_idx"] = idx res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"]) audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path) res["audio_path"] = str(audio_path)
if use_domain_asr:
res["spoken"] = alnum_to_asr_tokens(res["text"])
else:
res["spoken"] = res["text"]
res["utterance_id"] = audio_path.stem res["utterance_id"] = audio_path.stem
if not annotation_only: transcriber_pretrained = transcribe_gen(asr_port=8044)
transcriber_pretrained = transcribe_gen(asr_port=8044)
aud_seg = ( aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path) AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1) .set_channels(1)
.set_sample_width(2) .set_sample_width(2)
.set_frame_rate(24000) .set_frame_rate(24000)
) )
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate( res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
[res["text"]], [res["pretrained_asr"]] wav_plot_path = (
) rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
if use_domain_asr: )
transcriber_speller = transcribe_gen(asr_port=8045) if not wav_plot_path.exists():
res["domain_asr"] = transcriber_speller(aud_seg.raw_data) plot_seg(wav_plot_path, audio_path)
res["domain_wer"] = word_error_rate( res["plot_path"] = str(wav_plot_path)
[res["spoken"]], [res["pretrained_asr"]]
)
if enable_plots:
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
return res return res
except BaseException as e: except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}') print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
@ -73,61 +54,50 @@ def dump_ui(
dataset_dir: Path = Path("./data/asr_data"), dataset_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/valiation_data"), dump_dir: Path = Path("./data/valiation_data"),
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True), dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
use_domain_asr: bool = False,
annotation_only: bool = False,
enable_plots: bool = True,
): ):
from concurrent.futures import ThreadPoolExecutor from io import BytesIO
from functools import partial from pydub import AudioSegment
from ..utils import ui_data_generator
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json") data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
dump_path: Path = dump_dir / Path(data_name) / dump_fname
plot_dir = data_manifest_path.parent / Path("wav_plots") plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True) plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}") typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
data_funcs = [
partial(
preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
annotation_only,
enable_plots,
)
for i, v in enumerate(data_jsonl)
]
def exec_func(f): def asr_data_source_gen():
return f() with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
for v in data_jsonl:
sample = json.loads(v)
rel_root = data_manifest_path.parent
res = dict(sample)
audio_path = rel_root / Path(sample["audio_filepath"])
audio_segment = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
wav_plot_path = (
rel_root
/ Path("wav_plots")
/ Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
code_fb = BytesIO()
audio_segment.export(code_fb, format="wav")
wav_data = code_fb.getvalue()
duration = audio_segment.duration_seconds
asr_final = res["text"]
yield asr_final, duration, wav_data, "caller", audio_segment
with ThreadPoolExecutor() as exe: dump_data, num_datapoints = ui_data_generator(
print("starting all preprocess tasks") dataset_dir, data_name, asr_data_source_gen()
data_final = filter( )
None, ui_dump_file = dataset_dir / Path("ui_dump.json")
list( ExtendedPath(ui_dump_file).write_json({"data": dump_data})
tqdm(
exe.map(exec_func, data_funcs),
position=0,
leave=True,
total=len(data_funcs),
)
),
)
if annotation_only:
result = list(data_final)
else:
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(data_final, key=lambda x: x[wer_key], reverse=True)
ui_config = {
"use_domain_asr": use_domain_asr,
"annotation_only": annotation_only,
"enable_plots": enable_plots,
"data": result,
}
ExtendedPath(dump_path).write_json(ui_config)
@app.command() @app.command()
@ -190,7 +160,9 @@ def dump_corrections(
col = get_mongo_conn(col="asr_validation") col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0] task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False})) corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False}) cursor_obj = col.find(
{"type": "correction", "task_id": task_id}, projection={"_id": False}
)
corrections = [c for c in cursor_obj] corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections) ExtendedPath(dump_path).write_json(corrections)
@ -264,7 +236,9 @@ def split_extract(
dump_file: Path = Path("ui_dump.json"), dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"), manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True), corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True), conv_data_path: Path = typer.Option(
Path("./data/conv_data.json"), show_default=True
),
extraction_type: str = "all", extraction_type: str = "all",
): ):
import shutil import shutil
@ -286,7 +260,9 @@ def split_extract(
def extract_manifest(mg): def extract_manifest(mg):
for m in mg: for m in mg:
if m["text"] in extraction_vals: if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"])) shutil.copy(
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
)
yield m yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen)) asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
@ -295,12 +271,14 @@ def split_extract(
orig_ui_data = ExtendedPath(ui_data_path).read_json() orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"] ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data} file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data)) extracted_ui_data = list(
filter(lambda u: u["text"] in extraction_vals, ui_data)
)
final_data = [] final_data = []
for i, d in enumerate(extracted_ui_data): for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i d["real_idx"] = i
final_data.append(d) final_data.append(d)
orig_ui_data['data'] = final_data orig_ui_data["data"] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data) ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file: if corrections_file:
@ -316,7 +294,7 @@ def split_extract(
) )
ExtendedPath(dest_correction_path).write_json(extracted_corrections) ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all': if extraction_type.value == "all":
for ext_key in conv_data.keys(): for ext_key in conv_data.keys():
extract_data_of_type(ext_key) extract_data_of_type(ext_key)
else: else:
@ -338,7 +316,7 @@ def update_corrections(
def correct_manifest(ui_dump_path, corrections_path): def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json() corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data'] ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
correct_set = { correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct" c["code"] for c in corrections if c["value"]["status"] == "Correct"
} }
@ -367,7 +345,9 @@ def update_corrections(
) )
else: else:
orig_audio_path = Path(d["audio_path"]) orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")) new_name = str(
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
)
new_audio_path = orig_audio_path.with_name(new_name) new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path) orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))

View File

@ -72,22 +72,18 @@ def main(manifest: Path, task_id: str = ""):
st.set_task(manifest, task_id) st.set_task(manifest, task_id)
ui_config = load_ui_data(manifest) ui_config = load_ui_data(manifest)
asr_data = ui_config["data"] asr_data = ui_config["data"]
use_domain_asr = ui_config.get("use_domain_asr", True)
annotation_only = ui_config.get("annotation_only", False) annotation_only = ui_config.get("annotation_only", False)
enable_plots = ui_config.get("enable_plots", True)
sample_no = st.get_current_cursor() sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0: if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0") print("Invalid samplno resetting to 0")
st.update_cursor(0) st.update_cursor(0)
sample = asr_data[sample_no] sample = asr_data[sample_no]
title_type = "Speller " if use_domain_asr else ""
task_uid = st.task_id.rsplit("-", 1)[1] task_uid = st.task_id.rsplit("-", 1)[1]
if annotation_only: if annotation_only:
st.title(f"ASR Annotation - # {task_uid}") st.title(f"ASR Annotation - # {task_uid}")
else: else:
st.title(f"ASR {title_type}Validation - # {task_uid}") st.title(f"ASR Validation - # {task_uid}")
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else "" st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
new_sample = st.number_input( new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
) )
@ -96,19 +92,13 @@ def main(manifest: Path, task_id: str = ""):
st.sidebar.title(f"Details: [{sample['real_idx']}]") st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**") st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
if not annotation_only: if not annotation_only:
if use_domain_asr:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:") st.sidebar.title("Results:")
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**") st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if "caller" in sample: if "caller" in sample:
st.sidebar.markdown(f"Caller: **{sample['caller']}**") st.sidebar.markdown(f"Caller: **{sample['caller']}**")
if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
else: else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
if enable_plots: st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb")) st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text # set default to text
corrected = sample["text"] corrected = sample["text"]
@ -130,16 +120,12 @@ def main(manifest: Path, task_id: str = ""):
) )
st.update_cursor(sample_no + 1) st.update_cursor(sample_no + 1)
if correction_entry: if correction_entry:
st.markdown( status = correction_entry["value"]["status"]
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' correction = correction_entry["value"]["correction"]
) st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
text_sample = st.text_input("Go to Text:", value="") text_sample = st.text_input("Go to Text:", value="")
if text_sample != "": if text_sample != "":
candidates = [ candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
i
for (i, p) in enumerate(asr_data)
if p["text"] == text_sample or p["spoken"] == text_sample
]
if len(candidates) > 0: if len(candidates) > 0:
st.update_cursor(candidates[0]) st.update_cursor(candidates[0])
real_idx = st.number_input( real_idx = st.number_input(