94 lines
3.4 KiB
Python
94 lines
3.4 KiB
Python
import typer
|
|
from itertools import chain
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
@app.command()
|
|
def extract_data(
|
|
call_audio_dir: Path = Path("/dataset/png_prod/call_audio"),
|
|
call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"),
|
|
output_dir: Path = Path("./data"),
|
|
dataset_name: str = "png_gcp_2jan",
|
|
verbose: bool = False,
|
|
):
|
|
from pydub import AudioSegment
|
|
from .utils import ExtendedPath, asr_data_writer
|
|
from lenses import lens
|
|
|
|
call_asr_data: Path = output_dir / Path("asr_data")
|
|
call_asr_data.mkdir(exist_ok=True, parents=True)
|
|
|
|
def wav_event_generator(call_audio_dir):
|
|
for wav_path in call_audio_dir.glob("**/*.wav"):
|
|
if verbose:
|
|
typer.echo(f"loading events for file {wav_path}")
|
|
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
|
|
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
|
|
meta_path = call_meta_dir / rel_meta_path
|
|
events = ExtendedPath(meta_path).read_json()
|
|
yield call_wav, wav_path, events
|
|
|
|
def contains_asr(x):
|
|
return "AsrResult" in x
|
|
|
|
def channel(n):
|
|
def filter_func(ev):
|
|
return (
|
|
ev["AsrResult"]["Channel"] == n
|
|
if "Channel" in ev["AsrResult"]
|
|
else n == 0
|
|
)
|
|
|
|
return filter_func
|
|
|
|
def compute_endtime(call_wav, state):
|
|
for (i, st) in enumerate(state):
|
|
start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0)
|
|
transcript = st["AsrResult"]["Alternatives"][0]["Transcript"]
|
|
if i + 1 < len(state):
|
|
end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
|
|
else:
|
|
end_time = call_wav.duration_seconds
|
|
code_seg = call_wav[start_time * 1000 : end_time * 1000]
|
|
code_fb = BytesIO()
|
|
code_seg.export(code_fb, format="wav")
|
|
code_wav = code_fb.getvalue()
|
|
# only of some audio data is present yield it
|
|
if code_seg.duration_seconds >= 0.5:
|
|
yield transcript, code_seg.duration_seconds, code_wav
|
|
|
|
def asr_data_generator(call_wav, call_wav_fname, events):
|
|
call_wav_0, call_wav_1 = call_wav.split_to_mono()
|
|
asr_events = lens["Events"].Each()["Event"].Filter(contains_asr)
|
|
call_evs_0 = asr_events.Filter(channel(0)).collect()(events)
|
|
call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
|
|
if verbose:
|
|
typer.echo(f"processing data points on {call_wav_fname}")
|
|
call_data_0 = compute_endtime(call_wav_0, call_evs_0)
|
|
call_data_1 = compute_endtime(call_wav_1, call_evs_1)
|
|
return chain(call_data_0, call_data_1)
|
|
|
|
def generate_call_asr_data():
|
|
full_asr_data = []
|
|
total_duration = 0
|
|
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
|
|
asr_data = asr_data_generator(wav, wav_path, ev)
|
|
total_duration += wav.duration_seconds
|
|
full_asr_data.append(asr_data)
|
|
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
|
|
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
|
|
typer.echo(f"written {n_dps} data points")
|
|
|
|
generate_call_asr_data()
|
|
|
|
|
|
def main():
|
|
app()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|