1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-03-09 19:02:35 +00:00

add validation ui and post processing to correct using validation data

This commit is contained in:
2020-05-06 12:18:34 +05:30
parent aae03a6ae4
commit a7da729c0b
4 changed files with 237 additions and 35 deletions

View File

@@ -0,0 +1,176 @@
import pymongo
import typer
# import matplotlib.pyplot as plt
from pathlib import Path
import json
import shutil
# import pandas as pd
from pydub import AudioSegment
# from .jasper_client import transcriber_pretrained, transcriber_speller
from jasper.data_utils.validation.jasper_client import (
transcriber_pretrained,
transcriber_speller,
)
from jasper.data_utils.utils import alnum_to_asr_tokens
# import importlib
# import jasper.data_utils.utils
# importlib.reload(jasper.data_utils.utils)
from jasper.data_utils.utils import asr_manifest_reader, asr_manifest_writer
from nemo.collections.asr.metrics import word_error_rate
# from tqdm import tqdm as tqdm_base
from tqdm import tqdm
app = typer.Typer()
@app.command()
def dump_corrections(dump_path: Path = Path("./data/corrections.json")):
col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
corrections = [c for c in cursor_obj]
dump_f = dump_path.open("w")
json.dump(corrections, dump_f, indent=2)
dump_f.close()
def preprocess_datapoint(idx, rel, sample):
res = dict(sample)
res["real_idx"] = idx
audio_path = rel / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
res["gold_chars"] = audio_path.stem
res["gold_phone"] = sample["text"]
aud_seg = (
AudioSegment.from_wav(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
return res
def load_dataset(data_manifest_path: Path):
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
]
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
return result
@app.command()
def dump_processed_data(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
dump_path: Path = Path("./data/processed_data.json"),
):
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
]
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
dump_path = Path("./data/processed_data.json")
dump_f = dump_path.open("w")
json.dump(result, dump_f, indent=2)
dump_f.close()
@app.command()
def fill_unannotated(
processed_data_path: Path = Path("./data/processed_data.json"),
corrections_path: Path = Path("./data/corrections.json"),
):
processed_data = json.load(processed_data_path.open())
corrections = json.load(corrections_path.open())
annotated_codes = {c["code"] for c in corrections}
all_codes = {c["gold_chars"] for c in processed_data}
unann_codes = all_codes - annotated_codes
mongo_conn = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
for c in unann_codes:
mongo_conn.find_one_and_update(
{"type": "correction", "code": c},
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
upsert=True,
)
@app.command()
def update_corrections(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
processed_data_path: Path = Path("./data/processed_data.json"),
corrections_path: Path = Path("./data/corrections.json"),
):
def correct_manifest(manifest_data_gen, corrections_path):
corrections = json.load(corrections_path.open())
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
correction_map = {
c["code"]: c["value"]["correction"]
for c in corrections
if c["value"]["status"] == "Incorrect"
}
# for d in manifest_data_gen:
# if d["chars"] in incorrect_set:
# d["audio_path"].unlink()
renamed_set = set()
for d in manifest_data_gen:
if d["chars"] in correct_set:
yield {
"audio_filepath": d["audio_filepath"],
"duration": d["duration"],
"text": d["text"],
}
elif d["chars"] in correction_map:
correct_text = correction_map[d["chars"]]
renamed_set.add(correct_text)
new_name = str(Path(correct_text).with_suffix(".wav"))
d["audio_path"].replace(d["audio_path"].with_name(new_name))
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
yield {
"audio_filepath": new_filepath,
"duration": d["duration"],
"text": alnum_to_asr_tokens(correct_text),
}
else:
# don't delete if another correction points to an old file
if d["chars"] not in renamed_set:
d["audio_path"].unlink()
else:
print(f'skipping deletion of correction:{d["chars"]}')
typer.echo(f"Using data manifest:{data_manifest_path}")
dataset_dir = data_manifest_path.parent
dataset_name = dataset_dir.name
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
if not backup_dir.exists():
typer.echo(f"backing up to :{backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir))
manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
new_data_manifest_path.replace(data_manifest_path)
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -14,7 +14,6 @@ import typer
from .jasper_client import transcriber_pretrained, transcriber_speller
from .st_rerun import rerun
app = typer.Typer()
st.title("ASR Speller Validation")
@@ -53,7 +52,6 @@ if not hasattr(st, "mongo_connected"):
{"$set": {"value": value}},
upsert=True,
)
rerun()
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
if not cursor_obj:
@@ -76,7 +74,6 @@ def preprocess_datapoint(idx, rel, sample):
audio_path = rel / Path(sample["audio_filepath"])
res["audio_path"] = audio_path
res["gold_chars"] = audio_path.stem
res["gold_phone"] = sample["text"]
aud_seg = (
AudioSegment.from_wav(audio_path)
.set_channels(1)
@@ -85,7 +82,7 @@ def preprocess_datapoint(idx, rel, sample):
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
res["wer"] = word_error_rate([res["text"]], [res["speller_asr"]])
(y, sr) = librosa.load(audio_path)
plt.tight_layout()
librosa.display.waveplot(y=y, sr=sr)
@@ -116,7 +113,7 @@ def main(manifest: Path):
sample_no = st.get_current_cursor()
sample = pnr_data[sample_no]
st.markdown(
f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['gold_phone']}*"
f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['text']}*"
)
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data)
@@ -125,7 +122,7 @@ def main(manifest: Path):
st.update_cursor(new_sample - 1)
st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**")
st.sidebar.markdown(f"Expected Speech: *{sample['gold_phone']}*")
st.sidebar.markdown(f"Expected Speech: *{sample['text']}*")
st.sidebar.title("Results:")
st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}")
st.sidebar.text(f"Speller:{sample['speller_asr']}")
@@ -158,6 +155,7 @@ def main(manifest: Path):
st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
)
# real_idx = st.text_input("Go to real-index:", value=sample['real_idx'])
# st.markdown(
# ",".join(
# [