mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-03-09 19:02:35 +00:00
1. removed the transcriber_pretrained/speller from utils
2. introduced get_mongo_coll to get the collection object directly from mongo uri 3. removed processing of correction entries to remove space/upper casing
This commit is contained in:
@@ -24,6 +24,7 @@ def preprocess_datapoint(
|
||||
import librosa.display
|
||||
from pydub import AudioSegment
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
from jasper.client import transcribe_gen
|
||||
|
||||
try:
|
||||
res = dict(sample)
|
||||
@@ -36,7 +37,7 @@ def preprocess_datapoint(
|
||||
res["spoken"] = res["text"]
|
||||
res["utterance_id"] = audio_path.stem
|
||||
if not annotation_only:
|
||||
from jasper.client import transcriber_pretrained, transcriber_speller
|
||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||
|
||||
aud_seg = (
|
||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||
@@ -49,6 +50,7 @@ def preprocess_datapoint(
|
||||
[res["text"]], [res["pretrained_asr"]]
|
||||
)
|
||||
if use_domain_asr:
|
||||
transcriber_speller = transcribe_gen(asr_port=8045)
|
||||
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["domain_wer"] = word_error_rate(
|
||||
[res["spoken"]], [res["pretrained_asr"]]
|
||||
@@ -74,19 +76,19 @@ def preprocess_datapoint(
|
||||
|
||||
@app.command()
|
||||
def dump_validation_ui_data(
|
||||
data_manifest_path: Path = typer.Option(
|
||||
Path("./data/asr_data/call_alphanum/manifest.json"), show_default=True
|
||||
dataset_path: Path = typer.Option(
|
||||
Path("./data/asr_data/call_alphanum"), show_default=True
|
||||
),
|
||||
dump_path: Path = typer.Option(
|
||||
Path("./data/valiation_data/ui_dump.json"), show_default=True
|
||||
),
|
||||
use_domain_asr: bool = True,
|
||||
annotation_only: bool = True,
|
||||
dump_name: str = typer.Option("ui_dump.json", show_default=True),
|
||||
use_domain_asr: bool = False,
|
||||
annotation_only: bool = False,
|
||||
enable_plots: bool = True,
|
||||
):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
data_manifest_path = dataset_path / Path("manifest.json")
|
||||
dump_path: Path = Path(f"./data/valiation_data/{dataset_path.stem}/{dump_name}")
|
||||
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
@@ -137,7 +139,7 @@ def dump_validation_ui_data(
|
||||
|
||||
@app.command()
|
||||
def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")):
|
||||
col = get_mongo_conn().test.asr_validation
|
||||
col = get_mongo_conn(col='asr_validation')
|
||||
|
||||
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
||||
corrections = [c for c in cursor_obj]
|
||||
@@ -154,7 +156,7 @@ def fill_unannotated(
|
||||
annotated_codes = {c["code"] for c in corrections}
|
||||
all_codes = {c["gold_chars"] for c in processed_data}
|
||||
unann_codes = all_codes - annotated_codes
|
||||
mongo_conn = get_mongo_conn().test.asr_validation
|
||||
mongo_conn = get_mongo_conn(col='asr_validation')
|
||||
for c in unann_codes:
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": c},
|
||||
@@ -232,7 +234,7 @@ def update_corrections(
|
||||
def clear_mongo_corrections():
|
||||
delete = typer.confirm("are you sure you want to clear mongo collection it?")
|
||||
if delete:
|
||||
col = get_mongo_conn().test.asr_validation
|
||||
col = get_mongo_conn(col='asr_validation')
|
||||
col.delete_many({"type": "correction"})
|
||||
typer.echo("deleted mongo collection.")
|
||||
typer.echo("Aborted")
|
||||
|
||||
@@ -9,7 +9,7 @@ app = typer.Typer()
|
||||
|
||||
|
||||
if not hasattr(st, "mongo_connected"):
|
||||
st.mongoclient = get_mongo_conn().test.asr_validation
|
||||
st.mongoclient = get_mongo_conn(col='asr_validation')
|
||||
mongo_conn = st.mongoclient
|
||||
|
||||
def current_cursor_fn():
|
||||
@@ -111,9 +111,8 @@ def main(manifest: Path):
|
||||
if selected == "Inaudible":
|
||||
corrected = ""
|
||||
if st.button("Submit"):
|
||||
correct_code = corrected.replace(" ", "").upper()
|
||||
st.update_entry(
|
||||
sample["utterance_id"], {"status": selected, "correction": correct_code}
|
||||
sample["utterance_id"], {"status": selected, "correction": corrected}
|
||||
)
|
||||
st.update_cursor(sample_no + 1)
|
||||
if correction_entry:
|
||||
|
||||
Reference in New Issue
Block a user