2020-05-06 06:48:34 +00:00
|
|
|
import json
|
|
|
|
|
import shutil
|
2020-05-12 18:08:06 +00:00
|
|
|
from pathlib import Path
|
2020-06-15 04:54:38 +00:00
|
|
|
from enum import Enum
|
2020-05-06 06:48:34 +00:00
|
|
|
|
2020-05-12 18:08:06 +00:00
|
|
|
import typer
|
|
|
|
|
from tqdm import tqdm
|
2020-05-06 06:48:34 +00:00
|
|
|
|
2020-05-12 18:08:06 +00:00
|
|
|
from ..utils import (
|
|
|
|
|
alnum_to_asr_tokens,
|
|
|
|
|
ExtendedPath,
|
|
|
|
|
asr_manifest_reader,
|
|
|
|
|
asr_manifest_writer,
|
|
|
|
|
get_mongo_conn,
|
2020-05-06 06:48:34 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
|
|
|
|
|
2020-05-27 10:27:42 +00:00
|
|
|
def preprocess_datapoint(
|
|
|
|
|
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
|
|
|
|
|
):
|
2020-05-12 18:08:06 +00:00
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import librosa
|
|
|
|
|
import librosa.display
|
|
|
|
|
from pydub import AudioSegment
|
|
|
|
|
from nemo.collections.asr.metrics import word_error_rate
|
2020-06-04 12:19:16 +00:00
|
|
|
from jasper.client import transcribe_gen
|
2020-05-06 06:48:34 +00:00
|
|
|
|
2020-05-12 18:08:06 +00:00
|
|
|
try:
|
|
|
|
|
res = dict(sample)
|
|
|
|
|
res["real_idx"] = idx
|
|
|
|
|
audio_path = rel_root / Path(sample["audio_filepath"])
|
|
|
|
|
res["audio_path"] = str(audio_path)
|
2020-05-27 10:27:42 +00:00
|
|
|
if use_domain_asr:
|
|
|
|
|
res["spoken"] = alnum_to_asr_tokens(res["text"])
|
|
|
|
|
else:
|
|
|
|
|
res["spoken"] = res["text"]
|
2020-05-12 18:08:06 +00:00
|
|
|
res["utterance_id"] = audio_path.stem
|
2020-05-27 09:49:25 +00:00
|
|
|
if not annotation_only:
|
2020-06-04 12:19:16 +00:00
|
|
|
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
2020-05-27 09:49:25 +00:00
|
|
|
|
2020-05-27 10:27:42 +00:00
|
|
|
aud_seg = (
|
|
|
|
|
AudioSegment.from_file_using_temporary_files(audio_path)
|
|
|
|
|
.set_channels(1)
|
|
|
|
|
.set_sample_width(2)
|
|
|
|
|
.set_frame_rate(24000)
|
|
|
|
|
)
|
2020-05-27 09:49:25 +00:00
|
|
|
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
|
|
|
|
res["pretrained_wer"] = word_error_rate(
|
|
|
|
|
[res["text"]], [res["pretrained_asr"]]
|
2020-05-12 18:08:06 +00:00
|
|
|
)
|
2020-05-27 09:49:25 +00:00
|
|
|
if use_domain_asr:
|
2020-06-04 12:19:16 +00:00
|
|
|
transcriber_speller = transcribe_gen(asr_port=8045)
|
2020-05-27 09:49:25 +00:00
|
|
|
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
|
|
|
|
res["domain_wer"] = word_error_rate(
|
|
|
|
|
[res["spoken"]], [res["pretrained_asr"]]
|
|
|
|
|
)
|
2020-05-27 10:13:03 +00:00
|
|
|
if enable_plots:
|
|
|
|
|
wav_plot_path = (
|
|
|
|
|
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
|
|
|
|
)
|
|
|
|
|
if not wav_plot_path.exists():
|
|
|
|
|
fig = plt.Figure()
|
|
|
|
|
ax = fig.add_subplot()
|
|
|
|
|
(y, sr) = librosa.load(audio_path)
|
|
|
|
|
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
|
|
|
|
with wav_plot_path.open("wb") as wav_plot_f:
|
|
|
|
|
fig.set_tight_layout(True)
|
|
|
|
|
fig.savefig(wav_plot_f, format="png", dpi=50)
|
|
|
|
|
# fig.close()
|
|
|
|
|
res["plot_path"] = str(wav_plot_path)
|
2020-05-12 18:08:06 +00:00
|
|
|
return res
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
2020-05-06 06:48:34 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.command()
|
2020-06-10 09:25:59 +00:00
|
|
|
def dump_ui(
|
2020-06-05 11:02:08 +00:00
|
|
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
|
|
|
|
dataset_dir: Path = Path("./data/asr_data"),
|
|
|
|
|
dump_dir: Path = Path("./data/valiation_data"),
|
|
|
|
|
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
|
2020-06-04 12:19:16 +00:00
|
|
|
use_domain_asr: bool = False,
|
|
|
|
|
annotation_only: bool = False,
|
2020-05-27 10:13:03 +00:00
|
|
|
enable_plots: bool = True,
|
2020-05-06 06:48:34 +00:00
|
|
|
):
|
2020-05-12 18:08:06 +00:00
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
2020-06-05 11:02:08 +00:00
|
|
|
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
|
|
|
|
|
dump_path: Path = dump_dir / Path(data_name) / dump_fname
|
2020-05-12 18:08:06 +00:00
|
|
|
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
|
|
|
|
plot_dir.mkdir(parents=True, exist_ok=True)
|
2020-05-06 06:48:34 +00:00
|
|
|
typer.echo(f"Using data manifest:{data_manifest_path}")
|
|
|
|
|
with data_manifest_path.open("r") as pf:
|
|
|
|
|
pnr_jsonl = pf.readlines()
|
2020-05-12 18:08:06 +00:00
|
|
|
pnr_funcs = [
|
|
|
|
|
partial(
|
|
|
|
|
preprocess_datapoint,
|
|
|
|
|
i,
|
|
|
|
|
data_manifest_path.parent,
|
|
|
|
|
json.loads(v),
|
|
|
|
|
use_domain_asr,
|
2020-05-27 09:49:25 +00:00
|
|
|
annotation_only,
|
2020-05-27 10:13:03 +00:00
|
|
|
enable_plots,
|
2020-05-12 18:08:06 +00:00
|
|
|
)
|
|
|
|
|
for i, v in enumerate(pnr_jsonl)
|
2020-05-06 06:48:34 +00:00
|
|
|
]
|
2020-05-12 18:08:06 +00:00
|
|
|
|
|
|
|
|
def exec_func(f):
|
|
|
|
|
return f()
|
|
|
|
|
|
2020-05-27 08:52:44 +00:00
|
|
|
with ThreadPoolExecutor() as exe:
|
2020-05-27 09:49:25 +00:00
|
|
|
print("starting all preprocess tasks")
|
2020-05-12 18:08:06 +00:00
|
|
|
pnr_data = filter(
|
|
|
|
|
None,
|
|
|
|
|
list(
|
|
|
|
|
tqdm(
|
|
|
|
|
exe.map(exec_func, pnr_funcs),
|
|
|
|
|
position=0,
|
|
|
|
|
leave=True,
|
|
|
|
|
total=len(pnr_funcs),
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
)
|
2020-05-27 09:49:25 +00:00
|
|
|
if annotation_only:
|
2020-05-28 05:48:39 +00:00
|
|
|
result = list(pnr_data)
|
2020-05-27 09:49:25 +00:00
|
|
|
else:
|
|
|
|
|
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
|
|
|
|
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
|
|
|
|
ui_config = {
|
|
|
|
|
"use_domain_asr": use_domain_asr,
|
|
|
|
|
"data": result,
|
|
|
|
|
"annotation_only": annotation_only,
|
2020-05-27 10:13:03 +00:00
|
|
|
"enable_plots": enable_plots,
|
2020-05-27 09:49:25 +00:00
|
|
|
}
|
2020-05-12 18:08:06 +00:00
|
|
|
ExtendedPath(dump_path).write_json(ui_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.command()
|
2020-06-05 11:02:08 +00:00
|
|
|
def dump_corrections(
|
|
|
|
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
|
|
|
|
dump_dir: Path = Path("./data/valiation_data"),
|
|
|
|
|
dump_fname: Path = Path("corrections.json"),
|
|
|
|
|
):
|
|
|
|
|
dump_path = dump_dir / Path(data_name) / dump_fname
|
|
|
|
|
col = get_mongo_conn(col="asr_validation")
|
2020-05-12 18:08:06 +00:00
|
|
|
|
|
|
|
|
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
|
|
|
|
corrections = [c for c in cursor_obj]
|
|
|
|
|
ExtendedPath(dump_path).write_json(corrections)
|
2020-05-06 06:48:34 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.command()
|
|
|
|
|
def fill_unannotated(
|
2020-06-10 09:25:59 +00:00
|
|
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
|
|
|
|
dump_dir: Path = Path("./data/valiation_data"),
|
|
|
|
|
dump_file: Path = Path("ui_dump.json"),
|
|
|
|
|
corrections_file: Path = Path("corrections.json"),
|
2020-05-06 06:48:34 +00:00
|
|
|
):
|
2020-06-10 09:25:59 +00:00
|
|
|
processed_data_path = dump_dir / Path(data_name) / dump_file
|
|
|
|
|
corrections_path = dump_dir / Path(data_name) / corrections_file
|
2020-05-06 06:48:34 +00:00
|
|
|
processed_data = json.load(processed_data_path.open())
|
|
|
|
|
corrections = json.load(corrections_path.open())
|
|
|
|
|
annotated_codes = {c["code"] for c in corrections}
|
|
|
|
|
all_codes = {c["gold_chars"] for c in processed_data}
|
|
|
|
|
unann_codes = all_codes - annotated_codes
|
2020-06-05 11:02:08 +00:00
|
|
|
mongo_conn = get_mongo_conn(col="asr_validation")
|
2020-05-06 06:48:34 +00:00
|
|
|
for c in unann_codes:
|
|
|
|
|
mongo_conn.find_one_and_update(
|
|
|
|
|
{"type": "correction", "code": c},
|
|
|
|
|
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
|
|
|
|
|
upsert=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2020-06-15 04:54:38 +00:00
|
|
|
class ExtractionType(str, Enum):
|
|
|
|
|
date = "dates"
|
|
|
|
|
city = "cities"
|
|
|
|
|
name = "names"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.command()
|
|
|
|
|
def split_extract(
|
|
|
|
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
|
|
|
|
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
|
|
|
|
|
dump_dir: Path = Path("./data/valiation_data"),
|
|
|
|
|
dump_file: Path = Path("ui_dump.json"),
|
|
|
|
|
manifest_dir: Path = Path("./data/asr_data"),
|
|
|
|
|
manifest_file: Path = Path("manifest.json"),
|
|
|
|
|
corrections_file: Path = Path("corrections.json"),
|
|
|
|
|
conv_data_path: Path = Path("./data/conv_data.json"),
|
|
|
|
|
extraction_type: ExtractionType = ExtractionType.date,
|
|
|
|
|
):
|
|
|
|
|
import shutil
|
|
|
|
|
|
|
|
|
|
def get_conv_data(cdp):
|
|
|
|
|
from itertools import product
|
|
|
|
|
|
|
|
|
|
conv_data = json.load(cdp.open())
|
|
|
|
|
days = [str(i) for i in range(1, 32)]
|
|
|
|
|
months = conv_data["months"]
|
|
|
|
|
day_months = {d + " " + m for d, m in product(days, months)}
|
|
|
|
|
return {
|
|
|
|
|
"cities": set(conv_data["cities"]),
|
|
|
|
|
"names": set(conv_data["names"]),
|
|
|
|
|
"dates": day_months,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dest_data_name = data_name + "_" + extraction_type.value
|
|
|
|
|
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
|
|
|
|
|
conv_data = get_conv_data(conv_data_path)
|
|
|
|
|
extraction_vals = conv_data[extraction_type.value]
|
|
|
|
|
|
|
|
|
|
manifest_gen = asr_manifest_reader(data_manifest_path)
|
|
|
|
|
dest_data_dir = manifest_dir / Path(dest_data_name)
|
|
|
|
|
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
|
|
|
|
dest_manifest_path = dest_data_dir / manifest_file
|
|
|
|
|
dest_ui_dir = dump_dir / Path(dest_data_name)
|
|
|
|
|
dest_ui_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
dest_ui_path = dest_ui_dir / dump_file
|
|
|
|
|
dest_correction_path = dest_ui_dir / corrections_file
|
|
|
|
|
|
|
|
|
|
def extract_manifest(mg):
|
|
|
|
|
for m in mg:
|
|
|
|
|
if m["text"] in extraction_vals:
|
|
|
|
|
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
|
|
|
|
|
yield m
|
|
|
|
|
|
|
|
|
|
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
|
|
|
|
|
|
|
|
|
ui_data_path = dump_dir / Path(data_name) / dump_file
|
|
|
|
|
corrections_path = dump_dir / Path(data_name) / corrections_file
|
|
|
|
|
ui_data = json.load(ui_data_path.open())["data"]
|
|
|
|
|
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
|
|
|
|
corrections = json.load(corrections_path.open())
|
|
|
|
|
|
|
|
|
|
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
|
|
|
|
ExtendedPath(dest_ui_path).write_json(extracted_ui_data)
|
|
|
|
|
|
|
|
|
|
extracted_corrections = list(
|
|
|
|
|
filter(
|
|
|
|
|
lambda c: c["code"] in file_ui_map
|
|
|
|
|
and file_ui_map[c["code"]]["text"] in extraction_vals,
|
|
|
|
|
corrections,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
|
|
|
|
|
|
|
|
|
|
2020-05-06 06:48:34 +00:00
|
|
|
@app.command()
|
|
|
|
|
def update_corrections(
|
2020-06-10 09:25:59 +00:00
|
|
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
|
|
|
|
dump_dir: Path = Path("./data/valiation_data"),
|
|
|
|
|
manifest_dir: Path = Path("./data/asr_data"),
|
|
|
|
|
manifest_file: Path = Path("manifest.json"),
|
|
|
|
|
corrections_file: Path = Path("corrections.json"),
|
|
|
|
|
# data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
|
|
|
|
# corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
2020-05-20 05:46:22 +00:00
|
|
|
skip_incorrect: bool = True,
|
2020-05-06 06:48:34 +00:00
|
|
|
):
|
2020-06-10 09:25:59 +00:00
|
|
|
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
|
2020-06-15 04:54:38 +00:00
|
|
|
corrections_path = dump_dir / Path(data_name) / corrections_file
|
2020-06-10 09:25:59 +00:00
|
|
|
|
2020-05-06 06:48:34 +00:00
|
|
|
def correct_manifest(manifest_data_gen, corrections_path):
|
|
|
|
|
corrections = json.load(corrections_path.open())
|
|
|
|
|
correct_set = {
|
|
|
|
|
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
|
|
|
|
}
|
|
|
|
|
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
|
|
|
|
|
correction_map = {
|
|
|
|
|
c["code"]: c["value"]["correction"]
|
|
|
|
|
for c in corrections
|
|
|
|
|
if c["value"]["status"] == "Incorrect"
|
|
|
|
|
}
|
|
|
|
|
# for d in manifest_data_gen:
|
|
|
|
|
# if d["chars"] in incorrect_set:
|
|
|
|
|
# d["audio_path"].unlink()
|
|
|
|
|
renamed_set = set()
|
|
|
|
|
for d in manifest_data_gen:
|
|
|
|
|
if d["chars"] in correct_set:
|
|
|
|
|
yield {
|
|
|
|
|
"audio_filepath": d["audio_filepath"],
|
|
|
|
|
"duration": d["duration"],
|
|
|
|
|
"text": d["text"],
|
|
|
|
|
}
|
|
|
|
|
elif d["chars"] in correction_map:
|
|
|
|
|
correct_text = correction_map[d["chars"]]
|
2020-05-20 05:46:22 +00:00
|
|
|
if skip_incorrect:
|
2020-05-27 09:49:25 +00:00
|
|
|
print(
|
|
|
|
|
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
|
|
|
|
)
|
2020-05-20 05:46:22 +00:00
|
|
|
else:
|
|
|
|
|
renamed_set.add(correct_text)
|
|
|
|
|
new_name = str(Path(correct_text).with_suffix(".wav"))
|
|
|
|
|
d["audio_path"].replace(d["audio_path"].with_name(new_name))
|
|
|
|
|
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
|
|
|
|
yield {
|
|
|
|
|
"audio_filepath": new_filepath,
|
|
|
|
|
"duration": d["duration"],
|
|
|
|
|
"text": alnum_to_asr_tokens(correct_text),
|
|
|
|
|
}
|
2020-05-06 06:48:34 +00:00
|
|
|
else:
|
|
|
|
|
# don't delete if another correction points to an old file
|
|
|
|
|
if d["chars"] not in renamed_set:
|
|
|
|
|
d["audio_path"].unlink()
|
|
|
|
|
else:
|
|
|
|
|
print(f'skipping deletion of correction:{d["chars"]}')
|
|
|
|
|
|
|
|
|
|
typer.echo(f"Using data manifest:{data_manifest_path}")
|
|
|
|
|
dataset_dir = data_manifest_path.parent
|
|
|
|
|
dataset_name = dataset_dir.name
|
|
|
|
|
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
|
|
|
|
|
if not backup_dir.exists():
|
|
|
|
|
typer.echo(f"backing up to :{backup_dir}")
|
|
|
|
|
shutil.copytree(str(dataset_dir), str(backup_dir))
|
|
|
|
|
manifest_gen = asr_manifest_reader(data_manifest_path)
|
|
|
|
|
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
|
|
|
|
|
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
|
|
|
|
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
|
|
|
|
new_data_manifest_path.replace(data_manifest_path)
|
|
|
|
|
|
|
|
|
|
|
2020-05-12 18:08:06 +00:00
|
|
|
@app.command()
|
|
|
|
|
def clear_mongo_corrections():
|
2020-05-20 05:46:22 +00:00
|
|
|
delete = typer.confirm("are you sure you want to clear mongo collection it?")
|
|
|
|
|
if delete:
|
2020-06-05 11:02:08 +00:00
|
|
|
col = get_mongo_conn(col="asr_validation")
|
2020-05-20 05:46:22 +00:00
|
|
|
col.delete_many({"type": "correction"})
|
|
|
|
|
typer.echo("deleted mongo collection.")
|
|
|
|
|
typer.echo("Aborted")
|
2020-05-12 18:08:06 +00:00
|
|
|
|
|
|
|
|
|
2020-05-06 06:48:34 +00:00
|
|
|
def main():
|
|
|
|
|
app()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|