1. refactored streamlit code

2. fixed issues in data manifest handling
Malar Kannan 2020-04-29 17:22:45 +05:30
parent 41074a1bca
commit 4fd05a56d0
1 changed files with 47 additions and 42 deletions

View File

@ -10,44 +10,15 @@ import matplotlib.pyplot as plt
from tqdm import tqdm
from pydub import AudioSegment
import pymongo
import typer
from .jasper_client import transcriber_pretrained, transcriber_speller
from .st_rerun import rerun
app = typer.Typer()
st.title("ASR Speller Validation")
def clear_mongo_corrections():
col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
col.delete_many({"type": "correction"})
def preprocess_datapoint(idx, sample):
res = dict(sample)
res["real_idx"] = idx
audio_path = Path(sample["audio_filepath"])
res["audio_path"] = audio_path
res["gold_chars"] = audio_path.stem
res["gold_phone"] = sample["text"]
aud_seg = (
AudioSegment.from_wav(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
(y, sr) = librosa.load(audio_path)
plt.tight_layout()
librosa.display.waveplot(y=y, sr=sr)
wav_plot_f = BytesIO()
plt.savefig(wav_plot_f, format="png", dpi=50)
plt.close()
wav_plot_f.seek(0)
res["plot_png"] = wav_plot_f
return res
if not hasattr(st, "mongo_connected"):
st.mongoclient = pymongo.MongoClient(
"mongodb://localhost:27017/"
@ -94,23 +65,54 @@ if not hasattr(st, "mongo_connected"):
st.mongo_connected = True
def clear_mongo_corrections():
col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
col.delete_many({"type": "correction"})
def preprocess_datapoint(idx, rel, sample):
res = dict(sample)
res["real_idx"] = idx
audio_path = rel / Path(sample["audio_filepath"])
res["audio_path"] = audio_path
res["gold_chars"] = audio_path.stem
res["gold_phone"] = sample["text"]
aud_seg = (
AudioSegment.from_wav(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
(y, sr) = librosa.load(audio_path)
plt.tight_layout()
librosa.display.waveplot(y=y, sr=sr)
wav_plot_f = BytesIO()
plt.savefig(wav_plot_f, format="png", dpi=50)
plt.close()
wav_plot_f.seek(0)
res["plot_png"] = wav_plot_f
return res
@st.cache(hash_funcs={"rpyc.core.netref.builtins.method": lambda _: None})
def preprocess_dataset(dataset_path: Path = Path("/dataset/asr_data/call_alphanum_v3")):
print("misssed cache : preprocess_dataset")
dataset_path: Path = Path("/dataset/asr_data/call_alphanum_v3")
manifest_path = dataset_path / Path("test_manifest.json")
with manifest_path.open("r") as pf:
def preprocess_dataset(data_manifest_path: Path):
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [
preprocess_datapoint(i, json.loads(v))
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
for i, v in enumerate(tqdm(pnr_jsonl))
]
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
return result
def main():
pnr_data = preprocess_dataset()
@app.command()
def main(manifest: Path):
pnr_data = preprocess_dataset(manifest)
sample_no = st.get_current_cursor()
sample = pnr_data[sample_no]
st.markdown(
@ -128,7 +130,7 @@ def main():
st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}")
st.sidebar.text(f"Speller:{sample['speller_asr']}")
st.sidebar.title(f"WER: {sample['wer']:.2f}%")
st.sidebar.title(f"Speller WER: {sample['wer']:.2f}%")
# (y, sr) = librosa.load(sample["audio_path"])
# librosa.display.waveplot(y=y, sr=sr)
# st.sidebar.pyplot(fig=sample["plot_fig"])
@ -168,4 +170,7 @@ def main():
if __name__ == "__main__":
main()
try:
app()
except SystemExit:
pass