From 79aa5e85788070aad688d41fbbf9f8b1f8aa8fb5 Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Sun, 24 Oct 2021 01:32:15 +0530 Subject: [PATCH] 1. set flake8 max-line to 79 2. update streamlit dep to 1.0 3. add dev optional dep key 4. implement mono diarized dataset generation script 5. enable gpu support on asr transformers inference pipeline 6. use typer logging 7. clean-up annotation ui with everything other than asr-data keys as optional(including plots) 8. implement chunk_transcribe_meta_gen abstraction for asr chunking logic 9. make ui_persist compatibility change for streamlit 1.0 10. add diarize commands(bugfix) 11. add notebooks for diarization --- .flake8 | 2 +- notebooks/Diarization.ipynb | 225 ++++++++++++++++++ setup.py | 7 +- src/plume/cli/data/generate.py | 54 ++++- src/plume/models/pyann-dia/test.py | 0 src/plume/models/wav2vec2_transformers/asr.py | 29 ++- .../models/wav2vec2_transformers/serve.py | 15 +- src/plume/ui/annotation.py | 43 ++-- src/plume/ui/preview.py | 5 +- src/plume/utils/__init__.py | 1 - src/plume/utils/diarize.py | 65 +++++ src/plume/utils/transcribe.py | 71 +++++- src/plume/utils/ui_persist.py | 9 +- 13 files changed, 483 insertions(+), 43 deletions(-) create mode 100644 notebooks/Diarization.ipynb create mode 100644 src/plume/models/pyann-dia/test.py create mode 100644 src/plume/utils/diarize.py diff --git a/.flake8 b/.flake8 index 170a050..6e6e140 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] exclude = docs ignore = E203, W503 -max-line-length = 119 +max-line-length = 79 diff --git a/notebooks/Diarization.ipynb b/notebooks/Diarization.ipynb new file mode 100644 index 0000000..de5e9b7 --- /dev/null +++ b/notebooks/Diarization.ipynb @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "id": "808d647c-deef-41ec-8a69-06f7a72689cb", + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e49e45a1-2133-487c-a8bc-69caed303074", + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install pyannote.audio==1.1.1\n", + "#!pip install pyannote.core[notebook]\n", + "#!pip install pyannote.pipeline\n", + "#!pip install pyannote.core\n", + "!pip install ipywidgets" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "59437acd-b487-4531-bb4e-a4c216a67a29", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n", + "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n", + "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n", + "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n", + "/home/reves/plume-asr/.direnv/python-3.8.10/lib/python3.8/site-packages/pyannote/audio/embedding/approaches/arcface_loss.py:170: FutureWarning: The 's' parameter is deprecated in favor of 'scale', and will be removed in a future release\n", + " warnings.warn(msg, FutureWarning)\n", + "/home/reves/plume-asr/.direnv/python-3.8.10/lib/python3.8/site-packages/pyannote/audio/features/pretrained.py:156: UserWarning: Model was trained with 4s chunks and is applied on 2s chunks. This might lead to sub-optimal results.\n", + " warnings.warn(msg)\n", + "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n" + ] + } + ], + "source": [ + "#pipeline = torch.hub.load('pyannote/pyannote-audio', 'dia_ami')\n", + "pipeline = torch.hub.load('pyannote/pyannote-audio', 'dia')" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "5f69fda5-d94d-4c6f-854d-fd2788ff5bc5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3.94 s, sys: 179 ms, total: 4.12 s\n", + "Wall time: 4.12 s\n" + ] + } + ], + "source": [ + "%time diarization = pipeline({'audio': '/home/reves/test-dt.wav'})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93eb8203-af2f-4905-90ec-7e667200e647", + "metadata": {}, + "outputs": [], + "source": [ + "for turn, _, speaker in diarization.itertracks(yield_label=True):\n", + " print(f'Speaker \"{speaker}\" speaks between t={turn.start:.1f}s and t={turn.end:.1f}s.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc0075a9-d089-46a9-bf0c-47ff4a158d8f", + "metadata": {}, + "outputs": [], + "source": [ + "from ipywidgets import Audio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efdd69fb-06f3-4eee-b167-a73496ebbdca", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "Audio(value=open('/home/reves/test-dt.wav','rb').read(),format=\"wav\")\n", + "# test = Audio.from_file('/home/reves/test-dt.wav', autoplay=False)\n", + "# test.play()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9af87675-eade-4b5b-aa22-8b877202102a", + "metadata": {}, + "outputs": [], + "source": [ + "from plume.utils.transcribe import chunk_transcribe_meta_gen, transcribe_rpyc_gen\n", + "base_transcriber, base_prep = transcribe_rpyc_gen()\n", + "transcriber, prep = chunk_transcribe_meta_gen(\n", + " base_transcriber, base_prep, method=\"chunked\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "0fef3f35-7143-49c6-8cd2-862b5d2b7344", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pydub\n", + "audio_file = '/home/reves/test-dt.wav'\n", + "aseg = pydub.AudioSegment.from_file(audio_file)\n", + "aseg" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "34cbea05-4c57-4a3c-8832-5677d2486231", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "#Agent[1.2s-7.4s]: வணக்கம் சார் என்ளடைய நன்ுன் நகந்துக்கு சுலப்கப்பட்டி ரெச்சோரும் சல் எ்ன செலவு கன்றதுக்காவர் பா் \n", + "#Customer[7.4s-8.5s]: யர ர் \n", + "#Agent[8.8s-13.6s]: சரனால பிண்டிக்கல் தலப்பாகட்டி ரெஸ்டோரன்ட்ருந்து ஃபீட்பாக் கலெக்ட் வண்றதுக்காக கால் பண்ணிருக்கோம் ச \n", + "#Customer[13.9s-14.7s]: அ்கலு மேடம் \n", + "#Agent[15.4s-19.7s]: சர் இந்த கால் வந்து ரெண்டு ண்து மூனு நிமிஷம் நடிக்கும் பேசலாம்வாக செர்ன் தெரிலமா \n", + "#Customer[20.2s-22.8s]: இல்ல அவள வேலைப \n", + "#Customer[23.3s-24.1s]: சரிாச்சா \n", + "#Agent[24.7s-33.2s]: எச்ச ஃபீஸ்பாக் கலெக்ட் பண்றதுக்காக கால் பண்ணிருக்கோம் உங்க பேர் ரமேஷன்லா படேஷன் குவாலிட்டி மட்டு வி்சலி் க பு்பப்பறப்பா் செயப்படு் \n", + "#Agent[33.5s-37.8s]: நீங்க டைமன் பணிரந்த் அனுபவம் எப்டி எந்த்க்கரக்்சசி்டங் \n", + "#Customer[38.2s-41.2s]: ஆஹஆஹ் நல்லாருக்கேன் மேடம்்ுடி \n", + "#Customer[42.0s-45.5s]: நன்காசசாகடன் இந்ல எனக்கு ச்ிப்ணிநான்பகே்் \n", + "#Customer[46.1s-46.9s]: ஓகேயவநல்ா் \n", + "#Agent[47.4s-56.7s]: இ்தரெஸ்கரண்டுக்கு மொத்தம் ஐந்து மதுப்பெண்கள்ர்கறைசர் ஒன்ணு குறைவானது ஐந்ததிலும் ங்ளுட அனுபவத்து வச்சு எுத்கு நீங்க ஐந்து கைந்து குடுத்திருக்கீங்க இந்த மதிப்பெண்கள் உங்களல குடுக்கப்பட்டதீங்களா \n", + "#Customer[58.0s-58.4s]: என்ன மேடம் \n", + "#Agent[58.9s-68.1s]: அதாவது நேத்து நீங்க ரேட்டிங்பை்கு பைல் மிச் ரெஸ்ட்ரோன்ஸ்ல ப்கிஙகீங்களா சார் வரேட்டிங் சகம் மித்தநன்றி்ச வ்மேஷன் மித் க \n", + "#Customer[68.1s-72.8s]: மேடம்என் நம்ப தரவட ஓர் சி இரங்குத்தம்பாத்துருவான கஷ்சப்பட்டங்க எப்டி பண்ணம் சாப்டுவாங்க \n", + "#Customer[73.8s-78.4s]: சாபால் நல்லாருந்து என்ன பேசிும் ல்லா இல்லதாு எதனத்தம்ப தரனா \n", + "#Customer[78.6s-81.0s]: சொல்லுாசி \n", + "#Agent[82.7s-85.6s]: சோதலை அதிகமாருக்கு நிலங்க \n", + "#Customer[86.1s-88.6s]: போ்காொ்சம் வேலை யரக்குன்டா ஈவ்னிங் தூரும் \n", + "#Customer[89.6s-90.2s]: உ \n" + ] + } + ], + "source": [ + "for turn, _, speaker in diarization.itertracks(yield_label=True):\n", + " #print(f'Speaker \"{speaker}\" speaks between t={turn.start:.1f}s and t={turn.end:.1f}s.')\n", + " speaker_label = \"Agent\" if speaker == \"B\" else \"Customer\"\n", + " tscirpt = transcriber(prep(aseg[turn.start*1000:turn.end*1000]))\n", + " print(f'#{speaker_label}[{turn.start:.1f}s-{turn.end:.1f}s]: {tscirpt}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4ea9ee1-cc14-4e00-8433-45d416298af2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/setup.py b/setup.py index edba725..8821295 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ extra_requirements = { "pymongo~=3.10.1", "matplotlib~=3.2.1", "pydub~=0.24.0", - "streamlit~=0.58.0", + "streamlit~=1.0.0", "natural~=0.2.0", "stringcase~=1.2.0", "google-cloud-speech~=1.3.1", @@ -85,6 +85,11 @@ extra_requirements = { "pyspellchecker~=0.6.2", "google-cloud-texttospeech~=1.0.1", "rangehttpserver~=1.2.0", + "streamlit~=1.0.0", + ], + "dev": [ + "jupyterlab~=3.1.18", + "ipykernel~=6.4.1", ], "crypto": ["cryptography~=3.4.7"], "train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"], diff --git a/src/plume/cli/data/generate.py b/src/plume/cli/data/generate.py index 464346b..8056c60 100644 --- a/src/plume/cli/data/generate.py +++ b/src/plume/cli/data/generate.py @@ -1,11 +1,17 @@ from pathlib import Path import shutil - +from tqdm import tqdm import typer from plume.utils.lazy_import import lazy_module from plume.utils.tts import GoogleTTS -from plume.utils.transcribe import triton_transcribe_grpc_gen +from plume.utils.transcribe import ( + triton_transcribe_grpc_gen, + chunk_transcribe_meta_gen, + transcribe_rpyc_gen, +) from plume.utils.manifest import asr_manifest_writer +from plume.utils.diarize import diarize_audio_gen +from plume.utils.extended_path import ExtendedPath pydub = lazy_module("pydub") app = typer.Typer() @@ -50,4 +56,46 @@ def asr_dataset(audio_dir: Path, out_dir: Path, model="slu_num_wav2vec2"): "text": transcript, } - asr_manifest_writer(out_dir / 'manifest.json', data_gen()) + asr_manifest_writer(out_dir / "manifest.json", data_gen()) + + +@app.command() +def mono_diarize_asr_dataset(audio_dir: Path, out_dir: Path): + out_wav_dir = out_dir / "wavs" + out_wav_dir.mkdir(exist_ok=True, parents=True) + diarize_audio = diarize_audio_gen() + + def data_gen(): + aud_files = list(audio_dir.glob("*/*.mp3")) + list( + audio_dir.glob("*/*.wav") + ) + diameta = ExtendedPath(out_dir / "diameta.json") + base_transcriber, base_prep = transcribe_rpyc_gen() + transcriber, prep = chunk_transcribe_meta_gen( + base_transcriber, base_prep, method="chunked" + ) + + diametadata = [] + for af in tqdm(aud_files): + try: + # raise Exception("Test") + for dres in diarize_audio(af): + sample_fname = dres.pop("sample_fname") + out_af = out_wav_dir / sample_fname + wav_bytes = dres.pop("wav") + out_af.write_bytes(wav_bytes) + audio_af = out_af.relative_to(out_dir) + aud_seg = dres.pop("wavseg") + t_seg = prep(aud_seg) + transcript = transcriber(t_seg) + diametadata.append(dres) + yield { + "audio_filepath": str(audio_af), + "duration": aud_seg.duration_seconds, + "text": transcript, + } + except Exception as e: + print(f'error diariziaing/trascribing {af} - {e}') + diameta.write_json(diametadata) + + asr_manifest_writer(out_dir / "manifest.json", data_gen()) diff --git a/src/plume/models/pyann-dia/test.py b/src/plume/models/pyann-dia/test.py new file mode 100644 index 0000000..e69de29 diff --git a/src/plume/models/wav2vec2_transformers/asr.py b/src/plume/models/wav2vec2_transformers/asr.py index c8d9d59..08255f9 100644 --- a/src/plume/models/wav2vec2_transformers/asr.py +++ b/src/plume/models/wav2vec2_transformers/asr.py @@ -13,23 +13,38 @@ class Wav2Vec2TransformersASR(object): """docstring for Wav2Vec2TransformersASR.""" def __init__(self, model_dir): - super(Wav2Vec2TransformersASR, self).__init__() + # super(Wav2Vec2TransformersASR, self).__init__() + self.device = "cuda:1" if torch.cuda.is_available() else "cpu" + # sd = torch.load( + # model_dir / "pytorch_model.bin", map_location=self.device + # ) + # self.processor = Wav2Vec2Processor.from_pretrained( + # model_dir, state_dict=sd + # ) + # self.model = Wav2Vec2ForCTC.from_pretrained(model_dir, state_dict=sd).to(self.device) self.processor = Wav2Vec2Processor.from_pretrained(model_dir) - self.model = Wav2Vec2ForCTC.from_pretrained(model_dir) + self.model = Wav2Vec2ForCTC.from_pretrained(model_dir).to(self.device) def transcribe(self, audio_data): aud_f = BytesIO(audio_data) # net_input = {} speech_data, _ = sf.read(aud_f) input_values = self.processor( - speech_data, return_tensors="pt", padding="longest" - ).input_values # Batch size 1 + speech_data, + return_tensors="pt", + padding="longest", + sampling_rate=16000, + ).input_values.to( + self.device + ) # Batch size 1 # retrieve logits + #print(f"audio:{speech_data.shape} processed:{input_values.shape}") logits = self.model(input_values).logits - + #print(f"logit shape:{logits.shape}") # take argmax and decode predicted_ids = torch.argmax(logits, dim=-1) - + #print(f"predicted_ids shape:{predicted_ids.shape}") transcription = self.processor.batch_decode(predicted_ids)[0] - return transcription + result = transcription.replace('', '') + return result diff --git a/src/plume/models/wav2vec2_transformers/serve.py b/src/plume/models/wav2vec2_transformers/serve.py index 2feed4e..a511f46 100644 --- a/src/plume/models/wav2vec2_transformers/serve.py +++ b/src/plume/models/wav2vec2_transformers/serve.py @@ -1,5 +1,5 @@ import os -import logging +# import logging from pathlib import Path # from rpyc.utils.server import ThreadedServer @@ -10,6 +10,10 @@ from plume.utils import lazy_callable # from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR # from .asr import Wav2Vec2ASR +# logging.basicConfig( +# level=logging.INFO, +# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +# ) ThreadedServer = lazy_callable("rpyc.utils.server.ThreadedServer") Wav2Vec2TransformersASR = lazy_callable( @@ -41,13 +45,12 @@ app = typer.Typer() def rpyc_dir( model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")) ): + typer.echo("loading asr model...") w2vasr = Wav2Vec2TransformersASR(model_dir) + typer.echo("loaded asr model") service = ASRService(w2vasr) - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - logging.info("starting asr server...") + + typer.echo(f"serving asr on :{port}...") t = ThreadedServer(service, port=port) t.start() diff --git a/src/plume/ui/annotation.py b/src/plume/ui/annotation.py index 04c2a17..f04cba6 100644 --- a/src/plume/ui/annotation.py +++ b/src/plume/ui/annotation.py @@ -46,7 +46,10 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""): st.title(f"ASR Validation - # {task_uid}") st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**") new_sample = st.number_input( - "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) + "Go To Sample:", + value=sample_no + 1, + min_value=1, + max_value=len(asr_data), ) if new_sample != sample_no + 1: st.update_cursor(new_sample - 1) @@ -60,7 +63,9 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""): show_key(sample, "asr_wer", trail="%") show_key(sample, "correct_candidate") - st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes()) + if "plot_path" in sample: + st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes()) + st.audio((data_dir / Path(sample["audio_path"])).open("rb")) # set default to text corrected = sample["text"] @@ -78,27 +83,37 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""): corrected = "" if st.button("Submit"): st.update_entry( - sample["utterance_id"], {"status": selected, "correction": corrected} + sample["utterance_id"], + {"status": selected, "correction": corrected}, ) st.update_cursor(sample_no + 1) if correction_entry: status = correction_entry["value"]["status"] correction = correction_entry["value"]["correction"] - st.markdown(f"Your Response: **{status}** Correction: **{correction}**") + st.markdown( + f"Your Response: **{status}** Correction: **{correction}**" + ) text_sample = st.text_input("Go to Text:", value="") if text_sample != "": - candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample] + candidates = [ + i for (i, p) in enumerate(asr_data) if p["text"] == text_sample + ] if len(candidates) > 0: st.update_cursor(candidates[0]) - real_idx = st.number_input( - "Go to real-index", - value=sample["real_idx"], - min_value=0, - max_value=len(asr_data) - 1, - ) - if real_idx != int(sample["real_idx"]): - idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0] - st.update_cursor(idx) + if "real_idx" in sample: + real_idx = st.number_input( + "Go to real-index", + value=sample["real_idx"], + min_value=0, + max_value=len(asr_data) - 1, + ) + if real_idx != int(sample["real_idx"]): + idx = [ + i + for (i, p) in enumerate(asr_data) + if p["real_idx"] == real_idx + ][0] + st.update_cursor(idx) if __name__ == "__main__": diff --git a/src/plume/ui/preview.py b/src/plume/ui/preview.py index 89af3ad..6250931 100644 --- a/src/plume/ui/preview.py +++ b/src/plume/ui/preview.py @@ -27,7 +27,10 @@ def main(manifest: Path): st.title("ASR Manifest Preview") st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**") new_sample = st.number_input( - "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) + "Go To Sample:", + value=sample_no + 1, + min_value=1, + max_value=len(asr_data), ) if new_sample != sample_no + 1: st.update_cursor(new_sample - 1) diff --git a/src/plume/utils/__init__.py b/src/plume/utils/__init__.py index 87b2237..f0f5d86 100644 --- a/src/plume/utils/__init__.py +++ b/src/plume/utils/__init__.py @@ -1,7 +1,6 @@ import io import os import re -import json import wave import logging import subprocess diff --git a/src/plume/utils/diarize.py b/src/plume/utils/diarize.py new file mode 100644 index 0000000..2f9ae6b --- /dev/null +++ b/src/plume/utils/diarize.py @@ -0,0 +1,65 @@ +from pathlib import Path +from plume.utils import lazy_module +from plume.utils.audio import audio_seg_to_wav_bytes + +pydub = lazy_module('pydub') +torch = lazy_module('torch') + + +def transform_audio(file_location, path_to_save): + audio_seg = ( + pydub.AudioSegment.from_file(file_location) + .set_frame_rate(16000) + .set_sample_width(2) + ) + audio_seg.export(path_to_save, format="wav") + + +def gen_diarizer(): + pipeline = torch.hub.load("pyannote/pyannote-audio", "dia") + + def _diarizer(audio_path): + return pipeline({"audio": audio_path}) + + return _diarizer + + +# base_transcriber, base_prep = transcribe_rpyc_gen() +# transcriber, prep = chunk_transcribe_meta_gen( +# base_transcriber, base_prep, method="chunked") + +# diarizer = gen_diarizer() + + +def diarize_audio_gen(): + diarizer = gen_diarizer() + + def _diarize_audio(audio_path: Path): + aseg = ( + pydub.AudioSegment.from_file(audio_path) + .set_frame_rate(16000) + .set_sample_width(2) + .set_channels(1) + ) + aseg.export("/tmp/temp.wav", format="wav") + diarization = diarizer("/tmp/temp.wav") + for n, (turn, _, speaker) in enumerate( + diarization.itertracks(yield_label=True) + ): + # speaker_label = "Agent" if speaker == "B" else "Customer" + turn_seg = aseg[turn.start * 1000 : turn.end * 1000] + sample_fname = ( + audio_path.stem + "_" + str(n) + ".wav" + ) + yield { + "speaker": speaker, + "wav": audio_seg_to_wav_bytes(turn_seg), + "wavseg": turn_seg, + "start": turn.start, + "end": turn.end, + "turnidx": n, + "filename": audio_path.name, + "sample_fname": sample_fname + } + + return _diarize_audio diff --git a/src/plume/utils/transcribe.py b/src/plume/utils/transcribe.py index a4661bc..3ddcc58 100644 --- a/src/plume/utils/transcribe.py +++ b/src/plume/utils/transcribe.py @@ -47,12 +47,15 @@ def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT): asr_seg = ( aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000) ) + # af = BytesIO() + # asr_seg.export(af, format="wav") + # input_audio_bytes = af.getvalue() + return asr_seg + + def dummy_transcript(asr_seg, append_raw=False): af = BytesIO() asr_seg.export(af, format="wav") - input_audio_bytes = af.getvalue() - return input_audio_bytes - - def dummy_transcript(aud, append_raw=False): + aud = af.getvalue() return asr.transcribe(aud) return dummy_transcript, audio_prep @@ -147,6 +150,56 @@ def triton_transcribe_grpc_gen( return whole_transcriber, audio_prep +def chunk_transcribe_meta_gen( + transcriber, + prep, + method="chunked", + chunk_msec=5000, + sil_msec=500, + sep=" ", +): + from tritonclient.utils import np_to_triton_dtype, InferenceServerException + import tritonclient.grpc as grpcclient + # force loading + np.array + + sup_meth = ["chunked", "silence", "whole"] + if method not in sup_meth: + meths = "|".join(sup_meth) + raise Exception(f"unsupported method {method}. pick one of {meths}") + + def chunked_transcriber(aud_seg): + if method == "silence": + sil_chunks = pydub.silence.split_on_silence( + aud_seg, + min_silence_len=sil_msec, + silence_thresh=-50, + keep_silence=500, + ) + chunks = [sc for c in sil_chunks for sc in c[::chunk_msec]] + else: + chunks = aud_seg[::chunk_msec] + # if overlap: + # chunks = [ + # aud_seg[start, end] + # for start, end in range(0, int(aud_seg.duration_seconds * 1000, 1000)) + # ] + # pass + transcript_list = [] + sil_pad = pydub.AudioSegment.silent(duration=sil_msec) + for seg in chunks: + t_seg = sil_pad + seg + sil_pad + c_transcript = transcriber(t_seg) + transcript_list.append(c_transcript) + transcript = sep.join(transcript_list) + return transcript + whole_transcriber = ( + transcriber if method == "whole" else chunked_transcriber + ) + + return whole_transcriber, prep + + @app.command() def audio_file( audio_file: Path, @@ -157,13 +210,15 @@ def audio_file( model="slu_num_wav2vec2", ): aseg = pydub.AudioSegment.from_file(audio_file) + method = "chunked" if chunked else "whole" if rpyc: - transcriber, prep = transcribe_rpyc_gen() + base_transcriber, base_prep = transcribe_rpyc_gen() else: - method = "chunked" if chunked else "whole" - transcriber, prep = triton_transcribe_grpc_gen( - asr_model=model, method=method, append_raw=append_raw + base_transcriber, base_prep = triton_transcribe_grpc_gen( + asr_model=model, method='whole', append_raw=append_raw ) + transcriber, prep = chunk_transcribe_meta_gen( + base_transcriber, base_prep, method=method) transcription = transcriber(prep(aseg)) typer.echo(transcription) diff --git a/src/plume/utils/ui_persist.py b/src/plume/utils/ui_persist.py index f050d60..9113b79 100644 --- a/src/plume/utils/ui_persist.py +++ b/src/plume/utils/ui_persist.py @@ -11,10 +11,17 @@ def setup_file_state(st): def current_cursor_fn(): return task_path.read_json()["current_cursor"] + # if "audio_sample_idx" not in st.session_state: + # st.session_state.audio_sample_idx = task_path.read_json()[ + # "current_cursor" + # ] + # return st.session_state.audio_sample_idx def update_cursor_fn(val=0): task_path.write_json({"current_cursor": val}) - rerun() + # rerun() + # st.session_state.audio_sample_idx = val + st.experimental_rerun() st.get_current_cursor = current_cursor_fn st.update_cursor = update_cursor_fn