1. added deepgram support

2. compute asr sample accuracy
Malar Kannan 2020-08-07 12:02:01 +05:30
parent fa89775f86
commit f5c49338d9
2 changed files with 269 additions and 10 deletions

View File

@ -1,13 +1,16 @@
import io import io
import os import os
import re
import json import json
import base64
import wave import wave
from pathlib import Path from pathlib import Path
from itertools import product from itertools import product
from functools import partial from functools import partial
from math import floor from math import floor
from uuid import uuid4 from uuid import uuid4
from urllib.parse import urlsplit from urllib.parse import urlsplit, urlencode
from urllib.request import Request, urlopen
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import numpy as np import numpy as np
@ -100,6 +103,9 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"data": [], "data": [],
} }
data_funcs = [] data_funcs = []
deepgram_transcriber = deepgram_transcribe_gen()
# t2n = Text2Num()
transcriber_gcp = gcp_transcribe_gen() transcriber_gcp = gcp_transcribe_gen()
transcriber_pretrained = transcribe_gen(asr_port=8044) transcriber_pretrained = transcribe_gen(asr_port=8044)
with asr_manifest.open("w") as mf: with asr_manifest.open("w") as mf:
@ -119,6 +125,10 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
pretrained_result = transcriber_pretrained(aud_seg.raw_data) pretrained_result = transcriber_pretrained(aud_seg.raw_data)
gcp_seg = aud_seg.set_frame_rate(16000) gcp_seg = aud_seg.set_frame_rate(16000)
gcp_result = transcriber_gcp(gcp_seg.raw_data) gcp_result = transcriber_gcp(gcp_seg.raw_data)
aud_data = audio_path.read_bytes()
dgram_result = deepgram_transcriber(aud_data)
# gtruth = dp['text']
# dgram_result = t2n.convert(dgram_script)
pretrained_wer = word_error_rate([transcript], [pretrained_result]) pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = ( wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png") dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
@ -135,6 +145,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"caller": caller_name, "caller": caller_name,
"utterance_id": fname, "utterance_id": fname,
"gcp_asr": gcp_result, "gcp_asr": gcp_result,
"deepgram_asr": dgram_result,
"pretrained_asr": pretrained_result, "pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer, "pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path), "plot_path": str(wav_plot_path),
@ -471,6 +482,203 @@ def gcp_transcribe_gen():
return sample_recognize return sample_recognize
def deepgram_transcribe_gen():
DEEPGRAM_URL = "https://brain.deepgram.com/v2/listen"
MODEL = "agara"
encoding = "linear16"
sample_rate = "8000"
# diarize = "false"
q_params = {
"model": MODEL,
"encoding": encoding,
"sample_rate": sample_rate,
"language": "en-US",
"multichannel": "false",
"punctuate": "true",
}
url = "{}?{}".format(DEEPGRAM_URL, urlencode(q_params))
# print(url)
creds = ("arjun@agaralabs.com", "PoX1Y@x4h%oS")
def deepgram_offline(audio_data):
request = Request(
url,
method="POST",
headers={
"Authorization": "Basic {}".format(
base64.b64encode("{}:{}".format(*creds).encode("utf-8")).decode(
"utf-8"
)
)
},
data=audio_data,
)
with urlopen(request) as response:
msg = json.loads(response.read())
data = msg["results"]["channels"][0]["alternatives"][0]
return data["transcript"]
return deepgram_offline
class Text2Num(object):
"""docstring for Text2Num."""
def __init__(self):
numwords = {}
if not numwords:
units = [
"zero",
"one",
"two",
"three",
"four",
"five",
"six",
"seven",
"eight",
"nine",
"ten",
"eleven",
"twelve",
"thirteen",
"fourteen",
"fifteen",
"sixteen",
"seventeen",
"eighteen",
"nineteen",
]
tens = [
"",
"",
"twenty",
"thirty",
"forty",
"fifty",
"sixty",
"seventy",
"eighty",
"ninety",
]
scales = ["hundred", "thousand", "million", "billion", "trillion"]
numwords["and"] = (1, 0)
for idx, word in enumerate(units):
numwords[word] = (1, idx)
for idx, word in enumerate(tens):
numwords[word] = (1, idx * 10)
for idx, word in enumerate(scales):
numwords[word] = (10 ** (idx * 3 or 2), 0)
self.numwords = numwords
def is_num(self, word):
return word in self.numwords
def parseOrdinal(self, utterance, **kwargs):
lookup_dict = {
"first": 1,
"second": 2,
"third": 3,
"fourth": 4,
"fifth": 5,
"sixth": 6,
"seventh": 7,
"eighth": 8,
"ninth": 9,
"tenth": 10,
"one": 1,
"two": 2,
"three": 3,
"four": 4,
"five": 5,
"six": 6,
"seven": 7,
"eight": 8,
"nine": 9,
"ten": 10,
"1": 1,
"2": 2,
"3": 3,
"4": 4,
"5": 5,
"6": 6,
"7": 7,
"8": 8,
"9": 9,
"10": 10,
"last": -1,
}
pattern = re.compile(
r"(\s|^)(?P<num>(first)|(third)|(fourth)|(fifth)|(sixth)|(seventh)|(eighth)|(ninth)|(tenth)|(two)|(three)|(four)|(five)|(six)|(seven)|(eight)|(nine)|(ten)|(1)|(2)|(3)|(4)|(5)|(6)|(7)|(8)|(9)|(10)|(last))(\s|$)",
re.IGNORECASE,
)
ordinal = ""
if pattern.search(utterance):
ordinal = pattern.search(utterance).groupdict()["num"].strip()
elif re.search(r"(\s|^)(?P<num>(second))(\s|$)", utterance):
ordinal = "second"
elif re.search(r"(\s|^)(?P<num>(one))(\s|$)", utterance):
ordinal = "one"
ordinal = lookup_dict.get(ordinal, "")
return ordinal
def convert(self, sent):
# res = []
# for token in sent.split():
# if token in self.numwords:
# res.append(str(self.text2int(token)))
# else:
# res.append(token)
# return " ".join(res)
return " ".join(
[
str(self.parseOrdinal(x)) if self.parseOrdinal(x) != "" else x
for x in sent.split()
]
)
def text2int(self, textnum):
current = result = 0
for word in textnum.split():
if word not in self.numwords:
raise Exception("Illegal word: " + word)
scale, increment = self.numwords[word]
current = current * scale + increment
if scale > 100:
result += current
current = 0
return result + current
def is_sub_sequence(str1, str2):
m = len(str1)
n = len(str2)
def check_seq(string1, string2, m, n):
# Base Cases
if m == 0:
return True
if n == 0:
return False
# If last characters of two strings are matching
if string1[m - 1] == string2[n - 1]:
return check_seq(string1, string2, m - 1, n - 1)
# If last characters are not matching
return check_seq(string1, string2, m, n - 1)
return check_seq(str1, str2, m, n)
def parallel_apply(fn, iterable, workers=8): def parallel_apply(fn, iterable, workers=8):
with ThreadPoolExecutor(max_workers=workers) as exe: with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}") print(f"parallelly applying {fn}")

