Compare commits

...

2 Commits

Author SHA1 Message Date
Malar Kannan f5c49338d9 1. added deepgram support
2. compute asr sample accuracy
2020-08-07 12:02:01 +05:30
Malar Kannan fa89775f86 1. add a new streamlit ui to preview manifest
2. implement rpcy transcription client for files
2020-08-07 12:00:33 +05:30
4 changed files with 362 additions and 10 deletions

View File

@ -2,6 +2,10 @@ import os
import logging import logging
import rpyc import rpyc
from functools import lru_cache from functools import lru_cache
import typer
from pathlib import Path
app = typer.Typer()
logging.basicConfig( logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@ -19,3 +23,28 @@ def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
asr = rpyc.connect(asr_host, asr_port).root asr = rpyc.connect(asr_host, asr_port).root
logger.info(f"connected to asr server successfully") logger.info(f"connected to asr server successfully")
return asr.transcribe return asr.transcribe
@app.command()
def transcribe_file(audio_file: Path):
from pydub import AudioSegment
transcriber = transcribe_gen()
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_file)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
tscript_file_path = audio_file.with_suffix(".txt")
transcription = transcriber(aud_seg.raw_data)
with open(tscript_file_path, "w") as tf:
tf.write(transcription)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -1,13 +1,16 @@
import io import io
import os import os
import re
import json import json
import base64
import wave import wave
from pathlib import Path from pathlib import Path
from itertools import product from itertools import product
from functools import partial from functools import partial
from math import floor from math import floor
from uuid import uuid4 from uuid import uuid4
from urllib.parse import urlsplit from urllib.parse import urlsplit, urlencode
from urllib.request import Request, urlopen
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import numpy as np import numpy as np
@ -100,6 +103,9 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"data": [], "data": [],
} }
data_funcs = [] data_funcs = []
deepgram_transcriber = deepgram_transcribe_gen()
# t2n = Text2Num()
transcriber_gcp = gcp_transcribe_gen() transcriber_gcp = gcp_transcribe_gen()
transcriber_pretrained = transcribe_gen(asr_port=8044) transcriber_pretrained = transcribe_gen(asr_port=8044)
with asr_manifest.open("w") as mf: with asr_manifest.open("w") as mf:
@ -119,6 +125,10 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
pretrained_result = transcriber_pretrained(aud_seg.raw_data) pretrained_result = transcriber_pretrained(aud_seg.raw_data)
gcp_seg = aud_seg.set_frame_rate(16000) gcp_seg = aud_seg.set_frame_rate(16000)
gcp_result = transcriber_gcp(gcp_seg.raw_data) gcp_result = transcriber_gcp(gcp_seg.raw_data)
aud_data = audio_path.read_bytes()
dgram_result = deepgram_transcriber(aud_data)
# gtruth = dp['text']
# dgram_result = t2n.convert(dgram_script)
pretrained_wer = word_error_rate([transcript], [pretrained_result]) pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = ( wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png") dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
@ -135,6 +145,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"caller": caller_name, "caller": caller_name,
"utterance_id": fname, "utterance_id": fname,
"gcp_asr": gcp_result, "gcp_asr": gcp_result,
"deepgram_asr": dgram_result,
"pretrained_asr": pretrained_result, "pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer, "pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path), "plot_path": str(wav_plot_path),
@ -225,6 +236,12 @@ class ExtendedPath(type(Path())):
with self.open("r") as jf: with self.open("r") as jf:
return json.load(jf) return json.load(jf)
def read_jsonl(self):
print(f"reading jsonl from {self}")
with self.open("r") as jf:
for l in jf.readlines():
yield json.loads(l)
def write_json(self, data): def write_json(self, data):
print(f"writing json to {self}") print(f"writing json to {self}")
self.parent.mkdir(parents=True, exist_ok=True) self.parent.mkdir(parents=True, exist_ok=True)
@ -465,6 +482,203 @@ def gcp_transcribe_gen():
return sample_recognize return sample_recognize
def deepgram_transcribe_gen():
DEEPGRAM_URL = "https://brain.deepgram.com/v2/listen"
MODEL = "agara"
encoding = "linear16"
sample_rate = "8000"
# diarize = "false"
q_params = {
"model": MODEL,
"encoding": encoding,
"sample_rate": sample_rate,
"language": "en-US",
"multichannel": "false",
"punctuate": "true",
}
url = "{}?{}".format(DEEPGRAM_URL, urlencode(q_params))
# print(url)
creds = ("arjun@agaralabs.com", "PoX1Y@x4h%oS")
def deepgram_offline(audio_data):
request = Request(
url,
method="POST",
headers={
"Authorization": "Basic {}".format(
base64.b64encode("{}:{}".format(*creds).encode("utf-8")).decode(
"utf-8"
)
)
},
data=audio_data,
)
with urlopen(request) as response:
msg = json.loads(response.read())
data = msg["results"]["channels"][0]["alternatives"][0]
return data["transcript"]
return deepgram_offline
class Text2Num(object):
"""docstring for Text2Num."""
def __init__(self):
numwords = {}
if not numwords:
units = [
"zero",
"one",
"two",
"three",
"four",
"five",
"six",
"seven",
"eight",
"nine",
"ten",
"eleven",
"twelve",
"thirteen",
"fourteen",
"fifteen",
"sixteen",
"seventeen",
"eighteen",
"nineteen",
]
tens = [
"",
"",
"twenty",
"thirty",
"forty",
"fifty",
"sixty",
"seventy",
"eighty",
"ninety",
]
scales = ["hundred", "thousand", "million", "billion", "trillion"]
numwords["and"] = (1, 0)
for idx, word in enumerate(units):
numwords[word] = (1, idx)
for idx, word in enumerate(tens):
numwords[word] = (1, idx * 10)
for idx, word in enumerate(scales):
numwords[word] = (10 ** (idx * 3 or 2), 0)
self.numwords = numwords
def is_num(self, word):
return word in self.numwords
def parseOrdinal(self, utterance, **kwargs):
lookup_dict = {
"first": 1,
"second": 2,
"third": 3,
"fourth": 4,
"fifth": 5,
"sixth": 6,
"seventh": 7,
"eighth": 8,
"ninth": 9,
"tenth": 10,
"one": 1,
"two": 2,
"three": 3,
"four": 4,
"five": 5,
"six": 6,
"seven": 7,
"eight": 8,
"nine": 9,
"ten": 10,
"1": 1,
"2": 2,
"3": 3,
"4": 4,
"5": 5,
"6": 6,
"7": 7,
"8": 8,
"9": 9,
"10": 10,
"last": -1,
}
pattern = re.compile(
r"(\s|^)(?P<num>(first)|(third)|(fourth)|(fifth)|(sixth)|(seventh)|(eighth)|(ninth)|(tenth)|(two)|(three)|(four)|(five)|(six)|(seven)|(eight)|(nine)|(ten)|(1)|(2)|(3)|(4)|(5)|(6)|(7)|(8)|(9)|(10)|(last))(\s|$)",
re.IGNORECASE,
)
ordinal = ""
if pattern.search(utterance):
ordinal = pattern.search(utterance).groupdict()["num"].strip()
elif re.search(r"(\s|^)(?P<num>(second))(\s|$)", utterance):
ordinal = "second"
elif re.search(r"(\s|^)(?P<num>(one))(\s|$)", utterance):
ordinal = "one"
ordinal = lookup_dict.get(ordinal, "")
return ordinal
def convert(self, sent):
# res = []
# for token in sent.split():
# if token in self.numwords:
# res.append(str(self.text2int(token)))
# else:
# res.append(token)
# return " ".join(res)
return " ".join(
[
str(self.parseOrdinal(x)) if self.parseOrdinal(x) != "" else x
for x in sent.split()
]
)
def text2int(self, textnum):
current = result = 0
for word in textnum.split():
if word not in self.numwords:
raise Exception("Illegal word: " + word)
scale, increment = self.numwords[word]
current = current * scale + increment
if scale > 100:
result += current
current = 0
return result + current
def is_sub_sequence(str1, str2):
m = len(str1)
n = len(str2)
def check_seq(string1, string2, m, n):
# Base Cases
if m == 0:
return True
if n == 0:
return False
# If last characters of two strings are matching
if string1[m - 1] == string2[n - 1]:
return check_seq(string1, string2, m - 1, n - 1)
# If last characters are not matching
return check_seq(string1, string2, m, n - 1)
return check_seq(str1, str2, m, n)
def parallel_apply(fn, iterable, workers=8): def parallel_apply(fn, iterable, workers=8):
with ThreadPoolExecutor(max_workers=workers) as exe: with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}") print(f"parallelly applying {fn}")

