diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..f630a67 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +graft plume/utils/gentle_preview diff --git a/plume/cli/__init__.py b/plume/cli/__init__.py index 2200e2e..2de73b0 100644 --- a/plume/cli/__init__.py +++ b/plume/cli/__init__.py @@ -7,12 +7,12 @@ from .eval import app as eval_app from .serve import app as serve_app app = typer.Typer() -app.add_typer(data_app, name="data") -app.add_typer(ui_app, name="ui") -app.add_typer(train_app, name="train") -app.add_typer(eval_app, name="eval") -app.add_typer(serve_app, name="serve") -app.add_typer(utils_app, name='utils') +app.add_typer(data_app) +app.add_typer(ui_app) +app.add_typer(train_app) +app.add_typer(eval_app) +app.add_typer(serve_app) +app.add_typer(utils_app) def main(): diff --git a/plume/cli/data/__init__.py b/plume/cli/data/__init__.py index 9a90926..e7bcb7b 100644 --- a/plume/cli/data/__init__.py +++ b/plume/cli/data/__init__.py @@ -27,6 +27,13 @@ app.add_typer(generate_app, name="generate") app.add_typer(wav2vec2_app, name="wav2vec2") +@app.callback() +def data(): + """ + data sub commands + """ + + @app.command() def fix_path(dataset_path: Path, force: bool = False): manifest_path = dataset_path / Path("manifest.json") diff --git a/plume/cli/eval.py b/plume/cli/eval.py index 5686d77..53a2aef 100644 --- a/plume/cli/eval.py +++ b/plume/cli/eval.py @@ -3,3 +3,10 @@ from ..models.wav2vec2.eval import app as wav2vec2_app app = typer.Typer() app.add_typer(wav2vec2_app, name="wav2vec2") + + +@app.callback() +def eval(): + """ + eval sub commands + """ diff --git a/plume/cli/serve.py b/plume/cli/serve.py index 8397682..7b7e29d 100644 --- a/plume/cli/serve.py +++ b/plume/cli/serve.py @@ -5,3 +5,10 @@ from ..models.jasper.serve import app as jasper_app app = typer.Typer() app.add_typer(wav2vec2_app, name="wav2vec2") app.add_typer(jasper_app, name="jasper") + + +@app.callback() +def serve(): + """ + serve sub commands + """ diff --git a/plume/cli/train.py b/plume/cli/train.py index c067984..e8141b5 100644 --- a/plume/cli/train.py +++ b/plume/cli/train.py @@ -1,5 +1,12 @@ import typer -from ..models.wav2vec2.train import app as train_app +from ..models.wav2vec2.train import app as wav2vec2_app app = typer.Typer() -app.add_typer(train_app, name="wav2vec2") +app.add_typer(wav2vec2_app, name="wav2vec2") + + +@app.callback() +def train(): + """ + train sub commands + """ diff --git a/plume/ui/__init__.py b/plume/ui/__init__.py index 3aa516d..67a5c35 100644 --- a/plume/ui/__init__.py +++ b/plume/ui/__init__.py @@ -3,12 +3,20 @@ import sys from pathlib import Path from plume.utils import lazy_module + # from streamlit import cli as stcli -stcli = lazy_module('streamlit.cli') +stcli = lazy_module("streamlit.cli") app = typer.Typer() +@app.callback() +def ui(): + """ + ui sub commands + """ + + @app.command() def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""): annotation_lit_path = Path(__file__).parent / Path("annotation.py") @@ -40,13 +48,7 @@ def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = @app.command() def preview(manifest_path: Path): annotation_lit_path = Path(__file__).parent / Path("preview.py") - sys.argv = [ - "streamlit", - "run", - str(annotation_lit_path), - "--", - str(manifest_path) - ] + sys.argv = ["streamlit", "run", str(annotation_lit_path), "--", str(manifest_path)] sys.exit(stcli.main()) @@ -56,6 +58,18 @@ def collection(data_dir: Path, task_id: str = ""): pass +@app.command() +def alignment(preview_dir: Path, port: int = 8010): + from RangeHTTPServer import RangeRequestHandler + from functools import partial + from http.server import HTTPServer + + server_address = ("", port) + handler_class = partial(RangeRequestHandler, directory=str(preview_dir)) + httpd = HTTPServer(server_address, handler_class) + httpd.serve_forever() + + def main(): app() diff --git a/plume/ui/annotation.py b/plume/ui/annotation.py index 1c45c54..bcb883f 100644 --- a/plume/ui/annotation.py +++ b/plume/ui/annotation.py @@ -1,66 +1,14 @@ # import sys from pathlib import Path -from uuid import uuid4 import streamlit as st import typer - -from plume.utils import ExtendedPath, get_mongo_conn -from plume.preview.st_rerun import rerun +from plume.utils import ExtendedPath +from plume.utils.ui_persist import setup_mongo_asr_validation_state app = typer.Typer() - -if not hasattr(st, "mongo_connected"): - st.mongoclient = get_mongo_conn(col="asr_validation") - mongo_conn = st.mongoclient - st.task_id = str(uuid4()) - - def current_cursor_fn(): - # mongo_conn = st.mongoclient - cursor_obj = mongo_conn.find_one( - {"type": "current_cursor", "task_id": st.task_id} - ) - cursor_val = cursor_obj["cursor"] - return cursor_val - - def update_cursor_fn(val=0): - mongo_conn.find_one_and_update( - {"type": "current_cursor", "task_id": st.task_id}, - {"$set": {"type": "current_cursor", "task_id": st.task_id, "cursor": val}}, - upsert=True, - ) - rerun() - - def get_correction_entry_fn(code): - return mongo_conn.find_one( - {"type": "correction", "code": code}, projection={"_id": False} - ) - - def update_entry_fn(code, value): - mongo_conn.find_one_and_update( - {"type": "correction", "code": code}, - {"$set": {"value": value, "task_id": st.task_id}}, - upsert=True, - ) - - def set_task_fn(data_path, task_id): - if task_id: - st.task_id = task_id - task_path = data_path / Path(f"task-{st.task_id}.lck") - if not task_path.exists(): - print(f"creating task lock at {task_path}") - task_path.touch() - - st.get_current_cursor = current_cursor_fn - st.update_cursor = update_cursor_fn - st.get_correction_entry = get_correction_entry_fn - st.update_entry = update_entry_fn - st.set_task = set_task_fn - st.mongo_connected = True - cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id}) - if not cursor_obj: - update_cursor_fn(0) +setup_mongo_asr_validation_state(st) @st.cache() diff --git a/plume/ui/preview.py b/plume/ui/preview.py index 60f8dd6..89af3ad 100644 --- a/plume/ui/preview.py +++ b/plume/ui/preview.py @@ -3,27 +3,11 @@ from pathlib import Path import streamlit as st import typer from plume.utils import ExtendedPath -from plume.preview.st_rerun import rerun +from plume.utils.ui_persist import setup_file_state app = typer.Typer() -if not hasattr(st, "state_lock"): - # st.task_id = str(uuid4()) - task_path = ExtendedPath("preview.lck") - - def current_cursor_fn(): - return task_path.read_json()["current_cursor"] - - def update_cursor_fn(val=0): - task_path.write_json({"current_cursor": val}) - rerun() - - st.get_current_cursor = current_cursor_fn - st.update_cursor = update_cursor_fn - st.state_lock = True - # cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id}) - # if not cursor_obj: - update_cursor_fn(0) +setup_file_state(st) @st.cache() @@ -40,7 +24,7 @@ def main(manifest: Path): print("Invalid samplno resetting to 0") st.update_cursor(0) sample = asr_data[sample_no] - st.title(f"ASR Manifest Preview") + st.title("ASR Manifest Preview") st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**") new_sample = st.number_input( "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) diff --git a/plume/utils/.gitignore b/plume/utils/.gitignore new file mode 100644 index 0000000..3676125 --- /dev/null +++ b/plume/utils/.gitignore @@ -0,0 +1,151 @@ +/data/ +/model/ +/train/ +.env* +*.yaml +*.yml +*.json + + +# Created by https://www.gitignore.io/api/python +# Edit at https://www.gitignore.io/?templates=python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# End of https://www.gitignore.io/api/python + +# Created by https://www.gitignore.io/api/macos +# Edit at https://www.gitignore.io/?templates=macos + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# End of https://www.gitignore.io/api/macos diff --git a/plume/utils/__init__.py b/plume/utils/__init__.py index 2b43219..46f3094 100644 --- a/plume/utils/__init__.py +++ b/plume/utils/__init__.py @@ -11,12 +11,14 @@ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor import subprocess import shutil from urllib.parse import urlsplit + # from .lazy_loader import LazyLoader from .lazy_import import lazy_callable, lazy_module # from ruamel.yaml import YAML # import boto3 import typer + # import pymongo # from slugify import slugify # import pydub @@ -34,16 +36,16 @@ from .tts import app as tts_app from .transcribe import app as transcribe_app from .align import app as align_app -boto3 = lazy_module('boto3') -pymongo = lazy_module('pymongo') -pydub = lazy_module('pydub') -audio_display = lazy_module('librosa.display') -plt = lazy_module('matplotlib.pyplot') -librosa = lazy_module('librosa') -YAML = lazy_callable('ruamel.yaml.YAML') -num2words = lazy_callable('num2words.num2words') -slugify = lazy_callable('slugify.slugify') -compress = lazy_callable('natural.date.compress') +boto3 = lazy_module("boto3") +pymongo = lazy_module("pymongo") +pydub = lazy_module("pydub") +audio_display = lazy_module("librosa.display") +plt = lazy_module("matplotlib.pyplot") +librosa = lazy_module("librosa") +YAML = lazy_callable("ruamel.yaml.YAML") +num2words = lazy_callable("num2words.num2words") +slugify = lazy_callable("slugify.slugify") +compress = lazy_callable("natural.date.compress") app = typer.Typer() app.add_typer(tts_app, name="tts") @@ -51,6 +53,13 @@ app.add_typer(align_app, name="align") app.add_typer(transcribe_app, name="transcribe") +@app.callback() +def utils(): + """ + utils sub commands + """ + + logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) @@ -125,6 +134,10 @@ def upload_s3(dataset_path, s3_path): run_shell(f"aws s3 sync {dataset_path} {s3_path}") +def copy_s3(dataset_path, s3_path): + run_shell(f"aws s3 cp {dataset_path} {s3_path}") + + def get_download_path(s3_uri, output_path): s3_uri_p = urlsplit(s3_uri) download_path = output_path / Path(s3_uri_p.path[1:]) @@ -135,11 +148,12 @@ def get_download_path(s3_uri, output_path): def s3_downloader(): s3 = boto3.client("s3") - def download_s3(s3_uri, download_path): + def download_s3(s3_uri, download_path, verbose=False): s3_uri_p = urlsplit(s3_uri) download_path.parent.mkdir(exist_ok=True, parents=True) if not download_path.exists(): - print(f"downloading {s3_uri} to {download_path}") + if verbose: + print(f"downloading {s3_uri} to {download_path}") s3.download_file(s3_uri_p.netloc, s3_uri_p.path[1:], str(download_path)) return download_s3 @@ -186,6 +200,7 @@ def ui_data_generator(dataset_dir, asr_data_source, verbose=False): plot_seg(wav_plot_path.absolute(), audio_file) return { "audio_path": str(rel_data_path), + "audio_filepath": str(rel_data_path), "duration": round(audio_dur, 1), "text": transcript, "real_idx": num_datapoints, @@ -229,17 +244,17 @@ def ui_dump_manifest_writer(dataset_dir, asr_data_source, verbose=False): ) asr_manifest = dataset_dir / Path("manifest.json") - with asr_manifest.open("w") as mf: - print(f"writing manifest to {asr_manifest}") - for d in dump_data: - rel_data_path = d["audio_path"] - audio_dur = d["duration"] - transcript = d["text"] - manifest = manifest_str(str(rel_data_path), audio_dur, transcript) - mf.write(manifest) - + asr_manifest_writer(asr_manifest, dump_data, verbose=verbose) + # with asr_manifest.open("w") as mf: + # print(f"writing manifest to {asr_manifest}") + # for d in dump_data: + # rel_data_path = d["audio_path"] + # audio_dur = d["duration"] + # transcript = d["text"] + # manifest = manifest_str(str(rel_data_path), audio_dur, transcript) + # mf.write(manifest) ui_dump_file = dataset_dir / Path("ui_dump.json") - ExtendedPath(ui_dump_file).write_json({"data": dump_data}) + ExtendedPath(ui_dump_file).write_json({"data": dump_data}, verbose=verbose) return num_datapoints @@ -254,9 +269,10 @@ def asr_manifest_reader(data_manifest_path: Path): yield p -def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source): +def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source, verbose=False): with asr_manifest_path.open("w") as mf: - print(f"opening {asr_manifest_path} for writing manifest") + if verbose: + print(f"writing asr manifest to {asr_manifest_path}") for mani_dict in manifest_str_source: manifest = manifest_str( mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"] @@ -293,37 +309,43 @@ def batch(iterable, n=1): class ExtendedPath(type(Path())): """docstring for ExtendedPath.""" - def read_json(self): - print(f"reading json from {self}") + def read_json(self, verbose=False): + if verbose: + print(f"reading json from {self}") with self.open("r") as jf: return json.load(jf) - def read_yaml(self): + def read_yaml(self, verbose=False): yaml = YAML(typ="safe", pure=True) - print(f"reading yaml from {self}") + if verbose: + print(f"reading yaml from {self}") with self.open("r") as yf: return yaml.load(yf) - def read_jsonl(self): - print(f"reading jsonl from {self}") + def read_jsonl(self, verbose=False): + if verbose: + print(f"reading jsonl from {self}") with self.open("r") as jf: - for l in jf.readlines(): - yield json.loads(l) + for ln in jf.readlines(): + yield json.loads(ln) - def write_json(self, data): - print(f"writing json to {self}") + def write_json(self, data, verbose=False): + if verbose: + print(f"writing json to {self}") self.parent.mkdir(parents=True, exist_ok=True) with self.open("w") as jf: json.dump(data, jf, indent=2) - def write_yaml(self, data): + def write_yaml(self, data, verbose=False): yaml = YAML() - print(f"writing yaml to {self}") + if verbose: + print(f"writing yaml to {self}") with self.open("w") as yf: yaml.dump(data, yf) - def write_jsonl(self, data): - print(f"writing jsonl to {self}") + def write_jsonl(self, data, verbose=False): + if verbose: + print(f"writing jsonl to {self}") self.parent.mkdir(parents=True, exist_ok=True) with self.open("w") as jf: for d in data: diff --git a/plume/utils/align.py b/plume/utils/align.py index 943676b..5e937ce 100644 --- a/plume/utils/align.py +++ b/plume/utils/align.py @@ -1,12 +1,14 @@ from pathlib import Path -from .tts import GoogleTTS # from IPython import display import requests import io -import typer +import shutil +import typer from plume.utils import lazy_module +from .tts import GoogleTTS + display = lazy_module('IPython.display') pydub = lazy_module('pydub') @@ -63,16 +65,19 @@ def gentle_preview( audio_path: Path, transcript_path: Path, service_uri="http://101.53.142.218:8765/transcriptions", - gent_preview_dir="../gentle_preview", + gent_preview_dir="./gentle_preview", ): from . import ExtendedPath - ab = audio_path.read_bytes() - tt = transcript_path.read_text() - audio, alignment = gentle_aligner(service_uri, ab, tt) - audio.export(gent_preview_dir / Path("a.wav"), format="wav") - alignment["status"] = "OK" - ExtendedPath(gent_preview_dir / Path("status.json")).write_json(alignment) + pkg_gentle_dir = Path(__file__).parent / 'gentle_preview' + + shutil.copytree(str(pkg_gentle_dir), str(gent_preview_dir)) + # ab = audio_path.read_bytes() + # tt = transcript_path.read_text() + # audio, alignment = gentle_aligner(service_uri, ab, tt) + # audio.export(gent_preview_dir / Path("a.wav"), format="wav") + # alignment["status"] = "OK" + # ExtendedPath(gent_preview_dir / Path("status.json")).write_json(alignment) def main(): diff --git a/plume/utils/gentle_preview/README.md b/plume/utils/gentle_preview/README.md new file mode 100644 index 0000000..9337aca --- /dev/null +++ b/plume/utils/gentle_preview/README.md @@ -0,0 +1,5 @@ +Serve with https://github.com/danvk/RangeHTTPServer +`https://github.com/claysciences/CORSRangeHTTPServer` + +`python -m RangeHTTPServer` +`python -m http.server` diff --git a/plume/utils/gentle_preview/align.html b/plume/utils/gentle_preview/align.html new file mode 100644 index 0000000..35931de --- /dev/null +++ b/plume/utils/gentle_preview/align.html @@ -0,0 +1,80 @@ + + + + + + + +
+ Audio:
+
+
+ Transcript:
+
+ Conservative
+ Include disfluencies
+ +
+ + + diff --git a/plume/utils/gentle_preview/index.html b/plume/utils/gentle_preview/index.html new file mode 100644 index 0000000..eab0828 --- /dev/null +++ b/plume/utils/gentle_preview/index.html @@ -0,0 +1,408 @@ + + + + + + + + +
+ + + diff --git a/plume/utils/gentle_preview/preloader.gif b/plume/utils/gentle_preview/preloader.gif new file mode 100644 index 0000000..6c64343 Binary files /dev/null and b/plume/utils/gentle_preview/preloader.gif differ diff --git a/plume/ui/st_rerun.py b/plume/utils/st_rerun.py similarity index 100% rename from plume/ui/st_rerun.py rename to plume/utils/st_rerun.py diff --git a/plume/utils/transcribe.py b/plume/utils/transcribe.py index f1f74c1..330177d 100644 --- a/plume/utils/transcribe.py +++ b/plume/utils/transcribe.py @@ -8,12 +8,11 @@ import typer # import rpyc # from tqdm import tqdm -# from pydub import AudioSegment # from pydub.silence import split_on_silence from plume.utils import lazy_module, lazy_callable rpyc = lazy_module('rpyc') -AudioSegment = lazy_callable('pydub.AudioSegment') +pydub = lazy_module('pydub') split_on_silence = lazy_callable('pydub.silence.split_on_silence') app = typer.Typer() @@ -106,7 +105,7 @@ def triton_transcribe_grpc_gen( # ] # pass transcript_list = [] - sil_pad = AudioSegment.silent(duration=sil_msec) + sil_pad = pydub.AudioSegment.silent(duration=sil_msec) for seg in chunks: t_seg = sil_pad + seg + sil_pad c_transcript = transcriber(t_seg) @@ -124,9 +123,7 @@ def triton_transcribe_grpc_gen( @app.command() def file(audio_file: Path, write_file: bool = False, chunked=True): - from pydub import AudioSegment - - aseg = AudioSegment.from_file(audio_file) + aseg = pydub.AudioSegment.from_file(audio_file) transcriber, prep = triton_transcribe_grpc_gen() transcription = transcriber(prep(aseg)) @@ -139,10 +136,8 @@ def file(audio_file: Path, write_file: bool = False, chunked=True): @app.command() def benchmark(audio_file: Path): - from pydub import AudioSegment - transcriber, audio_prep = transcribe_rpyc_gen() - file_seg = AudioSegment.from_file(audio_file) + file_seg = pydub.AudioSegment.from_file(audio_file) aud_seg = audio_prep(file_seg) def timeinfo(): diff --git a/plume/utils/tts.py b/plume/utils/tts.py index c99fa97..a7eb892 100644 --- a/plume/utils/tts.py +++ b/plume/utils/tts.py @@ -27,6 +27,10 @@ class GoogleTTS(object): audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16, sample_rate_hertz=params["sample_rate"], ) + if 'speaking_rate' in params: + audio_config.speaking_rate = params['speaking_rate'] + if 'pitch' in params: + audio_config.pitch = params['pitch'] response = self.client.synthesize_speech(tts_input, voice, audio_config) audio_content = response.audio_content return audio_content @@ -74,6 +78,19 @@ class GoogleTTS(object): ) return results + @classmethod + def voice_by_name(cls, name): + """Lists the available voices.""" + + # client = cls().client + + # Performs the list voices request + results = cls.voice_list() + for voice in results: + if voice['name'] == name: + return voice + raise ValueError(f'{name} not a valid voice') + @app.command() def generate_audio_file(text, dest_path: Path = "./tts_audio.wav", voice="en-US-Wavenet-D"): diff --git a/plume/utils/ui_persist.py b/plume/utils/ui_persist.py new file mode 100644 index 0000000..f050d60 --- /dev/null +++ b/plume/utils/ui_persist.py @@ -0,0 +1,85 @@ +from plume.utils import ExtendedPath, get_mongo_conn +from plume.utils.st_rerun import rerun +from uuid import uuid4 +from pathlib import Path + + +def setup_file_state(st): + if not hasattr(st, "state_lock"): + # st.task_id = str(uuid4()) + task_path = ExtendedPath("preview.lck") + + def current_cursor_fn(): + return task_path.read_json()["current_cursor"] + + def update_cursor_fn(val=0): + task_path.write_json({"current_cursor": val}) + rerun() + + st.get_current_cursor = current_cursor_fn + st.update_cursor = update_cursor_fn + st.state_lock = True + # cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id}) + # if not cursor_obj: + update_cursor_fn(0) + + +def setup_mongo_asr_validation_state(st): + if not hasattr(st, "mongo_connected"): + st.mongoclient = get_mongo_conn(col="asr_validation") + mongo_conn = st.mongoclient + st.task_id = str(uuid4()) + + def current_cursor_fn(): + # mongo_conn = st.mongoclient + cursor_obj = mongo_conn.find_one( + {"type": "current_cursor", "task_id": st.task_id} + ) + cursor_val = cursor_obj["cursor"] + return cursor_val + + def update_cursor_fn(val=0): + mongo_conn.find_one_and_update( + {"type": "current_cursor", "task_id": st.task_id}, + { + "$set": { + "type": "current_cursor", + "task_id": st.task_id, + "cursor": val, + } + }, + upsert=True, + ) + rerun() + + def get_correction_entry_fn(code): + return mongo_conn.find_one( + {"type": "correction", "code": code}, projection={"_id": False} + ) + + def update_entry_fn(code, value): + mongo_conn.find_one_and_update( + {"type": "correction", "code": code}, + {"$set": {"value": value, "task_id": st.task_id}}, + upsert=True, + ) + + def set_task_fn(data_path, task_id): + if task_id: + st.task_id = task_id + task_path = data_path / Path(f"task-{st.task_id}.lck") + if not task_path.exists(): + print(f"creating task lock at {task_path}") + task_path.touch() + + st.get_current_cursor = current_cursor_fn + st.update_cursor = update_cursor_fn + st.get_correction_entry = get_correction_entry_fn + st.update_entry = update_entry_fn + st.set_task = set_task_fn + st.mongo_connected = True + cursor_obj = mongo_conn.find_one( + {"type": "current_cursor", "task_id": st.task_id} + ) + if not cursor_obj: + update_cursor_fn(0) diff --git a/plume/utils/vad.py b/plume/utils/vad.py new file mode 100644 index 0000000..5832914 --- /dev/null +++ b/plume/utils/vad.py @@ -0,0 +1,205 @@ +import logging +import asyncio +import argparse +from pathlib import Path + +import webrtcvad +import pydub +from pydub.playback import play +from pydub.utils import make_chunks + + +DEFAULT_CHUNK_DUR = 20 + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def is_frame_voice(vad, seg, chunk_dur): + return ( + True + if ( + seg.duration_seconds == chunk_dur / 1000 + and vad.is_speech(seg.raw_data, seg.frame_rate) + ) + else False + ) + + +class VADFilterAudio(object): + """docstring for VADFilterAudio.""" + + def __init__(self, chunk_dur=DEFAULT_CHUNK_DUR): + super(VADFilterAudio, self).__init__() + self.chunk_dur = chunk_dur + self.vad = webrtcvad.Vad() + + def filter_segment(self, wav_seg): + chunks = make_chunks(wav_seg, self.chunk_dur) + speech_buffer = b"" + + for i, c in enumerate(chunks[:-1]): + voice_frame = is_frame_voice(self.vad, c, self.chunk_dur) + if voice_frame: + speech_buffer += c.raw_data + filtered_seg = pydub.AudioSegment( + data=speech_buffer, + frame_rate=wav_seg.frame_rate, + channels=wav_seg.channels, + sample_width=wav_seg.sample_width, + ) + return filtered_seg + + +class VADUtterance(object): + """docstring for VADUtterance.""" + + def __init__( + self, + max_silence=500, + min_utterance=280, + max_utterance=20000, + chunk_dur=DEFAULT_CHUNK_DUR, + start_cycles=3, + ): + super(VADUtterance, self).__init__() + self.vad = webrtcvad.Vad() + self.chunk_dur = chunk_dur + # duration in millisecs + self.max_sil = max_silence + self.min_utt = min_utterance + self.max_utt = max_utterance + self.speech_start = start_cycles * chunk_dur + + def __repr__(self): + return f"VAD(max_silence={self.max_sil},min_utterance:{self.min_utt},max_utterance:{self.max_utt})" + + async def stream_utterance(self, audio_stream): + silence_buffer = pydub.AudioSegment.empty() + voice_buffer = pydub.AudioSegment.empty() + silence_threshold = False + async for c in audio_stream: + voice_frame = is_frame_voice(self.vad, c, self.chunk_dur) + logger.debug(f"is audio stream voice? {voice_frame}") + if voice_frame: + silence_threshold = False + voice_buffer += c + silence_buffer = pydub.AudioSegment.empty() + else: + silence_buffer += c + voc_dur = voice_buffer.duration_seconds * 1000 + sil_dur = silence_buffer.duration_seconds * 1000 + + if voc_dur >= self.max_utt: + logger.info( + f"detected voice overflow: voice duration {voice_buffer.duration_seconds}" + ) + yield voice_buffer + voice_buffer = pydub.AudioSegment.empty() + + if sil_dur >= self.max_sil: + if voc_dur >= self.min_utt: + logger.info( + f"detected silence: voice duration {voice_buffer.duration_seconds}" + ) + yield voice_buffer + voice_buffer = pydub.AudioSegment.empty() + # ignore/clear voice if silence reached threshold or indent the statement + if not silence_threshold: + silence_threshold = True + + if voice_buffer: + yield voice_buffer + + async def stream_events(self, audio_stream): + """ + yields 0, voice_buffer for SpeechBuffer + yields 1, None for StartedSpeaking + yields 2, None for StoppedSpeaking + yields 4, audio_stream + """ + silence_buffer = pydub.AudioSegment.empty() + voice_buffer = pydub.AudioSegment.empty() + silence_threshold, started_speaking = False, False + async for c in audio_stream: + # yield (4, c) + voice_frame = is_frame_voice(self.vad, c, self.chunk_dur) + logger.debug(f"is audio stream voice? {voice_frame}") + if voice_frame: + silence_threshold = False + voice_buffer += c + silence_buffer = pydub.AudioSegment.empty() + else: + silence_buffer += c + voc_dur = voice_buffer.duration_seconds * 1000 + sil_dur = silence_buffer.duration_seconds * 1000 + + if voc_dur >= self.speech_start and not started_speaking: + started_speaking = True + yield (1, None) + + if voc_dur >= self.max_utt: + logger.info( + f"detected voice overflow: voice duration {voice_buffer.duration_seconds}" + ) + yield (0, voice_buffer) + voice_buffer = pydub.AudioSegment.empty() + started_speaking = False + + if sil_dur >= self.max_sil: + if voc_dur >= self.min_utt: + logger.info( + f"detected silence: voice duration {voice_buffer.duration_seconds}" + ) + yield (0, voice_buffer) + voice_buffer = pydub.AudioSegment.empty() + started_speaking = False + # ignore/clear voice if silence reached threshold or indent the statement + if not silence_threshold: + silence_threshold = True + yield (2, None) + + if voice_buffer: + yield (0, voice_buffer) + + @classmethod + async def stream_utterance_file(cls, audio_file): + async def stream_gen(): + audio_seg = pydub.AudioSegment.from_file(audio_file).set_frame_rate(32000) + chunks = make_chunks(audio_seg, DEFAULT_CHUNK_DUR) + for c in chunks: + yield c + + va_ut = cls() + buffer_src = va_ut.stream_utterance(stream_gen()) + async for buf in buffer_src: + play(buf) + await asyncio.sleep(1) + + +class VADStreamGen(object): + """docstring for VADStreamGen.""" + + def __init__(self, arg): + super(VADStreamGen, self).__init__() + self.arg = arg + + +def main(): + prog = Path(__file__).stem + parser = argparse.ArgumentParser(prog=prog, description="transcribes audio file") + parser.add_argument( + "--audio_file", + type=argparse.FileType("rb"), + help="audio file to transcribe", + default="./test_utter2.wav", + ) + args = parser.parse_args() + loop = asyncio.get_event_loop() + loop.run_until_complete(VADUtterance.stream_utterance_file(args.audio_file)) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index eac5e4b..8ea5a79 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,9 @@ extra_requirements = { "stringcase~=1.2.0", "google-cloud-speech~=1.3.1", ], + "ui": [ + "rangehttpserver~=1.2.0", + ], "train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"], }