diff --git a/.flake8 b/.flake8 index 170a050..6e6e140 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] exclude = docs ignore = E203, W503 -max-line-length = 119 +max-line-length = 79 diff --git a/notebooks/Diarization.ipynb b/notebooks/Diarization.ipynb new file mode 100644 index 0000000..de5e9b7 --- /dev/null +++ b/notebooks/Diarization.ipynb @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "id": "808d647c-deef-41ec-8a69-06f7a72689cb", + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e49e45a1-2133-487c-a8bc-69caed303074", + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install pyannote.audio==1.1.1\n", + "#!pip install pyannote.core[notebook]\n", + "#!pip install pyannote.pipeline\n", + "#!pip install pyannote.core\n", + "!pip install ipywidgets" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "59437acd-b487-4531-bb4e-a4c216a67a29", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n", + "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n", + "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n", + "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n", + "/home/reves/plume-asr/.direnv/python-3.8.10/lib/python3.8/site-packages/pyannote/audio/embedding/approaches/arcface_loss.py:170: FutureWarning: The 's' parameter is deprecated in favor of 'scale', and will be removed in a future release\n", + " warnings.warn(msg, FutureWarning)\n", + "/home/reves/plume-asr/.direnv/python-3.8.10/lib/python3.8/site-packages/pyannote/audio/features/pretrained.py:156: UserWarning: Model was trained with 4s chunks and is applied on 2s chunks. This might lead to sub-optimal results.\n", + " warnings.warn(msg)\n", + "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n" + ] + } + ], + "source": [ + "#pipeline = torch.hub.load('pyannote/pyannote-audio', 'dia_ami')\n", + "pipeline = torch.hub.load('pyannote/pyannote-audio', 'dia')" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "5f69fda5-d94d-4c6f-854d-fd2788ff5bc5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3.94 s, sys: 179 ms, total: 4.12 s\n", + "Wall time: 4.12 s\n" + ] + } + ], + "source": [ + "%time diarization = pipeline({'audio': '/home/reves/test-dt.wav'})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93eb8203-af2f-4905-90ec-7e667200e647", + "metadata": {}, + "outputs": [], + "source": [ + "for turn, _, speaker in diarization.itertracks(yield_label=True):\n", + " print(f'Speaker \"{speaker}\" speaks between t={turn.start:.1f}s and t={turn.end:.1f}s.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc0075a9-d089-46a9-bf0c-47ff4a158d8f", + "metadata": {}, + "outputs": [], + "source": [ + "from ipywidgets import Audio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efdd69fb-06f3-4eee-b167-a73496ebbdca", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "Audio(value=open('/home/reves/test-dt.wav','rb').read(),format=\"wav\")\n", + "# test = Audio.from_file('/home/reves/test-dt.wav', autoplay=False)\n", + "# test.play()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9af87675-eade-4b5b-aa22-8b877202102a", + "metadata": {}, + "outputs": [], + "source": [ + "from plume.utils.transcribe import chunk_transcribe_meta_gen, transcribe_rpyc_gen\n", + "base_transcriber, base_prep = transcribe_rpyc_gen()\n", + "transcriber, prep = chunk_transcribe_meta_gen(\n", + " base_transcriber, base_prep, method=\"chunked\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "0fef3f35-7143-49c6-8cd2-862b5d2b7344", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pydub\n", + "audio_file = '/home/reves/test-dt.wav'\n", + "aseg = pydub.AudioSegment.from_file(audio_file)\n", + "aseg" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "34cbea05-4c57-4a3c-8832-5677d2486231", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "#Agent[1.2s-7.4s]: வணக்கம் சார் என்ளடைய நன்ுன் நகந்துக்கு சுலப்கப்பட்டி ரெச்சோரும் சல் எ்ன செலவு கன்றதுக்காவர் பா் \n", + "#Customer[7.4s-8.5s]: யர ர் \n", + "#Agent[8.8s-13.6s]: சரனால பிண்டிக்கல் தலப்பாகட்டி ரெஸ்டோரன்ட்ருந்து ஃபீட்பாக் கலெக்ட் வண்றதுக்காக கால் பண்ணிருக்கோம் ச \n", + "#Customer[13.9s-14.7s]: அ்கலு மேடம் \n", + "#Agent[15.4s-19.7s]: சர் இந்த கால் வந்து ரெண்டு ண்து மூனு நிமிஷம் நடிக்கும் பேசலாம்வாக செர்ன் தெரிலமா \n", + "#Customer[20.2s-22.8s]: இல்ல அவள வேலைப \n", + "#Customer[23.3s-24.1s]: சரிாச்சா \n", + "#Agent[24.7s-33.2s]: எச்ச ஃபீஸ்பாக் கலெக்ட் பண்றதுக்காக கால் பண்ணிருக்கோம் உங்க பேர் ரமேஷன்லா படேஷன் குவாலிட்டி மட்டு வி்சலி் க பு்பப்பறப்பா் செயப்படு் \n", + "#Agent[33.5s-37.8s]: நீங்க டைமன் பணிரந்த் அனுபவம் எப்டி எந்த்க்கரக்்சசி்டங் \n", + "#Customer[38.2s-41.2s]: ஆஹஆஹ் நல்லாருக்கேன் மேடம்்ுடி \n", + "#Customer[42.0s-45.5s]: நன்காசசாகடன் இந்ல எனக்கு ச்ிப்ணிநான்பகே்் \n", + "#Customer[46.1s-46.9s]: ஓகேயவநல்ா் \n", + "#Agent[47.4s-56.7s]: இ்தரெஸ்கரண்டுக்கு மொத்தம் ஐந்து மதுப்பெண்கள்ர்கறைசர் ஒன்ணு குறைவானது ஐந்ததிலும் ங்ளுட அனுபவத்து வச்சு எுத்கு நீங்க ஐந்து கைந்து குடுத்திருக்கீங்க இந்த மதிப்பெண்கள் உங்களல குடுக்கப்பட்டதீங்களா \n", + "#Customer[58.0s-58.4s]: என்ன மேடம் \n", + "#Agent[58.9s-68.1s]: அதாவது நேத்து நீங்க ரேட்டிங்பை்கு பைல் மிச் ரெஸ்ட்ரோன்ஸ்ல ப்கிஙகீங்களா சார் வரேட்டிங் சகம் மித்தநன்றி்ச வ்மேஷன் மித் க \n", + "#Customer[68.1s-72.8s]: மேடம்என் நம்ப தரவட ஓர் சி இரங்குத்தம்பாத்துருவான கஷ்சப்பட்டங்க எப்டி பண்ணம் சாப்டுவாங்க \n", + "#Customer[73.8s-78.4s]: சாபால் நல்லாருந்து என்ன பேசிும் ல்லா இல்லதாு எதனத்தம்ப தரனா \n", + "#Customer[78.6s-81.0s]: சொல்லுாசி \n", + "#Agent[82.7s-85.6s]: சோதலை அதிகமாருக்கு நிலங்க \n", + "#Customer[86.1s-88.6s]: போ்காொ்சம் வேலை யரக்குன்டா ஈவ்னிங் தூரும் \n", + "#Customer[89.6s-90.2s]: உ \n" + ] + } + ], + "source": [ + "for turn, _, speaker in diarization.itertracks(yield_label=True):\n", + " #print(f'Speaker \"{speaker}\" speaks between t={turn.start:.1f}s and t={turn.end:.1f}s.')\n", + " speaker_label = \"Agent\" if speaker == \"B\" else \"Customer\"\n", + " tscirpt = transcriber(prep(aseg[turn.start*1000:turn.end*1000]))\n", + " print(f'#{speaker_label}[{turn.start:.1f}s-{turn.end:.1f}s]: {tscirpt}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4ea9ee1-cc14-4e00-8433-45d416298af2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/setup.py b/setup.py index edba725..8821295 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ extra_requirements = { "pymongo~=3.10.1", "matplotlib~=3.2.1", "pydub~=0.24.0", - "streamlit~=0.58.0", + "streamlit~=1.0.0", "natural~=0.2.0", "stringcase~=1.2.0", "google-cloud-speech~=1.3.1", @@ -85,6 +85,11 @@ extra_requirements = { "pyspellchecker~=0.6.2", "google-cloud-texttospeech~=1.0.1", "rangehttpserver~=1.2.0", + "streamlit~=1.0.0", + ], + "dev": [ + "jupyterlab~=3.1.18", + "ipykernel~=6.4.1", ], "crypto": ["cryptography~=3.4.7"], "train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"], diff --git a/src/plume/cli/data/generate.py b/src/plume/cli/data/generate.py index 464346b..8056c60 100644 --- a/src/plume/cli/data/generate.py +++ b/src/plume/cli/data/generate.py @@ -1,11 +1,17 @@ from pathlib import Path import shutil - +from tqdm import tqdm import typer from plume.utils.lazy_import import lazy_module from plume.utils.tts import GoogleTTS -from plume.utils.transcribe import triton_transcribe_grpc_gen +from plume.utils.transcribe import ( + triton_transcribe_grpc_gen, + chunk_transcribe_meta_gen, + transcribe_rpyc_gen, +) from plume.utils.manifest import asr_manifest_writer +from plume.utils.diarize import diarize_audio_gen +from plume.utils.extended_path import ExtendedPath pydub = lazy_module("pydub") app = typer.Typer() @@ -50,4 +56,46 @@ def asr_dataset(audio_dir: Path, out_dir: Path, model="slu_num_wav2vec2"): "text": transcript, } - asr_manifest_writer(out_dir / 'manifest.json', data_gen()) + asr_manifest_writer(out_dir / "manifest.json", data_gen()) + + +@app.command() +def mono_diarize_asr_dataset(audio_dir: Path, out_dir: Path): + out_wav_dir = out_dir / "wavs" + out_wav_dir.mkdir(exist_ok=True, parents=True) + diarize_audio = diarize_audio_gen() + + def data_gen(): + aud_files = list(audio_dir.glob("*/*.mp3")) + list( + audio_dir.glob("*/*.wav") + ) + diameta = ExtendedPath(out_dir / "diameta.json") + base_transcriber, base_prep = transcribe_rpyc_gen() + transcriber, prep = chunk_transcribe_meta_gen( + base_transcriber, base_prep, method="chunked" + ) + + diametadata = [] + for af in tqdm(aud_files): + try: + # raise Exception("Test") + for dres in diarize_audio(af): + sample_fname = dres.pop("sample_fname") + out_af = out_wav_dir / sample_fname + wav_bytes = dres.pop("wav") + out_af.write_bytes(wav_bytes) + audio_af = out_af.relative_to(out_dir) + aud_seg = dres.pop("wavseg") + t_seg = prep(aud_seg) + transcript = transcriber(t_seg) + diametadata.append(dres) + yield { + "audio_filepath": str(audio_af), + "duration": aud_seg.duration_seconds, + "text": transcript, + } + except Exception as e: + print(f'error diariziaing/trascribing {af} - {e}') + diameta.write_json(diametadata) + + asr_manifest_writer(out_dir / "manifest.json", data_gen()) diff --git a/src/plume/models/pyann-dia/test.py b/src/plume/models/pyann-dia/test.py new file mode 100644 index 0000000..e69de29 diff --git a/src/plume/models/wav2vec2_transformers/asr.py b/src/plume/models/wav2vec2_transformers/asr.py index c8d9d59..08255f9 100644 --- a/src/plume/models/wav2vec2_transformers/asr.py +++ b/src/plume/models/wav2vec2_transformers/asr.py @@ -13,23 +13,38 @@ class Wav2Vec2TransformersASR(object): """docstring for Wav2Vec2TransformersASR.""" def __init__(self, model_dir): - super(Wav2Vec2TransformersASR, self).__init__() + # super(Wav2Vec2TransformersASR, self).__init__() + self.device = "cuda:1" if torch.cuda.is_available() else "cpu" + # sd = torch.load( + # model_dir / "pytorch_model.bin", map_location=self.device + # ) + # self.processor = Wav2Vec2Processor.from_pretrained( + # model_dir, state_dict=sd + # ) + # self.model = Wav2Vec2ForCTC.from_pretrained(model_dir, state_dict=sd).to(self.device) self.processor = Wav2Vec2Processor.from_pretrained(model_dir) - self.model = Wav2Vec2ForCTC.from_pretrained(model_dir) + self.model = Wav2Vec2ForCTC.from_pretrained(model_dir).to(self.device) def transcribe(self, audio_data): aud_f = BytesIO(audio_data) # net_input = {} speech_data, _ = sf.read(aud_f) input_values = self.processor( - speech_data, return_tensors="pt", padding="longest" - ).input_values # Batch size 1 + speech_data, + return_tensors="pt", + padding="longest", + sampling_rate=16000, + ).input_values.to( + self.device + ) # Batch size 1 # retrieve logits + #print(f"audio:{speech_data.shape} processed:{input_values.shape}") logits = self.model(input_values).logits - + #print(f"logit shape:{logits.shape}") # take argmax and decode predicted_ids = torch.argmax(logits, dim=-1) - + #print(f"predicted_ids shape:{predicted_ids.shape}") transcription = self.processor.batch_decode(predicted_ids)[0] - return transcription + result = transcription.replace('', '') + return result diff --git a/src/plume/models/wav2vec2_transformers/serve.py b/src/plume/models/wav2vec2_transformers/serve.py index 2feed4e..a511f46 100644 --- a/src/plume/models/wav2vec2_transformers/serve.py +++ b/src/plume/models/wav2vec2_transformers/serve.py @@ -1,5 +1,5 @@ import os -import logging +# import logging from pathlib import Path # from rpyc.utils.server import ThreadedServer @@ -10,6 +10,10 @@ from plume.utils import lazy_callable # from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR # from .asr import Wav2Vec2ASR +# logging.basicConfig( +# level=logging.INFO, +# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +# ) ThreadedServer = lazy_callable("rpyc.utils.server.ThreadedServer") Wav2Vec2TransformersASR = lazy_callable( @@ -41,13 +45,12 @@ app = typer.Typer() def rpyc_dir( model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")) ): + typer.echo("loading asr model...") w2vasr = Wav2Vec2TransformersASR(model_dir) + typer.echo("loaded asr model") service = ASRService(w2vasr) - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - logging.info("starting asr server...") + + typer.echo(f"serving asr on :{port}...") t = ThreadedServer(service, port=port) t.start() diff --git a/src/plume/ui/annotation.py b/src/plume/ui/annotation.py index 04c2a17..f04cba6 100644 --- a/src/plume/ui/annotation.py +++ b/src/plume/ui/annotation.py @@ -46,7 +46,10 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""): st.title(f"ASR Validation - # {task_uid}") st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**") new_sample = st.number_input( - "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) + "Go To Sample:", + value=sample_no + 1, + min_value=1, + max_value=len(asr_data), ) if new_sample != sample_no + 1: st.update_cursor(new_sample - 1) @@ -60,7 +63,9 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""): show_key(sample, "asr_wer", trail="%") show_key(sample, "correct_candidate") - st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes()) + if "plot_path" in sample: + st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes()) + st.audio((data_dir / Path(sample["audio_path"])).open("rb")) # set default to text corrected = sample["text"] @@ -78,27 +83,37 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""): corrected = "" if st.button("Submit"): st.update_entry( - sample["utterance_id"], {"status": selected, "correction": corrected} + sample["utterance_id"], + {"status": selected, "correction": corrected}, ) st.update_cursor(sample_no + 1) if correction_entry: status = correction_entry["value"]["status"] correction = correction_entry["value"]["correction"] - st.markdown(f"Your Response: **{status}** Correction: **{correction}**") + st.markdown( + f"Your Response: **{status}** Correction: **{correction}**" + ) text_sample = st.text_input("Go to Text:", value="") if text_sample != "": - candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample] + candidates = [ + i for (i, p) in enumerate(asr_data) if p["text"] == text_sample + ] if len(candidates) > 0: st.update_cursor(candidates[0]) - real_idx = st.number_input( - "Go to real-index", - value=sample["real_idx"], - min_value=0, - max_value=len(asr_data) - 1, - ) - if real_idx != int(sample["real_idx"]): - idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0] - st.update_cursor(idx) + if "real_idx" in sample: + real_idx = st.number_input( + "Go to real-index", + value=sample["real_idx"], + min_value=0, + max_value=len(asr_data) - 1, + ) + if real_idx != int(sample["real_idx"]): + idx = [ + i + for (i, p) in enumerate(asr_data) + if p["real_idx"] == real_idx + ][0] + st.update_cursor(idx) if __name__ == "__main__": diff --git a/src/plume/ui/preview.py b/src/plume/ui/preview.py index 89af3ad..6250931 100644 --- a/src/plume/ui/preview.py +++ b/src/plume/ui/preview.py @@ -27,7 +27,10 @@ def main(manifest: Path): st.title("ASR Manifest Preview") st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**") new_sample = st.number_input( - "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) + "Go To Sample:", + value=sample_no + 1, + min_value=1, + max_value=len(asr_data), ) if new_sample != sample_no + 1: st.update_cursor(new_sample - 1) diff --git a/src/plume/utils/__init__.py b/src/plume/utils/__init__.py index 87b2237..f0f5d86 100644 --- a/src/plume/utils/__init__.py +++ b/src/plume/utils/__init__.py @@ -1,7 +1,6 @@ import io import os import re -import json import wave import logging import subprocess diff --git a/src/plume/utils/diarize.py b/src/plume/utils/diarize.py new file mode 100644 index 0000000..2f9ae6b --- /dev/null +++ b/src/plume/utils/diarize.py @@ -0,0 +1,65 @@ +from pathlib import Path +from plume.utils import lazy_module +from plume.utils.audio import audio_seg_to_wav_bytes + +pydub = lazy_module('pydub') +torch = lazy_module('torch') + + +def transform_audio(file_location, path_to_save): + audio_seg = ( + pydub.AudioSegment.from_file(file_location) + .set_frame_rate(16000) + .set_sample_width(2) + ) + audio_seg.export(path_to_save, format="wav") + + +def gen_diarizer(): + pipeline = torch.hub.load("pyannote/pyannote-audio", "dia") + + def _diarizer(audio_path): + return pipeline({"audio": audio_path}) + + return _diarizer + + +# base_transcriber, base_prep = transcribe_rpyc_gen() +# transcriber, prep = chunk_transcribe_meta_gen( +# base_transcriber, base_prep, method="chunked") + +# diarizer = gen_diarizer() + + +def diarize_audio_gen(): + diarizer = gen_diarizer() + + def _diarize_audio(audio_path: Path): + aseg = ( + pydub.AudioSegment.from_file(audio_path) + .set_frame_rate(16000) + .set_sample_width(2) + .set_channels(1) + ) + aseg.export("/tmp/temp.wav", format="wav") + diarization = diarizer("/tmp/temp.wav") + for n, (turn, _, speaker) in enumerate( + diarization.itertracks(yield_label=True) + ): + # speaker_label = "Agent" if speaker == "B" else "Customer" + turn_seg = aseg[turn.start * 1000 : turn.end * 1000] + sample_fname = ( + audio_path.stem + "_" + str(n) + ".wav" + ) + yield { + "speaker": speaker, + "wav": audio_seg_to_wav_bytes(turn_seg), + "wavseg": turn_seg, + "start": turn.start, + "end": turn.end, + "turnidx": n, + "filename": audio_path.name, + "sample_fname": sample_fname + } + + return _diarize_audio diff --git a/src/plume/utils/transcribe.py b/src/plume/utils/transcribe.py index a4661bc..3ddcc58 100644 --- a/src/plume/utils/transcribe.py +++ b/src/plume/utils/transcribe.py @@ -47,12 +47,15 @@ def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT): asr_seg = ( aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000) ) + # af = BytesIO() + # asr_seg.export(af, format="wav") + # input_audio_bytes = af.getvalue() + return asr_seg + + def dummy_transcript(asr_seg, append_raw=False): af = BytesIO() asr_seg.export(af, format="wav") - input_audio_bytes = af.getvalue() - return input_audio_bytes - - def dummy_transcript(aud, append_raw=False): + aud = af.getvalue() return asr.transcribe(aud) return dummy_transcript, audio_prep @@ -147,6 +150,56 @@ def triton_transcribe_grpc_gen( return whole_transcriber, audio_prep +def chunk_transcribe_meta_gen( + transcriber, + prep, + method="chunked", + chunk_msec=5000, + sil_msec=500, + sep=" ", +): + from tritonclient.utils import np_to_triton_dtype, InferenceServerException + import tritonclient.grpc as grpcclient + # force loading + np.array + + sup_meth = ["chunked", "silence", "whole"] + if method not in sup_meth: + meths = "|".join(sup_meth) + raise Exception(f"unsupported method {method}. pick one of {meths}") + + def chunked_transcriber(aud_seg): + if method == "silence": + sil_chunks = pydub.silence.split_on_silence( + aud_seg, + min_silence_len=sil_msec, + silence_thresh=-50, + keep_silence=500, + ) + chunks = [sc for c in sil_chunks for sc in c[::chunk_msec]] + else: + chunks = aud_seg[::chunk_msec] + # if overlap: + # chunks = [ + # aud_seg[start, end] + # for start, end in range(0, int(aud_seg.duration_seconds * 1000, 1000)) + # ] + # pass + transcript_list = [] + sil_pad = pydub.AudioSegment.silent(duration=sil_msec) + for seg in chunks: + t_seg = sil_pad + seg + sil_pad + c_transcript = transcriber(t_seg) + transcript_list.append(c_transcript) + transcript = sep.join(transcript_list) + return transcript + whole_transcriber = ( + transcriber if method == "whole" else chunked_transcriber + ) + + return whole_transcriber, prep + + @app.command() def audio_file( audio_file: Path, @@ -157,13 +210,15 @@ def audio_file( model="slu_num_wav2vec2", ): aseg = pydub.AudioSegment.from_file(audio_file) + method = "chunked" if chunked else "whole" if rpyc: - transcriber, prep = transcribe_rpyc_gen() + base_transcriber, base_prep = transcribe_rpyc_gen() else: - method = "chunked" if chunked else "whole" - transcriber, prep = triton_transcribe_grpc_gen( - asr_model=model, method=method, append_raw=append_raw + base_transcriber, base_prep = triton_transcribe_grpc_gen( + asr_model=model, method='whole', append_raw=append_raw ) + transcriber, prep = chunk_transcribe_meta_gen( + base_transcriber, base_prep, method=method) transcription = transcriber(prep(aseg)) typer.echo(transcription) diff --git a/src/plume/utils/ui_persist.py b/src/plume/utils/ui_persist.py index f050d60..9113b79 100644 --- a/src/plume/utils/ui_persist.py +++ b/src/plume/utils/ui_persist.py @@ -11,10 +11,17 @@ def setup_file_state(st): def current_cursor_fn(): return task_path.read_json()["current_cursor"] + # if "audio_sample_idx" not in st.session_state: + # st.session_state.audio_sample_idx = task_path.read_json()[ + # "current_cursor" + # ] + # return st.session_state.audio_sample_idx def update_cursor_fn(val=0): task_path.write_json({"current_cursor": val}) - rerun() + # rerun() + # st.session_state.audio_sample_idx = val + st.experimental_rerun() st.get_current_cursor = current_cursor_fn st.update_cursor = update_cursor_fn