diff --git a/.flake8 b/.flake8
index 170a050..6e6e140 100644
--- a/.flake8
+++ b/.flake8
@@ -1,4 +1,4 @@
[flake8]
exclude = docs
ignore = E203, W503
-max-line-length = 119
+max-line-length = 79
diff --git a/notebooks/Diarization.ipynb b/notebooks/Diarization.ipynb
new file mode 100644
index 0000000..de5e9b7
--- /dev/null
+++ b/notebooks/Diarization.ipynb
@@ -0,0 +1,225 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "808d647c-deef-41ec-8a69-06f7a72689cb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e49e45a1-2133-487c-a8bc-69caed303074",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#!pip install pyannote.audio==1.1.1\n",
+ "#!pip install pyannote.core[notebook]\n",
+ "#!pip install pyannote.pipeline\n",
+ "#!pip install pyannote.core\n",
+ "!pip install ipywidgets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "59437acd-b487-4531-bb4e-a4c216a67a29",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n",
+ "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n",
+ "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n",
+ "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n",
+ "/home/reves/plume-asr/.direnv/python-3.8.10/lib/python3.8/site-packages/pyannote/audio/embedding/approaches/arcface_loss.py:170: FutureWarning: The 's' parameter is deprecated in favor of 'scale', and will be removed in a future release\n",
+ " warnings.warn(msg, FutureWarning)\n",
+ "/home/reves/plume-asr/.direnv/python-3.8.10/lib/python3.8/site-packages/pyannote/audio/features/pretrained.py:156: UserWarning: Model was trained with 4s chunks and is applied on 2s chunks. This might lead to sub-optimal results.\n",
+ " warnings.warn(msg)\n",
+ "Using cache found in /home/reves/.cache/torch/hub/pyannote_pyannote-audio_master\n"
+ ]
+ }
+ ],
+ "source": [
+ "#pipeline = torch.hub.load('pyannote/pyannote-audio', 'dia_ami')\n",
+ "pipeline = torch.hub.load('pyannote/pyannote-audio', 'dia')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "5f69fda5-d94d-4c6f-854d-fd2788ff5bc5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 3.94 s, sys: 179 ms, total: 4.12 s\n",
+ "Wall time: 4.12 s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%time diarization = pipeline({'audio': '/home/reves/test-dt.wav'})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "93eb8203-af2f-4905-90ec-7e667200e647",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for turn, _, speaker in diarization.itertracks(yield_label=True):\n",
+ " print(f'Speaker \"{speaker}\" speaks between t={turn.start:.1f}s and t={turn.end:.1f}s.')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bc0075a9-d089-46a9-bf0c-47ff4a158d8f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from ipywidgets import Audio"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "efdd69fb-06f3-4eee-b167-a73496ebbdca",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "Audio(value=open('/home/reves/test-dt.wav','rb').read(),format=\"wav\")\n",
+ "# test = Audio.from_file('/home/reves/test-dt.wav', autoplay=False)\n",
+ "# test.play()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "9af87675-eade-4b5b-aa22-8b877202102a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from plume.utils.transcribe import chunk_transcribe_meta_gen, transcribe_rpyc_gen\n",
+ "base_transcriber, base_prep = transcribe_rpyc_gen()\n",
+ "transcriber, prep = chunk_transcribe_meta_gen(\n",
+ " base_transcriber, base_prep, method=\"chunked\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "0fef3f35-7143-49c6-8cd2-862b5d2b7344",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import pydub\n",
+ "audio_file = '/home/reves/test-dt.wav'\n",
+ "aseg = pydub.AudioSegment.from_file(audio_file)\n",
+ "aseg"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "34cbea05-4c57-4a3c-8832-5677d2486231",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "#Agent[1.2s-7.4s]: வணக்கம் சார் என்ளடைய நன்ுன் நகந்துக்கு சுலப்கப்பட்டி ரெச்சோரும் சல் எ்ன செலவு கன்றதுக்காவர் பா் \n",
+ "#Customer[7.4s-8.5s]: யர ர் \n",
+ "#Agent[8.8s-13.6s]: சரனால பிண்டிக்கல் தலப்பாகட்டி ரெஸ்டோரன்ட்ருந்து ஃபீட்பாக் கலெக்ட் வண்றதுக்காக கால் பண்ணிருக்கோம் ச \n",
+ "#Customer[13.9s-14.7s]: அ்கலு மேடம் \n",
+ "#Agent[15.4s-19.7s]: சர் இந்த கால் வந்து ரெண்டு ண்து மூனு நிமிஷம் நடிக்கும் பேசலாம்வாக செர்ன் தெரிலமா \n",
+ "#Customer[20.2s-22.8s]: இல்ல அவள வேலைப \n",
+ "#Customer[23.3s-24.1s]: சரிாச்சா \n",
+ "#Agent[24.7s-33.2s]: எச்ச ஃபீஸ்பாக் கலெக்ட் பண்றதுக்காக கால் பண்ணிருக்கோம் உங்க பேர் ரமேஷன்லா படேஷன் குவாலிட்டி மட்டு வி்சலி் க பு்பப்பறப்பா் செயப்படு் \n",
+ "#Agent[33.5s-37.8s]: நீங்க டைமன் பணிரந்த் அனுபவம் எப்டி எந்த்க்கரக்்சசி்டங் \n",
+ "#Customer[38.2s-41.2s]: ஆஹஆஹ் நல்லாருக்கேன் மேடம்்ுடி \n",
+ "#Customer[42.0s-45.5s]: நன்காசசாகடன் இந்ல எனக்கு ச்ிப்ணிநான்பகே்் \n",
+ "#Customer[46.1s-46.9s]: ஓகேயவநல்ா் \n",
+ "#Agent[47.4s-56.7s]: இ்தரெஸ்கரண்டுக்கு மொத்தம் ஐந்து மதுப்பெண்கள்ர்கறைசர் ஒன்ணு குறைவானது ஐந்ததிலும் ங்ளுட அனுபவத்து வச்சு எுத்கு நீங்க ஐந்து கைந்து குடுத்திருக்கீங்க இந்த மதிப்பெண்கள் உங்களல குடுக்கப்பட்டதீங்களா \n",
+ "#Customer[58.0s-58.4s]: என்ன மேடம் \n",
+ "#Agent[58.9s-68.1s]: அதாவது நேத்து நீங்க ரேட்டிங்பை்கு பைல் மிச் ரெஸ்ட்ரோன்ஸ்ல ப்கிஙகீங்களா சார் வரேட்டிங் சகம் மித்தநன்றி்ச வ்மேஷன் மித் க \n",
+ "#Customer[68.1s-72.8s]: மேடம்என் நம்ப தரவட ஓர் சி இரங்குத்தம்பாத்துருவான கஷ்சப்பட்டங்க எப்டி பண்ணம் சாப்டுவாங்க \n",
+ "#Customer[73.8s-78.4s]: சாபால் நல்லாருந்து என்ன பேசிும் ல்லா இல்லதாு எதனத்தம்ப தரனா \n",
+ "#Customer[78.6s-81.0s]: சொல்லுாசி \n",
+ "#Agent[82.7s-85.6s]: சோதலை அதிகமாருக்கு நிலங்க \n",
+ "#Customer[86.1s-88.6s]: போ்காொ்சம் வேலை யரக்குன்டா ஈவ்னிங் தூரும் \n",
+ "#Customer[89.6s-90.2s]: உ \n"
+ ]
+ }
+ ],
+ "source": [
+ "for turn, _, speaker in diarization.itertracks(yield_label=True):\n",
+ " #print(f'Speaker \"{speaker}\" speaks between t={turn.start:.1f}s and t={turn.end:.1f}s.')\n",
+ " speaker_label = \"Agent\" if speaker == \"B\" else \"Customer\"\n",
+ " tscirpt = transcriber(prep(aseg[turn.start*1000:turn.end*1000]))\n",
+ " print(f'#{speaker_label}[{turn.start:.1f}s-{turn.end:.1f}s]: {tscirpt}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f4ea9ee1-cc14-4e00-8433-45d416298af2",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/setup.py b/setup.py
index edba725..8821295 100644
--- a/setup.py
+++ b/setup.py
@@ -76,7 +76,7 @@ extra_requirements = {
"pymongo~=3.10.1",
"matplotlib~=3.2.1",
"pydub~=0.24.0",
- "streamlit~=0.58.0",
+ "streamlit~=1.0.0",
"natural~=0.2.0",
"stringcase~=1.2.0",
"google-cloud-speech~=1.3.1",
@@ -85,6 +85,11 @@ extra_requirements = {
"pyspellchecker~=0.6.2",
"google-cloud-texttospeech~=1.0.1",
"rangehttpserver~=1.2.0",
+ "streamlit~=1.0.0",
+ ],
+ "dev": [
+ "jupyterlab~=3.1.18",
+ "ipykernel~=6.4.1",
],
"crypto": ["cryptography~=3.4.7"],
"train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"],
diff --git a/src/plume/cli/data/generate.py b/src/plume/cli/data/generate.py
index 464346b..8056c60 100644
--- a/src/plume/cli/data/generate.py
+++ b/src/plume/cli/data/generate.py
@@ -1,11 +1,17 @@
from pathlib import Path
import shutil
-
+from tqdm import tqdm
import typer
from plume.utils.lazy_import import lazy_module
from plume.utils.tts import GoogleTTS
-from plume.utils.transcribe import triton_transcribe_grpc_gen
+from plume.utils.transcribe import (
+ triton_transcribe_grpc_gen,
+ chunk_transcribe_meta_gen,
+ transcribe_rpyc_gen,
+)
from plume.utils.manifest import asr_manifest_writer
+from plume.utils.diarize import diarize_audio_gen
+from plume.utils.extended_path import ExtendedPath
pydub = lazy_module("pydub")
app = typer.Typer()
@@ -50,4 +56,46 @@ def asr_dataset(audio_dir: Path, out_dir: Path, model="slu_num_wav2vec2"):
"text": transcript,
}
- asr_manifest_writer(out_dir / 'manifest.json', data_gen())
+ asr_manifest_writer(out_dir / "manifest.json", data_gen())
+
+
+@app.command()
+def mono_diarize_asr_dataset(audio_dir: Path, out_dir: Path):
+ out_wav_dir = out_dir / "wavs"
+ out_wav_dir.mkdir(exist_ok=True, parents=True)
+ diarize_audio = diarize_audio_gen()
+
+ def data_gen():
+ aud_files = list(audio_dir.glob("*/*.mp3")) + list(
+ audio_dir.glob("*/*.wav")
+ )
+ diameta = ExtendedPath(out_dir / "diameta.json")
+ base_transcriber, base_prep = transcribe_rpyc_gen()
+ transcriber, prep = chunk_transcribe_meta_gen(
+ base_transcriber, base_prep, method="chunked"
+ )
+
+ diametadata = []
+ for af in tqdm(aud_files):
+ try:
+ # raise Exception("Test")
+ for dres in diarize_audio(af):
+ sample_fname = dres.pop("sample_fname")
+ out_af = out_wav_dir / sample_fname
+ wav_bytes = dres.pop("wav")
+ out_af.write_bytes(wav_bytes)
+ audio_af = out_af.relative_to(out_dir)
+ aud_seg = dres.pop("wavseg")
+ t_seg = prep(aud_seg)
+ transcript = transcriber(t_seg)
+ diametadata.append(dres)
+ yield {
+ "audio_filepath": str(audio_af),
+ "duration": aud_seg.duration_seconds,
+ "text": transcript,
+ }
+ except Exception as e:
+ print(f'error diariziaing/trascribing {af} - {e}')
+ diameta.write_json(diametadata)
+
+ asr_manifest_writer(out_dir / "manifest.json", data_gen())
diff --git a/src/plume/models/pyann-dia/test.py b/src/plume/models/pyann-dia/test.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/plume/models/wav2vec2_transformers/asr.py b/src/plume/models/wav2vec2_transformers/asr.py
index c8d9d59..08255f9 100644
--- a/src/plume/models/wav2vec2_transformers/asr.py
+++ b/src/plume/models/wav2vec2_transformers/asr.py
@@ -13,23 +13,38 @@ class Wav2Vec2TransformersASR(object):
"""docstring for Wav2Vec2TransformersASR."""
def __init__(self, model_dir):
- super(Wav2Vec2TransformersASR, self).__init__()
+ # super(Wav2Vec2TransformersASR, self).__init__()
+ self.device = "cuda:1" if torch.cuda.is_available() else "cpu"
+ # sd = torch.load(
+ # model_dir / "pytorch_model.bin", map_location=self.device
+ # )
+ # self.processor = Wav2Vec2Processor.from_pretrained(
+ # model_dir, state_dict=sd
+ # )
+ # self.model = Wav2Vec2ForCTC.from_pretrained(model_dir, state_dict=sd).to(self.device)
self.processor = Wav2Vec2Processor.from_pretrained(model_dir)
- self.model = Wav2Vec2ForCTC.from_pretrained(model_dir)
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_dir).to(self.device)
def transcribe(self, audio_data):
aud_f = BytesIO(audio_data)
# net_input = {}
speech_data, _ = sf.read(aud_f)
input_values = self.processor(
- speech_data, return_tensors="pt", padding="longest"
- ).input_values # Batch size 1
+ speech_data,
+ return_tensors="pt",
+ padding="longest",
+ sampling_rate=16000,
+ ).input_values.to(
+ self.device
+ ) # Batch size 1
# retrieve logits
+ #print(f"audio:{speech_data.shape} processed:{input_values.shape}")
logits = self.model(input_values).logits
-
+ #print(f"logit shape:{logits.shape}")
# take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1)
-
+ #print(f"predicted_ids shape:{predicted_ids.shape}")
transcription = self.processor.batch_decode(predicted_ids)[0]
- return transcription
+ result = transcription.replace('', '')
+ return result
diff --git a/src/plume/models/wav2vec2_transformers/serve.py b/src/plume/models/wav2vec2_transformers/serve.py
index 2feed4e..a511f46 100644
--- a/src/plume/models/wav2vec2_transformers/serve.py
+++ b/src/plume/models/wav2vec2_transformers/serve.py
@@ -1,5 +1,5 @@
import os
-import logging
+# import logging
from pathlib import Path
# from rpyc.utils.server import ThreadedServer
@@ -10,6 +10,10 @@ from plume.utils import lazy_callable
# from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR
# from .asr import Wav2Vec2ASR
+# logging.basicConfig(
+# level=logging.INFO,
+# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+# )
ThreadedServer = lazy_callable("rpyc.utils.server.ThreadedServer")
Wav2Vec2TransformersASR = lazy_callable(
@@ -41,13 +45,12 @@ app = typer.Typer()
def rpyc_dir(
model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
):
+ typer.echo("loading asr model...")
w2vasr = Wav2Vec2TransformersASR(model_dir)
+ typer.echo("loaded asr model")
service = ASRService(w2vasr)
- logging.basicConfig(
- level=logging.INFO,
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
- )
- logging.info("starting asr server...")
+
+ typer.echo(f"serving asr on :{port}...")
t = ThreadedServer(service, port=port)
t.start()
diff --git a/src/plume/ui/annotation.py b/src/plume/ui/annotation.py
index 04c2a17..f04cba6 100644
--- a/src/plume/ui/annotation.py
+++ b/src/plume/ui/annotation.py
@@ -46,7 +46,10 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
st.title(f"ASR Validation - # {task_uid}")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
- "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
+ "Go To Sample:",
+ value=sample_no + 1,
+ min_value=1,
+ max_value=len(asr_data),
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
@@ -60,7 +63,9 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
show_key(sample, "asr_wer", trail="%")
show_key(sample, "correct_candidate")
- st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes())
+ if "plot_path" in sample:
+ st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes())
+
st.audio((data_dir / Path(sample["audio_path"])).open("rb"))
# set default to text
corrected = sample["text"]
@@ -78,27 +83,37 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
corrected = ""
if st.button("Submit"):
st.update_entry(
- sample["utterance_id"], {"status": selected, "correction": corrected}
+ sample["utterance_id"],
+ {"status": selected, "correction": corrected},
)
st.update_cursor(sample_no + 1)
if correction_entry:
status = correction_entry["value"]["status"]
correction = correction_entry["value"]["correction"]
- st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
+ st.markdown(
+ f"Your Response: **{status}** Correction: **{correction}**"
+ )
text_sample = st.text_input("Go to Text:", value="")
if text_sample != "":
- candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
+ candidates = [
+ i for (i, p) in enumerate(asr_data) if p["text"] == text_sample
+ ]
if len(candidates) > 0:
st.update_cursor(candidates[0])
- real_idx = st.number_input(
- "Go to real-index",
- value=sample["real_idx"],
- min_value=0,
- max_value=len(asr_data) - 1,
- )
- if real_idx != int(sample["real_idx"]):
- idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
- st.update_cursor(idx)
+ if "real_idx" in sample:
+ real_idx = st.number_input(
+ "Go to real-index",
+ value=sample["real_idx"],
+ min_value=0,
+ max_value=len(asr_data) - 1,
+ )
+ if real_idx != int(sample["real_idx"]):
+ idx = [
+ i
+ for (i, p) in enumerate(asr_data)
+ if p["real_idx"] == real_idx
+ ][0]
+ st.update_cursor(idx)
if __name__ == "__main__":
diff --git a/src/plume/ui/preview.py b/src/plume/ui/preview.py
index 89af3ad..6250931 100644
--- a/src/plume/ui/preview.py
+++ b/src/plume/ui/preview.py
@@ -27,7 +27,10 @@ def main(manifest: Path):
st.title("ASR Manifest Preview")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
- "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
+ "Go To Sample:",
+ value=sample_no + 1,
+ min_value=1,
+ max_value=len(asr_data),
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
diff --git a/src/plume/utils/__init__.py b/src/plume/utils/__init__.py
index 87b2237..f0f5d86 100644
--- a/src/plume/utils/__init__.py
+++ b/src/plume/utils/__init__.py
@@ -1,7 +1,6 @@
import io
import os
import re
-import json
import wave
import logging
import subprocess
diff --git a/src/plume/utils/diarize.py b/src/plume/utils/diarize.py
new file mode 100644
index 0000000..2f9ae6b
--- /dev/null
+++ b/src/plume/utils/diarize.py
@@ -0,0 +1,65 @@
+from pathlib import Path
+from plume.utils import lazy_module
+from plume.utils.audio import audio_seg_to_wav_bytes
+
+pydub = lazy_module('pydub')
+torch = lazy_module('torch')
+
+
+def transform_audio(file_location, path_to_save):
+ audio_seg = (
+ pydub.AudioSegment.from_file(file_location)
+ .set_frame_rate(16000)
+ .set_sample_width(2)
+ )
+ audio_seg.export(path_to_save, format="wav")
+
+
+def gen_diarizer():
+ pipeline = torch.hub.load("pyannote/pyannote-audio", "dia")
+
+ def _diarizer(audio_path):
+ return pipeline({"audio": audio_path})
+
+ return _diarizer
+
+
+# base_transcriber, base_prep = transcribe_rpyc_gen()
+# transcriber, prep = chunk_transcribe_meta_gen(
+# base_transcriber, base_prep, method="chunked")
+
+# diarizer = gen_diarizer()
+
+
+def diarize_audio_gen():
+ diarizer = gen_diarizer()
+
+ def _diarize_audio(audio_path: Path):
+ aseg = (
+ pydub.AudioSegment.from_file(audio_path)
+ .set_frame_rate(16000)
+ .set_sample_width(2)
+ .set_channels(1)
+ )
+ aseg.export("/tmp/temp.wav", format="wav")
+ diarization = diarizer("/tmp/temp.wav")
+ for n, (turn, _, speaker) in enumerate(
+ diarization.itertracks(yield_label=True)
+ ):
+ # speaker_label = "Agent" if speaker == "B" else "Customer"
+ turn_seg = aseg[turn.start * 1000 : turn.end * 1000]
+ sample_fname = (
+ audio_path.stem + "_" + str(n) + ".wav"
+ )
+ yield {
+ "speaker": speaker,
+ "wav": audio_seg_to_wav_bytes(turn_seg),
+ "wavseg": turn_seg,
+ "start": turn.start,
+ "end": turn.end,
+ "turnidx": n,
+ "filename": audio_path.name,
+ "sample_fname": sample_fname
+ }
+
+ return _diarize_audio
diff --git a/src/plume/utils/transcribe.py b/src/plume/utils/transcribe.py
index a4661bc..3ddcc58 100644
--- a/src/plume/utils/transcribe.py
+++ b/src/plume/utils/transcribe.py
@@ -47,12 +47,15 @@ def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT):
asr_seg = (
aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
)
+ # af = BytesIO()
+ # asr_seg.export(af, format="wav")
+ # input_audio_bytes = af.getvalue()
+ return asr_seg
+
+ def dummy_transcript(asr_seg, append_raw=False):
af = BytesIO()
asr_seg.export(af, format="wav")
- input_audio_bytes = af.getvalue()
- return input_audio_bytes
-
- def dummy_transcript(aud, append_raw=False):
+ aud = af.getvalue()
return asr.transcribe(aud)
return dummy_transcript, audio_prep
@@ -147,6 +150,56 @@ def triton_transcribe_grpc_gen(
return whole_transcriber, audio_prep
+def chunk_transcribe_meta_gen(
+ transcriber,
+ prep,
+ method="chunked",
+ chunk_msec=5000,
+ sil_msec=500,
+ sep=" ",
+):
+ from tritonclient.utils import np_to_triton_dtype, InferenceServerException
+ import tritonclient.grpc as grpcclient
+ # force loading
+ np.array
+
+ sup_meth = ["chunked", "silence", "whole"]
+ if method not in sup_meth:
+ meths = "|".join(sup_meth)
+ raise Exception(f"unsupported method {method}. pick one of {meths}")
+
+ def chunked_transcriber(aud_seg):
+ if method == "silence":
+ sil_chunks = pydub.silence.split_on_silence(
+ aud_seg,
+ min_silence_len=sil_msec,
+ silence_thresh=-50,
+ keep_silence=500,
+ )
+ chunks = [sc for c in sil_chunks for sc in c[::chunk_msec]]
+ else:
+ chunks = aud_seg[::chunk_msec]
+ # if overlap:
+ # chunks = [
+ # aud_seg[start, end]
+ # for start, end in range(0, int(aud_seg.duration_seconds * 1000, 1000))
+ # ]
+ # pass
+ transcript_list = []
+ sil_pad = pydub.AudioSegment.silent(duration=sil_msec)
+ for seg in chunks:
+ t_seg = sil_pad + seg + sil_pad
+ c_transcript = transcriber(t_seg)
+ transcript_list.append(c_transcript)
+ transcript = sep.join(transcript_list)
+ return transcript
+ whole_transcriber = (
+ transcriber if method == "whole" else chunked_transcriber
+ )
+
+ return whole_transcriber, prep
+
+
@app.command()
def audio_file(
audio_file: Path,
@@ -157,13 +210,15 @@ def audio_file(
model="slu_num_wav2vec2",
):
aseg = pydub.AudioSegment.from_file(audio_file)
+ method = "chunked" if chunked else "whole"
if rpyc:
- transcriber, prep = transcribe_rpyc_gen()
+ base_transcriber, base_prep = transcribe_rpyc_gen()
else:
- method = "chunked" if chunked else "whole"
- transcriber, prep = triton_transcribe_grpc_gen(
- asr_model=model, method=method, append_raw=append_raw
+ base_transcriber, base_prep = triton_transcribe_grpc_gen(
+ asr_model=model, method='whole', append_raw=append_raw
)
+ transcriber, prep = chunk_transcribe_meta_gen(
+ base_transcriber, base_prep, method=method)
transcription = transcriber(prep(aseg))
typer.echo(transcription)
diff --git a/src/plume/utils/ui_persist.py b/src/plume/utils/ui_persist.py
index f050d60..9113b79 100644
--- a/src/plume/utils/ui_persist.py
+++ b/src/plume/utils/ui_persist.py
@@ -11,10 +11,17 @@ def setup_file_state(st):
def current_cursor_fn():
return task_path.read_json()["current_cursor"]
+ # if "audio_sample_idx" not in st.session_state:
+ # st.session_state.audio_sample_idx = task_path.read_json()[
+ # "current_cursor"
+ # ]
+ # return st.session_state.audio_sample_idx
def update_cursor_fn(val=0):
task_path.write_json({"current_cursor": val})
- rerun()
+ # rerun()
+ # st.session_state.audio_sample_idx = val
+ st.experimental_rerun()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn