2020-05-06 06:48:34 +00:00
|
|
|
import json
|
|
|
|
|
import shutil
|
2020-05-12 18:08:06 +00:00
|
|
|
from pathlib import Path
|
2020-05-06 06:48:34 +00:00
|
|
|
|
2020-05-12 18:08:06 +00:00
|
|
|
import typer
|
|
|
|
|
from tqdm import tqdm
|
2020-05-06 06:48:34 +00:00
|
|
|
|
2020-05-12 18:08:06 +00:00
|
|
|
from ..utils import (
|
|
|
|
|
alnum_to_asr_tokens,
|
|
|
|
|
ExtendedPath,
|
|
|
|
|
asr_manifest_reader,
|
|
|
|
|
asr_manifest_writer,
|
|
|
|
|
get_mongo_conn,
|
2020-05-06 06:48:34 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
|
|
|
|
|
2020-05-27 10:27:42 +00:00
|
|
|
def preprocess_datapoint(
|
|
|
|
|
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
|
|
|
|
|
):
|
2020-05-12 18:08:06 +00:00
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import librosa
|
|
|
|
|
import librosa.display
|
|
|
|
|
from pydub import AudioSegment
|
|
|
|
|
from nemo.collections.asr.metrics import word_error_rate
|
2020-05-06 06:48:34 +00:00
|
|
|
|
2020-05-12 18:08:06 +00:00
|
|
|
try:
|
|
|
|
|
res = dict(sample)
|
|
|
|
|
res["real_idx"] = idx
|
|
|
|
|
audio_path = rel_root / Path(sample["audio_filepath"])
|
|
|
|
|
res["audio_path"] = str(audio_path)
|
2020-05-27 10:27:42 +00:00
|
|
|
if use_domain_asr:
|
|
|
|
|
res["spoken"] = alnum_to_asr_tokens(res["text"])
|
|
|
|
|
else:
|
|
|
|
|
res["spoken"] = res["text"]
|
2020-05-12 18:08:06 +00:00
|
|
|
res["utterance_id"] = audio_path.stem
|
2020-05-27 09:49:25 +00:00
|
|
|
if not annotation_only:
|
|
|
|
|
from jasper.client import transcriber_pretrained, transcriber_speller
|
|
|
|
|
|
2020-05-27 10:27:42 +00:00
|
|
|
aud_seg = (
|
|
|
|
|
AudioSegment.from_file_using_temporary_files(audio_path)
|
|
|
|
|
.set_channels(1)
|
|
|
|
|
.set_sample_width(2)
|
|
|
|
|
.set_frame_rate(24000)
|
|
|
|
|
)
|
2020-05-27 09:49:25 +00:00
|
|
|
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
|
|
|
|
res["pretrained_wer"] = word_error_rate(
|
|
|
|
|
[res["text"]], [res["pretrained_asr"]]
|
2020-05-12 18:08:06 +00:00
|
|
|
)
|
2020-05-27 09:49:25 +00:00
|
|
|
if use_domain_asr:
|
|
|
|
|
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
|
|
|
|
res["domain_wer"] = word_error_rate(
|
|
|
|
|
[res["spoken"]], [res["pretrained_asr"]]
|
|
|
|
|
)
|
2020-05-27 10:13:03 +00:00
|
|
|
if enable_plots:
|
|
|
|
|
wav_plot_path = (
|
|
|
|
|
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
|
|
|
|
)
|
|
|
|
|
if not wav_plot_path.exists():
|
|
|
|
|
fig = plt.Figure()
|
|
|
|
|
ax = fig.add_subplot()
|
|
|
|
|
(y, sr) = librosa.load(audio_path)
|
|
|
|
|
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
|
|
|
|
with wav_plot_path.open("wb") as wav_plot_f:
|
|
|
|
|
fig.set_tight_layout(True)
|
|
|
|
|
fig.savefig(wav_plot_f, format="png", dpi=50)
|
|
|
|
|
# fig.close()
|
|
|
|
|
res["plot_path"] = str(wav_plot_path)
|
2020-05-12 18:08:06 +00:00
|
|
|
return res
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
2020-05-06 06:48:34 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.command()
|
2020-05-12 18:08:06 +00:00
|
|
|
def dump_validation_ui_data(
|
2020-05-27 09:49:25 +00:00
|
|
|
data_manifest_path: Path = typer.Option(
|
|
|
|
|
Path("./data/asr_data/call_alphanum/manifest.json"), show_default=True
|
|
|
|
|
),
|
|
|
|
|
dump_path: Path = typer.Option(
|
|
|
|
|
Path("./data/valiation_data/ui_dump.json"), show_default=True
|
|
|
|
|
),
|
2020-05-12 18:08:06 +00:00
|
|
|
use_domain_asr: bool = True,
|
2020-05-27 09:49:25 +00:00
|
|
|
annotation_only: bool = True,
|
2020-05-27 10:13:03 +00:00
|
|
|
enable_plots: bool = True,
|
2020-05-06 06:48:34 +00:00
|
|
|
):
|
2020-05-12 18:08:06 +00:00
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
|
|
|
|
plot_dir.mkdir(parents=True, exist_ok=True)
|
2020-05-06 06:48:34 +00:00
|
|
|
typer.echo(f"Using data manifest:{data_manifest_path}")
|
|
|
|
|
with data_manifest_path.open("r") as pf:
|
|
|
|
|
pnr_jsonl = pf.readlines()
|
2020-05-12 18:08:06 +00:00
|
|
|
pnr_funcs = [
|
|
|
|
|
partial(
|
|
|
|
|
preprocess_datapoint,
|
|
|
|
|
i,
|
|
|
|
|
data_manifest_path.parent,
|
|
|
|
|
json.loads(v),
|
|
|
|
|
use_domain_asr,
|
2020-05-27 09:49:25 +00:00
|
|
|
annotation_only,
|
2020-05-27 10:13:03 +00:00
|
|
|
enable_plots,
|
2020-05-12 18:08:06 +00:00
|
|
|
)
|
|
|
|
|
for i, v in enumerate(pnr_jsonl)
|
2020-05-06 06:48:34 +00:00
|
|
|
]
|
2020-05-12 18:08:06 +00:00
|
|
|
|
|
|
|
|
def exec_func(f):
|
|
|
|
|
return f()
|
|
|
|
|
|
2020-05-27 08:52:44 +00:00
|
|
|
with ThreadPoolExecutor() as exe:
|
2020-05-27 09:49:25 +00:00
|
|
|
print("starting all preprocess tasks")
|
2020-05-12 18:08:06 +00:00
|
|
|
pnr_data = filter(
|
|
|
|
|
None,
|
|
|
|
|
list(
|
|
|
|
|
tqdm(
|
|
|
|
|
exe.map(exec_func, pnr_funcs),
|
|
|
|
|
position=0,
|
|
|
|
|
leave=True,
|
|
|
|
|
total=len(pnr_funcs),
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
)
|
2020-05-27 09:49:25 +00:00
|
|
|
if annotation_only:
|
2020-05-28 05:48:39 +00:00
|
|
|
result = list(pnr_data)
|
2020-05-27 09:49:25 +00:00
|
|
|
else:
|
|
|
|
|
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
|
|
|
|
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
|
|
|
|
ui_config = {
|
|
|
|
|
"use_domain_asr": use_domain_asr,
|
|
|
|
|
"data": result,
|
|
|
|
|
"annotation_only": annotation_only,
|
2020-05-27 10:13:03 +00:00
|
|
|
"enable_plots": enable_plots,
|
2020-05-27 09:49:25 +00:00
|
|
|
}
|
2020-05-12 18:08:06 +00:00
|
|
|
ExtendedPath(dump_path).write_json(ui_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.command()
|
2020-05-14 10:09:44 +00:00
|
|
|
def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")):
|
2020-05-12 18:08:06 +00:00
|
|
|
col = get_mongo_conn().test.asr_validation
|
|
|
|
|
|
|
|
|
|
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
|
|
|
|
corrections = [c for c in cursor_obj]
|
|
|
|
|
ExtendedPath(dump_path).write_json(corrections)
|
2020-05-06 06:48:34 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.command()
|
|
|
|
|
def fill_unannotated(
|
2020-05-12 18:08:06 +00:00
|
|
|
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
|
|
|
|
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
2020-05-06 06:48:34 +00:00
|
|
|
):
|
|
|
|
|
processed_data = json.load(processed_data_path.open())
|
|
|
|
|
corrections = json.load(corrections_path.open())
|
|
|
|
|
annotated_codes = {c["code"] for c in corrections}
|
|
|
|
|
all_codes = {c["gold_chars"] for c in processed_data}
|
|
|
|
|
unann_codes = all_codes - annotated_codes
|
2020-05-12 18:08:06 +00:00
|
|
|
mongo_conn = get_mongo_conn().test.asr_validation
|
2020-05-06 06:48:34 +00:00
|
|
|
for c in unann_codes:
|
|
|
|
|
mongo_conn.find_one_and_update(
|
|
|
|
|
{"type": "correction", "code": c},
|
|
|
|
|
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
|
|
|
|
|
upsert=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.command()
|
|
|
|
|
def update_corrections(
|
|
|
|
|
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
2020-05-12 18:08:06 +00:00
|
|
|
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
2020-05-20 05:46:22 +00:00
|
|
|
skip_incorrect: bool = True,
|
2020-05-06 06:48:34 +00:00
|
|
|
):
|
|
|
|
|
def correct_manifest(manifest_data_gen, corrections_path):
|
|
|
|
|
corrections = json.load(corrections_path.open())
|
|
|
|
|
correct_set = {
|
|
|
|
|
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
|
|
|
|
}
|
|
|
|
|
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
|
|
|
|
|
correction_map = {
|
|
|
|
|
c["code"]: c["value"]["correction"]
|
|
|
|
|
for c in corrections
|
|
|
|
|
if c["value"]["status"] == "Incorrect"
|
|
|
|
|
}
|
|
|
|
|
# for d in manifest_data_gen:
|
|
|
|
|
# if d["chars"] in incorrect_set:
|
|
|
|
|
# d["audio_path"].unlink()
|
|
|
|
|
renamed_set = set()
|
|
|
|
|
for d in manifest_data_gen:
|
|
|
|
|
if d["chars"] in correct_set:
|
|
|
|
|
yield {
|
|
|
|
|
"audio_filepath": d["audio_filepath"],
|
|
|
|
|
"duration": d["duration"],
|
|
|
|
|
"text": d["text"],
|
|
|
|
|
}
|
|
|
|
|
elif d["chars"] in correction_map:
|
|
|
|
|
correct_text = correction_map[d["chars"]]
|
2020-05-20 05:46:22 +00:00
|
|
|
if skip_incorrect:
|
2020-05-27 09:49:25 +00:00
|
|
|
print(
|
|
|
|
|
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
|
|
|
|
)
|
2020-05-20 05:46:22 +00:00
|
|
|
else:
|
|
|
|
|
renamed_set.add(correct_text)
|
|
|
|
|
new_name = str(Path(correct_text).with_suffix(".wav"))
|
|
|
|
|
d["audio_path"].replace(d["audio_path"].with_name(new_name))
|
|
|
|
|
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
|
|
|
|
yield {
|
|
|
|
|
"audio_filepath": new_filepath,
|
|
|
|
|
"duration": d["duration"],
|
|
|
|
|
"text": alnum_to_asr_tokens(correct_text),
|
|
|
|
|
}
|
2020-05-06 06:48:34 +00:00
|
|
|
else:
|
|
|
|
|
# don't delete if another correction points to an old file
|
|
|
|
|
if d["chars"] not in renamed_set:
|
|
|
|
|
d["audio_path"].unlink()
|
|
|
|
|
else:
|
|
|
|
|
print(f'skipping deletion of correction:{d["chars"]}')
|
|
|
|
|
|
|
|
|
|
typer.echo(f"Using data manifest:{data_manifest_path}")
|
|
|
|
|
dataset_dir = data_manifest_path.parent
|
|
|
|
|
dataset_name = dataset_dir.name
|
|
|
|
|
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
|
|
|
|
|
if not backup_dir.exists():
|
|
|
|
|
typer.echo(f"backing up to :{backup_dir}")
|
|
|
|
|
shutil.copytree(str(dataset_dir), str(backup_dir))
|
|
|
|
|
manifest_gen = asr_manifest_reader(data_manifest_path)
|
|
|
|
|
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
|
|
|
|
|
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
|
|
|
|
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
|
|
|
|
new_data_manifest_path.replace(data_manifest_path)
|
|
|
|
|
|
|
|
|
|
|
2020-05-12 18:08:06 +00:00
|
|
|
@app.command()
|
|
|
|
|
def clear_mongo_corrections():
|
2020-05-20 05:46:22 +00:00
|
|
|
delete = typer.confirm("are you sure you want to clear mongo collection it?")
|
|
|
|
|
if delete:
|
|
|
|
|
col = get_mongo_conn().test.asr_validation
|
|
|
|
|
col.delete_many({"type": "correction"})
|
|
|
|
|
typer.echo("deleted mongo collection.")
|
|
|
|
|
typer.echo("Aborted")
|
2020-05-12 18:08:06 +00:00
|
|
|
|
|
|
|
|
|
2020-05-06 06:48:34 +00:00
|
|
|
def main():
|
|
|
|
|
app()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|