jasper-asr/jasper/data/validation/process.py

419 lines
16 KiB
Python
Raw Normal View History

1. integrated data generator using google tts 2. added training script fix module packaging issue implement call audio data recycler for asr 1. added streamlit based validation ui with mongodb datastore integration 2. fix asr wrong sample rate inference 3. update requirements 1. refactored streamlit code 2. fixed issues in data manifest handling refresh to next entry on submit and comment out mongo clearing code for safety :P add validation ui and post processing to correct using validation data 1. added a tool to extract asr data from gcp transcripts logs 2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file 3. clean-up leaderboard duration logic 4. added a wip dataloader service 5. made the asr_data_writer util more generic with verbose flags and unique filename 6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node 7. refactored the validation post processing to dump a ui config for validation 8. included utility functions to correct, fill update and clear annotations from mongodb data 9. refactored the ui logic to be more generic for any asr data 10. updated setup.py dependencies to support the above features unlink temporary files after transcribing 1. clean-up unused data process code 2. fix invalid sample no from mongo 3. data loader service return remote netref 1. added training utils with custom data loaders with remote rpyc dataservice support 2. fix validation correction dump path 3. cache dataset for precaching before training to memory 4. update dependencies 1. implement dataset augmentation and validation in process 2. added option to skip 'incorrect' annotations in validation data 3. added confirmation on clearing mongo collection 4. added an option to navigate to a given text in the validation ui 5. added a dataset and remote option to trainer to load dataset from directory and remote rpyc service 1. added utility command to export call logs 2. mongo conn accepts port refactored module structure 1. enabled silece stripping in chunks when recycling audio from asr logs 2. limit asr recycling to 1 min of start audio to get reliable alignments and ignoring agent channel 3. added rev recycler for generating asr dataset from rev transcripts and audio 4. update pydub dependency for silence stripping fn and removing threadpool hardcoded worker count 1. added support for mono/dual channel rev transcripts 2. handle errors when extracting datapoints from rev meta data 3. added suport for annotation only task when dumping ui data cleanup rev recycle added option to disable plots during validation fix skipping null audio and add more verbose logs respect verbose flag don't load audio for annotation only ui and keep spoken as text for normal asr validation 1. refactored wav chunk processing method 2. renamed streamlit to validation_ui show duration on validation of dataset parallelize data loading from remote skipping invalid data points 1. removed the transcriber_pretrained/speller from utils 2. introduced get_mongo_coll to get the collection object directly from mongo uri 3. removed processing of correction entries to remove space/upper casing refactor validation process arguments and logging 1. added a data extraction type argument 2. cleanup/refactor 1. using dataname args for update/fill annotations 2. rename to dump_ui added support for name/dates/cities call data extraction and more logs handling non-pnr cases without parens in text data 1. added conv data generator 2. more utils 1. added start delay arg in call recycler 2. implement ui_dump/manifest writer in call_recycler itself 3. refactored call data point plotter 4. added sample-ui task-ui on the validation process 5. implemented call-quality stats using corrections from mongo 6. support deleting cursors on mongo 7. implement multiple task support on validation ui based on task_id mongo field fix 11st to 11th in ordinal stripping silence on call chunk 1. added option to strip silent chunks 2. computing caller quality based on task-id of corrections 1. fix update-correction to use ui_dump instead of manifest 2. update training params no of checkpoints on chpk frequency 1. split extract all data types in one shot with --extraction-type all flag 2. add notes about diffing split extracted and original data 3. add a nlu conv generator to generate conv data based on nlu utterances and entities 4. add task uid support for dumping corrections 5. abstracted generate date fn 1. added a test generator and slu evaluator 2. ui dump now include gcp results 3. showing default option for more args validation process commands added evaluation command clean-up
2020-04-08 11:56:27 +00:00
import json
import shutil
from pathlib import Path
from enum import Enum
import typer
from tqdm import tqdm
from ..utils import (
alnum_to_asr_tokens,
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
tscript_uuid_fname,
get_mongo_conn,
plot_seg,
)
app = typer.Typer()
def preprocess_datapoint(
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
):
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.client import transcribe_gen
try:
res = dict(sample)
res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
if use_domain_asr:
res["spoken"] = alnum_to_asr_tokens(res["text"])
else:
res["spoken"] = res["text"]
res["utterance_id"] = audio_path.stem
if not annotation_only:
transcriber_pretrained = transcribe_gen(asr_port=8044)
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate(
[res["text"]], [res["pretrained_asr"]]
)
if use_domain_asr:
transcriber_speller = transcribe_gen(asr_port=8045)
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]]
)
if enable_plots:
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
return res
except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
@app.command()
def dump_ui(
data_name: str = typer.Option("dataname", show_default=True),
dataset_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/valiation_data"),
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
use_domain_asr: bool = False,
annotation_only: bool = False,
enable_plots: bool = True,
):
from concurrent.futures import ThreadPoolExecutor
from functools import partial
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
dump_path: Path = dump_dir / Path(data_name) / dump_fname
plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
data_funcs = [
partial(
preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
annotation_only,
enable_plots,
)
for i, v in enumerate(data_jsonl)
]
def exec_func(f):
return f()
with ThreadPoolExecutor() as exe:
print("starting all preprocess tasks")
data_final = filter(
None,
list(
tqdm(
exe.map(exec_func, data_funcs),
position=0,
leave=True,
total=len(data_funcs),
)
),
)
if annotation_only:
result = list(data_final)
else:
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(data_final, key=lambda x: x[wer_key], reverse=True)
ui_config = {
"use_domain_asr": use_domain_asr,
"annotation_only": annotation_only,
"enable_plots": enable_plots,
"data": result,
}
ExtendedPath(dump_path).write_json(ui_config)
@app.command()
def sample_ui(
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
sample_count: int = typer.Option(80, show_default=True),
sample_file: Path = Path("sample_dump.json"),
):
import pandas as pd
processed_data_path = dump_dir / Path(data_name) / dump_file
sample_path = dump_dir / Path(data_name) / sample_file
processed_data = ExtendedPath(processed_data_path).read_json()
df = pd.DataFrame(processed_data["data"])
samples_per_caller = sample_count // len(df["caller"].unique())
caller_samples = pd.concat(
[g.sample(samples_per_caller) for (c, g) in df.groupby("caller")]
)
caller_samples = caller_samples.reset_index(drop=True)
caller_samples["real_idx"] = caller_samples.index
sample_data = caller_samples.to_dict("records")
processed_data["data"] = sample_data
typer.echo(f"sampling {sample_count} datapoints")
ExtendedPath(sample_path).write_json(processed_data)
@app.command()
def task_ui(
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
task_count: int = typer.Option(4, show_default=True),
task_file: str = "task_dump",
):
import pandas as pd
import numpy as np
processed_data_path = dump_dir / Path(data_name) / dump_file
processed_data = ExtendedPath(processed_data_path).read_json()
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
task_f = task_f.reset_index(drop=True)
task_f["real_idx"] = task_f.index
task_data = task_f.to_dict("records")
processed_data["data"] = task_data
task_path = dump_dir / Path(data_name) / Path(task_file + f"-{t_idx}.json")
ExtendedPath(task_path).write_json(processed_data)
@app.command()
def dump_corrections(
task_uid: str,
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("corrections.json"),
):
dump_path = dump_dir / Path(data_name) / dump_fname
col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@app.command()
def caller_quality(
task_uid: str,
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("ui_dump.json"),
correction_fname: Path = Path("corrections.json"),
):
import copy
import pandas as pd
dump_path = dump_dir / Path(data_name) / dump_fname
correction_path = dump_dir / Path(data_name) / correction_fname
dump_data = ExtendedPath(dump_path).read_json()
dump_map = {d["utterance_id"]: d for d in dump_data["data"]}
correction_data = ExtendedPath(correction_path).read_json()
def correction_dp(c):
dp = copy.deepcopy(dump_map[c["code"]])
dp["valid"] = c["value"]["status"] == "Correct"
return dp
corrected_dump = [
correction_dp(c)
for c in correction_data
if c["task_id"].rsplit("-", 1)[1] == task_uid
]
df = pd.DataFrame(corrected_dump)
print(f"Total samples: {len(df)}")
for (c, g) in df.groupby("caller"):
total = len(g)
valid = len(g[g["valid"] == True])
valid_rate = valid * 100 / total
print(f"Caller: {c} Valid%:{valid_rate:.2f} of {total} samples")
@app.command()
def fill_unannotated(
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
dump_file: Path = Path("ui_dump.json"),
corrections_file: Path = Path("corrections.json"),
):
processed_data_path = dump_dir / Path(data_name) / dump_file
corrections_path = dump_dir / Path(data_name) / corrections_file
processed_data = json.load(processed_data_path.open())
corrections = json.load(corrections_path.open())
annotated_codes = {c["code"] for c in corrections}
all_codes = {c["gold_chars"] for c in processed_data}
unann_codes = all_codes - annotated_codes
mongo_conn = get_mongo_conn(col="asr_validation")
for c in unann_codes:
mongo_conn.find_one_and_update(
{"type": "correction", "code": c},
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
upsert=True,
)
@app.command()
def split_extract(
data_name: str = typer.Option("dataname", show_default=True),
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
# dump_dir: Path = Path("./data/valiation_data"),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
extraction_type: str = "all",
):
import shutil
data_manifest_path = dump_dir / Path(data_name) / manifest_file
conv_data = ExtendedPath(conv_data_path).read_json()
def extract_data_of_type(extraction_key):
extraction_vals = conv_data[extraction_key]
dest_data_name = data_name + "_" + extraction_key.lower()
manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = dump_dir / Path(dest_data_name)
dest_data_dir.mkdir(exist_ok=True, parents=True)
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file
dest_ui_path = dest_data_dir / dump_file
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
ui_data_path = dump_dir / Path(data_name) / dump_file
orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
final_data = []
for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i
final_data.append(d)
orig_ui_data['data'] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file:
dest_correction_path = dest_data_dir / corrections_file
corrections_path = dump_dir / Path(data_name) / corrections_file
corrections = json.load(corrections_path.open())
extracted_corrections = list(
filter(
lambda c: c["code"] in file_ui_map
and file_ui_map[c["code"]]["text"] in extraction_vals,
corrections,
)
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all':
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
extract_data_of_type(extraction_type.value)
@app.command()
def update_corrections(
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
ui_dump_file: Path = Path("ui_dump.json"),
skip_incorrect: bool = typer.Option(True, show_default=True),
):
data_manifest_path = dump_dir / Path(data_name) / manifest_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
correction_map = {
c["code"]: c["value"]["correction"]
for c in corrections
if c["value"]["status"] == "Incorrect"
}
# for d in manifest_data_gen:
# if d["chars"] in incorrect_set:
# d["audio_path"].unlink()
# renamed_set = set()
for d in ui_data:
if d["utterance_id"] in correct_set:
yield {
"audio_filepath": d["audio_filepath"],
"duration": d["duration"],
"text": d["text"],
}
elif d["utterance_id"] in correction_map:
correct_text = correction_map[d["utterance_id"]]
if skip_incorrect:
print(
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
)
else:
orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
yield {
"audio_filepath": new_filepath,
"duration": d["duration"],
"text": correct_text,
}
else:
orig_audio_path = Path(d["audio_path"])
# don't delete if another correction points to an old file
# if d["text"] not in renamed_set:
orig_audio_path.unlink()
# else:
# print(f'skipping deletion of correction:{d["text"]}')
typer.echo(f"Using data manifest:{data_manifest_path}")
dataset_dir = data_manifest_path.parent
dataset_name = dataset_dir.name
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
if not backup_dir.exists():
typer.echo(f"backing up to :{backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir))
# manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
new_data_manifest_path.replace(data_manifest_path)
@app.command()
def clear_mongo_corrections():
delete = typer.confirm("are you sure you want to clear mongo collection it?")
if delete:
col = get_mongo_conn(col="asr_validation")
col.delete_many({"type": "correction"})
col.delete_many({"type": "current_cursor"})
typer.echo("deleted mongo collection.")
return
typer.echo("Aborted")
def main():
app()
if __name__ == "__main__":
main()