View File

@ -156,6 +156,47 @@ def sample_ui(
ExtendedPath(sample_path).write_json(processed_data) ExtendedPath(sample_path).write_json(processed_data)
@app.command()
def sample_asr_accuracy(
data_name: str = typer.Option(
"png_06_2020_week1_numbers_window_customer", show_default=True
),
dump_dir: Path = Path("./data/asr_data"),
sample_file: Path = Path("sample_dump.json"),
asr_service: str = "deepgram",
):
# import pandas as pd
# from pydub import AudioSegment
from ..utils import is_sub_sequence, Text2Num
# from ..utils import deepgram_transcribe_gen
#
# deepgram_transcriber = deepgram_transcribe_gen()
t2n = Text2Num()
# processed_data_path = dump_dir / Path(data_name) / dump_file
sample_path = dump_dir / Path(data_name) / sample_file
processed_data = ExtendedPath(sample_path).read_json()
# asr_data = []
match_count, total_samples = 0, len(processed_data["data"])
for dp in tqdm(processed_data["data"]):
# aud_data = Path(dp["audio_path"]).read_bytes()
# dgram_result = deepgram_transcriber(aud_data)
# dp["deepgram_asr"] = dgram_result
gcp_num = dp["text"]
dgm_num = t2n.convert(dp["deepgram_asr"].lower())
if is_sub_sequence(gcp_num, dgm_num):
match_count += 1
print(f"MATCH GCP:{gcp_num}\tDGM:{dgm_num}")
else:
print(f"FAIL GCP:{gcp_num}\tDGM:{dgm_num}")
# asr_data.append(dp)
typer.echo(
f"{match_count} from deepgram matches with {total_samples} gcp transcripts."
)
# processed_data["data"] = asr_data
# ExtendedPath(sample_path).write_json(processed_data)
@app.command() @app.command()
def task_ui( def task_ui(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True), data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
@ -190,7 +231,9 @@ def dump_corrections(
col = get_mongo_conn(col="asr_validation") col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0] task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False})) corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False}) cursor_obj = col.find(
{"type": "correction", "task_id": task_id}, projection={"_id": False}
)
corrections = [c for c in cursor_obj] corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections) ExtendedPath(dump_path).write_json(corrections)
@ -271,7 +314,9 @@ def split_extract(
dump_file: Path = Path("ui_dump.json"), dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"), manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True), corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True), conv_data_path: Path = typer.Option(
Path("./data/conv_data.json"), show_default=True
),
extraction_type: ExtractionType = ExtractionType.all, extraction_type: ExtractionType = ExtractionType.all,
): ):
import shutil import shutil
@ -293,7 +338,9 @@ def split_extract(
def extract_manifest(mg): def extract_manifest(mg):
for m in mg: for m in mg:
if m["text"] in extraction_vals: if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"])) shutil.copy(
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
)
yield m yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen)) asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
@ -302,12 +349,14 @@ def split_extract(
orig_ui_data = ExtendedPath(ui_data_path).read_json() orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"] ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data} file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data)) extracted_ui_data = list(
filter(lambda u: u["text"] in extraction_vals, ui_data)
)
final_data = [] final_data = []
for i, d in enumerate(extracted_ui_data): for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i d["real_idx"] = i
final_data.append(d) final_data.append(d)
orig_ui_data['data'] = final_data orig_ui_data["data"] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data) ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file: if corrections_file:
@ -323,7 +372,7 @@ def split_extract(
) )
ExtendedPath(dest_correction_path).write_json(extracted_corrections) ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all': if extraction_type.value == "all":
for ext_key in conv_data.keys(): for ext_key in conv_data.keys():
extract_data_of_type(ext_key) extract_data_of_type(ext_key)
else: else:
@ -345,7 +394,7 @@ def update_corrections(
def correct_manifest(ui_dump_path, corrections_path): def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json() corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data'] ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
correct_set = { correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct" c["code"] for c in corrections if c["value"]["status"] == "Correct"
} }
@ -374,7 +423,9 @@ def update_corrections(
) )
else: else:
orig_audio_path = Path(d["audio_path"]) orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")) new_name = str(
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
)
new_audio_path = orig_audio_path.with_name(new_name) new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path) orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))

58
manifest_preview.py Normal file
View File

@ -0,0 +1,58 @@
from pathlib import Path
import streamlit as st
import typer
from jasper.data.utils import ExtendedPath
from jasper.data.validation.st_rerun import rerun
app = typer.Typer()
if not hasattr(st, "mongo_connected"):
# st.task_id = str(uuid4())
task_path = ExtendedPath("preview.lck")
def current_cursor_fn():
return task_path.read_json()["current_cursor"]
def update_cursor_fn(val=0):
task_path.write_json({"current_cursor": val})
rerun()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn
st.mongo_connected = True
# cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
# if not cursor_obj:
update_cursor_fn(0)
@st.cache()
def load_ui_data(validation_ui_data_path: Path):
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
@app.command()
def main(manifest: Path):
asr_data = load_ui_data(manifest)
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
st.title(f"ASR Manifest Preview")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
st.audio(Path(sample["audio_filepath"]).open("rb"))
if __name__ == "__main__":
try:
app()
except SystemExit:
pass