1. enabled silece stripping in chunks when recycling audio from asr logs

2. limit asr recycling to 1 min of start audio to get reliable alignments and ignoring agent channel
3. added rev recycler for generating asr dataset from rev transcripts and audio
4. update pydub dependency for silence stripping fn and removing threadpool hardcoded worker count
Malar Kannan 2020-05-27 14:22:44 +05:30
parent fca9c1aeb3
commit 1f2bedc156
5 changed files with 208 additions and 10 deletions

View File

@ -15,7 +15,7 @@ def extract_data(
verbose: bool = False,
):
from pydub import AudioSegment
from .utils import ExtendedPath, asr_data_writer
from .utils import ExtendedPath, asr_data_writer, strip_silence
from lenses import lens
call_asr_data: Path = output_dir / Path("asr_data")
@ -52,11 +52,15 @@ def extract_data(
end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
else:
end_time = call_wav.duration_seconds
code_seg = call_wav[start_time * 1000 : end_time * 1000]
full_code_seg = call_wav[start_time * 1000 : end_time * 1000]
code_seg = strip_silence(full_code_seg)
code_fb = BytesIO()
code_seg.export(code_fb, format="wav")
code_wav = code_fb.getvalue()
# only of some audio data is present yield it
# only starting 1 min audio has reliable alignment ignore rest
if start_time > 60:
break
# only if some reasonable audio data is present yield it
if code_seg.duration_seconds >= 0.5:
yield transcript, code_seg.duration_seconds, code_wav
@ -64,12 +68,14 @@ def extract_data(
call_wav_0, call_wav_1 = call_wav.split_to_mono()
asr_events = lens["Events"].Each()["Event"].Filter(contains_asr)
call_evs_0 = asr_events.Filter(channel(0)).collect()(events)
call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
# Ignoring agent channel events
# call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
if verbose:
typer.echo(f"processing data points on {call_wav_fname}")
call_data_0 = compute_endtime(call_wav_0, call_evs_0)
call_data_1 = compute_endtime(call_wav_1, call_evs_1)
return chain(call_data_0, call_data_1)
# Ignoring agent channel
# call_data_1 = compute_endtime(call_wav_1, call_evs_1)
return call_data_0 # chain(call_data_0, call_data_1)
def generate_call_asr_data():
full_asr_data = []

182
jasper/data/rev_recycler.py Normal file
View File

@ -0,0 +1,182 @@
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
import re
app = typer.Typer()
@app.command()
def extract_data(
call_audio_dir: Path = typer.Option(Path("/dataset/rev/wavs"), show_default=True),
call_meta_dir: Path = typer.Option(Path("/dataset/rev/jsons"), show_default=True),
output_dir: Path = typer.Option(Path("./data"), show_default=True),
dataset_name: str = typer.Option("rev_transribed", show_default=True),
verbose: bool = False,
):
from pydub import AudioSegment
from .utils import ExtendedPath, asr_data_writer, strip_silence
from lenses import lens
import datetime
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_event_generator(call_audio_dir):
for wav_path in call_audio_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
meta_path = call_meta_dir / rel_meta_path
if meta_path.exists():
events = ExtendedPath(meta_path).read_json()
yield call_wav, wav_path, events
else:
typer.echo(f"missing json corresponding to {wav_path}")
def contains_asr(x):
return "AsrResult" in x
def channel(n):
def filter_func(ev):
return (
ev["AsrResult"]["Channel"] == n
if "Channel" in ev["AsrResult"]
else n == 0
)
return filter_func
# def compute_endtime(call_wav, state):
# for (i, st) in enumerate(state):
# start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0)
# transcript = st["AsrResult"]["Alternatives"][0]["Transcript"]
# if i + 1 < len(state):
# end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
# else:
# end_time = call_wav.duration_seconds
# full_code_seg = call_wav[start_time * 1000 : end_time * 1000]
# code_seg = strip_silence(full_code_seg)
# code_fb = BytesIO()
# code_seg.export(code_fb, format="wav")
# code_wav = code_fb.getvalue()
# # only starting 1 min audio has reliable alignment
# if start_time > 60:
# break
# # only of some audio data is present yield it
# if code_seg.duration_seconds >= 0.5:
# yield transcript, code_seg.duration_seconds, code_wav
# def generate_call_asr_data():
# full_asr_data = []
# total_duration = 0
# for wav, wav_path, ev in wav_event_generator(call_audio_dir):
# asr_data = asr_data_generator(wav, wav_path, ev)
# total_duration += wav.duration_seconds
# full_asr_data.append(asr_data)
# typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
# n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
# typer.echo(f"written {n_dps} data points")
# generate_call_asr_data()
def time_to_msecs(time_str):
return (
datetime.datetime.strptime(time_str, "%H:%M:%S,%f")
- datetime.datetime(1900, 1, 1)
).total_seconds() * 1000
def asr_data_generator(wav_seg, wav_path, meta):
left_audio, right_audio = wav_seg.split_to_mono()
channel_map = {"Agent": right_audio, "Client": left_audio}
monologues = lens["monologues"].Each().collect()(meta)
for monologue in monologues:
# print(monologue["speaker_name"])
speaker_channel = channel_map.get(monologue["speaker_name"])
if not speaker_channel:
print(f'unknown speaker tag {monologue["speaker_name"]} in wav:{wav_path} skipping.')
continue
try:
start_time = (
lens["elements"]
.Each()
.Filter(lambda x: "timestamp" in x)["timestamp"]
.collect()(monologue)[0]
)
end_time = (
lens["elements"]
.Each()
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
.collect()(monologue)[-1]
)
except IndexError:
print(f'error when loading timestamp events in wav:{wav_path} skipping.')
# offset by 500 msec to include first vad? discarded audio
full_tscript_wav_seg = speaker_channel[time_to_msecs(start_time) - 500 : time_to_msecs(end_time)]
tscript_wav_seg = strip_silence(full_tscript_wav_seg)
tscript_wav_fb = BytesIO()
tscript_wav_seg.export(tscript_wav_fb, format="wav")
tscript_wav = tscript_wav_fb.getvalue()
text = "".join(lens["elements"].Each()["value"].collect()(monologue))
text_clean = re.sub(r"\[.*\]", "", text)
yield text_clean, tscript_wav_seg.duration_seconds, tscript_wav
def generate_rev_asr_data():
full_asr_data = []
total_duration = 0
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
asr_data = asr_data_generator(wav, wav_path, ev)
total_duration += wav.duration_seconds
full_asr_data.append(asr_data)
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
typer.echo(f"written {n_dps} data points")
generate_rev_asr_data()
# DEBUG
# data = list(wav_event_generator(call_audio_dir))
# wav_seg, wav_path, meta = data[0]
# left_audio, right_audio = wav_seg.split_to_mono()
# channel_map = {"Agent": right_audio, "Client": left_audio}
# # data[0][2]['speakers']
# # data[0][1]
# monologues = lens["monologues"].Each().collect()(meta)
# for monologue in monologues:
# # print(monologue["speaker_name"])
# speaker_channel = channel_map.get(monologue["speaker_name"])
# # monologue = monologues[0]
# # monologue
# start_time = (
# lens["elements"]
# .Each()
# .Filter(lambda x: "timestamp" in x)["timestamp"]
# .collect()(monologue)[0]
# )
# end_time = (
# lens["elements"]
# .Each()
# .Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
# .collect()(monologue)[-1]
# )
# start_time, end_time
#
# # offset by 500 msec to include first vad? discarded audio
# speaker_channel[time_to_msecs(start_time) - 500 : time_to_msecs(end_time)]
#
# # start_time = lens["elements"][0].get()(monologue)['timestamp']
# # end_time = lens["elements"][-1].get()(monologue)['timestamp']
# text = "".join(lens["elements"].Each()["value"].collect()(monologue))
# text_clean = re.sub(r"\[.*\]", "", text)
# # print(text)
# # print(text_clean)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -104,12 +104,21 @@ class ExtendedPath(type(Path())):
return json.dump(data, jf, indent=2)
def get_mongo_conn(host='', port=27017):
def get_mongo_conn(host="", port=27017):
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
mongo_uri = f"mongodb://{mongo_host}:{port}/"
return pymongo.MongoClient(mongo_uri)
def strip_silence(sound):
from pydub.silence import detect_leading_silence
start_trim = detect_leading_silence(sound)
end_trim = detect_leading_silence(sound.reverse())
duration = len(sound)
return sound[start_trim : duration - end_trim]
def main():
for c in random_pnr_generator():
print(c)

View File

@ -93,7 +93,7 @@ def dump_validation_ui_data(
def exec_func(f):
return f()
with ThreadPoolExecutor(max_workers=20) as exe:
with ThreadPoolExecutor() as exe:
print("starting all plot tasks")
pnr_data = filter(
None,

View File

@ -12,7 +12,7 @@ extra_requirements = {
"data": [
"google-cloud-texttospeech~=1.0.1",
"tqdm~=4.39.0",
"pydub~=0.23.1",
"pydub~=0.24.0",
"scikit_learn~=0.22.1",
"pandas~=1.0.3",
"boto3~=1.12.35",
@ -35,7 +35,7 @@ extra_requirements = {
"tqdm~=4.39.0",
"librosa==0.7.2",
"matplotlib==3.2.1",
"pydub~=0.23.1",
"pydub~=0.24.0",
"streamlit==0.58.0",
"stringcase==1.2.0"
]
@ -65,6 +65,7 @@ setup(
"jasper_data_generate = jasper.data.tts_generator:main",
"jasper_data_call_recycle = jasper.data.call_recycler:main",
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
"jasper_data_server = jasper.data.server:main",
"jasper_data_validation = jasper.data.validation.process:main",
"jasper_data_preprocess = jasper.data.process:main",