1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-06-13 12:32:08 +00:00

1. refactored ui_dump

2. added flake8
This commit is contained in:
2020-08-09 19:16:35 +05:30
parent 42647196fe
commit e8f58a5043
4 changed files with 154 additions and 179 deletions

View File

@@ -1,13 +1,10 @@
import json
import shutil
from pathlib import Path
from enum import Enum
import typer
from tqdm import tqdm
from ..utils import (
alnum_to_asr_tokens,
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
@@ -19,9 +16,7 @@ from ..utils import (
app = typer.Typer()
def preprocess_datapoint(
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
):
def preprocess_datapoint(idx, rel_root, sample):
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.client import transcribe_gen
@@ -31,37 +26,23 @@ def preprocess_datapoint(
res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
if use_domain_asr:
res["spoken"] = alnum_to_asr_tokens(res["text"])
else:
res["spoken"] = res["text"]
res["utterance_id"] = audio_path.stem
if not annotation_only:
transcriber_pretrained = transcribe_gen(asr_port=8044)
transcriber_pretrained = transcribe_gen(asr_port=8044)
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate(
[res["text"]], [res["pretrained_asr"]]
)
if use_domain_asr:
transcriber_speller = transcribe_gen(asr_port=8045)
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]]
)
if enable_plots:
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
return res
except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
@@ -73,61 +54,50 @@ def dump_ui(
dataset_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/valiation_data"),
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
use_domain_asr: bool = False,
annotation_only: bool = False,
enable_plots: bool = True,
):
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from io import BytesIO
from pydub import AudioSegment
from ..utils import ui_data_generator
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
dump_path: Path = dump_dir / Path(data_name) / dump_fname
plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
data_funcs = [
partial(
preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
annotation_only,
enable_plots,
)
for i, v in enumerate(data_jsonl)
]
def exec_func(f):
return f()
def asr_data_source_gen():
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
for v in data_jsonl:
sample = json.loads(v)
rel_root = data_manifest_path.parent
res = dict(sample)
audio_path = rel_root / Path(sample["audio_filepath"])
audio_segment = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
wav_plot_path = (
rel_root
/ Path("wav_plots")
/ Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
code_fb = BytesIO()
audio_segment.export(code_fb, format="wav")
wav_data = code_fb.getvalue()
duration = audio_segment.duration_seconds
asr_final = res["text"]
yield asr_final, duration, wav_data, "caller", audio_segment
with ThreadPoolExecutor() as exe:
print("starting all preprocess tasks")
data_final = filter(
None,
list(
tqdm(
exe.map(exec_func, data_funcs),
position=0,
leave=True,
total=len(data_funcs),
)
),
)
if annotation_only:
result = list(data_final)
else:
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(data_final, key=lambda x: x[wer_key], reverse=True)
ui_config = {
"use_domain_asr": use_domain_asr,
"annotation_only": annotation_only,
"enable_plots": enable_plots,
"data": result,
}
ExtendedPath(dump_path).write_json(ui_config)
dump_data, num_datapoints = ui_data_generator(
dataset_dir, data_name, asr_data_source_gen()
)
ui_dump_file = dataset_dir / Path("ui_dump.json")
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
@app.command()
@@ -190,7 +160,9 @@ def dump_corrections(
col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
cursor_obj = col.find(
{"type": "correction", "task_id": task_id}, projection={"_id": False}
)
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@@ -264,7 +236,9 @@ def split_extract(
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
conv_data_path: Path = typer.Option(
Path("./data/conv_data.json"), show_default=True
),
extraction_type: str = "all",
):
import shutil
@@ -286,7 +260,9 @@ def split_extract(
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
shutil.copy(
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
)
yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
@@ -295,12 +271,14 @@ def split_extract(
orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
extracted_ui_data = list(
filter(lambda u: u["text"] in extraction_vals, ui_data)
)
final_data = []
for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i
d["real_idx"] = i
final_data.append(d)
orig_ui_data['data'] = final_data
orig_ui_data["data"] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file:
@@ -316,7 +294,7 @@ def split_extract(
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all':
if extraction_type.value == "all":
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
@@ -338,7 +316,7 @@ def update_corrections(
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
@@ -367,7 +345,9 @@ def update_corrections(
)
else:
orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
new_name = str(
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
)
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))