mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-03-09 19:02:35 +00:00
1. added deepgram support
2. compute asr sample accuracy
This commit is contained in:
@@ -156,6 +156,47 @@ def sample_ui(
|
||||
ExtendedPath(sample_path).write_json(processed_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def sample_asr_accuracy(
|
||||
data_name: str = typer.Option(
|
||||
"png_06_2020_week1_numbers_window_customer", show_default=True
|
||||
),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
sample_file: Path = Path("sample_dump.json"),
|
||||
asr_service: str = "deepgram",
|
||||
):
|
||||
# import pandas as pd
|
||||
# from pydub import AudioSegment
|
||||
from ..utils import is_sub_sequence, Text2Num
|
||||
|
||||
# from ..utils import deepgram_transcribe_gen
|
||||
#
|
||||
# deepgram_transcriber = deepgram_transcribe_gen()
|
||||
t2n = Text2Num()
|
||||
# processed_data_path = dump_dir / Path(data_name) / dump_file
|
||||
sample_path = dump_dir / Path(data_name) / sample_file
|
||||
processed_data = ExtendedPath(sample_path).read_json()
|
||||
# asr_data = []
|
||||
match_count, total_samples = 0, len(processed_data["data"])
|
||||
for dp in tqdm(processed_data["data"]):
|
||||
# aud_data = Path(dp["audio_path"]).read_bytes()
|
||||
# dgram_result = deepgram_transcriber(aud_data)
|
||||
# dp["deepgram_asr"] = dgram_result
|
||||
gcp_num = dp["text"]
|
||||
dgm_num = t2n.convert(dp["deepgram_asr"].lower())
|
||||
if is_sub_sequence(gcp_num, dgm_num):
|
||||
match_count += 1
|
||||
print(f"MATCH GCP:{gcp_num}\tDGM:{dgm_num}")
|
||||
else:
|
||||
print(f"FAIL GCP:{gcp_num}\tDGM:{dgm_num}")
|
||||
# asr_data.append(dp)
|
||||
typer.echo(
|
||||
f"{match_count} from deepgram matches with {total_samples} gcp transcripts."
|
||||
)
|
||||
# processed_data["data"] = asr_data
|
||||
# ExtendedPath(sample_path).write_json(processed_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def task_ui(
|
||||
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||
@@ -190,7 +231,9 @@ def dump_corrections(
|
||||
col = get_mongo_conn(col="asr_validation")
|
||||
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
||||
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
||||
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
|
||||
cursor_obj = col.find(
|
||||
{"type": "correction", "task_id": task_id}, projection={"_id": False}
|
||||
)
|
||||
corrections = [c for c in cursor_obj]
|
||||
ExtendedPath(dump_path).write_json(corrections)
|
||||
|
||||
@@ -271,7 +314,9 @@ def split_extract(
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
||||
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
||||
conv_data_path: Path = typer.Option(
|
||||
Path("./data/conv_data.json"), show_default=True
|
||||
),
|
||||
extraction_type: ExtractionType = ExtractionType.all,
|
||||
):
|
||||
import shutil
|
||||
@@ -293,7 +338,9 @@ def split_extract(
|
||||
def extract_manifest(mg):
|
||||
for m in mg:
|
||||
if m["text"] in extraction_vals:
|
||||
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
|
||||
shutil.copy(
|
||||
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
|
||||
)
|
||||
yield m
|
||||
|
||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||
@@ -302,12 +349,14 @@ def split_extract(
|
||||
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
||||
ui_data = orig_ui_data["data"]
|
||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
||||
extracted_ui_data = list(
|
||||
filter(lambda u: u["text"] in extraction_vals, ui_data)
|
||||
)
|
||||
final_data = []
|
||||
for i, d in enumerate(extracted_ui_data):
|
||||
d['real_idx'] = i
|
||||
d["real_idx"] = i
|
||||
final_data.append(d)
|
||||
orig_ui_data['data'] = final_data
|
||||
orig_ui_data["data"] = final_data
|
||||
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
||||
|
||||
if corrections_file:
|
||||
@@ -323,7 +372,7 @@ def split_extract(
|
||||
)
|
||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
||||
|
||||
if extraction_type.value == 'all':
|
||||
if extraction_type.value == "all":
|
||||
for ext_key in conv_data.keys():
|
||||
extract_data_of_type(ext_key)
|
||||
else:
|
||||
@@ -345,7 +394,7 @@ def update_corrections(
|
||||
|
||||
def correct_manifest(ui_dump_path, corrections_path):
|
||||
corrections = ExtendedPath(corrections_path).read_json()
|
||||
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
|
||||
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
|
||||
correct_set = {
|
||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||
}
|
||||
@@ -374,7 +423,9 @@ def update_corrections(
|
||||
)
|
||||
else:
|
||||
orig_audio_path = Path(d["audio_path"])
|
||||
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
|
||||
new_name = str(
|
||||
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
|
||||
)
|
||||
new_audio_path = orig_audio_path.with_name(new_name)
|
||||
orig_audio_path.replace(new_audio_path)
|
||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||
|
||||
Reference in New Issue
Block a user