added option to disable plots during validation

Malar Kannan 2020-05-27 15:43:03 +05:30
parent 7ff2db3e2e
commit a38789d0c3
2 changed files with 21 additions and 15 deletions

View File

@ -16,7 +16,7 @@ from ..utils import (
app = typer.Typer() app = typer.Typer()
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only): def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import librosa import librosa
import librosa.display import librosa.display
@ -48,19 +48,20 @@ def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only)
res["domain_wer"] = word_error_rate( res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]] [res["spoken"]], [res["pretrained_asr"]]
) )
wav_plot_path = ( if enable_plots:
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png") wav_plot_path = (
) rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
if not wav_plot_path.exists(): )
fig = plt.Figure() if not wav_plot_path.exists():
ax = fig.add_subplot() fig = plt.Figure()
(y, sr) = librosa.load(audio_path) ax = fig.add_subplot()
librosa.display.waveplot(y=y, sr=sr, ax=ax) (y, sr) = librosa.load(audio_path)
with wav_plot_path.open("wb") as wav_plot_f: librosa.display.waveplot(y=y, sr=sr, ax=ax)
fig.set_tight_layout(True) with wav_plot_path.open("wb") as wav_plot_f:
fig.savefig(wav_plot_f, format="png", dpi=50) fig.set_tight_layout(True)
# fig.close() fig.savefig(wav_plot_f, format="png", dpi=50)
res["plot_path"] = str(wav_plot_path) # fig.close()
res["plot_path"] = str(wav_plot_path)
return res return res
except BaseException as e: except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}') print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
@ -76,6 +77,7 @@ def dump_validation_ui_data(
), ),
use_domain_asr: bool = True, use_domain_asr: bool = True,
annotation_only: bool = True, annotation_only: bool = True,
enable_plots: bool = True,
): ):
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
@ -93,6 +95,7 @@ def dump_validation_ui_data(
json.loads(v), json.loads(v),
use_domain_asr, use_domain_asr,
annotation_only, annotation_only,
enable_plots,
) )
for i, v in enumerate(pnr_jsonl) for i, v in enumerate(pnr_jsonl)
] ]
@ -122,6 +125,7 @@ def dump_validation_ui_data(
"use_domain_asr": use_domain_asr, "use_domain_asr": use_domain_asr,
"data": result, "data": result,
"annotation_only": annotation_only, "annotation_only": annotation_only,
"enable_plots": enable_plots,
} }
ExtendedPath(dump_path).write_json(ui_config) ExtendedPath(dump_path).write_json(ui_config)

View File

@ -63,6 +63,7 @@ def main(manifest: Path):
asr_data = ui_config["data"] asr_data = ui_config["data"]
use_domain_asr = ui_config.get("use_domain_asr", True) use_domain_asr = ui_config.get("use_domain_asr", True)
annotation_only = ui_config.get("annotation_only", False) annotation_only = ui_config.get("annotation_only", False)
enable_plots = ui_config.get("enable_plots", True)
sample_no = st.get_current_cursor() sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0: if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0") print("Invalid samplno resetting to 0")
@ -92,7 +93,8 @@ def main(manifest: Path):
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%") st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
else: else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
st.sidebar.image(Path(sample["plot_path"]).read_bytes()) if enable_plots:
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb")) st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text # set default to text
corrected = sample["text"] corrected = sample["text"]