1. refactor package root to src/ layout
2. add framwork suffix for models 3. change black max columns to 79 4. add tests 5. integrate vad, encrypt and refactor manifest, regentity, extended_path, audio, parallel utils 6. added ui utils for encrypted preview 7. wip marblenet model 8. added transformers based wav2vec2 inference 9. update readme and manifest 10. add deploy setup targettegra
parent
c474aa5f5a
commit
e07c7c9caf
|
|
@ -1 +1 @@
|
|||
graft plume/utils/gentle_preview
|
||||
graft src
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ The installation should work on Python 3.6 or newer. Untested on Python 2.7
|
|||
### Library
|
||||
> Jasper
|
||||
```python
|
||||
from plume.models.jasper.asr import JasperASR
|
||||
from plume.models.jasper_nemo.asr import JasperASR
|
||||
asr_model = JasperASR("/path/to/model_config_yaml","/path/to/encoder_checkpoint","/path/to/decoder_checkpoint") # Loads the models
|
||||
TEXT = asr_model.transcribe(wav_data) # Returns the text spoken in the wav
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1,28 +0,0 @@
|
|||
from scipy.signal import lfilter, butter
|
||||
from scipy.io.wavfile import read, write
|
||||
from numpy import array, int16
|
||||
import sys
|
||||
|
||||
|
||||
def butter_params(low_freq, high_freq, fs, order=5):
|
||||
nyq = 0.5 * fs
|
||||
low = low_freq / nyq
|
||||
high = high_freq / nyq
|
||||
b, a = butter(order, [low, high], btype="band")
|
||||
return b, a
|
||||
|
||||
|
||||
def butter_bandpass_filter(data, low_freq, high_freq, fs, order=5):
|
||||
b, a = butter_params(low_freq, high_freq, fs, order=order)
|
||||
y = lfilter(b, a, data)
|
||||
return y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fs, audio = read(sys.argv[1])
|
||||
import pdb; pdb.set_trace()
|
||||
low_freq = 300.0
|
||||
high_freq = 4000.0
|
||||
filtered_signal = butter_bandpass_filter(audio, low_freq, high_freq, fs, order=6)
|
||||
fname = sys.argv[1].split(".wav")[0] + "_moded.wav"
|
||||
write(fname, fs, array(filtered_signal, dtype=int16))
|
||||
|
|
@ -1,205 +0,0 @@
|
|||
import logging
|
||||
import asyncio
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import webrtcvad
|
||||
import pydub
|
||||
from pydub.playback import play
|
||||
from pydub.utils import make_chunks
|
||||
|
||||
|
||||
DEFAULT_CHUNK_DUR = 20
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_frame_voice(vad, seg, chunk_dur):
|
||||
return (
|
||||
True
|
||||
if (
|
||||
seg.duration_seconds == chunk_dur / 1000
|
||||
and vad.is_speech(seg.raw_data, seg.frame_rate)
|
||||
)
|
||||
else False
|
||||
)
|
||||
|
||||
|
||||
class VADFilterAudio(object):
|
||||
"""docstring for VADFilterAudio."""
|
||||
|
||||
def __init__(self, chunk_dur=DEFAULT_CHUNK_DUR):
|
||||
super(VADFilterAudio, self).__init__()
|
||||
self.chunk_dur = chunk_dur
|
||||
self.vad = webrtcvad.Vad()
|
||||
|
||||
def filter_segment(self, wav_seg):
|
||||
chunks = make_chunks(wav_seg, self.chunk_dur)
|
||||
speech_buffer = b""
|
||||
|
||||
for i, c in enumerate(chunks[:-1]):
|
||||
voice_frame = is_frame_voice(self.vad, c, self.chunk_dur)
|
||||
if voice_frame:
|
||||
speech_buffer += c.raw_data
|
||||
filtered_seg = pydub.AudioSegment(
|
||||
data=speech_buffer,
|
||||
frame_rate=wav_seg.frame_rate,
|
||||
channels=wav_seg.channels,
|
||||
sample_width=wav_seg.sample_width,
|
||||
)
|
||||
return filtered_seg
|
||||
|
||||
|
||||
class VADUtterance(object):
|
||||
"""docstring for VADUtterance."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_silence=500,
|
||||
min_utterance=280,
|
||||
max_utterance=20000,
|
||||
chunk_dur=DEFAULT_CHUNK_DUR,
|
||||
start_cycles=3,
|
||||
):
|
||||
super(VADUtterance, self).__init__()
|
||||
self.vad = webrtcvad.Vad()
|
||||
self.chunk_dur = chunk_dur
|
||||
# duration in millisecs
|
||||
self.max_sil = max_silence
|
||||
self.min_utt = min_utterance
|
||||
self.max_utt = max_utterance
|
||||
self.speech_start = start_cycles * chunk_dur
|
||||
|
||||
def __repr__(self):
|
||||
return f"VAD(max_silence={self.max_sil},min_utterance:{self.min_utt},max_utterance:{self.max_utt})"
|
||||
|
||||
async def stream_utterance(self, audio_stream):
|
||||
silence_buffer = pydub.AudioSegment.empty()
|
||||
voice_buffer = pydub.AudioSegment.empty()
|
||||
silence_threshold = False
|
||||
async for c in audio_stream:
|
||||
voice_frame = is_frame_voice(self.vad, c, self.chunk_dur)
|
||||
logger.debug(f"is audio stream voice? {voice_frame}")
|
||||
if voice_frame:
|
||||
silence_threshold = False
|
||||
voice_buffer += c
|
||||
silence_buffer = pydub.AudioSegment.empty()
|
||||
else:
|
||||
silence_buffer += c
|
||||
voc_dur = voice_buffer.duration_seconds * 1000
|
||||
sil_dur = silence_buffer.duration_seconds * 1000
|
||||
|
||||
if voc_dur >= self.max_utt:
|
||||
logger.info(
|
||||
f"detected voice overflow: voice duration {voice_buffer.duration_seconds}"
|
||||
)
|
||||
yield voice_buffer
|
||||
voice_buffer = pydub.AudioSegment.empty()
|
||||
|
||||
if sil_dur >= self.max_sil:
|
||||
if voc_dur >= self.min_utt:
|
||||
logger.info(
|
||||
f"detected silence: voice duration {voice_buffer.duration_seconds}"
|
||||
)
|
||||
yield voice_buffer
|
||||
voice_buffer = pydub.AudioSegment.empty()
|
||||
# ignore/clear voice if silence reached threshold or indent the statement
|
||||
if not silence_threshold:
|
||||
silence_threshold = True
|
||||
|
||||
if voice_buffer:
|
||||
yield voice_buffer
|
||||
|
||||
async def stream_events(self, audio_stream):
|
||||
"""
|
||||
yields 0, voice_buffer for SpeechBuffer
|
||||
yields 1, None for StartedSpeaking
|
||||
yields 2, None for StoppedSpeaking
|
||||
yields 4, audio_stream
|
||||
"""
|
||||
silence_buffer = pydub.AudioSegment.empty()
|
||||
voice_buffer = pydub.AudioSegment.empty()
|
||||
silence_threshold, started_speaking = False, False
|
||||
async for c in audio_stream:
|
||||
# yield (4, c)
|
||||
voice_frame = is_frame_voice(self.vad, c, self.chunk_dur)
|
||||
logger.debug(f"is audio stream voice? {voice_frame}")
|
||||
if voice_frame:
|
||||
silence_threshold = False
|
||||
voice_buffer += c
|
||||
silence_buffer = pydub.AudioSegment.empty()
|
||||
else:
|
||||
silence_buffer += c
|
||||
voc_dur = voice_buffer.duration_seconds * 1000
|
||||
sil_dur = silence_buffer.duration_seconds * 1000
|
||||
|
||||
if voc_dur >= self.speech_start and not started_speaking:
|
||||
started_speaking = True
|
||||
yield (1, None)
|
||||
|
||||
if voc_dur >= self.max_utt:
|
||||
logger.info(
|
||||
f"detected voice overflow: voice duration {voice_buffer.duration_seconds}"
|
||||
)
|
||||
yield (0, voice_buffer)
|
||||
voice_buffer = pydub.AudioSegment.empty()
|
||||
started_speaking = False
|
||||
|
||||
if sil_dur >= self.max_sil:
|
||||
if voc_dur >= self.min_utt:
|
||||
logger.info(
|
||||
f"detected silence: voice duration {voice_buffer.duration_seconds}"
|
||||
)
|
||||
yield (0, voice_buffer)
|
||||
voice_buffer = pydub.AudioSegment.empty()
|
||||
started_speaking = False
|
||||
# ignore/clear voice if silence reached threshold or indent the statement
|
||||
if not silence_threshold:
|
||||
silence_threshold = True
|
||||
yield (2, None)
|
||||
|
||||
if voice_buffer:
|
||||
yield (0, voice_buffer)
|
||||
|
||||
@classmethod
|
||||
async def stream_utterance_file(cls, audio_file):
|
||||
async def stream_gen():
|
||||
audio_seg = pydub.AudioSegment.from_file(audio_file).set_frame_rate(32000)
|
||||
chunks = make_chunks(audio_seg, DEFAULT_CHUNK_DUR)
|
||||
for c in chunks:
|
||||
yield c
|
||||
|
||||
va_ut = cls()
|
||||
buffer_src = va_ut.stream_utterance(stream_gen())
|
||||
async for buf in buffer_src:
|
||||
play(buf)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
class VADStreamGen(object):
|
||||
"""docstring for VADStreamGen."""
|
||||
|
||||
def __init__(self, arg):
|
||||
super(VADStreamGen, self).__init__()
|
||||
self.arg = arg
|
||||
|
||||
|
||||
def main():
|
||||
prog = Path(__file__).stem
|
||||
parser = argparse.ArgumentParser(prog=prog, description="transcribes audio file")
|
||||
parser.add_argument(
|
||||
"--audio_file",
|
||||
type=argparse.FileType("rb"),
|
||||
help="audio file to transcribe",
|
||||
default="./test_utter2.wav",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(VADUtterance.stream_utterance_file(args.audio_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
[tool.black]
|
||||
line-length = 79
|
||||
44
setup.py
44
setup.py
|
|
@ -3,12 +3,10 @@ from setuptools import setup, find_namespace_packages
|
|||
# pip install "nvidia-pyindex~=1.0.5"
|
||||
|
||||
requirements = [
|
||||
"torch~=1.6.0",
|
||||
"torchvision~=0.7.0",
|
||||
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
||||
"fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq",
|
||||
# "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
||||
# "fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq",
|
||||
# "google-cloud-texttospeech~=1.0.1",
|
||||
"tqdm~=4.54.0",
|
||||
"tqdm~=4.49.0",
|
||||
# "pydub~=0.24.0",
|
||||
# "scikit_learn~=0.22.1",
|
||||
# "pandas~=1.0.3",
|
||||
|
|
@ -47,8 +45,30 @@ extra_requirements = {
|
|||
"num2words~=0.5.10",
|
||||
"python-slugify~=4.0.0",
|
||||
"rpyc~=4.1.4",
|
||||
"webrtcvad~=2.0.10",
|
||||
# "datasets"
|
||||
# "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
||||
],
|
||||
"models": [
|
||||
# "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
||||
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@v1.0.0#egg=nemo_toolkit",
|
||||
"fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq",
|
||||
"transformers~=4.5.0",
|
||||
"torch~=1.7.0",
|
||||
"torchvision~=0.8.2",
|
||||
"torchaudio~=0.7.2",
|
||||
],
|
||||
"eval": [
|
||||
"jiwer~=2.2.0",
|
||||
"pydub~=0.24.0",
|
||||
"tritonclient[grpc]~=2.9.0",
|
||||
"pyspellchecker~=0.6.2",
|
||||
"num2words~=0.5.10",
|
||||
],
|
||||
"infer": [
|
||||
"pyspellchecker~=0.6.2",
|
||||
"num2words~=0.5.10",
|
||||
],
|
||||
"validation": [
|
||||
"pymongo~=3.10.1",
|
||||
"matplotlib~=3.2.1",
|
||||
|
|
@ -61,15 +81,20 @@ extra_requirements = {
|
|||
"ui": [
|
||||
"rangehttpserver~=1.2.0",
|
||||
],
|
||||
"crypto": ["cryptography~=3.4.7"],
|
||||
"train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"],
|
||||
}
|
||||
|
||||
extra_requirements["all"] = list({d for l in extra_requirements.values() for d in l})
|
||||
packages = find_namespace_packages()
|
||||
extra_requirements["deploy"] = (
|
||||
extra_requirements["models"] + extra_requirements["infer"]
|
||||
)
|
||||
extra_requirements["all"] = list(
|
||||
{d for r in extra_requirements.values() for d in r}
|
||||
)
|
||||
packages = find_namespace_packages("src")
|
||||
|
||||
setup(
|
||||
name="plume-asr",
|
||||
version="0.2.0",
|
||||
version="0.2.1",
|
||||
description="Multi model ASR base package",
|
||||
url="http://github.com/malarinv/plume-asr",
|
||||
author="Malar Kannan",
|
||||
|
|
@ -78,6 +103,7 @@ setup(
|
|||
install_requires=requirements,
|
||||
extras_require=extra_requirements,
|
||||
packages=packages,
|
||||
package_dir={"": "src"},
|
||||
entry_points={"console_scripts": ["plume = plume.cli:main"]},
|
||||
zip_safe=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
from . import main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,6 +1,14 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from random import shuffle
|
||||
from typing import List
|
||||
from itertools import chain
|
||||
|
||||
# from sklearn.model_selection import train_test_split
|
||||
from tqdm import tqdm
|
||||
import shutil
|
||||
import typer
|
||||
|
||||
from plume.utils import (
|
||||
asr_manifest_reader,
|
||||
asr_manifest_writer,
|
||||
|
|
@ -9,18 +17,19 @@ from plume.utils import (
|
|||
generate_filter_map,
|
||||
get_mongo_conn,
|
||||
tscript_uuid_fname,
|
||||
lazy_callable
|
||||
lazy_callable,
|
||||
lazy_module,
|
||||
wav_cryptor,
|
||||
text_cryptor,
|
||||
parallel_apply,
|
||||
)
|
||||
from typing import List
|
||||
from itertools import chain
|
||||
import shutil
|
||||
import typer
|
||||
import soundfile
|
||||
|
||||
from ...models.wav2vec2.data import app as wav2vec2_app
|
||||
from .generate import app as generate_app
|
||||
|
||||
train_test_split = lazy_callable('sklearn.model_selection.train_test_split')
|
||||
soundfile = lazy_module("soundfile")
|
||||
pydub = lazy_module("pydub")
|
||||
train_test_split = lazy_callable("sklearn.model_selection.train_test_split")
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(generate_app, name="generate")
|
||||
|
|
@ -62,7 +71,7 @@ def fix_path(dataset_path: Path, force: bool = False):
|
|||
|
||||
|
||||
@app.command()
|
||||
def augment(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||
def merge(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||
reader_list = []
|
||||
abs_manifest_path = Path("abs_manifest.json")
|
||||
for dataset_path in src_dataset_paths:
|
||||
|
|
@ -74,14 +83,89 @@ def augment(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
|||
|
||||
|
||||
@app.command()
|
||||
def split(dataset_path: Path, test_size: float = 0.03):
|
||||
def training_split(dataset_path: Path, test_size: float = 0.03):
|
||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
if not manifest_path.exists():
|
||||
fix_path(dataset_path)
|
||||
asr_data = list(asr_manifest_reader(manifest_path))
|
||||
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
|
||||
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
|
||||
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
|
||||
asr_manifest_writer(
|
||||
manifest_path.with_name("train_manifest.json"), train_pnr
|
||||
)
|
||||
asr_manifest_writer(
|
||||
manifest_path.with_name("test_manifest.json"), test_pnr
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def parts_split_by_size(
|
||||
dataset_path: Path,
|
||||
test_size: float = 0.03,
|
||||
split_prefix_names: List[str] = ["train", "test"],
|
||||
):
|
||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
if not manifest_path.exists():
|
||||
fix_path(dataset_path)
|
||||
asr_data = list(asr_manifest_reader(manifest_path))
|
||||
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
|
||||
dest_paths = [
|
||||
(dataset_path.parent / (dataset_path.name + "_" + spn), sd)
|
||||
for (spn, sd) in zip(split_prefix_names, [train_pnr, test_pnr])
|
||||
]
|
||||
for dest_path, manifest_data in dest_paths:
|
||||
wav_dir = dest_path / Path("wavs")
|
||||
wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
abs_manifest_path = ExtendedPath(dest_path / Path("abs_manifest.json"))
|
||||
for md in manifest_data:
|
||||
orig_path = Path(md["audio_filepath"])
|
||||
new_path = wav_dir / Path(orig_path.name)
|
||||
shutil.copy(orig_path, new_path)
|
||||
md["audio_filepath"] = str(new_path)
|
||||
md.pop("audio_path")
|
||||
abs_manifest_path.write_jsonl(manifest_data)
|
||||
fix_path(dest_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
def parts_split_by_dur(
|
||||
dataset_path: Path,
|
||||
dur_sec: int = 7200,
|
||||
suffix_name: List[str] = ["train", "test"],
|
||||
):
|
||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
if not manifest_path.exists():
|
||||
fix_path(dataset_path)
|
||||
asr_data = list(asr_manifest_reader(manifest_path))
|
||||
|
||||
def dur_split(dataset, dur_seconds):
|
||||
shuffle(dataset)
|
||||
counter_dur = 0
|
||||
train_set, test_set = [], []
|
||||
for d in dataset:
|
||||
if counter_dur <= dur_seconds:
|
||||
test_set.append(d)
|
||||
else:
|
||||
train_set.append(d)
|
||||
counter_dur += d["duration"]
|
||||
return train_set, test_set
|
||||
|
||||
train_pnr, test_pnr = dur_split(asr_data, dur_sec)
|
||||
dest_paths = [
|
||||
(dataset_path.parent / (dataset_path.name + "_" + spn), sd)
|
||||
for (spn, sd) in zip(suffix_name, [train_pnr, test_pnr])
|
||||
]
|
||||
for dest_path, manifest_data in dest_paths:
|
||||
wav_dir = dest_path / Path("wavs")
|
||||
wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
abs_manifest_path = ExtendedPath(dest_path / Path("abs_manifest.json"))
|
||||
for md in manifest_data:
|
||||
orig_path = Path(md["audio_filepath"])
|
||||
new_path = wav_dir / Path(orig_path.name)
|
||||
shutil.copy(orig_path, new_path)
|
||||
md["audio_filepath"] = str(new_path.absolute())
|
||||
md.pop("audio_path")
|
||||
abs_manifest_path.write_jsonl(manifest_data)
|
||||
fix_path(dest_path.absolute())
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
@ -111,7 +195,11 @@ def validate(dataset_path: Path):
|
|||
|
||||
|
||||
@app.command()
|
||||
def filter(src_dataset_path: Path, dest_dataset_path: Path, kind: str = "skip_dur"):
|
||||
def filter(
|
||||
src_dataset_path: Path,
|
||||
dest_dataset_path: Path,
|
||||
kind: str = "",
|
||||
):
|
||||
dest_manifest = dest_dataset_path / Path("manifest.json")
|
||||
data_file = src_dataset_path / Path("manifest.json")
|
||||
dest_wav_dir = dest_dataset_path / Path("wavs")
|
||||
|
|
@ -149,13 +237,21 @@ def info(dataset_path: Path):
|
|||
real_duration += soundfile.info(wav_path).duration
|
||||
|
||||
# frame_count = soundfile.info(audio_fname).frames
|
||||
print(f"max audio duration : {duration_str(max_duration)}")
|
||||
print(f"total audio duration : {duration_str(mf_wav_duration)}")
|
||||
print(f"total real audio duration : {duration_str(real_duration)}")
|
||||
print(
|
||||
f"total content duration : {duration_str(mf_wav_duration-empty_duration)}"
|
||||
f"max audio duration : {duration_str(max_duration, show_hours=True)}"
|
||||
)
|
||||
print(
|
||||
f"total audio duration : {duration_str(mf_wav_duration, show_hours=True)}"
|
||||
)
|
||||
print(
|
||||
f"total real audio duration : {duration_str(real_duration, show_hours=True)}"
|
||||
)
|
||||
print(
|
||||
f"total content duration : {duration_str(mf_wav_duration-empty_duration, show_hours=True)}"
|
||||
)
|
||||
print(
|
||||
f"total empty duration : {duration_str(empty_duration, show_hours=True)}"
|
||||
)
|
||||
print(f"total empty duration : {duration_str(empty_duration)}")
|
||||
print(
|
||||
f"total empty samples : {empty_count}/{total_count} ({empty_count*100/total_count:.2f}%)"
|
||||
)
|
||||
|
|
@ -167,7 +263,81 @@ def audio_duration(dataset_path: Path):
|
|||
for audio_rel_fname in dataset_path.absolute().glob("**/*.wav"):
|
||||
audio_fname = str(audio_rel_fname)
|
||||
wav_duration += soundfile.info(audio_fname).duration
|
||||
typer.echo(f"duration of wav files @ {dataset_path}: {duration_str(wav_duration)}")
|
||||
typer.echo(
|
||||
f"duration of wav files @ {dataset_path}: {duration_str(wav_duration)}"
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def encrypt(
|
||||
src_dataset_path: Path,
|
||||
dest_dataset_path: Path,
|
||||
encryption_key: str = typer.Option(..., prompt=True, hide_input=True),
|
||||
verbose: bool = False,
|
||||
):
|
||||
dest_manifest = dest_dataset_path / Path("manifest.json")
|
||||
src_manifest = src_dataset_path / Path("manifest.json")
|
||||
dest_wav_dir = dest_dataset_path / Path("wavs")
|
||||
dest_wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
wav_crypt = wav_cryptor(encryption_key)
|
||||
text_crypt = text_cryptor(encryption_key)
|
||||
# warmup
|
||||
_ = pydub.AudioSegment.from_file
|
||||
|
||||
def encrypt_item(s):
|
||||
crypt_text = text_crypt.encrypt_text(s["text"])
|
||||
src_wav_path = src_dataset_path / s["audio_filepath"]
|
||||
dst_wav_path = dest_dataset_path / s["audio_filepath"]
|
||||
wav_crypt.encrypt_wav_path_to(src_wav_path, dst_wav_path)
|
||||
s["text"] = crypt_text.decode("utf-8")
|
||||
return s
|
||||
|
||||
def encryted_gen():
|
||||
data = list(ExtendedPath(src_manifest).read_jsonl())
|
||||
iter_data = tqdm(data) if verbose else data
|
||||
encrypted_iter_data = parallel_apply(
|
||||
encrypt_item, iter_data, verbose=verbose, workers=64
|
||||
)
|
||||
for s in encrypted_iter_data:
|
||||
yield s
|
||||
|
||||
asr_manifest_writer(dest_manifest, encryted_gen(), verbose=verbose)
|
||||
|
||||
|
||||
@app.command()
|
||||
def decrypt(
|
||||
src_dataset_path: Path,
|
||||
dest_dataset_path: Path,
|
||||
encryption_key: str = typer.Option(..., prompt=True, hide_input=True),
|
||||
verbose: bool = True,
|
||||
):
|
||||
dest_manifest = dest_dataset_path / Path("manifest.json")
|
||||
src_manifest = src_dataset_path / Path("manifest.json")
|
||||
dest_wav_dir = dest_dataset_path / Path("wavs")
|
||||
dest_wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
wav_crypt = wav_cryptor(encryption_key)
|
||||
text_crypt = text_cryptor(encryption_key)
|
||||
# warmup
|
||||
_ = pydub.AudioSegment.from_file
|
||||
|
||||
def decrypt_item(s):
|
||||
crypt_text = text_crypt.decrypt_text(s["text"].encode("utf-8"))
|
||||
src_wav_path = src_dataset_path / s["audio_filepath"]
|
||||
dst_wav_path = dest_dataset_path / s["audio_filepath"]
|
||||
wav_crypt.decrypt_wav_path_to(src_wav_path, dst_wav_path)
|
||||
s["text"] = crypt_text
|
||||
return s
|
||||
|
||||
def decryted_gen():
|
||||
data = list(ExtendedPath(src_manifest).read_jsonl())
|
||||
iter_data = tqdm(data) if verbose else data
|
||||
decrypted_iter_data = parallel_apply(
|
||||
decrypt_item, iter_data, verbose=verbose, workers=64
|
||||
)
|
||||
for s in decrypted_iter_data:
|
||||
yield s
|
||||
|
||||
asr_manifest_writer(dest_manifest, decryted_gen(), verbose=verbose)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
@ -204,13 +374,19 @@ def task_split(
|
|||
|
||||
processed_data_path = data_dir / dump_file
|
||||
processed_data = ExtendedPath(processed_data_path).read_json()
|
||||
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
|
||||
df = (
|
||||
pd.DataFrame(processed_data["data"])
|
||||
.sample(frac=1)
|
||||
.reset_index(drop=True)
|
||||
)
|
||||
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
|
||||
task_f = task_f.reset_index(drop=True)
|
||||
task_f["real_idx"] = task_f.index
|
||||
task_data = task_f.to_dict("records")
|
||||
if sort:
|
||||
task_data = sorted(task_data, key=lambda x: x["asr_wer"], reverse=True)
|
||||
task_data = sorted(
|
||||
task_data, key=lambda x: x["asr_wer"], reverse=True
|
||||
)
|
||||
processed_data["data"] = task_data
|
||||
task_path = data_dir / Path(task_file + f"-{t_idx}.json")
|
||||
ExtendedPath(task_path).write_json(processed_data)
|
||||
|
|
@ -223,7 +399,9 @@ def get_corrections(task_uid):
|
|||
for c in col.distinct("task_id")
|
||||
if c.rsplit("-", 1)[1] == task_uid or c == task_uid
|
||||
][0]
|
||||
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
||||
corrections = list(
|
||||
col.find({"type": "correction"}, projection={"_id": False})
|
||||
)
|
||||
cursor_obj = col.find(
|
||||
{"type": "correction", "task_id": task_id}, projection={"_id": False}
|
||||
)
|
||||
|
|
@ -241,8 +419,8 @@ def dump_task_corrections(data_dir: Path, task_uid: str):
|
|||
|
||||
@app.command()
|
||||
def dump_all_corrections(data_dir: Path):
|
||||
for task_lcks in data_dir.glob('task-*.lck'):
|
||||
task_uid = task_lcks.stem.replace('task-', '')
|
||||
for task_lcks in data_dir.glob("task-*.lck"):
|
||||
task_uid = task_lcks.stem.replace("task-", "")
|
||||
dump_task_corrections(data_dir, task_uid)
|
||||
|
||||
|
||||
|
|
@ -292,7 +470,9 @@ def update_corrections(
|
|||
correct_text = correction_map[d["utterance_id"]]
|
||||
if skip_incorrect:
|
||||
ap = d["audio_path"]
|
||||
print(f"skipping incorrect {ap} corrected to {correct_text}")
|
||||
print(
|
||||
f"skipping incorrect {ap} corrected to {correct_text}"
|
||||
)
|
||||
orig_audio_path.unlink()
|
||||
else:
|
||||
new_fname = tscript_uuid_fname(correct_text)
|
||||
|
|
@ -304,7 +484,9 @@ def update_corrections(
|
|||
new_name = str(Path(new_fname).with_suffix(".wav"))
|
||||
new_audio_path = orig_audio_path.with_name(new_name)
|
||||
orig_audio_path.replace(new_audio_path)
|
||||
new_filepath = str(Path(d["audio_path"]).with_name(new_name))
|
||||
new_filepath = str(
|
||||
Path(d["audio_path"]).with_name(new_name)
|
||||
)
|
||||
d["corrected_from"] = d["text"]
|
||||
d["text"] = correct_text
|
||||
d["audio_path"] = new_filepath
|
||||
|
|
@ -325,7 +507,9 @@ def update_corrections(
|
|||
shutil.copytree(str(dataset_dir), str(backup_dir))
|
||||
renames = {}
|
||||
corrected_ui_dump = list(correct_ui_dump(data_dir, renames))
|
||||
ExtendedPath(data_dir / ui_dump_file).write_json({"data": corrected_ui_dump})
|
||||
ExtendedPath(data_dir / ui_dump_file).write_json(
|
||||
{"data": corrected_ui_dump}
|
||||
)
|
||||
corrected_manifest = (
|
||||
{
|
||||
"audio_filepath": d["audio_path"],
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
import typer
|
||||
from ..models.wav2vec2.eval import app as wav2vec2_app
|
||||
from ..models.wav2vec2_transformers.eval import app as wav2vec2_transformers_app
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(wav2vec2_app, name="wav2vec2")
|
||||
app.add_typer(wav2vec2_transformers_app, name="wav2vec2_transformers")
|
||||
|
||||
|
||||
@app.callback()
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
import typer
|
||||
from ..models.wav2vec2.serve import app as wav2vec2_app
|
||||
from ..models.jasper.serve import app as jasper_app
|
||||
from ..models.wav2vec2_transformers.serve import app as wav2vec2_transformers_app
|
||||
from ..models.jasper_nemo.serve import app as jasper_app
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(wav2vec2_app, name="wav2vec2")
|
||||
app.add_typer(wav2vec2_transformers_app, name="wav2vec2_transformers")
|
||||
app.add_typer(jasper_app, name="jasper")
|
||||
|
||||
|
||||
|
|
@ -16,7 +16,11 @@ class JasperASR(object):
|
|||
"""docstring for JasperASR."""
|
||||
|
||||
def __init__(
|
||||
self, model_yaml, encoder_checkpoint, decoder_checkpoint, language_model=None
|
||||
self,
|
||||
model_yaml,
|
||||
encoder_checkpoint,
|
||||
decoder_checkpoint,
|
||||
language_model=None,
|
||||
):
|
||||
super(JasperASR, self).__init__()
|
||||
# Read model YAML
|
||||
|
|
@ -24,16 +28,17 @@ class JasperASR(object):
|
|||
with open(model_yaml) as f:
|
||||
jasper_model_definition = yaml.load(f)
|
||||
self.neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU, backend=nemo.core.Backend.PyTorch
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
)
|
||||
self.labels = jasper_model_definition["labels"]
|
||||
self.data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor()
|
||||
self.jasper_encoder = nemo_asr.JasperEncoder(
|
||||
jasper=jasper_model_definition["JasperEncoder"]["jasper"],
|
||||
activation=jasper_model_definition["JasperEncoder"]["activation"],
|
||||
feat_in=jasper_model_definition["AudioToMelSpectrogramPreprocessor"][
|
||||
"features"
|
||||
],
|
||||
feat_in=jasper_model_definition[
|
||||
"AudioToMelSpectrogramPreprocessor"
|
||||
]["features"],
|
||||
)
|
||||
self.jasper_encoder.restore_from(encoder_checkpoint, local_rank=0)
|
||||
self.jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
|
|
@ -65,7 +70,11 @@ class JasperASR(object):
|
|||
wf.setframerate(24000)
|
||||
wf.writeframesraw(audio_data)
|
||||
wf.close()
|
||||
manifest = {"audio_filepath": audio_file_path, "duration": 60, "text": "todo"}
|
||||
manifest = {
|
||||
"audio_filepath": audio_file_path,
|
||||
"duration": 60,
|
||||
"text": "todo",
|
||||
}
|
||||
manifest_file = tempfile.NamedTemporaryFile(
|
||||
dir=WORK_DIR, prefix="jasper_manifest.", delete=False, mode="w"
|
||||
)
|
||||
|
|
@ -11,7 +11,12 @@ from nemo.backends.pytorch import DataLayerNM
|
|||
from nemo.core import DeviceType
|
||||
|
||||
# from nemo.core.neural_types import *
|
||||
from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType
|
||||
from nemo.core.neural_types import (
|
||||
NeuralType,
|
||||
AudioSignal,
|
||||
LengthsType,
|
||||
LabelsType,
|
||||
)
|
||||
from nemo.utils.decorators import add_port_docs
|
||||
|
||||
from nemo.collections.asr.parts.dataset import (
|
||||
|
|
@ -217,8 +222,7 @@ transcript_n}
|
|||
@property
|
||||
@add_port_docs()
|
||||
def output_ports(self):
|
||||
"""Returns definitions of module output ports.
|
||||
"""
|
||||
"""Returns definitions of module output ports."""
|
||||
return {
|
||||
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
|
|
@ -304,7 +308,9 @@ transcript_n}
|
|||
# Set up data loader
|
||||
if self._placement == DeviceType.AllGpu:
|
||||
logging.info("Parallelizing Datalayer.")
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(self._dataset)
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
self._dataset
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
# import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
|
@ -57,7 +58,10 @@ def parse_args():
|
|||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs", type=int, required=False, help="number of epochs to train"
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
|
|
@ -170,7 +174,8 @@ def create_all_dags(args, neural_factory):
|
|||
# logging.info("Have {0} examples to train on.".format(N))
|
||||
#
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
|
||||
sample_rate=sample_rate,
|
||||
**jasper_params["AudioToMelSpectrogramPreprocessor"],
|
||||
)
|
||||
|
||||
# multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
|
|
@ -284,7 +289,12 @@ def create_all_dags(args, neural_factory):
|
|||
callbacks = []
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
|
||||
(
|
||||
audio_signal_e,
|
||||
a_sig_length_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
|
|
@ -303,9 +313,16 @@ def create_all_dags(args, neural_factory):
|
|||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
|
||||
eval_tensors=[
|
||||
loss_e,
|
||||
predictions_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
|
||||
user_epochs_done_callback=partial(
|
||||
process_evaluation_epoch, tag=tagname
|
||||
),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
|
@ -3,19 +3,27 @@
|
|||
# import librosa
|
||||
import torch
|
||||
import pickle
|
||||
|
||||
# import torch.nn as nn
|
||||
# from torch_stft import STFT
|
||||
|
||||
# from nemo import logging
|
||||
from nemo.collections.asr.parts.perturb import AudioAugmentor
|
||||
|
||||
# from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
|
||||
class RpycWaveformFeaturizer(object):
|
||||
def __init__(
|
||||
self, sample_rate=16000, int_values=False, augmentor=None, rpyc_conn=None
|
||||
self,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
augmentor=None,
|
||||
rpyc_conn=None,
|
||||
):
|
||||
self.augmentor = augmentor if augmentor is not None else AudioAugmentor()
|
||||
self.augmentor = (
|
||||
augmentor if augmentor is not None else AudioAugmentor()
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.int_values = int_values
|
||||
self.remote_path_samples = rpyc_conn.get_path_samples
|
||||
|
|
@ -48,4 +56,6 @@ class RpycWaveformFeaturizer(object):
|
|||
sample_rate = input_config.get("sample_rate", 16000)
|
||||
int_values = input_config.get("int_values", False)
|
||||
|
||||
return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa)
|
||||
return cls(
|
||||
sample_rate=sample_rate, int_values=int_values, augmentor=aa
|
||||
)
|
||||
|
|
@ -9,7 +9,7 @@ import typer
|
|||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
JasperASR = lazy_callable('plume.models.jasper.asr.JasperASR')
|
||||
JasperASR = lazy_callable("plume.models.jasper_nemo.asr.JasperASR")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
|
@ -37,7 +37,9 @@ def rpyc(
|
|||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
|
||||
def rpyc_dir(
|
||||
model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||
):
|
||||
encoder_path = model_dir / Path("decoder.pt")
|
||||
decoder_path = model_dir / Path("encoder.pt")
|
||||
model_yaml_path = model_dir / Path("model.yaml")
|
||||
|
|
@ -40,7 +40,9 @@ class ASRDataService(rpyc.Service):
|
|||
|
||||
@app.command()
|
||||
def run_server(port: int = 0):
|
||||
listen_port = port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
|
||||
listen_port = (
|
||||
port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
|
||||
)
|
||||
service = ASRDataService()
|
||||
t = ThreadedServer(
|
||||
service, port=listen_port, protocol_config={"allow_all_attrs": True}
|
||||
|
|
@ -161,7 +161,8 @@ def create_all_dags(args, neural_factory):
|
|||
logging.info("Have {0} examples to train on.".format(N))
|
||||
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
|
||||
sample_rate=sample_rate,
|
||||
**jasper_params["AudioToMelSpectrogramPreprocessor"],
|
||||
)
|
||||
|
||||
multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
|
|
@ -212,8 +213,12 @@ def create_all_dags(args, neural_factory):
|
|||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
logging.info("================================")
|
||||
logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
|
||||
logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
|
||||
logging.info(
|
||||
f"Number of parameters in encoder: {jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info(
|
||||
f"Number of parameters in decoder: {jasper_decoder.num_weights}"
|
||||
)
|
||||
logging.info(
|
||||
f"Total number of parameters in model: "
|
||||
f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
|
|
@ -221,7 +226,12 @@ def create_all_dags(args, neural_factory):
|
|||
logging.info("================================")
|
||||
|
||||
# Train DAG
|
||||
(audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
|
||||
(
|
||||
audio_signal_t,
|
||||
a_sig_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = data_layer()
|
||||
processed_signal_t, p_length_t = data_preprocessor(
|
||||
input_signal=audio_signal_t, length=a_sig_length_t
|
||||
)
|
||||
|
|
@ -240,7 +250,9 @@ def create_all_dags(args, neural_factory):
|
|||
)
|
||||
|
||||
if spectr_augment_config:
|
||||
processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
|
||||
processed_signal_t = data_spectr_augmentation(
|
||||
input_spec=processed_signal_t
|
||||
)
|
||||
|
||||
encoded_t, encoded_len_t = jasper_encoder(
|
||||
audio_signal=processed_signal_t, length=p_length_t
|
||||
|
|
@ -273,7 +285,12 @@ def create_all_dags(args, neural_factory):
|
|||
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
|
||||
(
|
||||
audio_signal_e,
|
||||
a_sig_length_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
|
|
@ -292,9 +309,16 @@ def create_all_dags(args, neural_factory):
|
|||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
|
||||
eval_tensors=[
|
||||
loss_e,
|
||||
predictions_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
|
||||
user_epochs_done_callback=partial(
|
||||
process_evaluation_epoch, tag=tagname
|
||||
),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
|
@ -338,7 +362,9 @@ def main():
|
|||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory)
|
||||
train_loss, callbacks, steps_per_epoch = create_all_dags(
|
||||
args, neural_factory
|
||||
)
|
||||
# train model
|
||||
neural_factory.train(
|
||||
tensors_to_optimize=[train_loss],
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
import os
|
||||
import tempfile
|
||||
from ruamel.yaml import YAML
|
||||
import json
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import wave
|
||||
from nemo.collections.asr.helpers import post_process_predictions
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
WORK_DIR = "/tmp"
|
||||
|
||||
|
||||
class JasperASR(object):
|
||||
"""docstring for JasperASR."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_yaml,
|
||||
encoder_checkpoint,
|
||||
decoder_checkpoint,
|
||||
language_model=None,
|
||||
):
|
||||
super(JasperASR, self).__init__()
|
||||
# Read model YAML
|
||||
yaml = YAML(typ="safe")
|
||||
with open(model_yaml) as f:
|
||||
jasper_model_definition = yaml.load(f)
|
||||
self.neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
)
|
||||
self.labels = jasper_model_definition["labels"]
|
||||
self.data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor()
|
||||
self.jasper_encoder = nemo_asr.JasperEncoder(
|
||||
jasper=jasper_model_definition["JasperEncoder"]["jasper"],
|
||||
activation=jasper_model_definition["JasperEncoder"]["activation"],
|
||||
feat_in=jasper_model_definition[
|
||||
"AudioToMelSpectrogramPreprocessor"
|
||||
]["features"],
|
||||
)
|
||||
self.jasper_encoder.restore_from(encoder_checkpoint, local_rank=0)
|
||||
self.jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=1024, num_classes=len(self.labels)
|
||||
)
|
||||
self.jasper_decoder.restore_from(decoder_checkpoint, local_rank=0)
|
||||
self.greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
self.beam_search_with_lm = None
|
||||
if language_model:
|
||||
self.beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
|
||||
vocab=self.labels,
|
||||
beam_width=64,
|
||||
alpha=2.0,
|
||||
beta=1.0,
|
||||
lm_path=language_model,
|
||||
num_cpus=max(os.cpu_count(), 1),
|
||||
)
|
||||
|
||||
def transcribe(self, audio_data, greedy=True):
|
||||
audio_file = tempfile.NamedTemporaryFile(
|
||||
dir=WORK_DIR, prefix="jasper_audio.", delete=False
|
||||
)
|
||||
# audio_file.write(audio_data)
|
||||
audio_file.close()
|
||||
audio_file_path = audio_file.name
|
||||
wf = wave.open(audio_file_path, "w")
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(24000)
|
||||
wf.writeframesraw(audio_data)
|
||||
wf.close()
|
||||
manifest = {
|
||||
"audio_filepath": audio_file_path,
|
||||
"duration": 60,
|
||||
"text": "todo",
|
||||
}
|
||||
manifest_file = tempfile.NamedTemporaryFile(
|
||||
dir=WORK_DIR, prefix="jasper_manifest.", delete=False, mode="w"
|
||||
)
|
||||
manifest_file.write(json.dumps(manifest))
|
||||
manifest_file.close()
|
||||
manifest_file_path = manifest_file.name
|
||||
data_layer = nemo_asr.AudioToTextDataLayer(
|
||||
shuffle=False,
|
||||
manifest_filepath=manifest_file_path,
|
||||
labels=self.labels,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
# Define inference DAG
|
||||
audio_signal, audio_signal_len, _, _ = data_layer()
|
||||
processed_signal, processed_signal_len = self.data_preprocessor(
|
||||
input_signal=audio_signal, length=audio_signal_len
|
||||
)
|
||||
encoded, encoded_len = self.jasper_encoder(
|
||||
audio_signal=processed_signal, length=processed_signal_len
|
||||
)
|
||||
log_probs = self.jasper_decoder(encoder_output=encoded)
|
||||
predictions = self.greedy_decoder(log_probs=log_probs)
|
||||
|
||||
if greedy:
|
||||
eval_tensors = [predictions]
|
||||
else:
|
||||
if self.beam_search_with_lm:
|
||||
logging.info("Running with beam search")
|
||||
beam_predictions = self.beam_search_with_lm(
|
||||
log_probs=log_probs, log_probs_length=encoded_len
|
||||
)
|
||||
eval_tensors = [beam_predictions]
|
||||
else:
|
||||
logging.info(
|
||||
"language_model not specified. falling back to greedy decoding."
|
||||
)
|
||||
eval_tensors = [predictions]
|
||||
|
||||
tensors = self.neural_factory.infer(tensors=eval_tensors)
|
||||
prediction = post_process_predictions(tensors[0], self.labels)
|
||||
prediction_text = ". ".join(prediction)
|
||||
os.unlink(manifest_file.name)
|
||||
os.unlink(audio_file.name)
|
||||
return prediction_text
|
||||
|
||||
def transcribe_file(self, audio_file, *args, **kwargs):
|
||||
tscript_file_path = audio_file.with_suffix(".txt")
|
||||
audio_file_path = str(audio_file)
|
||||
with wave.open(audio_file_path, "r") as af:
|
||||
frame_count = af.getnframes()
|
||||
audio_data = af.readframes(frame_count)
|
||||
transcription = self.transcribe(audio_data, *args, **kwargs)
|
||||
with open(tscript_file_path, "w") as tf:
|
||||
tf.write(transcription)
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
from pathlib import Path
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def set_root(dataset_path: Path, root_path: Path):
|
||||
pass
|
||||
# for dataset_kind in ["train", "valid"]:
|
||||
# data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||
# with data_file.open("r") as df:
|
||||
# lines = df.readlines()
|
||||
# with data_file.open("w") as df:
|
||||
# lines[0] = str(root_path) + "\n"
|
||||
# df.writelines(lines)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,340 @@
|
|||
from functools import partial
|
||||
import tempfile
|
||||
|
||||
# from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import nemo
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.backends.pytorch import DataLayerNM
|
||||
from nemo.core import DeviceType
|
||||
|
||||
# from nemo.core.neural_types import *
|
||||
from nemo.core.neural_types import (
|
||||
NeuralType,
|
||||
AudioSignal,
|
||||
LengthsType,
|
||||
LabelsType,
|
||||
)
|
||||
from nemo.utils.decorators import add_port_docs
|
||||
|
||||
from nemo.collections.asr.parts.dataset import (
|
||||
# AudioDataset,
|
||||
# AudioLabelDataset,
|
||||
# KaldiFeatureDataset,
|
||||
# TranscriptDataset,
|
||||
parsers,
|
||||
collections,
|
||||
seq_collate_fn,
|
||||
)
|
||||
|
||||
# from functools import lru_cache
|
||||
import rpyc
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
from .featurizer import RpycWaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
|
||||
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
class CachedAudioDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset that loads tensors via a json file containing paths to audio
|
||||
files, transcripts, and durations (in seconds). Each new line is a
|
||||
different sample. Example below:
|
||||
|
||||
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
|
||||
"/path/to/audio.txt", "duration": 23.147}
|
||||
...
|
||||
{"audio_filepath": "/path/to/audio.wav", "text": "the
|
||||
transcription", offset": 301.75, "duration": 0.82, "utt":
|
||||
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
||||
|
||||
Args:
|
||||
manifest_filepath: Path to manifest json as described above. Can
|
||||
be comma-separated paths.
|
||||
labels: String containing all the possible characters to map to
|
||||
featurizer: Initialized featurizer class that converts paths of
|
||||
audio to feature tensors
|
||||
max_duration: If audio exceeds this length, do not include in dataset
|
||||
min_duration: If audio is less than this length, do not include
|
||||
in dataset
|
||||
max_utts: Limit number of utterances
|
||||
blank_index: blank character index, default = -1
|
||||
unk_index: unk_character index, default = -1
|
||||
normalize: whether to normalize transcript text (default): True
|
||||
bos_id: Id of beginning of sequence symbol to append if not None
|
||||
eos_id: Id of end of sequence symbol to append if not None
|
||||
load_audio: Boolean flag indicate whether do or not load audio
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
featurizer,
|
||||
max_duration=None,
|
||||
min_duration=None,
|
||||
max_utts=0,
|
||||
blank_index=-1,
|
||||
unk_index=-1,
|
||||
normalize=True,
|
||||
trim=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
load_audio=True,
|
||||
parser="en",
|
||||
):
|
||||
self.collection = collections.ASRAudioText(
|
||||
manifests_files=manifest_filepath.split(","),
|
||||
parser=parsers.make_parser(
|
||||
labels=labels,
|
||||
name=parser,
|
||||
unk_id=unk_index,
|
||||
blank_id=blank_index,
|
||||
do_normalize=normalize,
|
||||
),
|
||||
min_duration=min_duration,
|
||||
max_duration=max_duration,
|
||||
max_number=max_utts,
|
||||
)
|
||||
self.index_feature_map = {}
|
||||
|
||||
self.featurizer = featurizer
|
||||
self.trim = trim
|
||||
self.eos_id = eos_id
|
||||
self.bos_id = bos_id
|
||||
self.load_audio = load_audio
|
||||
print(f"initializing dataset {manifest_filepath}")
|
||||
|
||||
def exec_func(i):
|
||||
return self[i]
|
||||
|
||||
task_count = len(self.collection)
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all loading tasks")
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, range(task_count)),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=task_count,
|
||||
)
|
||||
)
|
||||
print(f"initializing complete")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.collection[index]
|
||||
if self.load_audio:
|
||||
cached_features = self.index_feature_map.get(index)
|
||||
if cached_features is not None:
|
||||
features = cached_features
|
||||
else:
|
||||
features = self.featurizer.process(
|
||||
sample.audio_file,
|
||||
offset=0,
|
||||
duration=sample.duration,
|
||||
trim=self.trim,
|
||||
)
|
||||
self.index_feature_map[index] = features
|
||||
f, fl = features, torch.tensor(features.shape[0]).long()
|
||||
else:
|
||||
f, fl = None, None
|
||||
|
||||
t, tl = sample.text_tokens, len(sample.text_tokens)
|
||||
if self.bos_id is not None:
|
||||
t = [self.bos_id] + t
|
||||
tl += 1
|
||||
if self.eos_id is not None:
|
||||
t = t + [self.eos_id]
|
||||
tl += 1
|
||||
|
||||
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.collection)
|
||||
|
||||
|
||||
class RpycAudioToTextDataLayer(DataLayerNM):
|
||||
"""Data Layer for general ASR tasks.
|
||||
|
||||
Module which reads ASR labeled data. It accepts comma-separated
|
||||
JSON manifest files describing the correspondence between wav audio files
|
||||
and their transcripts. JSON files should be of the following format::
|
||||
|
||||
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
|
||||
transcript_0}
|
||||
...
|
||||
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
|
||||
transcript_n}
|
||||
|
||||
Args:
|
||||
manifest_filepath (str): Dataset parameter.
|
||||
Path to JSON containing data.
|
||||
labels (list): Dataset parameter.
|
||||
List of characters that can be output by the ASR model.
|
||||
For Jasper, this is the 28 character set {a-z '}. The CTC blank
|
||||
symbol is automatically added later for models using ctc.
|
||||
batch_size (int): batch size
|
||||
sample_rate (int): Target sampling rate for data. Audio files will be
|
||||
resampled to sample_rate if it is not already.
|
||||
Defaults to 16000.
|
||||
int_values (bool): Bool indicating whether the audio file is saved as
|
||||
int data or float data.
|
||||
Defaults to False.
|
||||
eos_id (id): Dataset parameter.
|
||||
End of string symbol id used for seq2seq models.
|
||||
Defaults to None.
|
||||
min_duration (float): Dataset parameter.
|
||||
All training files which have a duration less than min_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to 0.1.
|
||||
max_duration (float): Dataset parameter.
|
||||
All training files which have a duration more than max_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to None.
|
||||
normalize_transcripts (bool): Dataset parameter.
|
||||
Whether to use automatic text cleaning.
|
||||
It is highly recommended to manually clean text for best results.
|
||||
Defaults to True.
|
||||
trim_silence (bool): Whether to use trim silence from beginning and end
|
||||
of audio signal using librosa.effects.trim().
|
||||
Defaults to False.
|
||||
load_audio (bool): Dataset parameter.
|
||||
Controls whether the dataloader loads the audio signal and
|
||||
transcript or just the transcript.
|
||||
Defaults to True.
|
||||
drop_last (bool): See PyTorch DataLoader.
|
||||
Defaults to False.
|
||||
shuffle (bool): See PyTorch DataLoader.
|
||||
Defaults to True.
|
||||
num_workers (int): See PyTorch DataLoader.
|
||||
Defaults to 0.
|
||||
perturb_config (dict): Currently disabled.
|
||||
"""
|
||||
|
||||
@property
|
||||
@add_port_docs()
|
||||
def output_ports(self):
|
||||
"""Returns definitions of module output ports."""
|
||||
return {
|
||||
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
"audio_signal": NeuralType(
|
||||
("B", "T"),
|
||||
AudioSignal(freq=self._sample_rate)
|
||||
if self is not None and self._sample_rate is not None
|
||||
else AudioSignal(),
|
||||
),
|
||||
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
|
||||
"transcripts": NeuralType(("B", "T"), LabelsType()),
|
||||
"transcript_length": NeuralType(tuple("B"), LengthsType()),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
batch_size,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
pad_id=None,
|
||||
min_duration=0.1,
|
||||
max_duration=None,
|
||||
normalize_transcripts=True,
|
||||
trim_silence=False,
|
||||
load_audio=True,
|
||||
rpyc_host="",
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
):
|
||||
super().__init__()
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
def rpyc_root_fn():
|
||||
return rpyc.connect(
|
||||
rpyc_host, 8064, config={"sync_request_timeout": 600}
|
||||
).root
|
||||
|
||||
rpyc_conn = rpyc_root_fn()
|
||||
|
||||
self._featurizer = RpycWaveformFeaturizer(
|
||||
sample_rate=self._sample_rate,
|
||||
int_values=int_values,
|
||||
augmentor=None,
|
||||
rpyc_conn=rpyc_conn,
|
||||
)
|
||||
|
||||
def read_remote_manifests():
|
||||
local_mp = []
|
||||
for mrp in manifest_filepath.split(","):
|
||||
md = rpyc_conn.read_path(mrp)
|
||||
mf = tempfile.NamedTemporaryFile(
|
||||
dir="/tmp", prefix="jasper_manifest.", delete=False
|
||||
)
|
||||
mf.write(md)
|
||||
mf.close()
|
||||
local_mp.append(mf.name)
|
||||
return ",".join(local_mp)
|
||||
|
||||
local_manifest_filepath = read_remote_manifests()
|
||||
dataset_params = {
|
||||
"manifest_filepath": local_manifest_filepath,
|
||||
"labels": labels,
|
||||
"featurizer": self._featurizer,
|
||||
"max_duration": max_duration,
|
||||
"min_duration": min_duration,
|
||||
"normalize": normalize_transcripts,
|
||||
"trim": trim_silence,
|
||||
"bos_id": bos_id,
|
||||
"eos_id": eos_id,
|
||||
"load_audio": load_audio,
|
||||
}
|
||||
|
||||
self._dataset = CachedAudioDataset(**dataset_params)
|
||||
self._batch_size = batch_size
|
||||
|
||||
# Set up data loader
|
||||
if self._placement == DeviceType.AllGpu:
|
||||
logging.info("Parallelizing Datalayer.")
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
self._dataset
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
if batch_size == -1:
|
||||
batch_size = len(self._dataset)
|
||||
|
||||
pad_id = 0 if pad_id is None else pad_id
|
||||
self._dataloader = torch.utils.data.DataLoader(
|
||||
dataset=self._dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
|
||||
drop_last=drop_last,
|
||||
shuffle=shuffle if sampler is None else False,
|
||||
sampler=sampler,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def data_iterator(self):
|
||||
return self._dataloader
|
||||
|
|
@ -0,0 +1,376 @@
|
|||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
# import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
# monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
# from nemo.utils.lr_policies import CosineAnnealing
|
||||
from training.data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="encoder checkpoint file: JasperEncoder.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="decoder checkpoint file: JasperDecoderForCTC.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
# data_layer = data_loader_layer(
|
||||
# manifest_filepath=args.train_dataset,
|
||||
# sample_rate=sample_rate,
|
||||
# labels=vocab,
|
||||
# batch_size=args.batch_size,
|
||||
# num_workers=cpu_per_traindl,
|
||||
# **train_dl_params,
|
||||
# # normalize_transcripts=False
|
||||
# )
|
||||
#
|
||||
# N = len(data_layer)
|
||||
# steps_per_epoch = math.ceil(
|
||||
# N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
# )
|
||||
# logging.info("Have {0} examples to train on.".format(N))
|
||||
#
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate,
|
||||
**jasper_params["AudioToMelSpectrogramPreprocessor"],
|
||||
)
|
||||
|
||||
# multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
# if multiply_batch_config:
|
||||
# multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
#
|
||||
# spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
# if spectr_augment_config:
|
||||
# data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
# **spectr_augment_config
|
||||
# )
|
||||
#
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
# if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
# else:
|
||||
# logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
# logging.info("================================")
|
||||
# logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
|
||||
# logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
|
||||
# logging.info(
|
||||
# f"Total number of parameters in model: "
|
||||
# f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
# )
|
||||
# logging.info("================================")
|
||||
#
|
||||
# # Train DAG
|
||||
# (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
|
||||
# processed_signal_t, p_length_t = data_preprocessor(
|
||||
# input_signal=audio_signal_t, length=a_sig_length_t
|
||||
# )
|
||||
#
|
||||
# if multiply_batch_config:
|
||||
# (
|
||||
# processed_signal_t,
|
||||
# p_length_t,
|
||||
# transcript_t,
|
||||
# transcript_len_t,
|
||||
# ) = multiply_batch(
|
||||
# in_x=processed_signal_t,
|
||||
# in_x_len=p_length_t,
|
||||
# in_y=transcript_t,
|
||||
# in_y_len=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# if spectr_augment_config:
|
||||
# processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
|
||||
#
|
||||
# encoded_t, encoded_len_t = jasper_encoder(
|
||||
# audio_signal=processed_signal_t, length=p_length_t
|
||||
# )
|
||||
# log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
# predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
# loss_t = ctc_loss(
|
||||
# log_probs=log_probs_t,
|
||||
# targets=transcript_t,
|
||||
# input_length=encoded_len_t,
|
||||
# target_length=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# # Callbacks needed to print info to console and Tensorboard
|
||||
# train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
# tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
# print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
# get_tb_values=lambda x: [("loss", x[0])],
|
||||
# tb_writer=neural_factory.tb_writer,
|
||||
# )
|
||||
#
|
||||
# chpt_callback = nemo.core.CheckpointCallback(
|
||||
# folder=neural_factory.checkpoint_dir,
|
||||
# load_from_folder=args.load_dir,
|
||||
# step_freq=args.checkpoint_save_freq,
|
||||
# checkpoints_to_keep=30,
|
||||
# )
|
||||
#
|
||||
# callbacks = [train_callback, chpt_callback]
|
||||
callbacks = []
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(
|
||||
audio_signal_e,
|
||||
a_sig_length_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[
|
||||
loss_e,
|
||||
predictions_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(
|
||||
process_evaluation_epoch, tag=tagname
|
||||
),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return callbacks
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
# name = construct_name(
|
||||
# args.exp_name,
|
||||
# args.lr,
|
||||
# args.batch_size,
|
||||
# args.max_steps,
|
||||
# args.num_epochs,
|
||||
# args.weight_decay,
|
||||
# args.optimizer,
|
||||
# args.iter_per_step,
|
||||
# )
|
||||
# log_dir = name
|
||||
# if args.work_dir:
|
||||
# log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
# local_rank=args.local_rank,
|
||||
# optimization_level=args.amp_opt_level,
|
||||
# log_dir=log_dir,
|
||||
# checkpoint_dir=args.checkpoint_dir,
|
||||
# create_tb_writer=args.create_tb_writer,
|
||||
# files_to_copy=[args.model_config, __file__],
|
||||
# cudnn_benchmark=args.cudnn_benchmark,
|
||||
# tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
# checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
callbacks = create_all_dags(args, neural_factory)
|
||||
# evaluate model
|
||||
neural_factory.eval(callbacks=callbacks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
# import math
|
||||
|
||||
# import librosa
|
||||
import torch
|
||||
import pickle
|
||||
|
||||
# import torch.nn as nn
|
||||
# from torch_stft import STFT
|
||||
|
||||
# from nemo import logging
|
||||
from nemo.collections.asr.parts.perturb import AudioAugmentor
|
||||
|
||||
# from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
|
||||
class RpycWaveformFeaturizer(object):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
augmentor=None,
|
||||
rpyc_conn=None,
|
||||
):
|
||||
self.augmentor = (
|
||||
augmentor if augmentor is not None else AudioAugmentor()
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.int_values = int_values
|
||||
self.remote_path_samples = rpyc_conn.get_path_samples
|
||||
|
||||
def max_augmentation_length(self, length):
|
||||
return self.augmentor.max_augmentation_length(length)
|
||||
|
||||
def process(self, file_path, offset=0, duration=0, trim=False):
|
||||
audio = self.remote_path_samples(
|
||||
file_path,
|
||||
target_sr=self.sample_rate,
|
||||
int_values=self.int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
return torch.tensor(pickle.loads(audio), dtype=torch.float)
|
||||
|
||||
def process_segment(self, audio_segment):
|
||||
self.augmentor.perturb(audio_segment)
|
||||
return torch.tensor(audio_segment, dtype=torch.float)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, input_config, perturbation_configs=None):
|
||||
if perturbation_configs is not None:
|
||||
aa = AudioAugmentor.from_config(perturbation_configs)
|
||||
else:
|
||||
aa = None
|
||||
|
||||
sample_rate = input_config.get("sample_rate", 16000)
|
||||
int_values = input_config.get("int_values", False)
|
||||
|
||||
return cls(
|
||||
sample_rate=sample_rate, int_values=int_values, augmentor=aa
|
||||
)
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import typer
|
||||
|
||||
# from .asr import JasperASR
|
||||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
JasperASR = lazy_callable("plume.models.jasper_nemo.asr.JasperASR")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc(
|
||||
encoder_path: Path = "/path/to/encoder.pt",
|
||||
decoder_path: Path = "/path/to/decoder.pt",
|
||||
model_yaml_path: Path = "/path/to/model.yaml",
|
||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
):
|
||||
for p in [encoder_path, decoder_path, model_yaml_path]:
|
||||
if not p.exists():
|
||||
logging.info(f"{p} doesn't exists")
|
||||
return
|
||||
asr = JasperASR(str(model_yaml_path), str(encoder_path), str(decoder_path))
|
||||
service = ASRService(asr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logging.info("starting asr server...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(
|
||||
model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||
):
|
||||
encoder_path = model_dir / Path("decoder.pt")
|
||||
decoder_path = model_dir / Path("encoder.pt")
|
||||
model_yaml_path = model_dir / Path("model.yaml")
|
||||
rpyc(encoder_path, decoder_path, model_yaml_path, port)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import rpyc
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import nemo
|
||||
import pickle
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
|
||||
)
|
||||
|
||||
|
||||
class ASRDataService(rpyc.Service):
|
||||
def exposed_get_path_samples(
|
||||
self, file_path, target_sr, int_values, offset, duration, trim
|
||||
):
|
||||
print(f"loading.. {file_path}")
|
||||
audio = AudioSegment.from_file(
|
||||
file_path,
|
||||
target_sr=target_sr,
|
||||
int_values=int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
|
||||
return pickle.dumps(audio.samples)
|
||||
|
||||
def exposed_read_path(self, file_path):
|
||||
# print(f"reading path.. {file_path}")
|
||||
return Path(file_path).read_bytes()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run_server(port: int = 0):
|
||||
listen_port = (
|
||||
port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
|
||||
)
|
||||
service = ASRDataService()
|
||||
t = ThreadedServer(
|
||||
service, port=listen_port, protocol_config={"allow_all_attrs": True}
|
||||
)
|
||||
typer.echo(f"starting asr server on {listen_port}...")
|
||||
t.start()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,392 @@
|
|||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
from nemo.utils.lr_policies import CosineAnnealing
|
||||
from .data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper-speller",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
data_layer = data_loader_layer(
|
||||
manifest_filepath=args.train_dataset,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**train_dl_params,
|
||||
# normalize_transcripts=False
|
||||
)
|
||||
|
||||
N = len(data_layer)
|
||||
steps_per_epoch = math.ceil(
|
||||
N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
)
|
||||
logging.info("Have {0} examples to train on.".format(N))
|
||||
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate,
|
||||
**jasper_params["AudioToMelSpectrogramPreprocessor"],
|
||||
)
|
||||
|
||||
multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
if multiply_batch_config:
|
||||
multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
|
||||
spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
if spectr_augment_config:
|
||||
data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
**spectr_augment_config
|
||||
)
|
||||
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
else:
|
||||
logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
logging.info("================================")
|
||||
logging.info(
|
||||
f"Number of parameters in encoder: {jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info(
|
||||
f"Number of parameters in decoder: {jasper_decoder.num_weights}"
|
||||
)
|
||||
logging.info(
|
||||
f"Total number of parameters in model: "
|
||||
f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info("================================")
|
||||
|
||||
# Train DAG
|
||||
(
|
||||
audio_signal_t,
|
||||
a_sig_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = data_layer()
|
||||
processed_signal_t, p_length_t = data_preprocessor(
|
||||
input_signal=audio_signal_t, length=a_sig_length_t
|
||||
)
|
||||
|
||||
if multiply_batch_config:
|
||||
(
|
||||
processed_signal_t,
|
||||
p_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = multiply_batch(
|
||||
in_x=processed_signal_t,
|
||||
in_x_len=p_length_t,
|
||||
in_y=transcript_t,
|
||||
in_y_len=transcript_len_t,
|
||||
)
|
||||
|
||||
if spectr_augment_config:
|
||||
processed_signal_t = data_spectr_augmentation(
|
||||
input_spec=processed_signal_t
|
||||
)
|
||||
|
||||
encoded_t, encoded_len_t = jasper_encoder(
|
||||
audio_signal=processed_signal_t, length=p_length_t
|
||||
)
|
||||
log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
loss_t = ctc_loss(
|
||||
log_probs=log_probs_t,
|
||||
targets=transcript_t,
|
||||
input_length=encoded_len_t,
|
||||
target_length=transcript_len_t,
|
||||
)
|
||||
|
||||
# Callbacks needed to print info to console and Tensorboard
|
||||
train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
get_tb_values=lambda x: [("loss", x[0])],
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
chpt_callback = nemo.core.CheckpointCallback(
|
||||
folder=neural_factory.checkpoint_dir,
|
||||
load_from_folder=args.load_dir,
|
||||
step_freq=args.checkpoint_save_freq,
|
||||
checkpoints_to_keep=30,
|
||||
)
|
||||
|
||||
callbacks = [train_callback, chpt_callback]
|
||||
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(
|
||||
audio_signal_e,
|
||||
a_sig_length_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[
|
||||
loss_e,
|
||||
predictions_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(
|
||||
process_evaluation_epoch, tag=tagname
|
||||
),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return loss_t, callbacks, steps_per_epoch
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
name = construct_name(
|
||||
args.exp_name,
|
||||
args.lr,
|
||||
args.batch_size,
|
||||
args.max_steps,
|
||||
args.num_epochs,
|
||||
args.weight_decay,
|
||||
args.optimizer,
|
||||
args.iter_per_step,
|
||||
)
|
||||
log_dir = name
|
||||
if args.work_dir:
|
||||
log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
local_rank=args.local_rank,
|
||||
optimization_level=args.amp_opt_level,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
create_tb_writer=args.create_tb_writer,
|
||||
files_to_copy=[args.model_config, __file__],
|
||||
cudnn_benchmark=args.cudnn_benchmark,
|
||||
tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
train_loss, callbacks, steps_per_epoch = create_all_dags(
|
||||
args, neural_factory
|
||||
)
|
||||
# train model
|
||||
neural_factory.train(
|
||||
tensors_to_optimize=[train_loss],
|
||||
callbacks=callbacks,
|
||||
lr_policy=CosineAnnealing(
|
||||
args.max_steps
|
||||
if args.max_steps is not None
|
||||
else args.num_epochs * steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
),
|
||||
optimizer=args.optimizer,
|
||||
optimization_params={
|
||||
"num_epochs": args.num_epochs,
|
||||
"max_steps": args.max_steps,
|
||||
"lr": args.lr,
|
||||
"betas": (args.beta1, args.beta2),
|
||||
"weight_decay": args.weight_decay,
|
||||
"grad_norm_clip": None,
|
||||
},
|
||||
batches_per_step=args.iter_per_step,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
import copy
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
import matplotlib.pyplot as plt
|
||||
import IPython.display as ipd
|
||||
# import pyaudio as pa
|
||||
import librosa
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
# sample rate, Hz
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained(
|
||||
"vad_marblenet"
|
||||
)
|
||||
# Preserve a copy of the full config
|
||||
cfg = copy.deepcopy(vad_model._cfg)
|
||||
# print(OmegaConf.to_yaml(cfg))
|
||||
|
|
@ -0,0 +1,234 @@
|
|||
from pathlib import Path
|
||||
from collections import Counter
|
||||
import shutil
|
||||
import io
|
||||
|
||||
# from time import time
|
||||
|
||||
# import pydub
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from plume.utils import (
|
||||
ExtendedPath,
|
||||
replace_redundant_spaces_with,
|
||||
lazy_module,
|
||||
random_segs,
|
||||
parallel_apply,
|
||||
batch,
|
||||
run_shell,
|
||||
)
|
||||
|
||||
from plume.utils.vad import VADUtterance
|
||||
|
||||
soundfile = lazy_module("soundfile")
|
||||
pydub = lazy_module("pydub")
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True):
|
||||
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
|
||||
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
|
||||
tok_counter = Counter()
|
||||
shutil.copy(
|
||||
src_dataset_path / Path("test_manifest.json"),
|
||||
src_dataset_path / Path("valid_manifest.json"),
|
||||
)
|
||||
if unlink:
|
||||
src_wavs = src_dataset_path / Path("wavs")
|
||||
for wav_path in tqdm(list(src_wavs.glob("**/*.wav"))):
|
||||
audio_seg = (
|
||||
pydub.AudioSegment.from_wav(wav_path)
|
||||
.set_frame_rate(16000)
|
||||
.set_channels(1)
|
||||
)
|
||||
dest_path = dest_dataset_path / Path("wavs") / Path(wav_path.name)
|
||||
audio_seg.export(dest_path, format="wav")
|
||||
|
||||
for dataset_kind in ["train", "valid"]:
|
||||
abs_manifest_path = ExtendedPath(
|
||||
src_dataset_path / Path(f"{dataset_kind}_manifest.json")
|
||||
)
|
||||
manifest_data = list(abs_manifest_path.read_jsonl())
|
||||
o_tsv, o_ltr = f"{dataset_kind}.tsv", f"{dataset_kind}.ltr"
|
||||
out_tsv = dest_dataset_path / Path(o_tsv)
|
||||
out_ltr = dest_dataset_path / Path(o_ltr)
|
||||
with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f:
|
||||
if unlink:
|
||||
tsv_f.write(f"{dest_dataset_path}\n")
|
||||
else:
|
||||
tsv_f.write(f"{src_dataset_path}\n")
|
||||
for md in manifest_data:
|
||||
audio_fname = md["audio_filepath"]
|
||||
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()
|
||||
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
|
||||
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
|
||||
tok_counter.update(pipe_toks)
|
||||
letter_toks = " ".join(pipe_toks) + " |\n"
|
||||
frame_count = soundfile.info(audio_fname).frames
|
||||
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute())
|
||||
ltr_f.write(letter_toks)
|
||||
tsv_f.write(f"{rel_path}\t{frame_count}\n")
|
||||
with dict_ltr.open("w") as d_f:
|
||||
for k, v in tok_counter.most_common():
|
||||
d_f.write(f"{k} {v}\n")
|
||||
(src_dataset_path / Path("valid_manifest.json")).unlink()
|
||||
|
||||
|
||||
@app.command()
|
||||
def set_root(dataset_path: Path, root_path: Path):
|
||||
for dataset_kind in ["train", "valid"]:
|
||||
data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||
with data_file.open("r") as df:
|
||||
lines = df.readlines()
|
||||
with data_file.open("w") as df:
|
||||
lines[0] = str(root_path) + "\n"
|
||||
df.writelines(lines)
|
||||
|
||||
|
||||
@app.command()
|
||||
def convert_audio(log_dir: Path, out_dir: Path):
|
||||
out_dir.mkdir(exist_ok=True, parents=True)
|
||||
all_wavs = list((log_dir).glob("**/*.wav"))
|
||||
name_wav_map = {i.name: i.absolute() for i in all_wavs}
|
||||
exists_wavs = list((out_dir).glob("**/*.wav"))
|
||||
rem_wavs = list(
|
||||
set((i.name for i in all_wavs)) - set((i.name for i in exists_wavs))
|
||||
)
|
||||
rem_wavs_real = [name_wav_map[i] for i in rem_wavs]
|
||||
|
||||
def resample_audio(i):
|
||||
dest_wav = out_dir / i.name
|
||||
if dest_wav.exists():
|
||||
return
|
||||
run_shell(f"ffmpeg -i {i.absolute()} -ac 1 -ar 16000 {dest_wav}", verbose=False)
|
||||
|
||||
parallel_apply(resample_audio, rem_wavs_real, workers=256)
|
||||
|
||||
|
||||
@app.command()
|
||||
def prepare_pretraining(
|
||||
log_dir: Path,
|
||||
dataset_path: Path,
|
||||
format: str = "wav",
|
||||
method: str = "random",
|
||||
max_silence: int = 3000,
|
||||
min_duration: int = 10000,
|
||||
max_duration: int = 30000,
|
||||
fixed_duration: int = 30000,
|
||||
batch_size: int = 100,
|
||||
):
|
||||
audio_dir = dataset_path / "audio"
|
||||
audio_dir.mkdir(exist_ok=True, parents=True)
|
||||
cache_dir = dataset_path / "cache"
|
||||
cache_dir.mkdir(exist_ok=True, parents=True)
|
||||
all_wavs = list((log_dir).glob("**/*.wav"))
|
||||
if method not in ["vad", "random", "fixed"]:
|
||||
typer.echo("should be one of random|fixed")
|
||||
raise typer.Exit()
|
||||
|
||||
def write_seg_arg(arg):
|
||||
seg, dest_wav = arg
|
||||
ob = io.BytesIO()
|
||||
seg.export(ob, format=format)
|
||||
dest_wav.write_bytes(ob.getvalue())
|
||||
ob.close()
|
||||
|
||||
with (dataset_path / "failed.log").open("w") as fl:
|
||||
vad_utt = VADUtterance(
|
||||
max_silence=max_silence,
|
||||
min_utterance=min_duration,
|
||||
max_utterance=max_duration,
|
||||
)
|
||||
|
||||
def vad_process_wav(wav_path):
|
||||
if (cache_dir / wav_path.stem).exists():
|
||||
return []
|
||||
try:
|
||||
aud_seg = pydub.AudioSegment.from_file(wav_path)
|
||||
except pydub.exceptions.CouldntDecodeError:
|
||||
fl.write(wav_path.name + "\n")
|
||||
return []
|
||||
full_seg = aud_seg
|
||||
# segs = random_segs(len(full_seg), min_duration, max_duration)
|
||||
segs = vad_utt.stream_segments(full_seg)
|
||||
audio_chunk_paths = []
|
||||
if len(full_seg) > min_duration:
|
||||
for (i, chunk_seg) in enumerate(segs):
|
||||
dest_wav = audio_dir / (wav_path.stem + f"_{i}.{format}")
|
||||
if dest_wav.exists():
|
||||
continue
|
||||
audio_chunk_paths.append((chunk_seg, dest_wav))
|
||||
(cache_dir / wav_path.stem).touch()
|
||||
return audio_chunk_paths
|
||||
|
||||
def random_process_wav(wav_path):
|
||||
if (cache_dir / wav_path.stem).exists():
|
||||
return []
|
||||
try:
|
||||
aud_seg = pydub.AudioSegment.from_file(wav_path)
|
||||
except pydub.exceptions.CouldntDecodeError:
|
||||
fl.write(wav_path.name + "\n")
|
||||
return []
|
||||
full_seg = aud_seg
|
||||
segs = random_segs(len(full_seg), min_duration, max_duration)
|
||||
audio_chunk_paths = []
|
||||
if len(full_seg) > min_duration:
|
||||
for (i, (start, end)) in enumerate(segs):
|
||||
dest_wav = audio_dir / (wav_path.stem + f"_{i}.{format}")
|
||||
if dest_wav.exists():
|
||||
continue
|
||||
chunk_seg = aud_seg[start:end]
|
||||
audio_chunk_paths.append((chunk_seg, dest_wav))
|
||||
(cache_dir / wav_path.stem).touch()
|
||||
return audio_chunk_paths
|
||||
|
||||
def fixed_process_wav(wav_path):
|
||||
if (cache_dir / wav_path.stem).exists():
|
||||
return []
|
||||
try:
|
||||
aud_seg = pydub.AudioSegment.from_file(wav_path)
|
||||
except pydub.exceptions.CouldntDecodeError:
|
||||
fl.write(wav_path.name + "\n")
|
||||
return []
|
||||
full_seg = aud_seg
|
||||
audio_chunk_paths = []
|
||||
if len(full_seg) > min_duration:
|
||||
for (i, chunk_seg) in enumerate(full_seg[::fixed_duration]):
|
||||
dest_wav = audio_dir / (wav_path.stem + f"_{i}.{format}")
|
||||
if dest_wav.exists() or len(chunk_seg) < min_duration:
|
||||
continue
|
||||
audio_chunk_paths.append((chunk_seg, dest_wav))
|
||||
(cache_dir / wav_path.stem).touch()
|
||||
return audio_chunk_paths
|
||||
|
||||
# warmup
|
||||
pydub.AudioSegment.from_file(all_wavs[0])
|
||||
# parallel_apply(process_wav, all_wavs, pool='process')
|
||||
# parallel_apply(process_wav, all_wavs)
|
||||
seg_f = (
|
||||
vad_process_wav
|
||||
if method == "vad"
|
||||
else (random_process_wav if method == "random" else fixed_process_wav)
|
||||
)
|
||||
for wp_batch in tqdm(batch(all_wavs, n=batch_size)):
|
||||
acp_batch = parallel_apply(seg_f, wp_batch)
|
||||
# acp_batch = list(map(seg_f, tqdm(wp_batch)))
|
||||
flat_acp_batch = [sd for acp in acp_batch for sd in acp]
|
||||
parallel_apply(write_seg_arg, flat_acp_batch)
|
||||
# for acp in acp_batch:
|
||||
# for (seg, des) in acp:
|
||||
# seg.export(des)
|
||||
# for seg_des in tqdm(flat_acp_batch):
|
||||
# write_seg_arg(seg_des)
|
||||
del flat_acp_batch
|
||||
del acp_batch
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
|
||||
|
||||
# import soundfile as sf
|
||||
from io import BytesIO
|
||||
import torch
|
||||
|
||||
from plume.utils import lazy_module
|
||||
|
||||
sf = lazy_module("soundfile")
|
||||
|
||||
|
||||
class Wav2Vec2TransformersASR(object):
|
||||
"""docstring for Wav2Vec2TransformersASR."""
|
||||
|
||||
def __init__(self, ctc_path, w2v_path, target_dict_path):
|
||||
super(Wav2Vec2TransformersASR, self).__init__()
|
||||
self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
|
||||
"facebook/wav2vec2-large-960h-lv60-self"
|
||||
)
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(
|
||||
"facebook/wav2vec2-large-960h-lv60-self"
|
||||
)
|
||||
|
||||
def transcribe(self, audio_data):
|
||||
aud_f = BytesIO(audio_data)
|
||||
# net_input = {}
|
||||
speech_data, _ = sf.read(aud_f)
|
||||
input_values = self.tokenizer(
|
||||
speech_data, return_tensors="pt", padding="longest"
|
||||
).input_values # Batch size 1
|
||||
|
||||
# retrieve logits
|
||||
logits = self.model(input_values).logits
|
||||
|
||||
# take argmax and decode
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
transcription = self.tokenizer.batch_decode(predicted_ids)[0]
|
||||
return transcription
|
||||
|
|
@ -2,7 +2,6 @@ from pathlib import Path
|
|||
from collections import Counter
|
||||
import shutil
|
||||
|
||||
import soundfile
|
||||
# import pydub
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
|
@ -12,8 +11,8 @@ from plume.utils import (
|
|||
replace_redundant_spaces_with,
|
||||
lazy_module
|
||||
)
|
||||
soundfile = lazy_module('soundfile')
|
||||
pydub = lazy_module('pydub')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
from pathlib import Path
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
# import pandas as pd
|
||||
|
||||
from plume.utils import (
|
||||
asr_manifest_reader,
|
||||
discard_except_digits,
|
||||
replace_digit_symbol,
|
||||
lazy_module
|
||||
# run_shell,
|
||||
)
|
||||
from ...utils.transcribe import triton_transcribe_grpc_gen, transcribe_rpyc_gen
|
||||
|
||||
pd = lazy_module('pandas')
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False):
|
||||
from pydub import AudioSegment
|
||||
|
||||
host = "localhost"
|
||||
port = 8044
|
||||
if rpyc:
|
||||
transcriber, audio_prep = transcribe_rpyc_gen(host, port)
|
||||
else:
|
||||
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole')
|
||||
result_path = manifest_file.parent / result_file
|
||||
manifest_list = list(asr_manifest_reader(manifest_file))
|
||||
|
||||
def compute_frame(d):
|
||||
audio_file = d["audio_path"]
|
||||
orig_text = d["text"]
|
||||
orig_num = discard_except_digits(replace_digit_symbol(orig_text))
|
||||
aud_seg = AudioSegment.from_file(audio_file)
|
||||
t_audio = audio_prep(aud_seg)
|
||||
asr_text = transcriber(t_audio)
|
||||
asr_num = discard_except_digits(replace_digit_symbol(asr_text))
|
||||
return {
|
||||
"audio_file": audio_file,
|
||||
"asr_text": asr_text,
|
||||
"asr_num": asr_num,
|
||||
"orig_text": orig_text,
|
||||
"orig_num": orig_num,
|
||||
"asr_match": orig_num == asr_num,
|
||||
}
|
||||
|
||||
# df_data = parallel_apply(compute_frame, manifest_list)
|
||||
df_data = map(compute_frame, tqdm(manifest_list))
|
||||
df = pd.DataFrame(df_data)
|
||||
df.to_csv(result_path)
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# from rpyc.utils.server import ThreadedServer
|
||||
import typer
|
||||
|
||||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
# from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR
|
||||
# from .asr import Wav2Vec2ASR
|
||||
|
||||
ThreadedServer = lazy_callable("rpyc.utils.server.ThreadedServer")
|
||||
Wav2Vec2TransformersASR = lazy_callable(
|
||||
"plume.models.wav2vec2_transformers.asr.Wav2Vec2TransformersASR"
|
||||
)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc(
|
||||
w2v_path: Path = "/path/to/base.pt",
|
||||
ctc_path: Path = "/path/to/ctc.pt",
|
||||
target_dict_path: Path = "/path/to/dict.ltr.txt",
|
||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
):
|
||||
w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path)
|
||||
service = ASRService(w2vasr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logging.info("starting asr server...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
|
||||
ctc_path = model_dir / Path("ctc.pt")
|
||||
w2v_path = model_dir / Path("base.pt")
|
||||
target_dict_path = model_dir / Path("dict.ltr.txt")
|
||||
rpyc(w2v_path, ctc_path, target_dict_path, port)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
|
||||
from datasets import load_dataset
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
# load model and tokenizer
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
||||
|
||||
|
||||
# define function to read in sound file
|
||||
def map_to_array(batch):
|
||||
speech, _ = sf.read(batch["file"])
|
||||
batch["speech"] = speech
|
||||
return batch
|
||||
|
||||
|
||||
# load dummy dataset and read soundfiles
|
||||
def main():
|
||||
ds = load_dataset(
|
||||
"patrickvonplaten/librispeech_asr_dummy", "clean", split="validation"
|
||||
)
|
||||
ds = ds.map(map_to_array)
|
||||
|
||||
# tokenize
|
||||
input_values = tokenizer(
|
||||
ds["speech"][:2], return_tensors="pt", padding="longest"
|
||||
).input_values # Batch size 1
|
||||
|
||||
# retrieve logits
|
||||
logits = model(input_values).logits
|
||||
|
||||
# take argmax and decode
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
transcription = tokenizer.batch_decode(predicted_ids)
|
||||
print(transcription)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
import typer
|
||||
# from fairseq_cli.train import cli_main
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import shlex
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
cli_main = lazy_callable('fairseq_cli.train.cli_main')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def local(dataset_path: Path):
|
||||
args = f'''--distributed-world-size 1 {dataset_path} \
|
||||
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
|
||||
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
|
||||
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
|
||||
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
|
||||
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
|
||||
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
|
||||
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
|
||||
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
|
||||
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
|
||||
'''
|
||||
new_args = ['train.py']
|
||||
new_args.extend(shlex.split(args))
|
||||
sys.argv = new_args
|
||||
cli_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
|
|
@ -46,9 +46,40 @@ def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str =
|
|||
|
||||
|
||||
@app.command()
|
||||
def preview(manifest_path: Path):
|
||||
def preview(manifest_path: Path, port: int = 8081):
|
||||
annotation_lit_path = Path(__file__).parent / Path("preview.py")
|
||||
sys.argv = ["streamlit", "run", str(annotation_lit_path), "--", str(manifest_path)]
|
||||
sys.argv = [
|
||||
"streamlit",
|
||||
"run",
|
||||
"--server.port",
|
||||
str(port),
|
||||
str(annotation_lit_path),
|
||||
"--",
|
||||
str(manifest_path),
|
||||
]
|
||||
sys.exit(stcli.main())
|
||||
|
||||
|
||||
@app.command()
|
||||
def encrypted_preview(manifest_path: Path, key: str, port: int = 8081):
|
||||
lit_path = Path(__file__).parent / Path("encrypted_preview.py")
|
||||
sys.argv = [
|
||||
"streamlit",
|
||||
"run",
|
||||
"--server.port",
|
||||
str(port),
|
||||
str(lit_path),
|
||||
"--",
|
||||
str(manifest_path),
|
||||
str(key),
|
||||
]
|
||||
sys.exit(stcli.main())
|
||||
|
||||
|
||||
@app.command()
|
||||
def audio(audio_dir: Path):
|
||||
lit_path = Path(__file__).parent / Path("audio.py")
|
||||
sys.argv = ["streamlit", "run", str(lit_path), "--", str(audio_dir)]
|
||||
sys.exit(stcli.main())
|
||||
|
||||
|
||||
|
|
@ -13,9 +13,9 @@ setup_mongo_asr_validation_state(st)
|
|||
|
||||
@st.cache()
|
||||
def load_ui_data(data_dir: Path, dump_fname: Path):
|
||||
validation_ui_data_path = data_dir / dump_fname
|
||||
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||
return ExtendedPath(validation_ui_data_path).read_json()
|
||||
annotation_ui_data_path = data_dir / dump_fname
|
||||
typer.echo(f"Using annotation ui data from {annotation_ui_data_path}")
|
||||
return ExtendedPath(annotation_ui_data_path).read_json()
|
||||
|
||||
|
||||
def show_key(sample, key, trail=""):
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(wav_dir: Path):
|
||||
wav_file = list(wav_dir.glob('**/*.wav'))[0]
|
||||
st.title("Audio Preview")
|
||||
print(wav_file.exists())
|
||||
st.audio(str(wav_dir / wav_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
import typer
|
||||
from plume.utils import ExtendedPath, wav_cryptor, text_cryptor
|
||||
from plume.utils.ui_persist import setup_file_state
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
setup_file_state(st)
|
||||
|
||||
|
||||
@st.cache()
|
||||
def load_ui_data(validation_ui_data_path: Path):
|
||||
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(manifest: Path, key: str):
|
||||
wc = wav_cryptor(key)
|
||||
tc = text_cryptor(key)
|
||||
asr_data = load_ui_data(manifest)
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
st.title("ASR Manifest Preview")
|
||||
gt_text = tc.decrypt_text(sample["text"].encode("utf-8"))
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{gt_text}**")
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.markdown(f"Gold Text: **{gt_text}**")
|
||||
wav = wc.decrypt_wav_path((manifest.parent / Path(sample["audio_filepath"])))
|
||||
st.audio(wav)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
|
|
@ -4,20 +4,20 @@ import re
|
|||
import json
|
||||
import wave
|
||||
import logging
|
||||
import subprocess
|
||||
import shutil
|
||||
import random
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from uuid import uuid4
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
import subprocess
|
||||
import shutil
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
# from .lazy_loader import LazyLoader
|
||||
from .lazy_import import lazy_callable, lazy_module
|
||||
|
||||
# from ruamel.yaml import YAML
|
||||
# import boto3
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
# import pymongo
|
||||
# from slugify import slugify
|
||||
|
|
@ -27,14 +27,29 @@ import typer
|
|||
# import librosa.display as audio_display
|
||||
# from natural.date import compress
|
||||
# from num2words import num2words
|
||||
from tqdm import tqdm
|
||||
from datetime import timedelta
|
||||
import datetime
|
||||
import six
|
||||
|
||||
# from .transcribe import triton_transcribe_grpc_gen
|
||||
# from .eval import app as eval_app
|
||||
from .manifest import asr_manifest_writer, manifest_str
|
||||
from .lazy_import import lazy_callable, lazy_module
|
||||
from .parallel import parallel_apply
|
||||
from .extended_path import ExtendedPath
|
||||
from .tts import app as tts_app
|
||||
from .transcribe import app as transcribe_app
|
||||
from .align import app as align_app
|
||||
from .encrypt import app as encrypt_app, wav_cryptor, text_cryptor # noqa
|
||||
from .regentity import ( # noqa
|
||||
num_replacer,
|
||||
alnum_replacer,
|
||||
num_keeper,
|
||||
alnum_keeper,
|
||||
default_num_rules,
|
||||
default_num_only_rules,
|
||||
default_alnum_rules,
|
||||
entity_replacer_keeper,
|
||||
)
|
||||
|
||||
boto3 = lazy_module("boto3")
|
||||
pymongo = lazy_module("pymongo")
|
||||
|
|
@ -45,9 +60,9 @@ librosa = lazy_module("librosa")
|
|||
YAML = lazy_callable("ruamel.yaml.YAML")
|
||||
num2words = lazy_callable("num2words.num2words")
|
||||
slugify = lazy_callable("slugify.slugify")
|
||||
compress = lazy_callable("natural.date.compress")
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(encrypt_app)
|
||||
app.add_typer(tts_app, name="tts")
|
||||
app.add_typer(align_app, name="align")
|
||||
app.add_typer(transcribe_app, name="transcribe")
|
||||
|
|
@ -60,31 +75,164 @@ def utils():
|
|||
"""
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
log_fmt_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=log_fmt_str)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def manifest_str(path, dur, text):
|
||||
return (
|
||||
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
|
||||
+ "\n"
|
||||
)
|
||||
# Precalculated timestamps
|
||||
TIME_MINUTE = 60
|
||||
TIME_HOUR = 3600
|
||||
TIME_DAY = 86400
|
||||
TIME_WEEK = 604800
|
||||
|
||||
|
||||
def duration_str(seconds):
|
||||
return compress(timedelta(seconds=seconds), pad=" ")
|
||||
def compress(t, show_hours=False, sign=False, pad=""):
|
||||
"""
|
||||
Convert the input to compressed format, works with a
|
||||
:class:`datetime.timedelta` object or a number that represents the number
|
||||
of seconds you want to compress. If you supply a timestamp or a
|
||||
:class:`datetime.datetime` object, it will give the delta relative to the
|
||||
current time.
|
||||
You can enable showing a sign in front of the compressed format with the
|
||||
``sign`` parameter, the default is not to show signs.
|
||||
Optionally, you can chose to pad the output. If you wish your values to be
|
||||
separated by spaces, set ``pad`` to ``' '``.
|
||||
:param t: seconds or :class:`datetime.timedelta` object
|
||||
:param sign: default ``False``
|
||||
:param pad: default ``''``
|
||||
>>> print(compress(0))
|
||||
0s
|
||||
>>> print(compress(1))
|
||||
1s
|
||||
>>> print(compress(12))
|
||||
12s
|
||||
>>> print(compress(123))
|
||||
2m3s
|
||||
>>> print(compress(1234))
|
||||
20m34s
|
||||
>>> print(compress(12345))
|
||||
3h25m45s
|
||||
>>> print(compress(123456))
|
||||
1d10h17m36s
|
||||
==============
|
||||
src: https://github.com/tehmaze/natural/blob/master/natural/date.py
|
||||
"""
|
||||
|
||||
if isinstance(t, datetime.timedelta):
|
||||
seconds = t.seconds + (t.days * 86400)
|
||||
elif isinstance(t, six.integer_types + (float,)):
|
||||
return compress(datetime.timedelta(seconds=t), sign, pad)
|
||||
else:
|
||||
raise Exception("Invalid time format")
|
||||
|
||||
parts = []
|
||||
if sign:
|
||||
parts.append("-" if t.days < 0 else "+")
|
||||
|
||||
if not show_hours:
|
||||
weeks, seconds = divmod(seconds, TIME_WEEK)
|
||||
days, seconds = divmod(seconds, TIME_DAY)
|
||||
hours, seconds = divmod(seconds, TIME_HOUR)
|
||||
minutes, seconds = divmod(seconds, TIME_MINUTE)
|
||||
|
||||
if not show_hours:
|
||||
if weeks:
|
||||
parts.append(("%dw") % (weeks,))
|
||||
if days:
|
||||
parts.append(("%dd") % (days,))
|
||||
if hours:
|
||||
parts.append(("%dh") % (hours,))
|
||||
if minutes:
|
||||
parts.append(("%dm") % (minutes,))
|
||||
if seconds or len(parts) == 0:
|
||||
parts.append(("%ds") % (seconds,))
|
||||
|
||||
return pad.join(parts)
|
||||
|
||||
|
||||
def replace_digit_symbol(w2v_out):
|
||||
num_int_map = {num2words(i): str(i) for i in range(10)}
|
||||
def duration_str(seconds, show_hours=False):
|
||||
t = datetime.timedelta(seconds=seconds)
|
||||
return compress(t, show_hours=show_hours, pad=" ")
|
||||
|
||||
|
||||
def replace_digit_symbol(w2v_out, num_range=10):
|
||||
def rep_i(i):
|
||||
return (num2words(i).replace("-", " "), str(i))
|
||||
|
||||
num_int_map = [rep_i(i) for i in reversed(range(num_range))]
|
||||
out = w2v_out.lower()
|
||||
for (k, v) in num_int_map.items():
|
||||
for (k, v) in num_int_map:
|
||||
out = re.sub(k, v, out)
|
||||
return out
|
||||
|
||||
|
||||
def num_keeper_orig(num_range=10, extra_rules=[]):
|
||||
num_int_map_ty = [
|
||||
(
|
||||
r"\b" + num2words(i) + r"\b",
|
||||
" " + str(i) + " ",
|
||||
)
|
||||
for i in reversed(range(num_range))
|
||||
]
|
||||
re_rules = [
|
||||
(re.compile(k, re.IGNORECASE), v)
|
||||
for (k, v) in [
|
||||
# (r"[ ;,.]", " "),
|
||||
(r"\bdouble(?: |-)(\w+)\b", "\\1 \\1"),
|
||||
(r"\btriple(?: |-)(\w+)\b", "\\1 \\1 \\1"),
|
||||
(r"hundred", "00"),
|
||||
(r"\boh\b", " 0 "),
|
||||
(r"\bo\b", " 0 "),
|
||||
]
|
||||
+ num_int_map_ty
|
||||
] + [(re.compile(k), v) for (k, v) in extra_rules]
|
||||
|
||||
def merge_intervals(intervals):
|
||||
# https://codereview.stackexchange.com/a/69249
|
||||
sorted_by_lower_bound = sorted(intervals, key=lambda tup: tup[0])
|
||||
merged = []
|
||||
|
||||
for higher in sorted_by_lower_bound:
|
||||
if not merged:
|
||||
merged.append(higher)
|
||||
else:
|
||||
lower = merged[-1]
|
||||
# test for intersection between lower and higher:
|
||||
# we know via sorting that lower[0] <= higher[0]
|
||||
if higher[0] <= lower[1]:
|
||||
upper_bound = max(lower[1], higher[1])
|
||||
merged[-1] = (
|
||||
lower[0],
|
||||
upper_bound,
|
||||
) # replace by merged interval
|
||||
else:
|
||||
merged.append(higher)
|
||||
return merged
|
||||
|
||||
# merging interval tree for optimal # https://www.geeksforgeeks.org/interval-tree/
|
||||
|
||||
def keep_numeric_literals(w2v_out):
|
||||
# out = w2v_out.lower()
|
||||
out = re.sub(r"[ ;,.]", " ", w2v_out).strip()
|
||||
# out = " " + out.strip() + " "
|
||||
# out = re.sub(r"double (\w+)", "\\1 \\1", out)
|
||||
# out = re.sub(r"triple (\w+)", "\\1 \\1 \\1", out)
|
||||
num_spans = []
|
||||
for (k, v) in re_rules: # [94:]:
|
||||
matches = k.finditer(out)
|
||||
for m in matches:
|
||||
# num_spans.append((k, m.span()))
|
||||
num_spans.append(m.span())
|
||||
# out = re.sub(k, v, out)
|
||||
merged = merge_intervals(num_spans)
|
||||
num_ents = len(merged)
|
||||
keep_out = " ".join((out[s[0] : s[1]] for s in merged))
|
||||
return keep_out, num_ents
|
||||
|
||||
return keep_numeric_literals
|
||||
|
||||
|
||||
def discard_except_digits(inp):
|
||||
return re.sub("[^0-9]", "", inp)
|
||||
|
||||
|
|
@ -103,6 +251,26 @@ def space_out(text):
|
|||
return letters
|
||||
|
||||
|
||||
def random_segs(total, min_val, max_val):
|
||||
out_list = []
|
||||
rand_total = prev_start = 0
|
||||
while True:
|
||||
if total < rand_total + min_val or total < rand_total:
|
||||
break
|
||||
sample = random.randint(min_val, max_val)
|
||||
if total - rand_total < max_val:
|
||||
break
|
||||
if total - rand_total < max_val + min_val:
|
||||
sample = random.randint(min_val, max_val - min_val)
|
||||
prev_start = rand_total
|
||||
if 0 < rand_total + sample - total < max_val:
|
||||
break
|
||||
rand_total += sample
|
||||
out_list.append((prev_start, rand_total))
|
||||
out_list.append((rand_total, total))
|
||||
return out_list
|
||||
|
||||
|
||||
def wav_bytes(audio_bytes, frame_rate=24000):
|
||||
wf_b = io.BytesIO()
|
||||
with wave.open(wf_b, mode="w") as wf:
|
||||
|
|
@ -117,17 +285,20 @@ def tscript_uuid_fname(transcript):
|
|||
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
|
||||
|
||||
def run_shell(cmd_str, work_dir="."):
|
||||
def run_shell(cmd_str, work_dir=".", verbose=True):
|
||||
cwd_path = Path(work_dir).absolute()
|
||||
p = subprocess.Popen(
|
||||
cmd_str,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
shell=True,
|
||||
cwd=cwd_path,
|
||||
)
|
||||
for line in p.stdout:
|
||||
print(line.replace(b"\n", b"").decode("utf-8"))
|
||||
if verbose:
|
||||
with subprocess.Popen(
|
||||
cmd_str,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
shell=True,
|
||||
cwd=cwd_path,
|
||||
) as p:
|
||||
for line in p.stdout:
|
||||
print(line.replace(b"\n", b"").decode("utf-8"))
|
||||
else:
|
||||
subprocess.run(cmd_str, shell=True, cwd=cwd_path, capture_output=True)
|
||||
|
||||
|
||||
def upload_s3(dataset_path, s3_path):
|
||||
|
|
@ -154,7 +325,8 @@ def s3_downloader():
|
|||
if not download_path.exists():
|
||||
if verbose:
|
||||
print(f"downloading {s3_uri} to {download_path}")
|
||||
s3.download_file(s3_uri_p.netloc, s3_uri_p.path[1:], str(download_path))
|
||||
dp_s = str(download_path)
|
||||
s3.download_file(s3_uri_p.netloc, s3_uri_p.path[1:], dp_s)
|
||||
|
||||
return download_s3
|
||||
|
||||
|
|
@ -167,7 +339,8 @@ def asr_data_writer(dataset_dir, asr_data_source, verbose=False):
|
|||
print(f"writing manifest to {asr_manifest}")
|
||||
for transcript, audio_dur, wav_data in asr_data_source:
|
||||
fname = tscript_uuid_fname(transcript)
|
||||
audio_file = dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav")
|
||||
wav_fname = Path(fname).with_suffix(".wav")
|
||||
audio_file = dataset_dir / Path("wavs") / wav_fname
|
||||
audio_file.write_bytes(wav_data)
|
||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||
|
|
@ -211,7 +384,13 @@ def ui_data_generator(dataset_dir, asr_data_source, verbose=False):
|
|||
|
||||
num_datapoints = 0
|
||||
data_funcs = []
|
||||
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
||||
for (
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
) in asr_data_source:
|
||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
audio_file = (
|
||||
dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav")
|
||||
|
|
@ -269,17 +448,6 @@ def asr_manifest_reader(data_manifest_path: Path):
|
|||
yield p
|
||||
|
||||
|
||||
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source, verbose=False):
|
||||
with asr_manifest_path.open("w") as mf:
|
||||
if verbose:
|
||||
print(f"writing asr manifest to {asr_manifest_path}")
|
||||
for mani_dict in manifest_str_source:
|
||||
manifest = manifest_str(
|
||||
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
|
||||
)
|
||||
mf.write(manifest)
|
||||
|
||||
|
||||
def asr_test_writer(out_file_path: Path, source):
|
||||
def dd_str(dd, idx):
|
||||
path = dd["audio_filepath"]
|
||||
|
|
@ -306,52 +474,6 @@ def batch(iterable, n=1):
|
|||
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
|
||||
|
||||
|
||||
class ExtendedPath(type(Path())):
|
||||
"""docstring for ExtendedPath."""
|
||||
|
||||
def read_json(self, verbose=False):
|
||||
if verbose:
|
||||
print(f"reading json from {self}")
|
||||
with self.open("r") as jf:
|
||||
return json.load(jf)
|
||||
|
||||
def read_yaml(self, verbose=False):
|
||||
yaml = YAML(typ="safe", pure=True)
|
||||
if verbose:
|
||||
print(f"reading yaml from {self}")
|
||||
with self.open("r") as yf:
|
||||
return yaml.load(yf)
|
||||
|
||||
def read_jsonl(self, verbose=False):
|
||||
if verbose:
|
||||
print(f"reading jsonl from {self}")
|
||||
with self.open("r") as jf:
|
||||
for ln in jf.readlines():
|
||||
yield json.loads(ln)
|
||||
|
||||
def write_json(self, data, verbose=False):
|
||||
if verbose:
|
||||
print(f"writing json to {self}")
|
||||
self.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.open("w") as jf:
|
||||
json.dump(data, jf, indent=2)
|
||||
|
||||
def write_yaml(self, data, verbose=False):
|
||||
yaml = YAML()
|
||||
if verbose:
|
||||
print(f"writing yaml to {self}")
|
||||
with self.open("w") as yf:
|
||||
yaml.dump(data, yf)
|
||||
|
||||
def write_jsonl(self, data, verbose=False):
|
||||
if verbose:
|
||||
print(f"writing jsonl to {self}")
|
||||
self.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.open("w") as jf:
|
||||
for d in data:
|
||||
jf.write(json.dumps(d) + "\n")
|
||||
|
||||
|
||||
def get_mongo_coll(uri):
|
||||
ud = pymongo.uri_parser.parse_uri(uri)
|
||||
conn = pymongo.MongoClient(uri)
|
||||
|
|
@ -383,37 +505,23 @@ def plot_seg(wav_plot_path, audio_path):
|
|||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
|
||||
|
||||
def parallel_apply(fn, iterable, workers=8, pool="thread"):
|
||||
if pool == "thread":
|
||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||
print(f"parallelly applying {fn}")
|
||||
return [
|
||||
res
|
||||
for res in tqdm(
|
||||
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
||||
)
|
||||
]
|
||||
elif pool == "process":
|
||||
with ProcessPoolExecutor(max_workers=workers) as exe:
|
||||
print(f"parallelly applying {fn}")
|
||||
return [
|
||||
res
|
||||
for res in tqdm(
|
||||
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise Exception(f"unsupported pool type - {pool}")
|
||||
|
||||
|
||||
def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
|
||||
min_nums = 3
|
||||
max_duration = 1 * 60 * 60
|
||||
skip_duration = 1 * 60 * 60
|
||||
max_sample_dur = 20
|
||||
min_sample_dur = 2
|
||||
verbose = True
|
||||
|
||||
src_data_enum = (
|
||||
tqdm(list(ExtendedPath(data_file).read_jsonl()))
|
||||
if verbose
|
||||
else ExtendedPath(data_file).read_jsonl()
|
||||
)
|
||||
|
||||
def filtered_max_dur():
|
||||
wav_duration = 0
|
||||
for s in ExtendedPath(data_file).read_jsonl():
|
||||
for s in src_data_enum:
|
||||
nums = re.sub(" ", "", s["text"])
|
||||
if len(nums) >= min_nums:
|
||||
wav_duration += s["duration"]
|
||||
|
|
@ -428,36 +536,54 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
|
|||
|
||||
def filtered_skip_dur():
|
||||
wav_duration = 0
|
||||
for s in ExtendedPath(data_file).read_jsonl():
|
||||
for s in src_data_enum:
|
||||
nums = re.sub(" ", "", s["text"])
|
||||
if len(nums) >= min_nums:
|
||||
wav_duration += s["duration"]
|
||||
if wav_duration <= skip_duration:
|
||||
continue
|
||||
elif len(nums) >= min_nums:
|
||||
yield s
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
yield s
|
||||
typer.echo(f"skipped {duration_str(skip_duration)} of audio")
|
||||
|
||||
def filtered_blanks():
|
||||
blank_count = 0
|
||||
for s in ExtendedPath(data_file).read_jsonl():
|
||||
blank_count = total_count = 0
|
||||
for s in src_data_enum:
|
||||
total_count += 1
|
||||
nums = re.sub(" ", "", s["text"])
|
||||
if nums != "":
|
||||
blank_count += 1
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
yield s
|
||||
typer.echo(f"filtered {blank_count} blank samples")
|
||||
else:
|
||||
blank_count += 1
|
||||
typer.echo(f"filtered {blank_count} of {total_count} blank samples")
|
||||
|
||||
def filtered_max_sample_dur():
|
||||
max_dur_count = 0
|
||||
for s in src_data_enum:
|
||||
wav_duration = s["duration"]
|
||||
if wav_duration <= max_sample_dur:
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
yield s
|
||||
else:
|
||||
max_dur_count += 1
|
||||
typer.echo(
|
||||
f"filtered {max_dur_count} samples longer thans {max_sample_dur}s"
|
||||
)
|
||||
|
||||
def filtered_transform_digits():
|
||||
count = 0
|
||||
for s in ExtendedPath(data_file).read_jsonl():
|
||||
for s in src_data_enum:
|
||||
count += 1
|
||||
digit_text = replace_digit_symbol(s["text"])
|
||||
only_digits = discard_except_digits(digit_text)
|
||||
|
|
@ -472,11 +598,13 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
|
|||
|
||||
def filtered_extract_chars():
|
||||
count = 0
|
||||
for s in ExtendedPath(data_file).read_jsonl():
|
||||
for s in src_data_enum:
|
||||
count += 1
|
||||
no_digits = digits_to_chars(s["text"]).upper()
|
||||
only_chars = re.sub("[^A-Z'\b]", " ", no_digits)
|
||||
filter_text = replace_redundant_spaces_with(only_chars, " ").strip()
|
||||
filter_text = replace_redundant_spaces_with(
|
||||
only_chars, " "
|
||||
).strip()
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
|
|
@ -487,16 +615,54 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
|
|||
|
||||
def filtered_resample():
|
||||
count = 0
|
||||
for s in ExtendedPath(data_file).read_jsonl():
|
||||
for s in src_data_enum:
|
||||
count += 1
|
||||
src_aud = pydub.AudioSegment.from_file(
|
||||
src_dataset_path / Path(s["audio_filepath"])
|
||||
)
|
||||
dst_aud = src_aud.set_channels(1).set_sample_width(1).set_frame_rate(24000)
|
||||
dst_aud.export(dest_dataset_path / Path(s["audio_filepath"]), format="wav")
|
||||
dst_aud = (
|
||||
src_aud.set_channels(1)
|
||||
.set_sample_width(1)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
dst_aud.export(
|
||||
dest_dataset_path / Path(s["audio_filepath"]), format="wav"
|
||||
)
|
||||
yield s
|
||||
typer.echo(f"transformed {count} samples")
|
||||
|
||||
def filtered_msec_to_sec():
|
||||
count = 0
|
||||
for s in src_data_enum:
|
||||
count += 1
|
||||
s["duration"] = s["duration"] / 1000
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
yield s
|
||||
typer.echo(f"transformed {count} samples")
|
||||
|
||||
def filtered_blank_hr_max_dur():
|
||||
max_duration = 3 * 60 * 60
|
||||
wav_duration = 0
|
||||
for s in src_data_enum:
|
||||
# nums = re.sub(" ", "", s["text"])
|
||||
s["text"] = "gAAAAABgq2FR6ajbhMsDmWRQBzX6gIzyAG5sMwFihGeV7E_6eVJqqF78yzmtTJPsJAOJEEXhJ9Z45MrYNgE1sq7VUdsBVGh2cw=="
|
||||
if (
|
||||
s["duration"] >= min_sample_dur
|
||||
and s["duration"] <= max_sample_dur
|
||||
):
|
||||
wav_duration += s["duration"]
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
yield s
|
||||
if wav_duration > max_duration:
|
||||
break
|
||||
typer.echo(f"filtered only {duration_str(wav_duration)} of audio")
|
||||
|
||||
filter_kind_map = {
|
||||
"max_dur_1hr_min3num": filtered_max_dur,
|
||||
"skip_dur_1hr_min3num": filtered_skip_dur,
|
||||
|
|
@ -504,5 +670,8 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
|
|||
"transform_digits": filtered_transform_digits,
|
||||
"extract_chars": filtered_extract_chars,
|
||||
"resample_ulaw24kmono": filtered_resample,
|
||||
"max_sample_dur": filtered_max_sample_dur,
|
||||
"msec_to_sec": filtered_msec_to_sec,
|
||||
"blank_3hr_max_dur": filtered_blank_hr_max_dur,
|
||||
}
|
||||
return filter_kind_map
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
from pathlib import Path
|
||||
# from IPython import display
|
||||
import requests
|
||||
import io
|
||||
import shutil
|
||||
|
||||
|
|
@ -11,6 +10,7 @@ from .tts import GoogleTTS
|
|||
|
||||
display = lazy_module('IPython.display')
|
||||
pydub = lazy_module('pydub')
|
||||
requests = lazy_module('requests')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
|
@ -72,12 +72,12 @@ def gentle_preview(
|
|||
pkg_gentle_dir = Path(__file__).parent / 'gentle_preview'
|
||||
|
||||
shutil.copytree(str(pkg_gentle_dir), str(gent_preview_dir))
|
||||
# ab = audio_path.read_bytes()
|
||||
# tt = transcript_path.read_text()
|
||||
# audio, alignment = gentle_aligner(service_uri, ab, tt)
|
||||
# audio.export(gent_preview_dir / Path("a.wav"), format="wav")
|
||||
# alignment["status"] = "OK"
|
||||
# ExtendedPath(gent_preview_dir / Path("status.json")).write_json(alignment)
|
||||
ab = audio_path.read_bytes()
|
||||
tt = transcript_path.read_text()
|
||||
audio, alignment = gentle_aligner(service_uri, ab, tt)
|
||||
audio.export(gent_preview_dir / Path("a.wav"), format="wav")
|
||||
alignment["status"] = "OK"
|
||||
ExtendedPath(gent_preview_dir / Path("status.json")).write_json(alignment)
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
import sys
|
||||
from io import BytesIO
|
||||
|
||||
from .lazy_import import lazy_module, lazy_callable
|
||||
|
||||
np = lazy_module("numpy")
|
||||
pydub = lazy_module("pydub")
|
||||
lfilter = lazy_callable("scipy.signal.lfilter")
|
||||
butter = lazy_callable("scipy.signal.butter")
|
||||
read = lazy_callable("scipy.io.wavfile.read")
|
||||
write = lazy_callable("scipy.io.wavfile.write")
|
||||
# from scipy.signal import lfilter, butter
|
||||
# from scipy.io.wavfile import read, write
|
||||
# import numpy as np
|
||||
|
||||
|
||||
def audio_seg_to_wav_bytes(aud_seg):
|
||||
b = BytesIO()
|
||||
aud_seg.export(b, format="wav")
|
||||
return b.getvalue()
|
||||
|
||||
|
||||
def audio_wav_bytes_to_seg(wav_bytes):
|
||||
b = BytesIO(wav_bytes)
|
||||
return pydub.AudioSegment.from_file(b)
|
||||
|
||||
|
||||
def butter_params(low_freq, high_freq, fs, order=5):
|
||||
nyq = 0.5 * fs
|
||||
low = low_freq / nyq
|
||||
high = high_freq / nyq
|
||||
b, a = butter(order, [low, high], btype="band")
|
||||
return b, a
|
||||
|
||||
|
||||
def butter_bandpass_filter(data, low_freq, high_freq, fs, order=5):
|
||||
b, a = butter_params(low_freq, high_freq, fs, order=order)
|
||||
y = lfilter(b, a, data)
|
||||
return y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fs, audio = read(sys.argv[1])
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
low_freq = 300.0
|
||||
high_freq = 4000.0
|
||||
filtered_signal = butter_bandpass_filter(
|
||||
audio, low_freq, high_freq, fs, order=6
|
||||
)
|
||||
fname = sys.argv[1].split(".wav")[0] + "_moded.wav"
|
||||
write(fname, fs, np.array(filtered_signal, dtype=np.int16))
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
from collections import namedtuple
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
# from cryptography.fernet import Fernet
|
||||
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import asr_manifest_writer
|
||||
from .extended_path import ExtendedPath
|
||||
from .audio import audio_seg_to_wav_bytes, audio_wav_bytes_to_seg
|
||||
from .parallel import parallel_apply
|
||||
from .lazy_import import lazy_module
|
||||
|
||||
cryptography = lazy_module("cryptography")
|
||||
# cryptography.fernet = lazy_module("cryptography.fernet")
|
||||
pydub = lazy_module("pydub")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback()
|
||||
def encrypt():
|
||||
"""
|
||||
encrypt sub commands
|
||||
"""
|
||||
|
||||
|
||||
def wav_cryptor(key=""):
|
||||
WavCryptor = namedtuple(
|
||||
"WavCryptor",
|
||||
(
|
||||
"keygen",
|
||||
"encrypt_wav_path_to",
|
||||
"decrypt_wav_path_to",
|
||||
"decrypt_wav_path",
|
||||
),
|
||||
)
|
||||
_enc_key = key
|
||||
_crypto_f = cryptography.fernet.Fernet(_enc_key)
|
||||
|
||||
def encrypt_wav_bytes(f, dec_wav_bytes):
|
||||
b = BytesIO(dec_wav_bytes)
|
||||
audio_seg = pydub.AudioSegment.from_file(b)
|
||||
# audio_seg.raw_data
|
||||
enc_wav_bytes = f.encrypt(audio_seg.raw_data)
|
||||
encrypted_seg = pydub.AudioSegment(
|
||||
enc_wav_bytes,
|
||||
frame_rate=audio_seg.frame_rate,
|
||||
channels=audio_seg.channels,
|
||||
sample_width=audio_seg.sample_width,
|
||||
)
|
||||
return audio_seg_to_wav_bytes(encrypted_seg)
|
||||
|
||||
def decrypt_wav_bytes(f, enc_wav_bytes):
|
||||
b = BytesIO(enc_wav_bytes)
|
||||
audio_seg = pydub.AudioSegment.from_file(b)
|
||||
dec_wav_bytes = f.decrypt(audio_seg.raw_data)
|
||||
decrypted_seg = pydub.AudioSegment(
|
||||
dec_wav_bytes,
|
||||
frame_rate=audio_seg.frame_rate,
|
||||
channels=audio_seg.channels,
|
||||
sample_width=audio_seg.sample_width,
|
||||
)
|
||||
return audio_seg_to_wav_bytes(decrypted_seg)
|
||||
|
||||
def encrypt_wav_path_to(dec_audio_path: Path, enc_audio_path: Path):
|
||||
dec_wav_bytes = dec_audio_path.read_bytes()
|
||||
enc_audio_path.write_bytes(encrypt_wav_bytes(_crypto_f, dec_wav_bytes))
|
||||
|
||||
def decrypt_wav_path_to(enc_audio_path: Path, dec_audio_path: Path):
|
||||
enc_wav_bytes = enc_audio_path.read_bytes()
|
||||
dec_audio_path.write_bytes(decrypt_wav_bytes(_crypto_f, enc_wav_bytes))
|
||||
|
||||
def decrypt_wav_path(enc_audio_path: Path):
|
||||
enc_wav_bytes = enc_audio_path.read_bytes()
|
||||
return decrypt_wav_bytes(_crypto_f, enc_wav_bytes)
|
||||
|
||||
return WavCryptor(
|
||||
cryptography.fernet.Fernet.generate_key,
|
||||
encrypt_wav_path_to,
|
||||
decrypt_wav_path_to,
|
||||
decrypt_wav_path,
|
||||
)
|
||||
|
||||
|
||||
def text_cryptor(key=""):
|
||||
TextCryptor = namedtuple(
|
||||
"TextCryptor",
|
||||
("keygen", "encrypt_text", "decrypt_text"),
|
||||
)
|
||||
_enc_key = key
|
||||
_crypto_f = cryptography.fernet.Fernet(_enc_key)
|
||||
|
||||
def encrypt_text(text: str):
|
||||
return _crypto_f.encrypt(text.encode("utf-8"))
|
||||
|
||||
def decrypt_text(text: str):
|
||||
return _crypto_f.decrypt(text).decode("utf-8")
|
||||
|
||||
return TextCryptor(
|
||||
cryptography.fernet.Fernet.generate_key, encrypt_text, decrypt_text
|
||||
)
|
||||
|
||||
|
||||
def encrypted_asr_manifest_reader(
|
||||
data_manifest_path: Path, encryption_key: str, verbose=True, parallel=True
|
||||
):
|
||||
print(f"reading encrypted manifest from {data_manifest_path}")
|
||||
asr_data = list(ExtendedPath(data_manifest_path).read_jsonl())
|
||||
enc_key_bytes = encryption_key.encode("utf-8")
|
||||
wc = wav_cryptor(enc_key_bytes)
|
||||
tc = text_cryptor(enc_key_bytes)
|
||||
|
||||
def decrypt_fn(p):
|
||||
d = {
|
||||
"audio_seg": audio_wav_bytes_to_seg(
|
||||
wc.decrypt_wav_path(
|
||||
data_manifest_path.parent / Path(p["audio_filepath"])
|
||||
)
|
||||
),
|
||||
"text": tc.decrypt_text(p["text"].encode("utf-8")),
|
||||
}
|
||||
return d
|
||||
|
||||
if parallel:
|
||||
for d in parallel_apply(decrypt_fn, asr_data, verbose=verbose):
|
||||
yield d
|
||||
else:
|
||||
for p in tqdm.tqdm(asr_data) if verbose else asr_data:
|
||||
yield decrypt_fn(d)
|
||||
|
||||
|
||||
def decrypt_asr_dataset(
|
||||
src_dataset_dir: Path,
|
||||
dest_dataset_dir: Path,
|
||||
encryption_key: str,
|
||||
verbose=True,
|
||||
parallel=True,
|
||||
):
|
||||
data_manifest_path = src_dataset_dir / "manifest.json"
|
||||
(dest_dataset_dir / "wavs").mkdir(exist_ok=True, parents=True)
|
||||
dest_manifest_path = dest_dataset_dir / "manifest.json"
|
||||
print(f"reading encrypted manifest from {data_manifest_path}")
|
||||
asr_data = list(ExtendedPath(data_manifest_path).read_jsonl())
|
||||
enc_key_bytes = encryption_key.encode("utf-8")
|
||||
wc = wav_cryptor(enc_key_bytes)
|
||||
tc = text_cryptor(enc_key_bytes)
|
||||
|
||||
def decrypt_fn(p):
|
||||
dest_path = dest_dataset_dir / Path(p["audio_filepath"])
|
||||
wc.decrypt_wav_path_to(
|
||||
src_dataset_dir / Path(p["audio_filepath"]), dest_path
|
||||
)
|
||||
d = {
|
||||
"audio_filepath": dest_path,
|
||||
"duration": p["duration"],
|
||||
"text": tc.decrypt_text(p["text"].encode("utf-8")),
|
||||
}
|
||||
return d
|
||||
|
||||
def datagen():
|
||||
if parallel:
|
||||
for d in parallel_apply(decrypt_fn, asr_data, verbose=verbose):
|
||||
yield d
|
||||
else:
|
||||
for p in tqdm.tqdm(asr_data) if verbose else asr_data:
|
||||
yield decrypt_fn(d)
|
||||
|
||||
asr_manifest_writer(dest_manifest_path, datagen)
|
||||
|
||||
|
||||
@app.command()
|
||||
def keygen():
|
||||
gen_key = cryptography.fernet.Fernet.generate_key()
|
||||
typer.echo(f"KEY: {gen_key}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def encrypt_text(
|
||||
text_to_encrypt: str,
|
||||
encryption_key: str = typer.Option(..., prompt=True, hide_input=True),
|
||||
):
|
||||
enc_key_bytes = encryption_key.encode("utf-8")
|
||||
tc = text_cryptor(enc_key_bytes)
|
||||
cryptext = tc.encrypt_text(text_to_encrypt)
|
||||
typer.echo(cryptext)
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from .lazy_import import lazy_module
|
||||
|
||||
yaml = lazy_module("ruamel.yaml")
|
||||
pydub = lazy_module("pydub")
|
||||
|
||||
|
||||
class ExtendedPath(type(Path())):
|
||||
"""docstring for ExtendedPath."""
|
||||
|
||||
def read_json(self, verbose=False):
|
||||
if verbose:
|
||||
print(f"reading json from {self}")
|
||||
with self.open("r") as jf:
|
||||
return json.load(jf)
|
||||
|
||||
def read_yaml(self, verbose=False):
|
||||
yaml_o = yaml.YAML(typ="safe", pure=True)
|
||||
if verbose:
|
||||
print(f"reading yaml from {self}")
|
||||
with self.open("r") as yf:
|
||||
return yaml_o.load(yf)
|
||||
|
||||
def read_jsonl(self, verbose=False):
|
||||
if verbose:
|
||||
print(f"reading jsonl from {self}")
|
||||
with self.open("r") as jf:
|
||||
for ln in jf.readlines():
|
||||
yield json.loads(ln)
|
||||
|
||||
def read_audio_segment(self):
|
||||
return pydub.AudioSegment.from_file(self)
|
||||
|
||||
def write_json(self, data, verbose=False):
|
||||
if verbose:
|
||||
print(f"writing json to {self}")
|
||||
self.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.open("w") as jf:
|
||||
json.dump(data, jf, indent=2)
|
||||
|
||||
def write_yaml(self, data, verbose=False):
|
||||
yaml_o = yaml.YAML()
|
||||
if verbose:
|
||||
print(f"writing yaml to {self}")
|
||||
with self.open("w") as yf:
|
||||
yaml_o.dump(data, yf)
|
||||
|
||||
def write_jsonl(self, data, verbose=False):
|
||||
if verbose:
|
||||
print(f"writing jsonl to {self}")
|
||||
self.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.open("w") as jf:
|
||||
for d in data:
|
||||
jf.write(json.dumps(d) + "\n")
|
||||
|
Before Width: | Height: | Size: 2.7 KiB After Width: | Height: | Size: 2.7 KiB |
|
|
@ -82,10 +82,10 @@ except ImportError:
|
|||
|
||||
# Adding a __spec__ doesn't really help. I'll leave the code here in case
|
||||
# future python implementations start relying on it.
|
||||
# try:
|
||||
# from importlib.machinery import ModuleSpec
|
||||
# except ImportError:
|
||||
# ModuleSpec = None
|
||||
try:
|
||||
from importlib.machinery import ModuleSpec
|
||||
except ImportError:
|
||||
ModuleSpec = None
|
||||
|
||||
import six
|
||||
from six import raise_from
|
||||
|
|
@ -206,8 +206,7 @@ class LazyModule(ModuleType):
|
|||
|
||||
|
||||
class LazyCallable(object):
|
||||
"""Class for lazily-loaded callables that triggers module loading on access
|
||||
"""
|
||||
"""Class for lazily-loaded callables that triggers module loading on access"""
|
||||
|
||||
def __init__(self, *args):
|
||||
if len(args) != 2:
|
||||
|
|
@ -399,9 +398,8 @@ def _lazy_module(modname, error_strings, lazy_mod_class):
|
|||
# Actual module instantiation
|
||||
mod = sys.modules[modname] = _LazyModule(modname)
|
||||
# No need for __spec__. Maybe in the future.
|
||||
# if ModuleSpec:
|
||||
# ModuleType.__setattr__(mod, '__spec__',
|
||||
# ModuleSpec(modname, None))
|
||||
if ModuleSpec:
|
||||
ModuleType.__setattr__(mod, "__spec__", ModuleSpec(modname, None))
|
||||
if fullsubmodname:
|
||||
submod = sys.modules[fullsubmodname]
|
||||
ModuleType.__setattr__(mod, submodname, submod)
|
||||
|
|
@ -531,8 +529,7 @@ def _lazy_callable(modname, cname, error_strings, lazy_mod_class, lazy_call_clas
|
|||
|
||||
|
||||
def _load_module(module):
|
||||
"""Ensures that a module, and its parents, are properly loaded
|
||||
"""
|
||||
"""Ensures that a module, and its parents, are properly loaded"""
|
||||
modclass = type(module)
|
||||
# We only take care of our own LazyModule instances
|
||||
if not issubclass(modclass, LazyModule):
|
||||
|
|
@ -623,8 +620,7 @@ _DELETION_DICT = ("_lazy_import_submodules",)
|
|||
|
||||
|
||||
def _setdef(argdict, name, defaultvalue):
|
||||
"""Like dict.setdefault but sets the default value also if None is present.
|
||||
"""
|
||||
"""Like dict.setdefault but sets the default value also if None is present."""
|
||||
if not name in argdict or argdict[name] is None:
|
||||
argdict[name] = defaultvalue
|
||||
return argdict[name]
|
||||
|
|
@ -645,8 +641,7 @@ def _set_default_errornames(modname, error_strings, call=False):
|
|||
|
||||
|
||||
def _caller_name(depth=2, default=""):
|
||||
"""Returns the name of the calling namespace.
|
||||
"""
|
||||
"""Returns the name of the calling namespace."""
|
||||
# the presence of sys._getframe might be implementation-dependent.
|
||||
# It isn't that serious if we can't get the caller's name.
|
||||
try:
|
||||
|
|
@ -700,8 +695,7 @@ def _clean_lazy_submod_refs(module):
|
|||
|
||||
|
||||
def _reset_lazymodule(module, cls_attrs):
|
||||
"""Resets a module's lazy state from cached data.
|
||||
"""
|
||||
"""Resets a module's lazy state from cached data."""
|
||||
modclass = type(module)
|
||||
del modclass.__getattribute__
|
||||
del modclass.__setattr__
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
from pathlib import Path
|
||||
|
||||
# from tqdm import tqdm
|
||||
import json
|
||||
|
||||
# from .extended_path import ExtendedPath
|
||||
# from .parallel import parallel_apply
|
||||
# from .encrypt import wav_cryptor, text_cryptor
|
||||
|
||||
|
||||
def manifest_str(path, dur, text):
|
||||
k = {"audio_filepath": path, "duration": round(dur, 1), "text": text}
|
||||
return json.dumps(k) + "\n"
|
||||
|
||||
|
||||
def asr_manifest_writer(
|
||||
asr_manifest_path: Path, manifest_str_source, verbose=False
|
||||
):
|
||||
with asr_manifest_path.open("w") as mf:
|
||||
if verbose:
|
||||
print(f"writing asr manifest to {asr_manifest_path}")
|
||||
for mani_dict in manifest_str_source:
|
||||
manifest = manifest_str(
|
||||
mani_dict["audio_filepath"],
|
||||
mani_dict["duration"],
|
||||
mani_dict["text"],
|
||||
)
|
||||
mf.write(manifest)
|
||||
|
||||
|
||||
#
|
||||
# def decrypt(
|
||||
# src_dataset_dir: Path,
|
||||
# dest_dataset_dir: Path,
|
||||
# encryption_key: str,
|
||||
# verbose=True,
|
||||
# parallel=True,
|
||||
# ):
|
||||
# data_manifest_path = src_dataset_dir / "manifest.json"
|
||||
# (dest_dataset_dir / "wavs").mkdir(exist_ok=True, parents=True)
|
||||
# dest_manifest_path = dest_dataset_dir / "manifest.json"
|
||||
# print(f"reading encrypted manifest from {data_manifest_path}")
|
||||
# asr_data = list(ExtendedPath(data_manifest_path).read_jsonl())
|
||||
# enc_key_bytes = encryption_key.encode("utf-8")
|
||||
# wc = wav_cryptor(enc_key_bytes)
|
||||
# tc = text_cryptor(enc_key_bytes)
|
||||
#
|
||||
# def decrypt_fn(p):
|
||||
# dest_path = dest_dataset_dir / Path(p["audio_filepath"])
|
||||
# wc.decrypt_wav_path_to(
|
||||
# src_dataset_dir / Path(p["audio_filepath"]), dest_path
|
||||
# )
|
||||
# d = {
|
||||
# "audio_filepath": dest_path,
|
||||
# "duration": p["duration"],
|
||||
# "text": tc.decrypt_text(p["text"].encode("utf-8")),
|
||||
# }
|
||||
# return d
|
||||
#
|
||||
# def datagen():
|
||||
# if parallel:
|
||||
# for d in parallel_apply(decrypt_fn, asr_data, verbose=verbose):
|
||||
# yield d
|
||||
# else:
|
||||
# for p in tqdm.tqdm(asr_data) if verbose else asr_data:
|
||||
# yield decrypt_fn(d)
|
||||
#
|
||||
# asr_manifest_writer(dest_manifest_path, datagen)
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def parallel_apply(fn, iterable, workers=8, pool="thread", verbose=True):
|
||||
# warm-up
|
||||
fn(iterable[0])
|
||||
if pool == "thread":
|
||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||
if verbose:
|
||||
print(f"parallelly applying {fn}")
|
||||
return [
|
||||
res
|
||||
for res in tqdm(
|
||||
exe.map(fn, iterable),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=len(iterable),
|
||||
)
|
||||
]
|
||||
else:
|
||||
return [res for res in exe.map(fn, iterable)]
|
||||
elif pool == "process":
|
||||
with ProcessPoolExecutor(max_workers=workers) as exe:
|
||||
if verbose:
|
||||
print(f"parallelly applying {fn}")
|
||||
with tqdm(total=len(iterable)) as progress:
|
||||
futures = []
|
||||
for i in iterable:
|
||||
future = exe.submit(fn, i)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
results = []
|
||||
for future in futures:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
return result
|
||||
else:
|
||||
return [res for res in exe.map(fn, iterable)]
|
||||
else:
|
||||
raise Exception(f"unsupported pool type - {pool}")
|
||||
|
|
@ -0,0 +1,383 @@
|
|||
import re
|
||||
|
||||
from .lazy_import import lazy_callable, lazy_module
|
||||
|
||||
num2words = lazy_callable("num2words.num2words")
|
||||
spellchecker = lazy_module("spellchecker")
|
||||
# from num2words import num2words
|
||||
|
||||
|
||||
def entity_replacer_keeper(
|
||||
pre_rules=[], entity_rules=[], post_rules=[], verbose=False
|
||||
):
|
||||
# def replacer_keeper_gen():
|
||||
pre_rules_c = [(re.compile(k), v) for (k, v) in pre_rules]
|
||||
entity_rules_c = [
|
||||
(re.compile(k, re.IGNORECASE), v) for (k, v) in entity_rules
|
||||
]
|
||||
post_rules_c = [(re.compile(k), v) for (k, v) in post_rules]
|
||||
|
||||
re_rules = pre_rules_c + entity_rules_c + post_rules_c
|
||||
|
||||
def replacer(w2v_out):
|
||||
out = w2v_out
|
||||
for (k, v) in re_rules:
|
||||
orig = out
|
||||
out = k.sub(v, out)
|
||||
if verbose:
|
||||
print(f"rule |{k}|: sub:|{v}| |{orig}|=> |{out}|")
|
||||
return out
|
||||
|
||||
def merge_intervals(intervals):
|
||||
# https://codereview.stackexchange.com/a/69249
|
||||
sorted_by_lower_bound = sorted(intervals, key=lambda tup: tup[0])
|
||||
merged = []
|
||||
|
||||
for higher in sorted_by_lower_bound:
|
||||
if not merged:
|
||||
merged.append(higher)
|
||||
else:
|
||||
lower = merged[-1]
|
||||
# test for intersection between lower and higher:
|
||||
# we know via sorting that lower[0] <= higher[0]
|
||||
if higher[0] <= lower[1]:
|
||||
upper_bound = max(lower[1], higher[1])
|
||||
merged[-1] = (
|
||||
lower[0],
|
||||
upper_bound,
|
||||
) # replace by merged interval
|
||||
else:
|
||||
merged.append(higher)
|
||||
return merged
|
||||
|
||||
# optimal merging interval tree
|
||||
# https://www.geeksforgeeks.org/interval-tree/
|
||||
|
||||
def keep_literals(w2v_out):
|
||||
# out = re.sub(r"[ ;,.]", " ", w2v_out).strip()
|
||||
out = w2v_out
|
||||
for (k, v) in pre_rules_c:
|
||||
out = k.sub(v, out)
|
||||
num_spans = []
|
||||
if verbose:
|
||||
print(f"num_rules: {len(entity_rules_c)}")
|
||||
for (k, v) in entity_rules_c: # [94:]:
|
||||
matches = k.finditer(out)
|
||||
for m in matches:
|
||||
# num_spans.append(m.span())
|
||||
# look at space seprated internal entities
|
||||
(start, end) = m.span()
|
||||
for s in re.finditer(r"\S+", out[start:end]):
|
||||
(start_e, end_e) = s.span()
|
||||
num_spans.append((start_e + start, end_e + start))
|
||||
if verbose:
|
||||
t = out[start_e + start : end_e + start]
|
||||
print(f"rule |{k}|: sub:|{v}| => |{t}|")
|
||||
|
||||
merged = merge_intervals(num_spans)
|
||||
num_ents = len(merged)
|
||||
keep_out = " ".join((out[s[0] : s[1]] for s in merged))
|
||||
for (k, v) in post_rules_c:
|
||||
keep_out = k.sub(v, keep_out)
|
||||
return keep_out, num_ents
|
||||
|
||||
return replacer, keep_literals
|
||||
|
||||
|
||||
def default_num_only_rules(num_range):
|
||||
entity_rules = (
|
||||
[
|
||||
(
|
||||
r"\b" + num2words(i) + r"\b",
|
||||
str(i),
|
||||
)
|
||||
for i in reversed(range(num_range))
|
||||
]
|
||||
+ [
|
||||
(
|
||||
r"\b" + str(i) + r"\b",
|
||||
str(i),
|
||||
)
|
||||
for i in reversed(range(10))
|
||||
]
|
||||
+ [
|
||||
(r"\bhundred\b", "00"),
|
||||
]
|
||||
)
|
||||
return entity_rules
|
||||
|
||||
|
||||
def default_num_rules(num_range):
|
||||
entity_rules = default_num_only_rules(num_range) + [
|
||||
(r"\boh\b", "0"),
|
||||
(r"\bo\b", "0"),
|
||||
(r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"),
|
||||
(r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"),
|
||||
]
|
||||
return entity_rules
|
||||
|
||||
|
||||
def infer_num_rules_vocab(num_range):
|
||||
vocab = [num2words(i) for i in reversed(range(num_range))] + [
|
||||
"hundred",
|
||||
"double",
|
||||
"triple",
|
||||
]
|
||||
entity_rules = [
|
||||
(
|
||||
num2words(i),
|
||||
str(i),
|
||||
)
|
||||
for i in reversed(range(num_range))
|
||||
] + [
|
||||
(r"\bhundred\b", "00"),
|
||||
(r"\boh\b", "0"),
|
||||
(r"\bo\b", "0"),
|
||||
(r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"),
|
||||
(r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"),
|
||||
]
|
||||
return entity_rules, vocab
|
||||
|
||||
|
||||
def do_tri_verbose_list():
|
||||
return [
|
||||
num2words(i) for i in list(range(11, 19)) + list(range(20, 100, 10))
|
||||
] + ["hundred"]
|
||||
|
||||
|
||||
def default_alnum_rules(num_range, oh_is_zero, i_oh_limit):
|
||||
oh_is_zero_rules = [
|
||||
(r"\boh\b", "0"),
|
||||
(r"\bo\b", "0"),
|
||||
]
|
||||
|
||||
num_list = [num2words(i) for i in reversed(range(num_range))]
|
||||
al_num_regex = r"|".join(num_list) + r"|[0-9a-z]"
|
||||
o_i_vars = r"(\[?(?:Oh|O|I)\]?)"
|
||||
i_oh_limit_rules = [
|
||||
(r"\b([a-hj-np-z])\b", "\\1"),
|
||||
(
|
||||
r"\b((?:"
|
||||
+ al_num_regex
|
||||
+ r"|^)\b\s*)(I|O)(\s*\b)(?="
|
||||
+ al_num_regex
|
||||
+ r"\s+|$)\b",
|
||||
"\\1[\\2]\\3",
|
||||
),
|
||||
# (
|
||||
# r"\b" + o_i_vars + r"(\s+)" + o_i_vars + r"\b",
|
||||
# "[\\1]\\2[\\3]",
|
||||
# ),
|
||||
(
|
||||
r"(\s+|^)" + o_i_vars + r"(\s+)\[?" + o_i_vars + r"\]?(\s+|$)",
|
||||
"\\1[\\2]\\3[\\4]\\5",
|
||||
),
|
||||
(
|
||||
r"(\s+|^)\[?" + o_i_vars + r"\]?(\s+)" + o_i_vars + r"(\s+|$)",
|
||||
"\\1[\\2]\\3[\\4]\\5",
|
||||
),
|
||||
]
|
||||
entity_rules = (
|
||||
default_num_only_rules(num_range)
|
||||
+ (oh_is_zero_rules if oh_is_zero else [(r"\boh\b", "o")])
|
||||
+ [
|
||||
(r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"),
|
||||
(r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"),
|
||||
# (r"\b([a-zA-Z])\b", "\\1"),
|
||||
]
|
||||
+ (i_oh_limit_rules if i_oh_limit else [(r"\b([a-zA-Z])\b", "\\1")])
|
||||
)
|
||||
return entity_rules
|
||||
|
||||
|
||||
def num_replacer(num_range=100, condense=True):
|
||||
entity_rules = default_num_rules(num_range)
|
||||
post_rules = [(r"[^0-9]", "")] if condense else []
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
return replacer
|
||||
|
||||
|
||||
def num_keeper(num_range=100):
|
||||
entity_rules = default_num_rules(num_range)
|
||||
pre_rules = [(r"[ ;,.]", " ")]
|
||||
post_rules = []
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
return keeper
|
||||
|
||||
|
||||
def alnum_replacer(
|
||||
num_range=100, oh_is_zero=False, i_oh_limit=True, condense=True
|
||||
):
|
||||
entity_rules = default_alnum_rules(
|
||||
num_range, oh_is_zero, i_oh_limit=i_oh_limit
|
||||
)
|
||||
# entity_rules = default_num_rules(num_range)
|
||||
pre_rules = [
|
||||
(r"[ ;,.]", " "),
|
||||
(r"[']", ""),
|
||||
# (
|
||||
# r"((?:(?<=\w{2,2})|^)\s*)(?:\bI\b|\bi\b|\bOh\b|\boh\b)(\s*(?:\w{2,}|$))",
|
||||
# "",
|
||||
# ),
|
||||
]
|
||||
|
||||
def upper_case(match_obj):
|
||||
char_elem = match_obj.group(0)
|
||||
return char_elem.upper()
|
||||
|
||||
post_rules = (
|
||||
(
|
||||
(
|
||||
[
|
||||
(r"(\s|^)(?:o|O|I|i)(\s|$)", "\\1\\2"),
|
||||
(r"\[(\w)\]", "\\1"),
|
||||
]
|
||||
if i_oh_limit
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
# (r"\b[a-zA-Z]+\'[a-zA-Z]+\b", ""),
|
||||
(r"\b[a-zA-Z]{2,}\b", ""),
|
||||
(r"[^a-zA-Z0-9]", ""),
|
||||
(r"([a-z].*)", upper_case),
|
||||
]
|
||||
)
|
||||
if condense
|
||||
else []
|
||||
)
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
return replacer
|
||||
|
||||
|
||||
def alnum_keeper(num_range=100, oh_is_zero=False):
|
||||
entity_rules = default_alnum_rules(num_range, oh_is_zero, i_oh_limit=True)
|
||||
|
||||
# def strip_space(match_obj):
|
||||
# # char_elem = match_obj.group(1)
|
||||
# return match_obj.group(1).strip() + match_obj.group(2).strip()
|
||||
|
||||
pre_rules = [
|
||||
(r"[ ;,.]", " "),
|
||||
(r"[']", ""),
|
||||
# (
|
||||
# r"((?:(?<=\w{2,2})|^)\s*)(?:\bI\b|\bi\b|\bOh\b|\boh\b)(\s*(?:\w{2,}|$))",
|
||||
# strip_space,
|
||||
# ),
|
||||
]
|
||||
|
||||
post_rules = [
|
||||
# (
|
||||
# r"((?:(?<=\w{2,2})|^)\s*)(?:\bI\b|\bi\b|\bOh\b|\boh\b)(\s*(?:\w{2,}|$))",
|
||||
# strip_space,
|
||||
# )
|
||||
]
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
return keeper
|
||||
|
||||
|
||||
def num_keeper_orig(num_range=10, extra_rules=[]):
|
||||
num_int_map_ty = [
|
||||
(
|
||||
r"\b" + num2words(i) + r"\b",
|
||||
" " + str(i) + " ",
|
||||
)
|
||||
for i in reversed(range(num_range))
|
||||
]
|
||||
re_rules = [
|
||||
(re.compile(k, re.IGNORECASE), v)
|
||||
for (k, v) in [
|
||||
# (r"[ ;,.]", " "),
|
||||
(r"\bdouble(?: |-)(\w+)\b", "\\1 \\1"),
|
||||
(r"\btriple(?: |-)(\w+)\b", "\\1 \\1 \\1"),
|
||||
(r"hundred", "00"),
|
||||
(r"\boh\b", " 0 "),
|
||||
(r"\bo\b", " 0 "),
|
||||
]
|
||||
+ num_int_map_ty
|
||||
] + [(re.compile(k), v) for (k, v) in extra_rules]
|
||||
|
||||
def merge_intervals(intervals):
|
||||
# https://codereview.stackexchange.com/a/69249
|
||||
sorted_by_lower_bound = sorted(intervals, key=lambda tup: tup[0])
|
||||
merged = []
|
||||
|
||||
for higher in sorted_by_lower_bound:
|
||||
if not merged:
|
||||
merged.append(higher)
|
||||
else:
|
||||
lower = merged[-1]
|
||||
# test for intersection between lower and higher:
|
||||
# we know via sorting that lower[0] <= higher[0]
|
||||
if higher[0] <= lower[1]:
|
||||
upper_bound = max(lower[1], higher[1])
|
||||
merged[-1] = (
|
||||
lower[0],
|
||||
upper_bound,
|
||||
) # replace by merged interval
|
||||
else:
|
||||
merged.append(higher)
|
||||
return merged
|
||||
|
||||
# merging interval tree for optimal # https://www.geeksforgeeks.org/interval-tree/
|
||||
|
||||
def keep_numeric_literals(w2v_out):
|
||||
# out = w2v_out.lower()
|
||||
out = re.sub(r"[ ;,.]", " ", w2v_out).strip()
|
||||
# out = " " + out.strip() + " "
|
||||
# out = re.sub(r"double (\w+)", "\\1 \\1", out)
|
||||
# out = re.sub(r"triple (\w+)", "\\1 \\1 \\1", out)
|
||||
num_spans = []
|
||||
for (k, v) in re_rules: # [94:]:
|
||||
matches = k.finditer(out)
|
||||
for m in matches:
|
||||
# num_spans.append((k, m.span()))
|
||||
num_spans.append(m.span())
|
||||
# out = re.sub(k, v, out)
|
||||
merged = merge_intervals(num_spans)
|
||||
num_ents = len(merged)
|
||||
keep_out = " ".join((out[s[0] : s[1]] for s in merged))
|
||||
return keep_out, num_ents
|
||||
|
||||
return keep_numeric_literals
|
||||
|
||||
|
||||
def infer_num_replacer(num_range=100, condense=True):
|
||||
entity_rules, vocab = infer_num_rules_vocab(num_range)
|
||||
corrector = vocab_corrector_gen(vocab)
|
||||
post_rules = [(r"[^0-9]", "")] if condense else []
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
|
||||
def final_replacer(x):
|
||||
return replacer(corrector(x))
|
||||
|
||||
return final_replacer
|
||||
|
||||
|
||||
def vocab_corrector_gen(vocab):
|
||||
spell = spellchecker.SpellChecker(distance=1)
|
||||
words_to_remove = set(spell.word_frequency.words()) - set(vocab)
|
||||
spell.word_frequency.remove_words(words_to_remove)
|
||||
|
||||
def corrector(inp):
|
||||
return " ".join(
|
||||
[spell.correction(tok) for tok in spell.split_words(inp)]
|
||||
)
|
||||
|
||||
return corrector
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
repl = infer_num_replacer()
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from plume.utils import lazy_module
|
||||
import typer
|
||||
|
||||
rpyc = lazy_module('rpyc')
|
||||
rpyc = lazy_module("rpyc")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
|
@ -20,7 +20,9 @@ class ASRService(rpyc.Service):
|
|||
# (to finalize the service, if needed)
|
||||
pass
|
||||
|
||||
def exposed_transcribe(self, utterance: bytes): # this is an exposed method
|
||||
def exposed_transcribe(
|
||||
self, utterance: bytes
|
||||
): # this is an exposed method
|
||||
speech_audio = self.asr.transcribe(utterance)
|
||||
return speech_audio
|
||||
|
||||
|
|
@ -5,15 +5,16 @@ from pathlib import Path
|
|||
from functools import lru_cache
|
||||
|
||||
import typer
|
||||
|
||||
# import rpyc
|
||||
|
||||
# from tqdm import tqdm
|
||||
# from pydub.silence import split_on_silence
|
||||
from plume.utils import lazy_module, lazy_callable
|
||||
from .lazy_import import lazy_module
|
||||
|
||||
rpyc = lazy_module('rpyc')
|
||||
pydub = lazy_module('pydub')
|
||||
split_on_silence = lazy_callable('pydub.silence.split_on_silence')
|
||||
rpyc = lazy_module("rpyc")
|
||||
pydub = lazy_module("pydub")
|
||||
np = lazy_module("numpy")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
|
@ -23,7 +24,7 @@ logging.basicConfig(
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ASR_RPYC_HOST = os.environ.get("JASR_RPYC_HOST", "localhost")
|
||||
ASR_RPYC_HOST = os.environ.get("ASR_RPYC_HOST", "localhost")
|
||||
ASR_RPYC_PORT = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||
|
||||
TRITON_ASR_MODEL = os.environ.get("TRITON_ASR_MODEL", "slu_wav2vec2")
|
||||
|
|
@ -37,13 +38,16 @@ def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT):
|
|||
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
|
||||
try:
|
||||
asr = rpyc.connect(asr_host, asr_port).root
|
||||
logger.info(f"connected to asr server successfully")
|
||||
logger.info("connected to asr server successfully")
|
||||
except ConnectionRefusedError:
|
||||
raise Exception("env-var JASPER_ASR_RPYC_HOST invalid")
|
||||
|
||||
def audio_prep(aud_seg):
|
||||
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
return asr_seg
|
||||
af = BytesIO()
|
||||
asr_seg.export(af, format="wav")
|
||||
input_audio_bytes = af.getvalue()
|
||||
return input_audio_bytes
|
||||
|
||||
return asr.transcribe, audio_prep
|
||||
|
||||
|
|
@ -58,9 +62,8 @@ def triton_transcribe_grpc_gen(
|
|||
# overlap=False,
|
||||
sep=" ",
|
||||
):
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
|
||||
import tritonclient.grpc as grpcclient
|
||||
import numpy as np
|
||||
|
||||
sup_meth = ["chunked", "silence", "whole"]
|
||||
if method not in sup_meth:
|
||||
|
|
@ -83,13 +86,18 @@ def triton_transcribe_grpc_gen(
|
|||
]
|
||||
inputs[0].set_data_from_numpy(input_audio_data)
|
||||
outputs = [grpcclient.InferRequestedOutput("OUTPUT_TEXT")]
|
||||
response = client.infer(asr_model, inputs, request_id=str(1), outputs=outputs)
|
||||
transcript = response.as_numpy("OUTPUT_TEXT")[0]
|
||||
try:
|
||||
response = client.infer(
|
||||
asr_model, inputs, request_id=str(1), outputs=outputs
|
||||
)
|
||||
transcript = response.as_numpy("OUTPUT_TEXT")[0]
|
||||
except InferenceServerException:
|
||||
transcript = b"[server error]"
|
||||
return transcript.decode("utf-8")
|
||||
|
||||
def chunked_transcriber(aud_seg):
|
||||
if method == "silence":
|
||||
sil_chunks = split_on_silence(
|
||||
sil_chunks = pydub.silence.split_on_silence(
|
||||
aud_seg,
|
||||
min_silence_len=sil_msec,
|
||||
silence_thresh=-50,
|
||||
|
|
@ -122,9 +130,14 @@ def triton_transcribe_grpc_gen(
|
|||
|
||||
|
||||
@app.command()
|
||||
def file(audio_file: Path, write_file: bool = False, chunked=True):
|
||||
def file(
|
||||
audio_file: Path, write_file: bool = False, chunked: bool = True, rpyc: bool = False, model='slu_wav2vec2'
|
||||
):
|
||||
aseg = pydub.AudioSegment.from_file(audio_file)
|
||||
transcriber, prep = triton_transcribe_grpc_gen()
|
||||
if rpyc:
|
||||
transcriber, prep = transcribe_rpyc_gen()
|
||||
else:
|
||||
transcriber, prep = triton_transcribe_grpc_gen(asr_model=model)
|
||||
transcription = transcriber(prep(aseg))
|
||||
|
||||
typer.echo(transcription)
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
import logging
|
||||
from .lazy_import import lazy_module
|
||||
|
||||
webrtcvad = lazy_module("webrtcvad")
|
||||
pydub = lazy_module("pydub")
|
||||
|
||||
DEFAULT_CHUNK_DUR = 30
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_frame_voice(vad, seg, chunk_dur):
|
||||
return (
|
||||
True
|
||||
if (
|
||||
seg.duration_seconds == chunk_dur / 1000
|
||||
and vad.is_speech(seg.raw_data, seg.frame_rate)
|
||||
)
|
||||
else False
|
||||
)
|
||||
|
||||
|
||||
class VADUtterance(object):
|
||||
"""docstring for VADUtterance."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_silence=500,
|
||||
min_utterance=280,
|
||||
max_utterance=20000,
|
||||
chunk_dur=DEFAULT_CHUNK_DUR,
|
||||
start_cycles=3,
|
||||
aggression=1,
|
||||
):
|
||||
super(VADUtterance, self).__init__()
|
||||
self.vad = webrtcvad.Vad(aggression)
|
||||
self.chunk_dur = chunk_dur
|
||||
# duration in millisecs
|
||||
self.max_sil = max_silence
|
||||
self.min_utt = min_utterance
|
||||
self.max_utt = max_utterance
|
||||
self.speech_start = start_cycles * chunk_dur
|
||||
|
||||
def __repr__(self):
|
||||
return f"VAD(max_silence={self.max_sil},min_utterance:{self.min_utt},max_utterance:{self.max_utt})"
|
||||
|
||||
def stream_segments(self, audio_seg):
|
||||
stream_seg = audio_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
silence_buffer = pydub.AudioSegment.empty()
|
||||
voice_buffer = pydub.AudioSegment.empty()
|
||||
silence_threshold = False
|
||||
for c in stream_seg[:: self.chunk_dur]:
|
||||
voice_frame = is_frame_voice(self.vad, c, self.chunk_dur)
|
||||
# logger.info(f"is audio stream voice? {voice_frame}")
|
||||
if voice_frame:
|
||||
silence_threshold = False
|
||||
voice_buffer += c
|
||||
silence_buffer = pydub.AudioSegment.empty()
|
||||
else:
|
||||
silence_buffer += c
|
||||
voc_dur = len(voice_buffer)
|
||||
sil_dur = len(silence_buffer)
|
||||
|
||||
if voc_dur >= self.max_utt:
|
||||
# logger.info(
|
||||
# f"detected voice overflow: voice duration {voice_buffer.duration_seconds}"
|
||||
# )
|
||||
yield voice_buffer
|
||||
voice_buffer = pydub.AudioSegment.empty()
|
||||
|
||||
if sil_dur >= self.max_sil:
|
||||
if voc_dur >= self.min_utt:
|
||||
# logger.info(
|
||||
# f"detected silence: voice duration {voice_buffer.duration_seconds}"
|
||||
# )
|
||||
yield voice_buffer
|
||||
voice_buffer = pydub.AudioSegment.empty()
|
||||
# ignore/clear voice if silence reached threshold or indent the statement
|
||||
if not silence_threshold:
|
||||
silence_threshold = True
|
||||
|
||||
# if voice_buffer:
|
||||
# yield voice_buffer
|
||||
|
||||
if self.min_utt < len(voice_buffer) < self.max_utt:
|
||||
yield voice_buffer
|
||||
|
||||
# def stream_utterance(self, audio_stream):
|
||||
# silence_buffer = pydub.AudioSegment.empty()
|
||||
# voice_buffer = pydub.AudioSegment.empty()
|
||||
# silence_threshold = False
|
||||
# for avf in audio_stream:
|
||||
# audio_bytes = avf.to_ndarray().tobytes()
|
||||
# c = (
|
||||
# pydub.AudioSegment(
|
||||
# data=audio_bytes,
|
||||
# frame_rate=avf.sample_rate,
|
||||
# channels=len(avf.layout.channels),
|
||||
# sample_width=avf.format.bytes,
|
||||
# )
|
||||
# .set_channels(1)
|
||||
# .set_sample_width(2)
|
||||
# .set_frame_rate(16000)
|
||||
# )
|
||||
# voice_frame = is_frame_voice(self.vad, c, self.chunk_dur)
|
||||
# # logger.info(f"is audio stream voice? {voice_frame}")
|
||||
# if voice_frame:
|
||||
# silence_threshold = False
|
||||
# voice_buffer += c
|
||||
# silence_buffer = pydub.AudioSegment.empty()
|
||||
# else:
|
||||
# silence_buffer += c
|
||||
# voc_dur = voice_buffer.duration_seconds * 1000
|
||||
# sil_dur = silence_buffer.duration_seconds * 1000
|
||||
#
|
||||
# if voc_dur >= self.max_utt:
|
||||
# # logger.info(
|
||||
# # f"detected voice overflow: voice duration {voice_buffer.duration_seconds}"
|
||||
# # )
|
||||
# yield voice_buffer
|
||||
# voice_buffer = pydub.AudioSegment.empty()
|
||||
#
|
||||
# if sil_dur >= self.max_sil:
|
||||
# if voc_dur >= self.min_utt:
|
||||
# # logger.info(
|
||||
# # f"detected silence: voice duration {voice_buffer.duration_seconds}"
|
||||
# # )
|
||||
# yield voice_buffer
|
||||
# voice_buffer = pydub.AudioSegment.empty()
|
||||
# # ignore/clear voice if silence reached threshold or indent the statement
|
||||
# if not silence_threshold:
|
||||
# silence_threshold = True
|
||||
#
|
||||
# if voice_buffer:
|
||||
# yield voice_buffer
|
||||
|
|
@ -0,0 +1,317 @@
|
|||
import re
|
||||
|
||||
|
||||
def entity_replacer_keeper(pre_rules=[], entity_rules=[], post_rules=[]):
|
||||
# def replacer_keeper_gen():
|
||||
pre_rules_c = [(re.compile(k), v) for (k, v) in pre_rules]
|
||||
entity_rules_c = [(re.compile(k, re.IGNORECASE), v) for (k, v) in entity_rules]
|
||||
post_rules_c = [(re.compile(k), v) for (k, v) in post_rules]
|
||||
|
||||
re_rules = pre_rules_c + entity_rules_c + post_rules_c
|
||||
|
||||
def replacer(w2v_out):
|
||||
out = w2v_out
|
||||
for (k, v) in re_rules:
|
||||
out = k.sub(v, out)
|
||||
return out
|
||||
|
||||
def merge_intervals(intervals):
|
||||
# https://codereview.stackexchange.com/a/69249
|
||||
sorted_by_lower_bound = sorted(intervals, key=lambda tup: tup[0])
|
||||
merged = []
|
||||
|
||||
for higher in sorted_by_lower_bound:
|
||||
if not merged:
|
||||
merged.append(higher)
|
||||
else:
|
||||
lower = merged[-1]
|
||||
# test for intersection between lower and higher:
|
||||
# we know via sorting that lower[0] <= higher[0]
|
||||
if higher[0] <= lower[1]:
|
||||
upper_bound = max(lower[1], higher[1])
|
||||
merged[-1] = (
|
||||
lower[0],
|
||||
upper_bound,
|
||||
) # replace by merged interval
|
||||
else:
|
||||
merged.append(higher)
|
||||
return merged
|
||||
|
||||
# merging interval tree for optimal # https://www.geeksforgeeks.org/interval-tree/
|
||||
|
||||
def keep_literals(w2v_out):
|
||||
# out = re.sub(r"[ ;,.]", " ", w2v_out).strip()
|
||||
out = w2v_out
|
||||
for (k, v) in pre_rules_c:
|
||||
out = k.sub(v, out)
|
||||
num_spans = []
|
||||
for (k, v) in entity_rules_c: # [94:]:
|
||||
matches = k.finditer(out)
|
||||
for m in matches:
|
||||
# num_spans.append((k, m.span()))
|
||||
num_spans.append(m.span())
|
||||
# out = re.sub(k, v, out)
|
||||
merged = merge_intervals(num_spans)
|
||||
num_ents = len(merged)
|
||||
keep_out = " ".join((out[s[0] : s[1]] for s in merged))
|
||||
for (k, v) in post_rules_c:
|
||||
keep_out = k.sub(v, keep_out)
|
||||
return keep_out, num_ents
|
||||
|
||||
return replacer, keep_literals
|
||||
|
||||
|
||||
def default_num_only_rules(num_range):
|
||||
entity_rules = (
|
||||
[
|
||||
("\\bninety-nine\\b", "99"),
|
||||
("\\bninety-eight\\b", "98"),
|
||||
("\\bninety-seven\\b", "97"),
|
||||
("\\bninety-six\\b", "96"),
|
||||
("\\bninety-five\\b", "95"),
|
||||
("\\bninety-four\\b", "94"),
|
||||
("\\bninety-three\\b", "93"),
|
||||
("\\bninety-two\\b", "92"),
|
||||
("\\bninety-one\\b", "91"),
|
||||
("\\bninety\\b", "90"),
|
||||
("\\beighty-nine\\b", "89"),
|
||||
("\\beighty-eight\\b", "88"),
|
||||
("\\beighty-seven\\b", "87"),
|
||||
("\\beighty-six\\b", "86"),
|
||||
("\\beighty-five\\b", "85"),
|
||||
("\\beighty-four\\b", "84"),
|
||||
("\\beighty-three\\b", "83"),
|
||||
("\\beighty-two\\b", "82"),
|
||||
("\\beighty-one\\b", "81"),
|
||||
("\\beighty\\b", "80"),
|
||||
("\\bseventy-nine\\b", "79"),
|
||||
("\\bseventy-eight\\b", "78"),
|
||||
("\\bseventy-seven\\b", "77"),
|
||||
("\\bseventy-six\\b", "76"),
|
||||
("\\bseventy-five\\b", "75"),
|
||||
("\\bseventy-four\\b", "74"),
|
||||
("\\bseventy-three\\b", "73"),
|
||||
("\\bseventy-two\\b", "72"),
|
||||
("\\bseventy-one\\b", "71"),
|
||||
("\\bseventy\\b", "70"),
|
||||
("\\bsixty-nine\\b", "69"),
|
||||
("\\bsixty-eight\\b", "68"),
|
||||
("\\bsixty-seven\\b", "67"),
|
||||
("\\bsixty-six\\b", "66"),
|
||||
("\\bsixty-five\\b", "65"),
|
||||
("\\bsixty-four\\b", "64"),
|
||||
("\\bsixty-three\\b", "63"),
|
||||
("\\bsixty-two\\b", "62"),
|
||||
("\\bsixty-one\\b", "61"),
|
||||
("\\bsixty\\b", "60"),
|
||||
("\\bfifty-nine\\b", "59"),
|
||||
("\\bfifty-eight\\b", "58"),
|
||||
("\\bfifty-seven\\b", "57"),
|
||||
("\\bfifty-six\\b", "56"),
|
||||
("\\bfifty-five\\b", "55"),
|
||||
("\\bfifty-four\\b", "54"),
|
||||
("\\bfifty-three\\b", "53"),
|
||||
("\\bfifty-two\\b", "52"),
|
||||
("\\bfifty-one\\b", "51"),
|
||||
("\\bfifty\\b", "50"),
|
||||
("\\bforty-nine\\b", "49"),
|
||||
("\\bforty-eight\\b", "48"),
|
||||
("\\bforty-seven\\b", "47"),
|
||||
("\\bforty-six\\b", "46"),
|
||||
("\\bforty-five\\b", "45"),
|
||||
("\\bforty-four\\b", "44"),
|
||||
("\\bforty-three\\b", "43"),
|
||||
("\\bforty-two\\b", "42"),
|
||||
("\\bforty-one\\b", "41"),
|
||||
("\\bforty\\b", "40"),
|
||||
("\\bthirty-nine\\b", "39"),
|
||||
("\\bthirty-eight\\b", "38"),
|
||||
("\\bthirty-seven\\b", "37"),
|
||||
("\\bthirty-six\\b", "36"),
|
||||
("\\bthirty-five\\b", "35"),
|
||||
("\\bthirty-four\\b", "34"),
|
||||
("\\bthirty-three\\b", "33"),
|
||||
("\\bthirty-two\\b", "32"),
|
||||
("\\bthirty-one\\b", "31"),
|
||||
("\\bthirty\\b", "30"),
|
||||
("\\btwenty-nine\\b", "29"),
|
||||
("\\btwenty-eight\\b", "28"),
|
||||
("\\btwenty-seven\\b", "27"),
|
||||
("\\btwenty-six\\b", "26"),
|
||||
("\\btwenty-five\\b", "25"),
|
||||
("\\btwenty-four\\b", "24"),
|
||||
("\\btwenty-three\\b", "23"),
|
||||
("\\btwenty-two\\b", "22"),
|
||||
("\\btwenty-one\\b", "21"),
|
||||
("\\btwenty\\b", "20"),
|
||||
("\\bnineteen\\b", "19"),
|
||||
("\\beighteen\\b", "18"),
|
||||
("\\bseventeen\\b", "17"),
|
||||
("\\bsixteen\\b", "16"),
|
||||
("\\bfifteen\\b", "15"),
|
||||
("\\bfourteen\\b", "14"),
|
||||
("\\bthirteen\\b", "13"),
|
||||
("\\btwelve\\b", "12"),
|
||||
("\\beleven\\b", "11"),
|
||||
("\\bten\\b", "10"),
|
||||
("\\bnine\\b", "9"),
|
||||
("\\beight\\b", "8"),
|
||||
("\\bseven\\b", "7"),
|
||||
("\\bsix\\b", "6"),
|
||||
("\\bfive\\b", "5"),
|
||||
("\\bfour\\b", "4"),
|
||||
("\\bthree\\b", "3"),
|
||||
("\\btwo\\b", "2"),
|
||||
("\\bone\\b", "1"),
|
||||
("\\bzero\\b", "0"),
|
||||
]
|
||||
+ [
|
||||
(
|
||||
r"\b" + str(i) + r"\b",
|
||||
str(i),
|
||||
)
|
||||
for i in reversed(range(10))
|
||||
]
|
||||
+ [
|
||||
(r"\bhundred\b", "00"),
|
||||
]
|
||||
)
|
||||
return entity_rules
|
||||
|
||||
|
||||
def default_num_rules(num_range):
|
||||
entity_rules = default_num_only_rules(num_range) + [
|
||||
(r"\boh\b", " 0 "),
|
||||
(r"\bo\b", " 0 "),
|
||||
(r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"),
|
||||
(r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"),
|
||||
]
|
||||
return entity_rules
|
||||
|
||||
|
||||
def default_alnum_rules(num_range, oh_is_zero):
|
||||
oh_is_zero_rules = [
|
||||
(r"\boh\b", "0"),
|
||||
(r"\bo\b", "0"),
|
||||
]
|
||||
entity_rules = (
|
||||
default_num_only_rules(num_range)
|
||||
+ (oh_is_zero_rules if oh_is_zero else [(r"\boh\b", "o")])
|
||||
+ [
|
||||
(r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"),
|
||||
(r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"),
|
||||
(r"\b([a-zA-Z])\b", "\\1"),
|
||||
]
|
||||
)
|
||||
return entity_rules
|
||||
|
||||
|
||||
def num_replacer(num_range=100, condense=True):
|
||||
entity_rules = default_num_rules(num_range)
|
||||
post_rules = [(r"[^0-9]", "")] if condense else []
|
||||
# post_rules = []
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
return replacer
|
||||
|
||||
|
||||
def num_keeper(num_range=100):
|
||||
entity_rules = default_num_rules(num_range)
|
||||
pre_rules = [(r"[ ;,.]", " ")]
|
||||
post_rules = []
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
return keeper
|
||||
|
||||
|
||||
def alnum_replacer(num_range=100, oh_is_zero=False, condense=True):
|
||||
entity_rules = default_alnum_rules(num_range, oh_is_zero)
|
||||
# entity_rules = default_num_rules(num_range)
|
||||
pre_rules = [(r"[ ;,.]", " "), (r"[']", "")]
|
||||
|
||||
def upper_case(match_obj):
|
||||
char_elem = match_obj.group(0)
|
||||
return char_elem.upper()
|
||||
|
||||
post_rules = (
|
||||
[
|
||||
# (r"\b[a-zA-Z]+\'[a-zA-Z]+\b", ""),
|
||||
(r"\b[a-zA-Z]{2,}\b", ""),
|
||||
(r"[^a-zA-Z0-9]", ""),
|
||||
(r"([a-z].*)", upper_case),
|
||||
]
|
||||
if condense
|
||||
else []
|
||||
)
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
return replacer
|
||||
|
||||
|
||||
def alnum_keeper(num_range=100, oh_is_zero=False):
|
||||
entity_rules = default_alnum_rules(num_range, oh_is_zero)
|
||||
pre_rules = [(r"[ ;,.]", " "), (r"[']", "")]
|
||||
post_rules = []
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
return keeper
|
||||
|
||||
|
||||
def test_num():
|
||||
num_extractor = num_replacer()
|
||||
keeper = num_keeper()
|
||||
num_only_replacer = num_replacer(condense=False)
|
||||
assert num_extractor("thirty-two") == "32"
|
||||
assert num_extractor("not thirty-two fifty-nine") == "3259"
|
||||
assert num_extractor(" triPle 5 fifty 3") == "555503"
|
||||
assert num_only_replacer(" triPle 5 fifty 3") == " 5 5 5 50 3"
|
||||
assert num_extractor("douBle 2 130") == "22130"
|
||||
assert num_extractor("It is a One fifty eIght 5 fifty ") == "1508550"
|
||||
assert (
|
||||
num_only_replacer(" It is a One fifty eIght 5 fifty ")
|
||||
== " It is a 1 50 8 5 50 "
|
||||
)
|
||||
assert num_extractor("One fifty-eight 5 oh o fifty") == "15850050"
|
||||
assert keeper(
|
||||
"my phone number is One hundred fifty-eight not 5 oh o fifty more"
|
||||
) == ("One hundred fifty-eight 5 oh o fifty", 7)
|
||||
|
||||
|
||||
def test_alnum():
|
||||
extractor_oh = alnum_replacer(oh_is_zero=True)
|
||||
extractor = alnum_replacer()
|
||||
keeper = alnum_keeper()
|
||||
only_replacer = alnum_replacer(condense=False)
|
||||
assert extractor("I'm thirty-two") == "32"
|
||||
assert extractor("a thirty-two") == "A32"
|
||||
assert extractor("not a b thirty-two fifty-nine") == "AB3259"
|
||||
assert extractor(" triPle 5 fifty 3") == "555503"
|
||||
assert only_replacer(" triPle 5 fifty 3") == " 5 5 5 50 3"
|
||||
assert extractor("douBle 2 130") == "22130"
|
||||
assert extractor("It is a One b fifty eIght A Z 5 fifty ") == "A1B508AZ550"
|
||||
assert (
|
||||
only_replacer(" It's a ; One b fifty eIght A Z 5 fifty ")
|
||||
== " Its a 1 b 50 8 A Z 5 50 "
|
||||
)
|
||||
assert (
|
||||
only_replacer(" I'm is a One b fifty eIght A Z 5 fifty ")
|
||||
== " Im is a 1 b 50 8 A Z 5 50 "
|
||||
)
|
||||
assert extractor("One Z fifty-eight 5 oh o b fifty") == "1Z585OOB50"
|
||||
assert extractor_oh("One Z fifty-eight 5 oh o b fifty") == "1Z58500B50"
|
||||
assert keeper(
|
||||
"I'll phone number One hundred n fifty-eight not 5 oh o fifty A B more"
|
||||
) == ("One hundred n fifty-eight 5 oh o fifty A B", 10)
|
||||
assert keeper("I'm One hundred n fifty-eight not 5 oh o fifty A B more") == (
|
||||
"One hundred n fifty-eight 5 oh o fifty A B",
|
||||
10,
|
||||
)
|
||||
|
||||
assert keeper("I am One hundred n fifty-eight not 5 oh o fifty A B more") == (
|
||||
"I One hundred n fifty-eight 5 oh o fifty A B",
|
||||
11,
|
||||
)
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
from plume.utils import (
|
||||
num_replacer,
|
||||
num_keeper,
|
||||
alnum_replacer,
|
||||
alnum_keeper,
|
||||
random_segs,
|
||||
)
|
||||
import numpy
|
||||
import random as rand
|
||||
import pytest
|
||||
|
||||
|
||||
def test_num_replacer_keeper():
|
||||
num_extractor = num_replacer()
|
||||
num_only_replacer = num_replacer(condense=False)
|
||||
assert num_extractor("thirty-two") == "32"
|
||||
assert num_extractor("not thirty-two fifty-nine") == "3259"
|
||||
assert num_extractor(" triPle 5 fifty 3") == "555503"
|
||||
assert num_only_replacer(" triPle 5 fifty 3") == " 5 5 5 50 3"
|
||||
assert num_extractor("douBle 2 130") == "22130"
|
||||
assert num_extractor("It is a One fifty eIght 5 fifty ") == "1508550"
|
||||
assert (
|
||||
num_only_replacer(" It is a One fifty eIght 5 fifty ")
|
||||
== " It is a 1 50 8 5 50 "
|
||||
)
|
||||
assert num_extractor("One fifty-eight 5 oh o fifty") == "15850050"
|
||||
keeper = num_keeper()
|
||||
assert keeper(
|
||||
"my phone number is One hundred fifty-eight not 5 oh o fifty more"
|
||||
) == ("One hundred fifty-eight 5 oh o fifty", 7)
|
||||
|
||||
|
||||
def test_alnum_replacer():
|
||||
extractor_oh = alnum_replacer(oh_is_zero=True)
|
||||
extractor = alnum_replacer()
|
||||
only_replacer = alnum_replacer(condense=False)
|
||||
assert extractor("5 oh i c 3") == "5OIC3"
|
||||
assert extractor("I am, oh it is 3. I will") == "3"
|
||||
assert extractor("I oh o 3") == "IOO3"
|
||||
assert extractor("I will 3 I") == "3I"
|
||||
assert extractor("I'm thirty-two") == "32"
|
||||
assert extractor("I am thirty-two") == "32"
|
||||
assert extractor("I j thirty-two") == "IJ32"
|
||||
assert extractor("a thirty-two") == "A32"
|
||||
assert extractor("not a b thirty-two fifty-nine") == "AB3259"
|
||||
assert extractor(" triPle 5 fifty 3") == "555503"
|
||||
assert only_replacer(" triPle 5 fifty 3") == " 5 5 5 50 3"
|
||||
assert extractor("douBle 2 130") == "22130"
|
||||
assert extractor("It is a One b fifty eIght A Z 5 fifty ") == "A1B508AZ550"
|
||||
assert (
|
||||
only_replacer(" It's a ; One b fifty eIght A Z 5 fifty ")
|
||||
== " Its a 1 b 50 8 A Z 5 50 "
|
||||
)
|
||||
assert (
|
||||
only_replacer(" I'm is a One b fifty eIght A Z 5 fifty ")
|
||||
== " Im is a 1 b 50 8 A Z 5 50 "
|
||||
)
|
||||
assert extractor("One Z fifty-eight 5 oh o b fifty") == "1Z585OOB50"
|
||||
assert extractor_oh("One Z fifty-eight 5 oh o b fifty") == "1Z58500B50"
|
||||
assert (
|
||||
extractor("I One hundred n fifty-eight not 5 oh o fifty A B more")
|
||||
== "I100N585OO50AB"
|
||||
)
|
||||
|
||||
|
||||
def test_alnum_keeper():
|
||||
keeper = alnum_keeper()
|
||||
assert keeper("I One hundred n fifty-eight not 5 oh o fifty A B more") == (
|
||||
"I One hundred n fifty-eight 5 oh o fifty A B",
|
||||
11,
|
||||
)
|
||||
assert keeper(
|
||||
"I'll phone number One hundred n fifty-eight not 5 oh o fifty A B more"
|
||||
) == ("One hundred n fifty-eight 5 oh o fifty A B", 10)
|
||||
assert keeper(
|
||||
"I'm One hundred n fifty-eight not 5 oh o fifty A B more"
|
||||
) == (
|
||||
"One hundred n fifty-eight 5 oh o fifty A B",
|
||||
10,
|
||||
)
|
||||
|
||||
assert keeper(
|
||||
"I am One hundred n fifty-eight not 5 oh o fifty A B more"
|
||||
) == (
|
||||
"One hundred n fifty-eight 5 oh o fifty A B",
|
||||
10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random():
|
||||
rand.seed(0)
|
||||
numpy.random.seed(0)
|
||||
|
||||
|
||||
def test_random_segs(random):
|
||||
segs = random_segs(100000, 1000, 3000)
|
||||
|
||||
def segs_comply(segs, min, max):
|
||||
for (start, end) in segs:
|
||||
if end - start < min or end - start > max:
|
||||
return False
|
||||
return True
|
||||
|
||||
assert segs_comply(segs, 1000, 3000) == True
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from plume.utils.regentity import infer_num_replacer
|
||||
|
||||
|
||||
def test_infer_num():
|
||||
repl = infer_num_replacer()
|
||||
|
||||
assert (
|
||||
repl(
|
||||
"SIX NINE TRIPL EIGHT SIX SIX DOULE NINE THREE ZERO TWO SEVENT-ONE"
|
||||
)
|
||||
== "69888669930271"
|
||||
)
|
||||
|
||||
assert (
|
||||
repl("SIX NINE FSIX EIGHT IGSIX SIX NINE NINE THRE ZERO TWO SEVEN ONE")
|
||||
== "6968669930271"
|
||||
)
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
# tox (https://tox.readthedocs.io/) is a tool for running tests
|
||||
# in multiple virtualenvs. This configuration file will run the
|
||||
# test suite on all supported python versions. To use it, "pip install tox"
|
||||
# and then run "tox" from this directory.
|
||||
|
||||
[tox]
|
||||
envlist = py38
|
||||
|
||||
[testenv]
|
||||
deps =
|
||||
pytest
|
||||
commands =
|
||||
pytest
|
||||
Loading…
Reference in New Issue