Compare commits
2 Commits
ae5586be72
...
f5c49338d9
| Author | SHA1 | Date |
|---|---|---|
|
|
f5c49338d9 | |
|
|
fa89775f86 |
|
|
@ -2,6 +2,10 @@ import os
|
||||||
import logging
|
import logging
|
||||||
import rpyc
|
import rpyc
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
import typer
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
|
@ -19,3 +23,28 @@ def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
|
||||||
asr = rpyc.connect(asr_host, asr_port).root
|
asr = rpyc.connect(asr_host, asr_port).root
|
||||||
logger.info(f"connected to asr server successfully")
|
logger.info(f"connected to asr server successfully")
|
||||||
return asr.transcribe
|
return asr.transcribe
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def transcribe_file(audio_file: Path):
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
transcriber = transcribe_gen()
|
||||||
|
aud_seg = (
|
||||||
|
AudioSegment.from_file_using_temporary_files(audio_file)
|
||||||
|
.set_channels(1)
|
||||||
|
.set_sample_width(2)
|
||||||
|
.set_frame_rate(24000)
|
||||||
|
)
|
||||||
|
tscript_file_path = audio_file.with_suffix(".txt")
|
||||||
|
transcription = transcriber(aud_seg.raw_data)
|
||||||
|
with open(tscript_file_path, "w") as tf:
|
||||||
|
tf.write(transcription)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,16 @@
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import json
|
import json
|
||||||
|
import base64
|
||||||
import wave
|
import wave
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from math import floor
|
from math import floor
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit, urlencode
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -100,6 +103,9 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
||||||
"data": [],
|
"data": [],
|
||||||
}
|
}
|
||||||
data_funcs = []
|
data_funcs = []
|
||||||
|
|
||||||
|
deepgram_transcriber = deepgram_transcribe_gen()
|
||||||
|
# t2n = Text2Num()
|
||||||
transcriber_gcp = gcp_transcribe_gen()
|
transcriber_gcp = gcp_transcribe_gen()
|
||||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||||
with asr_manifest.open("w") as mf:
|
with asr_manifest.open("w") as mf:
|
||||||
|
|
@ -119,6 +125,10 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
||||||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
||||||
gcp_seg = aud_seg.set_frame_rate(16000)
|
gcp_seg = aud_seg.set_frame_rate(16000)
|
||||||
gcp_result = transcriber_gcp(gcp_seg.raw_data)
|
gcp_result = transcriber_gcp(gcp_seg.raw_data)
|
||||||
|
aud_data = audio_path.read_bytes()
|
||||||
|
dgram_result = deepgram_transcriber(aud_data)
|
||||||
|
# gtruth = dp['text']
|
||||||
|
# dgram_result = t2n.convert(dgram_script)
|
||||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
||||||
wav_plot_path = (
|
wav_plot_path = (
|
||||||
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
||||||
|
|
@ -135,6 +145,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
||||||
"caller": caller_name,
|
"caller": caller_name,
|
||||||
"utterance_id": fname,
|
"utterance_id": fname,
|
||||||
"gcp_asr": gcp_result,
|
"gcp_asr": gcp_result,
|
||||||
|
"deepgram_asr": dgram_result,
|
||||||
"pretrained_asr": pretrained_result,
|
"pretrained_asr": pretrained_result,
|
||||||
"pretrained_wer": pretrained_wer,
|
"pretrained_wer": pretrained_wer,
|
||||||
"plot_path": str(wav_plot_path),
|
"plot_path": str(wav_plot_path),
|
||||||
|
|
@ -225,6 +236,12 @@ class ExtendedPath(type(Path())):
|
||||||
with self.open("r") as jf:
|
with self.open("r") as jf:
|
||||||
return json.load(jf)
|
return json.load(jf)
|
||||||
|
|
||||||
|
def read_jsonl(self):
|
||||||
|
print(f"reading jsonl from {self}")
|
||||||
|
with self.open("r") as jf:
|
||||||
|
for l in jf.readlines():
|
||||||
|
yield json.loads(l)
|
||||||
|
|
||||||
def write_json(self, data):
|
def write_json(self, data):
|
||||||
print(f"writing json to {self}")
|
print(f"writing json to {self}")
|
||||||
self.parent.mkdir(parents=True, exist_ok=True)
|
self.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -465,6 +482,203 @@ def gcp_transcribe_gen():
|
||||||
return sample_recognize
|
return sample_recognize
|
||||||
|
|
||||||
|
|
||||||
|
def deepgram_transcribe_gen():
|
||||||
|
|
||||||
|
DEEPGRAM_URL = "https://brain.deepgram.com/v2/listen"
|
||||||
|
MODEL = "agara"
|
||||||
|
encoding = "linear16"
|
||||||
|
sample_rate = "8000"
|
||||||
|
# diarize = "false"
|
||||||
|
q_params = {
|
||||||
|
"model": MODEL,
|
||||||
|
"encoding": encoding,
|
||||||
|
"sample_rate": sample_rate,
|
||||||
|
"language": "en-US",
|
||||||
|
"multichannel": "false",
|
||||||
|
"punctuate": "true",
|
||||||
|
}
|
||||||
|
url = "{}?{}".format(DEEPGRAM_URL, urlencode(q_params))
|
||||||
|
# print(url)
|
||||||
|
creds = ("arjun@agaralabs.com", "PoX1Y@x4h%oS")
|
||||||
|
|
||||||
|
def deepgram_offline(audio_data):
|
||||||
|
request = Request(
|
||||||
|
url,
|
||||||
|
method="POST",
|
||||||
|
headers={
|
||||||
|
"Authorization": "Basic {}".format(
|
||||||
|
base64.b64encode("{}:{}".format(*creds).encode("utf-8")).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
},
|
||||||
|
data=audio_data,
|
||||||
|
)
|
||||||
|
with urlopen(request) as response:
|
||||||
|
msg = json.loads(response.read())
|
||||||
|
data = msg["results"]["channels"][0]["alternatives"][0]
|
||||||
|
return data["transcript"]
|
||||||
|
|
||||||
|
return deepgram_offline
|
||||||
|
|
||||||
|
|
||||||
|
class Text2Num(object):
|
||||||
|
"""docstring for Text2Num."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
numwords = {}
|
||||||
|
if not numwords:
|
||||||
|
units = [
|
||||||
|
"zero",
|
||||||
|
"one",
|
||||||
|
"two",
|
||||||
|
"three",
|
||||||
|
"four",
|
||||||
|
"five",
|
||||||
|
"six",
|
||||||
|
"seven",
|
||||||
|
"eight",
|
||||||
|
"nine",
|
||||||
|
"ten",
|
||||||
|
"eleven",
|
||||||
|
"twelve",
|
||||||
|
"thirteen",
|
||||||
|
"fourteen",
|
||||||
|
"fifteen",
|
||||||
|
"sixteen",
|
||||||
|
"seventeen",
|
||||||
|
"eighteen",
|
||||||
|
"nineteen",
|
||||||
|
]
|
||||||
|
|
||||||
|
tens = [
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"twenty",
|
||||||
|
"thirty",
|
||||||
|
"forty",
|
||||||
|
"fifty",
|
||||||
|
"sixty",
|
||||||
|
"seventy",
|
||||||
|
"eighty",
|
||||||
|
"ninety",
|
||||||
|
]
|
||||||
|
|
||||||
|
scales = ["hundred", "thousand", "million", "billion", "trillion"]
|
||||||
|
|
||||||
|
numwords["and"] = (1, 0)
|
||||||
|
for idx, word in enumerate(units):
|
||||||
|
numwords[word] = (1, idx)
|
||||||
|
for idx, word in enumerate(tens):
|
||||||
|
numwords[word] = (1, idx * 10)
|
||||||
|
for idx, word in enumerate(scales):
|
||||||
|
numwords[word] = (10 ** (idx * 3 or 2), 0)
|
||||||
|
self.numwords = numwords
|
||||||
|
|
||||||
|
def is_num(self, word):
|
||||||
|
return word in self.numwords
|
||||||
|
|
||||||
|
def parseOrdinal(self, utterance, **kwargs):
|
||||||
|
lookup_dict = {
|
||||||
|
"first": 1,
|
||||||
|
"second": 2,
|
||||||
|
"third": 3,
|
||||||
|
"fourth": 4,
|
||||||
|
"fifth": 5,
|
||||||
|
"sixth": 6,
|
||||||
|
"seventh": 7,
|
||||||
|
"eighth": 8,
|
||||||
|
"ninth": 9,
|
||||||
|
"tenth": 10,
|
||||||
|
"one": 1,
|
||||||
|
"two": 2,
|
||||||
|
"three": 3,
|
||||||
|
"four": 4,
|
||||||
|
"five": 5,
|
||||||
|
"six": 6,
|
||||||
|
"seven": 7,
|
||||||
|
"eight": 8,
|
||||||
|
"nine": 9,
|
||||||
|
"ten": 10,
|
||||||
|
"1": 1,
|
||||||
|
"2": 2,
|
||||||
|
"3": 3,
|
||||||
|
"4": 4,
|
||||||
|
"5": 5,
|
||||||
|
"6": 6,
|
||||||
|
"7": 7,
|
||||||
|
"8": 8,
|
||||||
|
"9": 9,
|
||||||
|
"10": 10,
|
||||||
|
"last": -1,
|
||||||
|
}
|
||||||
|
pattern = re.compile(
|
||||||
|
r"(\s|^)(?P<num>(first)|(third)|(fourth)|(fifth)|(sixth)|(seventh)|(eighth)|(ninth)|(tenth)|(two)|(three)|(four)|(five)|(six)|(seven)|(eight)|(nine)|(ten)|(1)|(2)|(3)|(4)|(5)|(6)|(7)|(8)|(9)|(10)|(last))(\s|$)",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
ordinal = ""
|
||||||
|
if pattern.search(utterance):
|
||||||
|
ordinal = pattern.search(utterance).groupdict()["num"].strip()
|
||||||
|
elif re.search(r"(\s|^)(?P<num>(second))(\s|$)", utterance):
|
||||||
|
ordinal = "second"
|
||||||
|
elif re.search(r"(\s|^)(?P<num>(one))(\s|$)", utterance):
|
||||||
|
ordinal = "one"
|
||||||
|
ordinal = lookup_dict.get(ordinal, "")
|
||||||
|
return ordinal
|
||||||
|
|
||||||
|
def convert(self, sent):
|
||||||
|
# res = []
|
||||||
|
# for token in sent.split():
|
||||||
|
# if token in self.numwords:
|
||||||
|
# res.append(str(self.text2int(token)))
|
||||||
|
# else:
|
||||||
|
# res.append(token)
|
||||||
|
# return " ".join(res)
|
||||||
|
|
||||||
|
return " ".join(
|
||||||
|
[
|
||||||
|
str(self.parseOrdinal(x)) if self.parseOrdinal(x) != "" else x
|
||||||
|
for x in sent.split()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def text2int(self, textnum):
|
||||||
|
|
||||||
|
current = result = 0
|
||||||
|
for word in textnum.split():
|
||||||
|
if word not in self.numwords:
|
||||||
|
raise Exception("Illegal word: " + word)
|
||||||
|
|
||||||
|
scale, increment = self.numwords[word]
|
||||||
|
current = current * scale + increment
|
||||||
|
if scale > 100:
|
||||||
|
result += current
|
||||||
|
current = 0
|
||||||
|
|
||||||
|
return result + current
|
||||||
|
|
||||||
|
|
||||||
|
def is_sub_sequence(str1, str2):
|
||||||
|
m = len(str1)
|
||||||
|
n = len(str2)
|
||||||
|
|
||||||
|
def check_seq(string1, string2, m, n):
|
||||||
|
# Base Cases
|
||||||
|
if m == 0:
|
||||||
|
return True
|
||||||
|
if n == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If last characters of two strings are matching
|
||||||
|
if string1[m - 1] == string2[n - 1]:
|
||||||
|
return check_seq(string1, string2, m - 1, n - 1)
|
||||||
|
|
||||||
|
# If last characters are not matching
|
||||||
|
return check_seq(string1, string2, m, n - 1)
|
||||||
|
|
||||||
|
return check_seq(str1, str2, m, n)
|
||||||
|
|
||||||
|
|
||||||
def parallel_apply(fn, iterable, workers=8):
|
def parallel_apply(fn, iterable, workers=8):
|
||||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||||
print(f"parallelly applying {fn}")
|
print(f"parallelly applying {fn}")
|
||||||
|
|
|
||||||
|
|
@ -156,6 +156,47 @@ def sample_ui(
|
||||||
ExtendedPath(sample_path).write_json(processed_data)
|
ExtendedPath(sample_path).write_json(processed_data)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def sample_asr_accuracy(
|
||||||
|
data_name: str = typer.Option(
|
||||||
|
"png_06_2020_week1_numbers_window_customer", show_default=True
|
||||||
|
),
|
||||||
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
|
sample_file: Path = Path("sample_dump.json"),
|
||||||
|
asr_service: str = "deepgram",
|
||||||
|
):
|
||||||
|
# import pandas as pd
|
||||||
|
# from pydub import AudioSegment
|
||||||
|
from ..utils import is_sub_sequence, Text2Num
|
||||||
|
|
||||||
|
# from ..utils import deepgram_transcribe_gen
|
||||||
|
#
|
||||||
|
# deepgram_transcriber = deepgram_transcribe_gen()
|
||||||
|
t2n = Text2Num()
|
||||||
|
# processed_data_path = dump_dir / Path(data_name) / dump_file
|
||||||
|
sample_path = dump_dir / Path(data_name) / sample_file
|
||||||
|
processed_data = ExtendedPath(sample_path).read_json()
|
||||||
|
# asr_data = []
|
||||||
|
match_count, total_samples = 0, len(processed_data["data"])
|
||||||
|
for dp in tqdm(processed_data["data"]):
|
||||||
|
# aud_data = Path(dp["audio_path"]).read_bytes()
|
||||||
|
# dgram_result = deepgram_transcriber(aud_data)
|
||||||
|
# dp["deepgram_asr"] = dgram_result
|
||||||
|
gcp_num = dp["text"]
|
||||||
|
dgm_num = t2n.convert(dp["deepgram_asr"].lower())
|
||||||
|
if is_sub_sequence(gcp_num, dgm_num):
|
||||||
|
match_count += 1
|
||||||
|
print(f"MATCH GCP:{gcp_num}\tDGM:{dgm_num}")
|
||||||
|
else:
|
||||||
|
print(f"FAIL GCP:{gcp_num}\tDGM:{dgm_num}")
|
||||||
|
# asr_data.append(dp)
|
||||||
|
typer.echo(
|
||||||
|
f"{match_count} from deepgram matches with {total_samples} gcp transcripts."
|
||||||
|
)
|
||||||
|
# processed_data["data"] = asr_data
|
||||||
|
# ExtendedPath(sample_path).write_json(processed_data)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def task_ui(
|
def task_ui(
|
||||||
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||||
|
|
@ -190,7 +231,9 @@ def dump_corrections(
|
||||||
col = get_mongo_conn(col="asr_validation")
|
col = get_mongo_conn(col="asr_validation")
|
||||||
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
||||||
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
||||||
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
|
cursor_obj = col.find(
|
||||||
|
{"type": "correction", "task_id": task_id}, projection={"_id": False}
|
||||||
|
)
|
||||||
corrections = [c for c in cursor_obj]
|
corrections = [c for c in cursor_obj]
|
||||||
ExtendedPath(dump_path).write_json(corrections)
|
ExtendedPath(dump_path).write_json(corrections)
|
||||||
|
|
||||||
|
|
@ -271,7 +314,9 @@ def split_extract(
|
||||||
dump_file: Path = Path("ui_dump.json"),
|
dump_file: Path = Path("ui_dump.json"),
|
||||||
manifest_file: Path = Path("manifest.json"),
|
manifest_file: Path = Path("manifest.json"),
|
||||||
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
||||||
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
conv_data_path: Path = typer.Option(
|
||||||
|
Path("./data/conv_data.json"), show_default=True
|
||||||
|
),
|
||||||
extraction_type: ExtractionType = ExtractionType.all,
|
extraction_type: ExtractionType = ExtractionType.all,
|
||||||
):
|
):
|
||||||
import shutil
|
import shutil
|
||||||
|
|
@ -293,7 +338,9 @@ def split_extract(
|
||||||
def extract_manifest(mg):
|
def extract_manifest(mg):
|
||||||
for m in mg:
|
for m in mg:
|
||||||
if m["text"] in extraction_vals:
|
if m["text"] in extraction_vals:
|
||||||
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
|
shutil.copy(
|
||||||
|
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
|
||||||
|
)
|
||||||
yield m
|
yield m
|
||||||
|
|
||||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||||
|
|
@ -302,12 +349,14 @@ def split_extract(
|
||||||
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
||||||
ui_data = orig_ui_data["data"]
|
ui_data = orig_ui_data["data"]
|
||||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||||
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
extracted_ui_data = list(
|
||||||
|
filter(lambda u: u["text"] in extraction_vals, ui_data)
|
||||||
|
)
|
||||||
final_data = []
|
final_data = []
|
||||||
for i, d in enumerate(extracted_ui_data):
|
for i, d in enumerate(extracted_ui_data):
|
||||||
d['real_idx'] = i
|
d["real_idx"] = i
|
||||||
final_data.append(d)
|
final_data.append(d)
|
||||||
orig_ui_data['data'] = final_data
|
orig_ui_data["data"] = final_data
|
||||||
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
||||||
|
|
||||||
if corrections_file:
|
if corrections_file:
|
||||||
|
|
@ -323,7 +372,7 @@ def split_extract(
|
||||||
)
|
)
|
||||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
||||||
|
|
||||||
if extraction_type.value == 'all':
|
if extraction_type.value == "all":
|
||||||
for ext_key in conv_data.keys():
|
for ext_key in conv_data.keys():
|
||||||
extract_data_of_type(ext_key)
|
extract_data_of_type(ext_key)
|
||||||
else:
|
else:
|
||||||
|
|
@ -345,7 +394,7 @@ def update_corrections(
|
||||||
|
|
||||||
def correct_manifest(ui_dump_path, corrections_path):
|
def correct_manifest(ui_dump_path, corrections_path):
|
||||||
corrections = ExtendedPath(corrections_path).read_json()
|
corrections = ExtendedPath(corrections_path).read_json()
|
||||||
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
|
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
|
||||||
correct_set = {
|
correct_set = {
|
||||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||||
}
|
}
|
||||||
|
|
@ -374,7 +423,9 @@ def update_corrections(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
orig_audio_path = Path(d["audio_path"])
|
orig_audio_path = Path(d["audio_path"])
|
||||||
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
|
new_name = str(
|
||||||
|
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
|
||||||
|
)
|
||||||
new_audio_path = orig_audio_path.with_name(new_name)
|
new_audio_path = orig_audio_path.with_name(new_name)
|
||||||
orig_audio_path.replace(new_audio_path)
|
orig_audio_path.replace(new_audio_path)
|
||||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
import typer
|
||||||
|
from jasper.data.utils import ExtendedPath
|
||||||
|
from jasper.data.validation.st_rerun import rerun
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
if not hasattr(st, "mongo_connected"):
|
||||||
|
# st.task_id = str(uuid4())
|
||||||
|
task_path = ExtendedPath("preview.lck")
|
||||||
|
|
||||||
|
def current_cursor_fn():
|
||||||
|
return task_path.read_json()["current_cursor"]
|
||||||
|
|
||||||
|
def update_cursor_fn(val=0):
|
||||||
|
task_path.write_json({"current_cursor": val})
|
||||||
|
rerun()
|
||||||
|
|
||||||
|
st.get_current_cursor = current_cursor_fn
|
||||||
|
st.update_cursor = update_cursor_fn
|
||||||
|
st.mongo_connected = True
|
||||||
|
# cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
|
||||||
|
# if not cursor_obj:
|
||||||
|
update_cursor_fn(0)
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache()
|
||||||
|
def load_ui_data(validation_ui_data_path: Path):
|
||||||
|
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||||
|
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(manifest: Path):
|
||||||
|
asr_data = load_ui_data(manifest)
|
||||||
|
sample_no = st.get_current_cursor()
|
||||||
|
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||||
|
print("Invalid samplno resetting to 0")
|
||||||
|
st.update_cursor(0)
|
||||||
|
sample = asr_data[sample_no]
|
||||||
|
st.title(f"ASR Manifest Preview")
|
||||||
|
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||||
|
new_sample = st.number_input(
|
||||||
|
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||||
|
)
|
||||||
|
if new_sample != sample_no + 1:
|
||||||
|
st.update_cursor(new_sample - 1)
|
||||||
|
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||||
|
st.audio(Path(sample["audio_filepath"]).open("rb"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
app()
|
||||||
|
except SystemExit:
|
||||||
|
pass
|
||||||
Loading…
Reference in New Issue