From 1f2bedc156ffbd0123508a40f9da6f579c11644c Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Wed, 27 May 2020 14:22:44 +0530 Subject: [PATCH] 1. enabled silece stripping in chunks when recycling audio from asr logs 2. limit asr recycling to 1 min of start audio to get reliable alignments and ignoring agent channel 3. added rev recycler for generating asr dataset from rev transcripts and audio 4. update pydub dependency for silence stripping fn and removing threadpool hardcoded worker count --- jasper/data/asr_recycler.py | 18 ++- jasper/data/rev_recycler.py | 182 ++++++++++++++++++++++++++++++ jasper/data/utils.py | 11 +- jasper/data/validation/process.py | 2 +- setup.py | 5 +- 5 files changed, 208 insertions(+), 10 deletions(-) create mode 100644 jasper/data/rev_recycler.py diff --git a/jasper/data/asr_recycler.py b/jasper/data/asr_recycler.py index 5f9cfc6..cc38f8b 100644 --- a/jasper/data/asr_recycler.py +++ b/jasper/data/asr_recycler.py @@ -15,7 +15,7 @@ def extract_data( verbose: bool = False, ): from pydub import AudioSegment - from .utils import ExtendedPath, asr_data_writer + from .utils import ExtendedPath, asr_data_writer, strip_silence from lenses import lens call_asr_data: Path = output_dir / Path("asr_data") @@ -52,11 +52,15 @@ def extract_data( end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"] else: end_time = call_wav.duration_seconds - code_seg = call_wav[start_time * 1000 : end_time * 1000] + full_code_seg = call_wav[start_time * 1000 : end_time * 1000] + code_seg = strip_silence(full_code_seg) code_fb = BytesIO() code_seg.export(code_fb, format="wav") code_wav = code_fb.getvalue() - # only of some audio data is present yield it + # only starting 1 min audio has reliable alignment ignore rest + if start_time > 60: + break + # only if some reasonable audio data is present yield it if code_seg.duration_seconds >= 0.5: yield transcript, code_seg.duration_seconds, code_wav @@ -64,12 +68,14 @@ def extract_data( call_wav_0, call_wav_1 = call_wav.split_to_mono() asr_events = lens["Events"].Each()["Event"].Filter(contains_asr) call_evs_0 = asr_events.Filter(channel(0)).collect()(events) - call_evs_1 = asr_events.Filter(channel(1)).collect()(events) + # Ignoring agent channel events + # call_evs_1 = asr_events.Filter(channel(1)).collect()(events) if verbose: typer.echo(f"processing data points on {call_wav_fname}") call_data_0 = compute_endtime(call_wav_0, call_evs_0) - call_data_1 = compute_endtime(call_wav_1, call_evs_1) - return chain(call_data_0, call_data_1) + # Ignoring agent channel + # call_data_1 = compute_endtime(call_wav_1, call_evs_1) + return call_data_0 # chain(call_data_0, call_data_1) def generate_call_asr_data(): full_asr_data = [] diff --git a/jasper/data/rev_recycler.py b/jasper/data/rev_recycler.py new file mode 100644 index 0000000..bad1ab7 --- /dev/null +++ b/jasper/data/rev_recycler.py @@ -0,0 +1,182 @@ +import typer +from itertools import chain +from io import BytesIO +from pathlib import Path +import re + +app = typer.Typer() + + +@app.command() +def extract_data( + call_audio_dir: Path = typer.Option(Path("/dataset/rev/wavs"), show_default=True), + call_meta_dir: Path = typer.Option(Path("/dataset/rev/jsons"), show_default=True), + output_dir: Path = typer.Option(Path("./data"), show_default=True), + dataset_name: str = typer.Option("rev_transribed", show_default=True), + verbose: bool = False, +): + from pydub import AudioSegment + from .utils import ExtendedPath, asr_data_writer, strip_silence + from lenses import lens + import datetime + + call_asr_data: Path = output_dir / Path("asr_data") + call_asr_data.mkdir(exist_ok=True, parents=True) + + def wav_event_generator(call_audio_dir): + for wav_path in call_audio_dir.glob("**/*.wav"): + if verbose: + typer.echo(f"loading events for file {wav_path}") + call_wav = AudioSegment.from_file_using_temporary_files(wav_path) + rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir) + meta_path = call_meta_dir / rel_meta_path + if meta_path.exists(): + events = ExtendedPath(meta_path).read_json() + yield call_wav, wav_path, events + else: + typer.echo(f"missing json corresponding to {wav_path}") + + def contains_asr(x): + return "AsrResult" in x + + def channel(n): + def filter_func(ev): + return ( + ev["AsrResult"]["Channel"] == n + if "Channel" in ev["AsrResult"] + else n == 0 + ) + + return filter_func + + # def compute_endtime(call_wav, state): + # for (i, st) in enumerate(state): + # start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0) + # transcript = st["AsrResult"]["Alternatives"][0]["Transcript"] + # if i + 1 < len(state): + # end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"] + # else: + # end_time = call_wav.duration_seconds + # full_code_seg = call_wav[start_time * 1000 : end_time * 1000] + # code_seg = strip_silence(full_code_seg) + # code_fb = BytesIO() + # code_seg.export(code_fb, format="wav") + # code_wav = code_fb.getvalue() + # # only starting 1 min audio has reliable alignment + # if start_time > 60: + # break + # # only of some audio data is present yield it + # if code_seg.duration_seconds >= 0.5: + # yield transcript, code_seg.duration_seconds, code_wav + + # def generate_call_asr_data(): + # full_asr_data = [] + # total_duration = 0 + # for wav, wav_path, ev in wav_event_generator(call_audio_dir): + # asr_data = asr_data_generator(wav, wav_path, ev) + # total_duration += wav.duration_seconds + # full_asr_data.append(asr_data) + # typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s") + # n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data)) + # typer.echo(f"written {n_dps} data points") + + # generate_call_asr_data() + + def time_to_msecs(time_str): + return ( + datetime.datetime.strptime(time_str, "%H:%M:%S,%f") + - datetime.datetime(1900, 1, 1) + ).total_seconds() * 1000 + + def asr_data_generator(wav_seg, wav_path, meta): + left_audio, right_audio = wav_seg.split_to_mono() + channel_map = {"Agent": right_audio, "Client": left_audio} + monologues = lens["monologues"].Each().collect()(meta) + for monologue in monologues: + # print(monologue["speaker_name"]) + speaker_channel = channel_map.get(monologue["speaker_name"]) + if not speaker_channel: + print(f'unknown speaker tag {monologue["speaker_name"]} in wav:{wav_path} skipping.') + continue + try: + start_time = ( + lens["elements"] + .Each() + .Filter(lambda x: "timestamp" in x)["timestamp"] + .collect()(monologue)[0] + ) + end_time = ( + lens["elements"] + .Each() + .Filter(lambda x: "end_timestamp" in x)["end_timestamp"] + .collect()(monologue)[-1] + ) + except IndexError: + print(f'error when loading timestamp events in wav:{wav_path} skipping.') + + # offset by 500 msec to include first vad? discarded audio + full_tscript_wav_seg = speaker_channel[time_to_msecs(start_time) - 500 : time_to_msecs(end_time)] + tscript_wav_seg = strip_silence(full_tscript_wav_seg) + tscript_wav_fb = BytesIO() + tscript_wav_seg.export(tscript_wav_fb, format="wav") + tscript_wav = tscript_wav_fb.getvalue() + text = "".join(lens["elements"].Each()["value"].collect()(monologue)) + text_clean = re.sub(r"\[.*\]", "", text) + yield text_clean, tscript_wav_seg.duration_seconds, tscript_wav + + def generate_rev_asr_data(): + full_asr_data = [] + total_duration = 0 + for wav, wav_path, ev in wav_event_generator(call_audio_dir): + asr_data = asr_data_generator(wav, wav_path, ev) + total_duration += wav.duration_seconds + full_asr_data.append(asr_data) + typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s") + n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data)) + typer.echo(f"written {n_dps} data points") + + generate_rev_asr_data() + # DEBUG + # data = list(wav_event_generator(call_audio_dir)) + # wav_seg, wav_path, meta = data[0] + # left_audio, right_audio = wav_seg.split_to_mono() + # channel_map = {"Agent": right_audio, "Client": left_audio} + # # data[0][2]['speakers'] + # # data[0][1] + # monologues = lens["monologues"].Each().collect()(meta) + # for monologue in monologues: + # # print(monologue["speaker_name"]) + # speaker_channel = channel_map.get(monologue["speaker_name"]) + # # monologue = monologues[0] + # # monologue + # start_time = ( + # lens["elements"] + # .Each() + # .Filter(lambda x: "timestamp" in x)["timestamp"] + # .collect()(monologue)[0] + # ) + # end_time = ( + # lens["elements"] + # .Each() + # .Filter(lambda x: "end_timestamp" in x)["end_timestamp"] + # .collect()(monologue)[-1] + # ) + # start_time, end_time + # + # # offset by 500 msec to include first vad? discarded audio + # speaker_channel[time_to_msecs(start_time) - 500 : time_to_msecs(end_time)] + # + # # start_time = lens["elements"][0].get()(monologue)['timestamp'] + # # end_time = lens["elements"][-1].get()(monologue)['timestamp'] + # text = "".join(lens["elements"].Each()["value"].collect()(monologue)) + # text_clean = re.sub(r"\[.*\]", "", text) + # # print(text) + # # print(text_clean) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data/utils.py b/jasper/data/utils.py index 9ff26eb..76ba597 100644 --- a/jasper/data/utils.py +++ b/jasper/data/utils.py @@ -104,12 +104,21 @@ class ExtendedPath(type(Path())): return json.dump(data, jf, indent=2) -def get_mongo_conn(host='', port=27017): +def get_mongo_conn(host="", port=27017): mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost") mongo_uri = f"mongodb://{mongo_host}:{port}/" return pymongo.MongoClient(mongo_uri) +def strip_silence(sound): + from pydub.silence import detect_leading_silence + + start_trim = detect_leading_silence(sound) + end_trim = detect_leading_silence(sound.reverse()) + duration = len(sound) + return sound[start_trim : duration - end_trim] + + def main(): for c in random_pnr_generator(): print(c) diff --git a/jasper/data/validation/process.py b/jasper/data/validation/process.py index a911684..3ea1233 100644 --- a/jasper/data/validation/process.py +++ b/jasper/data/validation/process.py @@ -93,7 +93,7 @@ def dump_validation_ui_data( def exec_func(f): return f() - with ThreadPoolExecutor(max_workers=20) as exe: + with ThreadPoolExecutor() as exe: print("starting all plot tasks") pnr_data = filter( None, diff --git a/setup.py b/setup.py index 3900b78..801831c 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ extra_requirements = { "data": [ "google-cloud-texttospeech~=1.0.1", "tqdm~=4.39.0", - "pydub~=0.23.1", + "pydub~=0.24.0", "scikit_learn~=0.22.1", "pandas~=1.0.3", "boto3~=1.12.35", @@ -35,7 +35,7 @@ extra_requirements = { "tqdm~=4.39.0", "librosa==0.7.2", "matplotlib==3.2.1", - "pydub~=0.23.1", + "pydub~=0.24.0", "streamlit==0.58.0", "stringcase==1.2.0" ] @@ -65,6 +65,7 @@ setup( "jasper_data_generate = jasper.data.tts_generator:main", "jasper_data_call_recycle = jasper.data.call_recycler:main", "jasper_data_asr_recycle = jasper.data.asr_recycler:main", + "jasper_data_rev_recycle = jasper.data.rev_recycler:main", "jasper_data_server = jasper.data.server:main", "jasper_data_validation = jasper.data.validation.process:main", "jasper_data_preprocess = jasper.data.process:main",