1. refactored streamlit code
2. fixed issues in data manifest handling
parent
41074a1bca
commit
4fd05a56d0
|
|
@ -10,44 +10,15 @@ import matplotlib.pyplot as plt
|
|||
from tqdm import tqdm
|
||||
from pydub import AudioSegment
|
||||
import pymongo
|
||||
import typer
|
||||
from .jasper_client import transcriber_pretrained, transcriber_speller
|
||||
from .st_rerun import rerun
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
st.title("ASR Speller Validation")
|
||||
|
||||
|
||||
def clear_mongo_corrections():
|
||||
col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
|
||||
col.delete_many({"type": "correction"})
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, sample):
|
||||
res = dict(sample)
|
||||
res["real_idx"] = idx
|
||||
audio_path = Path(sample["audio_filepath"])
|
||||
res["audio_path"] = audio_path
|
||||
res["gold_chars"] = audio_path.stem
|
||||
res["gold_phone"] = sample["text"]
|
||||
aud_seg = (
|
||||
AudioSegment.from_wav(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
plt.tight_layout()
|
||||
librosa.display.waveplot(y=y, sr=sr)
|
||||
wav_plot_f = BytesIO()
|
||||
plt.savefig(wav_plot_f, format="png", dpi=50)
|
||||
plt.close()
|
||||
wav_plot_f.seek(0)
|
||||
res["plot_png"] = wav_plot_f
|
||||
return res
|
||||
|
||||
|
||||
if not hasattr(st, "mongo_connected"):
|
||||
st.mongoclient = pymongo.MongoClient(
|
||||
"mongodb://localhost:27017/"
|
||||
|
|
@ -94,23 +65,54 @@ if not hasattr(st, "mongo_connected"):
|
|||
st.mongo_connected = True
|
||||
|
||||
|
||||
def clear_mongo_corrections():
|
||||
col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
|
||||
col.delete_many({"type": "correction"})
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, rel, sample):
|
||||
res = dict(sample)
|
||||
res["real_idx"] = idx
|
||||
audio_path = rel / Path(sample["audio_filepath"])
|
||||
res["audio_path"] = audio_path
|
||||
res["gold_chars"] = audio_path.stem
|
||||
res["gold_phone"] = sample["text"]
|
||||
aud_seg = (
|
||||
AudioSegment.from_wav(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
plt.tight_layout()
|
||||
librosa.display.waveplot(y=y, sr=sr)
|
||||
wav_plot_f = BytesIO()
|
||||
plt.savefig(wav_plot_f, format="png", dpi=50)
|
||||
plt.close()
|
||||
wav_plot_f.seek(0)
|
||||
res["plot_png"] = wav_plot_f
|
||||
return res
|
||||
|
||||
|
||||
@st.cache(hash_funcs={"rpyc.core.netref.builtins.method": lambda _: None})
|
||||
def preprocess_dataset(dataset_path: Path = Path("/dataset/asr_data/call_alphanum_v3")):
|
||||
print("misssed cache : preprocess_dataset")
|
||||
dataset_path: Path = Path("/dataset/asr_data/call_alphanum_v3")
|
||||
manifest_path = dataset_path / Path("test_manifest.json")
|
||||
with manifest_path.open("r") as pf:
|
||||
def preprocess_dataset(data_manifest_path: Path):
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_data = [
|
||||
preprocess_datapoint(i, json.loads(v))
|
||||
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
|
||||
for i, v in enumerate(tqdm(pnr_jsonl))
|
||||
]
|
||||
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
pnr_data = preprocess_dataset()
|
||||
@app.command()
|
||||
def main(manifest: Path):
|
||||
pnr_data = preprocess_dataset(manifest)
|
||||
sample_no = st.get_current_cursor()
|
||||
sample = pnr_data[sample_no]
|
||||
st.markdown(
|
||||
|
|
@ -128,7 +130,7 @@ def main():
|
|||
st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}")
|
||||
st.sidebar.text(f"Speller:{sample['speller_asr']}")
|
||||
|
||||
st.sidebar.title(f"WER: {sample['wer']:.2f}%")
|
||||
st.sidebar.title(f"Speller WER: {sample['wer']:.2f}%")
|
||||
# (y, sr) = librosa.load(sample["audio_path"])
|
||||
# librosa.display.waveplot(y=y, sr=sr)
|
||||
# st.sidebar.pyplot(fig=sample["plot_fig"])
|
||||
|
|
@ -168,4 +170,7 @@ def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in New Issue