add validation ui and post processing to correct using validation data
parent
aae03a6ae4
commit
a7da729c0b
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from .utils import alnum_to_asr_tokens
|
from .utils import alnum_to_asr_tokens, asr_manifest_reader, asr_manifest_writer
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
@ -30,35 +30,38 @@ def separate_space_convert_digit_setpath():
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def split_data(manifest_path: Path = Path("/dataset/asr_data/pnr_data/pnr_data.json")):
|
def split_data(dataset_path: Path, test_size: float = 0.1):
|
||||||
with manifest_path.open("r") as pf:
|
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||||
pnr_jsonl = pf.readlines()
|
asr_data = list(asr_manifest_reader(manifest_path))
|
||||||
train_pnr, test_pnr = train_test_split(pnr_jsonl, test_size=0.1)
|
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
|
||||||
with (manifest_path.parent / Path("train_manifest.json")).open("w") as pf:
|
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
|
||||||
pnr_data = "".join(train_pnr)
|
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
|
||||||
pf.write(pnr_data)
|
|
||||||
with (manifest_path.parent / Path("test_manifest.json")).open("w") as pf:
|
|
||||||
pnr_data = "".join(test_pnr)
|
|
||||||
pf.write(pnr_data)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def fix_path(
|
def fixate_data(dataset_path: Path):
|
||||||
dataset_path: Path = Path("/dataset/asr_data/call_alphanum"),
|
manifest_path = dataset_path / Path("manifest.json")
|
||||||
):
|
real_manifest_path = dataset_path / Path("abs_manifest.json")
|
||||||
manifest_path = dataset_path / Path('manifest.json')
|
|
||||||
with manifest_path.open("r") as pf:
|
def fix_path():
|
||||||
pnr_jsonl = pf.readlines()
|
for i in asr_manifest_reader(manifest_path):
|
||||||
pnr_data = [json.loads(i) for i in pnr_jsonl]
|
|
||||||
new_pnr_data = []
|
|
||||||
for i in pnr_data:
|
|
||||||
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
|
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
|
||||||
new_pnr_data.append(i)
|
yield i
|
||||||
new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data]
|
|
||||||
real_manifest_path = dataset_path / Path('real_manifest.json')
|
asr_manifest_writer(real_manifest_path, fix_path())
|
||||||
with real_manifest_path.open("w") as pf:
|
|
||||||
new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n"
|
# with manifest_path.open("r") as pf:
|
||||||
pf.write(new_pnr_data)
|
# pnr_jsonl = pf.readlines()
|
||||||
|
# pnr_data = [json.loads(i) for i in pnr_jsonl]
|
||||||
|
# new_pnr_data = []
|
||||||
|
# for i in pnr_data:
|
||||||
|
# i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
|
||||||
|
# new_pnr_data.append(i)
|
||||||
|
# new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data]
|
||||||
|
# real_manifest_path = dataset_path / Path("abs_manifest.json")
|
||||||
|
# with real_manifest_path.open("w") as pf:
|
||||||
|
# new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n"
|
||||||
|
# pf.write(new_pnr_data)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
|
@ -78,14 +81,18 @@ def augment_an4():
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def validate_data(data_file: Path = Path("/dataset/asr_data/call_alphanum/train_manifest.json")):
|
def validate_data(data_file: Path):
|
||||||
with Path(data_file).open("r") as pf:
|
with Path(data_file).open("r") as pf:
|
||||||
pnr_jsonl = pf.readlines()
|
pnr_jsonl = pf.readlines()
|
||||||
for (i, s) in enumerate(pnr_jsonl):
|
for (i, s) in enumerate(pnr_jsonl):
|
||||||
try:
|
try:
|
||||||
json.loads(s)
|
d = json.loads(s)
|
||||||
|
audio_file = data_file.parent / Path(d["audio_filepath"])
|
||||||
|
if not audio_file.exists():
|
||||||
|
raise OSError(f"File {audio_file} not found")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
print(f"failed on {i}")
|
print(f'failed on {i} with "{e}"')
|
||||||
|
print("no errors found. seems like a valid manifest.")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,27 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source):
|
||||||
mf.write(manifest)
|
mf.write(manifest)
|
||||||
|
|
||||||
|
|
||||||
|
def asr_manifest_reader(data_manifest_path: Path):
|
||||||
|
print(f'reading manifest from {data_manifest_path}')
|
||||||
|
with data_manifest_path.open("r") as pf:
|
||||||
|
pnr_jsonl = pf.readlines()
|
||||||
|
pnr_data = [json.loads(v) for v in pnr_jsonl]
|
||||||
|
for p in pnr_data:
|
||||||
|
p['audio_path'] = data_manifest_path.parent / Path(p['audio_filepath'])
|
||||||
|
p['chars'] = Path(p['audio_filepath']).stem
|
||||||
|
yield p
|
||||||
|
|
||||||
|
|
||||||
|
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
|
||||||
|
with asr_manifest_path.open("w") as mf:
|
||||||
|
print(f'opening {asr_manifest_path} for writing manifest')
|
||||||
|
for mani_dict in manifest_str_source:
|
||||||
|
manifest = manifest_str(
|
||||||
|
mani_dict['audio_filepath'], mani_dict['duration'], mani_dict['text']
|
||||||
|
)
|
||||||
|
mf.write(manifest)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
for c in random_pnr_generator():
|
for c in random_pnr_generator():
|
||||||
print(c)
|
print(c)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,176 @@
|
||||||
|
import pymongo
|
||||||
|
import typer
|
||||||
|
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# import pandas as pd
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
# from .jasper_client import transcriber_pretrained, transcriber_speller
|
||||||
|
from jasper.data_utils.validation.jasper_client import (
|
||||||
|
transcriber_pretrained,
|
||||||
|
transcriber_speller,
|
||||||
|
)
|
||||||
|
from jasper.data_utils.utils import alnum_to_asr_tokens
|
||||||
|
|
||||||
|
# import importlib
|
||||||
|
# import jasper.data_utils.utils
|
||||||
|
# importlib.reload(jasper.data_utils.utils)
|
||||||
|
from jasper.data_utils.utils import asr_manifest_reader, asr_manifest_writer
|
||||||
|
from nemo.collections.asr.metrics import word_error_rate
|
||||||
|
|
||||||
|
# from tqdm import tqdm as tqdm_base
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def dump_corrections(dump_path: Path = Path("./data/corrections.json")):
|
||||||
|
col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
|
||||||
|
|
||||||
|
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
||||||
|
corrections = [c for c in cursor_obj]
|
||||||
|
dump_f = dump_path.open("w")
|
||||||
|
json.dump(corrections, dump_f, indent=2)
|
||||||
|
dump_f.close()
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_datapoint(idx, rel, sample):
|
||||||
|
res = dict(sample)
|
||||||
|
res["real_idx"] = idx
|
||||||
|
audio_path = rel / Path(sample["audio_filepath"])
|
||||||
|
res["audio_path"] = str(audio_path)
|
||||||
|
res["gold_chars"] = audio_path.stem
|
||||||
|
res["gold_phone"] = sample["text"]
|
||||||
|
aud_seg = (
|
||||||
|
AudioSegment.from_wav(audio_path)
|
||||||
|
.set_channels(1)
|
||||||
|
.set_sample_width(2)
|
||||||
|
.set_frame_rate(24000)
|
||||||
|
)
|
||||||
|
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||||
|
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||||
|
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(data_manifest_path: Path):
|
||||||
|
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||||
|
with data_manifest_path.open("r") as pf:
|
||||||
|
pnr_jsonl = pf.readlines()
|
||||||
|
pnr_data = [
|
||||||
|
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
|
||||||
|
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
|
||||||
|
]
|
||||||
|
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def dump_processed_data(
|
||||||
|
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||||
|
dump_path: Path = Path("./data/processed_data.json"),
|
||||||
|
):
|
||||||
|
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||||
|
with data_manifest_path.open("r") as pf:
|
||||||
|
pnr_jsonl = pf.readlines()
|
||||||
|
pnr_data = [
|
||||||
|
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
|
||||||
|
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
|
||||||
|
]
|
||||||
|
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
|
||||||
|
dump_path = Path("./data/processed_data.json")
|
||||||
|
dump_f = dump_path.open("w")
|
||||||
|
json.dump(result, dump_f, indent=2)
|
||||||
|
dump_f.close()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def fill_unannotated(
|
||||||
|
processed_data_path: Path = Path("./data/processed_data.json"),
|
||||||
|
corrections_path: Path = Path("./data/corrections.json"),
|
||||||
|
):
|
||||||
|
processed_data = json.load(processed_data_path.open())
|
||||||
|
corrections = json.load(corrections_path.open())
|
||||||
|
annotated_codes = {c["code"] for c in corrections}
|
||||||
|
all_codes = {c["gold_chars"] for c in processed_data}
|
||||||
|
unann_codes = all_codes - annotated_codes
|
||||||
|
mongo_conn = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
|
||||||
|
for c in unann_codes:
|
||||||
|
mongo_conn.find_one_and_update(
|
||||||
|
{"type": "correction", "code": c},
|
||||||
|
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
|
||||||
|
upsert=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def update_corrections(
|
||||||
|
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||||
|
processed_data_path: Path = Path("./data/processed_data.json"),
|
||||||
|
corrections_path: Path = Path("./data/corrections.json"),
|
||||||
|
):
|
||||||
|
def correct_manifest(manifest_data_gen, corrections_path):
|
||||||
|
corrections = json.load(corrections_path.open())
|
||||||
|
correct_set = {
|
||||||
|
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||||
|
}
|
||||||
|
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
|
||||||
|
correction_map = {
|
||||||
|
c["code"]: c["value"]["correction"]
|
||||||
|
for c in corrections
|
||||||
|
if c["value"]["status"] == "Incorrect"
|
||||||
|
}
|
||||||
|
# for d in manifest_data_gen:
|
||||||
|
# if d["chars"] in incorrect_set:
|
||||||
|
# d["audio_path"].unlink()
|
||||||
|
renamed_set = set()
|
||||||
|
for d in manifest_data_gen:
|
||||||
|
if d["chars"] in correct_set:
|
||||||
|
yield {
|
||||||
|
"audio_filepath": d["audio_filepath"],
|
||||||
|
"duration": d["duration"],
|
||||||
|
"text": d["text"],
|
||||||
|
}
|
||||||
|
elif d["chars"] in correction_map:
|
||||||
|
correct_text = correction_map[d["chars"]]
|
||||||
|
renamed_set.add(correct_text)
|
||||||
|
new_name = str(Path(correct_text).with_suffix(".wav"))
|
||||||
|
d["audio_path"].replace(d["audio_path"].with_name(new_name))
|
||||||
|
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||||
|
yield {
|
||||||
|
"audio_filepath": new_filepath,
|
||||||
|
"duration": d["duration"],
|
||||||
|
"text": alnum_to_asr_tokens(correct_text),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# don't delete if another correction points to an old file
|
||||||
|
if d["chars"] not in renamed_set:
|
||||||
|
d["audio_path"].unlink()
|
||||||
|
else:
|
||||||
|
print(f'skipping deletion of correction:{d["chars"]}')
|
||||||
|
|
||||||
|
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||||
|
dataset_dir = data_manifest_path.parent
|
||||||
|
dataset_name = dataset_dir.name
|
||||||
|
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
|
||||||
|
if not backup_dir.exists():
|
||||||
|
typer.echo(f"backing up to :{backup_dir}")
|
||||||
|
shutil.copytree(str(dataset_dir), str(backup_dir))
|
||||||
|
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||||
|
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
|
||||||
|
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
||||||
|
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
||||||
|
new_data_manifest_path.replace(data_manifest_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -14,7 +14,6 @@ import typer
|
||||||
from .jasper_client import transcriber_pretrained, transcriber_speller
|
from .jasper_client import transcriber_pretrained, transcriber_speller
|
||||||
from .st_rerun import rerun
|
from .st_rerun import rerun
|
||||||
|
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
st.title("ASR Speller Validation")
|
st.title("ASR Speller Validation")
|
||||||
|
|
||||||
|
|
@ -53,7 +52,6 @@ if not hasattr(st, "mongo_connected"):
|
||||||
{"$set": {"value": value}},
|
{"$set": {"value": value}},
|
||||||
upsert=True,
|
upsert=True,
|
||||||
)
|
)
|
||||||
rerun()
|
|
||||||
|
|
||||||
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
|
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
|
||||||
if not cursor_obj:
|
if not cursor_obj:
|
||||||
|
|
@ -76,7 +74,6 @@ def preprocess_datapoint(idx, rel, sample):
|
||||||
audio_path = rel / Path(sample["audio_filepath"])
|
audio_path = rel / Path(sample["audio_filepath"])
|
||||||
res["audio_path"] = audio_path
|
res["audio_path"] = audio_path
|
||||||
res["gold_chars"] = audio_path.stem
|
res["gold_chars"] = audio_path.stem
|
||||||
res["gold_phone"] = sample["text"]
|
|
||||||
aud_seg = (
|
aud_seg = (
|
||||||
AudioSegment.from_wav(audio_path)
|
AudioSegment.from_wav(audio_path)
|
||||||
.set_channels(1)
|
.set_channels(1)
|
||||||
|
|
@ -85,7 +82,7 @@ def preprocess_datapoint(idx, rel, sample):
|
||||||
)
|
)
|
||||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||||
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
|
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||||
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
|
res["wer"] = word_error_rate([res["text"]], [res["speller_asr"]])
|
||||||
(y, sr) = librosa.load(audio_path)
|
(y, sr) = librosa.load(audio_path)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
librosa.display.waveplot(y=y, sr=sr)
|
librosa.display.waveplot(y=y, sr=sr)
|
||||||
|
|
@ -116,7 +113,7 @@ def main(manifest: Path):
|
||||||
sample_no = st.get_current_cursor()
|
sample_no = st.get_current_cursor()
|
||||||
sample = pnr_data[sample_no]
|
sample = pnr_data[sample_no]
|
||||||
st.markdown(
|
st.markdown(
|
||||||
f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['gold_phone']}*"
|
f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['text']}*"
|
||||||
)
|
)
|
||||||
new_sample = st.number_input(
|
new_sample = st.number_input(
|
||||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data)
|
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data)
|
||||||
|
|
@ -125,7 +122,7 @@ def main(manifest: Path):
|
||||||
st.update_cursor(new_sample - 1)
|
st.update_cursor(new_sample - 1)
|
||||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||||
st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**")
|
st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**")
|
||||||
st.sidebar.markdown(f"Expected Speech: *{sample['gold_phone']}*")
|
st.sidebar.markdown(f"Expected Speech: *{sample['text']}*")
|
||||||
st.sidebar.title("Results:")
|
st.sidebar.title("Results:")
|
||||||
st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}")
|
st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}")
|
||||||
st.sidebar.text(f"Speller:{sample['speller_asr']}")
|
st.sidebar.text(f"Speller:{sample['speller_asr']}")
|
||||||
|
|
@ -158,6 +155,7 @@ def main(manifest: Path):
|
||||||
st.markdown(
|
st.markdown(
|
||||||
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
||||||
)
|
)
|
||||||
|
# real_idx = st.text_input("Go to real-index:", value=sample['real_idx'])
|
||||||
# st.markdown(
|
# st.markdown(
|
||||||
# ",".join(
|
# ",".join(
|
||||||
# [
|
# [
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue