Compare commits

..

No commits in common. "f5c49338d92f0c415b0de98609c008cd668ede49" and "ae5586be7224d6bb283d8abfcf7b67110ca94ab7" have entirely different histories.

4 changed files with 10 additions and 362 deletions

View File

@ -2,10 +2,6 @@ import os
import logging
import rpyc
from functools import lru_cache
import typer
from pathlib import Path
app = typer.Typer()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@ -23,28 +19,3 @@ def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
asr = rpyc.connect(asr_host, asr_port).root
logger.info(f"connected to asr server successfully")
return asr.transcribe
@app.command()
def transcribe_file(audio_file: Path):
from pydub import AudioSegment
transcriber = transcribe_gen()
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_file)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
tscript_file_path = audio_file.with_suffix(".txt")
transcription = transcriber(aud_seg.raw_data)
with open(tscript_file_path, "w") as tf:
tf.write(transcription)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -1,16 +1,13 @@
import io
import os
import re
import json
import base64
import wave
from pathlib import Path
from itertools import product
from functools import partial
from math import floor
from uuid import uuid4
from urllib.parse import urlsplit, urlencode
from urllib.request import Request, urlopen
from urllib.parse import urlsplit
from concurrent.futures import ThreadPoolExecutor
import numpy as np
@ -103,9 +100,6 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"data": [],
}
data_funcs = []
deepgram_transcriber = deepgram_transcribe_gen()
# t2n = Text2Num()
transcriber_gcp = gcp_transcribe_gen()
transcriber_pretrained = transcribe_gen(asr_port=8044)
with asr_manifest.open("w") as mf:
@ -125,10 +119,6 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
gcp_seg = aud_seg.set_frame_rate(16000)
gcp_result = transcriber_gcp(gcp_seg.raw_data)
aud_data = audio_path.read_bytes()
dgram_result = deepgram_transcriber(aud_data)
# gtruth = dp['text']
# dgram_result = t2n.convert(dgram_script)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
@ -145,7 +135,6 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"caller": caller_name,
"utterance_id": fname,
"gcp_asr": gcp_result,
"deepgram_asr": dgram_result,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
@ -236,12 +225,6 @@ class ExtendedPath(type(Path())):
with self.open("r") as jf:
return json.load(jf)
def read_jsonl(self):
print(f"reading jsonl from {self}")
with self.open("r") as jf:
for l in jf.readlines():
yield json.loads(l)
def write_json(self, data):
print(f"writing json to {self}")
self.parent.mkdir(parents=True, exist_ok=True)
@ -482,203 +465,6 @@ def gcp_transcribe_gen():
return sample_recognize
def deepgram_transcribe_gen():
DEEPGRAM_URL = "https://brain.deepgram.com/v2/listen"
MODEL = "agara"
encoding = "linear16"
sample_rate = "8000"
# diarize = "false"
q_params = {
"model": MODEL,
"encoding": encoding,
"sample_rate": sample_rate,
"language": "en-US",
"multichannel": "false",
"punctuate": "true",
}
url = "{}?{}".format(DEEPGRAM_URL, urlencode(q_params))
# print(url)
creds = ("arjun@agaralabs.com", "PoX1Y@x4h%oS")
def deepgram_offline(audio_data):
request = Request(
url,
method="POST",
headers={
"Authorization": "Basic {}".format(
base64.b64encode("{}:{}".format(*creds).encode("utf-8")).decode(
"utf-8"
)
)
},
data=audio_data,
)
with urlopen(request) as response:
msg = json.loads(response.read())
data = msg["results"]["channels"][0]["alternatives"][0]
return data["transcript"]
return deepgram_offline
class Text2Num(object):
"""docstring for Text2Num."""
def __init__(self):
numwords = {}
if not numwords:
units = [
"zero",
"one",
"two",
"three",
"four",
"five",
"six",
"seven",
"eight",
"nine",
"ten",
"eleven",
"twelve",
"thirteen",
"fourteen",
"fifteen",
"sixteen",
"seventeen",
"eighteen",
"nineteen",
]
tens = [
"",
"",
"twenty",
"thirty",
"forty",
"fifty",
"sixty",
"seventy",
"eighty",
"ninety",
]
scales = ["hundred", "thousand", "million", "billion", "trillion"]
numwords["and"] = (1, 0)
for idx, word in enumerate(units):
numwords[word] = (1, idx)
for idx, word in enumerate(tens):
numwords[word] = (1, idx * 10)
for idx, word in enumerate(scales):
numwords[word] = (10 ** (idx * 3 or 2), 0)
self.numwords = numwords
def is_num(self, word):
return word in self.numwords
def parseOrdinal(self, utterance, **kwargs):
lookup_dict = {
"first": 1,
"second": 2,
"third": 3,
"fourth": 4,
"fifth": 5,
"sixth": 6,
"seventh": 7,
"eighth": 8,
"ninth": 9,
"tenth": 10,
"one": 1,
"two": 2,
"three": 3,
"four": 4,
"five": 5,
"six": 6,
"seven": 7,
"eight": 8,
"nine": 9,
"ten": 10,
"1": 1,
"2": 2,
"3": 3,
"4": 4,
"5": 5,
"6": 6,
"7": 7,
"8": 8,
"9": 9,
"10": 10,
"last": -1,
}
pattern = re.compile(
r"(\s|^)(?P<num>(first)|(third)|(fourth)|(fifth)|(sixth)|(seventh)|(eighth)|(ninth)|(tenth)|(two)|(three)|(four)|(five)|(six)|(seven)|(eight)|(nine)|(ten)|(1)|(2)|(3)|(4)|(5)|(6)|(7)|(8)|(9)|(10)|(last))(\s|$)",
re.IGNORECASE,
)
ordinal = ""
if pattern.search(utterance):
ordinal = pattern.search(utterance).groupdict()["num"].strip()
elif re.search(r"(\s|^)(?P<num>(second))(\s|$)", utterance):
ordinal = "second"
elif re.search(r"(\s|^)(?P<num>(one))(\s|$)", utterance):
ordinal = "one"
ordinal = lookup_dict.get(ordinal, "")
return ordinal
def convert(self, sent):
# res = []
# for token in sent.split():
# if token in self.numwords:
# res.append(str(self.text2int(token)))
# else:
# res.append(token)
# return " ".join(res)
return " ".join(
[
str(self.parseOrdinal(x)) if self.parseOrdinal(x) != "" else x
for x in sent.split()
]
)
def text2int(self, textnum):
current = result = 0
for word in textnum.split():
if word not in self.numwords:
raise Exception("Illegal word: " + word)
scale, increment = self.numwords[word]
current = current * scale + increment
if scale > 100:
result += current
current = 0
return result + current
def is_sub_sequence(str1, str2):
m = len(str1)
n = len(str2)
def check_seq(string1, string2, m, n):
# Base Cases
if m == 0:
return True
if n == 0:
return False
# If last characters of two strings are matching
if string1[m - 1] == string2[n - 1]:
return check_seq(string1, string2, m - 1, n - 1)
# If last characters are not matching
return check_seq(string1, string2, m, n - 1)
return check_seq(str1, str2, m, n)
def parallel_apply(fn, iterable, workers=8):
with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")

View File

@ -156,47 +156,6 @@ def sample_ui(
ExtendedPath(sample_path).write_json(processed_data)
@app.command()
def sample_asr_accuracy(
data_name: str = typer.Option(
"png_06_2020_week1_numbers_window_customer", show_default=True
),
dump_dir: Path = Path("./data/asr_data"),
sample_file: Path = Path("sample_dump.json"),
asr_service: str = "deepgram",
):
# import pandas as pd
# from pydub import AudioSegment
from ..utils import is_sub_sequence, Text2Num
# from ..utils import deepgram_transcribe_gen
#
# deepgram_transcriber = deepgram_transcribe_gen()
t2n = Text2Num()
# processed_data_path = dump_dir / Path(data_name) / dump_file
sample_path = dump_dir / Path(data_name) / sample_file
processed_data = ExtendedPath(sample_path).read_json()
# asr_data = []
match_count, total_samples = 0, len(processed_data["data"])
for dp in tqdm(processed_data["data"]):
# aud_data = Path(dp["audio_path"]).read_bytes()
# dgram_result = deepgram_transcriber(aud_data)
# dp["deepgram_asr"] = dgram_result
gcp_num = dp["text"]
dgm_num = t2n.convert(dp["deepgram_asr"].lower())
if is_sub_sequence(gcp_num, dgm_num):
match_count += 1
print(f"MATCH GCP:{gcp_num}\tDGM:{dgm_num}")
else:
print(f"FAIL GCP:{gcp_num}\tDGM:{dgm_num}")
# asr_data.append(dp)
typer.echo(
f"{match_count} from deepgram matches with {total_samples} gcp transcripts."
)
# processed_data["data"] = asr_data
# ExtendedPath(sample_path).write_json(processed_data)
@app.command()
def task_ui(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
@ -231,9 +190,7 @@ def dump_corrections(
col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find(
{"type": "correction", "task_id": task_id}, projection={"_id": False}
)
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@ -314,9 +271,7 @@ def split_extract(
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(
Path("./data/conv_data.json"), show_default=True
),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
extraction_type: ExtractionType = ExtractionType.all,
):
import shutil
@ -338,9 +293,7 @@ def split_extract(
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
)
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
@ -349,14 +302,12 @@ def split_extract(
orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(
filter(lambda u: u["text"] in extraction_vals, ui_data)
)
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
final_data = []
for i, d in enumerate(extracted_ui_data):
d["real_idx"] = i
d['real_idx'] = i
final_data.append(d)
orig_ui_data["data"] = final_data
orig_ui_data['data'] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file:
@ -372,7 +323,7 @@ def split_extract(
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == "all":
if extraction_type.value == 'all':
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
@ -394,7 +345,7 @@ def update_corrections(
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
@ -423,9 +374,7 @@ def update_corrections(
)
else:
orig_audio_path = Path(d["audio_path"])
new_name = str(
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
)
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))

View File

@ -1,58 +0,0 @@
from pathlib import Path
import streamlit as st
import typer
from jasper.data.utils import ExtendedPath
from jasper.data.validation.st_rerun import rerun
app = typer.Typer()
if not hasattr(st, "mongo_connected"):
# st.task_id = str(uuid4())
task_path = ExtendedPath("preview.lck")
def current_cursor_fn():
return task_path.read_json()["current_cursor"]
def update_cursor_fn(val=0):
task_path.write_json({"current_cursor": val})
rerun()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn
st.mongo_connected = True
# cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
# if not cursor_obj:
update_cursor_fn(0)
@st.cache()
def load_ui_data(validation_ui_data_path: Path):
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
@app.command()
def main(manifest: Path):
asr_data = load_ui_data(manifest)
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
st.title(f"ASR Manifest Preview")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
st.audio(Path(sample["audio_filepath"]).open("rb"))
if __name__ == "__main__":
try:
app()
except SystemExit:
pass