1. added streamlit based validation ui with mongodb datastore integration
2. fix asr wrong sample rate inference 3. update requirements
parent
61048f855e
commit
41074a1bca
|
|
@ -62,7 +62,7 @@ class JasperASR(object):
|
||||||
wf = wave.open(audio_file_path, "w")
|
wf = wave.open(audio_file_path, "w")
|
||||||
wf.setnchannels(1)
|
wf.setnchannels(1)
|
||||||
wf.setsampwidth(2)
|
wf.setsampwidth(2)
|
||||||
wf.setframerate(16000)
|
wf.setframerate(24000)
|
||||||
wf.writeframesraw(audio_data)
|
wf.writeframesraw(audio_data)
|
||||||
wf.close()
|
wf.close()
|
||||||
manifest = {"audio_filepath": audio_file_path, "duration": 60, "text": "todo"}
|
manifest = {"audio_filepath": audio_file_path, "duration": 60, "text": "todo"}
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ def analyze(
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .utils import asr_data_writer
|
from .utils import asr_data_writer
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
# from itertools import product, chain
|
||||||
|
|
||||||
matplotlib.rcParams["agg.path.chunksize"] = 10000
|
matplotlib.rcParams["agg.path.chunksize"] = 10000
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import rpyc
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost")
|
||||||
|
ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045"))
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
|
||||||
|
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
|
||||||
|
asr = rpyc.connect(asr_host, asr_port).root
|
||||||
|
logger.info(f"connected to asr server successfully")
|
||||||
|
return asr.transcribe
|
||||||
|
|
||||||
|
|
||||||
|
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||||
|
transcriber_speller = transcribe_gen(asr_port=8045)
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
# import numpy as np
|
||||||
|
import librosa
|
||||||
|
import librosa.display
|
||||||
|
from pydub import AudioSegment
|
||||||
|
from jasper.client import transcriber_pretrained, transcriber_speller
|
||||||
|
|
||||||
|
# from pymongo import MongoClient
|
||||||
|
|
||||||
|
st.title("ASR Speller Validation")
|
||||||
|
dataset_path: Path = Path("/dataset/asr_data/call_alphanum_v3")
|
||||||
|
manifest_path = dataset_path / Path("test_manifest.json")
|
||||||
|
# print(manifest_path)
|
||||||
|
with manifest_path.open("r") as pf:
|
||||||
|
pnr_jsonl = pf.readlines()
|
||||||
|
pnr_data = [json.loads(i) for i in pnr_jsonl]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# pnr_data = MongoClient("mongodb://localhost:27017/").test.asr_pnr
|
||||||
|
# sample_no = 0
|
||||||
|
sample_no = (
|
||||||
|
st.slider(
|
||||||
|
"Sample",
|
||||||
|
min_value=1,
|
||||||
|
max_value=len(pnr_data),
|
||||||
|
value=1,
|
||||||
|
step=1,
|
||||||
|
format=None,
|
||||||
|
key=None,
|
||||||
|
)
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
sample = pnr_data[sample_no]
|
||||||
|
st.write(f"Sample No: {sample_no+1} of {len(pnr_data)}")
|
||||||
|
audio_path = Path(sample["audio_filepath"])
|
||||||
|
# st.write(f"Audio Path:{audio_path}")
|
||||||
|
aud_seg = AudioSegment.from_wav(audio_path) # .set_channels(1).set_sample_width(2).set_frame_rate(24000)
|
||||||
|
st.sidebar.text("Transcription")
|
||||||
|
st.sidebar.text(f"Pretrained:{transcriber_pretrained(aud_seg.raw_data)}")
|
||||||
|
st.sidebar.text(f"Speller:{transcriber_speller(aud_seg.raw_data)}")
|
||||||
|
st.sidebar.text(f"Expected: {audio_path.stem}")
|
||||||
|
spell_text = sample["text"]
|
||||||
|
st.sidebar.text(f"Spelled: {spell_text}")
|
||||||
|
st.audio(audio_path.open("rb"))
|
||||||
|
selected = st.radio("The Audio is", ("Correct", "Incorrect", "Inaudible"))
|
||||||
|
corrected = audio_path.stem
|
||||||
|
if selected == "Incorrect":
|
||||||
|
corrected = st.text_input("Actual:", value=corrected)
|
||||||
|
# content = ''
|
||||||
|
if sample_no > 0 and st.button("Previous"):
|
||||||
|
sample_no -= 1
|
||||||
|
if st.button("Next"):
|
||||||
|
st.write(sample_no, selected, corrected)
|
||||||
|
sample_no += 1
|
||||||
|
|
||||||
|
(y, sr) = librosa.load(audio_path)
|
||||||
|
librosa.display.waveplot(y=y, sr=sr)
|
||||||
|
# arr = np.random.normal(1, 1, size=100)
|
||||||
|
# plt.hist(arr, bins=20)
|
||||||
|
st.sidebar.pyplot()
|
||||||
|
|
||||||
|
|
||||||
|
# def main():
|
||||||
|
# app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
import streamlit.ReportThread as ReportThread
|
||||||
|
from streamlit.ScriptRequestQueue import RerunData
|
||||||
|
from streamlit.ScriptRunner import RerunException
|
||||||
|
from streamlit.server.Server import Server
|
||||||
|
|
||||||
|
|
||||||
|
def rerun():
|
||||||
|
"""Rerun a Streamlit app from the top!"""
|
||||||
|
widget_states = _get_widget_states()
|
||||||
|
raise RerunException(RerunData(widget_states))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_widget_states():
|
||||||
|
# Hack to get the session object from Streamlit.
|
||||||
|
|
||||||
|
ctx = ReportThread.get_report_ctx()
|
||||||
|
|
||||||
|
session = None
|
||||||
|
|
||||||
|
current_server = Server.get_current()
|
||||||
|
if hasattr(current_server, '_session_infos'):
|
||||||
|
# Streamlit < 0.56
|
||||||
|
session_infos = Server.get_current()._session_infos.values()
|
||||||
|
else:
|
||||||
|
session_infos = Server.get_current()._session_info_by_id.values()
|
||||||
|
|
||||||
|
for session_info in session_infos:
|
||||||
|
if session_info.session.enqueue == ctx.enqueue:
|
||||||
|
session = session_info.session
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Oh noes. Couldn't get your Streamlit Session object"
|
||||||
|
"Are you doing something fancy with threads?"
|
||||||
|
)
|
||||||
|
# Got the session object!
|
||||||
|
|
||||||
|
return session._widget_states
|
||||||
|
|
@ -0,0 +1,171 @@
|
||||||
|
import json
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
from nemo.collections.asr.metrics import word_error_rate
|
||||||
|
import librosa
|
||||||
|
import librosa.display
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pydub import AudioSegment
|
||||||
|
import pymongo
|
||||||
|
from .jasper_client import transcriber_pretrained, transcriber_speller
|
||||||
|
from .st_rerun import rerun
|
||||||
|
|
||||||
|
st.title("ASR Speller Validation")
|
||||||
|
|
||||||
|
|
||||||
|
def clear_mongo_corrections():
|
||||||
|
col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
|
||||||
|
col.delete_many({"type": "correction"})
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_datapoint(idx, sample):
|
||||||
|
res = dict(sample)
|
||||||
|
res["real_idx"] = idx
|
||||||
|
audio_path = Path(sample["audio_filepath"])
|
||||||
|
res["audio_path"] = audio_path
|
||||||
|
res["gold_chars"] = audio_path.stem
|
||||||
|
res["gold_phone"] = sample["text"]
|
||||||
|
aud_seg = (
|
||||||
|
AudioSegment.from_wav(audio_path)
|
||||||
|
.set_channels(1)
|
||||||
|
.set_sample_width(2)
|
||||||
|
.set_frame_rate(24000)
|
||||||
|
)
|
||||||
|
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||||
|
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||||
|
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
|
||||||
|
(y, sr) = librosa.load(audio_path)
|
||||||
|
plt.tight_layout()
|
||||||
|
librosa.display.waveplot(y=y, sr=sr)
|
||||||
|
wav_plot_f = BytesIO()
|
||||||
|
plt.savefig(wav_plot_f, format="png", dpi=50)
|
||||||
|
plt.close()
|
||||||
|
wav_plot_f.seek(0)
|
||||||
|
res["plot_png"] = wav_plot_f
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if not hasattr(st, "mongo_connected"):
|
||||||
|
st.mongoclient = pymongo.MongoClient(
|
||||||
|
"mongodb://localhost:27017/"
|
||||||
|
).test.asr_validation
|
||||||
|
mongo_conn = st.mongoclient
|
||||||
|
|
||||||
|
def current_cursor_fn():
|
||||||
|
# mongo_conn = st.mongoclient
|
||||||
|
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
|
||||||
|
cursor_val = cursor_obj["cursor"]
|
||||||
|
return cursor_val
|
||||||
|
|
||||||
|
def update_cursor_fn(val=0):
|
||||||
|
mongo_conn.find_one_and_update(
|
||||||
|
{"type": "current_cursor"},
|
||||||
|
{"$set": {"type": "current_cursor", "cursor": val}},
|
||||||
|
upsert=True,
|
||||||
|
)
|
||||||
|
rerun()
|
||||||
|
|
||||||
|
def get_correction_entry_fn(code):
|
||||||
|
# mongo_conn = st.mongoclient
|
||||||
|
# cursor_obj = mongo_conn.find_one({"type": "correction", "code": code})
|
||||||
|
# cursor_val = cursor_obj["cursor"]
|
||||||
|
return mongo_conn.find_one(
|
||||||
|
{"type": "correction", "code": code}, projection={"_id": False}
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_entry_fn(code, value):
|
||||||
|
mongo_conn.find_one_and_update(
|
||||||
|
{"type": "correction", "code": code},
|
||||||
|
{"$set": {"value": value}},
|
||||||
|
upsert=True,
|
||||||
|
)
|
||||||
|
rerun()
|
||||||
|
|
||||||
|
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
|
||||||
|
if not cursor_obj:
|
||||||
|
update_cursor_fn(0)
|
||||||
|
st.get_current_cursor = current_cursor_fn
|
||||||
|
st.update_cursor = update_cursor_fn
|
||||||
|
st.get_correction_entry = get_correction_entry_fn
|
||||||
|
st.update_entry = update_entry_fn
|
||||||
|
st.mongo_connected = True
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache(hash_funcs={"rpyc.core.netref.builtins.method": lambda _: None})
|
||||||
|
def preprocess_dataset(dataset_path: Path = Path("/dataset/asr_data/call_alphanum_v3")):
|
||||||
|
print("misssed cache : preprocess_dataset")
|
||||||
|
dataset_path: Path = Path("/dataset/asr_data/call_alphanum_v3")
|
||||||
|
manifest_path = dataset_path / Path("test_manifest.json")
|
||||||
|
with manifest_path.open("r") as pf:
|
||||||
|
pnr_jsonl = pf.readlines()
|
||||||
|
pnr_data = [
|
||||||
|
preprocess_datapoint(i, json.loads(v))
|
||||||
|
for i, v in enumerate(tqdm(pnr_jsonl))
|
||||||
|
]
|
||||||
|
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
pnr_data = preprocess_dataset()
|
||||||
|
sample_no = st.get_current_cursor()
|
||||||
|
sample = pnr_data[sample_no]
|
||||||
|
st.markdown(
|
||||||
|
f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['gold_phone']}*"
|
||||||
|
)
|
||||||
|
new_sample = st.number_input(
|
||||||
|
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data)
|
||||||
|
)
|
||||||
|
if new_sample != sample_no + 1:
|
||||||
|
st.update_cursor(new_sample - 1)
|
||||||
|
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||||
|
st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**")
|
||||||
|
st.sidebar.markdown(f"Expected Speech: *{sample['gold_phone']}*")
|
||||||
|
st.sidebar.title("Results:")
|
||||||
|
st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}")
|
||||||
|
st.sidebar.text(f"Speller:{sample['speller_asr']}")
|
||||||
|
|
||||||
|
st.sidebar.title(f"WER: {sample['wer']:.2f}%")
|
||||||
|
# (y, sr) = librosa.load(sample["audio_path"])
|
||||||
|
# librosa.display.waveplot(y=y, sr=sr)
|
||||||
|
# st.sidebar.pyplot(fig=sample["plot_fig"])
|
||||||
|
st.sidebar.image(sample["plot_png"])
|
||||||
|
st.audio(sample["audio_path"].open("rb"))
|
||||||
|
corrected = sample["gold_chars"]
|
||||||
|
correction_entry = st.get_correction_entry(sample["gold_chars"])
|
||||||
|
selected_idx = 0
|
||||||
|
options = ("Correct", "Incorrect", "Inaudible")
|
||||||
|
if correction_entry:
|
||||||
|
selected_idx = options.index(correction_entry["value"]["status"])
|
||||||
|
corrected = correction_entry["value"]["correction"]
|
||||||
|
selected = st.radio("The Audio is", options, index=selected_idx)
|
||||||
|
if selected == "Incorrect":
|
||||||
|
corrected = st.text_input("Actual:", value=corrected)
|
||||||
|
if selected == "Inaudible":
|
||||||
|
corrected = ""
|
||||||
|
if st.button("Submit"):
|
||||||
|
correct_code = corrected.replace(" ", "").upper()
|
||||||
|
st.update_entry(
|
||||||
|
sample["gold_chars"], {"status": selected, "correction": correct_code}
|
||||||
|
)
|
||||||
|
if correction_entry:
|
||||||
|
st.markdown(
|
||||||
|
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
||||||
|
)
|
||||||
|
# st.markdown(
|
||||||
|
# ",".join(
|
||||||
|
# [
|
||||||
|
# "**" + str(p["real_idx"]) + "**"
|
||||||
|
# if p["real_idx"] == sample["real_idx"]
|
||||||
|
# else str(p["real_idx"])
|
||||||
|
# for p in pnr_data
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
8
setup.py
8
setup.py
|
|
@ -25,6 +25,14 @@ extra_requirements = {
|
||||||
"typer[all]==0.1.1",
|
"typer[all]==0.1.1",
|
||||||
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
||||||
],
|
],
|
||||||
|
"validation": [
|
||||||
|
"rpyc~=4.1.4",
|
||||||
|
"tqdm~=4.39.0",
|
||||||
|
"librosa==0.7.2",
|
||||||
|
"pydub~=0.23.1",
|
||||||
|
"streamlit==0.58.0",
|
||||||
|
"stringcase==1.2.0"
|
||||||
|
]
|
||||||
# "train": [
|
# "train": [
|
||||||
# "torchaudio==0.5.0",
|
# "torchaudio==0.5.0",
|
||||||
# "torch-stft==0.1.4",
|
# "torch-stft==0.1.4",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
import runpy
|
||||||
|
|
||||||
|
runpy.run_module("jasper.data_utils.validation.ui", run_name="__main__", alter_sys=True)
|
||||||
Loading…
Reference in New Issue