added option to disable plots during validation
parent
7ff2db3e2e
commit
a38789d0c3
|
|
@ -16,7 +16,7 @@ from ..utils import (
|
|||
app = typer.Typer()
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only):
|
||||
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots):
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
|
|
@ -48,19 +48,20 @@ def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only)
|
|||
res["domain_wer"] = word_error_rate(
|
||||
[res["spoken"]], [res["pretrained_asr"]]
|
||||
)
|
||||
wav_plot_path = (
|
||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
fig = plt.Figure()
|
||||
ax = fig.add_subplot()
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
||||
with wav_plot_path.open("wb") as wav_plot_f:
|
||||
fig.set_tight_layout(True)
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
# fig.close()
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
if enable_plots:
|
||||
wav_plot_path = (
|
||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
fig = plt.Figure()
|
||||
ax = fig.add_subplot()
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
||||
with wav_plot_path.open("wb") as wav_plot_f:
|
||||
fig.set_tight_layout(True)
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
# fig.close()
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
return res
|
||||
except BaseException as e:
|
||||
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
||||
|
|
@ -76,6 +77,7 @@ def dump_validation_ui_data(
|
|||
),
|
||||
use_domain_asr: bool = True,
|
||||
annotation_only: bool = True,
|
||||
enable_plots: bool = True,
|
||||
):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
|
@ -93,6 +95,7 @@ def dump_validation_ui_data(
|
|||
json.loads(v),
|
||||
use_domain_asr,
|
||||
annotation_only,
|
||||
enable_plots,
|
||||
)
|
||||
for i, v in enumerate(pnr_jsonl)
|
||||
]
|
||||
|
|
@ -122,6 +125,7 @@ def dump_validation_ui_data(
|
|||
"use_domain_asr": use_domain_asr,
|
||||
"data": result,
|
||||
"annotation_only": annotation_only,
|
||||
"enable_plots": enable_plots,
|
||||
}
|
||||
ExtendedPath(dump_path).write_json(ui_config)
|
||||
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ def main(manifest: Path):
|
|||
asr_data = ui_config["data"]
|
||||
use_domain_asr = ui_config.get("use_domain_asr", True)
|
||||
annotation_only = ui_config.get("annotation_only", False)
|
||||
enable_plots = ui_config.get("enable_plots", True)
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
|
|
@ -92,7 +93,8 @@ def main(manifest: Path):
|
|||
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
||||
else:
|
||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
if enable_plots:
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
||||
# set default to text
|
||||
corrected = sample["text"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue