From 41074a1bcabac5db47888bf2aa4ae88d853e559e Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Wed, 29 Apr 2020 14:26:11 +0530 Subject: [PATCH] 1. added streamlit based validation ui with mongodb datastore integration 2. fix asr wrong sample rate inference 3. update requirements --- jasper/asr.py | 2 +- jasper/data_utils/call_recycler.py | 1 + jasper/data_utils/validation/jasper_client.py | 23 +++ jasper/data_utils/validation/orig_ui.py | 73 ++++++++ jasper/data_utils/validation/st_rerun.py | 38 ++++ jasper/data_utils/validation/ui.py | 171 ++++++++++++++++++ setup.py | 8 + streamlit.py | 3 + 8 files changed, 318 insertions(+), 1 deletion(-) create mode 100644 jasper/data_utils/validation/jasper_client.py create mode 100644 jasper/data_utils/validation/orig_ui.py create mode 100644 jasper/data_utils/validation/st_rerun.py create mode 100644 jasper/data_utils/validation/ui.py create mode 100644 streamlit.py diff --git a/jasper/asr.py b/jasper/asr.py index de3d78f..2f16635 100644 --- a/jasper/asr.py +++ b/jasper/asr.py @@ -62,7 +62,7 @@ class JasperASR(object): wf = wave.open(audio_file_path, "w") wf.setnchannels(1) wf.setsampwidth(2) - wf.setframerate(16000) + wf.setframerate(24000) wf.writeframesraw(audio_data) wf.close() manifest = {"audio_filepath": audio_file_path, "duration": 60, "text": "todo"} diff --git a/jasper/data_utils/call_recycler.py b/jasper/data_utils/call_recycler.py index 8678787..d3904c0 100644 --- a/jasper/data_utils/call_recycler.py +++ b/jasper/data_utils/call_recycler.py @@ -46,6 +46,7 @@ def analyze( from tqdm import tqdm from .utils import asr_data_writer from pydub import AudioSegment + # from itertools import product, chain matplotlib.rcParams["agg.path.chunksize"] = 10000 diff --git a/jasper/data_utils/validation/jasper_client.py b/jasper/data_utils/validation/jasper_client.py new file mode 100644 index 0000000..84fd465 --- /dev/null +++ b/jasper/data_utils/validation/jasper_client.py @@ -0,0 +1,23 @@ +import os +import logging +import rpyc + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost") +ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045")) + + +def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT): + logger.info(f"connecting to asr server at {asr_host}:{asr_port}") + asr = rpyc.connect(asr_host, asr_port).root + logger.info(f"connected to asr server successfully") + return asr.transcribe + + +transcriber_pretrained = transcribe_gen(asr_port=8044) +transcriber_speller = transcribe_gen(asr_port=8045) diff --git a/jasper/data_utils/validation/orig_ui.py b/jasper/data_utils/validation/orig_ui.py new file mode 100644 index 0000000..95a23a9 --- /dev/null +++ b/jasper/data_utils/validation/orig_ui.py @@ -0,0 +1,73 @@ +import json +from pathlib import Path +import streamlit as st + +# import matplotlib.pyplot as plt +# import numpy as np +import librosa +import librosa.display +from pydub import AudioSegment +from jasper.client import transcriber_pretrained, transcriber_speller + +# from pymongo import MongoClient + +st.title("ASR Speller Validation") +dataset_path: Path = Path("/dataset/asr_data/call_alphanum_v3") +manifest_path = dataset_path / Path("test_manifest.json") +# print(manifest_path) +with manifest_path.open("r") as pf: + pnr_jsonl = pf.readlines() + pnr_data = [json.loads(i) for i in pnr_jsonl] + + +def main(): + # pnr_data = MongoClient("mongodb://localhost:27017/").test.asr_pnr + # sample_no = 0 + sample_no = ( + st.slider( + "Sample", + min_value=1, + max_value=len(pnr_data), + value=1, + step=1, + format=None, + key=None, + ) + - 1 + ) + sample = pnr_data[sample_no] + st.write(f"Sample No: {sample_no+1} of {len(pnr_data)}") + audio_path = Path(sample["audio_filepath"]) + # st.write(f"Audio Path:{audio_path}") + aud_seg = AudioSegment.from_wav(audio_path) # .set_channels(1).set_sample_width(2).set_frame_rate(24000) + st.sidebar.text("Transcription") + st.sidebar.text(f"Pretrained:{transcriber_pretrained(aud_seg.raw_data)}") + st.sidebar.text(f"Speller:{transcriber_speller(aud_seg.raw_data)}") + st.sidebar.text(f"Expected: {audio_path.stem}") + spell_text = sample["text"] + st.sidebar.text(f"Spelled: {spell_text}") + st.audio(audio_path.open("rb")) + selected = st.radio("The Audio is", ("Correct", "Incorrect", "Inaudible")) + corrected = audio_path.stem + if selected == "Incorrect": + corrected = st.text_input("Actual:", value=corrected) + # content = '' + if sample_no > 0 and st.button("Previous"): + sample_no -= 1 + if st.button("Next"): + st.write(sample_no, selected, corrected) + sample_no += 1 + + (y, sr) = librosa.load(audio_path) + librosa.display.waveplot(y=y, sr=sr) + # arr = np.random.normal(1, 1, size=100) + # plt.hist(arr, bins=20) + st.sidebar.pyplot() + + +# def main(): +# app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data_utils/validation/st_rerun.py b/jasper/data_utils/validation/st_rerun.py new file mode 100644 index 0000000..ae80624 --- /dev/null +++ b/jasper/data_utils/validation/st_rerun.py @@ -0,0 +1,38 @@ +import streamlit.ReportThread as ReportThread +from streamlit.ScriptRequestQueue import RerunData +from streamlit.ScriptRunner import RerunException +from streamlit.server.Server import Server + + +def rerun(): + """Rerun a Streamlit app from the top!""" + widget_states = _get_widget_states() + raise RerunException(RerunData(widget_states)) + + +def _get_widget_states(): + # Hack to get the session object from Streamlit. + + ctx = ReportThread.get_report_ctx() + + session = None + + current_server = Server.get_current() + if hasattr(current_server, '_session_infos'): + # Streamlit < 0.56 + session_infos = Server.get_current()._session_infos.values() + else: + session_infos = Server.get_current()._session_info_by_id.values() + + for session_info in session_infos: + if session_info.session.enqueue == ctx.enqueue: + session = session_info.session + + if session is None: + raise RuntimeError( + "Oh noes. Couldn't get your Streamlit Session object" + "Are you doing something fancy with threads?" + ) + # Got the session object! + + return session._widget_states diff --git a/jasper/data_utils/validation/ui.py b/jasper/data_utils/validation/ui.py new file mode 100644 index 0000000..03fae29 --- /dev/null +++ b/jasper/data_utils/validation/ui.py @@ -0,0 +1,171 @@ +import json +from io import BytesIO +from pathlib import Path + +import streamlit as st +from nemo.collections.asr.metrics import word_error_rate +import librosa +import librosa.display +import matplotlib.pyplot as plt +from tqdm import tqdm +from pydub import AudioSegment +import pymongo +from .jasper_client import transcriber_pretrained, transcriber_speller +from .st_rerun import rerun + +st.title("ASR Speller Validation") + + +def clear_mongo_corrections(): + col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation + col.delete_many({"type": "correction"}) + + +def preprocess_datapoint(idx, sample): + res = dict(sample) + res["real_idx"] = idx + audio_path = Path(sample["audio_filepath"]) + res["audio_path"] = audio_path + res["gold_chars"] = audio_path.stem + res["gold_phone"] = sample["text"] + aud_seg = ( + AudioSegment.from_wav(audio_path) + .set_channels(1) + .set_sample_width(2) + .set_frame_rate(24000) + ) + res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) + res["speller_asr"] = transcriber_speller(aud_seg.raw_data) + res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]]) + (y, sr) = librosa.load(audio_path) + plt.tight_layout() + librosa.display.waveplot(y=y, sr=sr) + wav_plot_f = BytesIO() + plt.savefig(wav_plot_f, format="png", dpi=50) + plt.close() + wav_plot_f.seek(0) + res["plot_png"] = wav_plot_f + return res + + +if not hasattr(st, "mongo_connected"): + st.mongoclient = pymongo.MongoClient( + "mongodb://localhost:27017/" + ).test.asr_validation + mongo_conn = st.mongoclient + + def current_cursor_fn(): + # mongo_conn = st.mongoclient + cursor_obj = mongo_conn.find_one({"type": "current_cursor"}) + cursor_val = cursor_obj["cursor"] + return cursor_val + + def update_cursor_fn(val=0): + mongo_conn.find_one_and_update( + {"type": "current_cursor"}, + {"$set": {"type": "current_cursor", "cursor": val}}, + upsert=True, + ) + rerun() + + def get_correction_entry_fn(code): + # mongo_conn = st.mongoclient + # cursor_obj = mongo_conn.find_one({"type": "correction", "code": code}) + # cursor_val = cursor_obj["cursor"] + return mongo_conn.find_one( + {"type": "correction", "code": code}, projection={"_id": False} + ) + + def update_entry_fn(code, value): + mongo_conn.find_one_and_update( + {"type": "correction", "code": code}, + {"$set": {"value": value}}, + upsert=True, + ) + rerun() + + cursor_obj = mongo_conn.find_one({"type": "current_cursor"}) + if not cursor_obj: + update_cursor_fn(0) + st.get_current_cursor = current_cursor_fn + st.update_cursor = update_cursor_fn + st.get_correction_entry = get_correction_entry_fn + st.update_entry = update_entry_fn + st.mongo_connected = True + + +@st.cache(hash_funcs={"rpyc.core.netref.builtins.method": lambda _: None}) +def preprocess_dataset(dataset_path: Path = Path("/dataset/asr_data/call_alphanum_v3")): + print("misssed cache : preprocess_dataset") + dataset_path: Path = Path("/dataset/asr_data/call_alphanum_v3") + manifest_path = dataset_path / Path("test_manifest.json") + with manifest_path.open("r") as pf: + pnr_jsonl = pf.readlines() + pnr_data = [ + preprocess_datapoint(i, json.loads(v)) + for i, v in enumerate(tqdm(pnr_jsonl)) + ] + result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True) + return result + + +def main(): + pnr_data = preprocess_dataset() + sample_no = st.get_current_cursor() + sample = pnr_data[sample_no] + st.markdown( + f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['gold_phone']}*" + ) + new_sample = st.number_input( + "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data) + ) + if new_sample != sample_no + 1: + st.update_cursor(new_sample - 1) + st.sidebar.title(f"Details: [{sample['real_idx']}]") + st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**") + st.sidebar.markdown(f"Expected Speech: *{sample['gold_phone']}*") + st.sidebar.title("Results:") + st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}") + st.sidebar.text(f"Speller:{sample['speller_asr']}") + + st.sidebar.title(f"WER: {sample['wer']:.2f}%") + # (y, sr) = librosa.load(sample["audio_path"]) + # librosa.display.waveplot(y=y, sr=sr) + # st.sidebar.pyplot(fig=sample["plot_fig"]) + st.sidebar.image(sample["plot_png"]) + st.audio(sample["audio_path"].open("rb")) + corrected = sample["gold_chars"] + correction_entry = st.get_correction_entry(sample["gold_chars"]) + selected_idx = 0 + options = ("Correct", "Incorrect", "Inaudible") + if correction_entry: + selected_idx = options.index(correction_entry["value"]["status"]) + corrected = correction_entry["value"]["correction"] + selected = st.radio("The Audio is", options, index=selected_idx) + if selected == "Incorrect": + corrected = st.text_input("Actual:", value=corrected) + if selected == "Inaudible": + corrected = "" + if st.button("Submit"): + correct_code = corrected.replace(" ", "").upper() + st.update_entry( + sample["gold_chars"], {"status": selected, "correction": correct_code} + ) + if correction_entry: + st.markdown( + f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' + ) + # st.markdown( + # ",".join( + # [ + # "**" + str(p["real_idx"]) + "**" + # if p["real_idx"] == sample["real_idx"] + # else str(p["real_idx"]) + # for p in pnr_data + # ] + # ) + # ) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 015e125..7111587 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,14 @@ extra_requirements = { "typer[all]==0.1.1", "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses", ], + "validation": [ + "rpyc~=4.1.4", + "tqdm~=4.39.0", + "librosa==0.7.2", + "pydub~=0.23.1", + "streamlit==0.58.0", + "stringcase==1.2.0" + ] # "train": [ # "torchaudio==0.5.0", # "torch-stft==0.1.4", diff --git a/streamlit.py b/streamlit.py new file mode 100644 index 0000000..58ba0bd --- /dev/null +++ b/streamlit.py @@ -0,0 +1,3 @@ +import runpy + +runpy.run_module("jasper.data_utils.validation.ui", run_name="__main__", alter_sys=True)