1. add asr data generation from audio files and triton-asr results
2. add asr data clean/channel extraction process subcommands 3. add parallel without pool(single thread) 4. include support for raw transcripts from triton-asr resultstegra
parent
e07c7c9caf
commit
af51fe95cb
|
|
@ -26,13 +26,15 @@ from plume.utils import (
|
|||
|
||||
from ...models.wav2vec2.data import app as wav2vec2_app
|
||||
from .generate import app as generate_app
|
||||
from .process import app as process_app
|
||||
|
||||
soundfile = lazy_module("soundfile")
|
||||
pydub = lazy_module("pydub")
|
||||
train_test_split = lazy_callable("sklearn.model_selection.train_test_split")
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(generate_app, name="generate")
|
||||
app.add_typer(generate_app)
|
||||
app.add_typer(process_app)
|
||||
app.add_typer(wav2vec2_app, name="wav2vec2")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,53 @@
|
|||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
import typer
|
||||
from ...utils.tts import GoogleTTS
|
||||
from plume.utils.lazy_import import lazy_module
|
||||
from plume.utils.tts import GoogleTTS
|
||||
from plume.utils.transcribe import triton_transcribe_grpc_gen
|
||||
from plume.utils.manifest import asr_manifest_writer
|
||||
|
||||
pydub = lazy_module("pydub")
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback()
|
||||
def generate():
|
||||
"""
|
||||
generate sub commands
|
||||
"""
|
||||
|
||||
|
||||
@app.command()
|
||||
def tts_dataset(dest_path: Path):
|
||||
tts = GoogleTTS()
|
||||
pass
|
||||
|
||||
|
||||
@app.command()
|
||||
def asr_dataset(audio_dir: Path, out_dir: Path, model="slu_num_wav2vec2"):
|
||||
out_wav_dir = out_dir / "wavs"
|
||||
out_wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def data_gen():
|
||||
aud_files = list(audio_dir.glob("*.mp3")) + list(
|
||||
audio_dir.glob("*.wav")
|
||||
)
|
||||
transcriber, prep = triton_transcribe_grpc_gen(
|
||||
asr_model=model, method="whole", append_raw=True
|
||||
)
|
||||
for af in aud_files:
|
||||
out_af = out_wav_dir / af.name
|
||||
audio_af = out_af.relative_to(out_dir)
|
||||
shutil.copy2(af, out_af)
|
||||
aud_seg = pydub.AudioSegment.from_file(out_af)
|
||||
t_seg = prep(aud_seg)
|
||||
transcript = transcriber(t_seg)
|
||||
# [digit_tscript, raw_tscript] = transcript.split("|")
|
||||
yield {
|
||||
"audio_filepath": str(audio_af),
|
||||
"duration": aud_seg.duration_seconds,
|
||||
"text": transcript,
|
||||
}
|
||||
|
||||
asr_manifest_writer(out_dir / 'manifest.json', data_gen())
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from ...utils.audio import remove_if_invalid, copy_channel_to
|
||||
import shutil
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback()
|
||||
def process():
|
||||
"""
|
||||
clean sub commands
|
||||
"""
|
||||
|
||||
|
||||
@app.command()
|
||||
def remove_invalid(audio_dir: Path, out_dir: Path):
|
||||
shutil.copytree(audio_dir, out_dir, dirs_exist_ok=True)
|
||||
aud_files = list(out_dir.glob("*.mp3")) + list(out_dir.glob("*.wav"))
|
||||
for af in aud_files:
|
||||
remove_if_invalid(af)
|
||||
|
||||
|
||||
@app.command()
|
||||
def extract_channel(audio_dir: Path, out_dir: Path, channel="left"):
|
||||
# shutil.copytree(audio_dir, out_dir, dirs_exist_ok=True)
|
||||
out_dir.mkdir(exist_ok=True, parents=True)
|
||||
aud_files = list(audio_dir.glob("*.mp3")) + list(audio_dir.glob("*.wav"))
|
||||
for af in aud_files:
|
||||
out_af = out_dir / af.relative_to(audio_dir)
|
||||
copy_channel_to(af, out_af, channel)
|
||||
|
|
@ -424,14 +424,6 @@ def ui_dump_manifest_writer(dataset_dir, asr_data_source, verbose=False):
|
|||
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
asr_manifest_writer(asr_manifest, dump_data, verbose=verbose)
|
||||
# with asr_manifest.open("w") as mf:
|
||||
# print(f"writing manifest to {asr_manifest}")
|
||||
# for d in dump_data:
|
||||
# rel_data_path = d["audio_path"]
|
||||
# audio_dur = d["duration"]
|
||||
# transcript = d["text"]
|
||||
# manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||
# mf.write(manifest)
|
||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||
ExtendedPath(ui_dump_file).write_json({"data": dump_data}, verbose=verbose)
|
||||
return num_datapoints
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import sys
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from .lazy_import import lazy_module, lazy_callable
|
||||
|
||||
|
|
@ -14,6 +15,27 @@ write = lazy_callable("scipy.io.wavfile.write")
|
|||
# import numpy as np
|
||||
|
||||
|
||||
def remove_if_invalid(af: Path):
|
||||
# audio_dir.glob('*.wav')
|
||||
# aud_files = list(audio_dir.glob("*.mp3")) + list(audio_dir.glob("*.wav"))
|
||||
# for af in aud_files:
|
||||
try:
|
||||
pydub.AudioSegment.from_file(af)
|
||||
except pydub.exceptions.CouldntDecodeError:
|
||||
print(f"removing invalid file {af}")
|
||||
af.unlink()
|
||||
|
||||
|
||||
def copy_channel_to(i_af: Path, o_af: Path, channel):
|
||||
i_af_seg = pydub.AudioSegment.from_file(i_af)
|
||||
if i_af_seg.channels > 1:
|
||||
left, right = i_af_seg.split_to_mono()
|
||||
channel_seg = left if channel == "left" else right
|
||||
else:
|
||||
channel_seg = i_af_seg
|
||||
channel_seg.export(o_af, format="wav")
|
||||
|
||||
|
||||
def audio_seg_to_wav_bytes(aud_seg):
|
||||
b = BytesIO()
|
||||
aud_seg.export(b, format="wav")
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ from tqdm import tqdm
|
|||
|
||||
|
||||
def parallel_apply(fn, iterable, workers=8, pool="thread", verbose=True):
|
||||
# warm-up
|
||||
fn(iterable[0])
|
||||
# warm-up (doesn't work there fn conditionals that doesn't follow hot path)
|
||||
# fn(iterable[0])
|
||||
if pool == "thread":
|
||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||
if verbose:
|
||||
|
|
@ -37,5 +37,10 @@ def parallel_apply(fn, iterable, workers=8, pool="thread", verbose=True):
|
|||
return result
|
||||
else:
|
||||
return [res for res in exe.map(fn, iterable)]
|
||||
elif pool == "none":
|
||||
if verbose:
|
||||
return list(map(fn, tqdm(iterable)))
|
||||
else:
|
||||
return list(map(fn, iterable))
|
||||
else:
|
||||
raise Exception(f"unsupported pool type - {pool}")
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ np = lazy_module("numpy")
|
|||
app = typer.Typer()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -43,13 +44,18 @@ def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT):
|
|||
raise Exception("env-var JASPER_ASR_RPYC_HOST invalid")
|
||||
|
||||
def audio_prep(aud_seg):
|
||||
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
asr_seg = (
|
||||
aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
)
|
||||
af = BytesIO()
|
||||
asr_seg.export(af, format="wav")
|
||||
input_audio_bytes = af.getvalue()
|
||||
return input_audio_bytes
|
||||
|
||||
return asr.transcribe, audio_prep
|
||||
def dummy_transcript(aud, append_raw=False):
|
||||
return asr.transcribe(aud)
|
||||
|
||||
return dummy_transcript, audio_prep
|
||||
|
||||
|
||||
def triton_transcribe_grpc_gen(
|
||||
|
|
@ -60,10 +66,13 @@ def triton_transcribe_grpc_gen(
|
|||
chunk_msec=5000,
|
||||
sil_msec=500,
|
||||
# overlap=False,
|
||||
append_raw=False,
|
||||
sep=" ",
|
||||
):
|
||||
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
|
||||
import tritonclient.grpc as grpcclient
|
||||
# force loading
|
||||
np.array
|
||||
|
||||
sup_meth = ["chunked", "silence", "whole"]
|
||||
if method not in sup_meth:
|
||||
|
|
@ -90,10 +99,13 @@ def triton_transcribe_grpc_gen(
|
|||
response = client.infer(
|
||||
asr_model, inputs, request_id=str(1), outputs=outputs
|
||||
)
|
||||
transcript = response.as_numpy("OUTPUT_TEXT")[0]
|
||||
outputs = response.as_numpy("OUTPUT_TEXT")
|
||||
transcript = outputs[0].decode("utf-8")
|
||||
if len(outputs) > 1 and append_raw:
|
||||
transcript = transcript + "|" + outputs[1].decode("utf-8")
|
||||
except InferenceServerException:
|
||||
transcript = b"[server error]"
|
||||
return transcript.decode("utf-8")
|
||||
transcript = "[server error]"
|
||||
return transcript
|
||||
|
||||
def chunked_transcriber(aud_seg):
|
||||
if method == "silence":
|
||||
|
|
@ -122,22 +134,34 @@ def triton_transcribe_grpc_gen(
|
|||
return transcript
|
||||
|
||||
def audio_prep(aud_seg):
|
||||
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
asr_seg = (
|
||||
aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
)
|
||||
return asr_seg
|
||||
|
||||
whole_transcriber = transcriber if method == "whole" else chunked_transcriber
|
||||
whole_transcriber = (
|
||||
transcriber if method == "whole" else chunked_transcriber
|
||||
)
|
||||
return whole_transcriber, audio_prep
|
||||
|
||||
|
||||
@app.command()
|
||||
def file(
|
||||
audio_file: Path, write_file: bool = False, chunked: bool = True, rpyc: bool = False, model='slu_wav2vec2'
|
||||
audio_file: Path,
|
||||
write_file: bool = False,
|
||||
chunked: bool = False,
|
||||
rpyc: bool = False,
|
||||
append_raw: bool = False,
|
||||
model="slu_num_wav2vec2",
|
||||
):
|
||||
aseg = pydub.AudioSegment.from_file(audio_file)
|
||||
if rpyc:
|
||||
transcriber, prep = transcribe_rpyc_gen()
|
||||
else:
|
||||
transcriber, prep = triton_transcribe_grpc_gen(asr_model=model)
|
||||
method = "chunked" if chunked else "whole"
|
||||
transcriber, prep = triton_transcribe_grpc_gen(
|
||||
asr_model=model, method=method, append_raw=append_raw
|
||||
)
|
||||
transcription = transcriber(prep(aseg))
|
||||
|
||||
typer.echo(transcription)
|
||||
|
|
|
|||
Loading…
Reference in New Issue