diff --git a/src/plume/cli/data/__init__.py b/src/plume/cli/data/__init__.py index bcf1042..e8de6a1 100644 --- a/src/plume/cli/data/__init__.py +++ b/src/plume/cli/data/__init__.py @@ -26,13 +26,15 @@ from plume.utils import ( from ...models.wav2vec2.data import app as wav2vec2_app from .generate import app as generate_app +from .process import app as process_app soundfile = lazy_module("soundfile") pydub = lazy_module("pydub") train_test_split = lazy_callable("sklearn.model_selection.train_test_split") app = typer.Typer() -app.add_typer(generate_app, name="generate") +app.add_typer(generate_app) +app.add_typer(process_app) app.add_typer(wav2vec2_app, name="wav2vec2") diff --git a/src/plume/cli/data/generate.py b/src/plume/cli/data/generate.py index b1ee129..464346b 100644 --- a/src/plume/cli/data/generate.py +++ b/src/plume/cli/data/generate.py @@ -1,12 +1,53 @@ from pathlib import Path +import shutil import typer -from ...utils.tts import GoogleTTS +from plume.utils.lazy_import import lazy_module +from plume.utils.tts import GoogleTTS +from plume.utils.transcribe import triton_transcribe_grpc_gen +from plume.utils.manifest import asr_manifest_writer +pydub = lazy_module("pydub") app = typer.Typer() +@app.callback() +def generate(): + """ + generate sub commands + """ + + @app.command() def tts_dataset(dest_path: Path): tts = GoogleTTS() pass + + +@app.command() +def asr_dataset(audio_dir: Path, out_dir: Path, model="slu_num_wav2vec2"): + out_wav_dir = out_dir / "wavs" + out_wav_dir.mkdir(exist_ok=True, parents=True) + + def data_gen(): + aud_files = list(audio_dir.glob("*.mp3")) + list( + audio_dir.glob("*.wav") + ) + transcriber, prep = triton_transcribe_grpc_gen( + asr_model=model, method="whole", append_raw=True + ) + for af in aud_files: + out_af = out_wav_dir / af.name + audio_af = out_af.relative_to(out_dir) + shutil.copy2(af, out_af) + aud_seg = pydub.AudioSegment.from_file(out_af) + t_seg = prep(aud_seg) + transcript = transcriber(t_seg) + # [digit_tscript, raw_tscript] = transcript.split("|") + yield { + "audio_filepath": str(audio_af), + "duration": aud_seg.duration_seconds, + "text": transcript, + } + + asr_manifest_writer(out_dir / 'manifest.json', data_gen()) diff --git a/src/plume/cli/data/process.py b/src/plume/cli/data/process.py new file mode 100644 index 0000000..0b230f8 --- /dev/null +++ b/src/plume/cli/data/process.py @@ -0,0 +1,32 @@ +from pathlib import Path + +import typer +from ...utils.audio import remove_if_invalid, copy_channel_to +import shutil + +app = typer.Typer() + + +@app.callback() +def process(): + """ + clean sub commands + """ + + +@app.command() +def remove_invalid(audio_dir: Path, out_dir: Path): + shutil.copytree(audio_dir, out_dir, dirs_exist_ok=True) + aud_files = list(out_dir.glob("*.mp3")) + list(out_dir.glob("*.wav")) + for af in aud_files: + remove_if_invalid(af) + + +@app.command() +def extract_channel(audio_dir: Path, out_dir: Path, channel="left"): + # shutil.copytree(audio_dir, out_dir, dirs_exist_ok=True) + out_dir.mkdir(exist_ok=True, parents=True) + aud_files = list(audio_dir.glob("*.mp3")) + list(audio_dir.glob("*.wav")) + for af in aud_files: + out_af = out_dir / af.relative_to(audio_dir) + copy_channel_to(af, out_af, channel) diff --git a/src/plume/utils/__init__.py b/src/plume/utils/__init__.py index 6c69da8..a3e3646 100644 --- a/src/plume/utils/__init__.py +++ b/src/plume/utils/__init__.py @@ -424,14 +424,6 @@ def ui_dump_manifest_writer(dataset_dir, asr_data_source, verbose=False): asr_manifest = dataset_dir / Path("manifest.json") asr_manifest_writer(asr_manifest, dump_data, verbose=verbose) - # with asr_manifest.open("w") as mf: - # print(f"writing manifest to {asr_manifest}") - # for d in dump_data: - # rel_data_path = d["audio_path"] - # audio_dur = d["duration"] - # transcript = d["text"] - # manifest = manifest_str(str(rel_data_path), audio_dur, transcript) - # mf.write(manifest) ui_dump_file = dataset_dir / Path("ui_dump.json") ExtendedPath(ui_dump_file).write_json({"data": dump_data}, verbose=verbose) return num_datapoints diff --git a/src/plume/utils/audio.py b/src/plume/utils/audio.py index 416b4a2..e7d1934 100644 --- a/src/plume/utils/audio.py +++ b/src/plume/utils/audio.py @@ -1,5 +1,6 @@ import sys from io import BytesIO +from pathlib import Path from .lazy_import import lazy_module, lazy_callable @@ -14,6 +15,27 @@ write = lazy_callable("scipy.io.wavfile.write") # import numpy as np +def remove_if_invalid(af: Path): + # audio_dir.glob('*.wav') + # aud_files = list(audio_dir.glob("*.mp3")) + list(audio_dir.glob("*.wav")) + # for af in aud_files: + try: + pydub.AudioSegment.from_file(af) + except pydub.exceptions.CouldntDecodeError: + print(f"removing invalid file {af}") + af.unlink() + + +def copy_channel_to(i_af: Path, o_af: Path, channel): + i_af_seg = pydub.AudioSegment.from_file(i_af) + if i_af_seg.channels > 1: + left, right = i_af_seg.split_to_mono() + channel_seg = left if channel == "left" else right + else: + channel_seg = i_af_seg + channel_seg.export(o_af, format="wav") + + def audio_seg_to_wav_bytes(aud_seg): b = BytesIO() aud_seg.export(b, format="wav") diff --git a/src/plume/utils/parallel.py b/src/plume/utils/parallel.py index d125de5..9d9773d 100644 --- a/src/plume/utils/parallel.py +++ b/src/plume/utils/parallel.py @@ -3,8 +3,8 @@ from tqdm import tqdm def parallel_apply(fn, iterable, workers=8, pool="thread", verbose=True): - # warm-up - fn(iterable[0]) + # warm-up (doesn't work there fn conditionals that doesn't follow hot path) + # fn(iterable[0]) if pool == "thread": with ThreadPoolExecutor(max_workers=workers) as exe: if verbose: @@ -37,5 +37,10 @@ def parallel_apply(fn, iterable, workers=8, pool="thread", verbose=True): return result else: return [res for res in exe.map(fn, iterable)] + elif pool == "none": + if verbose: + return list(map(fn, tqdm(iterable))) + else: + return list(map(fn, iterable)) else: raise Exception(f"unsupported pool type - {pool}") diff --git a/src/plume/utils/transcribe.py b/src/plume/utils/transcribe.py index 8964b6d..bb02f24 100644 --- a/src/plume/utils/transcribe.py +++ b/src/plume/utils/transcribe.py @@ -19,7 +19,8 @@ np = lazy_module("numpy") app = typer.Typer() logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) @@ -43,13 +44,18 @@ def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT): raise Exception("env-var JASPER_ASR_RPYC_HOST invalid") def audio_prep(aud_seg): - asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000) + asr_seg = ( + aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000) + ) af = BytesIO() asr_seg.export(af, format="wav") input_audio_bytes = af.getvalue() return input_audio_bytes - return asr.transcribe, audio_prep + def dummy_transcript(aud, append_raw=False): + return asr.transcribe(aud) + + return dummy_transcript, audio_prep def triton_transcribe_grpc_gen( @@ -60,10 +66,13 @@ def triton_transcribe_grpc_gen( chunk_msec=5000, sil_msec=500, # overlap=False, + append_raw=False, sep=" ", ): from tritonclient.utils import np_to_triton_dtype, InferenceServerException import tritonclient.grpc as grpcclient + # force loading + np.array sup_meth = ["chunked", "silence", "whole"] if method not in sup_meth: @@ -90,10 +99,13 @@ def triton_transcribe_grpc_gen( response = client.infer( asr_model, inputs, request_id=str(1), outputs=outputs ) - transcript = response.as_numpy("OUTPUT_TEXT")[0] + outputs = response.as_numpy("OUTPUT_TEXT") + transcript = outputs[0].decode("utf-8") + if len(outputs) > 1 and append_raw: + transcript = transcript + "|" + outputs[1].decode("utf-8") except InferenceServerException: - transcript = b"[server error]" - return transcript.decode("utf-8") + transcript = "[server error]" + return transcript def chunked_transcriber(aud_seg): if method == "silence": @@ -122,22 +134,34 @@ def triton_transcribe_grpc_gen( return transcript def audio_prep(aud_seg): - asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000) + asr_seg = ( + aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000) + ) return asr_seg - whole_transcriber = transcriber if method == "whole" else chunked_transcriber + whole_transcriber = ( + transcriber if method == "whole" else chunked_transcriber + ) return whole_transcriber, audio_prep @app.command() def file( - audio_file: Path, write_file: bool = False, chunked: bool = True, rpyc: bool = False, model='slu_wav2vec2' + audio_file: Path, + write_file: bool = False, + chunked: bool = False, + rpyc: bool = False, + append_raw: bool = False, + model="slu_num_wav2vec2", ): aseg = pydub.AudioSegment.from_file(audio_file) if rpyc: transcriber, prep = transcribe_rpyc_gen() else: - transcriber, prep = triton_transcribe_grpc_gen(asr_model=model) + method = "chunked" if chunked else "whole" + transcriber, prep = triton_transcribe_grpc_gen( + asr_model=model, method=method, append_raw=append_raw + ) transcription = transcriber(prep(aseg)) typer.echo(transcription)