1. added support for mono/dual channel rev transcripts

2. handle errors when extracting datapoints from rev meta data
3. added suport for annotation only task when dumping ui data
Malar Kannan 2020-05-27 15:19:25 +05:30
parent 1f2bedc156
commit 1acf9e403c
4 changed files with 86 additions and 34 deletions

View File

@ -88,7 +88,7 @@ def extract_data(
- datetime.datetime(1900, 1, 1) - datetime.datetime(1900, 1, 1)
).total_seconds() * 1000 ).total_seconds() * 1000
def asr_data_generator(wav_seg, wav_path, meta): def dual_asr_data_generator(wav_seg, wav_path, meta):
left_audio, right_audio = wav_seg.split_to_mono() left_audio, right_audio = wav_seg.split_to_mono()
channel_map = {"Agent": right_audio, "Client": left_audio} channel_map = {"Agent": right_audio, "Client": left_audio}
monologues = lens["monologues"].Each().collect()(meta) monologues = lens["monologues"].Each().collect()(meta)
@ -113,6 +113,7 @@ def extract_data(
) )
except IndexError: except IndexError:
print(f'error when loading timestamp events in wav:{wav_path} skipping.') print(f'error when loading timestamp events in wav:{wav_path} skipping.')
continue
# offset by 500 msec to include first vad? discarded audio # offset by 500 msec to include first vad? discarded audio
full_tscript_wav_seg = speaker_channel[time_to_msecs(start_time) - 500 : time_to_msecs(end_time)] full_tscript_wav_seg = speaker_channel[time_to_msecs(start_time) - 500 : time_to_msecs(end_time)]
@ -124,10 +125,43 @@ def extract_data(
text_clean = re.sub(r"\[.*\]", "", text) text_clean = re.sub(r"\[.*\]", "", text)
yield text_clean, tscript_wav_seg.duration_seconds, tscript_wav yield text_clean, tscript_wav_seg.duration_seconds, tscript_wav
def mono_asr_data_generator(wav_seg, wav_path, meta):
monologues = lens["monologues"].Each().collect()(meta)
for monologue in monologues:
try:
start_time = (
lens["elements"]
.Each()
.Filter(lambda x: "timestamp" in x)["timestamp"]
.collect()(monologue)[0]
)
end_time = (
lens["elements"]
.Each()
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
.collect()(monologue)[-1]
)
except IndexError:
print(f'error when loading timestamp events in wav:{wav_path} skipping.')
continue
# offset by 500 msec to include first vad? discarded audio
full_tscript_wav_seg = wav_seg[time_to_msecs(start_time) - 500 : time_to_msecs(end_time)]
tscript_wav_seg = strip_silence(full_tscript_wav_seg)
tscript_wav_fb = BytesIO()
tscript_wav_seg.export(tscript_wav_fb, format="wav")
tscript_wav = tscript_wav_fb.getvalue()
text = "".join(lens["elements"].Each()["value"].collect()(monologue))
text_clean = re.sub(r"\[.*\]", "", text)
yield text_clean, tscript_wav_seg.duration_seconds, tscript_wav
def generate_rev_asr_data(): def generate_rev_asr_data():
full_asr_data = [] full_asr_data = []
total_duration = 0 total_duration = 0
for wav, wav_path, ev in wav_event_generator(call_audio_dir): for wav, wav_path, ev in wav_event_generator(call_audio_dir):
if wav.channels > 2:
print(f'skipping many channel audio {wav_path}')
asr_data_generator = mono_asr_data_generator if wav.channels == 1 else dual_asr_data_generator
asr_data = asr_data_generator(wav, wav_path, ev) asr_data = asr_data_generator(wav, wav_path, ev)
total_duration += wav.duration_seconds total_duration += wav.duration_seconds
full_asr_data.append(asr_data) full_asr_data.append(asr_data)

View File

@ -0,0 +1 @@

View File

@ -16,16 +16,12 @@ from ..utils import (
app = typer.Typer() app = typer.Typer()
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr): def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import librosa import librosa
import librosa.display import librosa.display
from pydub import AudioSegment from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate from nemo.collections.asr.metrics import word_error_rate
from jasper.client import (
transcriber_pretrained,
transcriber_speller,
)
try: try:
res = dict(sample) res = dict(sample)
@ -40,13 +36,18 @@ def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
.set_sample_width(2) .set_sample_width(2)
.set_frame_rate(24000) .set_frame_rate(24000)
) )
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) if not annotation_only:
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]]) from jasper.client import transcriber_pretrained, transcriber_speller
if use_domain_asr:
res["domain_asr"] = transcriber_speller(aud_seg.raw_data) res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["domain_wer"] = word_error_rate( res["pretrained_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]] [res["text"]], [res["pretrained_asr"]]
) )
if use_domain_asr:
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]]
)
wav_plot_path = ( wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png") rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
) )
@ -67,9 +68,14 @@ def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
@app.command() @app.command()
def dump_validation_ui_data( def dump_validation_ui_data(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), data_manifest_path: Path = typer.Option(
dump_path: Path = Path("./data/valiation_data/ui_dump.json"), Path("./data/asr_data/call_alphanum/manifest.json"), show_default=True
),
dump_path: Path = typer.Option(
Path("./data/valiation_data/ui_dump.json"), show_default=True
),
use_domain_asr: bool = True, use_domain_asr: bool = True,
annotation_only: bool = True,
): ):
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
@ -86,6 +92,7 @@ def dump_validation_ui_data(
data_manifest_path.parent, data_manifest_path.parent,
json.loads(v), json.loads(v),
use_domain_asr, use_domain_asr,
annotation_only,
) )
for i, v in enumerate(pnr_jsonl) for i, v in enumerate(pnr_jsonl)
] ]
@ -94,7 +101,7 @@ def dump_validation_ui_data(
return f() return f()
with ThreadPoolExecutor() as exe: with ThreadPoolExecutor() as exe:
print("starting all plot tasks") print("starting all preprocess tasks")
pnr_data = filter( pnr_data = filter(
None, None,
list( list(
@ -106,9 +113,16 @@ def dump_validation_ui_data(
) )
), ),
) )
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer" if annotation_only:
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True) result = pnr_data
ui_config = {"use_domain_asr": use_domain_asr, "data": result} else:
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
ui_config = {
"use_domain_asr": use_domain_asr,
"data": result,
"annotation_only": annotation_only,
}
ExtendedPath(dump_path).write_json(ui_config) ExtendedPath(dump_path).write_json(ui_config)
@ -171,7 +185,9 @@ def update_corrections(
elif d["chars"] in correction_map: elif d["chars"] in correction_map:
correct_text = correction_map[d["chars"]] correct_text = correction_map[d["chars"]]
if skip_incorrect: if skip_incorrect:
print(f'skipping incorrect {d["audio_path"]} corrected to {correct_text}') print(
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
)
else: else:
renamed_set.add(correct_text) renamed_set.add(correct_text)
new_name = str(Path(correct_text).with_suffix(".wav")) new_name = str(Path(correct_text).with_suffix(".wav"))

View File

@ -61,14 +61,18 @@ def load_ui_data(validation_ui_data_path: Path):
def main(manifest: Path): def main(manifest: Path):
ui_config = load_ui_data(manifest) ui_config = load_ui_data(manifest)
asr_data = ui_config["data"] asr_data = ui_config["data"]
use_domain_asr = ui_config["use_domain_asr"] use_domain_asr = ui_config.get("use_domain_asr", True)
annotation_only = ui_config.get("annotation_only", False)
sample_no = st.get_current_cursor() sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0: if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0") print("Invalid samplno resetting to 0")
st.update_cursor(0) st.update_cursor(0)
sample = asr_data[sample_no] sample = asr_data[sample_no]
title_type = "Speller " if use_domain_asr else "" title_type = "Speller " if use_domain_asr else ""
st.title(f"ASR {title_type}Validation") if annotation_only:
st.title(f"ASR Annotation")
else:
st.title(f"ASR {title_type}Validation")
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else "" addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text) st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
new_sample = st.number_input( new_sample = st.number_input(
@ -78,15 +82,16 @@ def main(manifest: Path):
st.update_cursor(new_sample - 1) st.update_cursor(new_sample - 1)
st.sidebar.title(f"Details: [{sample['real_idx']}]") st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**") st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
if use_domain_asr: if not annotation_only:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*") if use_domain_asr:
st.sidebar.title("Results:") st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**") st.sidebar.title("Results:")
if use_domain_asr: st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**") if use_domain_asr:
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%") st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
else: st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
st.sidebar.image(Path(sample["plot_path"]).read_bytes()) st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb")) st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text # set default to text
@ -113,10 +118,6 @@ def main(manifest: Path):
st.markdown( st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
) )
# if st.button("Previous Untagged"):
# pass
# if st.button("Next Untagged"):
# pass
text_sample = st.text_input("Go to Text:", value='') text_sample = st.text_input("Go to Text:", value='')
if text_sample != '': if text_sample != '':
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample] candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample]