Compare commits

...

2 Commits

Author SHA1 Message Date
Malar Kannan f5c49338d9 1. added deepgram support
2. compute asr sample accuracy
2020-08-07 12:02:01 +05:30
Malar Kannan fa89775f86 1. add a new streamlit ui to preview manifest
2. implement rpcy transcription client for files
2020-08-07 12:00:33 +05:30
4 changed files with 362 additions and 10 deletions

View File

@ -2,6 +2,10 @@ import os
import logging
import rpyc
from functools import lru_cache
import typer
from pathlib import Path
app = typer.Typer()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@ -19,3 +23,28 @@ def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
asr = rpyc.connect(asr_host, asr_port).root
logger.info(f"connected to asr server successfully")
return asr.transcribe
@app.command()
def transcribe_file(audio_file: Path):
from pydub import AudioSegment
transcriber = transcribe_gen()
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_file)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
tscript_file_path = audio_file.with_suffix(".txt")
transcription = transcriber(aud_seg.raw_data)
with open(tscript_file_path, "w") as tf:
tf.write(transcription)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -1,13 +1,16 @@
import io
import os
import re
import json
import base64
import wave
from pathlib import Path
from itertools import product
from functools import partial
from math import floor
from uuid import uuid4
from urllib.parse import urlsplit
from urllib.parse import urlsplit, urlencode
from urllib.request import Request, urlopen
from concurrent.futures import ThreadPoolExecutor
import numpy as np
@ -100,6 +103,9 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"data": [],
}
data_funcs = []
deepgram_transcriber = deepgram_transcribe_gen()
# t2n = Text2Num()
transcriber_gcp = gcp_transcribe_gen()
transcriber_pretrained = transcribe_gen(asr_port=8044)
with asr_manifest.open("w") as mf:
@ -119,6 +125,10 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
gcp_seg = aud_seg.set_frame_rate(16000)
gcp_result = transcriber_gcp(gcp_seg.raw_data)
aud_data = audio_path.read_bytes()
dgram_result = deepgram_transcriber(aud_data)
# gtruth = dp['text']
# dgram_result = t2n.convert(dgram_script)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
@ -135,6 +145,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"caller": caller_name,
"utterance_id": fname,
"gcp_asr": gcp_result,
"deepgram_asr": dgram_result,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
@ -225,6 +236,12 @@ class ExtendedPath(type(Path())):
with self.open("r") as jf:
return json.load(jf)
def read_jsonl(self):
print(f"reading jsonl from {self}")
with self.open("r") as jf:
for l in jf.readlines():
yield json.loads(l)
def write_json(self, data):
print(f"writing json to {self}")
self.parent.mkdir(parents=True, exist_ok=True)
@ -465,6 +482,203 @@ def gcp_transcribe_gen():
return sample_recognize
def deepgram_transcribe_gen():
DEEPGRAM_URL = "https://brain.deepgram.com/v2/listen"
MODEL = "agara"
encoding = "linear16"
sample_rate = "8000"
# diarize = "false"
q_params = {
"model": MODEL,
"encoding": encoding,
"sample_rate": sample_rate,
"language": "en-US",
"multichannel": "false",
"punctuate": "true",
}
url = "{}?{}".format(DEEPGRAM_URL, urlencode(q_params))
# print(url)
creds = ("arjun@agaralabs.com", "PoX1Y@x4h%oS")
def deepgram_offline(audio_data):
request = Request(
url,
method="POST",
headers={
"Authorization": "Basic {}".format(
base64.b64encode("{}:{}".format(*creds).encode("utf-8")).decode(
"utf-8"
)
)
},
data=audio_data,
)
with urlopen(request) as response:
msg = json.loads(response.read())
data = msg["results"]["channels"][0]["alternatives"][0]
return data["transcript"]
return deepgram_offline
class Text2Num(object):
"""docstring for Text2Num."""
def __init__(self):
numwords = {}
if not numwords:
units = [
"zero",
"one",
"two",
"three",
"four",
"five",
"six",
"seven",
"eight",
"nine",
"ten",
"eleven",
"twelve",
"thirteen",
"fourteen",
"fifteen",
"sixteen",
"seventeen",
"eighteen",
"nineteen",
]
tens = [
"",
"",
"twenty",
"thirty",
"forty",
"fifty",
"sixty",
"seventy",
"eighty",
"ninety",
]
scales = ["hundred", "thousand", "million", "billion", "trillion"]
numwords["and"] = (1, 0)
for idx, word in enumerate(units):
numwords[word] = (1, idx)
for idx, word in enumerate(tens):
numwords[word] = (1, idx * 10)
for idx, word in enumerate(scales):
numwords[word] = (10 ** (idx * 3 or 2), 0)
self.numwords = numwords
def is_num(self, word):
return word in self.numwords
def parseOrdinal(self, utterance, **kwargs):
lookup_dict = {
"first": 1,
"second": 2,
"third": 3,
"fourth": 4,
"fifth": 5,
"sixth": 6,
"seventh": 7,
"eighth": 8,
"ninth": 9,
"tenth": 10,
"one": 1,
"two": 2,
"three": 3,
"four": 4,
"five": 5,
"six": 6,
"seven": 7,
"eight": 8,
"nine": 9,
"ten": 10,
"1": 1,
"2": 2,
"3": 3,
"4": 4,
"5": 5,
"6": 6,
"7": 7,
"8": 8,
"9": 9,
"10": 10,
"last": -1,
}
pattern = re.compile(
r"(\s|^)(?P<num>(first)|(third)|(fourth)|(fifth)|(sixth)|(seventh)|(eighth)|(ninth)|(tenth)|(two)|(three)|(four)|(five)|(six)|(seven)|(eight)|(nine)|(ten)|(1)|(2)|(3)|(4)|(5)|(6)|(7)|(8)|(9)|(10)|(last))(\s|$)",
re.IGNORECASE,
)
ordinal = ""
if pattern.search(utterance):
ordinal = pattern.search(utterance).groupdict()["num"].strip()
elif re.search(r"(\s|^)(?P<num>(second))(\s|$)", utterance):
ordinal = "second"
elif re.search(r"(\s|^)(?P<num>(one))(\s|$)", utterance):
ordinal = "one"
ordinal = lookup_dict.get(ordinal, "")
return ordinal
def convert(self, sent):
# res = []
# for token in sent.split():
# if token in self.numwords:
# res.append(str(self.text2int(token)))
# else:
# res.append(token)
# return " ".join(res)
return " ".join(
[
str(self.parseOrdinal(x)) if self.parseOrdinal(x) != "" else x
for x in sent.split()
]
)
def text2int(self, textnum):
current = result = 0
for word in textnum.split():
if word not in self.numwords:
raise Exception("Illegal word: " + word)
scale, increment = self.numwords[word]
current = current * scale + increment
if scale > 100:
result += current
current = 0
return result + current
def is_sub_sequence(str1, str2):
m = len(str1)
n = len(str2)
def check_seq(string1, string2, m, n):
# Base Cases
if m == 0:
return True
if n == 0:
return False
# If last characters of two strings are matching
if string1[m - 1] == string2[n - 1]:
return check_seq(string1, string2, m - 1, n - 1)
# If last characters are not matching
return check_seq(string1, string2, m, n - 1)
return check_seq(str1, str2, m, n)
def parallel_apply(fn, iterable, workers=8):
with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")

View File

@ -156,6 +156,47 @@ def sample_ui(
ExtendedPath(sample_path).write_json(processed_data)
@app.command()
def sample_asr_accuracy(
data_name: str = typer.Option(
"png_06_2020_week1_numbers_window_customer", show_default=True
),
dump_dir: Path = Path("./data/asr_data"),
sample_file: Path = Path("sample_dump.json"),
asr_service: str = "deepgram",
):
# import pandas as pd
# from pydub import AudioSegment
from ..utils import is_sub_sequence, Text2Num
# from ..utils import deepgram_transcribe_gen
#
# deepgram_transcriber = deepgram_transcribe_gen()
t2n = Text2Num()
# processed_data_path = dump_dir / Path(data_name) / dump_file
sample_path = dump_dir / Path(data_name) / sample_file
processed_data = ExtendedPath(sample_path).read_json()
# asr_data = []
match_count, total_samples = 0, len(processed_data["data"])
for dp in tqdm(processed_data["data"]):
# aud_data = Path(dp["audio_path"]).read_bytes()
# dgram_result = deepgram_transcriber(aud_data)
# dp["deepgram_asr"] = dgram_result
gcp_num = dp["text"]
dgm_num = t2n.convert(dp["deepgram_asr"].lower())
if is_sub_sequence(gcp_num, dgm_num):
match_count += 1
print(f"MATCH GCP:{gcp_num}\tDGM:{dgm_num}")
else:
print(f"FAIL GCP:{gcp_num}\tDGM:{dgm_num}")
# asr_data.append(dp)
typer.echo(
f"{match_count} from deepgram matches with {total_samples} gcp transcripts."
)
# processed_data["data"] = asr_data
# ExtendedPath(sample_path).write_json(processed_data)
@app.command()
def task_ui(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
@ -190,7 +231,9 @@ def dump_corrections(
col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
cursor_obj = col.find(
{"type": "correction", "task_id": task_id}, projection={"_id": False}
)
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@ -271,7 +314,9 @@ def split_extract(
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
conv_data_path: Path = typer.Option(
Path("./data/conv_data.json"), show_default=True
),
extraction_type: ExtractionType = ExtractionType.all,
):
import shutil
@ -293,7 +338,9 @@ def split_extract(
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
shutil.copy(
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
)
yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
@ -302,12 +349,14 @@ def split_extract(
orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
extracted_ui_data = list(
filter(lambda u: u["text"] in extraction_vals, ui_data)
)
final_data = []
for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i
d["real_idx"] = i
final_data.append(d)
orig_ui_data['data'] = final_data
orig_ui_data["data"] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file:
@ -323,7 +372,7 @@ def split_extract(
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all':
if extraction_type.value == "all":
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
@ -345,7 +394,7 @@ def update_corrections(
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
@ -374,7 +423,9 @@ def update_corrections(
)
else:
orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
new_name = str(
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
)
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))

58
manifest_preview.py Normal file
View File

@ -0,0 +1,58 @@
from pathlib import Path
import streamlit as st
import typer
from jasper.data.utils import ExtendedPath
from jasper.data.validation.st_rerun import rerun
app = typer.Typer()
if not hasattr(st, "mongo_connected"):
# st.task_id = str(uuid4())
task_path = ExtendedPath("preview.lck")
def current_cursor_fn():
return task_path.read_json()["current_cursor"]
def update_cursor_fn(val=0):
task_path.write_json({"current_cursor": val})
rerun()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn
st.mongo_connected = True
# cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
# if not cursor_obj:
update_cursor_fn(0)
@st.cache()
def load_ui_data(validation_ui_data_path: Path):
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
@app.command()
def main(manifest: Path):
asr_data = load_ui_data(manifest)
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
st.title(f"ASR Manifest Preview")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
st.audio(Path(sample["audio_filepath"]).open("rb"))
if __name__ == "__main__":
try:
app()
except SystemExit:
pass