1. added support for mono/dual channel rev transcripts
2. handle errors when extracting datapoints from rev meta data 3. added suport for annotation only task when dumping ui data
parent
1f2bedc156
commit
1acf9e403c
|
|
@ -88,7 +88,7 @@ def extract_data(
|
||||||
- datetime.datetime(1900, 1, 1)
|
- datetime.datetime(1900, 1, 1)
|
||||||
).total_seconds() * 1000
|
).total_seconds() * 1000
|
||||||
|
|
||||||
def asr_data_generator(wav_seg, wav_path, meta):
|
def dual_asr_data_generator(wav_seg, wav_path, meta):
|
||||||
left_audio, right_audio = wav_seg.split_to_mono()
|
left_audio, right_audio = wav_seg.split_to_mono()
|
||||||
channel_map = {"Agent": right_audio, "Client": left_audio}
|
channel_map = {"Agent": right_audio, "Client": left_audio}
|
||||||
monologues = lens["monologues"].Each().collect()(meta)
|
monologues = lens["monologues"].Each().collect()(meta)
|
||||||
|
|
@ -113,6 +113,7 @@ def extract_data(
|
||||||
)
|
)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
print(f'error when loading timestamp events in wav:{wav_path} skipping.')
|
print(f'error when loading timestamp events in wav:{wav_path} skipping.')
|
||||||
|
continue
|
||||||
|
|
||||||
# offset by 500 msec to include first vad? discarded audio
|
# offset by 500 msec to include first vad? discarded audio
|
||||||
full_tscript_wav_seg = speaker_channel[time_to_msecs(start_time) - 500 : time_to_msecs(end_time)]
|
full_tscript_wav_seg = speaker_channel[time_to_msecs(start_time) - 500 : time_to_msecs(end_time)]
|
||||||
|
|
@ -124,10 +125,43 @@ def extract_data(
|
||||||
text_clean = re.sub(r"\[.*\]", "", text)
|
text_clean = re.sub(r"\[.*\]", "", text)
|
||||||
yield text_clean, tscript_wav_seg.duration_seconds, tscript_wav
|
yield text_clean, tscript_wav_seg.duration_seconds, tscript_wav
|
||||||
|
|
||||||
|
def mono_asr_data_generator(wav_seg, wav_path, meta):
|
||||||
|
monologues = lens["monologues"].Each().collect()(meta)
|
||||||
|
for monologue in monologues:
|
||||||
|
try:
|
||||||
|
start_time = (
|
||||||
|
lens["elements"]
|
||||||
|
.Each()
|
||||||
|
.Filter(lambda x: "timestamp" in x)["timestamp"]
|
||||||
|
.collect()(monologue)[0]
|
||||||
|
)
|
||||||
|
end_time = (
|
||||||
|
lens["elements"]
|
||||||
|
.Each()
|
||||||
|
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
|
||||||
|
.collect()(monologue)[-1]
|
||||||
|
)
|
||||||
|
except IndexError:
|
||||||
|
print(f'error when loading timestamp events in wav:{wav_path} skipping.')
|
||||||
|
continue
|
||||||
|
|
||||||
|
# offset by 500 msec to include first vad? discarded audio
|
||||||
|
full_tscript_wav_seg = wav_seg[time_to_msecs(start_time) - 500 : time_to_msecs(end_time)]
|
||||||
|
tscript_wav_seg = strip_silence(full_tscript_wav_seg)
|
||||||
|
tscript_wav_fb = BytesIO()
|
||||||
|
tscript_wav_seg.export(tscript_wav_fb, format="wav")
|
||||||
|
tscript_wav = tscript_wav_fb.getvalue()
|
||||||
|
text = "".join(lens["elements"].Each()["value"].collect()(monologue))
|
||||||
|
text_clean = re.sub(r"\[.*\]", "", text)
|
||||||
|
yield text_clean, tscript_wav_seg.duration_seconds, tscript_wav
|
||||||
|
|
||||||
def generate_rev_asr_data():
|
def generate_rev_asr_data():
|
||||||
full_asr_data = []
|
full_asr_data = []
|
||||||
total_duration = 0
|
total_duration = 0
|
||||||
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
|
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
|
||||||
|
if wav.channels > 2:
|
||||||
|
print(f'skipping many channel audio {wav_path}')
|
||||||
|
asr_data_generator = mono_asr_data_generator if wav.channels == 1 else dual_asr_data_generator
|
||||||
asr_data = asr_data_generator(wav, wav_path, ev)
|
asr_data = asr_data_generator(wav, wav_path, ev)
|
||||||
total_duration += wav.duration_seconds
|
total_duration += wav.duration_seconds
|
||||||
full_asr_data.append(asr_data)
|
full_asr_data.append(asr_data)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
|
||||||
|
|
@ -16,16 +16,12 @@ from ..utils import (
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
|
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only):
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import librosa
|
import librosa
|
||||||
import librosa.display
|
import librosa.display
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
from nemo.collections.asr.metrics import word_error_rate
|
from nemo.collections.asr.metrics import word_error_rate
|
||||||
from jasper.client import (
|
|
||||||
transcriber_pretrained,
|
|
||||||
transcriber_speller,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = dict(sample)
|
res = dict(sample)
|
||||||
|
|
@ -40,8 +36,13 @@ def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
|
||||||
.set_sample_width(2)
|
.set_sample_width(2)
|
||||||
.set_frame_rate(24000)
|
.set_frame_rate(24000)
|
||||||
)
|
)
|
||||||
|
if not annotation_only:
|
||||||
|
from jasper.client import transcriber_pretrained, transcriber_speller
|
||||||
|
|
||||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||||
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
|
res["pretrained_wer"] = word_error_rate(
|
||||||
|
[res["text"]], [res["pretrained_asr"]]
|
||||||
|
)
|
||||||
if use_domain_asr:
|
if use_domain_asr:
|
||||||
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||||
res["domain_wer"] = word_error_rate(
|
res["domain_wer"] = word_error_rate(
|
||||||
|
|
@ -67,9 +68,14 @@ def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def dump_validation_ui_data(
|
def dump_validation_ui_data(
|
||||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
data_manifest_path: Path = typer.Option(
|
||||||
dump_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
Path("./data/asr_data/call_alphanum/manifest.json"), show_default=True
|
||||||
|
),
|
||||||
|
dump_path: Path = typer.Option(
|
||||||
|
Path("./data/valiation_data/ui_dump.json"), show_default=True
|
||||||
|
),
|
||||||
use_domain_asr: bool = True,
|
use_domain_asr: bool = True,
|
||||||
|
annotation_only: bool = True,
|
||||||
):
|
):
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
@ -86,6 +92,7 @@ def dump_validation_ui_data(
|
||||||
data_manifest_path.parent,
|
data_manifest_path.parent,
|
||||||
json.loads(v),
|
json.loads(v),
|
||||||
use_domain_asr,
|
use_domain_asr,
|
||||||
|
annotation_only,
|
||||||
)
|
)
|
||||||
for i, v in enumerate(pnr_jsonl)
|
for i, v in enumerate(pnr_jsonl)
|
||||||
]
|
]
|
||||||
|
|
@ -94,7 +101,7 @@ def dump_validation_ui_data(
|
||||||
return f()
|
return f()
|
||||||
|
|
||||||
with ThreadPoolExecutor() as exe:
|
with ThreadPoolExecutor() as exe:
|
||||||
print("starting all plot tasks")
|
print("starting all preprocess tasks")
|
||||||
pnr_data = filter(
|
pnr_data = filter(
|
||||||
None,
|
None,
|
||||||
list(
|
list(
|
||||||
|
|
@ -106,9 +113,16 @@ def dump_validation_ui_data(
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
if annotation_only:
|
||||||
|
result = pnr_data
|
||||||
|
else:
|
||||||
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
||||||
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
||||||
ui_config = {"use_domain_asr": use_domain_asr, "data": result}
|
ui_config = {
|
||||||
|
"use_domain_asr": use_domain_asr,
|
||||||
|
"data": result,
|
||||||
|
"annotation_only": annotation_only,
|
||||||
|
}
|
||||||
ExtendedPath(dump_path).write_json(ui_config)
|
ExtendedPath(dump_path).write_json(ui_config)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -171,7 +185,9 @@ def update_corrections(
|
||||||
elif d["chars"] in correction_map:
|
elif d["chars"] in correction_map:
|
||||||
correct_text = correction_map[d["chars"]]
|
correct_text = correction_map[d["chars"]]
|
||||||
if skip_incorrect:
|
if skip_incorrect:
|
||||||
print(f'skipping incorrect {d["audio_path"]} corrected to {correct_text}')
|
print(
|
||||||
|
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
renamed_set.add(correct_text)
|
renamed_set.add(correct_text)
|
||||||
new_name = str(Path(correct_text).with_suffix(".wav"))
|
new_name = str(Path(correct_text).with_suffix(".wav"))
|
||||||
|
|
|
||||||
|
|
@ -61,13 +61,17 @@ def load_ui_data(validation_ui_data_path: Path):
|
||||||
def main(manifest: Path):
|
def main(manifest: Path):
|
||||||
ui_config = load_ui_data(manifest)
|
ui_config = load_ui_data(manifest)
|
||||||
asr_data = ui_config["data"]
|
asr_data = ui_config["data"]
|
||||||
use_domain_asr = ui_config["use_domain_asr"]
|
use_domain_asr = ui_config.get("use_domain_asr", True)
|
||||||
|
annotation_only = ui_config.get("annotation_only", False)
|
||||||
sample_no = st.get_current_cursor()
|
sample_no = st.get_current_cursor()
|
||||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||||
print("Invalid samplno resetting to 0")
|
print("Invalid samplno resetting to 0")
|
||||||
st.update_cursor(0)
|
st.update_cursor(0)
|
||||||
sample = asr_data[sample_no]
|
sample = asr_data[sample_no]
|
||||||
title_type = "Speller " if use_domain_asr else ""
|
title_type = "Speller " if use_domain_asr else ""
|
||||||
|
if annotation_only:
|
||||||
|
st.title(f"ASR Annotation")
|
||||||
|
else:
|
||||||
st.title(f"ASR {title_type}Validation")
|
st.title(f"ASR {title_type}Validation")
|
||||||
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
||||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
||||||
|
|
@ -78,6 +82,7 @@ def main(manifest: Path):
|
||||||
st.update_cursor(new_sample - 1)
|
st.update_cursor(new_sample - 1)
|
||||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||||
|
if not annotation_only:
|
||||||
if use_domain_asr:
|
if use_domain_asr:
|
||||||
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
||||||
st.sidebar.title("Results:")
|
st.sidebar.title("Results:")
|
||||||
|
|
@ -113,10 +118,6 @@ def main(manifest: Path):
|
||||||
st.markdown(
|
st.markdown(
|
||||||
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
||||||
)
|
)
|
||||||
# if st.button("Previous Untagged"):
|
|
||||||
# pass
|
|
||||||
# if st.button("Next Untagged"):
|
|
||||||
# pass
|
|
||||||
text_sample = st.text_input("Go to Text:", value='')
|
text_sample = st.text_input("Go to Text:", value='')
|
||||||
if text_sample != '':
|
if text_sample != '':
|
||||||
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample]
|
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue