Compare commits
No commits in common. "f5c49338d92f0c415b0de98609c008cd668ede49" and "ae5586be7224d6bb283d8abfcf7b67110ca94ab7" have entirely different histories.
f5c49338d9
...
ae5586be72
|
|
@ -2,10 +2,6 @@ import os
|
||||||
import logging
|
import logging
|
||||||
import rpyc
|
import rpyc
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import typer
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
|
@ -23,28 +19,3 @@ def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
|
||||||
asr = rpyc.connect(asr_host, asr_port).root
|
asr = rpyc.connect(asr_host, asr_port).root
|
||||||
logger.info(f"connected to asr server successfully")
|
logger.info(f"connected to asr server successfully")
|
||||||
return asr.transcribe
|
return asr.transcribe
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def transcribe_file(audio_file: Path):
|
|
||||||
from pydub import AudioSegment
|
|
||||||
|
|
||||||
transcriber = transcribe_gen()
|
|
||||||
aud_seg = (
|
|
||||||
AudioSegment.from_file_using_temporary_files(audio_file)
|
|
||||||
.set_channels(1)
|
|
||||||
.set_sample_width(2)
|
|
||||||
.set_frame_rate(24000)
|
|
||||||
)
|
|
||||||
tscript_file_path = audio_file.with_suffix(".txt")
|
|
||||||
transcription = transcriber(aud_seg.raw_data)
|
|
||||||
with open(tscript_file_path, "w") as tf:
|
|
||||||
tf.write(transcription)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
app()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,13 @@
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import json
|
import json
|
||||||
import base64
|
|
||||||
import wave
|
import wave
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from math import floor
|
from math import floor
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from urllib.parse import urlsplit, urlencode
|
from urllib.parse import urlsplit
|
||||||
from urllib.request import Request, urlopen
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -103,9 +100,6 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
||||||
"data": [],
|
"data": [],
|
||||||
}
|
}
|
||||||
data_funcs = []
|
data_funcs = []
|
||||||
|
|
||||||
deepgram_transcriber = deepgram_transcribe_gen()
|
|
||||||
# t2n = Text2Num()
|
|
||||||
transcriber_gcp = gcp_transcribe_gen()
|
transcriber_gcp = gcp_transcribe_gen()
|
||||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||||
with asr_manifest.open("w") as mf:
|
with asr_manifest.open("w") as mf:
|
||||||
|
|
@ -125,10 +119,6 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
||||||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
||||||
gcp_seg = aud_seg.set_frame_rate(16000)
|
gcp_seg = aud_seg.set_frame_rate(16000)
|
||||||
gcp_result = transcriber_gcp(gcp_seg.raw_data)
|
gcp_result = transcriber_gcp(gcp_seg.raw_data)
|
||||||
aud_data = audio_path.read_bytes()
|
|
||||||
dgram_result = deepgram_transcriber(aud_data)
|
|
||||||
# gtruth = dp['text']
|
|
||||||
# dgram_result = t2n.convert(dgram_script)
|
|
||||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
||||||
wav_plot_path = (
|
wav_plot_path = (
|
||||||
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
||||||
|
|
@ -145,7 +135,6 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
||||||
"caller": caller_name,
|
"caller": caller_name,
|
||||||
"utterance_id": fname,
|
"utterance_id": fname,
|
||||||
"gcp_asr": gcp_result,
|
"gcp_asr": gcp_result,
|
||||||
"deepgram_asr": dgram_result,
|
|
||||||
"pretrained_asr": pretrained_result,
|
"pretrained_asr": pretrained_result,
|
||||||
"pretrained_wer": pretrained_wer,
|
"pretrained_wer": pretrained_wer,
|
||||||
"plot_path": str(wav_plot_path),
|
"plot_path": str(wav_plot_path),
|
||||||
|
|
@ -236,12 +225,6 @@ class ExtendedPath(type(Path())):
|
||||||
with self.open("r") as jf:
|
with self.open("r") as jf:
|
||||||
return json.load(jf)
|
return json.load(jf)
|
||||||
|
|
||||||
def read_jsonl(self):
|
|
||||||
print(f"reading jsonl from {self}")
|
|
||||||
with self.open("r") as jf:
|
|
||||||
for l in jf.readlines():
|
|
||||||
yield json.loads(l)
|
|
||||||
|
|
||||||
def write_json(self, data):
|
def write_json(self, data):
|
||||||
print(f"writing json to {self}")
|
print(f"writing json to {self}")
|
||||||
self.parent.mkdir(parents=True, exist_ok=True)
|
self.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -482,203 +465,6 @@ def gcp_transcribe_gen():
|
||||||
return sample_recognize
|
return sample_recognize
|
||||||
|
|
||||||
|
|
||||||
def deepgram_transcribe_gen():
|
|
||||||
|
|
||||||
DEEPGRAM_URL = "https://brain.deepgram.com/v2/listen"
|
|
||||||
MODEL = "agara"
|
|
||||||
encoding = "linear16"
|
|
||||||
sample_rate = "8000"
|
|
||||||
# diarize = "false"
|
|
||||||
q_params = {
|
|
||||||
"model": MODEL,
|
|
||||||
"encoding": encoding,
|
|
||||||
"sample_rate": sample_rate,
|
|
||||||
"language": "en-US",
|
|
||||||
"multichannel": "false",
|
|
||||||
"punctuate": "true",
|
|
||||||
}
|
|
||||||
url = "{}?{}".format(DEEPGRAM_URL, urlencode(q_params))
|
|
||||||
# print(url)
|
|
||||||
creds = ("arjun@agaralabs.com", "PoX1Y@x4h%oS")
|
|
||||||
|
|
||||||
def deepgram_offline(audio_data):
|
|
||||||
request = Request(
|
|
||||||
url,
|
|
||||||
method="POST",
|
|
||||||
headers={
|
|
||||||
"Authorization": "Basic {}".format(
|
|
||||||
base64.b64encode("{}:{}".format(*creds).encode("utf-8")).decode(
|
|
||||||
"utf-8"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
},
|
|
||||||
data=audio_data,
|
|
||||||
)
|
|
||||||
with urlopen(request) as response:
|
|
||||||
msg = json.loads(response.read())
|
|
||||||
data = msg["results"]["channels"][0]["alternatives"][0]
|
|
||||||
return data["transcript"]
|
|
||||||
|
|
||||||
return deepgram_offline
|
|
||||||
|
|
||||||
|
|
||||||
class Text2Num(object):
|
|
||||||
"""docstring for Text2Num."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
numwords = {}
|
|
||||||
if not numwords:
|
|
||||||
units = [
|
|
||||||
"zero",
|
|
||||||
"one",
|
|
||||||
"two",
|
|
||||||
"three",
|
|
||||||
"four",
|
|
||||||
"five",
|
|
||||||
"six",
|
|
||||||
"seven",
|
|
||||||
"eight",
|
|
||||||
"nine",
|
|
||||||
"ten",
|
|
||||||
"eleven",
|
|
||||||
"twelve",
|
|
||||||
"thirteen",
|
|
||||||
"fourteen",
|
|
||||||
"fifteen",
|
|
||||||
"sixteen",
|
|
||||||
"seventeen",
|
|
||||||
"eighteen",
|
|
||||||
"nineteen",
|
|
||||||
]
|
|
||||||
|
|
||||||
tens = [
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
"twenty",
|
|
||||||
"thirty",
|
|
||||||
"forty",
|
|
||||||
"fifty",
|
|
||||||
"sixty",
|
|
||||||
"seventy",
|
|
||||||
"eighty",
|
|
||||||
"ninety",
|
|
||||||
]
|
|
||||||
|
|
||||||
scales = ["hundred", "thousand", "million", "billion", "trillion"]
|
|
||||||
|
|
||||||
numwords["and"] = (1, 0)
|
|
||||||
for idx, word in enumerate(units):
|
|
||||||
numwords[word] = (1, idx)
|
|
||||||
for idx, word in enumerate(tens):
|
|
||||||
numwords[word] = (1, idx * 10)
|
|
||||||
for idx, word in enumerate(scales):
|
|
||||||
numwords[word] = (10 ** (idx * 3 or 2), 0)
|
|
||||||
self.numwords = numwords
|
|
||||||
|
|
||||||
def is_num(self, word):
|
|
||||||
return word in self.numwords
|
|
||||||
|
|
||||||
def parseOrdinal(self, utterance, **kwargs):
|
|
||||||
lookup_dict = {
|
|
||||||
"first": 1,
|
|
||||||
"second": 2,
|
|
||||||
"third": 3,
|
|
||||||
"fourth": 4,
|
|
||||||
"fifth": 5,
|
|
||||||
"sixth": 6,
|
|
||||||
"seventh": 7,
|
|
||||||
"eighth": 8,
|
|
||||||
"ninth": 9,
|
|
||||||
"tenth": 10,
|
|
||||||
"one": 1,
|
|
||||||
"two": 2,
|
|
||||||
"three": 3,
|
|
||||||
"four": 4,
|
|
||||||
"five": 5,
|
|
||||||
"six": 6,
|
|
||||||
"seven": 7,
|
|
||||||
"eight": 8,
|
|
||||||
"nine": 9,
|
|
||||||
"ten": 10,
|
|
||||||
"1": 1,
|
|
||||||
"2": 2,
|
|
||||||
"3": 3,
|
|
||||||
"4": 4,
|
|
||||||
"5": 5,
|
|
||||||
"6": 6,
|
|
||||||
"7": 7,
|
|
||||||
"8": 8,
|
|
||||||
"9": 9,
|
|
||||||
"10": 10,
|
|
||||||
"last": -1,
|
|
||||||
}
|
|
||||||
pattern = re.compile(
|
|
||||||
r"(\s|^)(?P<num>(first)|(third)|(fourth)|(fifth)|(sixth)|(seventh)|(eighth)|(ninth)|(tenth)|(two)|(three)|(four)|(five)|(six)|(seven)|(eight)|(nine)|(ten)|(1)|(2)|(3)|(4)|(5)|(6)|(7)|(8)|(9)|(10)|(last))(\s|$)",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
ordinal = ""
|
|
||||||
if pattern.search(utterance):
|
|
||||||
ordinal = pattern.search(utterance).groupdict()["num"].strip()
|
|
||||||
elif re.search(r"(\s|^)(?P<num>(second))(\s|$)", utterance):
|
|
||||||
ordinal = "second"
|
|
||||||
elif re.search(r"(\s|^)(?P<num>(one))(\s|$)", utterance):
|
|
||||||
ordinal = "one"
|
|
||||||
ordinal = lookup_dict.get(ordinal, "")
|
|
||||||
return ordinal
|
|
||||||
|
|
||||||
def convert(self, sent):
|
|
||||||
# res = []
|
|
||||||
# for token in sent.split():
|
|
||||||
# if token in self.numwords:
|
|
||||||
# res.append(str(self.text2int(token)))
|
|
||||||
# else:
|
|
||||||
# res.append(token)
|
|
||||||
# return " ".join(res)
|
|
||||||
|
|
||||||
return " ".join(
|
|
||||||
[
|
|
||||||
str(self.parseOrdinal(x)) if self.parseOrdinal(x) != "" else x
|
|
||||||
for x in sent.split()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def text2int(self, textnum):
|
|
||||||
|
|
||||||
current = result = 0
|
|
||||||
for word in textnum.split():
|
|
||||||
if word not in self.numwords:
|
|
||||||
raise Exception("Illegal word: " + word)
|
|
||||||
|
|
||||||
scale, increment = self.numwords[word]
|
|
||||||
current = current * scale + increment
|
|
||||||
if scale > 100:
|
|
||||||
result += current
|
|
||||||
current = 0
|
|
||||||
|
|
||||||
return result + current
|
|
||||||
|
|
||||||
|
|
||||||
def is_sub_sequence(str1, str2):
|
|
||||||
m = len(str1)
|
|
||||||
n = len(str2)
|
|
||||||
|
|
||||||
def check_seq(string1, string2, m, n):
|
|
||||||
# Base Cases
|
|
||||||
if m == 0:
|
|
||||||
return True
|
|
||||||
if n == 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# If last characters of two strings are matching
|
|
||||||
if string1[m - 1] == string2[n - 1]:
|
|
||||||
return check_seq(string1, string2, m - 1, n - 1)
|
|
||||||
|
|
||||||
# If last characters are not matching
|
|
||||||
return check_seq(string1, string2, m, n - 1)
|
|
||||||
|
|
||||||
return check_seq(str1, str2, m, n)
|
|
||||||
|
|
||||||
|
|
||||||
def parallel_apply(fn, iterable, workers=8):
|
def parallel_apply(fn, iterable, workers=8):
|
||||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||||
print(f"parallelly applying {fn}")
|
print(f"parallelly applying {fn}")
|
||||||
|
|
|
||||||
|
|
@ -156,47 +156,6 @@ def sample_ui(
|
||||||
ExtendedPath(sample_path).write_json(processed_data)
|
ExtendedPath(sample_path).write_json(processed_data)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def sample_asr_accuracy(
|
|
||||||
data_name: str = typer.Option(
|
|
||||||
"png_06_2020_week1_numbers_window_customer", show_default=True
|
|
||||||
),
|
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
|
||||||
sample_file: Path = Path("sample_dump.json"),
|
|
||||||
asr_service: str = "deepgram",
|
|
||||||
):
|
|
||||||
# import pandas as pd
|
|
||||||
# from pydub import AudioSegment
|
|
||||||
from ..utils import is_sub_sequence, Text2Num
|
|
||||||
|
|
||||||
# from ..utils import deepgram_transcribe_gen
|
|
||||||
#
|
|
||||||
# deepgram_transcriber = deepgram_transcribe_gen()
|
|
||||||
t2n = Text2Num()
|
|
||||||
# processed_data_path = dump_dir / Path(data_name) / dump_file
|
|
||||||
sample_path = dump_dir / Path(data_name) / sample_file
|
|
||||||
processed_data = ExtendedPath(sample_path).read_json()
|
|
||||||
# asr_data = []
|
|
||||||
match_count, total_samples = 0, len(processed_data["data"])
|
|
||||||
for dp in tqdm(processed_data["data"]):
|
|
||||||
# aud_data = Path(dp["audio_path"]).read_bytes()
|
|
||||||
# dgram_result = deepgram_transcriber(aud_data)
|
|
||||||
# dp["deepgram_asr"] = dgram_result
|
|
||||||
gcp_num = dp["text"]
|
|
||||||
dgm_num = t2n.convert(dp["deepgram_asr"].lower())
|
|
||||||
if is_sub_sequence(gcp_num, dgm_num):
|
|
||||||
match_count += 1
|
|
||||||
print(f"MATCH GCP:{gcp_num}\tDGM:{dgm_num}")
|
|
||||||
else:
|
|
||||||
print(f"FAIL GCP:{gcp_num}\tDGM:{dgm_num}")
|
|
||||||
# asr_data.append(dp)
|
|
||||||
typer.echo(
|
|
||||||
f"{match_count} from deepgram matches with {total_samples} gcp transcripts."
|
|
||||||
)
|
|
||||||
# processed_data["data"] = asr_data
|
|
||||||
# ExtendedPath(sample_path).write_json(processed_data)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def task_ui(
|
def task_ui(
|
||||||
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||||
|
|
@ -231,9 +190,7 @@ def dump_corrections(
|
||||||
col = get_mongo_conn(col="asr_validation")
|
col = get_mongo_conn(col="asr_validation")
|
||||||
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
||||||
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
||||||
cursor_obj = col.find(
|
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
|
||||||
{"type": "correction", "task_id": task_id}, projection={"_id": False}
|
|
||||||
)
|
|
||||||
corrections = [c for c in cursor_obj]
|
corrections = [c for c in cursor_obj]
|
||||||
ExtendedPath(dump_path).write_json(corrections)
|
ExtendedPath(dump_path).write_json(corrections)
|
||||||
|
|
||||||
|
|
@ -314,9 +271,7 @@ def split_extract(
|
||||||
dump_file: Path = Path("ui_dump.json"),
|
dump_file: Path = Path("ui_dump.json"),
|
||||||
manifest_file: Path = Path("manifest.json"),
|
manifest_file: Path = Path("manifest.json"),
|
||||||
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
||||||
conv_data_path: Path = typer.Option(
|
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
||||||
Path("./data/conv_data.json"), show_default=True
|
|
||||||
),
|
|
||||||
extraction_type: ExtractionType = ExtractionType.all,
|
extraction_type: ExtractionType = ExtractionType.all,
|
||||||
):
|
):
|
||||||
import shutil
|
import shutil
|
||||||
|
|
@ -338,9 +293,7 @@ def split_extract(
|
||||||
def extract_manifest(mg):
|
def extract_manifest(mg):
|
||||||
for m in mg:
|
for m in mg:
|
||||||
if m["text"] in extraction_vals:
|
if m["text"] in extraction_vals:
|
||||||
shutil.copy(
|
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
|
||||||
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
|
|
||||||
)
|
|
||||||
yield m
|
yield m
|
||||||
|
|
||||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||||
|
|
@ -349,14 +302,12 @@ def split_extract(
|
||||||
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
||||||
ui_data = orig_ui_data["data"]
|
ui_data = orig_ui_data["data"]
|
||||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||||
extracted_ui_data = list(
|
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
||||||
filter(lambda u: u["text"] in extraction_vals, ui_data)
|
|
||||||
)
|
|
||||||
final_data = []
|
final_data = []
|
||||||
for i, d in enumerate(extracted_ui_data):
|
for i, d in enumerate(extracted_ui_data):
|
||||||
d["real_idx"] = i
|
d['real_idx'] = i
|
||||||
final_data.append(d)
|
final_data.append(d)
|
||||||
orig_ui_data["data"] = final_data
|
orig_ui_data['data'] = final_data
|
||||||
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
||||||
|
|
||||||
if corrections_file:
|
if corrections_file:
|
||||||
|
|
@ -372,7 +323,7 @@ def split_extract(
|
||||||
)
|
)
|
||||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
||||||
|
|
||||||
if extraction_type.value == "all":
|
if extraction_type.value == 'all':
|
||||||
for ext_key in conv_data.keys():
|
for ext_key in conv_data.keys():
|
||||||
extract_data_of_type(ext_key)
|
extract_data_of_type(ext_key)
|
||||||
else:
|
else:
|
||||||
|
|
@ -394,7 +345,7 @@ def update_corrections(
|
||||||
|
|
||||||
def correct_manifest(ui_dump_path, corrections_path):
|
def correct_manifest(ui_dump_path, corrections_path):
|
||||||
corrections = ExtendedPath(corrections_path).read_json()
|
corrections = ExtendedPath(corrections_path).read_json()
|
||||||
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
|
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
|
||||||
correct_set = {
|
correct_set = {
|
||||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||||
}
|
}
|
||||||
|
|
@ -423,9 +374,7 @@ def update_corrections(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
orig_audio_path = Path(d["audio_path"])
|
orig_audio_path = Path(d["audio_path"])
|
||||||
new_name = str(
|
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
|
||||||
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
|
|
||||||
)
|
|
||||||
new_audio_path = orig_audio_path.with_name(new_name)
|
new_audio_path = orig_audio_path.with_name(new_name)
|
||||||
orig_audio_path.replace(new_audio_path)
|
orig_audio_path.replace(new_audio_path)
|
||||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||||
|
|
|
||||||
|
|
@ -1,58 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import streamlit as st
|
|
||||||
import typer
|
|
||||||
from jasper.data.utils import ExtendedPath
|
|
||||||
from jasper.data.validation.st_rerun import rerun
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
if not hasattr(st, "mongo_connected"):
|
|
||||||
# st.task_id = str(uuid4())
|
|
||||||
task_path = ExtendedPath("preview.lck")
|
|
||||||
|
|
||||||
def current_cursor_fn():
|
|
||||||
return task_path.read_json()["current_cursor"]
|
|
||||||
|
|
||||||
def update_cursor_fn(val=0):
|
|
||||||
task_path.write_json({"current_cursor": val})
|
|
||||||
rerun()
|
|
||||||
|
|
||||||
st.get_current_cursor = current_cursor_fn
|
|
||||||
st.update_cursor = update_cursor_fn
|
|
||||||
st.mongo_connected = True
|
|
||||||
# cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
|
|
||||||
# if not cursor_obj:
|
|
||||||
update_cursor_fn(0)
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache()
|
|
||||||
def load_ui_data(validation_ui_data_path: Path):
|
|
||||||
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
|
||||||
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def main(manifest: Path):
|
|
||||||
asr_data = load_ui_data(manifest)
|
|
||||||
sample_no = st.get_current_cursor()
|
|
||||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
|
||||||
print("Invalid samplno resetting to 0")
|
|
||||||
st.update_cursor(0)
|
|
||||||
sample = asr_data[sample_no]
|
|
||||||
st.title(f"ASR Manifest Preview")
|
|
||||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
|
||||||
new_sample = st.number_input(
|
|
||||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
|
||||||
)
|
|
||||||
if new_sample != sample_no + 1:
|
|
||||||
st.update_cursor(new_sample - 1)
|
|
||||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
|
||||||
st.audio(Path(sample["audio_filepath"]).open("rb"))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
try:
|
|
||||||
app()
|
|
||||||
except SystemExit:
|
|
||||||
pass
|
|
||||||
Loading…
Reference in New Issue