1. add asr data generation from audio files and triton-asr results

2. add asr data clean/channel extraction process subcommands
3. add parallel without pool(single thread)
4. include support for raw transcripts from triton-asr results
tegra
Malar 2021-06-07 15:44:04 +05:30
parent e07c7c9caf
commit af51fe95cb
7 changed files with 140 additions and 22 deletions

View File

@ -26,13 +26,15 @@ from plume.utils import (
from ...models.wav2vec2.data import app as wav2vec2_app
from .generate import app as generate_app
from .process import app as process_app
soundfile = lazy_module("soundfile")
pydub = lazy_module("pydub")
train_test_split = lazy_callable("sklearn.model_selection.train_test_split")
app = typer.Typer()
app.add_typer(generate_app, name="generate")
app.add_typer(generate_app)
app.add_typer(process_app)
app.add_typer(wav2vec2_app, name="wav2vec2")

View File

@ -1,12 +1,53 @@
from pathlib import Path
import shutil
import typer
from ...utils.tts import GoogleTTS
from plume.utils.lazy_import import lazy_module
from plume.utils.tts import GoogleTTS
from plume.utils.transcribe import triton_transcribe_grpc_gen
from plume.utils.manifest import asr_manifest_writer
pydub = lazy_module("pydub")
app = typer.Typer()
@app.callback()
def generate():
"""
generate sub commands
"""
@app.command()
def tts_dataset(dest_path: Path):
tts = GoogleTTS()
pass
@app.command()
def asr_dataset(audio_dir: Path, out_dir: Path, model="slu_num_wav2vec2"):
out_wav_dir = out_dir / "wavs"
out_wav_dir.mkdir(exist_ok=True, parents=True)
def data_gen():
aud_files = list(audio_dir.glob("*.mp3")) + list(
audio_dir.glob("*.wav")
)
transcriber, prep = triton_transcribe_grpc_gen(
asr_model=model, method="whole", append_raw=True
)
for af in aud_files:
out_af = out_wav_dir / af.name
audio_af = out_af.relative_to(out_dir)
shutil.copy2(af, out_af)
aud_seg = pydub.AudioSegment.from_file(out_af)
t_seg = prep(aud_seg)
transcript = transcriber(t_seg)
# [digit_tscript, raw_tscript] = transcript.split("|")
yield {
"audio_filepath": str(audio_af),
"duration": aud_seg.duration_seconds,
"text": transcript,
}
asr_manifest_writer(out_dir / 'manifest.json', data_gen())

View File

@ -0,0 +1,32 @@
from pathlib import Path
import typer
from ...utils.audio import remove_if_invalid, copy_channel_to
import shutil
app = typer.Typer()
@app.callback()
def process():
"""
clean sub commands
"""
@app.command()
def remove_invalid(audio_dir: Path, out_dir: Path):
shutil.copytree(audio_dir, out_dir, dirs_exist_ok=True)
aud_files = list(out_dir.glob("*.mp3")) + list(out_dir.glob("*.wav"))
for af in aud_files:
remove_if_invalid(af)
@app.command()
def extract_channel(audio_dir: Path, out_dir: Path, channel="left"):
# shutil.copytree(audio_dir, out_dir, dirs_exist_ok=True)
out_dir.mkdir(exist_ok=True, parents=True)
aud_files = list(audio_dir.glob("*.mp3")) + list(audio_dir.glob("*.wav"))
for af in aud_files:
out_af = out_dir / af.relative_to(audio_dir)
copy_channel_to(af, out_af, channel)

View File

@ -424,14 +424,6 @@ def ui_dump_manifest_writer(dataset_dir, asr_data_source, verbose=False):
asr_manifest = dataset_dir / Path("manifest.json")
asr_manifest_writer(asr_manifest, dump_data, verbose=verbose)
# with asr_manifest.open("w") as mf:
# print(f"writing manifest to {asr_manifest}")
# for d in dump_data:
# rel_data_path = d["audio_path"]
# audio_dur = d["duration"]
# transcript = d["text"]
# manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
# mf.write(manifest)
ui_dump_file = dataset_dir / Path("ui_dump.json")
ExtendedPath(ui_dump_file).write_json({"data": dump_data}, verbose=verbose)
return num_datapoints

View File

@ -1,5 +1,6 @@
import sys
from io import BytesIO
from pathlib import Path
from .lazy_import import lazy_module, lazy_callable
@ -14,6 +15,27 @@ write = lazy_callable("scipy.io.wavfile.write")
# import numpy as np
def remove_if_invalid(af: Path):
# audio_dir.glob('*.wav')
# aud_files = list(audio_dir.glob("*.mp3")) + list(audio_dir.glob("*.wav"))
# for af in aud_files:
try:
pydub.AudioSegment.from_file(af)
except pydub.exceptions.CouldntDecodeError:
print(f"removing invalid file {af}")
af.unlink()
def copy_channel_to(i_af: Path, o_af: Path, channel):
i_af_seg = pydub.AudioSegment.from_file(i_af)
if i_af_seg.channels > 1:
left, right = i_af_seg.split_to_mono()
channel_seg = left if channel == "left" else right
else:
channel_seg = i_af_seg
channel_seg.export(o_af, format="wav")
def audio_seg_to_wav_bytes(aud_seg):
b = BytesIO()
aud_seg.export(b, format="wav")

View File

@ -3,8 +3,8 @@ from tqdm import tqdm
def parallel_apply(fn, iterable, workers=8, pool="thread", verbose=True):
# warm-up
fn(iterable[0])
# warm-up (doesn't work there fn conditionals that doesn't follow hot path)
# fn(iterable[0])
if pool == "thread":
with ThreadPoolExecutor(max_workers=workers) as exe:
if verbose:
@ -37,5 +37,10 @@ def parallel_apply(fn, iterable, workers=8, pool="thread", verbose=True):
return result
else:
return [res for res in exe.map(fn, iterable)]
elif pool == "none":
if verbose:
return list(map(fn, tqdm(iterable)))
else:
return list(map(fn, iterable))
else:
raise Exception(f"unsupported pool type - {pool}")

View File

@ -19,7 +19,8 @@ np = lazy_module("numpy")
app = typer.Typer()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
@ -43,13 +44,18 @@ def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT):
raise Exception("env-var JASPER_ASR_RPYC_HOST invalid")
def audio_prep(aud_seg):
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
asr_seg = (
aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
)
af = BytesIO()
asr_seg.export(af, format="wav")
input_audio_bytes = af.getvalue()
return input_audio_bytes
return asr.transcribe, audio_prep
def dummy_transcript(aud, append_raw=False):
return asr.transcribe(aud)
return dummy_transcript, audio_prep
def triton_transcribe_grpc_gen(
@ -60,10 +66,13 @@ def triton_transcribe_grpc_gen(
chunk_msec=5000,
sil_msec=500,
# overlap=False,
append_raw=False,
sep=" ",
):
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
import tritonclient.grpc as grpcclient
# force loading
np.array
sup_meth = ["chunked", "silence", "whole"]
if method not in sup_meth:
@ -90,10 +99,13 @@ def triton_transcribe_grpc_gen(
response = client.infer(
asr_model, inputs, request_id=str(1), outputs=outputs
)
transcript = response.as_numpy("OUTPUT_TEXT")[0]
outputs = response.as_numpy("OUTPUT_TEXT")
transcript = outputs[0].decode("utf-8")
if len(outputs) > 1 and append_raw:
transcript = transcript + "|" + outputs[1].decode("utf-8")
except InferenceServerException:
transcript = b"[server error]"
return transcript.decode("utf-8")
transcript = "[server error]"
return transcript
def chunked_transcriber(aud_seg):
if method == "silence":
@ -122,22 +134,34 @@ def triton_transcribe_grpc_gen(
return transcript
def audio_prep(aud_seg):
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
asr_seg = (
aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
)
return asr_seg
whole_transcriber = transcriber if method == "whole" else chunked_transcriber
whole_transcriber = (
transcriber if method == "whole" else chunked_transcriber
)
return whole_transcriber, audio_prep
@app.command()
def file(
audio_file: Path, write_file: bool = False, chunked: bool = True, rpyc: bool = False, model='slu_wav2vec2'
audio_file: Path,
write_file: bool = False,
chunked: bool = False,
rpyc: bool = False,
append_raw: bool = False,
model="slu_num_wav2vec2",
):
aseg = pydub.AudioSegment.from_file(audio_file)
if rpyc:
transcriber, prep = transcribe_rpyc_gen()
else:
transcriber, prep = triton_transcribe_grpc_gen(asr_model=model)
method = "chunked" if chunked else "whole"
transcriber, prep = triton_transcribe_grpc_gen(
asr_model=model, method=method, append_raw=append_raw
)
transcription = transcriber(prep(aseg))
typer.echo(transcription)