1. added deepgram support
2. compute asr sample accuracy
parent
fa89775f86
commit
f5c49338d9
|
|
@ -1,13 +1,16 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import base64
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from itertools import product
|
||||
from functools import partial
|
||||
from math import floor
|
||||
from uuid import uuid4
|
||||
from urllib.parse import urlsplit
|
||||
from urllib.parse import urlsplit, urlencode
|
||||
from urllib.request import Request, urlopen
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -100,6 +103,9 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
|||
"data": [],
|
||||
}
|
||||
data_funcs = []
|
||||
|
||||
deepgram_transcriber = deepgram_transcribe_gen()
|
||||
# t2n = Text2Num()
|
||||
transcriber_gcp = gcp_transcribe_gen()
|
||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||
with asr_manifest.open("w") as mf:
|
||||
|
|
@ -119,6 +125,10 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
|||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
||||
gcp_seg = aud_seg.set_frame_rate(16000)
|
||||
gcp_result = transcriber_gcp(gcp_seg.raw_data)
|
||||
aud_data = audio_path.read_bytes()
|
||||
dgram_result = deepgram_transcriber(aud_data)
|
||||
# gtruth = dp['text']
|
||||
# dgram_result = t2n.convert(dgram_script)
|
||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
||||
wav_plot_path = (
|
||||
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
||||
|
|
@ -135,6 +145,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
|||
"caller": caller_name,
|
||||
"utterance_id": fname,
|
||||
"gcp_asr": gcp_result,
|
||||
"deepgram_asr": dgram_result,
|
||||
"pretrained_asr": pretrained_result,
|
||||
"pretrained_wer": pretrained_wer,
|
||||
"plot_path": str(wav_plot_path),
|
||||
|
|
@ -471,6 +482,203 @@ def gcp_transcribe_gen():
|
|||
return sample_recognize
|
||||
|
||||
|
||||
def deepgram_transcribe_gen():
|
||||
|
||||
DEEPGRAM_URL = "https://brain.deepgram.com/v2/listen"
|
||||
MODEL = "agara"
|
||||
encoding = "linear16"
|
||||
sample_rate = "8000"
|
||||
# diarize = "false"
|
||||
q_params = {
|
||||
"model": MODEL,
|
||||
"encoding": encoding,
|
||||
"sample_rate": sample_rate,
|
||||
"language": "en-US",
|
||||
"multichannel": "false",
|
||||
"punctuate": "true",
|
||||
}
|
||||
url = "{}?{}".format(DEEPGRAM_URL, urlencode(q_params))
|
||||
# print(url)
|
||||
creds = ("arjun@agaralabs.com", "PoX1Y@x4h%oS")
|
||||
|
||||
def deepgram_offline(audio_data):
|
||||
request = Request(
|
||||
url,
|
||||
method="POST",
|
||||
headers={
|
||||
"Authorization": "Basic {}".format(
|
||||
base64.b64encode("{}:{}".format(*creds).encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
},
|
||||
data=audio_data,
|
||||
)
|
||||
with urlopen(request) as response:
|
||||
msg = json.loads(response.read())
|
||||
data = msg["results"]["channels"][0]["alternatives"][0]
|
||||
return data["transcript"]
|
||||
|
||||
return deepgram_offline
|
||||
|
||||
|
||||
class Text2Num(object):
|
||||
"""docstring for Text2Num."""
|
||||
|
||||
def __init__(self):
|
||||
numwords = {}
|
||||
if not numwords:
|
||||
units = [
|
||||
"zero",
|
||||
"one",
|
||||
"two",
|
||||
"three",
|
||||
"four",
|
||||
"five",
|
||||
"six",
|
||||
"seven",
|
||||
"eight",
|
||||
"nine",
|
||||
"ten",
|
||||
"eleven",
|
||||
"twelve",
|
||||
"thirteen",
|
||||
"fourteen",
|
||||
"fifteen",
|
||||
"sixteen",
|
||||
"seventeen",
|
||||
"eighteen",
|
||||
"nineteen",
|
||||
]
|
||||
|
||||
tens = [
|
||||
"",
|
||||
"",
|
||||
"twenty",
|
||||
"thirty",
|
||||
"forty",
|
||||
"fifty",
|
||||
"sixty",
|
||||
"seventy",
|
||||
"eighty",
|
||||
"ninety",
|
||||
]
|
||||
|
||||
scales = ["hundred", "thousand", "million", "billion", "trillion"]
|
||||
|
||||
numwords["and"] = (1, 0)
|
||||
for idx, word in enumerate(units):
|
||||
numwords[word] = (1, idx)
|
||||
for idx, word in enumerate(tens):
|
||||
numwords[word] = (1, idx * 10)
|
||||
for idx, word in enumerate(scales):
|
||||
numwords[word] = (10 ** (idx * 3 or 2), 0)
|
||||
self.numwords = numwords
|
||||
|
||||
def is_num(self, word):
|
||||
return word in self.numwords
|
||||
|
||||
def parseOrdinal(self, utterance, **kwargs):
|
||||
lookup_dict = {
|
||||
"first": 1,
|
||||
"second": 2,
|
||||
"third": 3,
|
||||
"fourth": 4,
|
||||
"fifth": 5,
|
||||
"sixth": 6,
|
||||
"seventh": 7,
|
||||
"eighth": 8,
|
||||
"ninth": 9,
|
||||
"tenth": 10,
|
||||
"one": 1,
|
||||
"two": 2,
|
||||
"three": 3,
|
||||
"four": 4,
|
||||
"five": 5,
|
||||
"six": 6,
|
||||
"seven": 7,
|
||||
"eight": 8,
|
||||
"nine": 9,
|
||||
"ten": 10,
|
||||
"1": 1,
|
||||
"2": 2,
|
||||
"3": 3,
|
||||
"4": 4,
|
||||
"5": 5,
|
||||
"6": 6,
|
||||
"7": 7,
|
||||
"8": 8,
|
||||
"9": 9,
|
||||
"10": 10,
|
||||
"last": -1,
|
||||
}
|
||||
pattern = re.compile(
|
||||
r"(\s|^)(?P<num>(first)|(third)|(fourth)|(fifth)|(sixth)|(seventh)|(eighth)|(ninth)|(tenth)|(two)|(three)|(four)|(five)|(six)|(seven)|(eight)|(nine)|(ten)|(1)|(2)|(3)|(4)|(5)|(6)|(7)|(8)|(9)|(10)|(last))(\s|$)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
ordinal = ""
|
||||
if pattern.search(utterance):
|
||||
ordinal = pattern.search(utterance).groupdict()["num"].strip()
|
||||
elif re.search(r"(\s|^)(?P<num>(second))(\s|$)", utterance):
|
||||
ordinal = "second"
|
||||
elif re.search(r"(\s|^)(?P<num>(one))(\s|$)", utterance):
|
||||
ordinal = "one"
|
||||
ordinal = lookup_dict.get(ordinal, "")
|
||||
return ordinal
|
||||
|
||||
def convert(self, sent):
|
||||
# res = []
|
||||
# for token in sent.split():
|
||||
# if token in self.numwords:
|
||||
# res.append(str(self.text2int(token)))
|
||||
# else:
|
||||
# res.append(token)
|
||||
# return " ".join(res)
|
||||
|
||||
return " ".join(
|
||||
[
|
||||
str(self.parseOrdinal(x)) if self.parseOrdinal(x) != "" else x
|
||||
for x in sent.split()
|
||||
]
|
||||
)
|
||||
|
||||
def text2int(self, textnum):
|
||||
|
||||
current = result = 0
|
||||
for word in textnum.split():
|
||||
if word not in self.numwords:
|
||||
raise Exception("Illegal word: " + word)
|
||||
|
||||
scale, increment = self.numwords[word]
|
||||
current = current * scale + increment
|
||||
if scale > 100:
|
||||
result += current
|
||||
current = 0
|
||||
|
||||
return result + current
|
||||
|
||||
|
||||
def is_sub_sequence(str1, str2):
|
||||
m = len(str1)
|
||||
n = len(str2)
|
||||
|
||||
def check_seq(string1, string2, m, n):
|
||||
# Base Cases
|
||||
if m == 0:
|
||||
return True
|
||||
if n == 0:
|
||||
return False
|
||||
|
||||
# If last characters of two strings are matching
|
||||
if string1[m - 1] == string2[n - 1]:
|
||||
return check_seq(string1, string2, m - 1, n - 1)
|
||||
|
||||
# If last characters are not matching
|
||||
return check_seq(string1, string2, m, n - 1)
|
||||
|
||||
return check_seq(str1, str2, m, n)
|
||||
|
||||
|
||||
def parallel_apply(fn, iterable, workers=8):
|
||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||
print(f"parallelly applying {fn}")
|
||||
|
|
|
|||
|
|
@ -156,6 +156,47 @@ def sample_ui(
|
|||
ExtendedPath(sample_path).write_json(processed_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def sample_asr_accuracy(
|
||||
data_name: str = typer.Option(
|
||||
"png_06_2020_week1_numbers_window_customer", show_default=True
|
||||
),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
sample_file: Path = Path("sample_dump.json"),
|
||||
asr_service: str = "deepgram",
|
||||
):
|
||||
# import pandas as pd
|
||||
# from pydub import AudioSegment
|
||||
from ..utils import is_sub_sequence, Text2Num
|
||||
|
||||
# from ..utils import deepgram_transcribe_gen
|
||||
#
|
||||
# deepgram_transcriber = deepgram_transcribe_gen()
|
||||
t2n = Text2Num()
|
||||
# processed_data_path = dump_dir / Path(data_name) / dump_file
|
||||
sample_path = dump_dir / Path(data_name) / sample_file
|
||||
processed_data = ExtendedPath(sample_path).read_json()
|
||||
# asr_data = []
|
||||
match_count, total_samples = 0, len(processed_data["data"])
|
||||
for dp in tqdm(processed_data["data"]):
|
||||
# aud_data = Path(dp["audio_path"]).read_bytes()
|
||||
# dgram_result = deepgram_transcriber(aud_data)
|
||||
# dp["deepgram_asr"] = dgram_result
|
||||
gcp_num = dp["text"]
|
||||
dgm_num = t2n.convert(dp["deepgram_asr"].lower())
|
||||
if is_sub_sequence(gcp_num, dgm_num):
|
||||
match_count += 1
|
||||
print(f"MATCH GCP:{gcp_num}\tDGM:{dgm_num}")
|
||||
else:
|
||||
print(f"FAIL GCP:{gcp_num}\tDGM:{dgm_num}")
|
||||
# asr_data.append(dp)
|
||||
typer.echo(
|
||||
f"{match_count} from deepgram matches with {total_samples} gcp transcripts."
|
||||
)
|
||||
# processed_data["data"] = asr_data
|
||||
# ExtendedPath(sample_path).write_json(processed_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def task_ui(
|
||||
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||
|
|
@ -190,7 +231,9 @@ def dump_corrections(
|
|||
col = get_mongo_conn(col="asr_validation")
|
||||
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
||||
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
||||
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
|
||||
cursor_obj = col.find(
|
||||
{"type": "correction", "task_id": task_id}, projection={"_id": False}
|
||||
)
|
||||
corrections = [c for c in cursor_obj]
|
||||
ExtendedPath(dump_path).write_json(corrections)
|
||||
|
||||
|
|
@ -271,7 +314,9 @@ def split_extract(
|
|||
dump_file: Path = Path("ui_dump.json"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
||||
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
||||
conv_data_path: Path = typer.Option(
|
||||
Path("./data/conv_data.json"), show_default=True
|
||||
),
|
||||
extraction_type: ExtractionType = ExtractionType.all,
|
||||
):
|
||||
import shutil
|
||||
|
|
@ -293,7 +338,9 @@ def split_extract(
|
|||
def extract_manifest(mg):
|
||||
for m in mg:
|
||||
if m["text"] in extraction_vals:
|
||||
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
|
||||
shutil.copy(
|
||||
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
|
||||
)
|
||||
yield m
|
||||
|
||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||
|
|
@ -302,12 +349,14 @@ def split_extract(
|
|||
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
||||
ui_data = orig_ui_data["data"]
|
||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
||||
extracted_ui_data = list(
|
||||
filter(lambda u: u["text"] in extraction_vals, ui_data)
|
||||
)
|
||||
final_data = []
|
||||
for i, d in enumerate(extracted_ui_data):
|
||||
d['real_idx'] = i
|
||||
d["real_idx"] = i
|
||||
final_data.append(d)
|
||||
orig_ui_data['data'] = final_data
|
||||
orig_ui_data["data"] = final_data
|
||||
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
||||
|
||||
if corrections_file:
|
||||
|
|
@ -323,7 +372,7 @@ def split_extract(
|
|||
)
|
||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
||||
|
||||
if extraction_type.value == 'all':
|
||||
if extraction_type.value == "all":
|
||||
for ext_key in conv_data.keys():
|
||||
extract_data_of_type(ext_key)
|
||||
else:
|
||||
|
|
@ -345,7 +394,7 @@ def update_corrections(
|
|||
|
||||
def correct_manifest(ui_dump_path, corrections_path):
|
||||
corrections = ExtendedPath(corrections_path).read_json()
|
||||
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
|
||||
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
|
||||
correct_set = {
|
||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||
}
|
||||
|
|
@ -374,7 +423,9 @@ def update_corrections(
|
|||
)
|
||||
else:
|
||||
orig_audio_path = Path(d["audio_path"])
|
||||
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
|
||||
new_name = str(
|
||||
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
|
||||
)
|
||||
new_audio_path = orig_audio_path.with_name(new_name)
|
||||
orig_audio_path.replace(new_audio_path)
|
||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||
|
|
|
|||
Loading…
Reference in New Issue