diff --git a/jasper/data/validation/process.py b/jasper/data/validation/process.py index b9c7d2c..c401d87 100644 --- a/jasper/data/validation/process.py +++ b/jasper/data/validation/process.py @@ -16,7 +16,7 @@ from ..utils import ( app = typer.Typer() -def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only): +def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots): import matplotlib.pyplot as plt import librosa import librosa.display @@ -48,19 +48,20 @@ def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only) res["domain_wer"] = word_error_rate( [res["spoken"]], [res["pretrained_asr"]] ) - wav_plot_path = ( - rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png") - ) - if not wav_plot_path.exists(): - fig = plt.Figure() - ax = fig.add_subplot() - (y, sr) = librosa.load(audio_path) - librosa.display.waveplot(y=y, sr=sr, ax=ax) - with wav_plot_path.open("wb") as wav_plot_f: - fig.set_tight_layout(True) - fig.savefig(wav_plot_f, format="png", dpi=50) - # fig.close() - res["plot_path"] = str(wav_plot_path) + if enable_plots: + wav_plot_path = ( + rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png") + ) + if not wav_plot_path.exists(): + fig = plt.Figure() + ax = fig.add_subplot() + (y, sr) = librosa.load(audio_path) + librosa.display.waveplot(y=y, sr=sr, ax=ax) + with wav_plot_path.open("wb") as wav_plot_f: + fig.set_tight_layout(True) + fig.savefig(wav_plot_f, format="png", dpi=50) + # fig.close() + res["plot_path"] = str(wav_plot_path) return res except BaseException as e: print(f'failed on {idx}: {sample["audio_filepath"]} with {e}') @@ -76,6 +77,7 @@ def dump_validation_ui_data( ), use_domain_asr: bool = True, annotation_only: bool = True, + enable_plots: bool = True, ): from concurrent.futures import ThreadPoolExecutor from functools import partial @@ -93,6 +95,7 @@ def dump_validation_ui_data( json.loads(v), use_domain_asr, annotation_only, + enable_plots, ) for i, v in enumerate(pnr_jsonl) ] @@ -122,6 +125,7 @@ def dump_validation_ui_data( "use_domain_asr": use_domain_asr, "data": result, "annotation_only": annotation_only, + "enable_plots": enable_plots, } ExtendedPath(dump_path).write_json(ui_config) diff --git a/jasper/data/validation/ui.py b/jasper/data/validation/ui.py index 7ba67d3..f013677 100644 --- a/jasper/data/validation/ui.py +++ b/jasper/data/validation/ui.py @@ -63,6 +63,7 @@ def main(manifest: Path): asr_data = ui_config["data"] use_domain_asr = ui_config.get("use_domain_asr", True) annotation_only = ui_config.get("annotation_only", False) + enable_plots = ui_config.get("enable_plots", True) sample_no = st.get_current_cursor() if len(asr_data) - 1 < sample_no or sample_no < 0: print("Invalid samplno resetting to 0") @@ -92,7 +93,8 @@ def main(manifest: Path): st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%") else: st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") - st.sidebar.image(Path(sample["plot_path"]).read_bytes()) + if enable_plots: + st.sidebar.image(Path(sample["plot_path"]).read_bytes()) st.audio(Path(sample["audio_path"]).open("rb")) # set default to text corrected = sample["text"]