View File

@ -156,6 +156,47 @@ def sample_ui(
ExtendedPath(sample_path).write_json(processed_data) ExtendedPath(sample_path).write_json(processed_data)
@app.command()
def sample_asr_accuracy(
data_name: str = typer.Option(
"png_06_2020_week1_numbers_window_customer", show_default=True
),
dump_dir: Path = Path("./data/asr_data"),
sample_file: Path = Path("sample_dump.json"),
asr_service: str = "deepgram",
):
# import pandas as pd
# from pydub import AudioSegment
from ..utils import is_sub_sequence, Text2Num
# from ..utils import deepgram_transcribe_gen
#
# deepgram_transcriber = deepgram_transcribe_gen()
t2n = Text2Num()
# processed_data_path = dump_dir / Path(data_name) / dump_file
sample_path = dump_dir / Path(data_name) / sample_file
processed_data = ExtendedPath(sample_path).read_json()
# asr_data = []
match_count, total_samples = 0, len(processed_data["data"])
for dp in tqdm(processed_data["data"]):
# aud_data = Path(dp["audio_path"]).read_bytes()
# dgram_result = deepgram_transcriber(aud_data)
# dp["deepgram_asr"] = dgram_result
gcp_num = dp["text"]
dgm_num = t2n.convert(dp["deepgram_asr"].lower())
if is_sub_sequence(gcp_num, dgm_num):
match_count += 1
print(f"MATCH GCP:{gcp_num}\tDGM:{dgm_num}")
else:
print(f"FAIL GCP:{gcp_num}\tDGM:{dgm_num}")
# asr_data.append(dp)
typer.echo(
f"{match_count} from deepgram matches with {total_samples} gcp transcripts."
)
# processed_data["data"] = asr_data
# ExtendedPath(sample_path).write_json(processed_data)
@app.command() @app.command()
def task_ui( def task_ui(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True), data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
@ -190,7 +231,9 @@ def dump_corrections(
col = get_mongo_conn(col="asr_validation") col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0] task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False})) corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False}) cursor_obj = col.find(
{"type": "correction", "task_id": task_id}, projection={"_id": False}
)
corrections = [c for c in cursor_obj] corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections) ExtendedPath(dump_path).write_json(corrections)
@ -271,7 +314,9 @@ def split_extract(
dump_file: Path = Path("ui_dump.json"), dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"), manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True), corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True), conv_data_path: Path = typer.Option(
Path("./data/conv_data.json"), show_default=True
),
extraction_type: ExtractionType = ExtractionType.all, extraction_type: ExtractionType = ExtractionType.all,
): ):
import shutil import shutil
@ -293,7 +338,9 @@ def split_extract(
def extract_manifest(mg): def extract_manifest(mg):
for m in mg: for m in mg:
if m["text"] in extraction_vals: if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"])) shutil.copy(
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
)
yield m yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen)) asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
@ -302,12 +349,14 @@ def split_extract(
orig_ui_data = ExtendedPath(ui_data_path).read_json() orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"] ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data} file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data)) extracted_ui_data = list(
filter(lambda u: u["text"] in extraction_vals, ui_data)
)
final_data = [] final_data = []
for i, d in enumerate(extracted_ui_data): for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i d["real_idx"] = i
final_data.append(d) final_data.append(d)
orig_ui_data['data'] = final_data orig_ui_data["data"] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data) ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file: if corrections_file:
@ -323,7 +372,7 @@ def split_extract(
) )
ExtendedPath(dest_correction_path).write_json(extracted_corrections) ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all': if extraction_type.value == "all":
for ext_key in conv_data.keys(): for ext_key in conv_data.keys():
extract_data_of_type(ext_key) extract_data_of_type(ext_key)
else: else:
@ -345,7 +394,7 @@ def update_corrections(
def correct_manifest(ui_dump_path, corrections_path): def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json() corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data'] ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
correct_set = { correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct" c["code"] for c in corrections if c["value"]["status"] == "Correct"
} }
@ -374,7 +423,9 @@ def update_corrections(
) )
else: else:
orig_audio_path = Path(d["audio_path"]) orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")) new_name = str(
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
)
new_audio_path = orig_audio_path.with_name(new_name) new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path) orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))