mirror of
https://github.com/malarinv/plume-asr.git
synced 2026-03-08 04:12:35 +00:00
1. refactor package root to src/ layout
2. add framwork suffix for models 3. change black max columns to 79 4. add tests 5. integrate vad, encrypt and refactor manifest, regentity, extended_path, audio, parallel utils 6. added ui utils for encrypted preview 7. wip marblenet model 8. added transformers based wav2vec2 inference 9. update readme and manifest 10. add deploy setup target
This commit is contained in:
23
src/plume/cli/__init__.py
Normal file
23
src/plume/cli/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import typer
|
||||
from ..utils import app as utils_app
|
||||
from .data import app as data_app
|
||||
from ..ui import app as ui_app
|
||||
from .train import app as train_app
|
||||
from .eval import app as eval_app
|
||||
from .serve import app as serve_app
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(data_app)
|
||||
app.add_typer(ui_app)
|
||||
app.add_typer(train_app)
|
||||
app.add_typer(eval_app)
|
||||
app.add_typer(serve_app)
|
||||
app.add_typer(utils_app)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
src/plume/cli/__main__.py
Normal file
5
src/plume/cli/__main__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from . import main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
530
src/plume/cli/data/__init__.py
Normal file
530
src/plume/cli/data/__init__.py
Normal file
@@ -0,0 +1,530 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from random import shuffle
|
||||
from typing import List
|
||||
from itertools import chain
|
||||
|
||||
# from sklearn.model_selection import train_test_split
|
||||
from tqdm import tqdm
|
||||
import shutil
|
||||
import typer
|
||||
|
||||
from plume.utils import (
|
||||
asr_manifest_reader,
|
||||
asr_manifest_writer,
|
||||
ExtendedPath,
|
||||
duration_str,
|
||||
generate_filter_map,
|
||||
get_mongo_conn,
|
||||
tscript_uuid_fname,
|
||||
lazy_callable,
|
||||
lazy_module,
|
||||
wav_cryptor,
|
||||
text_cryptor,
|
||||
parallel_apply,
|
||||
)
|
||||
|
||||
from ...models.wav2vec2.data import app as wav2vec2_app
|
||||
from .generate import app as generate_app
|
||||
|
||||
soundfile = lazy_module("soundfile")
|
||||
pydub = lazy_module("pydub")
|
||||
train_test_split = lazy_callable("sklearn.model_selection.train_test_split")
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(generate_app, name="generate")
|
||||
app.add_typer(wav2vec2_app, name="wav2vec2")
|
||||
|
||||
|
||||
@app.callback()
|
||||
def data():
|
||||
"""
|
||||
data sub commands
|
||||
"""
|
||||
|
||||
|
||||
@app.command()
|
||||
def fix_path(dataset_path: Path, force: bool = False):
|
||||
manifest_path = dataset_path / Path("manifest.json")
|
||||
real_manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
|
||||
def fix_real_path():
|
||||
for i in asr_manifest_reader(manifest_path):
|
||||
i["audio_filepath"] = str(
|
||||
(dataset_path / Path(i["audio_filepath"])).absolute()
|
||||
)
|
||||
yield i
|
||||
|
||||
def fix_rel_path():
|
||||
for i in asr_manifest_reader(real_manifest_path):
|
||||
i["audio_filepath"] = str(
|
||||
Path(i["audio_filepath"]).relative_to(dataset_path)
|
||||
)
|
||||
yield i
|
||||
|
||||
if not manifest_path.exists() and not real_manifest_path.exists():
|
||||
typer.echo("Invalid dataset directory")
|
||||
if not real_manifest_path.exists() or force:
|
||||
asr_manifest_writer(real_manifest_path, fix_real_path())
|
||||
if not manifest_path.exists():
|
||||
asr_manifest_writer(manifest_path, fix_rel_path())
|
||||
|
||||
|
||||
@app.command()
|
||||
def merge(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||
reader_list = []
|
||||
abs_manifest_path = Path("abs_manifest.json")
|
||||
for dataset_path in src_dataset_paths:
|
||||
manifest_path = dataset_path / abs_manifest_path
|
||||
reader_list.append(asr_manifest_reader(manifest_path))
|
||||
dest_dataset_path.mkdir(parents=True, exist_ok=True)
|
||||
dest_manifest_path = dest_dataset_path / abs_manifest_path
|
||||
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
|
||||
|
||||
|
||||
@app.command()
|
||||
def training_split(dataset_path: Path, test_size: float = 0.03):
|
||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
if not manifest_path.exists():
|
||||
fix_path(dataset_path)
|
||||
asr_data = list(asr_manifest_reader(manifest_path))
|
||||
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
|
||||
asr_manifest_writer(
|
||||
manifest_path.with_name("train_manifest.json"), train_pnr
|
||||
)
|
||||
asr_manifest_writer(
|
||||
manifest_path.with_name("test_manifest.json"), test_pnr
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def parts_split_by_size(
|
||||
dataset_path: Path,
|
||||
test_size: float = 0.03,
|
||||
split_prefix_names: List[str] = ["train", "test"],
|
||||
):
|
||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
if not manifest_path.exists():
|
||||
fix_path(dataset_path)
|
||||
asr_data = list(asr_manifest_reader(manifest_path))
|
||||
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
|
||||
dest_paths = [
|
||||
(dataset_path.parent / (dataset_path.name + "_" + spn), sd)
|
||||
for (spn, sd) in zip(split_prefix_names, [train_pnr, test_pnr])
|
||||
]
|
||||
for dest_path, manifest_data in dest_paths:
|
||||
wav_dir = dest_path / Path("wavs")
|
||||
wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
abs_manifest_path = ExtendedPath(dest_path / Path("abs_manifest.json"))
|
||||
for md in manifest_data:
|
||||
orig_path = Path(md["audio_filepath"])
|
||||
new_path = wav_dir / Path(orig_path.name)
|
||||
shutil.copy(orig_path, new_path)
|
||||
md["audio_filepath"] = str(new_path)
|
||||
md.pop("audio_path")
|
||||
abs_manifest_path.write_jsonl(manifest_data)
|
||||
fix_path(dest_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
def parts_split_by_dur(
|
||||
dataset_path: Path,
|
||||
dur_sec: int = 7200,
|
||||
suffix_name: List[str] = ["train", "test"],
|
||||
):
|
||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
if not manifest_path.exists():
|
||||
fix_path(dataset_path)
|
||||
asr_data = list(asr_manifest_reader(manifest_path))
|
||||
|
||||
def dur_split(dataset, dur_seconds):
|
||||
shuffle(dataset)
|
||||
counter_dur = 0
|
||||
train_set, test_set = [], []
|
||||
for d in dataset:
|
||||
if counter_dur <= dur_seconds:
|
||||
test_set.append(d)
|
||||
else:
|
||||
train_set.append(d)
|
||||
counter_dur += d["duration"]
|
||||
return train_set, test_set
|
||||
|
||||
train_pnr, test_pnr = dur_split(asr_data, dur_sec)
|
||||
dest_paths = [
|
||||
(dataset_path.parent / (dataset_path.name + "_" + spn), sd)
|
||||
for (spn, sd) in zip(suffix_name, [train_pnr, test_pnr])
|
||||
]
|
||||
for dest_path, manifest_data in dest_paths:
|
||||
wav_dir = dest_path / Path("wavs")
|
||||
wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
abs_manifest_path = ExtendedPath(dest_path / Path("abs_manifest.json"))
|
||||
for md in manifest_data:
|
||||
orig_path = Path(md["audio_filepath"])
|
||||
new_path = wav_dir / Path(orig_path.name)
|
||||
shutil.copy(orig_path, new_path)
|
||||
md["audio_filepath"] = str(new_path.absolute())
|
||||
md.pop("audio_path")
|
||||
abs_manifest_path.write_jsonl(manifest_data)
|
||||
fix_path(dest_path.absolute())
|
||||
|
||||
|
||||
@app.command()
|
||||
def validate(dataset_path: Path):
|
||||
from natural.date import compress
|
||||
from datetime import timedelta
|
||||
|
||||
for mf_type in ["train_manifest.json", "test_manifest.json"]:
|
||||
data_file = dataset_path / Path(mf_type)
|
||||
print(f"validating {data_file}.")
|
||||
with Path(data_file).open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
duration = 0
|
||||
for (i, s) in enumerate(pnr_jsonl):
|
||||
try:
|
||||
d = json.loads(s)
|
||||
duration += d["duration"]
|
||||
audio_file = data_file.parent / Path(d["audio_filepath"])
|
||||
if not audio_file.exists():
|
||||
raise OSError(f"File {audio_file} not found")
|
||||
except BaseException as e:
|
||||
print(f'failed on {i} with "{e}"')
|
||||
duration_str = compress(timedelta(seconds=duration), pad=" ")
|
||||
print(
|
||||
f"no errors found. seems like a valid {mf_type}. contains {duration_str} of audio"
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def filter(
|
||||
src_dataset_path: Path,
|
||||
dest_dataset_path: Path,
|
||||
kind: str = "",
|
||||
):
|
||||
dest_manifest = dest_dataset_path / Path("manifest.json")
|
||||
data_file = src_dataset_path / Path("manifest.json")
|
||||
dest_wav_dir = dest_dataset_path / Path("wavs")
|
||||
dest_wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
filter_kind_map = generate_filter_map(
|
||||
src_dataset_path, dest_dataset_path, data_file
|
||||
)
|
||||
|
||||
selected_filter = filter_kind_map.get(kind, None)
|
||||
if selected_filter:
|
||||
asr_manifest_writer(dest_manifest, selected_filter())
|
||||
else:
|
||||
typer.echo(f"filter kind - {kind} not implemented")
|
||||
typer.echo(f"select one of {', '.join(filter_kind_map.keys())}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def info(dataset_path: Path):
|
||||
for k in ["", "abs_", "train_", "test_"]:
|
||||
mf_wav_duration = (
|
||||
real_duration
|
||||
) = max_duration = empty_duration = empty_count = total_count = 0
|
||||
data_file = dataset_path / Path(f"{k}manifest.json")
|
||||
if data_file.exists():
|
||||
print(f"stats on {data_file}")
|
||||
for s in ExtendedPath(data_file).read_jsonl():
|
||||
total_count += 1
|
||||
mf_wav_duration += s["duration"]
|
||||
if s["text"] == "":
|
||||
empty_count += 1
|
||||
empty_duration += s["duration"]
|
||||
wav_path = str(dataset_path / Path(s["audio_filepath"]))
|
||||
if max_duration < soundfile.info(wav_path).duration:
|
||||
max_duration = soundfile.info(wav_path).duration
|
||||
real_duration += soundfile.info(wav_path).duration
|
||||
|
||||
# frame_count = soundfile.info(audio_fname).frames
|
||||
print(
|
||||
f"max audio duration : {duration_str(max_duration, show_hours=True)}"
|
||||
)
|
||||
print(
|
||||
f"total audio duration : {duration_str(mf_wav_duration, show_hours=True)}"
|
||||
)
|
||||
print(
|
||||
f"total real audio duration : {duration_str(real_duration, show_hours=True)}"
|
||||
)
|
||||
print(
|
||||
f"total content duration : {duration_str(mf_wav_duration-empty_duration, show_hours=True)}"
|
||||
)
|
||||
print(
|
||||
f"total empty duration : {duration_str(empty_duration, show_hours=True)}"
|
||||
)
|
||||
print(
|
||||
f"total empty samples : {empty_count}/{total_count} ({empty_count*100/total_count:.2f}%)"
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def audio_duration(dataset_path: Path):
|
||||
wav_duration = 0
|
||||
for audio_rel_fname in dataset_path.absolute().glob("**/*.wav"):
|
||||
audio_fname = str(audio_rel_fname)
|
||||
wav_duration += soundfile.info(audio_fname).duration
|
||||
typer.echo(
|
||||
f"duration of wav files @ {dataset_path}: {duration_str(wav_duration)}"
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def encrypt(
|
||||
src_dataset_path: Path,
|
||||
dest_dataset_path: Path,
|
||||
encryption_key: str = typer.Option(..., prompt=True, hide_input=True),
|
||||
verbose: bool = False,
|
||||
):
|
||||
dest_manifest = dest_dataset_path / Path("manifest.json")
|
||||
src_manifest = src_dataset_path / Path("manifest.json")
|
||||
dest_wav_dir = dest_dataset_path / Path("wavs")
|
||||
dest_wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
wav_crypt = wav_cryptor(encryption_key)
|
||||
text_crypt = text_cryptor(encryption_key)
|
||||
# warmup
|
||||
_ = pydub.AudioSegment.from_file
|
||||
|
||||
def encrypt_item(s):
|
||||
crypt_text = text_crypt.encrypt_text(s["text"])
|
||||
src_wav_path = src_dataset_path / s["audio_filepath"]
|
||||
dst_wav_path = dest_dataset_path / s["audio_filepath"]
|
||||
wav_crypt.encrypt_wav_path_to(src_wav_path, dst_wav_path)
|
||||
s["text"] = crypt_text.decode("utf-8")
|
||||
return s
|
||||
|
||||
def encryted_gen():
|
||||
data = list(ExtendedPath(src_manifest).read_jsonl())
|
||||
iter_data = tqdm(data) if verbose else data
|
||||
encrypted_iter_data = parallel_apply(
|
||||
encrypt_item, iter_data, verbose=verbose, workers=64
|
||||
)
|
||||
for s in encrypted_iter_data:
|
||||
yield s
|
||||
|
||||
asr_manifest_writer(dest_manifest, encryted_gen(), verbose=verbose)
|
||||
|
||||
|
||||
@app.command()
|
||||
def decrypt(
|
||||
src_dataset_path: Path,
|
||||
dest_dataset_path: Path,
|
||||
encryption_key: str = typer.Option(..., prompt=True, hide_input=True),
|
||||
verbose: bool = True,
|
||||
):
|
||||
dest_manifest = dest_dataset_path / Path("manifest.json")
|
||||
src_manifest = src_dataset_path / Path("manifest.json")
|
||||
dest_wav_dir = dest_dataset_path / Path("wavs")
|
||||
dest_wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
wav_crypt = wav_cryptor(encryption_key)
|
||||
text_crypt = text_cryptor(encryption_key)
|
||||
# warmup
|
||||
_ = pydub.AudioSegment.from_file
|
||||
|
||||
def decrypt_item(s):
|
||||
crypt_text = text_crypt.decrypt_text(s["text"].encode("utf-8"))
|
||||
src_wav_path = src_dataset_path / s["audio_filepath"]
|
||||
dst_wav_path = dest_dataset_path / s["audio_filepath"]
|
||||
wav_crypt.decrypt_wav_path_to(src_wav_path, dst_wav_path)
|
||||
s["text"] = crypt_text
|
||||
return s
|
||||
|
||||
def decryted_gen():
|
||||
data = list(ExtendedPath(src_manifest).read_jsonl())
|
||||
iter_data = tqdm(data) if verbose else data
|
||||
decrypted_iter_data = parallel_apply(
|
||||
decrypt_item, iter_data, verbose=verbose, workers=64
|
||||
)
|
||||
for s in decrypted_iter_data:
|
||||
yield s
|
||||
|
||||
asr_manifest_writer(dest_manifest, decryted_gen(), verbose=verbose)
|
||||
|
||||
|
||||
@app.command()
|
||||
def migrate(src_path: Path, dest_path: Path):
|
||||
shutil.copytree(str(src_path), str(dest_path))
|
||||
wav_dir = dest_path / Path("wavs")
|
||||
wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
abs_manifest_path = ExtendedPath(dest_path / Path("abs_manifest.json"))
|
||||
backup_abs_manifest_path = abs_manifest_path.with_suffix(".json.orig")
|
||||
shutil.copy(abs_manifest_path, backup_abs_manifest_path)
|
||||
manifest_data = list(abs_manifest_path.read_jsonl())
|
||||
for md in manifest_data:
|
||||
orig_path = Path(md["audio_filepath"])
|
||||
new_path = wav_dir / Path(orig_path.name)
|
||||
shutil.copy(orig_path, new_path)
|
||||
md["audio_filepath"] = str(new_path)
|
||||
abs_manifest_path.write_jsonl(manifest_data)
|
||||
fix_path(dest_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
def task_split(
|
||||
data_dir: Path,
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
task_count: int = typer.Option(2, show_default=True),
|
||||
task_file: str = "task_dump",
|
||||
sort: bool = True,
|
||||
):
|
||||
"""
|
||||
split ui_dump.json to `task_count` tasks
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
processed_data_path = data_dir / dump_file
|
||||
processed_data = ExtendedPath(processed_data_path).read_json()
|
||||
df = (
|
||||
pd.DataFrame(processed_data["data"])
|
||||
.sample(frac=1)
|
||||
.reset_index(drop=True)
|
||||
)
|
||||
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
|
||||
task_f = task_f.reset_index(drop=True)
|
||||
task_f["real_idx"] = task_f.index
|
||||
task_data = task_f.to_dict("records")
|
||||
if sort:
|
||||
task_data = sorted(
|
||||
task_data, key=lambda x: x["asr_wer"], reverse=True
|
||||
)
|
||||
processed_data["data"] = task_data
|
||||
task_path = data_dir / Path(task_file + f"-{t_idx}.json")
|
||||
ExtendedPath(task_path).write_json(processed_data)
|
||||
|
||||
|
||||
def get_corrections(task_uid):
|
||||
col = get_mongo_conn(col="asr_validation")
|
||||
task_id = [
|
||||
c
|
||||
for c in col.distinct("task_id")
|
||||
if c.rsplit("-", 1)[1] == task_uid or c == task_uid
|
||||
][0]
|
||||
corrections = list(
|
||||
col.find({"type": "correction"}, projection={"_id": False})
|
||||
)
|
||||
cursor_obj = col.find(
|
||||
{"type": "correction", "task_id": task_id}, projection={"_id": False}
|
||||
)
|
||||
corrections = [c for c in cursor_obj]
|
||||
return corrections
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_task_corrections(data_dir: Path, task_uid: str):
|
||||
dump_fname: Path = Path(f"corrections-{task_uid}.json")
|
||||
dump_path = data_dir / dump_fname
|
||||
corrections = get_corrections(task_uid)
|
||||
ExtendedPath(dump_path).write_json(corrections)
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_all_corrections(data_dir: Path):
|
||||
for task_lcks in data_dir.glob("task-*.lck"):
|
||||
task_uid = task_lcks.stem.replace("task-", "")
|
||||
dump_task_corrections(data_dir, task_uid)
|
||||
|
||||
|
||||
@app.command()
|
||||
def update_corrections(
|
||||
data_dir: Path,
|
||||
skip_incorrect: bool = typer.Option(
|
||||
False, show_default=True, help="treats incorrect as invalid"
|
||||
),
|
||||
skip_inaudible: bool = typer.Option(
|
||||
False, show_default=True, help="include invalid as blank target"
|
||||
),
|
||||
):
|
||||
"""
|
||||
applies the corrections-*.json
|
||||
backup the original dataset
|
||||
"""
|
||||
manifest_file: Path = Path("manifest.json")
|
||||
renames_file: Path = Path("rename_map.json")
|
||||
ui_dump_file: Path = Path("ui_dump.json")
|
||||
data_manifest_path = data_dir / manifest_file
|
||||
renames_path = data_dir / renames_file
|
||||
|
||||
def correct_ui_dump(data_dir, rename_result):
|
||||
ui_dump_path = data_dir / ui_dump_file
|
||||
# corrections_path = data_dir / Path("corrections.json")
|
||||
corrections = [
|
||||
t
|
||||
for p in data_dir.glob("corrections-*.json")
|
||||
for t in ExtendedPath(p).read_json()
|
||||
]
|
||||
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
|
||||
correct_set = {
|
||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||
}
|
||||
correction_map = {
|
||||
c["code"]: c["value"]["correction"]
|
||||
for c in corrections
|
||||
if c["value"]["status"] == "Incorrect"
|
||||
}
|
||||
for d in ui_data:
|
||||
orig_audio_path = (data_dir / Path(d["audio_path"])).absolute()
|
||||
if d["utterance_id"] in correct_set:
|
||||
d["corrected_from"] = d["text"]
|
||||
yield d
|
||||
elif d["utterance_id"] in correction_map:
|
||||
correct_text = correction_map[d["utterance_id"]]
|
||||
if skip_incorrect:
|
||||
ap = d["audio_path"]
|
||||
print(
|
||||
f"skipping incorrect {ap} corrected to {correct_text}"
|
||||
)
|
||||
orig_audio_path.unlink()
|
||||
else:
|
||||
new_fname = tscript_uuid_fname(correct_text)
|
||||
rename_result[new_fname] = {
|
||||
"orig_text": d["text"],
|
||||
"correct_text": correct_text,
|
||||
"orig_id": d["utterance_id"],
|
||||
}
|
||||
new_name = str(Path(new_fname).with_suffix(".wav"))
|
||||
new_audio_path = orig_audio_path.with_name(new_name)
|
||||
orig_audio_path.replace(new_audio_path)
|
||||
new_filepath = str(
|
||||
Path(d["audio_path"]).with_name(new_name)
|
||||
)
|
||||
d["corrected_from"] = d["text"]
|
||||
d["text"] = correct_text
|
||||
d["audio_path"] = new_filepath
|
||||
yield d
|
||||
else:
|
||||
if skip_inaudible:
|
||||
orig_audio_path.unlink()
|
||||
else:
|
||||
d["corrected_from"] = d["text"]
|
||||
d["text"] = ""
|
||||
yield d
|
||||
|
||||
dataset_dir = data_manifest_path.parent
|
||||
dataset_name = dataset_dir.name
|
||||
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
|
||||
if not backup_dir.exists():
|
||||
typer.echo(f"backing up to {backup_dir}")
|
||||
shutil.copytree(str(dataset_dir), str(backup_dir))
|
||||
renames = {}
|
||||
corrected_ui_dump = list(correct_ui_dump(data_dir, renames))
|
||||
ExtendedPath(data_dir / ui_dump_file).write_json(
|
||||
{"data": corrected_ui_dump}
|
||||
)
|
||||
corrected_manifest = (
|
||||
{
|
||||
"audio_filepath": d["audio_path"],
|
||||
"duration": d["duration"],
|
||||
"text": d["text"],
|
||||
}
|
||||
for d in corrected_ui_dump
|
||||
)
|
||||
asr_manifest_writer(data_manifest_path, corrected_manifest)
|
||||
ExtendedPath(renames_path).write_json(renames)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
12
src/plume/cli/data/generate.py
Normal file
12
src/plume/cli/data/generate.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from ...utils.tts import GoogleTTS
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def tts_dataset(dest_path: Path):
|
||||
tts = GoogleTTS()
|
||||
pass
|
||||
14
src/plume/cli/eval.py
Normal file
14
src/plume/cli/eval.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import typer
|
||||
from ..models.wav2vec2.eval import app as wav2vec2_app
|
||||
from ..models.wav2vec2_transformers.eval import app as wav2vec2_transformers_app
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(wav2vec2_app, name="wav2vec2")
|
||||
app.add_typer(wav2vec2_transformers_app, name="wav2vec2_transformers")
|
||||
|
||||
|
||||
@app.callback()
|
||||
def eval():
|
||||
"""
|
||||
eval sub commands
|
||||
"""
|
||||
16
src/plume/cli/serve.py
Normal file
16
src/plume/cli/serve.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import typer
|
||||
from ..models.wav2vec2.serve import app as wav2vec2_app
|
||||
from ..models.wav2vec2_transformers.serve import app as wav2vec2_transformers_app
|
||||
from ..models.jasper_nemo.serve import app as jasper_app
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(wav2vec2_app, name="wav2vec2")
|
||||
app.add_typer(wav2vec2_transformers_app, name="wav2vec2_transformers")
|
||||
app.add_typer(jasper_app, name="jasper")
|
||||
|
||||
|
||||
@app.callback()
|
||||
def serve():
|
||||
"""
|
||||
serve sub commands
|
||||
"""
|
||||
12
src/plume/cli/train.py
Normal file
12
src/plume/cli/train.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import typer
|
||||
from ..models.wav2vec2.train import app as wav2vec2_app
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(wav2vec2_app, name="wav2vec2")
|
||||
|
||||
|
||||
@app.callback()
|
||||
def train():
|
||||
"""
|
||||
train sub commands
|
||||
"""
|
||||
1
src/plume/models/__init__.py
Normal file
1
src/plume/models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# from . import jasper, wav2vec2, matchboxnet
|
||||
0
src/plume/models/jasper_nemo/__init__.py
Normal file
0
src/plume/models/jasper_nemo/__init__.py
Normal file
132
src/plume/models/jasper_nemo/asr.py
Normal file
132
src/plume/models/jasper_nemo/asr.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import os
|
||||
import tempfile
|
||||
from ruamel.yaml import YAML
|
||||
import json
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import wave
|
||||
from nemo.collections.asr.helpers import post_process_predictions
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
WORK_DIR = "/tmp"
|
||||
|
||||
|
||||
class JasperASR(object):
|
||||
"""docstring for JasperASR."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_yaml,
|
||||
encoder_checkpoint,
|
||||
decoder_checkpoint,
|
||||
language_model=None,
|
||||
):
|
||||
super(JasperASR, self).__init__()
|
||||
# Read model YAML
|
||||
yaml = YAML(typ="safe")
|
||||
with open(model_yaml) as f:
|
||||
jasper_model_definition = yaml.load(f)
|
||||
self.neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
)
|
||||
self.labels = jasper_model_definition["labels"]
|
||||
self.data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor()
|
||||
self.jasper_encoder = nemo_asr.JasperEncoder(
|
||||
jasper=jasper_model_definition["JasperEncoder"]["jasper"],
|
||||
activation=jasper_model_definition["JasperEncoder"]["activation"],
|
||||
feat_in=jasper_model_definition[
|
||||
"AudioToMelSpectrogramPreprocessor"
|
||||
]["features"],
|
||||
)
|
||||
self.jasper_encoder.restore_from(encoder_checkpoint, local_rank=0)
|
||||
self.jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=1024, num_classes=len(self.labels)
|
||||
)
|
||||
self.jasper_decoder.restore_from(decoder_checkpoint, local_rank=0)
|
||||
self.greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
self.beam_search_with_lm = None
|
||||
if language_model:
|
||||
self.beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
|
||||
vocab=self.labels,
|
||||
beam_width=64,
|
||||
alpha=2.0,
|
||||
beta=1.0,
|
||||
lm_path=language_model,
|
||||
num_cpus=max(os.cpu_count(), 1),
|
||||
)
|
||||
|
||||
def transcribe(self, audio_data, greedy=True):
|
||||
audio_file = tempfile.NamedTemporaryFile(
|
||||
dir=WORK_DIR, prefix="jasper_audio.", delete=False
|
||||
)
|
||||
# audio_file.write(audio_data)
|
||||
audio_file.close()
|
||||
audio_file_path = audio_file.name
|
||||
wf = wave.open(audio_file_path, "w")
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(24000)
|
||||
wf.writeframesraw(audio_data)
|
||||
wf.close()
|
||||
manifest = {
|
||||
"audio_filepath": audio_file_path,
|
||||
"duration": 60,
|
||||
"text": "todo",
|
||||
}
|
||||
manifest_file = tempfile.NamedTemporaryFile(
|
||||
dir=WORK_DIR, prefix="jasper_manifest.", delete=False, mode="w"
|
||||
)
|
||||
manifest_file.write(json.dumps(manifest))
|
||||
manifest_file.close()
|
||||
manifest_file_path = manifest_file.name
|
||||
data_layer = nemo_asr.AudioToTextDataLayer(
|
||||
shuffle=False,
|
||||
manifest_filepath=manifest_file_path,
|
||||
labels=self.labels,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
# Define inference DAG
|
||||
audio_signal, audio_signal_len, _, _ = data_layer()
|
||||
processed_signal, processed_signal_len = self.data_preprocessor(
|
||||
input_signal=audio_signal, length=audio_signal_len
|
||||
)
|
||||
encoded, encoded_len = self.jasper_encoder(
|
||||
audio_signal=processed_signal, length=processed_signal_len
|
||||
)
|
||||
log_probs = self.jasper_decoder(encoder_output=encoded)
|
||||
predictions = self.greedy_decoder(log_probs=log_probs)
|
||||
|
||||
if greedy:
|
||||
eval_tensors = [predictions]
|
||||
else:
|
||||
if self.beam_search_with_lm:
|
||||
logging.info("Running with beam search")
|
||||
beam_predictions = self.beam_search_with_lm(
|
||||
log_probs=log_probs, log_probs_length=encoded_len
|
||||
)
|
||||
eval_tensors = [beam_predictions]
|
||||
else:
|
||||
logging.info(
|
||||
"language_model not specified. falling back to greedy decoding."
|
||||
)
|
||||
eval_tensors = [predictions]
|
||||
|
||||
tensors = self.neural_factory.infer(tensors=eval_tensors)
|
||||
prediction = post_process_predictions(tensors[0], self.labels)
|
||||
prediction_text = ". ".join(prediction)
|
||||
os.unlink(manifest_file.name)
|
||||
os.unlink(audio_file.name)
|
||||
return prediction_text
|
||||
|
||||
def transcribe_file(self, audio_file, *args, **kwargs):
|
||||
tscript_file_path = audio_file.with_suffix(".txt")
|
||||
audio_file_path = str(audio_file)
|
||||
with wave.open(audio_file_path, "r") as af:
|
||||
frame_count = af.getnframes()
|
||||
audio_data = af.readframes(frame_count)
|
||||
transcription = self.transcribe(audio_data, *args, **kwargs)
|
||||
with open(tscript_file_path, "w") as tf:
|
||||
tf.write(transcription)
|
||||
24
src/plume/models/jasper_nemo/data.py
Normal file
24
src/plume/models/jasper_nemo/data.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from pathlib import Path
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def set_root(dataset_path: Path, root_path: Path):
|
||||
pass
|
||||
# for dataset_kind in ["train", "valid"]:
|
||||
# data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||
# with data_file.open("r") as df:
|
||||
# lines = df.readlines()
|
||||
# with data_file.open("w") as df:
|
||||
# lines[0] = str(root_path) + "\n"
|
||||
# df.writelines(lines)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
340
src/plume/models/jasper_nemo/data_loaders.py
Normal file
340
src/plume/models/jasper_nemo/data_loaders.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from functools import partial
|
||||
import tempfile
|
||||
|
||||
# from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import nemo
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.backends.pytorch import DataLayerNM
|
||||
from nemo.core import DeviceType
|
||||
|
||||
# from nemo.core.neural_types import *
|
||||
from nemo.core.neural_types import (
|
||||
NeuralType,
|
||||
AudioSignal,
|
||||
LengthsType,
|
||||
LabelsType,
|
||||
)
|
||||
from nemo.utils.decorators import add_port_docs
|
||||
|
||||
from nemo.collections.asr.parts.dataset import (
|
||||
# AudioDataset,
|
||||
# AudioLabelDataset,
|
||||
# KaldiFeatureDataset,
|
||||
# TranscriptDataset,
|
||||
parsers,
|
||||
collections,
|
||||
seq_collate_fn,
|
||||
)
|
||||
|
||||
# from functools import lru_cache
|
||||
import rpyc
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
from .featurizer import RpycWaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
|
||||
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
class CachedAudioDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset that loads tensors via a json file containing paths to audio
|
||||
files, transcripts, and durations (in seconds). Each new line is a
|
||||
different sample. Example below:
|
||||
|
||||
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
|
||||
"/path/to/audio.txt", "duration": 23.147}
|
||||
...
|
||||
{"audio_filepath": "/path/to/audio.wav", "text": "the
|
||||
transcription", offset": 301.75, "duration": 0.82, "utt":
|
||||
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
||||
|
||||
Args:
|
||||
manifest_filepath: Path to manifest json as described above. Can
|
||||
be comma-separated paths.
|
||||
labels: String containing all the possible characters to map to
|
||||
featurizer: Initialized featurizer class that converts paths of
|
||||
audio to feature tensors
|
||||
max_duration: If audio exceeds this length, do not include in dataset
|
||||
min_duration: If audio is less than this length, do not include
|
||||
in dataset
|
||||
max_utts: Limit number of utterances
|
||||
blank_index: blank character index, default = -1
|
||||
unk_index: unk_character index, default = -1
|
||||
normalize: whether to normalize transcript text (default): True
|
||||
bos_id: Id of beginning of sequence symbol to append if not None
|
||||
eos_id: Id of end of sequence symbol to append if not None
|
||||
load_audio: Boolean flag indicate whether do or not load audio
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
featurizer,
|
||||
max_duration=None,
|
||||
min_duration=None,
|
||||
max_utts=0,
|
||||
blank_index=-1,
|
||||
unk_index=-1,
|
||||
normalize=True,
|
||||
trim=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
load_audio=True,
|
||||
parser="en",
|
||||
):
|
||||
self.collection = collections.ASRAudioText(
|
||||
manifests_files=manifest_filepath.split(","),
|
||||
parser=parsers.make_parser(
|
||||
labels=labels,
|
||||
name=parser,
|
||||
unk_id=unk_index,
|
||||
blank_id=blank_index,
|
||||
do_normalize=normalize,
|
||||
),
|
||||
min_duration=min_duration,
|
||||
max_duration=max_duration,
|
||||
max_number=max_utts,
|
||||
)
|
||||
self.index_feature_map = {}
|
||||
|
||||
self.featurizer = featurizer
|
||||
self.trim = trim
|
||||
self.eos_id = eos_id
|
||||
self.bos_id = bos_id
|
||||
self.load_audio = load_audio
|
||||
print(f"initializing dataset {manifest_filepath}")
|
||||
|
||||
def exec_func(i):
|
||||
return self[i]
|
||||
|
||||
task_count = len(self.collection)
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all loading tasks")
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, range(task_count)),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=task_count,
|
||||
)
|
||||
)
|
||||
print(f"initializing complete")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.collection[index]
|
||||
if self.load_audio:
|
||||
cached_features = self.index_feature_map.get(index)
|
||||
if cached_features is not None:
|
||||
features = cached_features
|
||||
else:
|
||||
features = self.featurizer.process(
|
||||
sample.audio_file,
|
||||
offset=0,
|
||||
duration=sample.duration,
|
||||
trim=self.trim,
|
||||
)
|
||||
self.index_feature_map[index] = features
|
||||
f, fl = features, torch.tensor(features.shape[0]).long()
|
||||
else:
|
||||
f, fl = None, None
|
||||
|
||||
t, tl = sample.text_tokens, len(sample.text_tokens)
|
||||
if self.bos_id is not None:
|
||||
t = [self.bos_id] + t
|
||||
tl += 1
|
||||
if self.eos_id is not None:
|
||||
t = t + [self.eos_id]
|
||||
tl += 1
|
||||
|
||||
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.collection)
|
||||
|
||||
|
||||
class RpycAudioToTextDataLayer(DataLayerNM):
|
||||
"""Data Layer for general ASR tasks.
|
||||
|
||||
Module which reads ASR labeled data. It accepts comma-separated
|
||||
JSON manifest files describing the correspondence between wav audio files
|
||||
and their transcripts. JSON files should be of the following format::
|
||||
|
||||
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
|
||||
transcript_0}
|
||||
...
|
||||
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
|
||||
transcript_n}
|
||||
|
||||
Args:
|
||||
manifest_filepath (str): Dataset parameter.
|
||||
Path to JSON containing data.
|
||||
labels (list): Dataset parameter.
|
||||
List of characters that can be output by the ASR model.
|
||||
For Jasper, this is the 28 character set {a-z '}. The CTC blank
|
||||
symbol is automatically added later for models using ctc.
|
||||
batch_size (int): batch size
|
||||
sample_rate (int): Target sampling rate for data. Audio files will be
|
||||
resampled to sample_rate if it is not already.
|
||||
Defaults to 16000.
|
||||
int_values (bool): Bool indicating whether the audio file is saved as
|
||||
int data or float data.
|
||||
Defaults to False.
|
||||
eos_id (id): Dataset parameter.
|
||||
End of string symbol id used for seq2seq models.
|
||||
Defaults to None.
|
||||
min_duration (float): Dataset parameter.
|
||||
All training files which have a duration less than min_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to 0.1.
|
||||
max_duration (float): Dataset parameter.
|
||||
All training files which have a duration more than max_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to None.
|
||||
normalize_transcripts (bool): Dataset parameter.
|
||||
Whether to use automatic text cleaning.
|
||||
It is highly recommended to manually clean text for best results.
|
||||
Defaults to True.
|
||||
trim_silence (bool): Whether to use trim silence from beginning and end
|
||||
of audio signal using librosa.effects.trim().
|
||||
Defaults to False.
|
||||
load_audio (bool): Dataset parameter.
|
||||
Controls whether the dataloader loads the audio signal and
|
||||
transcript or just the transcript.
|
||||
Defaults to True.
|
||||
drop_last (bool): See PyTorch DataLoader.
|
||||
Defaults to False.
|
||||
shuffle (bool): See PyTorch DataLoader.
|
||||
Defaults to True.
|
||||
num_workers (int): See PyTorch DataLoader.
|
||||
Defaults to 0.
|
||||
perturb_config (dict): Currently disabled.
|
||||
"""
|
||||
|
||||
@property
|
||||
@add_port_docs()
|
||||
def output_ports(self):
|
||||
"""Returns definitions of module output ports."""
|
||||
return {
|
||||
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
"audio_signal": NeuralType(
|
||||
("B", "T"),
|
||||
AudioSignal(freq=self._sample_rate)
|
||||
if self is not None and self._sample_rate is not None
|
||||
else AudioSignal(),
|
||||
),
|
||||
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
|
||||
"transcripts": NeuralType(("B", "T"), LabelsType()),
|
||||
"transcript_length": NeuralType(tuple("B"), LengthsType()),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
batch_size,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
pad_id=None,
|
||||
min_duration=0.1,
|
||||
max_duration=None,
|
||||
normalize_transcripts=True,
|
||||
trim_silence=False,
|
||||
load_audio=True,
|
||||
rpyc_host="",
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
):
|
||||
super().__init__()
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
def rpyc_root_fn():
|
||||
return rpyc.connect(
|
||||
rpyc_host, 8064, config={"sync_request_timeout": 600}
|
||||
).root
|
||||
|
||||
rpyc_conn = rpyc_root_fn()
|
||||
|
||||
self._featurizer = RpycWaveformFeaturizer(
|
||||
sample_rate=self._sample_rate,
|
||||
int_values=int_values,
|
||||
augmentor=None,
|
||||
rpyc_conn=rpyc_conn,
|
||||
)
|
||||
|
||||
def read_remote_manifests():
|
||||
local_mp = []
|
||||
for mrp in manifest_filepath.split(","):
|
||||
md = rpyc_conn.read_path(mrp)
|
||||
mf = tempfile.NamedTemporaryFile(
|
||||
dir="/tmp", prefix="jasper_manifest.", delete=False
|
||||
)
|
||||
mf.write(md)
|
||||
mf.close()
|
||||
local_mp.append(mf.name)
|
||||
return ",".join(local_mp)
|
||||
|
||||
local_manifest_filepath = read_remote_manifests()
|
||||
dataset_params = {
|
||||
"manifest_filepath": local_manifest_filepath,
|
||||
"labels": labels,
|
||||
"featurizer": self._featurizer,
|
||||
"max_duration": max_duration,
|
||||
"min_duration": min_duration,
|
||||
"normalize": normalize_transcripts,
|
||||
"trim": trim_silence,
|
||||
"bos_id": bos_id,
|
||||
"eos_id": eos_id,
|
||||
"load_audio": load_audio,
|
||||
}
|
||||
|
||||
self._dataset = CachedAudioDataset(**dataset_params)
|
||||
self._batch_size = batch_size
|
||||
|
||||
# Set up data loader
|
||||
if self._placement == DeviceType.AllGpu:
|
||||
logging.info("Parallelizing Datalayer.")
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
self._dataset
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
if batch_size == -1:
|
||||
batch_size = len(self._dataset)
|
||||
|
||||
pad_id = 0 if pad_id is None else pad_id
|
||||
self._dataloader = torch.utils.data.DataLoader(
|
||||
dataset=self._dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
|
||||
drop_last=drop_last,
|
||||
shuffle=shuffle if sampler is None else False,
|
||||
sampler=sampler,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def data_iterator(self):
|
||||
return self._dataloader
|
||||
376
src/plume/models/jasper_nemo/eval.py
Normal file
376
src/plume/models/jasper_nemo/eval.py
Normal file
@@ -0,0 +1,376 @@
|
||||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
# import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
# monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
# from nemo.utils.lr_policies import CosineAnnealing
|
||||
from training.data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="encoder checkpoint file: JasperEncoder.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="decoder checkpoint file: JasperDecoderForCTC.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
# data_layer = data_loader_layer(
|
||||
# manifest_filepath=args.train_dataset,
|
||||
# sample_rate=sample_rate,
|
||||
# labels=vocab,
|
||||
# batch_size=args.batch_size,
|
||||
# num_workers=cpu_per_traindl,
|
||||
# **train_dl_params,
|
||||
# # normalize_transcripts=False
|
||||
# )
|
||||
#
|
||||
# N = len(data_layer)
|
||||
# steps_per_epoch = math.ceil(
|
||||
# N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
# )
|
||||
# logging.info("Have {0} examples to train on.".format(N))
|
||||
#
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate,
|
||||
**jasper_params["AudioToMelSpectrogramPreprocessor"],
|
||||
)
|
||||
|
||||
# multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
# if multiply_batch_config:
|
||||
# multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
#
|
||||
# spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
# if spectr_augment_config:
|
||||
# data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
# **spectr_augment_config
|
||||
# )
|
||||
#
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
# if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
# else:
|
||||
# logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
# logging.info("================================")
|
||||
# logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
|
||||
# logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
|
||||
# logging.info(
|
||||
# f"Total number of parameters in model: "
|
||||
# f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
# )
|
||||
# logging.info("================================")
|
||||
#
|
||||
# # Train DAG
|
||||
# (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
|
||||
# processed_signal_t, p_length_t = data_preprocessor(
|
||||
# input_signal=audio_signal_t, length=a_sig_length_t
|
||||
# )
|
||||
#
|
||||
# if multiply_batch_config:
|
||||
# (
|
||||
# processed_signal_t,
|
||||
# p_length_t,
|
||||
# transcript_t,
|
||||
# transcript_len_t,
|
||||
# ) = multiply_batch(
|
||||
# in_x=processed_signal_t,
|
||||
# in_x_len=p_length_t,
|
||||
# in_y=transcript_t,
|
||||
# in_y_len=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# if spectr_augment_config:
|
||||
# processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
|
||||
#
|
||||
# encoded_t, encoded_len_t = jasper_encoder(
|
||||
# audio_signal=processed_signal_t, length=p_length_t
|
||||
# )
|
||||
# log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
# predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
# loss_t = ctc_loss(
|
||||
# log_probs=log_probs_t,
|
||||
# targets=transcript_t,
|
||||
# input_length=encoded_len_t,
|
||||
# target_length=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# # Callbacks needed to print info to console and Tensorboard
|
||||
# train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
# tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
# print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
# get_tb_values=lambda x: [("loss", x[0])],
|
||||
# tb_writer=neural_factory.tb_writer,
|
||||
# )
|
||||
#
|
||||
# chpt_callback = nemo.core.CheckpointCallback(
|
||||
# folder=neural_factory.checkpoint_dir,
|
||||
# load_from_folder=args.load_dir,
|
||||
# step_freq=args.checkpoint_save_freq,
|
||||
# checkpoints_to_keep=30,
|
||||
# )
|
||||
#
|
||||
# callbacks = [train_callback, chpt_callback]
|
||||
callbacks = []
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(
|
||||
audio_signal_e,
|
||||
a_sig_length_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[
|
||||
loss_e,
|
||||
predictions_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(
|
||||
process_evaluation_epoch, tag=tagname
|
||||
),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return callbacks
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
# name = construct_name(
|
||||
# args.exp_name,
|
||||
# args.lr,
|
||||
# args.batch_size,
|
||||
# args.max_steps,
|
||||
# args.num_epochs,
|
||||
# args.weight_decay,
|
||||
# args.optimizer,
|
||||
# args.iter_per_step,
|
||||
# )
|
||||
# log_dir = name
|
||||
# if args.work_dir:
|
||||
# log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
# local_rank=args.local_rank,
|
||||
# optimization_level=args.amp_opt_level,
|
||||
# log_dir=log_dir,
|
||||
# checkpoint_dir=args.checkpoint_dir,
|
||||
# create_tb_writer=args.create_tb_writer,
|
||||
# files_to_copy=[args.model_config, __file__],
|
||||
# cudnn_benchmark=args.cudnn_benchmark,
|
||||
# tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
# checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
callbacks = create_all_dags(args, neural_factory)
|
||||
# evaluate model
|
||||
neural_factory.eval(callbacks=callbacks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
61
src/plume/models/jasper_nemo/featurizer.py
Normal file
61
src/plume/models/jasper_nemo/featurizer.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# import math
|
||||
|
||||
# import librosa
|
||||
import torch
|
||||
import pickle
|
||||
|
||||
# import torch.nn as nn
|
||||
# from torch_stft import STFT
|
||||
|
||||
# from nemo import logging
|
||||
from nemo.collections.asr.parts.perturb import AudioAugmentor
|
||||
|
||||
# from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
|
||||
class RpycWaveformFeaturizer(object):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
augmentor=None,
|
||||
rpyc_conn=None,
|
||||
):
|
||||
self.augmentor = (
|
||||
augmentor if augmentor is not None else AudioAugmentor()
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.int_values = int_values
|
||||
self.remote_path_samples = rpyc_conn.get_path_samples
|
||||
|
||||
def max_augmentation_length(self, length):
|
||||
return self.augmentor.max_augmentation_length(length)
|
||||
|
||||
def process(self, file_path, offset=0, duration=0, trim=False):
|
||||
audio = self.remote_path_samples(
|
||||
file_path,
|
||||
target_sr=self.sample_rate,
|
||||
int_values=self.int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
return torch.tensor(pickle.loads(audio), dtype=torch.float)
|
||||
|
||||
def process_segment(self, audio_segment):
|
||||
self.augmentor.perturb(audio_segment)
|
||||
return torch.tensor(audio_segment, dtype=torch.float)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, input_config, perturbation_configs=None):
|
||||
if perturbation_configs is not None:
|
||||
aa = AudioAugmentor.from_config(perturbation_configs)
|
||||
else:
|
||||
aa = None
|
||||
|
||||
sample_rate = input_config.get("sample_rate", 16000)
|
||||
int_values = input_config.get("int_values", False)
|
||||
|
||||
return cls(
|
||||
sample_rate=sample_rate, int_values=int_values, augmentor=aa
|
||||
)
|
||||
54
src/plume/models/jasper_nemo/serve.py
Normal file
54
src/plume/models/jasper_nemo/serve.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import typer
|
||||
|
||||
# from .asr import JasperASR
|
||||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
JasperASR = lazy_callable("plume.models.jasper_nemo.asr.JasperASR")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc(
|
||||
encoder_path: Path = "/path/to/encoder.pt",
|
||||
decoder_path: Path = "/path/to/decoder.pt",
|
||||
model_yaml_path: Path = "/path/to/model.yaml",
|
||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
):
|
||||
for p in [encoder_path, decoder_path, model_yaml_path]:
|
||||
if not p.exists():
|
||||
logging.info(f"{p} doesn't exists")
|
||||
return
|
||||
asr = JasperASR(str(model_yaml_path), str(encoder_path), str(decoder_path))
|
||||
service = ASRService(asr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logging.info("starting asr server...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(
|
||||
model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||
):
|
||||
encoder_path = model_dir / Path("decoder.pt")
|
||||
decoder_path = model_dir / Path("encoder.pt")
|
||||
model_yaml_path = model_dir / Path("model.yaml")
|
||||
rpyc(encoder_path, decoder_path, model_yaml_path, port)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
59
src/plume/models/jasper_nemo/serve_data.py
Normal file
59
src/plume/models/jasper_nemo/serve_data.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import rpyc
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import nemo
|
||||
import pickle
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
|
||||
)
|
||||
|
||||
|
||||
class ASRDataService(rpyc.Service):
|
||||
def exposed_get_path_samples(
|
||||
self, file_path, target_sr, int_values, offset, duration, trim
|
||||
):
|
||||
print(f"loading.. {file_path}")
|
||||
audio = AudioSegment.from_file(
|
||||
file_path,
|
||||
target_sr=target_sr,
|
||||
int_values=int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
|
||||
return pickle.dumps(audio.samples)
|
||||
|
||||
def exposed_read_path(self, file_path):
|
||||
# print(f"reading path.. {file_path}")
|
||||
return Path(file_path).read_bytes()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run_server(port: int = 0):
|
||||
listen_port = (
|
||||
port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
|
||||
)
|
||||
service = ASRDataService()
|
||||
t = ThreadedServer(
|
||||
service, port=listen_port, protocol_config={"allow_all_attrs": True}
|
||||
)
|
||||
typer.echo(f"starting asr server on {listen_port}...")
|
||||
t.start()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
392
src/plume/models/jasper_nemo/train.py
Normal file
392
src/plume/models/jasper_nemo/train.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
from nemo.utils.lr_policies import CosineAnnealing
|
||||
from .data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper-speller",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
data_layer = data_loader_layer(
|
||||
manifest_filepath=args.train_dataset,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**train_dl_params,
|
||||
# normalize_transcripts=False
|
||||
)
|
||||
|
||||
N = len(data_layer)
|
||||
steps_per_epoch = math.ceil(
|
||||
N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
)
|
||||
logging.info("Have {0} examples to train on.".format(N))
|
||||
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate,
|
||||
**jasper_params["AudioToMelSpectrogramPreprocessor"],
|
||||
)
|
||||
|
||||
multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
if multiply_batch_config:
|
||||
multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
|
||||
spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
if spectr_augment_config:
|
||||
data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
**spectr_augment_config
|
||||
)
|
||||
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
else:
|
||||
logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
logging.info("================================")
|
||||
logging.info(
|
||||
f"Number of parameters in encoder: {jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info(
|
||||
f"Number of parameters in decoder: {jasper_decoder.num_weights}"
|
||||
)
|
||||
logging.info(
|
||||
f"Total number of parameters in model: "
|
||||
f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info("================================")
|
||||
|
||||
# Train DAG
|
||||
(
|
||||
audio_signal_t,
|
||||
a_sig_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = data_layer()
|
||||
processed_signal_t, p_length_t = data_preprocessor(
|
||||
input_signal=audio_signal_t, length=a_sig_length_t
|
||||
)
|
||||
|
||||
if multiply_batch_config:
|
||||
(
|
||||
processed_signal_t,
|
||||
p_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = multiply_batch(
|
||||
in_x=processed_signal_t,
|
||||
in_x_len=p_length_t,
|
||||
in_y=transcript_t,
|
||||
in_y_len=transcript_len_t,
|
||||
)
|
||||
|
||||
if spectr_augment_config:
|
||||
processed_signal_t = data_spectr_augmentation(
|
||||
input_spec=processed_signal_t
|
||||
)
|
||||
|
||||
encoded_t, encoded_len_t = jasper_encoder(
|
||||
audio_signal=processed_signal_t, length=p_length_t
|
||||
)
|
||||
log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
loss_t = ctc_loss(
|
||||
log_probs=log_probs_t,
|
||||
targets=transcript_t,
|
||||
input_length=encoded_len_t,
|
||||
target_length=transcript_len_t,
|
||||
)
|
||||
|
||||
# Callbacks needed to print info to console and Tensorboard
|
||||
train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
get_tb_values=lambda x: [("loss", x[0])],
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
chpt_callback = nemo.core.CheckpointCallback(
|
||||
folder=neural_factory.checkpoint_dir,
|
||||
load_from_folder=args.load_dir,
|
||||
step_freq=args.checkpoint_save_freq,
|
||||
checkpoints_to_keep=30,
|
||||
)
|
||||
|
||||
callbacks = [train_callback, chpt_callback]
|
||||
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(
|
||||
audio_signal_e,
|
||||
a_sig_length_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[
|
||||
loss_e,
|
||||
predictions_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(
|
||||
process_evaluation_epoch, tag=tagname
|
||||
),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return loss_t, callbacks, steps_per_epoch
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
name = construct_name(
|
||||
args.exp_name,
|
||||
args.lr,
|
||||
args.batch_size,
|
||||
args.max_steps,
|
||||
args.num_epochs,
|
||||
args.weight_decay,
|
||||
args.optimizer,
|
||||
args.iter_per_step,
|
||||
)
|
||||
log_dir = name
|
||||
if args.work_dir:
|
||||
log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
local_rank=args.local_rank,
|
||||
optimization_level=args.amp_opt_level,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
create_tb_writer=args.create_tb_writer,
|
||||
files_to_copy=[args.model_config, __file__],
|
||||
cudnn_benchmark=args.cudnn_benchmark,
|
||||
tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
train_loss, callbacks, steps_per_epoch = create_all_dags(
|
||||
args, neural_factory
|
||||
)
|
||||
# train model
|
||||
neural_factory.train(
|
||||
tensors_to_optimize=[train_loss],
|
||||
callbacks=callbacks,
|
||||
lr_policy=CosineAnnealing(
|
||||
args.max_steps
|
||||
if args.max_steps is not None
|
||||
else args.num_epochs * steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
),
|
||||
optimizer=args.optimizer,
|
||||
optimization_params={
|
||||
"num_epochs": args.num_epochs,
|
||||
"max_steps": args.max_steps,
|
||||
"lr": args.lr,
|
||||
"betas": (args.beta1, args.beta2),
|
||||
"weight_decay": args.weight_decay,
|
||||
"grad_norm_clip": None,
|
||||
},
|
||||
batches_per_step=args.iter_per_step,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
src/plume/models/marblenet_nemo/__init__.py
Normal file
0
src/plume/models/marblenet_nemo/__init__.py
Normal file
132
src/plume/models/marblenet_nemo/asr.py
Normal file
132
src/plume/models/marblenet_nemo/asr.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import os
|
||||
import tempfile
|
||||
from ruamel.yaml import YAML
|
||||
import json
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import wave
|
||||
from nemo.collections.asr.helpers import post_process_predictions
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
WORK_DIR = "/tmp"
|
||||
|
||||
|
||||
class JasperASR(object):
|
||||
"""docstring for JasperASR."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_yaml,
|
||||
encoder_checkpoint,
|
||||
decoder_checkpoint,
|
||||
language_model=None,
|
||||
):
|
||||
super(JasperASR, self).__init__()
|
||||
# Read model YAML
|
||||
yaml = YAML(typ="safe")
|
||||
with open(model_yaml) as f:
|
||||
jasper_model_definition = yaml.load(f)
|
||||
self.neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
)
|
||||
self.labels = jasper_model_definition["labels"]
|
||||
self.data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor()
|
||||
self.jasper_encoder = nemo_asr.JasperEncoder(
|
||||
jasper=jasper_model_definition["JasperEncoder"]["jasper"],
|
||||
activation=jasper_model_definition["JasperEncoder"]["activation"],
|
||||
feat_in=jasper_model_definition[
|
||||
"AudioToMelSpectrogramPreprocessor"
|
||||
]["features"],
|
||||
)
|
||||
self.jasper_encoder.restore_from(encoder_checkpoint, local_rank=0)
|
||||
self.jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=1024, num_classes=len(self.labels)
|
||||
)
|
||||
self.jasper_decoder.restore_from(decoder_checkpoint, local_rank=0)
|
||||
self.greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
self.beam_search_with_lm = None
|
||||
if language_model:
|
||||
self.beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
|
||||
vocab=self.labels,
|
||||
beam_width=64,
|
||||
alpha=2.0,
|
||||
beta=1.0,
|
||||
lm_path=language_model,
|
||||
num_cpus=max(os.cpu_count(), 1),
|
||||
)
|
||||
|
||||
def transcribe(self, audio_data, greedy=True):
|
||||
audio_file = tempfile.NamedTemporaryFile(
|
||||
dir=WORK_DIR, prefix="jasper_audio.", delete=False
|
||||
)
|
||||
# audio_file.write(audio_data)
|
||||
audio_file.close()
|
||||
audio_file_path = audio_file.name
|
||||
wf = wave.open(audio_file_path, "w")
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(24000)
|
||||
wf.writeframesraw(audio_data)
|
||||
wf.close()
|
||||
manifest = {
|
||||
"audio_filepath": audio_file_path,
|
||||
"duration": 60,
|
||||
"text": "todo",
|
||||
}
|
||||
manifest_file = tempfile.NamedTemporaryFile(
|
||||
dir=WORK_DIR, prefix="jasper_manifest.", delete=False, mode="w"
|
||||
)
|
||||
manifest_file.write(json.dumps(manifest))
|
||||
manifest_file.close()
|
||||
manifest_file_path = manifest_file.name
|
||||
data_layer = nemo_asr.AudioToTextDataLayer(
|
||||
shuffle=False,
|
||||
manifest_filepath=manifest_file_path,
|
||||
labels=self.labels,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
# Define inference DAG
|
||||
audio_signal, audio_signal_len, _, _ = data_layer()
|
||||
processed_signal, processed_signal_len = self.data_preprocessor(
|
||||
input_signal=audio_signal, length=audio_signal_len
|
||||
)
|
||||
encoded, encoded_len = self.jasper_encoder(
|
||||
audio_signal=processed_signal, length=processed_signal_len
|
||||
)
|
||||
log_probs = self.jasper_decoder(encoder_output=encoded)
|
||||
predictions = self.greedy_decoder(log_probs=log_probs)
|
||||
|
||||
if greedy:
|
||||
eval_tensors = [predictions]
|
||||
else:
|
||||
if self.beam_search_with_lm:
|
||||
logging.info("Running with beam search")
|
||||
beam_predictions = self.beam_search_with_lm(
|
||||
log_probs=log_probs, log_probs_length=encoded_len
|
||||
)
|
||||
eval_tensors = [beam_predictions]
|
||||
else:
|
||||
logging.info(
|
||||
"language_model not specified. falling back to greedy decoding."
|
||||
)
|
||||
eval_tensors = [predictions]
|
||||
|
||||
tensors = self.neural_factory.infer(tensors=eval_tensors)
|
||||
prediction = post_process_predictions(tensors[0], self.labels)
|
||||
prediction_text = ". ".join(prediction)
|
||||
os.unlink(manifest_file.name)
|
||||
os.unlink(audio_file.name)
|
||||
return prediction_text
|
||||
|
||||
def transcribe_file(self, audio_file, *args, **kwargs):
|
||||
tscript_file_path = audio_file.with_suffix(".txt")
|
||||
audio_file_path = str(audio_file)
|
||||
with wave.open(audio_file_path, "r") as af:
|
||||
frame_count = af.getnframes()
|
||||
audio_data = af.readframes(frame_count)
|
||||
transcription = self.transcribe(audio_data, *args, **kwargs)
|
||||
with open(tscript_file_path, "w") as tf:
|
||||
tf.write(transcription)
|
||||
24
src/plume/models/marblenet_nemo/data.py
Normal file
24
src/plume/models/marblenet_nemo/data.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from pathlib import Path
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def set_root(dataset_path: Path, root_path: Path):
|
||||
pass
|
||||
# for dataset_kind in ["train", "valid"]:
|
||||
# data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||
# with data_file.open("r") as df:
|
||||
# lines = df.readlines()
|
||||
# with data_file.open("w") as df:
|
||||
# lines[0] = str(root_path) + "\n"
|
||||
# df.writelines(lines)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
340
src/plume/models/marblenet_nemo/data_loaders.py
Normal file
340
src/plume/models/marblenet_nemo/data_loaders.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from functools import partial
|
||||
import tempfile
|
||||
|
||||
# from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import nemo
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.backends.pytorch import DataLayerNM
|
||||
from nemo.core import DeviceType
|
||||
|
||||
# from nemo.core.neural_types import *
|
||||
from nemo.core.neural_types import (
|
||||
NeuralType,
|
||||
AudioSignal,
|
||||
LengthsType,
|
||||
LabelsType,
|
||||
)
|
||||
from nemo.utils.decorators import add_port_docs
|
||||
|
||||
from nemo.collections.asr.parts.dataset import (
|
||||
# AudioDataset,
|
||||
# AudioLabelDataset,
|
||||
# KaldiFeatureDataset,
|
||||
# TranscriptDataset,
|
||||
parsers,
|
||||
collections,
|
||||
seq_collate_fn,
|
||||
)
|
||||
|
||||
# from functools import lru_cache
|
||||
import rpyc
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
from .featurizer import RpycWaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
|
||||
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
class CachedAudioDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset that loads tensors via a json file containing paths to audio
|
||||
files, transcripts, and durations (in seconds). Each new line is a
|
||||
different sample. Example below:
|
||||
|
||||
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
|
||||
"/path/to/audio.txt", "duration": 23.147}
|
||||
...
|
||||
{"audio_filepath": "/path/to/audio.wav", "text": "the
|
||||
transcription", offset": 301.75, "duration": 0.82, "utt":
|
||||
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
||||
|
||||
Args:
|
||||
manifest_filepath: Path to manifest json as described above. Can
|
||||
be comma-separated paths.
|
||||
labels: String containing all the possible characters to map to
|
||||
featurizer: Initialized featurizer class that converts paths of
|
||||
audio to feature tensors
|
||||
max_duration: If audio exceeds this length, do not include in dataset
|
||||
min_duration: If audio is less than this length, do not include
|
||||
in dataset
|
||||
max_utts: Limit number of utterances
|
||||
blank_index: blank character index, default = -1
|
||||
unk_index: unk_character index, default = -1
|
||||
normalize: whether to normalize transcript text (default): True
|
||||
bos_id: Id of beginning of sequence symbol to append if not None
|
||||
eos_id: Id of end of sequence symbol to append if not None
|
||||
load_audio: Boolean flag indicate whether do or not load audio
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
featurizer,
|
||||
max_duration=None,
|
||||
min_duration=None,
|
||||
max_utts=0,
|
||||
blank_index=-1,
|
||||
unk_index=-1,
|
||||
normalize=True,
|
||||
trim=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
load_audio=True,
|
||||
parser="en",
|
||||
):
|
||||
self.collection = collections.ASRAudioText(
|
||||
manifests_files=manifest_filepath.split(","),
|
||||
parser=parsers.make_parser(
|
||||
labels=labels,
|
||||
name=parser,
|
||||
unk_id=unk_index,
|
||||
blank_id=blank_index,
|
||||
do_normalize=normalize,
|
||||
),
|
||||
min_duration=min_duration,
|
||||
max_duration=max_duration,
|
||||
max_number=max_utts,
|
||||
)
|
||||
self.index_feature_map = {}
|
||||
|
||||
self.featurizer = featurizer
|
||||
self.trim = trim
|
||||
self.eos_id = eos_id
|
||||
self.bos_id = bos_id
|
||||
self.load_audio = load_audio
|
||||
print(f"initializing dataset {manifest_filepath}")
|
||||
|
||||
def exec_func(i):
|
||||
return self[i]
|
||||
|
||||
task_count = len(self.collection)
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all loading tasks")
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, range(task_count)),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=task_count,
|
||||
)
|
||||
)
|
||||
print(f"initializing complete")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.collection[index]
|
||||
if self.load_audio:
|
||||
cached_features = self.index_feature_map.get(index)
|
||||
if cached_features is not None:
|
||||
features = cached_features
|
||||
else:
|
||||
features = self.featurizer.process(
|
||||
sample.audio_file,
|
||||
offset=0,
|
||||
duration=sample.duration,
|
||||
trim=self.trim,
|
||||
)
|
||||
self.index_feature_map[index] = features
|
||||
f, fl = features, torch.tensor(features.shape[0]).long()
|
||||
else:
|
||||
f, fl = None, None
|
||||
|
||||
t, tl = sample.text_tokens, len(sample.text_tokens)
|
||||
if self.bos_id is not None:
|
||||
t = [self.bos_id] + t
|
||||
tl += 1
|
||||
if self.eos_id is not None:
|
||||
t = t + [self.eos_id]
|
||||
tl += 1
|
||||
|
||||
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.collection)
|
||||
|
||||
|
||||
class RpycAudioToTextDataLayer(DataLayerNM):
|
||||
"""Data Layer for general ASR tasks.
|
||||
|
||||
Module which reads ASR labeled data. It accepts comma-separated
|
||||
JSON manifest files describing the correspondence between wav audio files
|
||||
and their transcripts. JSON files should be of the following format::
|
||||
|
||||
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
|
||||
transcript_0}
|
||||
...
|
||||
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
|
||||
transcript_n}
|
||||
|
||||
Args:
|
||||
manifest_filepath (str): Dataset parameter.
|
||||
Path to JSON containing data.
|
||||
labels (list): Dataset parameter.
|
||||
List of characters that can be output by the ASR model.
|
||||
For Jasper, this is the 28 character set {a-z '}. The CTC blank
|
||||
symbol is automatically added later for models using ctc.
|
||||
batch_size (int): batch size
|
||||
sample_rate (int): Target sampling rate for data. Audio files will be
|
||||
resampled to sample_rate if it is not already.
|
||||
Defaults to 16000.
|
||||
int_values (bool): Bool indicating whether the audio file is saved as
|
||||
int data or float data.
|
||||
Defaults to False.
|
||||
eos_id (id): Dataset parameter.
|
||||
End of string symbol id used for seq2seq models.
|
||||
Defaults to None.
|
||||
min_duration (float): Dataset parameter.
|
||||
All training files which have a duration less than min_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to 0.1.
|
||||
max_duration (float): Dataset parameter.
|
||||
All training files which have a duration more than max_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to None.
|
||||
normalize_transcripts (bool): Dataset parameter.
|
||||
Whether to use automatic text cleaning.
|
||||
It is highly recommended to manually clean text for best results.
|
||||
Defaults to True.
|
||||
trim_silence (bool): Whether to use trim silence from beginning and end
|
||||
of audio signal using librosa.effects.trim().
|
||||
Defaults to False.
|
||||
load_audio (bool): Dataset parameter.
|
||||
Controls whether the dataloader loads the audio signal and
|
||||
transcript or just the transcript.
|
||||
Defaults to True.
|
||||
drop_last (bool): See PyTorch DataLoader.
|
||||
Defaults to False.
|
||||
shuffle (bool): See PyTorch DataLoader.
|
||||
Defaults to True.
|
||||
num_workers (int): See PyTorch DataLoader.
|
||||
Defaults to 0.
|
||||
perturb_config (dict): Currently disabled.
|
||||
"""
|
||||
|
||||
@property
|
||||
@add_port_docs()
|
||||
def output_ports(self):
|
||||
"""Returns definitions of module output ports."""
|
||||
return {
|
||||
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
"audio_signal": NeuralType(
|
||||
("B", "T"),
|
||||
AudioSignal(freq=self._sample_rate)
|
||||
if self is not None and self._sample_rate is not None
|
||||
else AudioSignal(),
|
||||
),
|
||||
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
|
||||
"transcripts": NeuralType(("B", "T"), LabelsType()),
|
||||
"transcript_length": NeuralType(tuple("B"), LengthsType()),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
batch_size,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
pad_id=None,
|
||||
min_duration=0.1,
|
||||
max_duration=None,
|
||||
normalize_transcripts=True,
|
||||
trim_silence=False,
|
||||
load_audio=True,
|
||||
rpyc_host="",
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
):
|
||||
super().__init__()
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
def rpyc_root_fn():
|
||||
return rpyc.connect(
|
||||
rpyc_host, 8064, config={"sync_request_timeout": 600}
|
||||
).root
|
||||
|
||||
rpyc_conn = rpyc_root_fn()
|
||||
|
||||
self._featurizer = RpycWaveformFeaturizer(
|
||||
sample_rate=self._sample_rate,
|
||||
int_values=int_values,
|
||||
augmentor=None,
|
||||
rpyc_conn=rpyc_conn,
|
||||
)
|
||||
|
||||
def read_remote_manifests():
|
||||
local_mp = []
|
||||
for mrp in manifest_filepath.split(","):
|
||||
md = rpyc_conn.read_path(mrp)
|
||||
mf = tempfile.NamedTemporaryFile(
|
||||
dir="/tmp", prefix="jasper_manifest.", delete=False
|
||||
)
|
||||
mf.write(md)
|
||||
mf.close()
|
||||
local_mp.append(mf.name)
|
||||
return ",".join(local_mp)
|
||||
|
||||
local_manifest_filepath = read_remote_manifests()
|
||||
dataset_params = {
|
||||
"manifest_filepath": local_manifest_filepath,
|
||||
"labels": labels,
|
||||
"featurizer": self._featurizer,
|
||||
"max_duration": max_duration,
|
||||
"min_duration": min_duration,
|
||||
"normalize": normalize_transcripts,
|
||||
"trim": trim_silence,
|
||||
"bos_id": bos_id,
|
||||
"eos_id": eos_id,
|
||||
"load_audio": load_audio,
|
||||
}
|
||||
|
||||
self._dataset = CachedAudioDataset(**dataset_params)
|
||||
self._batch_size = batch_size
|
||||
|
||||
# Set up data loader
|
||||
if self._placement == DeviceType.AllGpu:
|
||||
logging.info("Parallelizing Datalayer.")
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
self._dataset
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
if batch_size == -1:
|
||||
batch_size = len(self._dataset)
|
||||
|
||||
pad_id = 0 if pad_id is None else pad_id
|
||||
self._dataloader = torch.utils.data.DataLoader(
|
||||
dataset=self._dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
|
||||
drop_last=drop_last,
|
||||
shuffle=shuffle if sampler is None else False,
|
||||
sampler=sampler,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def data_iterator(self):
|
||||
return self._dataloader
|
||||
376
src/plume/models/marblenet_nemo/eval.py
Normal file
376
src/plume/models/marblenet_nemo/eval.py
Normal file
@@ -0,0 +1,376 @@
|
||||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
# import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
# monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
# from nemo.utils.lr_policies import CosineAnnealing
|
||||
from training.data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="encoder checkpoint file: JasperEncoder.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="decoder checkpoint file: JasperDecoderForCTC.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
# data_layer = data_loader_layer(
|
||||
# manifest_filepath=args.train_dataset,
|
||||
# sample_rate=sample_rate,
|
||||
# labels=vocab,
|
||||
# batch_size=args.batch_size,
|
||||
# num_workers=cpu_per_traindl,
|
||||
# **train_dl_params,
|
||||
# # normalize_transcripts=False
|
||||
# )
|
||||
#
|
||||
# N = len(data_layer)
|
||||
# steps_per_epoch = math.ceil(
|
||||
# N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
# )
|
||||
# logging.info("Have {0} examples to train on.".format(N))
|
||||
#
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate,
|
||||
**jasper_params["AudioToMelSpectrogramPreprocessor"],
|
||||
)
|
||||
|
||||
# multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
# if multiply_batch_config:
|
||||
# multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
#
|
||||
# spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
# if spectr_augment_config:
|
||||
# data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
# **spectr_augment_config
|
||||
# )
|
||||
#
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
# if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
# else:
|
||||
# logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
# logging.info("================================")
|
||||
# logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
|
||||
# logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
|
||||
# logging.info(
|
||||
# f"Total number of parameters in model: "
|
||||
# f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
# )
|
||||
# logging.info("================================")
|
||||
#
|
||||
# # Train DAG
|
||||
# (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
|
||||
# processed_signal_t, p_length_t = data_preprocessor(
|
||||
# input_signal=audio_signal_t, length=a_sig_length_t
|
||||
# )
|
||||
#
|
||||
# if multiply_batch_config:
|
||||
# (
|
||||
# processed_signal_t,
|
||||
# p_length_t,
|
||||
# transcript_t,
|
||||
# transcript_len_t,
|
||||
# ) = multiply_batch(
|
||||
# in_x=processed_signal_t,
|
||||
# in_x_len=p_length_t,
|
||||
# in_y=transcript_t,
|
||||
# in_y_len=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# if spectr_augment_config:
|
||||
# processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
|
||||
#
|
||||
# encoded_t, encoded_len_t = jasper_encoder(
|
||||
# audio_signal=processed_signal_t, length=p_length_t
|
||||
# )
|
||||
# log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
# predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
# loss_t = ctc_loss(
|
||||
# log_probs=log_probs_t,
|
||||
# targets=transcript_t,
|
||||
# input_length=encoded_len_t,
|
||||
# target_length=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# # Callbacks needed to print info to console and Tensorboard
|
||||
# train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
# tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
# print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
# get_tb_values=lambda x: [("loss", x[0])],
|
||||
# tb_writer=neural_factory.tb_writer,
|
||||
# )
|
||||
#
|
||||
# chpt_callback = nemo.core.CheckpointCallback(
|
||||
# folder=neural_factory.checkpoint_dir,
|
||||
# load_from_folder=args.load_dir,
|
||||
# step_freq=args.checkpoint_save_freq,
|
||||
# checkpoints_to_keep=30,
|
||||
# )
|
||||
#
|
||||
# callbacks = [train_callback, chpt_callback]
|
||||
callbacks = []
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(
|
||||
audio_signal_e,
|
||||
a_sig_length_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[
|
||||
loss_e,
|
||||
predictions_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(
|
||||
process_evaluation_epoch, tag=tagname
|
||||
),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return callbacks
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
# name = construct_name(
|
||||
# args.exp_name,
|
||||
# args.lr,
|
||||
# args.batch_size,
|
||||
# args.max_steps,
|
||||
# args.num_epochs,
|
||||
# args.weight_decay,
|
||||
# args.optimizer,
|
||||
# args.iter_per_step,
|
||||
# )
|
||||
# log_dir = name
|
||||
# if args.work_dir:
|
||||
# log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
# local_rank=args.local_rank,
|
||||
# optimization_level=args.amp_opt_level,
|
||||
# log_dir=log_dir,
|
||||
# checkpoint_dir=args.checkpoint_dir,
|
||||
# create_tb_writer=args.create_tb_writer,
|
||||
# files_to_copy=[args.model_config, __file__],
|
||||
# cudnn_benchmark=args.cudnn_benchmark,
|
||||
# tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
# checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
callbacks = create_all_dags(args, neural_factory)
|
||||
# evaluate model
|
||||
neural_factory.eval(callbacks=callbacks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
61
src/plume/models/marblenet_nemo/featurizer.py
Normal file
61
src/plume/models/marblenet_nemo/featurizer.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# import math
|
||||
|
||||
# import librosa
|
||||
import torch
|
||||
import pickle
|
||||
|
||||
# import torch.nn as nn
|
||||
# from torch_stft import STFT
|
||||
|
||||
# from nemo import logging
|
||||
from nemo.collections.asr.parts.perturb import AudioAugmentor
|
||||
|
||||
# from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
|
||||
class RpycWaveformFeaturizer(object):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
augmentor=None,
|
||||
rpyc_conn=None,
|
||||
):
|
||||
self.augmentor = (
|
||||
augmentor if augmentor is not None else AudioAugmentor()
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.int_values = int_values
|
||||
self.remote_path_samples = rpyc_conn.get_path_samples
|
||||
|
||||
def max_augmentation_length(self, length):
|
||||
return self.augmentor.max_augmentation_length(length)
|
||||
|
||||
def process(self, file_path, offset=0, duration=0, trim=False):
|
||||
audio = self.remote_path_samples(
|
||||
file_path,
|
||||
target_sr=self.sample_rate,
|
||||
int_values=self.int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
return torch.tensor(pickle.loads(audio), dtype=torch.float)
|
||||
|
||||
def process_segment(self, audio_segment):
|
||||
self.augmentor.perturb(audio_segment)
|
||||
return torch.tensor(audio_segment, dtype=torch.float)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, input_config, perturbation_configs=None):
|
||||
if perturbation_configs is not None:
|
||||
aa = AudioAugmentor.from_config(perturbation_configs)
|
||||
else:
|
||||
aa = None
|
||||
|
||||
sample_rate = input_config.get("sample_rate", 16000)
|
||||
int_values = input_config.get("int_values", False)
|
||||
|
||||
return cls(
|
||||
sample_rate=sample_rate, int_values=int_values, augmentor=aa
|
||||
)
|
||||
54
src/plume/models/marblenet_nemo/serve.py
Normal file
54
src/plume/models/marblenet_nemo/serve.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import typer
|
||||
|
||||
# from .asr import JasperASR
|
||||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
JasperASR = lazy_callable("plume.models.jasper_nemo.asr.JasperASR")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc(
|
||||
encoder_path: Path = "/path/to/encoder.pt",
|
||||
decoder_path: Path = "/path/to/decoder.pt",
|
||||
model_yaml_path: Path = "/path/to/model.yaml",
|
||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
):
|
||||
for p in [encoder_path, decoder_path, model_yaml_path]:
|
||||
if not p.exists():
|
||||
logging.info(f"{p} doesn't exists")
|
||||
return
|
||||
asr = JasperASR(str(model_yaml_path), str(encoder_path), str(decoder_path))
|
||||
service = ASRService(asr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logging.info("starting asr server...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(
|
||||
model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||
):
|
||||
encoder_path = model_dir / Path("decoder.pt")
|
||||
decoder_path = model_dir / Path("encoder.pt")
|
||||
model_yaml_path = model_dir / Path("model.yaml")
|
||||
rpyc(encoder_path, decoder_path, model_yaml_path, port)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
59
src/plume/models/marblenet_nemo/serve_data.py
Normal file
59
src/plume/models/marblenet_nemo/serve_data.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import rpyc
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import nemo
|
||||
import pickle
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
|
||||
)
|
||||
|
||||
|
||||
class ASRDataService(rpyc.Service):
|
||||
def exposed_get_path_samples(
|
||||
self, file_path, target_sr, int_values, offset, duration, trim
|
||||
):
|
||||
print(f"loading.. {file_path}")
|
||||
audio = AudioSegment.from_file(
|
||||
file_path,
|
||||
target_sr=target_sr,
|
||||
int_values=int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
|
||||
return pickle.dumps(audio.samples)
|
||||
|
||||
def exposed_read_path(self, file_path):
|
||||
# print(f"reading path.. {file_path}")
|
||||
return Path(file_path).read_bytes()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run_server(port: int = 0):
|
||||
listen_port = (
|
||||
port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
|
||||
)
|
||||
service = ASRDataService()
|
||||
t = ThreadedServer(
|
||||
service, port=listen_port, protocol_config={"allow_all_attrs": True}
|
||||
)
|
||||
typer.echo(f"starting asr server on {listen_port}...")
|
||||
t.start()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
392
src/plume/models/marblenet_nemo/train.py
Normal file
392
src/plume/models/marblenet_nemo/train.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
from nemo.utils.lr_policies import CosineAnnealing
|
||||
from .data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper-speller",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
data_layer = data_loader_layer(
|
||||
manifest_filepath=args.train_dataset,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**train_dl_params,
|
||||
# normalize_transcripts=False
|
||||
)
|
||||
|
||||
N = len(data_layer)
|
||||
steps_per_epoch = math.ceil(
|
||||
N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
)
|
||||
logging.info("Have {0} examples to train on.".format(N))
|
||||
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate,
|
||||
**jasper_params["AudioToMelSpectrogramPreprocessor"],
|
||||
)
|
||||
|
||||
multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
if multiply_batch_config:
|
||||
multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
|
||||
spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
if spectr_augment_config:
|
||||
data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
**spectr_augment_config
|
||||
)
|
||||
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
else:
|
||||
logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
logging.info("================================")
|
||||
logging.info(
|
||||
f"Number of parameters in encoder: {jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info(
|
||||
f"Number of parameters in decoder: {jasper_decoder.num_weights}"
|
||||
)
|
||||
logging.info(
|
||||
f"Total number of parameters in model: "
|
||||
f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info("================================")
|
||||
|
||||
# Train DAG
|
||||
(
|
||||
audio_signal_t,
|
||||
a_sig_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = data_layer()
|
||||
processed_signal_t, p_length_t = data_preprocessor(
|
||||
input_signal=audio_signal_t, length=a_sig_length_t
|
||||
)
|
||||
|
||||
if multiply_batch_config:
|
||||
(
|
||||
processed_signal_t,
|
||||
p_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = multiply_batch(
|
||||
in_x=processed_signal_t,
|
||||
in_x_len=p_length_t,
|
||||
in_y=transcript_t,
|
||||
in_y_len=transcript_len_t,
|
||||
)
|
||||
|
||||
if spectr_augment_config:
|
||||
processed_signal_t = data_spectr_augmentation(
|
||||
input_spec=processed_signal_t
|
||||
)
|
||||
|
||||
encoded_t, encoded_len_t = jasper_encoder(
|
||||
audio_signal=processed_signal_t, length=p_length_t
|
||||
)
|
||||
log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
loss_t = ctc_loss(
|
||||
log_probs=log_probs_t,
|
||||
targets=transcript_t,
|
||||
input_length=encoded_len_t,
|
||||
target_length=transcript_len_t,
|
||||
)
|
||||
|
||||
# Callbacks needed to print info to console and Tensorboard
|
||||
train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
get_tb_values=lambda x: [("loss", x[0])],
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
chpt_callback = nemo.core.CheckpointCallback(
|
||||
folder=neural_factory.checkpoint_dir,
|
||||
load_from_folder=args.load_dir,
|
||||
step_freq=args.checkpoint_save_freq,
|
||||
checkpoints_to_keep=30,
|
||||
)
|
||||
|
||||
callbacks = [train_callback, chpt_callback]
|
||||
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(
|
||||
audio_signal_e,
|
||||
a_sig_length_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[
|
||||
loss_e,
|
||||
predictions_e,
|
||||
transcript_e,
|
||||
transcript_len_e,
|
||||
],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(
|
||||
process_evaluation_epoch, tag=tagname
|
||||
),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return loss_t, callbacks, steps_per_epoch
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
name = construct_name(
|
||||
args.exp_name,
|
||||
args.lr,
|
||||
args.batch_size,
|
||||
args.max_steps,
|
||||
args.num_epochs,
|
||||
args.weight_decay,
|
||||
args.optimizer,
|
||||
args.iter_per_step,
|
||||
)
|
||||
log_dir = name
|
||||
if args.work_dir:
|
||||
log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
local_rank=args.local_rank,
|
||||
optimization_level=args.amp_opt_level,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
create_tb_writer=args.create_tb_writer,
|
||||
files_to_copy=[args.model_config, __file__],
|
||||
cudnn_benchmark=args.cudnn_benchmark,
|
||||
tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
train_loss, callbacks, steps_per_epoch = create_all_dags(
|
||||
args, neural_factory
|
||||
)
|
||||
# train model
|
||||
neural_factory.train(
|
||||
tensors_to_optimize=[train_loss],
|
||||
callbacks=callbacks,
|
||||
lr_policy=CosineAnnealing(
|
||||
args.max_steps
|
||||
if args.max_steps is not None
|
||||
else args.num_epochs * steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
),
|
||||
optimizer=args.optimizer,
|
||||
optimization_params={
|
||||
"num_epochs": args.num_epochs,
|
||||
"max_steps": args.max_steps,
|
||||
"lr": args.lr,
|
||||
"betas": (args.beta1, args.beta2),
|
||||
"weight_decay": args.weight_decay,
|
||||
"grad_norm_clip": None,
|
||||
},
|
||||
batches_per_step=args.iter_per_step,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
22
src/plume/models/marblenet_nemo/trial.py
Normal file
22
src/plume/models/marblenet_nemo/trial.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
import copy
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
import matplotlib.pyplot as plt
|
||||
import IPython.display as ipd
|
||||
# import pyaudio as pa
|
||||
import librosa
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
# sample rate, Hz
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained(
|
||||
"vad_marblenet"
|
||||
)
|
||||
# Preserve a copy of the full config
|
||||
cfg = copy.deepcopy(vad_model._cfg)
|
||||
# print(OmegaConf.to_yaml(cfg))
|
||||
0
src/plume/models/wav2vec2/__init__.py
Normal file
0
src/plume/models/wav2vec2/__init__.py
Normal file
204
src/plume/models/wav2vec2/asr.py
Normal file
204
src/plume/models/wav2vec2/asr.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from io import BytesIO
|
||||
import warnings
|
||||
import itertools as it
|
||||
|
||||
import torch
|
||||
import soundfile as sf
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from fairseq import utils
|
||||
from fairseq.models import BaseFairseqModel
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.models.wav2vec.wav2vec2_asr import base_architecture, Wav2VecEncoder
|
||||
except ModuleNotFoundError:
|
||||
warnings.warn("Install fairseq")
|
||||
try:
|
||||
from wav2letter.decoder import CriterionType
|
||||
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
|
||||
except ModuleNotFoundError:
|
||||
warnings.warn("Install wav2letter")
|
||||
|
||||
|
||||
class Wav2VecCtc(BaseFairseqModel):
|
||||
def __init__(self, w2v_encoder, args):
|
||||
super().__init__()
|
||||
self.w2v_encoder = w2v_encoder
|
||||
self.args = args
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
return state_dict
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, target_dict):
|
||||
"""Build a new model instance."""
|
||||
base_architecture(args)
|
||||
w2v_encoder = Wav2VecEncoder(args, target_dict)
|
||||
return cls(w2v_encoder, args)
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs):
|
||||
"""Get normalized probabilities (or log probs) from a net's output."""
|
||||
logits = net_output["encoder_out"]
|
||||
if log_probs:
|
||||
return utils.log_softmax(logits.float(), dim=-1)
|
||||
else:
|
||||
return utils.softmax(logits.float(), dim=-1)
|
||||
|
||||
def forward(self, **kwargs):
|
||||
x = self.w2v_encoder(**kwargs)
|
||||
return x
|
||||
|
||||
|
||||
class W2lDecoder(object):
|
||||
def __init__(self, tgt_dict):
|
||||
self.tgt_dict = tgt_dict
|
||||
self.vocab_size = len(tgt_dict)
|
||||
self.nbest = 1
|
||||
|
||||
self.criterion_type = CriterionType.CTC
|
||||
self.blank = (
|
||||
tgt_dict.index("<ctc_blank>")
|
||||
if "<ctc_blank>" in tgt_dict.indices
|
||||
else tgt_dict.bos()
|
||||
)
|
||||
self.asg_transitions = None
|
||||
|
||||
def generate(self, model, sample, **unused):
|
||||
"""Generate a batch of inferences."""
|
||||
# model.forward normally channels prev_output_tokens into the decoder
|
||||
# separately, but SequenceGenerator directly calls model.encoder
|
||||
encoder_input = {
|
||||
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
|
||||
}
|
||||
emissions = self.get_emissions(model, encoder_input)
|
||||
return self.decode(emissions)
|
||||
|
||||
def get_emissions(self, model, encoder_input):
|
||||
"""Run encoder and normalize emissions"""
|
||||
# encoder_out = models[0].encoder(**encoder_input)
|
||||
encoder_out = model(**encoder_input)
|
||||
if self.criterion_type == CriterionType.CTC:
|
||||
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
|
||||
|
||||
return emissions.transpose(0, 1).float().cpu().contiguous()
|
||||
|
||||
def get_tokens(self, idxs):
|
||||
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
|
||||
idxs = (g[0] for g in it.groupby(idxs))
|
||||
idxs = filter(lambda x: x != self.blank, idxs)
|
||||
|
||||
return torch.LongTensor(list(idxs))
|
||||
|
||||
|
||||
class W2lViterbiDecoder(W2lDecoder):
|
||||
def __init__(self, tgt_dict):
|
||||
super().__init__(tgt_dict)
|
||||
|
||||
def decode(self, emissions):
|
||||
B, T, N = emissions.size()
|
||||
hypos = list()
|
||||
|
||||
if self.asg_transitions is None:
|
||||
transitions = torch.FloatTensor(N, N).zero_()
|
||||
else:
|
||||
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
|
||||
|
||||
viterbi_path = torch.IntTensor(B, T)
|
||||
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
|
||||
CpuViterbiPath.compute(
|
||||
B,
|
||||
T,
|
||||
N,
|
||||
get_data_ptr_as_bytes(emissions),
|
||||
get_data_ptr_as_bytes(transitions),
|
||||
get_data_ptr_as_bytes(viterbi_path),
|
||||
get_data_ptr_as_bytes(workspace),
|
||||
)
|
||||
return [
|
||||
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
|
||||
for b in range(B)
|
||||
]
|
||||
|
||||
|
||||
def post_process(sentence: str, symbol: str):
|
||||
if symbol == "sentencepiece":
|
||||
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
|
||||
elif symbol == "wordpiece":
|
||||
sentence = sentence.replace(" ", "").replace("_", " ").strip()
|
||||
elif symbol == "letter":
|
||||
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
||||
elif symbol == "_EOW":
|
||||
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
||||
elif symbol is not None and symbol != "none":
|
||||
sentence = (sentence + " ").replace(symbol, "").rstrip()
|
||||
return sentence
|
||||
|
||||
|
||||
def get_feature(filepath):
|
||||
def postprocess(feats, sample_rate):
|
||||
if feats.dim == 2:
|
||||
feats = feats.mean(-1)
|
||||
|
||||
assert feats.dim() == 1, feats.dim()
|
||||
|
||||
with torch.no_grad():
|
||||
feats = F.layer_norm(feats, feats.shape)
|
||||
return feats
|
||||
|
||||
wav, sample_rate = sf.read(filepath)
|
||||
feats = torch.from_numpy(wav).float()
|
||||
if torch.cuda.is_available():
|
||||
feats = feats.cuda()
|
||||
feats = postprocess(feats, sample_rate)
|
||||
return feats
|
||||
|
||||
|
||||
def load_model(ctc_model_path, w2v_model_path, target_dict):
|
||||
w2v = torch.load(ctc_model_path)
|
||||
w2v["args"].w2v_path = w2v_model_path
|
||||
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
|
||||
model.load_state_dict(w2v["model"], strict=True)
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
class Wav2Vec2ASR(object):
|
||||
"""docstring for Wav2Vec2ASR."""
|
||||
|
||||
def __init__(self, ctc_path, w2v_path, target_dict_path):
|
||||
super(Wav2Vec2ASR, self).__init__()
|
||||
self.target_dict = Dictionary.load(target_dict_path)
|
||||
|
||||
self.model = load_model(ctc_path, w2v_path, self.target_dict)
|
||||
self.model.eval()
|
||||
|
||||
self.generator = W2lViterbiDecoder(self.target_dict)
|
||||
|
||||
def transcribe(self, audio_data, greedy=True):
|
||||
aud_f = BytesIO(audio_data)
|
||||
# aud_seg = pydub.AudioSegment.from_file(aud_f)
|
||||
# feat_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
# feat_f = io.BytesIO()
|
||||
# feat_seg.export(feat_f, format='wav')
|
||||
# feat_f.seek(0)
|
||||
net_input = {}
|
||||
feature = get_feature(aud_f)
|
||||
net_input["source"] = feature.unsqueeze(0)
|
||||
|
||||
padding_mask = (
|
||||
torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
padding_mask = padding_mask.cuda()
|
||||
|
||||
net_input["padding_mask"] = padding_mask
|
||||
sample = {}
|
||||
sample["net_input"] = net_input
|
||||
|
||||
with torch.no_grad():
|
||||
hypo = self.generator.generate(self.model, sample, prefix_tokens=None)
|
||||
hyp_pieces = self.target_dict.string(hypo[0][0]["tokens"].int().cpu())
|
||||
result = post_process(hyp_pieces, "letter")
|
||||
return result
|
||||
234
src/plume/models/wav2vec2/data.py
Normal file
234
src/plume/models/wav2vec2/data.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
import shutil
|
||||
import io
|
||||
|
||||
# from time import time
|
||||
|
||||
# import pydub
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from plume.utils import (
|
||||
ExtendedPath,
|
||||
replace_redundant_spaces_with,
|
||||
lazy_module,
|
||||
random_segs,
|
||||
parallel_apply,
|
||||
batch,
|
||||
run_shell,
|
||||
)
|
||||
|
||||
from plume.utils.vad import VADUtterance
|
||||
|
||||
soundfile = lazy_module("soundfile")
|
||||
pydub = lazy_module("pydub")
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True):
|
||||
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
|
||||
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
|
||||
tok_counter = Counter()
|
||||
shutil.copy(
|
||||
src_dataset_path / Path("test_manifest.json"),
|
||||
src_dataset_path / Path("valid_manifest.json"),
|
||||
)
|
||||
if unlink:
|
||||
src_wavs = src_dataset_path / Path("wavs")
|
||||
for wav_path in tqdm(list(src_wavs.glob("**/*.wav"))):
|
||||
audio_seg = (
|
||||
pydub.AudioSegment.from_wav(wav_path)
|
||||
.set_frame_rate(16000)
|
||||
.set_channels(1)
|
||||
)
|
||||
dest_path = dest_dataset_path / Path("wavs") / Path(wav_path.name)
|
||||
audio_seg.export(dest_path, format="wav")
|
||||
|
||||
for dataset_kind in ["train", "valid"]:
|
||||
abs_manifest_path = ExtendedPath(
|
||||
src_dataset_path / Path(f"{dataset_kind}_manifest.json")
|
||||
)
|
||||
manifest_data = list(abs_manifest_path.read_jsonl())
|
||||
o_tsv, o_ltr = f"{dataset_kind}.tsv", f"{dataset_kind}.ltr"
|
||||
out_tsv = dest_dataset_path / Path(o_tsv)
|
||||
out_ltr = dest_dataset_path / Path(o_ltr)
|
||||
with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f:
|
||||
if unlink:
|
||||
tsv_f.write(f"{dest_dataset_path}\n")
|
||||
else:
|
||||
tsv_f.write(f"{src_dataset_path}\n")
|
||||
for md in manifest_data:
|
||||
audio_fname = md["audio_filepath"]
|
||||
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()
|
||||
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
|
||||
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
|
||||
tok_counter.update(pipe_toks)
|
||||
letter_toks = " ".join(pipe_toks) + " |\n"
|
||||
frame_count = soundfile.info(audio_fname).frames
|
||||
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute())
|
||||
ltr_f.write(letter_toks)
|
||||
tsv_f.write(f"{rel_path}\t{frame_count}\n")
|
||||
with dict_ltr.open("w") as d_f:
|
||||
for k, v in tok_counter.most_common():
|
||||
d_f.write(f"{k} {v}\n")
|
||||
(src_dataset_path / Path("valid_manifest.json")).unlink()
|
||||
|
||||
|
||||
@app.command()
|
||||
def set_root(dataset_path: Path, root_path: Path):
|
||||
for dataset_kind in ["train", "valid"]:
|
||||
data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||
with data_file.open("r") as df:
|
||||
lines = df.readlines()
|
||||
with data_file.open("w") as df:
|
||||
lines[0] = str(root_path) + "\n"
|
||||
df.writelines(lines)
|
||||
|
||||
|
||||
@app.command()
|
||||
def convert_audio(log_dir: Path, out_dir: Path):
|
||||
out_dir.mkdir(exist_ok=True, parents=True)
|
||||
all_wavs = list((log_dir).glob("**/*.wav"))
|
||||
name_wav_map = {i.name: i.absolute() for i in all_wavs}
|
||||
exists_wavs = list((out_dir).glob("**/*.wav"))
|
||||
rem_wavs = list(
|
||||
set((i.name for i in all_wavs)) - set((i.name for i in exists_wavs))
|
||||
)
|
||||
rem_wavs_real = [name_wav_map[i] for i in rem_wavs]
|
||||
|
||||
def resample_audio(i):
|
||||
dest_wav = out_dir / i.name
|
||||
if dest_wav.exists():
|
||||
return
|
||||
run_shell(f"ffmpeg -i {i.absolute()} -ac 1 -ar 16000 {dest_wav}", verbose=False)
|
||||
|
||||
parallel_apply(resample_audio, rem_wavs_real, workers=256)
|
||||
|
||||
|
||||
@app.command()
|
||||
def prepare_pretraining(
|
||||
log_dir: Path,
|
||||
dataset_path: Path,
|
||||
format: str = "wav",
|
||||
method: str = "random",
|
||||
max_silence: int = 3000,
|
||||
min_duration: int = 10000,
|
||||
max_duration: int = 30000,
|
||||
fixed_duration: int = 30000,
|
||||
batch_size: int = 100,
|
||||
):
|
||||
audio_dir = dataset_path / "audio"
|
||||
audio_dir.mkdir(exist_ok=True, parents=True)
|
||||
cache_dir = dataset_path / "cache"
|
||||
cache_dir.mkdir(exist_ok=True, parents=True)
|
||||
all_wavs = list((log_dir).glob("**/*.wav"))
|
||||
if method not in ["vad", "random", "fixed"]:
|
||||
typer.echo("should be one of random|fixed")
|
||||
raise typer.Exit()
|
||||
|
||||
def write_seg_arg(arg):
|
||||
seg, dest_wav = arg
|
||||
ob = io.BytesIO()
|
||||
seg.export(ob, format=format)
|
||||
dest_wav.write_bytes(ob.getvalue())
|
||||
ob.close()
|
||||
|
||||
with (dataset_path / "failed.log").open("w") as fl:
|
||||
vad_utt = VADUtterance(
|
||||
max_silence=max_silence,
|
||||
min_utterance=min_duration,
|
||||
max_utterance=max_duration,
|
||||
)
|
||||
|
||||
def vad_process_wav(wav_path):
|
||||
if (cache_dir / wav_path.stem).exists():
|
||||
return []
|
||||
try:
|
||||
aud_seg = pydub.AudioSegment.from_file(wav_path)
|
||||
except pydub.exceptions.CouldntDecodeError:
|
||||
fl.write(wav_path.name + "\n")
|
||||
return []
|
||||
full_seg = aud_seg
|
||||
# segs = random_segs(len(full_seg), min_duration, max_duration)
|
||||
segs = vad_utt.stream_segments(full_seg)
|
||||
audio_chunk_paths = []
|
||||
if len(full_seg) > min_duration:
|
||||
for (i, chunk_seg) in enumerate(segs):
|
||||
dest_wav = audio_dir / (wav_path.stem + f"_{i}.{format}")
|
||||
if dest_wav.exists():
|
||||
continue
|
||||
audio_chunk_paths.append((chunk_seg, dest_wav))
|
||||
(cache_dir / wav_path.stem).touch()
|
||||
return audio_chunk_paths
|
||||
|
||||
def random_process_wav(wav_path):
|
||||
if (cache_dir / wav_path.stem).exists():
|
||||
return []
|
||||
try:
|
||||
aud_seg = pydub.AudioSegment.from_file(wav_path)
|
||||
except pydub.exceptions.CouldntDecodeError:
|
||||
fl.write(wav_path.name + "\n")
|
||||
return []
|
||||
full_seg = aud_seg
|
||||
segs = random_segs(len(full_seg), min_duration, max_duration)
|
||||
audio_chunk_paths = []
|
||||
if len(full_seg) > min_duration:
|
||||
for (i, (start, end)) in enumerate(segs):
|
||||
dest_wav = audio_dir / (wav_path.stem + f"_{i}.{format}")
|
||||
if dest_wav.exists():
|
||||
continue
|
||||
chunk_seg = aud_seg[start:end]
|
||||
audio_chunk_paths.append((chunk_seg, dest_wav))
|
||||
(cache_dir / wav_path.stem).touch()
|
||||
return audio_chunk_paths
|
||||
|
||||
def fixed_process_wav(wav_path):
|
||||
if (cache_dir / wav_path.stem).exists():
|
||||
return []
|
||||
try:
|
||||
aud_seg = pydub.AudioSegment.from_file(wav_path)
|
||||
except pydub.exceptions.CouldntDecodeError:
|
||||
fl.write(wav_path.name + "\n")
|
||||
return []
|
||||
full_seg = aud_seg
|
||||
audio_chunk_paths = []
|
||||
if len(full_seg) > min_duration:
|
||||
for (i, chunk_seg) in enumerate(full_seg[::fixed_duration]):
|
||||
dest_wav = audio_dir / (wav_path.stem + f"_{i}.{format}")
|
||||
if dest_wav.exists() or len(chunk_seg) < min_duration:
|
||||
continue
|
||||
audio_chunk_paths.append((chunk_seg, dest_wav))
|
||||
(cache_dir / wav_path.stem).touch()
|
||||
return audio_chunk_paths
|
||||
|
||||
# warmup
|
||||
pydub.AudioSegment.from_file(all_wavs[0])
|
||||
# parallel_apply(process_wav, all_wavs, pool='process')
|
||||
# parallel_apply(process_wav, all_wavs)
|
||||
seg_f = (
|
||||
vad_process_wav
|
||||
if method == "vad"
|
||||
else (random_process_wav if method == "random" else fixed_process_wav)
|
||||
)
|
||||
for wp_batch in tqdm(batch(all_wavs, n=batch_size)):
|
||||
acp_batch = parallel_apply(seg_f, wp_batch)
|
||||
# acp_batch = list(map(seg_f, tqdm(wp_batch)))
|
||||
flat_acp_batch = [sd for acp in acp_batch for sd in acp]
|
||||
parallel_apply(write_seg_arg, flat_acp_batch)
|
||||
# for acp in acp_batch:
|
||||
# for (seg, des) in acp:
|
||||
# seg.export(des)
|
||||
# for seg_des in tqdm(flat_acp_batch):
|
||||
# write_seg_arg(seg_des)
|
||||
del flat_acp_batch
|
||||
del acp_batch
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
49
src/plume/models/wav2vec2/eval.py
Normal file
49
src/plume/models/wav2vec2/eval.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
# import pandas as pd
|
||||
|
||||
from plume.utils import (
|
||||
asr_manifest_reader,
|
||||
discard_except_digits,
|
||||
replace_digit_symbol,
|
||||
lazy_module
|
||||
# run_shell,
|
||||
)
|
||||
from ...utils.transcribe import triton_transcribe_grpc_gen
|
||||
|
||||
pd = lazy_module('pandas')
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def manifest(manifest_file: Path, result_file: Path = "results.csv"):
|
||||
from pydub import AudioSegment
|
||||
|
||||
host = "localhost"
|
||||
port = 8044
|
||||
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole')
|
||||
result_path = manifest_file.parent / result_file
|
||||
manifest_list = list(asr_manifest_reader(manifest_file))
|
||||
|
||||
def compute_frame(d):
|
||||
audio_file = d["audio_path"]
|
||||
orig_text = d["text"]
|
||||
orig_num = discard_except_digits(replace_digit_symbol(orig_text))
|
||||
aud_seg = AudioSegment.from_file(audio_file)
|
||||
t_audio = audio_prep(aud_seg)
|
||||
asr_text = transcriber(t_audio)
|
||||
asr_num = discard_except_digits(replace_digit_symbol(asr_text))
|
||||
return {
|
||||
"audio_file": audio_file,
|
||||
"asr_text": asr_text,
|
||||
"asr_num": asr_num,
|
||||
"orig_text": orig_text,
|
||||
"orig_num": orig_num,
|
||||
"asr_match": orig_num == asr_num,
|
||||
}
|
||||
|
||||
# df_data = parallel_apply(compute_frame, manifest_list)
|
||||
df_data = map(compute_frame, tqdm(manifest_list))
|
||||
df = pd.DataFrame(df_data)
|
||||
df.to_csv(result_path)
|
||||
53
src/plume/models/wav2vec2/serve.py
Normal file
53
src/plume/models/wav2vec2/serve.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# from rpyc.utils.server import ThreadedServer
|
||||
import typer
|
||||
|
||||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
# from .asr import Wav2Vec2ASR
|
||||
|
||||
ThreadedServer = lazy_callable('rpyc.utils.server.ThreadedServer')
|
||||
Wav2Vec2ASR = lazy_callable('plume.models.wav2vec2.asr.Wav2Vec2ASR')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc(
|
||||
w2v_path: Path = "/path/to/base.pt",
|
||||
ctc_path: Path = "/path/to/ctc.pt",
|
||||
target_dict_path: Path = "/path/to/dict.ltr.txt",
|
||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
):
|
||||
for p in [w2v_path, ctc_path, target_dict_path]:
|
||||
if not p.exists():
|
||||
logging.info(f"{p} doesn't exists")
|
||||
return
|
||||
w2vasr = Wav2Vec2ASR(str(ctc_path), str(w2v_path), str(target_dict_path))
|
||||
service = ASRService(w2vasr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logging.info("starting asr server...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
|
||||
ctc_path = model_dir / Path("ctc.pt")
|
||||
w2v_path = model_dir / Path("base.pt")
|
||||
target_dict_path = model_dir / Path("dict.ltr.txt")
|
||||
rpyc(w2v_path, ctc_path, target_dict_path, port)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
34
src/plume/models/wav2vec2/train.py
Normal file
34
src/plume/models/wav2vec2/train.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import typer
|
||||
# from fairseq_cli.train import cli_main
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import shlex
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
cli_main = lazy_callable('fairseq_cli.train.cli_main')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def local(dataset_path: Path):
|
||||
args = f'''--distributed-world-size 1 {dataset_path} \
|
||||
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
|
||||
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
|
||||
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
|
||||
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
|
||||
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
|
||||
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
|
||||
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
|
||||
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
|
||||
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
|
||||
'''
|
||||
new_args = ['train.py']
|
||||
new_args.extend(shlex.split(args))
|
||||
sys.argv = new_args
|
||||
cli_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
0
src/plume/models/wav2vec2_transformers/__init__.py
Normal file
0
src/plume/models/wav2vec2_transformers/__init__.py
Normal file
39
src/plume/models/wav2vec2_transformers/asr.py
Normal file
39
src/plume/models/wav2vec2_transformers/asr.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
|
||||
|
||||
# import soundfile as sf
|
||||
from io import BytesIO
|
||||
import torch
|
||||
|
||||
from plume.utils import lazy_module
|
||||
|
||||
sf = lazy_module("soundfile")
|
||||
|
||||
|
||||
class Wav2Vec2TransformersASR(object):
|
||||
"""docstring for Wav2Vec2TransformersASR."""
|
||||
|
||||
def __init__(self, ctc_path, w2v_path, target_dict_path):
|
||||
super(Wav2Vec2TransformersASR, self).__init__()
|
||||
self.tokenizer = Wav2Vec2Tokenizer.from_pretrained(
|
||||
"facebook/wav2vec2-large-960h-lv60-self"
|
||||
)
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(
|
||||
"facebook/wav2vec2-large-960h-lv60-self"
|
||||
)
|
||||
|
||||
def transcribe(self, audio_data):
|
||||
aud_f = BytesIO(audio_data)
|
||||
# net_input = {}
|
||||
speech_data, _ = sf.read(aud_f)
|
||||
input_values = self.tokenizer(
|
||||
speech_data, return_tensors="pt", padding="longest"
|
||||
).input_values # Batch size 1
|
||||
|
||||
# retrieve logits
|
||||
logits = self.model(input_values).logits
|
||||
|
||||
# take argmax and decode
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
transcription = self.tokenizer.batch_decode(predicted_ids)[0]
|
||||
return transcription
|
||||
85
src/plume/models/wav2vec2_transformers/data.py
Normal file
85
src/plume/models/wav2vec2_transformers/data.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
import shutil
|
||||
|
||||
# import pydub
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from plume.utils import (
|
||||
ExtendedPath,
|
||||
replace_redundant_spaces_with,
|
||||
lazy_module
|
||||
)
|
||||
soundfile = lazy_module('soundfile')
|
||||
pydub = lazy_module('pydub')
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True):
|
||||
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
|
||||
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
|
||||
tok_counter = Counter()
|
||||
shutil.copy(
|
||||
src_dataset_path / Path("test_manifest.json"),
|
||||
src_dataset_path / Path("valid_manifest.json"),
|
||||
)
|
||||
if unlink:
|
||||
src_wavs = src_dataset_path / Path("wavs")
|
||||
for wav_path in tqdm(list(src_wavs.glob("**/*.wav"))):
|
||||
audio_seg = (
|
||||
pydub.AudioSegment.from_wav(wav_path)
|
||||
.set_frame_rate(16000)
|
||||
.set_channels(1)
|
||||
)
|
||||
dest_path = dest_dataset_path / Path("wavs") / Path(wav_path.name)
|
||||
audio_seg.export(dest_path, format="wav")
|
||||
|
||||
for dataset_kind in ["train", "valid"]:
|
||||
abs_manifest_path = ExtendedPath(
|
||||
src_dataset_path / Path(f"{dataset_kind}_manifest.json")
|
||||
)
|
||||
manifest_data = list(abs_manifest_path.read_jsonl())
|
||||
o_tsv, o_ltr = f"{dataset_kind}.tsv", f"{dataset_kind}.ltr"
|
||||
out_tsv = dest_dataset_path / Path(o_tsv)
|
||||
out_ltr = dest_dataset_path / Path(o_ltr)
|
||||
with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f:
|
||||
if unlink:
|
||||
tsv_f.write(f"{dest_dataset_path}\n")
|
||||
else:
|
||||
tsv_f.write(f"{src_dataset_path}\n")
|
||||
for md in manifest_data:
|
||||
audio_fname = md["audio_filepath"]
|
||||
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()
|
||||
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
|
||||
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
|
||||
tok_counter.update(pipe_toks)
|
||||
letter_toks = " ".join(pipe_toks) + " |\n"
|
||||
frame_count = soundfile.info(audio_fname).frames
|
||||
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute())
|
||||
ltr_f.write(letter_toks)
|
||||
tsv_f.write(f"{rel_path}\t{frame_count}\n")
|
||||
with dict_ltr.open("w") as d_f:
|
||||
for k, v in tok_counter.most_common():
|
||||
d_f.write(f"{k} {v}\n")
|
||||
(src_dataset_path / Path("valid_manifest.json")).unlink()
|
||||
|
||||
|
||||
@app.command()
|
||||
def set_root(dataset_path: Path, root_path: Path):
|
||||
for dataset_kind in ["train", "valid"]:
|
||||
data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||
with data_file.open("r") as df:
|
||||
lines = df.readlines()
|
||||
with data_file.open("w") as df:
|
||||
lines[0] = str(root_path) + "\n"
|
||||
df.writelines(lines)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
52
src/plume/models/wav2vec2_transformers/eval.py
Normal file
52
src/plume/models/wav2vec2_transformers/eval.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from pathlib import Path
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
# import pandas as pd
|
||||
|
||||
from plume.utils import (
|
||||
asr_manifest_reader,
|
||||
discard_except_digits,
|
||||
replace_digit_symbol,
|
||||
lazy_module
|
||||
# run_shell,
|
||||
)
|
||||
from ...utils.transcribe import triton_transcribe_grpc_gen, transcribe_rpyc_gen
|
||||
|
||||
pd = lazy_module('pandas')
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def manifest(manifest_file: Path, result_file: Path = "results.csv", rpyc: bool = False):
|
||||
from pydub import AudioSegment
|
||||
|
||||
host = "localhost"
|
||||
port = 8044
|
||||
if rpyc:
|
||||
transcriber, audio_prep = transcribe_rpyc_gen(host, port)
|
||||
else:
|
||||
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole')
|
||||
result_path = manifest_file.parent / result_file
|
||||
manifest_list = list(asr_manifest_reader(manifest_file))
|
||||
|
||||
def compute_frame(d):
|
||||
audio_file = d["audio_path"]
|
||||
orig_text = d["text"]
|
||||
orig_num = discard_except_digits(replace_digit_symbol(orig_text))
|
||||
aud_seg = AudioSegment.from_file(audio_file)
|
||||
t_audio = audio_prep(aud_seg)
|
||||
asr_text = transcriber(t_audio)
|
||||
asr_num = discard_except_digits(replace_digit_symbol(asr_text))
|
||||
return {
|
||||
"audio_file": audio_file,
|
||||
"asr_text": asr_text,
|
||||
"asr_num": asr_num,
|
||||
"orig_text": orig_text,
|
||||
"orig_num": orig_num,
|
||||
"asr_match": orig_num == asr_num,
|
||||
}
|
||||
|
||||
# df_data = parallel_apply(compute_frame, manifest_list)
|
||||
df_data = map(compute_frame, tqdm(manifest_list))
|
||||
df = pd.DataFrame(df_data)
|
||||
df.to_csv(result_path)
|
||||
52
src/plume/models/wav2vec2_transformers/serve.py
Normal file
52
src/plume/models/wav2vec2_transformers/serve.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# from rpyc.utils.server import ThreadedServer
|
||||
import typer
|
||||
|
||||
from ...utils.serve import ASRService
|
||||
from plume.utils import lazy_callable
|
||||
# from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR
|
||||
# from .asr import Wav2Vec2ASR
|
||||
|
||||
ThreadedServer = lazy_callable("rpyc.utils.server.ThreadedServer")
|
||||
Wav2Vec2TransformersASR = lazy_callable(
|
||||
"plume.models.wav2vec2_transformers.asr.Wav2Vec2TransformersASR"
|
||||
)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc(
|
||||
w2v_path: Path = "/path/to/base.pt",
|
||||
ctc_path: Path = "/path/to/ctc.pt",
|
||||
target_dict_path: Path = "/path/to/dict.ltr.txt",
|
||||
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||
):
|
||||
w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path)
|
||||
service = ASRService(w2vasr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logging.info("starting asr server...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
||||
@app.command()
|
||||
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
|
||||
ctc_path = model_dir / Path("ctc.pt")
|
||||
w2v_path = model_dir / Path("base.pt")
|
||||
target_dict_path = model_dir / Path("dict.ltr.txt")
|
||||
rpyc(w2v_path, ctc_path, target_dict_path, port)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
41
src/plume/models/wav2vec2_transformers/test.py
Normal file
41
src/plume/models/wav2vec2_transformers/test.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
|
||||
from datasets import load_dataset
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
# load model and tokenizer
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
||||
|
||||
|
||||
# define function to read in sound file
|
||||
def map_to_array(batch):
|
||||
speech, _ = sf.read(batch["file"])
|
||||
batch["speech"] = speech
|
||||
return batch
|
||||
|
||||
|
||||
# load dummy dataset and read soundfiles
|
||||
def main():
|
||||
ds = load_dataset(
|
||||
"patrickvonplaten/librispeech_asr_dummy", "clean", split="validation"
|
||||
)
|
||||
ds = ds.map(map_to_array)
|
||||
|
||||
# tokenize
|
||||
input_values = tokenizer(
|
||||
ds["speech"][:2], return_tensors="pt", padding="longest"
|
||||
).input_values # Batch size 1
|
||||
|
||||
# retrieve logits
|
||||
logits = model(input_values).logits
|
||||
|
||||
# take argmax and decode
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
transcription = tokenizer.batch_decode(predicted_ids)
|
||||
print(transcription)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
34
src/plume/models/wav2vec2_transformers/train.py
Normal file
34
src/plume/models/wav2vec2_transformers/train.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import typer
|
||||
# from fairseq_cli.train import cli_main
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import shlex
|
||||
from plume.utils import lazy_callable
|
||||
|
||||
cli_main = lazy_callable('fairseq_cli.train.cli_main')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def local(dataset_path: Path):
|
||||
args = f'''--distributed-world-size 1 {dataset_path} \
|
||||
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
|
||||
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
|
||||
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
|
||||
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
|
||||
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
|
||||
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
|
||||
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
|
||||
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
|
||||
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
|
||||
'''
|
||||
new_args = ['train.py']
|
||||
new_args.extend(shlex.split(args))
|
||||
sys.argv = new_args
|
||||
cli_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
109
src/plume/ui/__init__.py
Normal file
109
src/plume/ui/__init__.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import typer
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from plume.utils import lazy_module
|
||||
|
||||
# from streamlit import cli as stcli
|
||||
|
||||
stcli = lazy_module("streamlit.cli")
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback()
|
||||
def ui():
|
||||
"""
|
||||
ui sub commands
|
||||
"""
|
||||
|
||||
|
||||
@app.command()
|
||||
def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
|
||||
annotation_lit_path = Path(__file__).parent / Path("annotation.py")
|
||||
if task_id:
|
||||
sys.argv = [
|
||||
"streamlit",
|
||||
"run",
|
||||
str(annotation_lit_path),
|
||||
"--",
|
||||
str(data_dir),
|
||||
"--task-id",
|
||||
task_id,
|
||||
"--dump-fname",
|
||||
dump_fname,
|
||||
]
|
||||
else:
|
||||
sys.argv = [
|
||||
"streamlit",
|
||||
"run",
|
||||
str(annotation_lit_path),
|
||||
"--",
|
||||
str(data_dir),
|
||||
"--dump-fname",
|
||||
dump_fname,
|
||||
]
|
||||
sys.exit(stcli.main())
|
||||
|
||||
|
||||
@app.command()
|
||||
def preview(manifest_path: Path, port: int = 8081):
|
||||
annotation_lit_path = Path(__file__).parent / Path("preview.py")
|
||||
sys.argv = [
|
||||
"streamlit",
|
||||
"run",
|
||||
"--server.port",
|
||||
str(port),
|
||||
str(annotation_lit_path),
|
||||
"--",
|
||||
str(manifest_path),
|
||||
]
|
||||
sys.exit(stcli.main())
|
||||
|
||||
|
||||
@app.command()
|
||||
def encrypted_preview(manifest_path: Path, key: str, port: int = 8081):
|
||||
lit_path = Path(__file__).parent / Path("encrypted_preview.py")
|
||||
sys.argv = [
|
||||
"streamlit",
|
||||
"run",
|
||||
"--server.port",
|
||||
str(port),
|
||||
str(lit_path),
|
||||
"--",
|
||||
str(manifest_path),
|
||||
str(key),
|
||||
]
|
||||
sys.exit(stcli.main())
|
||||
|
||||
|
||||
@app.command()
|
||||
def audio(audio_dir: Path):
|
||||
lit_path = Path(__file__).parent / Path("audio.py")
|
||||
sys.argv = ["streamlit", "run", str(lit_path), "--", str(audio_dir)]
|
||||
sys.exit(stcli.main())
|
||||
|
||||
|
||||
@app.command()
|
||||
def collection(data_dir: Path, task_id: str = ""):
|
||||
# TODO: Implement web ui for data collection
|
||||
pass
|
||||
|
||||
|
||||
@app.command()
|
||||
def alignment(preview_dir: Path, port: int = 8010):
|
||||
from RangeHTTPServer import RangeRequestHandler
|
||||
from functools import partial
|
||||
from http.server import HTTPServer
|
||||
|
||||
server_address = ("", port)
|
||||
handler_class = partial(RangeRequestHandler, directory=str(preview_dir))
|
||||
httpd = HTTPServer(server_address, handler_class)
|
||||
httpd.serve_forever()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
108
src/plume/ui/annotation.py
Normal file
108
src/plume/ui/annotation.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# import sys
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
import typer
|
||||
from plume.utils import ExtendedPath
|
||||
from plume.utils.ui_persist import setup_mongo_asr_validation_state
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
setup_mongo_asr_validation_state(st)
|
||||
|
||||
|
||||
@st.cache()
|
||||
def load_ui_data(data_dir: Path, dump_fname: Path):
|
||||
annotation_ui_data_path = data_dir / dump_fname
|
||||
typer.echo(f"Using annotation ui data from {annotation_ui_data_path}")
|
||||
return ExtendedPath(annotation_ui_data_path).read_json()
|
||||
|
||||
|
||||
def show_key(sample, key, trail=""):
|
||||
if key in sample:
|
||||
title = key.replace("_", " ").title()
|
||||
if type(sample[key]) == float:
|
||||
st.sidebar.markdown(f"{title}: {sample[key]:.2f}{trail}")
|
||||
else:
|
||||
st.sidebar.markdown(f"{title}: {sample[key]}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
|
||||
st.set_task(data_dir, task_id)
|
||||
ui_config = load_ui_data(data_dir, dump_fname)
|
||||
asr_data = ui_config["data"]
|
||||
annotation_only = ui_config.get("annotation_only", False)
|
||||
asr_result_key = ui_config.get("asr_result_key", "pretrained_asr")
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
task_uid = st.task_id.rsplit("-", 1)[1]
|
||||
if annotation_only:
|
||||
st.title(f"ASR Annotation - # {task_uid}")
|
||||
else:
|
||||
st.title(f"ASR Validation - # {task_uid}")
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
# if "caller" in sample:
|
||||
# st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
||||
show_key(sample, "caller")
|
||||
if not annotation_only:
|
||||
show_key(sample, asr_result_key)
|
||||
show_key(sample, "asr_wer", trail="%")
|
||||
show_key(sample, "correct_candidate")
|
||||
|
||||
st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes())
|
||||
st.audio((data_dir / Path(sample["audio_path"])).open("rb"))
|
||||
# set default to text
|
||||
corrected = sample["text"]
|
||||
correction_entry = st.get_correction_entry(sample["utterance_id"])
|
||||
selected_idx = 0
|
||||
options = ("Correct", "Incorrect", "Inaudible")
|
||||
# if correction entry is present set the corresponding ui defaults
|
||||
if correction_entry:
|
||||
selected_idx = options.index(correction_entry["value"]["status"])
|
||||
corrected = correction_entry["value"]["correction"]
|
||||
selected = st.radio("The Audio is", options, index=selected_idx)
|
||||
if selected == "Incorrect":
|
||||
corrected = st.text_input("Actual:", value=corrected)
|
||||
if selected == "Inaudible":
|
||||
corrected = ""
|
||||
if st.button("Submit"):
|
||||
st.update_entry(
|
||||
sample["utterance_id"], {"status": selected, "correction": corrected}
|
||||
)
|
||||
st.update_cursor(sample_no + 1)
|
||||
if correction_entry:
|
||||
status = correction_entry["value"]["status"]
|
||||
correction = correction_entry["value"]["correction"]
|
||||
st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
|
||||
text_sample = st.text_input("Go to Text:", value="")
|
||||
if text_sample != "":
|
||||
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
|
||||
if len(candidates) > 0:
|
||||
st.update_cursor(candidates[0])
|
||||
real_idx = st.number_input(
|
||||
"Go to real-index",
|
||||
value=sample["real_idx"],
|
||||
min_value=0,
|
||||
max_value=len(asr_data) - 1,
|
||||
)
|
||||
if real_idx != int(sample["real_idx"]):
|
||||
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
|
||||
st.update_cursor(idx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
21
src/plume/ui/audio.py
Normal file
21
src/plume/ui/audio.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(wav_dir: Path):
|
||||
wav_file = list(wav_dir.glob('**/*.wav'))[0]
|
||||
st.title("Audio Preview")
|
||||
print(wav_file.exists())
|
||||
st.audio(str(wav_dir / wav_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
46
src/plume/ui/encrypted_preview.py
Normal file
46
src/plume/ui/encrypted_preview.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
import typer
|
||||
from plume.utils import ExtendedPath, wav_cryptor, text_cryptor
|
||||
from plume.utils.ui_persist import setup_file_state
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
setup_file_state(st)
|
||||
|
||||
|
||||
@st.cache()
|
||||
def load_ui_data(validation_ui_data_path: Path):
|
||||
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(manifest: Path, key: str):
|
||||
wc = wav_cryptor(key)
|
||||
tc = text_cryptor(key)
|
||||
asr_data = load_ui_data(manifest)
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
st.title("ASR Manifest Preview")
|
||||
gt_text = tc.decrypt_text(sample["text"].encode("utf-8"))
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{gt_text}**")
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.markdown(f"Gold Text: **{gt_text}**")
|
||||
wav = wc.decrypt_wav_path((manifest.parent / Path(sample["audio_filepath"])))
|
||||
st.audio(wav)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
42
src/plume/ui/preview.py
Normal file
42
src/plume/ui/preview.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
import typer
|
||||
from plume.utils import ExtendedPath
|
||||
from plume.utils.ui_persist import setup_file_state
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
setup_file_state(st)
|
||||
|
||||
|
||||
@st.cache()
|
||||
def load_ui_data(validation_ui_data_path: Path):
|
||||
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(manifest: Path):
|
||||
asr_data = load_ui_data(manifest)
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
st.title("ASR Manifest Preview")
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
st.audio((manifest.parent / Path(sample["audio_filepath"])).open("rb"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
151
src/plume/utils/.gitignore
vendored
Normal file
151
src/plume/utils/.gitignore
vendored
Normal file
@@ -0,0 +1,151 @@
|
||||
/data/
|
||||
/model/
|
||||
/train/
|
||||
.env*
|
||||
*.yaml
|
||||
*.yml
|
||||
*.json
|
||||
|
||||
|
||||
# Created by https://www.gitignore.io/api/python
|
||||
# Edit at https://www.gitignore.io/?templates=python
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# Mr Developer
|
||||
.mr.developer.cfg
|
||||
.project
|
||||
.pydevproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# End of https://www.gitignore.io/api/python
|
||||
|
||||
# Created by https://www.gitignore.io/api/macos
|
||||
# Edit at https://www.gitignore.io/?templates=macos
|
||||
|
||||
### macOS ###
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
# End of https://www.gitignore.io/api/macos
|
||||
677
src/plume/utils/__init__.py
Normal file
677
src/plume/utils/__init__.py
Normal file
@@ -0,0 +1,677 @@
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import wave
|
||||
import logging
|
||||
import subprocess
|
||||
import shutil
|
||||
import random
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from uuid import uuid4
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
# from .lazy_loader import LazyLoader
|
||||
|
||||
# from ruamel.yaml import YAML
|
||||
# import boto3
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
# import pymongo
|
||||
# from slugify import slugify
|
||||
# import pydub
|
||||
# import matplotlib.pyplot as plt
|
||||
# import librosa
|
||||
# import librosa.display as audio_display
|
||||
# from natural.date import compress
|
||||
# from num2words import num2words
|
||||
import datetime
|
||||
import six
|
||||
|
||||
# from .transcribe import triton_transcribe_grpc_gen
|
||||
# from .eval import app as eval_app
|
||||
from .manifest import asr_manifest_writer, manifest_str
|
||||
from .lazy_import import lazy_callable, lazy_module
|
||||
from .parallel import parallel_apply
|
||||
from .extended_path import ExtendedPath
|
||||
from .tts import app as tts_app
|
||||
from .transcribe import app as transcribe_app
|
||||
from .align import app as align_app
|
||||
from .encrypt import app as encrypt_app, wav_cryptor, text_cryptor # noqa
|
||||
from .regentity import ( # noqa
|
||||
num_replacer,
|
||||
alnum_replacer,
|
||||
num_keeper,
|
||||
alnum_keeper,
|
||||
default_num_rules,
|
||||
default_num_only_rules,
|
||||
default_alnum_rules,
|
||||
entity_replacer_keeper,
|
||||
)
|
||||
|
||||
boto3 = lazy_module("boto3")
|
||||
pymongo = lazy_module("pymongo")
|
||||
pydub = lazy_module("pydub")
|
||||
audio_display = lazy_module("librosa.display")
|
||||
plt = lazy_module("matplotlib.pyplot")
|
||||
librosa = lazy_module("librosa")
|
||||
YAML = lazy_callable("ruamel.yaml.YAML")
|
||||
num2words = lazy_callable("num2words.num2words")
|
||||
slugify = lazy_callable("slugify.slugify")
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(encrypt_app)
|
||||
app.add_typer(tts_app, name="tts")
|
||||
app.add_typer(align_app, name="align")
|
||||
app.add_typer(transcribe_app, name="transcribe")
|
||||
|
||||
|
||||
@app.callback()
|
||||
def utils():
|
||||
"""
|
||||
utils sub commands
|
||||
"""
|
||||
|
||||
|
||||
log_fmt_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=log_fmt_str)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Precalculated timestamps
|
||||
TIME_MINUTE = 60
|
||||
TIME_HOUR = 3600
|
||||
TIME_DAY = 86400
|
||||
TIME_WEEK = 604800
|
||||
|
||||
|
||||
def compress(t, show_hours=False, sign=False, pad=""):
|
||||
"""
|
||||
Convert the input to compressed format, works with a
|
||||
:class:`datetime.timedelta` object or a number that represents the number
|
||||
of seconds you want to compress. If you supply a timestamp or a
|
||||
:class:`datetime.datetime` object, it will give the delta relative to the
|
||||
current time.
|
||||
You can enable showing a sign in front of the compressed format with the
|
||||
``sign`` parameter, the default is not to show signs.
|
||||
Optionally, you can chose to pad the output. If you wish your values to be
|
||||
separated by spaces, set ``pad`` to ``' '``.
|
||||
:param t: seconds or :class:`datetime.timedelta` object
|
||||
:param sign: default ``False``
|
||||
:param pad: default ``''``
|
||||
>>> print(compress(0))
|
||||
0s
|
||||
>>> print(compress(1))
|
||||
1s
|
||||
>>> print(compress(12))
|
||||
12s
|
||||
>>> print(compress(123))
|
||||
2m3s
|
||||
>>> print(compress(1234))
|
||||
20m34s
|
||||
>>> print(compress(12345))
|
||||
3h25m45s
|
||||
>>> print(compress(123456))
|
||||
1d10h17m36s
|
||||
==============
|
||||
src: https://github.com/tehmaze/natural/blob/master/natural/date.py
|
||||
"""
|
||||
|
||||
if isinstance(t, datetime.timedelta):
|
||||
seconds = t.seconds + (t.days * 86400)
|
||||
elif isinstance(t, six.integer_types + (float,)):
|
||||
return compress(datetime.timedelta(seconds=t), sign, pad)
|
||||
else:
|
||||
raise Exception("Invalid time format")
|
||||
|
||||
parts = []
|
||||
if sign:
|
||||
parts.append("-" if t.days < 0 else "+")
|
||||
|
||||
if not show_hours:
|
||||
weeks, seconds = divmod(seconds, TIME_WEEK)
|
||||
days, seconds = divmod(seconds, TIME_DAY)
|
||||
hours, seconds = divmod(seconds, TIME_HOUR)
|
||||
minutes, seconds = divmod(seconds, TIME_MINUTE)
|
||||
|
||||
if not show_hours:
|
||||
if weeks:
|
||||
parts.append(("%dw") % (weeks,))
|
||||
if days:
|
||||
parts.append(("%dd") % (days,))
|
||||
if hours:
|
||||
parts.append(("%dh") % (hours,))
|
||||
if minutes:
|
||||
parts.append(("%dm") % (minutes,))
|
||||
if seconds or len(parts) == 0:
|
||||
parts.append(("%ds") % (seconds,))
|
||||
|
||||
return pad.join(parts)
|
||||
|
||||
|
||||
def duration_str(seconds, show_hours=False):
|
||||
t = datetime.timedelta(seconds=seconds)
|
||||
return compress(t, show_hours=show_hours, pad=" ")
|
||||
|
||||
|
||||
def replace_digit_symbol(w2v_out, num_range=10):
|
||||
def rep_i(i):
|
||||
return (num2words(i).replace("-", " "), str(i))
|
||||
|
||||
num_int_map = [rep_i(i) for i in reversed(range(num_range))]
|
||||
out = w2v_out.lower()
|
||||
for (k, v) in num_int_map:
|
||||
out = re.sub(k, v, out)
|
||||
return out
|
||||
|
||||
|
||||
def num_keeper_orig(num_range=10, extra_rules=[]):
|
||||
num_int_map_ty = [
|
||||
(
|
||||
r"\b" + num2words(i) + r"\b",
|
||||
" " + str(i) + " ",
|
||||
)
|
||||
for i in reversed(range(num_range))
|
||||
]
|
||||
re_rules = [
|
||||
(re.compile(k, re.IGNORECASE), v)
|
||||
for (k, v) in [
|
||||
# (r"[ ;,.]", " "),
|
||||
(r"\bdouble(?: |-)(\w+)\b", "\\1 \\1"),
|
||||
(r"\btriple(?: |-)(\w+)\b", "\\1 \\1 \\1"),
|
||||
(r"hundred", "00"),
|
||||
(r"\boh\b", " 0 "),
|
||||
(r"\bo\b", " 0 "),
|
||||
]
|
||||
+ num_int_map_ty
|
||||
] + [(re.compile(k), v) for (k, v) in extra_rules]
|
||||
|
||||
def merge_intervals(intervals):
|
||||
# https://codereview.stackexchange.com/a/69249
|
||||
sorted_by_lower_bound = sorted(intervals, key=lambda tup: tup[0])
|
||||
merged = []
|
||||
|
||||
for higher in sorted_by_lower_bound:
|
||||
if not merged:
|
||||
merged.append(higher)
|
||||
else:
|
||||
lower = merged[-1]
|
||||
# test for intersection between lower and higher:
|
||||
# we know via sorting that lower[0] <= higher[0]
|
||||
if higher[0] <= lower[1]:
|
||||
upper_bound = max(lower[1], higher[1])
|
||||
merged[-1] = (
|
||||
lower[0],
|
||||
upper_bound,
|
||||
) # replace by merged interval
|
||||
else:
|
||||
merged.append(higher)
|
||||
return merged
|
||||
|
||||
# merging interval tree for optimal # https://www.geeksforgeeks.org/interval-tree/
|
||||
|
||||
def keep_numeric_literals(w2v_out):
|
||||
# out = w2v_out.lower()
|
||||
out = re.sub(r"[ ;,.]", " ", w2v_out).strip()
|
||||
# out = " " + out.strip() + " "
|
||||
# out = re.sub(r"double (\w+)", "\\1 \\1", out)
|
||||
# out = re.sub(r"triple (\w+)", "\\1 \\1 \\1", out)
|
||||
num_spans = []
|
||||
for (k, v) in re_rules: # [94:]:
|
||||
matches = k.finditer(out)
|
||||
for m in matches:
|
||||
# num_spans.append((k, m.span()))
|
||||
num_spans.append(m.span())
|
||||
# out = re.sub(k, v, out)
|
||||
merged = merge_intervals(num_spans)
|
||||
num_ents = len(merged)
|
||||
keep_out = " ".join((out[s[0] : s[1]] for s in merged))
|
||||
return keep_out, num_ents
|
||||
|
||||
return keep_numeric_literals
|
||||
|
||||
|
||||
def discard_except_digits(inp):
|
||||
return re.sub("[^0-9]", "", inp)
|
||||
|
||||
|
||||
def digits_to_chars(text):
|
||||
num_tokens = [num2words(c) + " " if "0" <= c <= "9" else c for c in text]
|
||||
return ("".join(num_tokens)).lower()
|
||||
|
||||
|
||||
def replace_redundant_spaces_with(text, sub):
|
||||
return re.sub(" +", sub, text)
|
||||
|
||||
|
||||
def space_out(text):
|
||||
letters = " ".join(list(text))
|
||||
return letters
|
||||
|
||||
|
||||
def random_segs(total, min_val, max_val):
|
||||
out_list = []
|
||||
rand_total = prev_start = 0
|
||||
while True:
|
||||
if total < rand_total + min_val or total < rand_total:
|
||||
break
|
||||
sample = random.randint(min_val, max_val)
|
||||
if total - rand_total < max_val:
|
||||
break
|
||||
if total - rand_total < max_val + min_val:
|
||||
sample = random.randint(min_val, max_val - min_val)
|
||||
prev_start = rand_total
|
||||
if 0 < rand_total + sample - total < max_val:
|
||||
break
|
||||
rand_total += sample
|
||||
out_list.append((prev_start, rand_total))
|
||||
out_list.append((rand_total, total))
|
||||
return out_list
|
||||
|
||||
|
||||
def wav_bytes(audio_bytes, frame_rate=24000):
|
||||
wf_b = io.BytesIO()
|
||||
with wave.open(wf_b, mode="w") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setframerate(frame_rate)
|
||||
wf.setsampwidth(2)
|
||||
wf.writeframesraw(audio_bytes)
|
||||
return wf_b.getvalue()
|
||||
|
||||
|
||||
def tscript_uuid_fname(transcript):
|
||||
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
|
||||
|
||||
def run_shell(cmd_str, work_dir=".", verbose=True):
|
||||
cwd_path = Path(work_dir).absolute()
|
||||
if verbose:
|
||||
with subprocess.Popen(
|
||||
cmd_str,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
shell=True,
|
||||
cwd=cwd_path,
|
||||
) as p:
|
||||
for line in p.stdout:
|
||||
print(line.replace(b"\n", b"").decode("utf-8"))
|
||||
else:
|
||||
subprocess.run(cmd_str, shell=True, cwd=cwd_path, capture_output=True)
|
||||
|
||||
|
||||
def upload_s3(dataset_path, s3_path):
|
||||
run_shell(f"aws s3 sync {dataset_path} {s3_path}")
|
||||
|
||||
|
||||
def copy_s3(dataset_path, s3_path):
|
||||
run_shell(f"aws s3 cp {dataset_path} {s3_path}")
|
||||
|
||||
|
||||
def get_download_path(s3_uri, output_path):
|
||||
s3_uri_p = urlsplit(s3_uri)
|
||||
download_path = output_path / Path(s3_uri_p.path[1:])
|
||||
download_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
return download_path
|
||||
|
||||
|
||||
def s3_downloader():
|
||||
s3 = boto3.client("s3")
|
||||
|
||||
def download_s3(s3_uri, download_path, verbose=False):
|
||||
s3_uri_p = urlsplit(s3_uri)
|
||||
download_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
if not download_path.exists():
|
||||
if verbose:
|
||||
print(f"downloading {s3_uri} to {download_path}")
|
||||
dp_s = str(download_path)
|
||||
s3.download_file(s3_uri_p.netloc, s3_uri_p.path[1:], dp_s)
|
||||
|
||||
return download_s3
|
||||
|
||||
|
||||
def asr_data_writer(dataset_dir, asr_data_source, verbose=False):
|
||||
(dataset_dir / Path("wavs")).mkdir(parents=True, exist_ok=True)
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
num_datapoints = 0
|
||||
with asr_manifest.open("w") as mf:
|
||||
print(f"writing manifest to {asr_manifest}")
|
||||
for transcript, audio_dur, wav_data in asr_data_source:
|
||||
fname = tscript_uuid_fname(transcript)
|
||||
wav_fname = Path(fname).with_suffix(".wav")
|
||||
audio_file = dataset_dir / Path("wavs") / wav_fname
|
||||
audio_file.write_bytes(wav_data)
|
||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||
mf.write(manifest)
|
||||
if verbose:
|
||||
print(f"writing '{transcript}' of duration {audio_dur}")
|
||||
num_datapoints += 1
|
||||
return num_datapoints
|
||||
|
||||
|
||||
def ui_data_generator(dataset_dir, asr_data_source, verbose=False):
|
||||
(dataset_dir / Path("wavs")).mkdir(parents=True, exist_ok=True)
|
||||
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def data_fn(
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
fname,
|
||||
audio_file,
|
||||
num_datapoints,
|
||||
rel_data_path,
|
||||
):
|
||||
png_path = Path(fname).with_suffix(".png")
|
||||
rel_plot_path = Path("wav_plots") / png_path
|
||||
wav_plot_path = dataset_dir / rel_plot_path
|
||||
if not wav_plot_path.exists():
|
||||
plot_seg(wav_plot_path.absolute(), audio_file)
|
||||
return {
|
||||
"audio_path": str(rel_data_path),
|
||||
"audio_filepath": str(rel_data_path),
|
||||
"duration": round(audio_dur, 1),
|
||||
"text": transcript,
|
||||
"real_idx": num_datapoints,
|
||||
"caller": caller_name,
|
||||
"utterance_id": fname,
|
||||
"plot_path": str(rel_plot_path),
|
||||
}
|
||||
|
||||
num_datapoints = 0
|
||||
data_funcs = []
|
||||
for (
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
) in asr_data_source:
|
||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
audio_file = (
|
||||
dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav")
|
||||
).absolute()
|
||||
audio_file.write_bytes(wav_data)
|
||||
# audio_path = str(audio_file)
|
||||
rel_data_path = audio_file.relative_to(dataset_dir.absolute())
|
||||
data_funcs.append(
|
||||
partial(
|
||||
data_fn,
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
fname,
|
||||
audio_file,
|
||||
num_datapoints,
|
||||
rel_data_path,
|
||||
)
|
||||
)
|
||||
num_datapoints += 1
|
||||
ui_data = parallel_apply(lambda x: x(), data_funcs)
|
||||
return ui_data, num_datapoints
|
||||
|
||||
|
||||
def ui_dump_manifest_writer(dataset_dir, asr_data_source, verbose=False):
|
||||
dump_data, num_datapoints = ui_data_generator(
|
||||
dataset_dir, asr_data_source, verbose=verbose
|
||||
)
|
||||
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
asr_manifest_writer(asr_manifest, dump_data, verbose=verbose)
|
||||
# with asr_manifest.open("w") as mf:
|
||||
# print(f"writing manifest to {asr_manifest}")
|
||||
# for d in dump_data:
|
||||
# rel_data_path = d["audio_path"]
|
||||
# audio_dur = d["duration"]
|
||||
# transcript = d["text"]
|
||||
# manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||
# mf.write(manifest)
|
||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||
ExtendedPath(ui_dump_file).write_json({"data": dump_data}, verbose=verbose)
|
||||
return num_datapoints
|
||||
|
||||
|
||||
def asr_manifest_reader(data_manifest_path: Path):
|
||||
print(f"reading manifest from {data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
data_jsonl = pf.readlines()
|
||||
data_data = [json.loads(v) for v in data_jsonl]
|
||||
for p in data_data:
|
||||
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
||||
p["text"] = p["text"].strip()
|
||||
yield p
|
||||
|
||||
|
||||
def asr_test_writer(out_file_path: Path, source):
|
||||
def dd_str(dd, idx):
|
||||
path = dd["audio_filepath"]
|
||||
# dur = dd["duration"]
|
||||
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
|
||||
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
|
||||
|
||||
res_file = out_file_path.with_suffix(".result.json")
|
||||
with out_file_path.open("w") as of:
|
||||
print(f"opening {out_file_path} for writing test")
|
||||
results = []
|
||||
idx = 0
|
||||
for ui_dd in source:
|
||||
results.append(ui_dd)
|
||||
out_str = dd_str(ui_dd, idx)
|
||||
of.write(out_str)
|
||||
idx += 1
|
||||
of.write("DO_HANGUP\n")
|
||||
ExtendedPath(res_file).write_json(results)
|
||||
|
||||
|
||||
def batch(iterable, n=1):
|
||||
ls = len(iterable)
|
||||
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
|
||||
|
||||
|
||||
def get_mongo_coll(uri):
|
||||
ud = pymongo.uri_parser.parse_uri(uri)
|
||||
conn = pymongo.MongoClient(uri)
|
||||
return conn[ud["database"]][ud["collection"]]
|
||||
|
||||
|
||||
def get_mongo_conn(host="", port=27017, db="db", col="collection"):
|
||||
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
||||
mongo_uri = f"mongodb://{mongo_host}:{port}/"
|
||||
return pymongo.MongoClient(mongo_uri)[db][col]
|
||||
|
||||
|
||||
def strip_silence(sound):
|
||||
from pydub.silence import detect_leading_silence
|
||||
|
||||
start_trim = detect_leading_silence(sound)
|
||||
end_trim = detect_leading_silence(sound.reverse())
|
||||
duration = len(sound)
|
||||
return sound[start_trim : duration - end_trim]
|
||||
|
||||
|
||||
def plot_seg(wav_plot_path, audio_path):
|
||||
fig = plt.Figure()
|
||||
ax = fig.add_subplot()
|
||||
(y, sr) = librosa.load(str(audio_path))
|
||||
audio_display.waveplot(y=y, sr=sr, ax=ax)
|
||||
with wav_plot_path.open("wb") as wav_plot_f:
|
||||
fig.set_tight_layout(True)
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
|
||||
|
||||
def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
|
||||
min_nums = 3
|
||||
max_duration = 1 * 60 * 60
|
||||
skip_duration = 1 * 60 * 60
|
||||
max_sample_dur = 20
|
||||
min_sample_dur = 2
|
||||
verbose = True
|
||||
|
||||
src_data_enum = (
|
||||
tqdm(list(ExtendedPath(data_file).read_jsonl()))
|
||||
if verbose
|
||||
else ExtendedPath(data_file).read_jsonl()
|
||||
)
|
||||
|
||||
def filtered_max_dur():
|
||||
wav_duration = 0
|
||||
for s in src_data_enum:
|
||||
nums = re.sub(" ", "", s["text"])
|
||||
if len(nums) >= min_nums:
|
||||
wav_duration += s["duration"]
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
yield s
|
||||
if wav_duration > max_duration:
|
||||
break
|
||||
typer.echo(f"filtered only {duration_str(wav_duration)} of audio")
|
||||
|
||||
def filtered_skip_dur():
|
||||
wav_duration = 0
|
||||
for s in src_data_enum:
|
||||
nums = re.sub(" ", "", s["text"])
|
||||
if len(nums) >= min_nums:
|
||||
wav_duration += s["duration"]
|
||||
if wav_duration <= skip_duration:
|
||||
continue
|
||||
elif len(nums) >= min_nums:
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
yield s
|
||||
typer.echo(f"skipped {duration_str(skip_duration)} of audio")
|
||||
|
||||
def filtered_blanks():
|
||||
blank_count = total_count = 0
|
||||
for s in src_data_enum:
|
||||
total_count += 1
|
||||
nums = re.sub(" ", "", s["text"])
|
||||
if nums != "":
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
yield s
|
||||
else:
|
||||
blank_count += 1
|
||||
typer.echo(f"filtered {blank_count} of {total_count} blank samples")
|
||||
|
||||
def filtered_max_sample_dur():
|
||||
max_dur_count = 0
|
||||
for s in src_data_enum:
|
||||
wav_duration = s["duration"]
|
||||
if wav_duration <= max_sample_dur:
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
yield s
|
||||
else:
|
||||
max_dur_count += 1
|
||||
typer.echo(
|
||||
f"filtered {max_dur_count} samples longer thans {max_sample_dur}s"
|
||||
)
|
||||
|
||||
def filtered_transform_digits():
|
||||
count = 0
|
||||
for s in src_data_enum:
|
||||
count += 1
|
||||
digit_text = replace_digit_symbol(s["text"])
|
||||
only_digits = discard_except_digits(digit_text)
|
||||
char_text = digits_to_chars(only_digits)
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
s["text"] = char_text
|
||||
yield s
|
||||
typer.echo(f"transformed {count} samples")
|
||||
|
||||
def filtered_extract_chars():
|
||||
count = 0
|
||||
for s in src_data_enum:
|
||||
count += 1
|
||||
no_digits = digits_to_chars(s["text"]).upper()
|
||||
only_chars = re.sub("[^A-Z'\b]", " ", no_digits)
|
||||
filter_text = replace_redundant_spaces_with(
|
||||
only_chars, " "
|
||||
).strip()
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
s["text"] = filter_text
|
||||
yield s
|
||||
typer.echo(f"transformed {count} samples")
|
||||
|
||||
def filtered_resample():
|
||||
count = 0
|
||||
for s in src_data_enum:
|
||||
count += 1
|
||||
src_aud = pydub.AudioSegment.from_file(
|
||||
src_dataset_path / Path(s["audio_filepath"])
|
||||
)
|
||||
dst_aud = (
|
||||
src_aud.set_channels(1)
|
||||
.set_sample_width(1)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
dst_aud.export(
|
||||
dest_dataset_path / Path(s["audio_filepath"]), format="wav"
|
||||
)
|
||||
yield s
|
||||
typer.echo(f"transformed {count} samples")
|
||||
|
||||
def filtered_msec_to_sec():
|
||||
count = 0
|
||||
for s in src_data_enum:
|
||||
count += 1
|
||||
s["duration"] = s["duration"] / 1000
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
yield s
|
||||
typer.echo(f"transformed {count} samples")
|
||||
|
||||
def filtered_blank_hr_max_dur():
|
||||
max_duration = 3 * 60 * 60
|
||||
wav_duration = 0
|
||||
for s in src_data_enum:
|
||||
# nums = re.sub(" ", "", s["text"])
|
||||
s["text"] = "gAAAAABgq2FR6ajbhMsDmWRQBzX6gIzyAG5sMwFihGeV7E_6eVJqqF78yzmtTJPsJAOJEEXhJ9Z45MrYNgE1sq7VUdsBVGh2cw=="
|
||||
if (
|
||||
s["duration"] >= min_sample_dur
|
||||
and s["duration"] <= max_sample_dur
|
||||
):
|
||||
wav_duration += s["duration"]
|
||||
shutil.copy(
|
||||
src_dataset_path / Path(s["audio_filepath"]),
|
||||
dest_dataset_path / Path(s["audio_filepath"]),
|
||||
)
|
||||
yield s
|
||||
if wav_duration > max_duration:
|
||||
break
|
||||
typer.echo(f"filtered only {duration_str(wav_duration)} of audio")
|
||||
|
||||
filter_kind_map = {
|
||||
"max_dur_1hr_min3num": filtered_max_dur,
|
||||
"skip_dur_1hr_min3num": filtered_skip_dur,
|
||||
"blanks": filtered_blanks,
|
||||
"transform_digits": filtered_transform_digits,
|
||||
"extract_chars": filtered_extract_chars,
|
||||
"resample_ulaw24kmono": filtered_resample,
|
||||
"max_sample_dur": filtered_max_sample_dur,
|
||||
"msec_to_sec": filtered_msec_to_sec,
|
||||
"blank_3hr_max_dur": filtered_blank_hr_max_dur,
|
||||
}
|
||||
return filter_kind_map
|
||||
88
src/plume/utils/align.py
Normal file
88
src/plume/utils/align.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from pathlib import Path
|
||||
# from IPython import display
|
||||
import io
|
||||
import shutil
|
||||
|
||||
import typer
|
||||
from plume.utils import lazy_module
|
||||
|
||||
from .tts import GoogleTTS
|
||||
|
||||
display = lazy_module('IPython.display')
|
||||
pydub = lazy_module('pydub')
|
||||
requests = lazy_module('requests')
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
# Start gentle with following command
|
||||
# docker run --rm -d --name gentle_service -p 8765:8765/tcp lowerquality/gentle
|
||||
|
||||
|
||||
def gentle_aligner(service_uri, wav_data, utter_text):
|
||||
# service_uri= "http://52.41.161.36:8765/transcriptions"
|
||||
wav_f = io.BytesIO(wav_data)
|
||||
wav_seg = pydub.AudioSegment.from_file(wav_f)
|
||||
|
||||
mp3_f = io.BytesIO()
|
||||
wav_seg.export(mp3_f, format="mp3")
|
||||
mp3_f.seek(0)
|
||||
params = (("async", "false"),)
|
||||
files = {
|
||||
"audio": ("audio.mp3", mp3_f),
|
||||
"transcript": ("words.txt", io.BytesIO(utter_text.encode("utf-8"))),
|
||||
}
|
||||
|
||||
response = requests.post(service_uri, params=params, files=files)
|
||||
print(f"Time duration of audio {wav_seg.duration_seconds}")
|
||||
print(f"Time taken to align: {response.elapsed}s")
|
||||
return wav_seg, response.json()
|
||||
|
||||
|
||||
def gentle_align_iter(service_uri, wav_data, utter_text):
|
||||
wav_seg, response = gentle_aligner(service_uri, wav_data, utter_text)
|
||||
for span in response:
|
||||
word_seg = wav_seg[int(span["start"] * 1000) : int(span["end"] * 1000)]
|
||||
word = span["word"]
|
||||
yield (word, word_seg)
|
||||
|
||||
|
||||
def tts_jupyter():
|
||||
google_voices = GoogleTTS.voice_list()
|
||||
gtts = GoogleTTS()
|
||||
# google_voices[4]
|
||||
us_voice = [v for v in google_voices if v["language"] == "en-US"][0]
|
||||
utter_text = (
|
||||
"I would like to align the audio segments based on word level timestamps"
|
||||
)
|
||||
wav_data = gtts.text_to_speech(text=utter_text, params=us_voice)
|
||||
for word, seg in gentle_align_iter(wav_data, utter_text):
|
||||
print(word)
|
||||
display.display(seg)
|
||||
|
||||
|
||||
@app.command()
|
||||
def gentle_preview(
|
||||
audio_path: Path,
|
||||
transcript_path: Path,
|
||||
service_uri="http://101.53.142.218:8765/transcriptions",
|
||||
gent_preview_dir="./gentle_preview",
|
||||
):
|
||||
from . import ExtendedPath
|
||||
|
||||
pkg_gentle_dir = Path(__file__).parent / 'gentle_preview'
|
||||
|
||||
shutil.copytree(str(pkg_gentle_dir), str(gent_preview_dir))
|
||||
ab = audio_path.read_bytes()
|
||||
tt = transcript_path.read_text()
|
||||
audio, alignment = gentle_aligner(service_uri, ab, tt)
|
||||
audio.export(gent_preview_dir / Path("a.wav"), format="wav")
|
||||
alignment["status"] = "OK"
|
||||
ExtendedPath(gent_preview_dir / Path("status.json")).write_json(alignment)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
53
src/plume/utils/audio.py
Normal file
53
src/plume/utils/audio.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import sys
|
||||
from io import BytesIO
|
||||
|
||||
from .lazy_import import lazy_module, lazy_callable
|
||||
|
||||
np = lazy_module("numpy")
|
||||
pydub = lazy_module("pydub")
|
||||
lfilter = lazy_callable("scipy.signal.lfilter")
|
||||
butter = lazy_callable("scipy.signal.butter")
|
||||
read = lazy_callable("scipy.io.wavfile.read")
|
||||
write = lazy_callable("scipy.io.wavfile.write")
|
||||
# from scipy.signal import lfilter, butter
|
||||
# from scipy.io.wavfile import read, write
|
||||
# import numpy as np
|
||||
|
||||
|
||||
def audio_seg_to_wav_bytes(aud_seg):
|
||||
b = BytesIO()
|
||||
aud_seg.export(b, format="wav")
|
||||
return b.getvalue()
|
||||
|
||||
|
||||
def audio_wav_bytes_to_seg(wav_bytes):
|
||||
b = BytesIO(wav_bytes)
|
||||
return pydub.AudioSegment.from_file(b)
|
||||
|
||||
|
||||
def butter_params(low_freq, high_freq, fs, order=5):
|
||||
nyq = 0.5 * fs
|
||||
low = low_freq / nyq
|
||||
high = high_freq / nyq
|
||||
b, a = butter(order, [low, high], btype="band")
|
||||
return b, a
|
||||
|
||||
|
||||
def butter_bandpass_filter(data, low_freq, high_freq, fs, order=5):
|
||||
b, a = butter_params(low_freq, high_freq, fs, order=order)
|
||||
y = lfilter(b, a, data)
|
||||
return y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fs, audio = read(sys.argv[1])
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
low_freq = 300.0
|
||||
high_freq = 4000.0
|
||||
filtered_signal = butter_bandpass_filter(
|
||||
audio, low_freq, high_freq, fs, order=6
|
||||
)
|
||||
fname = sys.argv[1].split(".wav")[0] + "_moded.wav"
|
||||
write(fname, fs, np.array(filtered_signal, dtype=np.int16))
|
||||
188
src/plume/utils/encrypt.py
Normal file
188
src/plume/utils/encrypt.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from collections import namedtuple
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
# from cryptography.fernet import Fernet
|
||||
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import asr_manifest_writer
|
||||
from .extended_path import ExtendedPath
|
||||
from .audio import audio_seg_to_wav_bytes, audio_wav_bytes_to_seg
|
||||
from .parallel import parallel_apply
|
||||
from .lazy_import import lazy_module
|
||||
|
||||
cryptography = lazy_module("cryptography")
|
||||
# cryptography.fernet = lazy_module("cryptography.fernet")
|
||||
pydub = lazy_module("pydub")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback()
|
||||
def encrypt():
|
||||
"""
|
||||
encrypt sub commands
|
||||
"""
|
||||
|
||||
|
||||
def wav_cryptor(key=""):
|
||||
WavCryptor = namedtuple(
|
||||
"WavCryptor",
|
||||
(
|
||||
"keygen",
|
||||
"encrypt_wav_path_to",
|
||||
"decrypt_wav_path_to",
|
||||
"decrypt_wav_path",
|
||||
),
|
||||
)
|
||||
_enc_key = key
|
||||
_crypto_f = cryptography.fernet.Fernet(_enc_key)
|
||||
|
||||
def encrypt_wav_bytes(f, dec_wav_bytes):
|
||||
b = BytesIO(dec_wav_bytes)
|
||||
audio_seg = pydub.AudioSegment.from_file(b)
|
||||
# audio_seg.raw_data
|
||||
enc_wav_bytes = f.encrypt(audio_seg.raw_data)
|
||||
encrypted_seg = pydub.AudioSegment(
|
||||
enc_wav_bytes,
|
||||
frame_rate=audio_seg.frame_rate,
|
||||
channels=audio_seg.channels,
|
||||
sample_width=audio_seg.sample_width,
|
||||
)
|
||||
return audio_seg_to_wav_bytes(encrypted_seg)
|
||||
|
||||
def decrypt_wav_bytes(f, enc_wav_bytes):
|
||||
b = BytesIO(enc_wav_bytes)
|
||||
audio_seg = pydub.AudioSegment.from_file(b)
|
||||
dec_wav_bytes = f.decrypt(audio_seg.raw_data)
|
||||
decrypted_seg = pydub.AudioSegment(
|
||||
dec_wav_bytes,
|
||||
frame_rate=audio_seg.frame_rate,
|
||||
channels=audio_seg.channels,
|
||||
sample_width=audio_seg.sample_width,
|
||||
)
|
||||
return audio_seg_to_wav_bytes(decrypted_seg)
|
||||
|
||||
def encrypt_wav_path_to(dec_audio_path: Path, enc_audio_path: Path):
|
||||
dec_wav_bytes = dec_audio_path.read_bytes()
|
||||
enc_audio_path.write_bytes(encrypt_wav_bytes(_crypto_f, dec_wav_bytes))
|
||||
|
||||
def decrypt_wav_path_to(enc_audio_path: Path, dec_audio_path: Path):
|
||||
enc_wav_bytes = enc_audio_path.read_bytes()
|
||||
dec_audio_path.write_bytes(decrypt_wav_bytes(_crypto_f, enc_wav_bytes))
|
||||
|
||||
def decrypt_wav_path(enc_audio_path: Path):
|
||||
enc_wav_bytes = enc_audio_path.read_bytes()
|
||||
return decrypt_wav_bytes(_crypto_f, enc_wav_bytes)
|
||||
|
||||
return WavCryptor(
|
||||
cryptography.fernet.Fernet.generate_key,
|
||||
encrypt_wav_path_to,
|
||||
decrypt_wav_path_to,
|
||||
decrypt_wav_path,
|
||||
)
|
||||
|
||||
|
||||
def text_cryptor(key=""):
|
||||
TextCryptor = namedtuple(
|
||||
"TextCryptor",
|
||||
("keygen", "encrypt_text", "decrypt_text"),
|
||||
)
|
||||
_enc_key = key
|
||||
_crypto_f = cryptography.fernet.Fernet(_enc_key)
|
||||
|
||||
def encrypt_text(text: str):
|
||||
return _crypto_f.encrypt(text.encode("utf-8"))
|
||||
|
||||
def decrypt_text(text: str):
|
||||
return _crypto_f.decrypt(text).decode("utf-8")
|
||||
|
||||
return TextCryptor(
|
||||
cryptography.fernet.Fernet.generate_key, encrypt_text, decrypt_text
|
||||
)
|
||||
|
||||
|
||||
def encrypted_asr_manifest_reader(
|
||||
data_manifest_path: Path, encryption_key: str, verbose=True, parallel=True
|
||||
):
|
||||
print(f"reading encrypted manifest from {data_manifest_path}")
|
||||
asr_data = list(ExtendedPath(data_manifest_path).read_jsonl())
|
||||
enc_key_bytes = encryption_key.encode("utf-8")
|
||||
wc = wav_cryptor(enc_key_bytes)
|
||||
tc = text_cryptor(enc_key_bytes)
|
||||
|
||||
def decrypt_fn(p):
|
||||
d = {
|
||||
"audio_seg": audio_wav_bytes_to_seg(
|
||||
wc.decrypt_wav_path(
|
||||
data_manifest_path.parent / Path(p["audio_filepath"])
|
||||
)
|
||||
),
|
||||
"text": tc.decrypt_text(p["text"].encode("utf-8")),
|
||||
}
|
||||
return d
|
||||
|
||||
if parallel:
|
||||
for d in parallel_apply(decrypt_fn, asr_data, verbose=verbose):
|
||||
yield d
|
||||
else:
|
||||
for p in tqdm.tqdm(asr_data) if verbose else asr_data:
|
||||
yield decrypt_fn(d)
|
||||
|
||||
|
||||
def decrypt_asr_dataset(
|
||||
src_dataset_dir: Path,
|
||||
dest_dataset_dir: Path,
|
||||
encryption_key: str,
|
||||
verbose=True,
|
||||
parallel=True,
|
||||
):
|
||||
data_manifest_path = src_dataset_dir / "manifest.json"
|
||||
(dest_dataset_dir / "wavs").mkdir(exist_ok=True, parents=True)
|
||||
dest_manifest_path = dest_dataset_dir / "manifest.json"
|
||||
print(f"reading encrypted manifest from {data_manifest_path}")
|
||||
asr_data = list(ExtendedPath(data_manifest_path).read_jsonl())
|
||||
enc_key_bytes = encryption_key.encode("utf-8")
|
||||
wc = wav_cryptor(enc_key_bytes)
|
||||
tc = text_cryptor(enc_key_bytes)
|
||||
|
||||
def decrypt_fn(p):
|
||||
dest_path = dest_dataset_dir / Path(p["audio_filepath"])
|
||||
wc.decrypt_wav_path_to(
|
||||
src_dataset_dir / Path(p["audio_filepath"]), dest_path
|
||||
)
|
||||
d = {
|
||||
"audio_filepath": dest_path,
|
||||
"duration": p["duration"],
|
||||
"text": tc.decrypt_text(p["text"].encode("utf-8")),
|
||||
}
|
||||
return d
|
||||
|
||||
def datagen():
|
||||
if parallel:
|
||||
for d in parallel_apply(decrypt_fn, asr_data, verbose=verbose):
|
||||
yield d
|
||||
else:
|
||||
for p in tqdm.tqdm(asr_data) if verbose else asr_data:
|
||||
yield decrypt_fn(d)
|
||||
|
||||
asr_manifest_writer(dest_manifest_path, datagen)
|
||||
|
||||
|
||||
@app.command()
|
||||
def keygen():
|
||||
gen_key = cryptography.fernet.Fernet.generate_key()
|
||||
typer.echo(f"KEY: {gen_key}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def encrypt_text(
|
||||
text_to_encrypt: str,
|
||||
encryption_key: str = typer.Option(..., prompt=True, hide_input=True),
|
||||
):
|
||||
enc_key_bytes = encryption_key.encode("utf-8")
|
||||
tc = text_cryptor(enc_key_bytes)
|
||||
cryptext = tc.encrypt_text(text_to_encrypt)
|
||||
typer.echo(cryptext)
|
||||
56
src/plume/utils/extended_path.py
Normal file
56
src/plume/utils/extended_path.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from .lazy_import import lazy_module
|
||||
|
||||
yaml = lazy_module("ruamel.yaml")
|
||||
pydub = lazy_module("pydub")
|
||||
|
||||
|
||||
class ExtendedPath(type(Path())):
|
||||
"""docstring for ExtendedPath."""
|
||||
|
||||
def read_json(self, verbose=False):
|
||||
if verbose:
|
||||
print(f"reading json from {self}")
|
||||
with self.open("r") as jf:
|
||||
return json.load(jf)
|
||||
|
||||
def read_yaml(self, verbose=False):
|
||||
yaml_o = yaml.YAML(typ="safe", pure=True)
|
||||
if verbose:
|
||||
print(f"reading yaml from {self}")
|
||||
with self.open("r") as yf:
|
||||
return yaml_o.load(yf)
|
||||
|
||||
def read_jsonl(self, verbose=False):
|
||||
if verbose:
|
||||
print(f"reading jsonl from {self}")
|
||||
with self.open("r") as jf:
|
||||
for ln in jf.readlines():
|
||||
yield json.loads(ln)
|
||||
|
||||
def read_audio_segment(self):
|
||||
return pydub.AudioSegment.from_file(self)
|
||||
|
||||
def write_json(self, data, verbose=False):
|
||||
if verbose:
|
||||
print(f"writing json to {self}")
|
||||
self.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.open("w") as jf:
|
||||
json.dump(data, jf, indent=2)
|
||||
|
||||
def write_yaml(self, data, verbose=False):
|
||||
yaml_o = yaml.YAML()
|
||||
if verbose:
|
||||
print(f"writing yaml to {self}")
|
||||
with self.open("w") as yf:
|
||||
yaml_o.dump(data, yf)
|
||||
|
||||
def write_jsonl(self, data, verbose=False):
|
||||
if verbose:
|
||||
print(f"writing jsonl to {self}")
|
||||
self.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.open("w") as jf:
|
||||
for d in data:
|
||||
jf.write(json.dumps(d) + "\n")
|
||||
5
src/plume/utils/gentle_preview/README.md
Normal file
5
src/plume/utils/gentle_preview/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
Serve with https://github.com/danvk/RangeHTTPServer
|
||||
`https://github.com/claysciences/CORSRangeHTTPServer`
|
||||
|
||||
`python -m RangeHTTPServer`
|
||||
`python -m http.server`
|
||||
80
src/plume/utils/gentle_preview/align.html
Normal file
80
src/plume/utils/gentle_preview/align.html
Normal file
@@ -0,0 +1,80 @@
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<style>
|
||||
body {font-family: sans-serif; padding-top: 70px; }
|
||||
textarea { width: 500px; height: 20em; }
|
||||
input, textarea { margin: 1em 0; }
|
||||
#header {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
height: 50px;
|
||||
line-height: 50px;
|
||||
width: 100%;
|
||||
background-color: #999;
|
||||
box-shadow: 0px 0px 5px 0px rgba(0,0,0,0.5);
|
||||
font-family: Helvetica, sans-serif;
|
||||
}
|
||||
#header, #header a {
|
||||
color: white;
|
||||
}
|
||||
.home {
|
||||
margin: 0;
|
||||
font-size: 125%;
|
||||
font-weight: lighter;
|
||||
text-transform: lowercase;
|
||||
}
|
||||
.home a {
|
||||
margin: 0;
|
||||
background: #666;
|
||||
padding-left: 25px;
|
||||
padding-right: 30px;
|
||||
margin-right: 20px;
|
||||
float: left;
|
||||
text-decoration: none;
|
||||
}
|
||||
.home:hover a {
|
||||
background: #555;
|
||||
}
|
||||
#align-button {
|
||||
background: #CCC;
|
||||
border: 0;
|
||||
font-size: 18px;
|
||||
padding: 10px 30px;
|
||||
cursor: pointer;
|
||||
}
|
||||
#alignment-flags {
|
||||
background: #CCC;
|
||||
border: 0;
|
||||
font-size: 18px;
|
||||
padding: 10px 30px;
|
||||
}
|
||||
#footer {
|
||||
margin-top: 100px;
|
||||
border-top: 1px dotted black;
|
||||
font-size: 8pt;
|
||||
font-style: italic;
|
||||
padding: 10px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="header">
|
||||
<h1 class="home"><a href="/">Gentle</a></h1>
|
||||
</div>
|
||||
<form action="/transcriptions" method="POST" enctype="multipart/form-data">
|
||||
Audio:<br>
|
||||
<input type=file name=audio><br>
|
||||
<br>
|
||||
Transcript:<br>
|
||||
<textarea name="transcript"></textarea><br>
|
||||
<input id=alignment-flags name=conservative type=checkbox> Conservative<br>
|
||||
<input id=alignment-flags name=disfluency type=checkbox> Include disfluencies<br>
|
||||
<input id="align-button" type=submit value=Align>
|
||||
</form>
|
||||
<div id="footer">
|
||||
<a href="https://lowerquality.com/gentle">Gentle</a> is free software released under the <a href="https://opensource.org/licenses/MIT">MIT license</a>. <a href="https://lowerquality.com/gentle">Homepage</a> | <a href="https://github.com/lowerquality/gentle">Source code</a>.
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
408
src/plume/utils/gentle_preview/index.html
Normal file
408
src/plume/utils/gentle_preview/index.html
Normal file
@@ -0,0 +1,408 @@
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<style>
|
||||
html, body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
#header {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
height: 50px;
|
||||
line-height: 50px;
|
||||
width: 100%;
|
||||
background-color: #999;
|
||||
box-shadow: 0px 0px 5px 0px rgba(0,0,0,0.5);
|
||||
font-family: Helvetica, sans-serif;
|
||||
}
|
||||
#header, #header a {
|
||||
color: white;
|
||||
}
|
||||
#downloads {
|
||||
float: right;
|
||||
background: #999;
|
||||
}
|
||||
.download {
|
||||
float: right;
|
||||
background: #999;
|
||||
padding: 0 5px;
|
||||
}
|
||||
.home {
|
||||
margin: 0;
|
||||
font-size: 125%;
|
||||
font-weight: lighter;
|
||||
text-transform: lowercase;
|
||||
}
|
||||
.home a {
|
||||
margin: 0;
|
||||
background: #666;
|
||||
padding-left: 25px;
|
||||
padding-right: 30px;
|
||||
margin-right: 20px;
|
||||
float: left;
|
||||
text-decoration: none;
|
||||
}
|
||||
.home:hover a {
|
||||
background: #555;
|
||||
}
|
||||
#audio {
|
||||
margin-top: 9px;
|
||||
width: 50%;
|
||||
display: inline-block;
|
||||
}
|
||||
#transcript {
|
||||
margin: 0 15px;
|
||||
margin-top: 70px;
|
||||
margin-bottom: 5em;
|
||||
white-space: pre-wrap;
|
||||
line-height: 2em;
|
||||
max-width: 600px;
|
||||
color: #999;
|
||||
}
|
||||
#transcript.status {
|
||||
background-color: #333;
|
||||
color: #fff;
|
||||
font-family: Courier, mono;
|
||||
line-height: 1em;
|
||||
font-size: 10pt;
|
||||
max-width: 100%;
|
||||
}
|
||||
#transcript.status h2 {
|
||||
padding: 10px;
|
||||
}
|
||||
#transcript.status .entry {
|
||||
margin-bottom: 10px;
|
||||
padding: 10px;
|
||||
}
|
||||
#transcript.status progress {
|
||||
width: 100%;
|
||||
height: 30px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.success {
|
||||
color: black;
|
||||
}
|
||||
.success:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.active {
|
||||
color: magenta;
|
||||
}
|
||||
#preloader {
|
||||
visibility: hidden;
|
||||
}
|
||||
.phactive {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.phones {
|
||||
position: absolute;
|
||||
color: #333;
|
||||
}
|
||||
.phones .phone {
|
||||
margin-right: 5px;
|
||||
font-family: Helvetica, sans-serif;
|
||||
text-transform: uppercase;
|
||||
font-size: 50%;
|
||||
}
|
||||
.phones .phone:last-child {
|
||||
margin-right: 0;
|
||||
}
|
||||
#footer {
|
||||
margin-top: 100px;
|
||||
border-top: 1px dotted black;
|
||||
font-size: 8pt;
|
||||
font-style: italic;
|
||||
font-family: Helvetica, sans-serif;
|
||||
padding: 10px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="header">
|
||||
<!-- <h1 class="home"><a href="/">Gentle</a></h1> -->
|
||||
<audio id="audio" src="a.wav" controls="true" preload="auto"></audio>
|
||||
<img src="/preloader.gif" id="preloader" alt="loading...">
|
||||
<span id="downloads"> </div>
|
||||
</div>
|
||||
<div id="transcript"></div>
|
||||
<!-- <div id="footer">
|
||||
<a href="https://lowerquality.com/gentle">Gentle</a> is free software released under the <a href="https://opensource.org/licenses/MIT">MIT license</a>. <a href="https://lowerquality.com/gentle">Homepage</a> | <a href="https://github.com/lowerquality/gentle">Source code</a>.
|
||||
</div> -->
|
||||
|
||||
<script>
|
||||
|
||||
function get(url, cb) {
|
||||
var xhr = new XMLHttpRequest();
|
||||
xhr.open("GET", url, true);
|
||||
xhr.onload = function() {
|
||||
cb(this.responseText);
|
||||
}
|
||||
xhr.send();
|
||||
}
|
||||
function get_json(url, cb) {
|
||||
get(url, function(x) {
|
||||
cb(JSON.parse(x));
|
||||
});
|
||||
}
|
||||
|
||||
var $a = document.getElementById("audio");
|
||||
window.onkeydown = function(ev) {
|
||||
if(ev.keyCode == 32) {
|
||||
ev.preventDefault();
|
||||
$a.pause();
|
||||
}
|
||||
}
|
||||
|
||||
var $trans = document.getElementById("transcript");
|
||||
var $preloader = document.getElementById('preloader');
|
||||
|
||||
var wds = [];
|
||||
var cur_wd;
|
||||
|
||||
var $phones = document.createElement("div");
|
||||
$phones.className = "phones";
|
||||
document.body.appendChild($phones);
|
||||
|
||||
var cur_phones$ = []; // List of phoneme $divs
|
||||
var $active_phone;
|
||||
|
||||
function render_phones(wd) {
|
||||
cur_phones$ = [];
|
||||
$phones.innerHTML = "";
|
||||
$active_phone = null;
|
||||
|
||||
$phones.style.top = wd.$div.offsetTop + 18;
|
||||
$phones.style.left = wd.$div.offsetLeft;
|
||||
|
||||
var dur = wd.end - wd.start;
|
||||
|
||||
var start_x = wd.$div.offsetLeft;
|
||||
|
||||
wd.phones
|
||||
.forEach(function(ph){
|
||||
var $p = document.createElement("span");
|
||||
$p.className = "phone";
|
||||
$p.textContent = ph.phone.split("_")[0];
|
||||
|
||||
$phones.appendChild($p);
|
||||
cur_phones$.push($p);
|
||||
});
|
||||
|
||||
var offsetToCenter = (wd.$div.offsetWidth - $phones.offsetWidth) / 2;
|
||||
$phones.style.left = wd.$div.offsetLeft + offsetToCenter;
|
||||
}
|
||||
function highlight_phone(t) {
|
||||
if(!cur_wd) {
|
||||
$phones.innerHTML = "";
|
||||
return;
|
||||
}
|
||||
var hit;
|
||||
var cur_t = cur_wd.start;
|
||||
|
||||
cur_wd.phones.forEach(function(ph, idx) {
|
||||
if(cur_t <= t && cur_t + ph.duration >= t) {
|
||||
hit = idx;
|
||||
}
|
||||
cur_t += ph.duration;
|
||||
});
|
||||
|
||||
if(hit) {
|
||||
var $ph = cur_phones$[hit];
|
||||
if($ph != $active_phone) {
|
||||
if($active_phone) {
|
||||
$active_phone.classList.remove("phactive");
|
||||
}
|
||||
if($ph) {
|
||||
$ph.classList.add("phactive");
|
||||
}
|
||||
}
|
||||
$active_phone = $ph;
|
||||
}
|
||||
}
|
||||
|
||||
function highlight_word() {
|
||||
var t = $a.currentTime;
|
||||
// XXX: O(N); use binary search
|
||||
var hits = wds.filter(function(x) {
|
||||
return (t - x.start) > 0.01 && (x.end - t) > 0.01;
|
||||
}, wds);
|
||||
var next_wd = hits[hits.length - 1];
|
||||
|
||||
if(cur_wd != next_wd) {
|
||||
var active = document.querySelectorAll('.active');
|
||||
for(var i = 0; i < active.length; i++) {
|
||||
active[i].classList.remove('active');
|
||||
}
|
||||
if(next_wd && next_wd.$div) {
|
||||
next_wd.$div.classList.add('active');
|
||||
render_phones(next_wd);
|
||||
}
|
||||
}
|
||||
cur_wd = next_wd;
|
||||
highlight_phone(t);
|
||||
|
||||
window.requestAnimationFrame(highlight_word);
|
||||
}
|
||||
window.requestAnimationFrame(highlight_word);
|
||||
|
||||
$trans.innerHTML = "Loading...";
|
||||
|
||||
function render(ret) {
|
||||
wds = ret['words'] || [];
|
||||
transcript = ret['transcript'];
|
||||
|
||||
$trans.innerHTML = '';
|
||||
|
||||
var currentOffset = 0;
|
||||
|
||||
wds.forEach(function(wd) {
|
||||
if(wd.case == 'not-found-in-transcript') {
|
||||
// TODO: show phonemes somewhere
|
||||
var txt = ' ' + wd.word;
|
||||
var $plaintext = document.createTextNode(txt);
|
||||
$trans.appendChild($plaintext);
|
||||
return;
|
||||
}
|
||||
|
||||
// Add non-linked text
|
||||
if(wd.startOffset > currentOffset) {
|
||||
var txt = transcript.slice(currentOffset, wd.startOffset);
|
||||
var $plaintext = document.createTextNode(txt);
|
||||
$trans.appendChild($plaintext);
|
||||
currentOffset = wd.startOffset;
|
||||
}
|
||||
|
||||
var $wd = document.createElement('span');
|
||||
var txt = transcript.slice(wd.startOffset, wd.endOffset);
|
||||
var $wdText = document.createTextNode(txt);
|
||||
$wd.appendChild($wdText);
|
||||
wd.$div = $wd;
|
||||
if(wd.start !== undefined) {
|
||||
$wd.className = 'success';
|
||||
}
|
||||
$wd.onclick = function() {
|
||||
if(wd.start !== undefined) {
|
||||
console.log(wd.start);
|
||||
$a.currentTime = wd.start;
|
||||
$a.play();
|
||||
}
|
||||
};
|
||||
$trans.appendChild($wd);
|
||||
currentOffset = wd.endOffset;
|
||||
});
|
||||
|
||||
var txt = transcript.slice(currentOffset, transcript.length);
|
||||
var $plaintext = document.createTextNode(txt);
|
||||
$trans.appendChild($plaintext);
|
||||
currentOffset = transcript.length;
|
||||
}
|
||||
|
||||
function show_downloads() {
|
||||
var $d = document.getElementById("downloads");
|
||||
$d.textContent = "Download as: ";
|
||||
var uid = window.location.pathname.split("/")[2];
|
||||
// Name, path, title, inhibit-on-file:///
|
||||
[["CSV", "align.csv", "Word alignment CSV"],
|
||||
["JSON", "align.json", "JSON word/phoneme alignment data"],
|
||||
["Zip", "/zip/" + uid + ".zip", "Standalone zipfile", true]]
|
||||
.forEach(function(x) {
|
||||
var $a = document.createElement("a");
|
||||
$a.className = "download";
|
||||
$a.textContent = x[0];
|
||||
$a.href = x[1];
|
||||
$a.title = x[2];
|
||||
if(!x[3] || window.location.protocol != "file:") {
|
||||
$d.appendChild($a);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
var status_init = false;
|
||||
var status_log = []; // [ status ]
|
||||
var $status_pro;
|
||||
|
||||
function render_status(ret) {
|
||||
if(!status_init) {
|
||||
// Clobber the $trans div and use it for status updates
|
||||
$trans.innerHTML = "<h2>transcription in progress</h2>";
|
||||
$trans.className = "status";
|
||||
$status_pro = document.createElement("progress");
|
||||
$status_pro.setAttribute("min", "0");
|
||||
$status_pro.setAttribute("max", "100");
|
||||
$status_pro.value = 0;
|
||||
$trans.appendChild($status_pro);
|
||||
|
||||
status_init = true;
|
||||
}
|
||||
if(ret.status !== "TRANSCRIBING") {
|
||||
if(ret.percent) {
|
||||
$status_pro.value = (100*ret.percent);
|
||||
}
|
||||
}
|
||||
else if(ret.percent && (status_log.length == 0 || status_log[status_log.length-1].percent+0.0001 < ret.percent)) {
|
||||
// New entry
|
||||
var $entry = document.createElement("div");
|
||||
$entry.className = "entry";
|
||||
$entry.textContent = ret.message;
|
||||
ret.$div = $entry;
|
||||
|
||||
if(ret.percent) {
|
||||
$status_pro.value = (100*ret.percent);
|
||||
}
|
||||
|
||||
if(status_log.length > 0) {
|
||||
$trans.insertBefore($entry, status_log[status_log.length-1].$div);
|
||||
}
|
||||
else {
|
||||
$trans.appendChild($entry);
|
||||
}
|
||||
status_log.push(ret);
|
||||
}
|
||||
}
|
||||
|
||||
function update() {
|
||||
if(INLINE_JSON) {
|
||||
// We want this to work from file:/// domains, so we provide a
|
||||
// mechanism for inlining the alignment data.
|
||||
render(INLINE_JSON);
|
||||
// show_downloads();
|
||||
}
|
||||
else {
|
||||
// Show the status
|
||||
get_json('status.json', function(ret) {
|
||||
$a.style.visibility = 'hidden';
|
||||
if (ret.status == 'ERROR') {
|
||||
$preloader.style.visibility = 'hidden';
|
||||
$trans.innerHTML = '<b>' + ret.status + ': ' + ret.error + '</b>';
|
||||
} else if (ret.status == 'TRANSCRIBING' || ret.status == 'ALIGNING') {
|
||||
$preloader.style.visibility = 'visible';
|
||||
render_status(ret);
|
||||
setTimeout(update, 2000);
|
||||
} else if (ret.status == 'OK') {
|
||||
// show_downloads();
|
||||
$preloader.style.visibility = 'hidden';
|
||||
// XXX: should we fetch the align.json?
|
||||
// window.location.reload();
|
||||
$a.style.visibility = 'visible';
|
||||
render(ret);
|
||||
} else if (ret.status == 'ENCODING' || ret.status == 'STARTED') {
|
||||
$preloader.style.visibility = 'visible';
|
||||
$trans.innerHTML = 'Encoding, please wait...';
|
||||
setTimeout(update, 2000);
|
||||
} else {
|
||||
console.log("unknown status", ret);
|
||||
$preloader.style.visibility = 'hidden';
|
||||
$trans.innerHTML = ret.status + '...';
|
||||
setTimeout(update, 5000);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
var INLINE_JSON;
|
||||
|
||||
update();
|
||||
|
||||
</script></body></html>
|
||||
BIN
src/plume/utils/gentle_preview/preloader.gif
Normal file
BIN
src/plume/utils/gentle_preview/preloader.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.7 KiB |
731
src/plume/utils/lazy_import.py
Normal file
731
src/plume/utils/lazy_import.py
Normal file
@@ -0,0 +1,731 @@
|
||||
# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
|
||||
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4
|
||||
#
|
||||
# lazy_import --- https://github.com/mnmelo/lazy_import
|
||||
# Copyright (C) 2017-2018 Manuel Nuno Melo
|
||||
#
|
||||
# This file is part of lazy_import.
|
||||
#
|
||||
# lazy_import is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# lazy_import is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with lazy_import. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
# lazy_import was based on code from the importing module from the PEAK
|
||||
# package (see <http://peak.telecommunity.com/DevCenter/Importing>). The PEAK
|
||||
# package is released under the following license, reproduced here:
|
||||
#
|
||||
# Copyright (C) 1996-2004 by Phillip J. Eby and Tyler C. Sarna.
|
||||
# All rights reserved. This software may be used under the same terms
|
||||
# as Zope or Python. THERE ARE ABSOLUTELY NO WARRANTIES OF ANY KIND.
|
||||
# Code quality varies between modules, from "beta" to "experimental
|
||||
# pre-alpha". :)
|
||||
#
|
||||
# Code pertaining to lazy loading from PEAK importing was included in
|
||||
# lazy_import, modified in a number of ways. These are detailed in the
|
||||
# CHANGELOG file of lazy_import. Changes mainly involved Python 3
|
||||
# compatibility, extension to allow customizable behavior, and added
|
||||
# functionality (lazy importing of callable objects).
|
||||
#
|
||||
|
||||
"""
|
||||
Lazy module loading
|
||||
===================
|
||||
Functions and classes for lazy module loading that also delay import errors.
|
||||
Heavily borrowed from the `importing`_ module.
|
||||
.. _`importing`: http://peak.telecommunity.com/DevCenter/Importing
|
||||
Files and directories
|
||||
---------------------
|
||||
.. autofunction:: module
|
||||
.. autofunction:: callable
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"lazy_module",
|
||||
"lazy_callable",
|
||||
"lazy_function",
|
||||
"lazy_class",
|
||||
"LazyModule",
|
||||
"LazyCallable",
|
||||
"module_basename",
|
||||
"_MSG",
|
||||
"_MSG_CALLABLE",
|
||||
]
|
||||
|
||||
from types import ModuleType
|
||||
import sys
|
||||
|
||||
try:
|
||||
from importlib._bootstrap import _ImportLockContext
|
||||
except ImportError:
|
||||
# Python 2 doesn't have the context manager. Roll it ourselves (copied from
|
||||
# Python 3's importlib/_bootstrap.py)
|
||||
import imp
|
||||
|
||||
class _ImportLockContext:
|
||||
"""Context manager for the import lock."""
|
||||
|
||||
def __enter__(self):
|
||||
imp.acquire_lock()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
imp.release_lock()
|
||||
|
||||
|
||||
# Adding a __spec__ doesn't really help. I'll leave the code here in case
|
||||
# future python implementations start relying on it.
|
||||
try:
|
||||
from importlib.machinery import ModuleSpec
|
||||
except ImportError:
|
||||
ModuleSpec = None
|
||||
|
||||
import six
|
||||
from six import raise_from
|
||||
from six.moves import reload_module
|
||||
|
||||
# It is sometime useful to have access to the version number of a library.
|
||||
# This is usually done through the __version__ special attribute.
|
||||
# To make sure the version number is consistent between setup.py and the
|
||||
# library, we read the version number from the file called VERSION that stays
|
||||
# in the module directory.
|
||||
import os
|
||||
|
||||
# VERSION_FILE = os.path.join(os.path.dirname(__file__), "VERSION")
|
||||
# with open(VERSION_FILE) as infile:
|
||||
# __version__ = infile.read().strip()
|
||||
|
||||
# Logging
|
||||
import logging
|
||||
|
||||
# adding a TRACE level for stack debugging
|
||||
_LAZY_TRACE = 1
|
||||
logging.addLevelName(1, "LAZY_TRACE")
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
# Logs a formatted stack (takes no message or args/kwargs)
|
||||
def _lazy_trace(self):
|
||||
if self.isEnabledFor(_LAZY_TRACE):
|
||||
import traceback
|
||||
|
||||
self._log(_LAZY_TRACE, " ### STACK TRACE ###", ())
|
||||
for line in traceback.format_stack(sys._getframe(2)):
|
||||
for subline in line.split("\n"):
|
||||
self._log(_LAZY_TRACE, subline.rstrip(), ())
|
||||
|
||||
|
||||
logging.Logger.lazy_trace = _lazy_trace
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
################################
|
||||
# Module/function registration #
|
||||
################################
|
||||
|
||||
#### Lazy classes ####
|
||||
|
||||
|
||||
class LazyModule(ModuleType):
|
||||
"""Class for lazily-loaded modules that triggers proper loading on access.
|
||||
Instantiation should be made from a subclass of :class:`LazyModule`, with
|
||||
one subclass per instantiated module. Regular attribute set/access can then
|
||||
be recovered by setting the subclass's :meth:`__getattribute__` and
|
||||
:meth:`__setattribute__` to those of :class:`types.ModuleType`.
|
||||
"""
|
||||
|
||||
# peak.util.imports sets __slots__ to (), but it seems pointless because
|
||||
# the base ModuleType doesn't itself set __slots__.
|
||||
def __getattribute__(self, attr):
|
||||
logger.debug(
|
||||
"Getting attr {} of LazyModule instance of {}".format(
|
||||
attr, super(LazyModule, self).__getattribute__("__name__")
|
||||
)
|
||||
)
|
||||
logger.lazy_trace()
|
||||
# IPython tries to be too clever and constantly inspects, asking for
|
||||
# modules' attrs, which causes premature module loading and unesthetic
|
||||
# internal errors if the lazily-loaded module doesn't exist.
|
||||
if (
|
||||
run_from_ipython()
|
||||
and (attr.startswith(("__", "_ipython")) or attr == "_repr_mimebundle_")
|
||||
and module_basename(_caller_name()) in ("inspect", "IPython")
|
||||
):
|
||||
logger.debug(
|
||||
"Ignoring request for {}, deemed from IPython's "
|
||||
"inspection.".format(
|
||||
super(LazyModule, self).__getattribute__("__name__"), attr
|
||||
)
|
||||
)
|
||||
raise AttributeError
|
||||
if not attr in ("__name__", "__class__", "__spec__"):
|
||||
# __name__ and __class__ yield their values from the LazyModule;
|
||||
# __spec__ causes an AttributeError. Maybe in the future it will be
|
||||
# necessary to return an actual ModuleSpec object, but it works as
|
||||
# it is without that now.
|
||||
|
||||
# If it's an already-loaded submodule, we return it without
|
||||
# triggering a full loading
|
||||
try:
|
||||
return sys.modules[self.__name__ + "." + attr]
|
||||
except KeyError:
|
||||
pass
|
||||
# Check if it's one of the lazy callables
|
||||
try:
|
||||
_callable = type(self)._lazy_import_callables[attr]
|
||||
logger.debug("Returning lazy-callable '{}'.".format(attr))
|
||||
return _callable
|
||||
except (AttributeError, KeyError) as err:
|
||||
logger.debug(
|
||||
"Proceeding to load module {}, "
|
||||
"from requested value {}".format(
|
||||
super(LazyModule, self).__getattribute__("__name__"), attr
|
||||
)
|
||||
)
|
||||
_load_module(self)
|
||||
logger.debug(
|
||||
"Returning value '{}'.".format(
|
||||
super(LazyModule, self).__getattribute__(attr)
|
||||
)
|
||||
)
|
||||
return super(LazyModule, self).__getattribute__(attr)
|
||||
|
||||
def __setattr__(self, attr, value):
|
||||
logger.debug(
|
||||
"Setting attr {} to value {}, in LazyModule instance "
|
||||
"of {}".format(
|
||||
attr, value, super(LazyModule, self).__getattribute__("__name__")
|
||||
)
|
||||
)
|
||||
_load_module(self)
|
||||
return super(LazyModule, self).__setattr__(attr, value)
|
||||
|
||||
|
||||
class LazyCallable(object):
|
||||
"""Class for lazily-loaded callables that triggers module loading on access"""
|
||||
|
||||
def __init__(self, *args):
|
||||
if len(args) != 2:
|
||||
# Maybe the user tried to base a class off this lazy callable?
|
||||
try:
|
||||
logger.debug(
|
||||
"Got wrong number of args when init'ing "
|
||||
"LazyCallable. args is '{}'".format(args)
|
||||
)
|
||||
base = args[1][0]
|
||||
if isinstance(base, LazyCallable) and len(args) == 3:
|
||||
raise NotImplementedError(
|
||||
"It seems you are trying to use "
|
||||
"a lazy callable as a class "
|
||||
"base. This is not supported."
|
||||
)
|
||||
except (IndexError, TypeError):
|
||||
raise_from(
|
||||
TypeError(
|
||||
"LazyCallable takes exactly 2 arguments: "
|
||||
"a module/lazy module object and the name of "
|
||||
"a callable to be lazily loaded."
|
||||
),
|
||||
None,
|
||||
)
|
||||
self.module, self.cname = args
|
||||
self.modclass = type(self.module)
|
||||
self.callable = None
|
||||
# Need to save these, since the module-loading gets rid of them
|
||||
self.error_msgs = self.modclass._lazy_import_error_msgs
|
||||
self.error_strings = self.modclass._lazy_import_error_strings
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# No need to go through all the reloading more than once.
|
||||
if self.callable:
|
||||
return self.callable(*args, **kwargs)
|
||||
try:
|
||||
del self.modclass._lazy_import_callables[self.cname]
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
try:
|
||||
self.callable = getattr(self.module, self.cname)
|
||||
except AttributeError:
|
||||
msg = self.error_msgs["msg_callable"]
|
||||
raise_from(
|
||||
AttributeError(msg.format(callable=self.cname, **self.error_strings)),
|
||||
None,
|
||||
)
|
||||
except ImportError as err:
|
||||
# Import failed. We reset the dict and re-raise the ImportError.
|
||||
try:
|
||||
self.modclass._lazy_import_callables[self.cname] = self
|
||||
except AttributeError:
|
||||
self.modclass._lazy_import_callables = {self.cname: self}
|
||||
raise_from(err, None)
|
||||
else:
|
||||
return self.callable(*args, **kwargs)
|
||||
|
||||
|
||||
### Functions ###
|
||||
|
||||
|
||||
def lazy_module(modname, error_strings=None, lazy_mod_class=LazyModule, level="leaf"):
|
||||
"""Function allowing lazy importing of a module into the namespace.
|
||||
A lazy module object is created, registered in `sys.modules`, and
|
||||
returned. This is a hollow module; actual loading, and `ImportErrors` if
|
||||
not found, are delayed until an attempt is made to access attributes of the
|
||||
lazy module.
|
||||
A handy application is to use :func:`lazy_module` early in your own code
|
||||
(say, in `__init__.py`) to register all modulenames you want to be lazy.
|
||||
Because of registration in `sys.modules` later invocations of
|
||||
`import modulename` will also return the lazy object. This means that after
|
||||
initial registration the rest of your code can use regular pyhon import
|
||||
statements and retain the lazyness of the modules.
|
||||
Parameters
|
||||
----------
|
||||
modname : str
|
||||
The module to import.
|
||||
error_strings : dict, optional
|
||||
A dictionary of strings to use when module-loading fails. Key 'msg'
|
||||
sets the message to use (defaults to :attr:`lazy_import._MSG`). The
|
||||
message is formatted using the remaining dictionary keys. The default
|
||||
message informs the user of which module is missing (key 'module'),
|
||||
what code loaded the module as lazy (key 'caller'), and which package
|
||||
should be installed to solve the dependency (key 'install_name').
|
||||
None of the keys is mandatory and all are given smart names by default.
|
||||
lazy_mod_class: type, optional
|
||||
Which class to use when instantiating the lazy module, to allow
|
||||
deep customization. The default is :class:`LazyModule` and custom
|
||||
alternatives **must** be a subclass thereof.
|
||||
level : str, optional
|
||||
Which submodule reference to return. Either a reference to the 'leaf'
|
||||
module (the default) or to the 'base' module. This is useful if you'll
|
||||
be using the module functionality in the same place you're calling
|
||||
:func:`lazy_module` from, since then you don't need to run `import`
|
||||
again. Setting *level* does not affect which names/modules get
|
||||
registered in `sys.modules`.
|
||||
For *level* set to 'base' and *modulename* 'aaa.bbb.ccc'::
|
||||
aaa = lazy_import.lazy_module("aaa.bbb.ccc", level='base')
|
||||
# 'aaa' becomes defined in the current namespace, with
|
||||
# (sub)attributes 'aaa.bbb' and 'aaa.bbb.ccc'.
|
||||
# It's the lazy equivalent to:
|
||||
import aaa.bbb.ccc
|
||||
For *level* set to 'leaf'::
|
||||
ccc = lazy_import.lazy_module("aaa.bbb.ccc", level='leaf')
|
||||
# Only 'ccc' becomes set in the current namespace.
|
||||
# Lazy equivalent to:
|
||||
from aaa.bbb import ccc
|
||||
Returns
|
||||
-------
|
||||
module
|
||||
The module specified by *modname*, or its base, depending on *level*.
|
||||
The module isn't immediately imported. Instead, an instance of
|
||||
*lazy_mod_class* is returned. Upon access to any of its attributes, the
|
||||
module is finally loaded.
|
||||
Examples
|
||||
--------
|
||||
>>> import lazy_import, sys
|
||||
>>> np = lazy_import.lazy_module("numpy")
|
||||
>>> np
|
||||
Lazily-loaded module numpy
|
||||
>>> np is sys.modules['numpy']
|
||||
True
|
||||
>>> np.pi # This causes the full loading of the module ...
|
||||
3.141592653589793
|
||||
>>> np # ... and the module is changed in place.
|
||||
<module 'numpy' from '/usr/local/lib/python/site-packages/numpy/__init__.py'>
|
||||
>>> import lazy_import, sys
|
||||
>>> # The following succeeds even when asking for a module that's not available
|
||||
>>> missing = lazy_import.lazy_module("missing_module")
|
||||
>>> missing
|
||||
Lazily-loaded module missing_module
|
||||
>>> missing is sys.modules['missing_module']
|
||||
True
|
||||
>>> missing.some_attr # This causes the full loading of the module, which now fails.
|
||||
ImportError: __main__ attempted to use a functionality that requires module missing_module, but it couldn't be loaded. Please install missing_module and retry.
|
||||
See Also
|
||||
--------
|
||||
:func:`lazy_callable`
|
||||
:class:`LazyModule`
|
||||
"""
|
||||
if error_strings is None:
|
||||
error_strings = {}
|
||||
_set_default_errornames(modname, error_strings)
|
||||
|
||||
mod = _lazy_module(modname, error_strings, lazy_mod_class)
|
||||
if level == "base":
|
||||
return sys.modules[module_basename(modname)]
|
||||
elif level == "leaf":
|
||||
return mod
|
||||
else:
|
||||
raise ValueError("Parameter 'level' must be one of ('base', 'leaf')")
|
||||
|
||||
|
||||
def _lazy_module(modname, error_strings, lazy_mod_class):
|
||||
with _ImportLockContext():
|
||||
fullmodname = modname
|
||||
fullsubmodname = None
|
||||
# ensure parent module/package is in sys.modules
|
||||
# and parent.modname=module, as soon as the parent is imported
|
||||
while modname:
|
||||
try:
|
||||
mod = sys.modules[modname]
|
||||
# We reached a (base) module that's already loaded. Let's stop
|
||||
# the cycle. Can't use 'break' because we still want to go
|
||||
# through the fullsubmodname check below.
|
||||
modname = ""
|
||||
except KeyError:
|
||||
err_s = error_strings.copy()
|
||||
err_s.setdefault("module", modname)
|
||||
|
||||
class _LazyModule(lazy_mod_class):
|
||||
_lazy_import_error_msgs = {"msg": err_s.pop("msg")}
|
||||
try:
|
||||
_lazy_import_error_msgs["msg_callable"] = err_s.pop(
|
||||
"msg_callable"
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
_lazy_import_error_strings = err_s
|
||||
_lazy_import_callables = {}
|
||||
_lazy_import_submodules = {}
|
||||
|
||||
def __repr__(self):
|
||||
return "Lazily-loaded module {}".format(self.__name__)
|
||||
|
||||
# A bit of cosmetic, to make AttributeErrors read more natural
|
||||
_LazyModule.__name__ = "module"
|
||||
# Actual module instantiation
|
||||
mod = sys.modules[modname] = _LazyModule(modname)
|
||||
# No need for __spec__. Maybe in the future.
|
||||
if ModuleSpec:
|
||||
ModuleType.__setattr__(mod, "__spec__", ModuleSpec(modname, None))
|
||||
if fullsubmodname:
|
||||
submod = sys.modules[fullsubmodname]
|
||||
ModuleType.__setattr__(mod, submodname, submod)
|
||||
_LazyModule._lazy_import_submodules[submodname] = submod
|
||||
fullsubmodname = modname
|
||||
modname, _, submodname = modname.rpartition(".")
|
||||
return sys.modules[fullmodname]
|
||||
|
||||
|
||||
def lazy_callable(modname, *names, **kwargs):
|
||||
"""Performs lazy importing of one or more callables.
|
||||
:func:`lazy_callable` creates functions that are thin wrappers that pass
|
||||
any and all arguments straight to the target module's callables. These can
|
||||
be functions or classes. The full loading of that module is only actually
|
||||
triggered when the returned lazy function itself is called. This lazy
|
||||
import of the target module uses the same mechanism as
|
||||
:func:`lazy_module`.
|
||||
|
||||
If, however, the target module has already been fully imported prior
|
||||
to invocation of :func:`lazy_callable`, then the target callables
|
||||
themselves are returned and no lazy imports are made.
|
||||
:func:`lazy_function` and :func:`lazy_function` are aliases of
|
||||
:func:`lazy_callable`.
|
||||
Parameters
|
||||
----------
|
||||
modname : str
|
||||
The base module from where to import the callable(s) in *names*,
|
||||
or a full 'module_name.callable_name' string.
|
||||
names : str (optional)
|
||||
The callable name(s) to import from the module specified by *modname*.
|
||||
If left empty, *modname* is assumed to also include the callable name
|
||||
to import.
|
||||
error_strings : dict, optional
|
||||
A dictionary of strings to use when reporting loading errors (either a
|
||||
missing module, or a missing callable name in the loaded module).
|
||||
*error_string* follows the same usage as described under
|
||||
:func:`lazy_module`, with the exceptions that 1) a further key,
|
||||
'msg_callable', can be supplied to be used as the error when a module
|
||||
is successfully loaded but the target callable can't be found therein
|
||||
(defaulting to :attr:`lazy_import._MSG_CALLABLE`); 2) a key 'callable'
|
||||
is always added with the callable name being loaded.
|
||||
lazy_mod_class : type, optional
|
||||
See definition under :func:`lazy_module`.
|
||||
lazy_call_class : type, optional
|
||||
Analogously to *lazy_mod_class*, allows setting a custom class to
|
||||
handle lazy callables, other than the default :class:`LazyCallable`.
|
||||
Returns
|
||||
-------
|
||||
wrapper function or tuple of wrapper functions
|
||||
If *names* is passed, returns a tuple of wrapper functions, one for
|
||||
each element in *names*.
|
||||
If only *modname* is passed it is assumed to be a full
|
||||
'module_name.callable_name' string, in which case the wrapper for the
|
||||
imported callable is returned directly, and not in a tuple.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Unlike :func:`lazy_module`, which returns a lazy module that eventually
|
||||
mutates into the fully-functional version, :func:`lazy_callable` only
|
||||
returns thin wrappers that never change. This means that the returned
|
||||
wrapper object never truly becomes the one under the module's namespace,
|
||||
even after successful loading of the module in *modname*. This is fine for
|
||||
most practical use cases, but may break code that relies on the usage of
|
||||
the returned objects oter than calling them. One such example is the lazy
|
||||
import of a class: it's fine to use the returned wrapper to instantiate an
|
||||
object, but it can't be used, for instance, to subclass from.
|
||||
Examples
|
||||
--------
|
||||
>>> import lazy_import, sys
|
||||
>>> fn = lazy_import.lazy_callable("numpy.arange")
|
||||
>>> sys.modules['numpy']
|
||||
Lazily-loaded module numpy
|
||||
>>> fn(10)
|
||||
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
>>> sys.modules['numpy']
|
||||
<module 'numpy' from '/usr/local/lib/python3.5/site-packages/numpy/__init__.py'>
|
||||
>>> import lazy_import, sys
|
||||
>>> cl = lazy_import.lazy_callable("numpy.ndarray") # a class
|
||||
>>> obj = cl([1, 2]) # This works OK (and also triggers the loading of numpy)
|
||||
>>> class MySubclass(cl): # This fails because cls is just a wrapper,
|
||||
>>> pass # not an actual class.
|
||||
See Also
|
||||
--------
|
||||
:func:`lazy_module`
|
||||
:class:`LazyCallable`
|
||||
:class:`LazyModule`
|
||||
"""
|
||||
if not names:
|
||||
modname, _, name = modname.rpartition(".")
|
||||
lazy_mod_class = _setdef(kwargs, "lazy_mod_class", LazyModule)
|
||||
lazy_call_class = _setdef(kwargs, "lazy_call_class", LazyCallable)
|
||||
error_strings = _setdef(kwargs, "error_strings", {})
|
||||
_set_default_errornames(modname, error_strings, call=True)
|
||||
|
||||
if not names:
|
||||
# We allow passing a single string as 'modname.callable_name',
|
||||
# in which case the wrapper is returned directly and not as a list.
|
||||
return _lazy_callable(
|
||||
modname, name, error_strings.copy(), lazy_mod_class, lazy_call_class
|
||||
)
|
||||
return tuple(
|
||||
_lazy_callable(
|
||||
modname, cname, error_strings.copy(), lazy_mod_class, lazy_call_class
|
||||
)
|
||||
for cname in names
|
||||
)
|
||||
|
||||
|
||||
lazy_function = lazy_class = lazy_callable
|
||||
|
||||
|
||||
def _lazy_callable(modname, cname, error_strings, lazy_mod_class, lazy_call_class):
|
||||
# We could do most of this in the LazyCallable __init__, but here we can
|
||||
# pre-check whether to actually be lazy or not.
|
||||
module = _lazy_module(modname, error_strings, lazy_mod_class)
|
||||
modclass = type(module)
|
||||
if issubclass(modclass, LazyModule) and hasattr(modclass, "_lazy_import_callables"):
|
||||
modclass._lazy_import_callables.setdefault(
|
||||
cname, lazy_call_class(module, cname)
|
||||
)
|
||||
return getattr(module, cname)
|
||||
|
||||
|
||||
#######################
|
||||
# Real module loading #
|
||||
#######################
|
||||
|
||||
|
||||
def _load_module(module):
|
||||
"""Ensures that a module, and its parents, are properly loaded"""
|
||||
modclass = type(module)
|
||||
# We only take care of our own LazyModule instances
|
||||
if not issubclass(modclass, LazyModule):
|
||||
raise TypeError("Passed module is not a LazyModule instance.")
|
||||
with _ImportLockContext():
|
||||
parent, _, modname = module.__name__.rpartition(".")
|
||||
logger.debug("loading module {}".format(modname))
|
||||
# We first identify whether this is a loadable LazyModule, then we
|
||||
# strip as much of lazy_import behavior as possible (keeping it cached,
|
||||
# in case loading fails and we need to reset the lazy state).
|
||||
if not hasattr(modclass, "_lazy_import_error_msgs"):
|
||||
# Alreay loaded (no _lazy_import_error_msgs attr). Not reloading.
|
||||
return
|
||||
# First, ensure the parent is loaded (using recursion; *very* unlikely
|
||||
# we'll ever hit a stack limit in this case).
|
||||
modclass._LOADING = True
|
||||
try:
|
||||
if parent:
|
||||
logger.debug("first loading parent module {}".format(parent))
|
||||
setattr(sys.modules[parent], modname, module)
|
||||
if not hasattr(modclass, "_LOADING"):
|
||||
logger.debug("Module {} already loaded by the parent".format(modname))
|
||||
# We've been loaded by the parent. Let's bail.
|
||||
return
|
||||
cached_data = _clean_lazymodule(module)
|
||||
try:
|
||||
# Get Python to do the real import!
|
||||
reload_module(module)
|
||||
except:
|
||||
# Loading failed. We reset our lazy state.
|
||||
logger.debug("Failed to load module {}. Resetting...".format(modname))
|
||||
_reset_lazymodule(module, cached_data)
|
||||
raise
|
||||
else:
|
||||
# Successful load
|
||||
logger.debug("Successfully loaded module {}".format(modname))
|
||||
delattr(modclass, "_LOADING")
|
||||
_reset_lazy_submod_refs(module)
|
||||
|
||||
except (AttributeError, ImportError) as err:
|
||||
logger.debug(
|
||||
"Failed to load {}.\n{}: {}".format(
|
||||
modname, err.__class__.__name__, err
|
||||
)
|
||||
)
|
||||
logger.lazy_trace()
|
||||
# Under Python 3 reloading our dummy LazyModule instances causes an
|
||||
# AttributeError if the module can't be found. Would be preferrable
|
||||
# if we could always rely on an ImportError. As it is we vet the
|
||||
# AttributeError as thoroughly as possible.
|
||||
if (six.PY3 and isinstance(err, AttributeError)) and not err.args[
|
||||
0
|
||||
] == "'NoneType' object has no attribute 'name'":
|
||||
# Not the AttributeError we were looking for.
|
||||
raise
|
||||
msg = modclass._lazy_import_error_msgs["msg"]
|
||||
raise_from(
|
||||
ImportError(msg.format(**modclass._lazy_import_error_strings)), None
|
||||
)
|
||||
|
||||
|
||||
##############################
|
||||
# Helper functions/constants #
|
||||
##############################
|
||||
|
||||
_MSG = (
|
||||
"{caller} attempted to use a functionality that requires module "
|
||||
"{module}, but it couldn't be loaded. Please install {install_name} "
|
||||
"and retry."
|
||||
)
|
||||
|
||||
_MSG_CALLABLE = (
|
||||
"{caller} attempted to use a functionality that requires "
|
||||
"{callable}, of module {module}, but it couldn't be found in that "
|
||||
"module. Please install a version of {install_name} that has "
|
||||
"{module}.{callable} and retry."
|
||||
)
|
||||
|
||||
_CLS_ATTRS = (
|
||||
"_lazy_import_error_strings",
|
||||
"_lazy_import_error_msgs",
|
||||
"_lazy_import_callables",
|
||||
"_lazy_import_submodules",
|
||||
"__repr__",
|
||||
)
|
||||
|
||||
_DELETION_DICT = ("_lazy_import_submodules",)
|
||||
|
||||
|
||||
def _setdef(argdict, name, defaultvalue):
|
||||
"""Like dict.setdefault but sets the default value also if None is present."""
|
||||
if not name in argdict or argdict[name] is None:
|
||||
argdict[name] = defaultvalue
|
||||
return argdict[name]
|
||||
|
||||
|
||||
def module_basename(modname):
|
||||
return modname.partition(".")[0]
|
||||
|
||||
|
||||
def _set_default_errornames(modname, error_strings, call=False):
|
||||
# We don't set the modulename default here because it will change for
|
||||
# parents of lazily imported submodules.
|
||||
error_strings.setdefault("caller", _caller_name(3, default="Python"))
|
||||
error_strings.setdefault("install_name", module_basename(modname))
|
||||
error_strings.setdefault("msg", _MSG)
|
||||
if call:
|
||||
error_strings.setdefault("msg_callable", _MSG_CALLABLE)
|
||||
|
||||
|
||||
def _caller_name(depth=2, default=""):
|
||||
"""Returns the name of the calling namespace."""
|
||||
# the presence of sys._getframe might be implementation-dependent.
|
||||
# It isn't that serious if we can't get the caller's name.
|
||||
try:
|
||||
return sys._getframe(depth).f_globals["__name__"]
|
||||
except AttributeError:
|
||||
return default
|
||||
|
||||
|
||||
def _clean_lazymodule(module):
|
||||
"""Removes all lazy behavior from a module's class, for loading.
|
||||
Also removes all module attributes listed under the module's class deletion
|
||||
dictionaries. Deletion dictionaries are class attributes with names
|
||||
specified in `_DELETION_DICT`.
|
||||
Parameters
|
||||
----------
|
||||
module: LazyModule
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary of deleted class attributes, that can be used to reset the
|
||||
lazy state using :func:`_reset_lazymodule`.
|
||||
"""
|
||||
modclass = type(module)
|
||||
_clean_lazy_submod_refs(module)
|
||||
|
||||
modclass.__getattribute__ = ModuleType.__getattribute__
|
||||
modclass.__setattr__ = ModuleType.__setattr__
|
||||
cls_attrs = {}
|
||||
for cls_attr in _CLS_ATTRS:
|
||||
try:
|
||||
cls_attrs[cls_attr] = getattr(modclass, cls_attr)
|
||||
delattr(modclass, cls_attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
return cls_attrs
|
||||
|
||||
|
||||
def _clean_lazy_submod_refs(module):
|
||||
modclass = type(module)
|
||||
for deldict in _DELETION_DICT:
|
||||
try:
|
||||
delnames = getattr(modclass, deldict)
|
||||
except AttributeError:
|
||||
continue
|
||||
for delname in delnames:
|
||||
try:
|
||||
super(LazyModule, module).__delattr__(delname)
|
||||
except AttributeError:
|
||||
# Maybe raise a warning?
|
||||
pass
|
||||
|
||||
|
||||
def _reset_lazymodule(module, cls_attrs):
|
||||
"""Resets a module's lazy state from cached data."""
|
||||
modclass = type(module)
|
||||
del modclass.__getattribute__
|
||||
del modclass.__setattr__
|
||||
try:
|
||||
del modclass._LOADING
|
||||
except AttributeError:
|
||||
pass
|
||||
for cls_attr in _CLS_ATTRS:
|
||||
try:
|
||||
setattr(modclass, cls_attr, cls_attrs[cls_attr])
|
||||
except KeyError:
|
||||
pass
|
||||
_reset_lazy_submod_refs(module)
|
||||
|
||||
|
||||
def _reset_lazy_submod_refs(module):
|
||||
modclass = type(module)
|
||||
for deldict in _DELETION_DICT:
|
||||
try:
|
||||
resetnames = getattr(modclass, deldict)
|
||||
except AttributeError:
|
||||
continue
|
||||
for name, submod in resetnames.items():
|
||||
super(LazyModule, module).__setattr__(name, submod)
|
||||
|
||||
|
||||
def run_from_ipython():
|
||||
# Taken from https://stackoverflow.com/questions/5376837
|
||||
try:
|
||||
__IPYTHON__
|
||||
return True
|
||||
except NameError:
|
||||
return False
|
||||
46
src/plume/utils/lazy_loader.py
Normal file
46
src/plume/utils/lazy_loader.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Code copied from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/lazy_loader.py
|
||||
"""A LazyLoader class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import importlib
|
||||
import types
|
||||
|
||||
|
||||
class LazyLoader(types.ModuleType):
|
||||
"""Lazily import a module, mainly to avoid pulling in large dependencies.
|
||||
|
||||
`contrib`, and `ffmpeg` are examples of modules that are large and not always
|
||||
needed, and this allows them to only be loaded when they are used.
|
||||
"""
|
||||
|
||||
# The lint error here is incorrect.
|
||||
def __init__(
|
||||
self, local_name, parent_module_globals, name
|
||||
): # pylint: disable=super-on-old-class
|
||||
self._local_name = local_name
|
||||
self._parent_module_globals = parent_module_globals
|
||||
|
||||
super(LazyLoader, self).__init__(name)
|
||||
|
||||
def _load(self):
|
||||
# Import the target module and insert it into the parent's namespace
|
||||
module = importlib.import_module(self.__name__)
|
||||
self._parent_module_globals[self._local_name] = module
|
||||
|
||||
# Update this object's dict so that if someone keeps a reference to the
|
||||
# LazyLoader, lookups are efficient (__getattr__ is only called on lookups
|
||||
# that fail).
|
||||
self.__dict__.update(module.__dict__)
|
||||
|
||||
return module
|
||||
|
||||
def __getattr__(self, item):
|
||||
module = self._load()
|
||||
return getattr(module, item)
|
||||
|
||||
def __dir__(self):
|
||||
module = self._load()
|
||||
return dir(module)
|
||||
68
src/plume/utils/manifest.py
Normal file
68
src/plume/utils/manifest.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from pathlib import Path
|
||||
|
||||
# from tqdm import tqdm
|
||||
import json
|
||||
|
||||
# from .extended_path import ExtendedPath
|
||||
# from .parallel import parallel_apply
|
||||
# from .encrypt import wav_cryptor, text_cryptor
|
||||
|
||||
|
||||
def manifest_str(path, dur, text):
|
||||
k = {"audio_filepath": path, "duration": round(dur, 1), "text": text}
|
||||
return json.dumps(k) + "\n"
|
||||
|
||||
|
||||
def asr_manifest_writer(
|
||||
asr_manifest_path: Path, manifest_str_source, verbose=False
|
||||
):
|
||||
with asr_manifest_path.open("w") as mf:
|
||||
if verbose:
|
||||
print(f"writing asr manifest to {asr_manifest_path}")
|
||||
for mani_dict in manifest_str_source:
|
||||
manifest = manifest_str(
|
||||
mani_dict["audio_filepath"],
|
||||
mani_dict["duration"],
|
||||
mani_dict["text"],
|
||||
)
|
||||
mf.write(manifest)
|
||||
|
||||
|
||||
#
|
||||
# def decrypt(
|
||||
# src_dataset_dir: Path,
|
||||
# dest_dataset_dir: Path,
|
||||
# encryption_key: str,
|
||||
# verbose=True,
|
||||
# parallel=True,
|
||||
# ):
|
||||
# data_manifest_path = src_dataset_dir / "manifest.json"
|
||||
# (dest_dataset_dir / "wavs").mkdir(exist_ok=True, parents=True)
|
||||
# dest_manifest_path = dest_dataset_dir / "manifest.json"
|
||||
# print(f"reading encrypted manifest from {data_manifest_path}")
|
||||
# asr_data = list(ExtendedPath(data_manifest_path).read_jsonl())
|
||||
# enc_key_bytes = encryption_key.encode("utf-8")
|
||||
# wc = wav_cryptor(enc_key_bytes)
|
||||
# tc = text_cryptor(enc_key_bytes)
|
||||
#
|
||||
# def decrypt_fn(p):
|
||||
# dest_path = dest_dataset_dir / Path(p["audio_filepath"])
|
||||
# wc.decrypt_wav_path_to(
|
||||
# src_dataset_dir / Path(p["audio_filepath"]), dest_path
|
||||
# )
|
||||
# d = {
|
||||
# "audio_filepath": dest_path,
|
||||
# "duration": p["duration"],
|
||||
# "text": tc.decrypt_text(p["text"].encode("utf-8")),
|
||||
# }
|
||||
# return d
|
||||
#
|
||||
# def datagen():
|
||||
# if parallel:
|
||||
# for d in parallel_apply(decrypt_fn, asr_data, verbose=verbose):
|
||||
# yield d
|
||||
# else:
|
||||
# for p in tqdm.tqdm(asr_data) if verbose else asr_data:
|
||||
# yield decrypt_fn(d)
|
||||
#
|
||||
# asr_manifest_writer(dest_manifest_path, datagen)
|
||||
41
src/plume/utils/parallel.py
Normal file
41
src/plume/utils/parallel.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def parallel_apply(fn, iterable, workers=8, pool="thread", verbose=True):
|
||||
# warm-up
|
||||
fn(iterable[0])
|
||||
if pool == "thread":
|
||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||
if verbose:
|
||||
print(f"parallelly applying {fn}")
|
||||
return [
|
||||
res
|
||||
for res in tqdm(
|
||||
exe.map(fn, iterable),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=len(iterable),
|
||||
)
|
||||
]
|
||||
else:
|
||||
return [res for res in exe.map(fn, iterable)]
|
||||
elif pool == "process":
|
||||
with ProcessPoolExecutor(max_workers=workers) as exe:
|
||||
if verbose:
|
||||
print(f"parallelly applying {fn}")
|
||||
with tqdm(total=len(iterable)) as progress:
|
||||
futures = []
|
||||
for i in iterable:
|
||||
future = exe.submit(fn, i)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
results = []
|
||||
for future in futures:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
return result
|
||||
else:
|
||||
return [res for res in exe.map(fn, iterable)]
|
||||
else:
|
||||
raise Exception(f"unsupported pool type - {pool}")
|
||||
383
src/plume/utils/regentity.py
Normal file
383
src/plume/utils/regentity.py
Normal file
@@ -0,0 +1,383 @@
|
||||
import re
|
||||
|
||||
from .lazy_import import lazy_callable, lazy_module
|
||||
|
||||
num2words = lazy_callable("num2words.num2words")
|
||||
spellchecker = lazy_module("spellchecker")
|
||||
# from num2words import num2words
|
||||
|
||||
|
||||
def entity_replacer_keeper(
|
||||
pre_rules=[], entity_rules=[], post_rules=[], verbose=False
|
||||
):
|
||||
# def replacer_keeper_gen():
|
||||
pre_rules_c = [(re.compile(k), v) for (k, v) in pre_rules]
|
||||
entity_rules_c = [
|
||||
(re.compile(k, re.IGNORECASE), v) for (k, v) in entity_rules
|
||||
]
|
||||
post_rules_c = [(re.compile(k), v) for (k, v) in post_rules]
|
||||
|
||||
re_rules = pre_rules_c + entity_rules_c + post_rules_c
|
||||
|
||||
def replacer(w2v_out):
|
||||
out = w2v_out
|
||||
for (k, v) in re_rules:
|
||||
orig = out
|
||||
out = k.sub(v, out)
|
||||
if verbose:
|
||||
print(f"rule |{k}|: sub:|{v}| |{orig}|=> |{out}|")
|
||||
return out
|
||||
|
||||
def merge_intervals(intervals):
|
||||
# https://codereview.stackexchange.com/a/69249
|
||||
sorted_by_lower_bound = sorted(intervals, key=lambda tup: tup[0])
|
||||
merged = []
|
||||
|
||||
for higher in sorted_by_lower_bound:
|
||||
if not merged:
|
||||
merged.append(higher)
|
||||
else:
|
||||
lower = merged[-1]
|
||||
# test for intersection between lower and higher:
|
||||
# we know via sorting that lower[0] <= higher[0]
|
||||
if higher[0] <= lower[1]:
|
||||
upper_bound = max(lower[1], higher[1])
|
||||
merged[-1] = (
|
||||
lower[0],
|
||||
upper_bound,
|
||||
) # replace by merged interval
|
||||
else:
|
||||
merged.append(higher)
|
||||
return merged
|
||||
|
||||
# optimal merging interval tree
|
||||
# https://www.geeksforgeeks.org/interval-tree/
|
||||
|
||||
def keep_literals(w2v_out):
|
||||
# out = re.sub(r"[ ;,.]", " ", w2v_out).strip()
|
||||
out = w2v_out
|
||||
for (k, v) in pre_rules_c:
|
||||
out = k.sub(v, out)
|
||||
num_spans = []
|
||||
if verbose:
|
||||
print(f"num_rules: {len(entity_rules_c)}")
|
||||
for (k, v) in entity_rules_c: # [94:]:
|
||||
matches = k.finditer(out)
|
||||
for m in matches:
|
||||
# num_spans.append(m.span())
|
||||
# look at space seprated internal entities
|
||||
(start, end) = m.span()
|
||||
for s in re.finditer(r"\S+", out[start:end]):
|
||||
(start_e, end_e) = s.span()
|
||||
num_spans.append((start_e + start, end_e + start))
|
||||
if verbose:
|
||||
t = out[start_e + start : end_e + start]
|
||||
print(f"rule |{k}|: sub:|{v}| => |{t}|")
|
||||
|
||||
merged = merge_intervals(num_spans)
|
||||
num_ents = len(merged)
|
||||
keep_out = " ".join((out[s[0] : s[1]] for s in merged))
|
||||
for (k, v) in post_rules_c:
|
||||
keep_out = k.sub(v, keep_out)
|
||||
return keep_out, num_ents
|
||||
|
||||
return replacer, keep_literals
|
||||
|
||||
|
||||
def default_num_only_rules(num_range):
|
||||
entity_rules = (
|
||||
[
|
||||
(
|
||||
r"\b" + num2words(i) + r"\b",
|
||||
str(i),
|
||||
)
|
||||
for i in reversed(range(num_range))
|
||||
]
|
||||
+ [
|
||||
(
|
||||
r"\b" + str(i) + r"\b",
|
||||
str(i),
|
||||
)
|
||||
for i in reversed(range(10))
|
||||
]
|
||||
+ [
|
||||
(r"\bhundred\b", "00"),
|
||||
]
|
||||
)
|
||||
return entity_rules
|
||||
|
||||
|
||||
def default_num_rules(num_range):
|
||||
entity_rules = default_num_only_rules(num_range) + [
|
||||
(r"\boh\b", "0"),
|
||||
(r"\bo\b", "0"),
|
||||
(r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"),
|
||||
(r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"),
|
||||
]
|
||||
return entity_rules
|
||||
|
||||
|
||||
def infer_num_rules_vocab(num_range):
|
||||
vocab = [num2words(i) for i in reversed(range(num_range))] + [
|
||||
"hundred",
|
||||
"double",
|
||||
"triple",
|
||||
]
|
||||
entity_rules = [
|
||||
(
|
||||
num2words(i),
|
||||
str(i),
|
||||
)
|
||||
for i in reversed(range(num_range))
|
||||
] + [
|
||||
(r"\bhundred\b", "00"),
|
||||
(r"\boh\b", "0"),
|
||||
(r"\bo\b", "0"),
|
||||
(r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"),
|
||||
(r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"),
|
||||
]
|
||||
return entity_rules, vocab
|
||||
|
||||
|
||||
def do_tri_verbose_list():
|
||||
return [
|
||||
num2words(i) for i in list(range(11, 19)) + list(range(20, 100, 10))
|
||||
] + ["hundred"]
|
||||
|
||||
|
||||
def default_alnum_rules(num_range, oh_is_zero, i_oh_limit):
|
||||
oh_is_zero_rules = [
|
||||
(r"\boh\b", "0"),
|
||||
(r"\bo\b", "0"),
|
||||
]
|
||||
|
||||
num_list = [num2words(i) for i in reversed(range(num_range))]
|
||||
al_num_regex = r"|".join(num_list) + r"|[0-9a-z]"
|
||||
o_i_vars = r"(\[?(?:Oh|O|I)\]?)"
|
||||
i_oh_limit_rules = [
|
||||
(r"\b([a-hj-np-z])\b", "\\1"),
|
||||
(
|
||||
r"\b((?:"
|
||||
+ al_num_regex
|
||||
+ r"|^)\b\s*)(I|O)(\s*\b)(?="
|
||||
+ al_num_regex
|
||||
+ r"\s+|$)\b",
|
||||
"\\1[\\2]\\3",
|
||||
),
|
||||
# (
|
||||
# r"\b" + o_i_vars + r"(\s+)" + o_i_vars + r"\b",
|
||||
# "[\\1]\\2[\\3]",
|
||||
# ),
|
||||
(
|
||||
r"(\s+|^)" + o_i_vars + r"(\s+)\[?" + o_i_vars + r"\]?(\s+|$)",
|
||||
"\\1[\\2]\\3[\\4]\\5",
|
||||
),
|
||||
(
|
||||
r"(\s+|^)\[?" + o_i_vars + r"\]?(\s+)" + o_i_vars + r"(\s+|$)",
|
||||
"\\1[\\2]\\3[\\4]\\5",
|
||||
),
|
||||
]
|
||||
entity_rules = (
|
||||
default_num_only_rules(num_range)
|
||||
+ (oh_is_zero_rules if oh_is_zero else [(r"\boh\b", "o")])
|
||||
+ [
|
||||
(r"\bdouble(?: |-)(\w+|\d+)\b", "\\1 \\1"),
|
||||
(r"\btriple(?: |-)(\w+|\d+)\b", "\\1 \\1 \\1"),
|
||||
# (r"\b([a-zA-Z])\b", "\\1"),
|
||||
]
|
||||
+ (i_oh_limit_rules if i_oh_limit else [(r"\b([a-zA-Z])\b", "\\1")])
|
||||
)
|
||||
return entity_rules
|
||||
|
||||
|
||||
def num_replacer(num_range=100, condense=True):
|
||||
entity_rules = default_num_rules(num_range)
|
||||
post_rules = [(r"[^0-9]", "")] if condense else []
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
return replacer
|
||||
|
||||
|
||||
def num_keeper(num_range=100):
|
||||
entity_rules = default_num_rules(num_range)
|
||||
pre_rules = [(r"[ ;,.]", " ")]
|
||||
post_rules = []
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
return keeper
|
||||
|
||||
|
||||
def alnum_replacer(
|
||||
num_range=100, oh_is_zero=False, i_oh_limit=True, condense=True
|
||||
):
|
||||
entity_rules = default_alnum_rules(
|
||||
num_range, oh_is_zero, i_oh_limit=i_oh_limit
|
||||
)
|
||||
# entity_rules = default_num_rules(num_range)
|
||||
pre_rules = [
|
||||
(r"[ ;,.]", " "),
|
||||
(r"[']", ""),
|
||||
# (
|
||||
# r"((?:(?<=\w{2,2})|^)\s*)(?:\bI\b|\bi\b|\bOh\b|\boh\b)(\s*(?:\w{2,}|$))",
|
||||
# "",
|
||||
# ),
|
||||
]
|
||||
|
||||
def upper_case(match_obj):
|
||||
char_elem = match_obj.group(0)
|
||||
return char_elem.upper()
|
||||
|
||||
post_rules = (
|
||||
(
|
||||
(
|
||||
[
|
||||
(r"(\s|^)(?:o|O|I|i)(\s|$)", "\\1\\2"),
|
||||
(r"\[(\w)\]", "\\1"),
|
||||
]
|
||||
if i_oh_limit
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
# (r"\b[a-zA-Z]+\'[a-zA-Z]+\b", ""),
|
||||
(r"\b[a-zA-Z]{2,}\b", ""),
|
||||
(r"[^a-zA-Z0-9]", ""),
|
||||
(r"([a-z].*)", upper_case),
|
||||
]
|
||||
)
|
||||
if condense
|
||||
else []
|
||||
)
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
return replacer
|
||||
|
||||
|
||||
def alnum_keeper(num_range=100, oh_is_zero=False):
|
||||
entity_rules = default_alnum_rules(num_range, oh_is_zero, i_oh_limit=True)
|
||||
|
||||
# def strip_space(match_obj):
|
||||
# # char_elem = match_obj.group(1)
|
||||
# return match_obj.group(1).strip() + match_obj.group(2).strip()
|
||||
|
||||
pre_rules = [
|
||||
(r"[ ;,.]", " "),
|
||||
(r"[']", ""),
|
||||
# (
|
||||
# r"((?:(?<=\w{2,2})|^)\s*)(?:\bI\b|\bi\b|\bOh\b|\boh\b)(\s*(?:\w{2,}|$))",
|
||||
# strip_space,
|
||||
# ),
|
||||
]
|
||||
|
||||
post_rules = [
|
||||
# (
|
||||
# r"((?:(?<=\w{2,2})|^)\s*)(?:\bI\b|\bi\b|\bOh\b|\boh\b)(\s*(?:\w{2,}|$))",
|
||||
# strip_space,
|
||||
# )
|
||||
]
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
pre_rules=pre_rules, entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
return keeper
|
||||
|
||||
|
||||
def num_keeper_orig(num_range=10, extra_rules=[]):
|
||||
num_int_map_ty = [
|
||||
(
|
||||
r"\b" + num2words(i) + r"\b",
|
||||
" " + str(i) + " ",
|
||||
)
|
||||
for i in reversed(range(num_range))
|
||||
]
|
||||
re_rules = [
|
||||
(re.compile(k, re.IGNORECASE), v)
|
||||
for (k, v) in [
|
||||
# (r"[ ;,.]", " "),
|
||||
(r"\bdouble(?: |-)(\w+)\b", "\\1 \\1"),
|
||||
(r"\btriple(?: |-)(\w+)\b", "\\1 \\1 \\1"),
|
||||
(r"hundred", "00"),
|
||||
(r"\boh\b", " 0 "),
|
||||
(r"\bo\b", " 0 "),
|
||||
]
|
||||
+ num_int_map_ty
|
||||
] + [(re.compile(k), v) for (k, v) in extra_rules]
|
||||
|
||||
def merge_intervals(intervals):
|
||||
# https://codereview.stackexchange.com/a/69249
|
||||
sorted_by_lower_bound = sorted(intervals, key=lambda tup: tup[0])
|
||||
merged = []
|
||||
|
||||
for higher in sorted_by_lower_bound:
|
||||
if not merged:
|
||||
merged.append(higher)
|
||||
else:
|
||||
lower = merged[-1]
|
||||
# test for intersection between lower and higher:
|
||||
# we know via sorting that lower[0] <= higher[0]
|
||||
if higher[0] <= lower[1]:
|
||||
upper_bound = max(lower[1], higher[1])
|
||||
merged[-1] = (
|
||||
lower[0],
|
||||
upper_bound,
|
||||
) # replace by merged interval
|
||||
else:
|
||||
merged.append(higher)
|
||||
return merged
|
||||
|
||||
# merging interval tree for optimal # https://www.geeksforgeeks.org/interval-tree/
|
||||
|
||||
def keep_numeric_literals(w2v_out):
|
||||
# out = w2v_out.lower()
|
||||
out = re.sub(r"[ ;,.]", " ", w2v_out).strip()
|
||||
# out = " " + out.strip() + " "
|
||||
# out = re.sub(r"double (\w+)", "\\1 \\1", out)
|
||||
# out = re.sub(r"triple (\w+)", "\\1 \\1 \\1", out)
|
||||
num_spans = []
|
||||
for (k, v) in re_rules: # [94:]:
|
||||
matches = k.finditer(out)
|
||||
for m in matches:
|
||||
# num_spans.append((k, m.span()))
|
||||
num_spans.append(m.span())
|
||||
# out = re.sub(k, v, out)
|
||||
merged = merge_intervals(num_spans)
|
||||
num_ents = len(merged)
|
||||
keep_out = " ".join((out[s[0] : s[1]] for s in merged))
|
||||
return keep_out, num_ents
|
||||
|
||||
return keep_numeric_literals
|
||||
|
||||
|
||||
def infer_num_replacer(num_range=100, condense=True):
|
||||
entity_rules, vocab = infer_num_rules_vocab(num_range)
|
||||
corrector = vocab_corrector_gen(vocab)
|
||||
post_rules = [(r"[^0-9]", "")] if condense else []
|
||||
replacer, keeper = entity_replacer_keeper(
|
||||
entity_rules=entity_rules, post_rules=post_rules
|
||||
)
|
||||
|
||||
def final_replacer(x):
|
||||
return replacer(corrector(x))
|
||||
|
||||
return final_replacer
|
||||
|
||||
|
||||
def vocab_corrector_gen(vocab):
|
||||
spell = spellchecker.SpellChecker(distance=1)
|
||||
words_to_remove = set(spell.word_frequency.words()) - set(vocab)
|
||||
spell.word_frequency.remove_words(words_to_remove)
|
||||
|
||||
def corrector(inp):
|
||||
return " ".join(
|
||||
[spell.correction(tok) for tok in spell.split_words(inp)]
|
||||
)
|
||||
|
||||
return corrector
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
repl = infer_num_replacer()
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
33
src/plume/utils/serve.py
Normal file
33
src/plume/utils/serve.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from plume.utils import lazy_module
|
||||
import typer
|
||||
|
||||
rpyc = lazy_module("rpyc")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
class ASRService(rpyc.Service):
|
||||
def __init__(self, asr_recognizer):
|
||||
self.asr = asr_recognizer
|
||||
|
||||
def on_connect(self, conn):
|
||||
# code that runs when a connection is created
|
||||
# (to init the service, if needed)
|
||||
pass
|
||||
|
||||
def on_disconnect(self, conn):
|
||||
# code that runs after the connection has already closed
|
||||
# (to finalize the service, if needed)
|
||||
pass
|
||||
|
||||
def exposed_transcribe(
|
||||
self, utterance: bytes
|
||||
): # this is an exposed method
|
||||
speech_audio = self.asr.transcribe(utterance)
|
||||
return speech_audio
|
||||
|
||||
def exposed_transcribe_cb(
|
||||
self, utterance: bytes, respond
|
||||
): # this is an exposed method
|
||||
speech_audio = self.asr.transcribe(utterance)
|
||||
respond(speech_audio)
|
||||
45
src/plume/utils/st_rerun.py
Normal file
45
src/plume/utils/st_rerun.py
Normal file
@@ -0,0 +1,45 @@
|
||||
try:
|
||||
# Before Streamlit 0.65
|
||||
from streamlit.ReportThread import get_report_ctx
|
||||
from streamlit.server.Server import Server
|
||||
from streamlit.ScriptRequestQueue import RerunData
|
||||
from streamlit.ScriptRunner import RerunException
|
||||
except ModuleNotFoundError:
|
||||
# After Streamlit 0.65
|
||||
from streamlit.report_thread import get_report_ctx
|
||||
from streamlit.server.server import Server
|
||||
from streamlit.script_request_queue import RerunData
|
||||
from streamlit.script_runner import RerunException
|
||||
|
||||
|
||||
def rerun():
|
||||
"""Rerun a Streamlit app from the top!"""
|
||||
widget_states = _get_widget_states()
|
||||
raise RerunException(RerunData(widget_states))
|
||||
|
||||
|
||||
def _get_widget_states():
|
||||
# Hack to get the session object from Streamlit.
|
||||
|
||||
ctx = get_report_ctx()
|
||||
|
||||
session = None
|
||||
|
||||
current_server = Server.get_current()
|
||||
if hasattr(current_server, '_session_infos'):
|
||||
# Streamlit < 0.56
|
||||
session_infos = Server.get_current()._session_infos.values()
|
||||
else:
|
||||
session_infos = Server.get_current()._session_info_by_id.values()
|
||||
|
||||
for session_info in session_infos:
|
||||
if session_info.session.enqueue == ctx.enqueue:
|
||||
session = session_info.session
|
||||
|
||||
if session is None:
|
||||
raise RuntimeError(
|
||||
"Oh noes. Couldn't get your Streamlit Session object"
|
||||
"Are you doing something fancy with threads?"
|
||||
)
|
||||
# Got the session object!
|
||||
return session._widget_states
|
||||
192
src/plume/utils/transcribe.py
Normal file
192
src/plume/utils/transcribe.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from functools import lru_cache
|
||||
|
||||
import typer
|
||||
|
||||
# import rpyc
|
||||
|
||||
# from tqdm import tqdm
|
||||
# from pydub.silence import split_on_silence
|
||||
from .lazy_import import lazy_module
|
||||
|
||||
rpyc = lazy_module("rpyc")
|
||||
pydub = lazy_module("pydub")
|
||||
np = lazy_module("numpy")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ASR_RPYC_HOST = os.environ.get("ASR_RPYC_HOST", "localhost")
|
||||
ASR_RPYC_PORT = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||
|
||||
TRITON_ASR_MODEL = os.environ.get("TRITON_ASR_MODEL", "slu_wav2vec2")
|
||||
|
||||
TRITON_GRPC_ASR_HOST = os.environ.get("TRITON_GRPC_ASR_HOST", "localhost")
|
||||
TRITON_GRPC_ASR_PORT = int(os.environ.get("TRITON_GRPC_ASR_PORT", "8001"))
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT):
|
||||
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
|
||||
try:
|
||||
asr = rpyc.connect(asr_host, asr_port).root
|
||||
logger.info("connected to asr server successfully")
|
||||
except ConnectionRefusedError:
|
||||
raise Exception("env-var JASPER_ASR_RPYC_HOST invalid")
|
||||
|
||||
def audio_prep(aud_seg):
|
||||
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
af = BytesIO()
|
||||
asr_seg.export(af, format="wav")
|
||||
input_audio_bytes = af.getvalue()
|
||||
return input_audio_bytes
|
||||
|
||||
return asr.transcribe, audio_prep
|
||||
|
||||
|
||||
def triton_transcribe_grpc_gen(
|
||||
asr_host=TRITON_GRPC_ASR_HOST,
|
||||
asr_port=TRITON_GRPC_ASR_PORT,
|
||||
asr_model=TRITON_ASR_MODEL,
|
||||
method="chunked",
|
||||
chunk_msec=5000,
|
||||
sil_msec=500,
|
||||
# overlap=False,
|
||||
sep=" ",
|
||||
):
|
||||
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
|
||||
import tritonclient.grpc as grpcclient
|
||||
|
||||
sup_meth = ["chunked", "silence", "whole"]
|
||||
if method not in sup_meth:
|
||||
meths = "|".join(sup_meth)
|
||||
raise Exception(f"unsupported method {method}. pick one of {meths}")
|
||||
|
||||
client = grpcclient.InferenceServerClient(f"{asr_host}:{asr_port}")
|
||||
|
||||
def transcriber(aud_seg):
|
||||
af = BytesIO()
|
||||
aud_seg.export(af, format="wav")
|
||||
input_audio_bytes = af.getvalue()
|
||||
input_audio_data = np.array([input_audio_bytes])
|
||||
inputs = [
|
||||
grpcclient.InferInput(
|
||||
"INPUT_AUDIO",
|
||||
input_audio_data.shape,
|
||||
np_to_triton_dtype(input_audio_data.dtype),
|
||||
)
|
||||
]
|
||||
inputs[0].set_data_from_numpy(input_audio_data)
|
||||
outputs = [grpcclient.InferRequestedOutput("OUTPUT_TEXT")]
|
||||
try:
|
||||
response = client.infer(
|
||||
asr_model, inputs, request_id=str(1), outputs=outputs
|
||||
)
|
||||
transcript = response.as_numpy("OUTPUT_TEXT")[0]
|
||||
except InferenceServerException:
|
||||
transcript = b"[server error]"
|
||||
return transcript.decode("utf-8")
|
||||
|
||||
def chunked_transcriber(aud_seg):
|
||||
if method == "silence":
|
||||
sil_chunks = pydub.silence.split_on_silence(
|
||||
aud_seg,
|
||||
min_silence_len=sil_msec,
|
||||
silence_thresh=-50,
|
||||
keep_silence=500,
|
||||
)
|
||||
chunks = [sc for c in sil_chunks for sc in c[::chunk_msec]]
|
||||
else:
|
||||
chunks = aud_seg[::chunk_msec]
|
||||
# if overlap:
|
||||
# chunks = [
|
||||
# aud_seg[start, end]
|
||||
# for start, end in range(0, int(aud_seg.duration_seconds * 1000, 1000))
|
||||
# ]
|
||||
# pass
|
||||
transcript_list = []
|
||||
sil_pad = pydub.AudioSegment.silent(duration=sil_msec)
|
||||
for seg in chunks:
|
||||
t_seg = sil_pad + seg + sil_pad
|
||||
c_transcript = transcriber(t_seg)
|
||||
transcript_list.append(c_transcript)
|
||||
transcript = sep.join(transcript_list)
|
||||
return transcript
|
||||
|
||||
def audio_prep(aud_seg):
|
||||
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
return asr_seg
|
||||
|
||||
whole_transcriber = transcriber if method == "whole" else chunked_transcriber
|
||||
return whole_transcriber, audio_prep
|
||||
|
||||
|
||||
@app.command()
|
||||
def file(
|
||||
audio_file: Path, write_file: bool = False, chunked: bool = True, rpyc: bool = False, model='slu_wav2vec2'
|
||||
):
|
||||
aseg = pydub.AudioSegment.from_file(audio_file)
|
||||
if rpyc:
|
||||
transcriber, prep = transcribe_rpyc_gen()
|
||||
else:
|
||||
transcriber, prep = triton_transcribe_grpc_gen(asr_model=model)
|
||||
transcription = transcriber(prep(aseg))
|
||||
|
||||
typer.echo(transcription)
|
||||
if write_file:
|
||||
tscript_file_path = audio_file.with_suffix(".txt")
|
||||
with open(tscript_file_path, "w") as tf:
|
||||
tf.write(transcription)
|
||||
|
||||
|
||||
@app.command()
|
||||
def benchmark(audio_file: Path):
|
||||
transcriber, audio_prep = transcribe_rpyc_gen()
|
||||
file_seg = pydub.AudioSegment.from_file(audio_file)
|
||||
aud_seg = audio_prep(file_seg)
|
||||
|
||||
def timeinfo():
|
||||
from timeit import Timer
|
||||
|
||||
timer = Timer(lambda: transcriber(aud_seg))
|
||||
number = 100
|
||||
repeat = 10
|
||||
time_taken = timer.repeat(repeat, number=number)
|
||||
best = min(time_taken) * 1000 / number
|
||||
print(f"{number} loops, best of {repeat}: {best:.3f} msec per loop")
|
||||
|
||||
timeinfo()
|
||||
import time
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
transcriber, audio_prep = triton_transcribe_grpc_gen()
|
||||
aud_seg = audio_prep(file_seg)
|
||||
|
||||
def timeinfo():
|
||||
from timeit import Timer
|
||||
|
||||
timer = Timer(lambda: transcriber(aud_seg))
|
||||
number = 100
|
||||
repeat = 10
|
||||
time_taken = timer.repeat(repeat, number=number)
|
||||
best = min(time_taken) * 1000 / number
|
||||
print(f"{number} loops, best of {repeat}: {best:.3f} msec per loop")
|
||||
|
||||
timeinfo()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
109
src/plume/utils/tts.py
Normal file
109
src/plume/utils/tts.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from logging import getLogger
|
||||
from plume.utils import lazy_module
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
|
||||
# from google.cloud import texttospeech
|
||||
texttospeech = lazy_module('google.cloud.texttospeech')
|
||||
|
||||
LOGGER = getLogger("googletts")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
class GoogleTTS(object):
|
||||
def __init__(self):
|
||||
self.client = texttospeech.TextToSpeechClient()
|
||||
|
||||
def text_to_speech(self, text: str, params: dict) -> bytes:
|
||||
tts_input = texttospeech.types.SynthesisInput(text=text)
|
||||
voice = texttospeech.types.VoiceSelectionParams(
|
||||
language_code=params["language"], name=params["name"]
|
||||
)
|
||||
audio_config = texttospeech.types.AudioConfig(
|
||||
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=params["sample_rate"],
|
||||
)
|
||||
if 'speaking_rate' in params:
|
||||
audio_config.speaking_rate = params['speaking_rate']
|
||||
if 'pitch' in params:
|
||||
audio_config.pitch = params['pitch']
|
||||
response = self.client.synthesize_speech(tts_input, voice, audio_config)
|
||||
audio_content = response.audio_content
|
||||
return audio_content
|
||||
|
||||
def ssml_to_speech(self, text: str, params: dict) -> bytes:
|
||||
tts_input = texttospeech.types.SynthesisInput(ssml=text)
|
||||
voice = texttospeech.types.VoiceSelectionParams(
|
||||
language_code=params["language"], name=params["name"]
|
||||
)
|
||||
audio_config = texttospeech.types.AudioConfig(
|
||||
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=params["sample_rate"],
|
||||
)
|
||||
response = self.client.synthesize_speech(tts_input, voice, audio_config)
|
||||
audio_content = response.audio_content
|
||||
return audio_content
|
||||
|
||||
@classmethod
|
||||
def voice_list(cls):
|
||||
"""Lists the available voices."""
|
||||
|
||||
client = cls().client
|
||||
|
||||
# Performs the list voices request
|
||||
voices = client.list_voices()
|
||||
results = []
|
||||
for voice in voices.voices:
|
||||
supported_eng_langs = [
|
||||
lang for lang in voice.language_codes if lang[:2] == "en"
|
||||
]
|
||||
if len(supported_eng_langs) > 0:
|
||||
lang = ",".join(supported_eng_langs)
|
||||
else:
|
||||
continue
|
||||
|
||||
ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender)
|
||||
results.append(
|
||||
{
|
||||
"name": voice.name,
|
||||
"language": lang,
|
||||
"gender": ssml_gender.name,
|
||||
"engine": "wavenet" if "Wav" in voice.name else "standard",
|
||||
"sample_rate": voice.natural_sample_rate_hertz,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def voice_by_name(cls, name):
|
||||
"""Lists the available voices."""
|
||||
|
||||
# client = cls().client
|
||||
|
||||
# Performs the list voices request
|
||||
results = cls.voice_list()
|
||||
for voice in results:
|
||||
if voice['name'] == name:
|
||||
return voice
|
||||
raise ValueError(f'{name} not a valid voice')
|
||||
|
||||
|
||||
@app.command()
|
||||
def generate_audio_file(text, dest_path: Path = "./tts_audio.wav", voice="en-US-Wavenet-D"):
|
||||
tts = GoogleTTS()
|
||||
selected_voice = [v for v in tts.voice_list() if v["name"] == voice][0]
|
||||
wav_data = tts.text_to_speech(text, selected_voice)
|
||||
with dest_path.open("wb") as wf:
|
||||
wf.write(wav_data)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
85
src/plume/utils/ui_persist.py
Normal file
85
src/plume/utils/ui_persist.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from plume.utils import ExtendedPath, get_mongo_conn
|
||||
from plume.utils.st_rerun import rerun
|
||||
from uuid import uuid4
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def setup_file_state(st):
|
||||
if not hasattr(st, "state_lock"):
|
||||
# st.task_id = str(uuid4())
|
||||
task_path = ExtendedPath("preview.lck")
|
||||
|
||||
def current_cursor_fn():
|
||||
return task_path.read_json()["current_cursor"]
|
||||
|
||||
def update_cursor_fn(val=0):
|
||||
task_path.write_json({"current_cursor": val})
|
||||
rerun()
|
||||
|
||||
st.get_current_cursor = current_cursor_fn
|
||||
st.update_cursor = update_cursor_fn
|
||||
st.state_lock = True
|
||||
# cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
|
||||
# if not cursor_obj:
|
||||
update_cursor_fn(0)
|
||||
|
||||
|
||||
def setup_mongo_asr_validation_state(st):
|
||||
if not hasattr(st, "mongo_connected"):
|
||||
st.mongoclient = get_mongo_conn(col="asr_validation")
|
||||
mongo_conn = st.mongoclient
|
||||
st.task_id = str(uuid4())
|
||||
|
||||
def current_cursor_fn():
|
||||
# mongo_conn = st.mongoclient
|
||||
cursor_obj = mongo_conn.find_one(
|
||||
{"type": "current_cursor", "task_id": st.task_id}
|
||||
)
|
||||
cursor_val = cursor_obj["cursor"]
|
||||
return cursor_val
|
||||
|
||||
def update_cursor_fn(val=0):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "current_cursor", "task_id": st.task_id},
|
||||
{
|
||||
"$set": {
|
||||
"type": "current_cursor",
|
||||
"task_id": st.task_id,
|
||||
"cursor": val,
|
||||
}
|
||||
},
|
||||
upsert=True,
|
||||
)
|
||||
rerun()
|
||||
|
||||
def get_correction_entry_fn(code):
|
||||
return mongo_conn.find_one(
|
||||
{"type": "correction", "code": code}, projection={"_id": False}
|
||||
)
|
||||
|
||||
def update_entry_fn(code, value):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": code},
|
||||
{"$set": {"value": value, "task_id": st.task_id}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
def set_task_fn(data_path, task_id):
|
||||
if task_id:
|
||||
st.task_id = task_id
|
||||
task_path = data_path / Path(f"task-{st.task_id}.lck")
|
||||
if not task_path.exists():
|
||||
print(f"creating task lock at {task_path}")
|
||||
task_path.touch()
|
||||
|
||||
st.get_current_cursor = current_cursor_fn
|
||||
st.update_cursor = update_cursor_fn
|
||||
st.get_correction_entry = get_correction_entry_fn
|
||||
st.update_entry = update_entry_fn
|
||||
st.set_task = set_task_fn
|
||||
st.mongo_connected = True
|
||||
cursor_obj = mongo_conn.find_one(
|
||||
{"type": "current_cursor", "task_id": st.task_id}
|
||||
)
|
||||
if not cursor_obj:
|
||||
update_cursor_fn(0)
|
||||
134
src/plume/utils/vad.py
Normal file
134
src/plume/utils/vad.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import logging
|
||||
from .lazy_import import lazy_module
|
||||
|
||||
webrtcvad = lazy_module("webrtcvad")
|
||||
pydub = lazy_module("pydub")
|
||||
|
||||
DEFAULT_CHUNK_DUR = 30
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_frame_voice(vad, seg, chunk_dur):
|
||||
return (
|
||||
True
|
||||
if (
|
||||
seg.duration_seconds == chunk_dur / 1000
|
||||
and vad.is_speech(seg.raw_data, seg.frame_rate)
|
||||
)
|
||||
else False
|
||||
)
|
||||
|
||||
|
||||
class VADUtterance(object):
|
||||
"""docstring for VADUtterance."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_silence=500,
|
||||
min_utterance=280,
|
||||
max_utterance=20000,
|
||||
chunk_dur=DEFAULT_CHUNK_DUR,
|
||||
start_cycles=3,
|
||||
aggression=1,
|
||||
):
|
||||
super(VADUtterance, self).__init__()
|
||||
self.vad = webrtcvad.Vad(aggression)
|
||||
self.chunk_dur = chunk_dur
|
||||
# duration in millisecs
|
||||
self.max_sil = max_silence
|
||||
self.min_utt = min_utterance
|
||||
self.max_utt = max_utterance
|
||||
self.speech_start = start_cycles * chunk_dur
|
||||
|
||||
def __repr__(self):
|
||||
return f"VAD(max_silence={self.max_sil},min_utterance:{self.min_utt},max_utterance:{self.max_utt})"
|
||||
|
||||
def stream_segments(self, audio_seg):
|
||||
stream_seg = audio_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
silence_buffer = pydub.AudioSegment.empty()
|
||||
voice_buffer = pydub.AudioSegment.empty()
|
||||
silence_threshold = False
|
||||
for c in stream_seg[:: self.chunk_dur]:
|
||||
voice_frame = is_frame_voice(self.vad, c, self.chunk_dur)
|
||||
# logger.info(f"is audio stream voice? {voice_frame}")
|
||||
if voice_frame:
|
||||
silence_threshold = False
|
||||
voice_buffer += c
|
||||
silence_buffer = pydub.AudioSegment.empty()
|
||||
else:
|
||||
silence_buffer += c
|
||||
voc_dur = len(voice_buffer)
|
||||
sil_dur = len(silence_buffer)
|
||||
|
||||
if voc_dur >= self.max_utt:
|
||||
# logger.info(
|
||||
# f"detected voice overflow: voice duration {voice_buffer.duration_seconds}"
|
||||
# )
|
||||
yield voice_buffer
|
||||
voice_buffer = pydub.AudioSegment.empty()
|
||||
|
||||
if sil_dur >= self.max_sil:
|
||||
if voc_dur >= self.min_utt:
|
||||
# logger.info(
|
||||
# f"detected silence: voice duration {voice_buffer.duration_seconds}"
|
||||
# )
|
||||
yield voice_buffer
|
||||
voice_buffer = pydub.AudioSegment.empty()
|
||||
# ignore/clear voice if silence reached threshold or indent the statement
|
||||
if not silence_threshold:
|
||||
silence_threshold = True
|
||||
|
||||
# if voice_buffer:
|
||||
# yield voice_buffer
|
||||
|
||||
if self.min_utt < len(voice_buffer) < self.max_utt:
|
||||
yield voice_buffer
|
||||
|
||||
# def stream_utterance(self, audio_stream):
|
||||
# silence_buffer = pydub.AudioSegment.empty()
|
||||
# voice_buffer = pydub.AudioSegment.empty()
|
||||
# silence_threshold = False
|
||||
# for avf in audio_stream:
|
||||
# audio_bytes = avf.to_ndarray().tobytes()
|
||||
# c = (
|
||||
# pydub.AudioSegment(
|
||||
# data=audio_bytes,
|
||||
# frame_rate=avf.sample_rate,
|
||||
# channels=len(avf.layout.channels),
|
||||
# sample_width=avf.format.bytes,
|
||||
# )
|
||||
# .set_channels(1)
|
||||
# .set_sample_width(2)
|
||||
# .set_frame_rate(16000)
|
||||
# )
|
||||
# voice_frame = is_frame_voice(self.vad, c, self.chunk_dur)
|
||||
# # logger.info(f"is audio stream voice? {voice_frame}")
|
||||
# if voice_frame:
|
||||
# silence_threshold = False
|
||||
# voice_buffer += c
|
||||
# silence_buffer = pydub.AudioSegment.empty()
|
||||
# else:
|
||||
# silence_buffer += c
|
||||
# voc_dur = voice_buffer.duration_seconds * 1000
|
||||
# sil_dur = silence_buffer.duration_seconds * 1000
|
||||
#
|
||||
# if voc_dur >= self.max_utt:
|
||||
# # logger.info(
|
||||
# # f"detected voice overflow: voice duration {voice_buffer.duration_seconds}"
|
||||
# # )
|
||||
# yield voice_buffer
|
||||
# voice_buffer = pydub.AudioSegment.empty()
|
||||
#
|
||||
# if sil_dur >= self.max_sil:
|
||||
# if voc_dur >= self.min_utt:
|
||||
# # logger.info(
|
||||
# # f"detected silence: voice duration {voice_buffer.duration_seconds}"
|
||||
# # )
|
||||
# yield voice_buffer
|
||||
# voice_buffer = pydub.AudioSegment.empty()
|
||||
# # ignore/clear voice if silence reached threshold or indent the statement
|
||||
# if not silence_threshold:
|
||||
# silence_threshold = True
|
||||
#
|
||||
# if voice_buffer:
|
||||
# yield voice_buffer
|
||||
Reference in New Issue
Block a user