massive refactor/rename to plume

tegra
Malar Kannan 2021-02-23 19:43:33 +05:30
parent e8f58a5043
commit ed6117559a
51 changed files with 2864 additions and 1037 deletions

View File

@ -3,3 +3,17 @@
```
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
```
> Prepare Augmented Data
```
plume data filter /dataset/png_entities/png_numbers_2020_07/ /dataset/png_entities/png_numbers_2020_07_skip1hour/
plume data augment /dataset/agara_slu/call_alphanum_ag_sg_v1_abs/ /dataset/png_entities/png_numbers_2020_07_1hour_noblank/ /dataset/png_entities/png_numbers_2020_07_skip1hour/ /dataset/png_entities/aug_pngskip1hour-agsgalnum-1hournoblank/
plume data filter --kind transform_digits /dataset/agara_slu/png1hour-agsgalnum-1hournoblank/ /dataset/agara_slu/png1hour-agsgalnum-1hournoblank_prep/
```
```
KENLM_INC=/usr/local/include/kenlm/ pip install -e ../deps/wav2letter/bindings/python/
```

View File

@ -1,8 +1,8 @@
# Jasper ASR
# Plume ASR
[![image](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black)
> Generates text from speech audio
> Generates text from audio containing speech
---
# Table of Contents
@ -20,7 +20,7 @@
# Features
* ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) )
* ASR using Wav2Vec2 (from [fairseq](https://github.com/pytorch/fairseq) )
# Installation
To install the packages and its dependencies run.
@ -29,14 +29,26 @@ python setup.py install
```
or with pip
```bash
pip install .[server]
pip install .[all]
```
The installation should work on Python 3.6 or newer. Untested on Python 2.7
# Usage
### Library
> Jasper
```python
from jasper.asr import JasperASR
from plume.models.jasper.asr import JasperASR
asr_model = JasperASR("/path/to/model_config_yaml","/path/to/encoder_checkpoint","/path/to/decoder_checkpoint") # Loads the models
TEXT = asr_model.transcribe(wav_data) # Returns the text spoken in the wav
```
> Wav2Vec2
```python
from plume.models.wav2vec2.asr import Wav2Vec2ASR
asr_model = Wav2Vec2ASR("/path/to/ctc_checkpoint","/path/to/w2v_checkpoint","/path/to/target_dictionary") # Loads the models
TEXT = asr_model.transcribe(wav_data) # Returns the text spoken in the wav
```
### Command Line
```
$ plume
```

View File

@ -1 +0,0 @@

View File

@ -1,21 +0,0 @@
import os
import logging
import rpyc
from functools import lru_cache
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost")
ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045"))
@lru_cache()
def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
asr = rpyc.connect(asr_host, asr_port).root
logger.info(f"connected to asr server successfully")
return asr.transcribe

View File

@ -1 +0,0 @@

View File

@ -1,77 +0,0 @@
import json
from pathlib import Path
from sklearn.model_selection import train_test_split
from .utils import asr_manifest_reader, asr_manifest_writer
from typing import List
from itertools import chain
import typer
app = typer.Typer()
@app.command()
def fixate_data(dataset_path: Path):
manifest_path = dataset_path / Path("manifest.json")
real_manifest_path = dataset_path / Path("abs_manifest.json")
def fix_path():
for i in asr_manifest_reader(manifest_path):
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
yield i
asr_manifest_writer(real_manifest_path, fix_path())
@app.command()
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
reader_list = []
abs_manifest_path = Path("abs_manifest.json")
for dataset_path in src_dataset_paths:
manifest_path = dataset_path / abs_manifest_path
reader_list.append(asr_manifest_reader(manifest_path))
dest_dataset_path.mkdir(parents=True, exist_ok=True)
dest_manifest_path = dest_dataset_path / abs_manifest_path
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
@app.command()
def split_data(dataset_path: Path, test_size: float = 0.1):
manifest_path = dataset_path / Path("abs_manifest.json")
asr_data = list(asr_manifest_reader(manifest_path))
train_data, test_data = train_test_split(asr_data, test_size=test_size)
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_data)
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_data)
@app.command()
def validate_data(dataset_path: Path):
from natural.date import compress
from datetime import timedelta
for mf_type in ["train_manifest.json", "test_manifest.json"]:
data_file = dataset_path / Path(mf_type)
print(f"validating {data_file}.")
with Path(data_file).open("r") as pf:
data_jsonl = pf.readlines()
duration = 0
for (i, s) in enumerate(data_jsonl):
try:
d = json.loads(s)
duration += d["duration"]
audio_file = data_file.parent / Path(d["audio_filepath"])
if not audio_file.exists():
raise OSError(f"File {audio_file} not found")
except BaseException as e:
print(f'failed on {i} with "{e}"')
duration_str = compress(timedelta(seconds=duration), pad=" ")
print(
f"no errors found. seems like a valid {mf_type}. contains {duration_str}sec of audio"
)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -1,93 +0,0 @@
from rastrik.proto.callrecord_pb2 import CallRecord
import gzip
from pydub import AudioSegment
from .utils import ui_dump_manifest_writer, strip_silence
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
app = typer.Typer()
@app.command()
def extract_manifest(
call_log_dir: Path = Path("./data/call_audio"),
output_dir: Path = Path("./data"),
dataset_name: str = "grassroot_pizzahut_v1",
caller_name: str = "grassroot",
verbose: bool = False,
):
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_pb2_generator(log_dir):
for wav_path in log_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
meta_path = wav_path.with_suffix(".pb2.gz")
yield call_wav, wav_path, meta_path
def read_event(call_wav, log_file):
call_wav_0, call_wav_1 = call_wav.split_to_mono()
with gzip.open(log_file, "rb") as log_h:
record_data = log_h.read()
cr = CallRecord()
cr.ParseFromString(record_data)
first_audio_event_timestamp = next(
(
i
for i in cr.events
if i.WhichOneof("event_type") == "call_event"
and i.call_event.WhichOneof("event_type") == "call_audio"
)
).timestamp.ToDatetime()
speech_events = [
i
for i in cr.events
if i.WhichOneof("event_type") == "speech_event"
and i.speech_event.WhichOneof("event_type") == "asr_final"
]
previous_event_timestamp = (
first_audio_event_timestamp - first_audio_event_timestamp
)
for index, each_speech_events in enumerate(speech_events):
asr_final = each_speech_events.speech_event.asr_final
speech_timestamp = each_speech_events.timestamp.ToDatetime()
actual_timestamp = speech_timestamp - first_audio_event_timestamp
start_time = previous_event_timestamp.total_seconds() * 1000
end_time = actual_timestamp.total_seconds() * 1000
audio_segment = strip_silence(call_wav_1[start_time:end_time])
code_fb = BytesIO()
audio_segment.export(code_fb, format="wav")
wav_data = code_fb.getvalue()
previous_event_timestamp = actual_timestamp
duration = (end_time - start_time) / 1000
yield asr_final, duration, wav_data, "grassroot", audio_segment
def generate_call_asr_data():
full_data = []
total_duration = 0
for wav, wav_path, pb2_path in wav_pb2_generator(call_log_dir):
asr_data = read_event(wav, pb2_path)
total_duration += wav.duration_seconds
full_data.append(asr_data)
n_calls = len(full_data)
typer.echo(f"loaded {n_calls} calls of duration {total_duration}s")
n_dps = ui_dump_manifest_writer(call_asr_data, dataset_name, chain(*full_data))
typer.echo(f"written {n_dps} data points")
generate_call_asr_data()
def main():
app()
if __name__ == "__main__":
main()

View File

@ -1,241 +0,0 @@
import io
import os
import json
import wave
from pathlib import Path
from functools import partial
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor
import pymongo
from slugify import slugify
from jasper.client import transcribe_gen
from nemo.collections.asr.metrics import word_error_rate
import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm
def manifest_str(path, dur, text):
return (
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
+ "\n"
)
def wav_bytes(audio_bytes, frame_rate=24000):
wf_b = io.BytesIO()
with wave.open(wf_b, mode="w") as wf:
wf.setnchannels(1)
wf.setframerate(frame_rate)
wf.setsampwidth(2)
wf.writeframesraw(audio_bytes)
return wf_b.getvalue()
def tscript_uuid_fname(transcript):
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for transcript, audio_dur, wav_data in asr_data_source:
fname = tscript_uuid_fname(transcript)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
rel_data_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
if verbose:
print(f"writing '{transcript}' of duration {audio_dur}")
num_datapoints += 1
return num_datapoints
def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
def data_fn(
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
png_path = Path(fname).with_suffix(".png")
wav_plot_path = dataset_dir / Path("wav_plots") / png_path
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
return {
"audio_filepath": str(rel_data_path),
"duration": round(audio_dur, 1),
"text": transcript,
"real_idx": num_datapoints,
"audio_path": audio_path,
"spoken": transcript,
"caller": caller_name,
"utterance_id": fname,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
}
num_datapoints = 0
data_funcs = []
transcriber_pretrained = transcribe_gen(asr_port=8044)
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
audio_path = str(audio_file)
rel_data_path = audio_file.relative_to(dataset_dir)
data_funcs.append(
partial(
data_fn,
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_path,
num_datapoints,
rel_data_path,
)
)
num_datapoints += 1
ui_data = parallel_apply(lambda x: x(), data_funcs)
return ui_data, num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
dump_data, num_datapoints = ui_data_generator(
output_dir, dataset_name, asr_data_source, verbose=verbose
)
asr_manifest = dataset_dir / Path("manifest.json")
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for d in dump_data:
rel_data_path = d["audio_filepath"]
audio_dur = d["duration"]
transcript = d["text"]
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
ui_dump_file = dataset_dir / Path("ui_dump.json")
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
return num_datapoints
def asr_manifest_reader(data_manifest_path: Path):
print(f"reading manifest from {data_manifest_path}")
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
data_data = [json.loads(v) for v in data_jsonl]
for p in data_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["text"] = p["text"].strip()
yield p
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
with asr_manifest_path.open("w") as mf:
print(f"opening {asr_manifest_path} for writing manifest")
for mani_dict in manifest_str_source:
manifest = manifest_str(
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
)
mf.write(manifest)
def asr_test_writer(out_file_path: Path, source):
def dd_str(dd, idx):
path = dd["audio_filepath"]
# dur = dd["duration"]
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
res_file = out_file_path.with_suffix(".result.json")
with out_file_path.open("w") as of:
print(f"opening {out_file_path} for writing test")
results = []
idx = 0
for ui_dd in source:
results.append(ui_dd)
out_str = dd_str(ui_dd, idx)
of.write(out_str)
idx += 1
of.write("DO_HANGUP\n")
ExtendedPath(res_file).write_json(results)
def batch(iterable, n=1):
ls = len(iterable)
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
class ExtendedPath(type(Path())):
"""docstring for ExtendedPath."""
def read_json(self):
print(f"reading json from {self}")
with self.open("r") as jf:
return json.load(jf)
def write_json(self, data):
print(f"writing json to {self}")
self.parent.mkdir(parents=True, exist_ok=True)
with self.open("w") as jf:
return json.dump(data, jf, indent=2)
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
mongo_uri = f"mongodb://{mongo_host}:{port}/"
return pymongo.MongoClient(mongo_uri)[db][col]
def strip_silence(sound):
from pydub.silence import detect_leading_silence
start_trim = detect_leading_silence(sound)
end_trim = detect_leading_silence(sound.reverse())
duration = len(sound)
return sound[start_trim : duration - end_trim]
def plot_seg(wav_plot_path, audio_path):
fig = plt.Figure()
ax = fig.add_subplot()
(y, sr) = librosa.load(audio_path)
librosa.display.waveplot(y=y, sr=sr, ax=ax)
with wav_plot_path.open("wb") as wav_plot_f:
fig.set_tight_layout(True)
fig.savefig(wav_plot_f, format="png", dpi=50)
def parallel_apply(fn, iterable, workers=8):
with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")
return [
res
for res in tqdm(
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
)
]

View File

@ -1 +0,0 @@

View File

@ -1,398 +0,0 @@
import json
import shutil
from pathlib import Path
import typer
from ..utils import (
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
tscript_uuid_fname,
get_mongo_conn,
plot_seg,
)
app = typer.Typer()
def preprocess_datapoint(idx, rel_root, sample):
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.client import transcribe_gen
try:
res = dict(sample)
res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
res["utterance_id"] = audio_path.stem
transcriber_pretrained = transcribe_gen(asr_port=8044)
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
return res
except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
@app.command()
def dump_ui(
data_name: str = typer.Option("dataname", show_default=True),
dataset_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/valiation_data"),
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
):
from io import BytesIO
from pydub import AudioSegment
from ..utils import ui_data_generator
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}")
def asr_data_source_gen():
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
for v in data_jsonl:
sample = json.loads(v)
rel_root = data_manifest_path.parent
res = dict(sample)
audio_path = rel_root / Path(sample["audio_filepath"])
audio_segment = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
wav_plot_path = (
rel_root
/ Path("wav_plots")
/ Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
code_fb = BytesIO()
audio_segment.export(code_fb, format="wav")
wav_data = code_fb.getvalue()
duration = audio_segment.duration_seconds
asr_final = res["text"]
yield asr_final, duration, wav_data, "caller", audio_segment
dump_data, num_datapoints = ui_data_generator(
dataset_dir, data_name, asr_data_source_gen()
)
ui_dump_file = dataset_dir / Path("ui_dump.json")
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
@app.command()
def sample_ui(
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
sample_count: int = typer.Option(80, show_default=True),
sample_file: Path = Path("sample_dump.json"),
):
import pandas as pd
processed_data_path = dump_dir / Path(data_name) / dump_file
sample_path = dump_dir / Path(data_name) / sample_file
processed_data = ExtendedPath(processed_data_path).read_json()
df = pd.DataFrame(processed_data["data"])
samples_per_caller = sample_count // len(df["caller"].unique())
caller_samples = pd.concat(
[g.sample(samples_per_caller) for (c, g) in df.groupby("caller")]
)
caller_samples = caller_samples.reset_index(drop=True)
caller_samples["real_idx"] = caller_samples.index
sample_data = caller_samples.to_dict("records")
processed_data["data"] = sample_data
typer.echo(f"sampling {sample_count} datapoints")
ExtendedPath(sample_path).write_json(processed_data)
@app.command()
def task_ui(
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
task_count: int = typer.Option(4, show_default=True),
task_file: str = "task_dump",
):
import pandas as pd
import numpy as np
processed_data_path = dump_dir / Path(data_name) / dump_file
processed_data = ExtendedPath(processed_data_path).read_json()
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
task_f = task_f.reset_index(drop=True)
task_f["real_idx"] = task_f.index
task_data = task_f.to_dict("records")
processed_data["data"] = task_data
task_path = dump_dir / Path(data_name) / Path(task_file + f"-{t_idx}.json")
ExtendedPath(task_path).write_json(processed_data)
@app.command()
def dump_corrections(
task_uid: str,
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("corrections.json"),
):
dump_path = dump_dir / Path(data_name) / dump_fname
col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find(
{"type": "correction", "task_id": task_id}, projection={"_id": False}
)
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@app.command()
def caller_quality(
task_uid: str,
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("ui_dump.json"),
correction_fname: Path = Path("corrections.json"),
):
import copy
import pandas as pd
dump_path = dump_dir / Path(data_name) / dump_fname
correction_path = dump_dir / Path(data_name) / correction_fname
dump_data = ExtendedPath(dump_path).read_json()
dump_map = {d["utterance_id"]: d for d in dump_data["data"]}
correction_data = ExtendedPath(correction_path).read_json()
def correction_dp(c):
dp = copy.deepcopy(dump_map[c["code"]])
dp["valid"] = c["value"]["status"] == "Correct"
return dp
corrected_dump = [
correction_dp(c)
for c in correction_data
if c["task_id"].rsplit("-", 1)[1] == task_uid
]
df = pd.DataFrame(corrected_dump)
print(f"Total samples: {len(df)}")
for (c, g) in df.groupby("caller"):
total = len(g)
valid = len(g[g["valid"] == True])
valid_rate = valid * 100 / total
print(f"Caller: {c} Valid%:{valid_rate:.2f} of {total} samples")
@app.command()
def fill_unannotated(
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
dump_file: Path = Path("ui_dump.json"),
corrections_file: Path = Path("corrections.json"),
):
processed_data_path = dump_dir / Path(data_name) / dump_file
corrections_path = dump_dir / Path(data_name) / corrections_file
processed_data = json.load(processed_data_path.open())
corrections = json.load(corrections_path.open())
annotated_codes = {c["code"] for c in corrections}
all_codes = {c["gold_chars"] for c in processed_data}
unann_codes = all_codes - annotated_codes
mongo_conn = get_mongo_conn(col="asr_validation")
for c in unann_codes:
mongo_conn.find_one_and_update(
{"type": "correction", "code": c},
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
upsert=True,
)
@app.command()
def split_extract(
data_name: str = typer.Option("dataname", show_default=True),
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
# dump_dir: Path = Path("./data/valiation_data"),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(
Path("./data/conv_data.json"), show_default=True
),
extraction_type: str = "all",
):
import shutil
data_manifest_path = dump_dir / Path(data_name) / manifest_file
conv_data = ExtendedPath(conv_data_path).read_json()
def extract_data_of_type(extraction_key):
extraction_vals = conv_data[extraction_key]
dest_data_name = data_name + "_" + extraction_key.lower()
manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = dump_dir / Path(dest_data_name)
dest_data_dir.mkdir(exist_ok=True, parents=True)
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file
dest_ui_path = dest_data_dir / dump_file
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
)
yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
ui_data_path = dump_dir / Path(data_name) / dump_file
orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(
filter(lambda u: u["text"] in extraction_vals, ui_data)
)
final_data = []
for i, d in enumerate(extracted_ui_data):
d["real_idx"] = i
final_data.append(d)
orig_ui_data["data"] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file:
dest_correction_path = dest_data_dir / corrections_file
corrections_path = dump_dir / Path(data_name) / corrections_file
corrections = json.load(corrections_path.open())
extracted_corrections = list(
filter(
lambda c: c["code"] in file_ui_map
and file_ui_map[c["code"]]["text"] in extraction_vals,
corrections,
)
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == "all":
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
extract_data_of_type(extraction_type.value)
@app.command()
def update_corrections(
data_name: str = typer.Option("dataname", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
ui_dump_file: Path = Path("ui_dump.json"),
skip_incorrect: bool = typer.Option(True, show_default=True),
):
data_manifest_path = dump_dir / Path(data_name) / manifest_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
correction_map = {
c["code"]: c["value"]["correction"]
for c in corrections
if c["value"]["status"] == "Incorrect"
}
# for d in manifest_data_gen:
# if d["chars"] in incorrect_set:
# d["audio_path"].unlink()
# renamed_set = set()
for d in ui_data:
if d["utterance_id"] in correct_set:
yield {
"audio_filepath": d["audio_filepath"],
"duration": d["duration"],
"text": d["text"],
}
elif d["utterance_id"] in correction_map:
correct_text = correction_map[d["utterance_id"]]
if skip_incorrect:
print(
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
)
else:
orig_audio_path = Path(d["audio_path"])
new_name = str(
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
)
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
yield {
"audio_filepath": new_filepath,
"duration": d["duration"],
"text": correct_text,
}
else:
orig_audio_path = Path(d["audio_path"])
# don't delete if another correction points to an old file
# if d["text"] not in renamed_set:
orig_audio_path.unlink()
# else:
# print(f'skipping deletion of correction:{d["text"]}')
typer.echo(f"Using data manifest:{data_manifest_path}")
dataset_dir = data_manifest_path.parent
dataset_name = dataset_dir.name
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
if not backup_dir.exists():
typer.echo(f"backing up to :{backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir))
# manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
new_data_manifest_path.replace(data_manifest_path)
@app.command()
def clear_mongo_corrections():
delete = typer.confirm("are you sure you want to clear mongo collection it?")
if delete:
col = get_mongo_conn(col="asr_validation")
col.delete_many({"type": "correction"})
col.delete_many({"type": "current_cursor"})
typer.echo("deleted mongo collection.")
return
typer.echo("Aborted")
def main():
app()
if __name__ == "__main__":
main()

View File

@ -1,57 +0,0 @@
import os
import logging
import rpyc
from rpyc.utils.server import ThreadedServer
from .asr import JasperASR
from .utils import arg_parser
class ASRService(rpyc.Service):
def __init__(self, asr_recognizer):
self.asr = asr_recognizer
def on_connect(self, conn):
# code that runs when a connection is created
# (to init the service, if needed)
pass
def on_disconnect(self, conn):
# code that runs after the connection has already closed
# (to finalize the service, if needed)
pass
def exposed_transcribe(self, utterance: bytes): # this is an exposed method
speech_audio = self.asr.transcribe(utterance)
return speech_audio
def exposed_transcribe_cb(
self, utterance: bytes, respond
): # this is an exposed method
speech_audio = self.asr.transcribe(utterance)
respond(speech_audio)
def main():
parser = arg_parser('jasper_transcribe')
parser.description = 'jasper asr rpyc server'
parser.add_argument(
"--port", type=int, default=int(os.environ.get("ASR_RPYC_PORT", "8044")), help="port to listen on"
)
args = parser.parse_args()
args_dict = vars(args)
port = args_dict.pop("port")
jasper_asr = JasperASR(**args_dict)
service = ASRService(jasper_asr)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logging.info("starting asr server...")
t = ThreadedServer(service, port=port)
t.start()
if __name__ == "__main__":
main()

View File

@ -1 +0,0 @@

View File

@ -1,22 +0,0 @@
from pathlib import Path
from .asr import JasperASR
from .utils import arg_parser
def main():
parser = arg_parser('jasper_transcribe')
parser.description = 'transcribe audio file to text'
parser.add_argument(
"audio_file",
type=Path,
help="audio file(16khz 1channel int16 wav) to transcribe",
)
parser.add_argument(
"--greedy", type=bool, default=False, help="enables greedy decoding"
)
args = parser.parse_args()
args_dict = vars(args)
audio_file = args_dict.pop("audio_file")
greedy = args_dict.pop("greedy")
jasper_asr = JasperASR(**args_dict)
jasper_asr.transcribe_file(audio_file, greedy)

View File

@ -1,40 +0,0 @@
import os
import argparse
from pathlib import Path
MODEL_YAML = os.environ.get("JASPER_MODEL_CONFIG", "/models/jasper/jasper10x5dr.yaml")
CHECKPOINT_ENCODER = os.environ.get(
"JASPER_ENCODER_CHECKPOINT", "/models/jasper/JasperEncoder-STEP-265520.pt"
)
CHECKPOINT_DECODER = os.environ.get(
"JASPER_DECODER_CHECKPOINT", "/models/jasper/JasperDecoderForCTC-STEP-265520.pt"
)
KEN_LM = os.environ.get("JASPER_KEN_LM", "/models/jasper/kenlm.pt")
def arg_parser(prog):
parser = argparse.ArgumentParser(
prog=prog, description=f"convert speech to text"
)
parser.add_argument(
"--model_yaml",
type=Path,
default=Path(MODEL_YAML),
help="model config yaml file",
)
parser.add_argument(
"--encoder_checkpoint",
type=Path,
default=Path(CHECKPOINT_ENCODER),
help="encoder checkpoint weights file",
)
parser.add_argument(
"--decoder_checkpoint",
type=Path,
default=Path(CHECKPOINT_DECODER),
help="decoder checkpoint weights file",
)
parser.add_argument(
"--language_model", type=Path, default=None, help="kenlm language model file"
)
return parser

23
plume/cli/__init__.py Normal file
View File

@ -0,0 +1,23 @@
import typer
from ..utils import app as utils_app
from .data import app as data_app
from ..ui import app as ui_app
from .train import app as train_app
from .eval import app as eval_app
from .serve import app as serve_app
app = typer.Typer()
app.add_typer(data_app, name="data")
app.add_typer(ui_app, name="ui")
app.add_typer(train_app, name="train")
app.add_typer(eval_app, name="eval")
app.add_typer(serve_app, name="serve")
app.add_typer(utils_app, name='utils')
def main():
app()
if __name__ == "__main__":
main()

339
plume/cli/data/__init__.py Normal file
View File

@ -0,0 +1,339 @@
import json
from pathlib import Path
# from sklearn.model_selection import train_test_split
from plume.utils import (
asr_manifest_reader,
asr_manifest_writer,
ExtendedPath,
duration_str,
generate_filter_map,
get_mongo_conn,
tscript_uuid_fname,
lazy_callable
)
from typing import List
from itertools import chain
import shutil
import typer
import soundfile
from ...models.wav2vec2.data import app as wav2vec2_app
from .generate import app as generate_app
train_test_split = lazy_callable('sklearn.model_selection.train_test_split')
app = typer.Typer()
app.add_typer(generate_app, name="generate")
app.add_typer(wav2vec2_app, name="wav2vec2")
@app.command()
def fix_path(dataset_path: Path, force: bool = False):
manifest_path = dataset_path / Path("manifest.json")
real_manifest_path = dataset_path / Path("abs_manifest.json")
def fix_real_path():
for i in asr_manifest_reader(manifest_path):
i["audio_filepath"] = str(
(dataset_path / Path(i["audio_filepath"])).absolute()
)
yield i
def fix_rel_path():
for i in asr_manifest_reader(real_manifest_path):
i["audio_filepath"] = str(
Path(i["audio_filepath"]).relative_to(dataset_path)
)
yield i
if not manifest_path.exists() and not real_manifest_path.exists():
typer.echo("Invalid dataset directory")
if not real_manifest_path.exists() or force:
asr_manifest_writer(real_manifest_path, fix_real_path())
if not manifest_path.exists():
asr_manifest_writer(manifest_path, fix_rel_path())
@app.command()
def augment(src_dataset_paths: List[Path], dest_dataset_path: Path):
reader_list = []
abs_manifest_path = Path("abs_manifest.json")
for dataset_path in src_dataset_paths:
manifest_path = dataset_path / abs_manifest_path
reader_list.append(asr_manifest_reader(manifest_path))
dest_dataset_path.mkdir(parents=True, exist_ok=True)
dest_manifest_path = dest_dataset_path / abs_manifest_path
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
@app.command()
def split(dataset_path: Path, test_size: float = 0.03):
manifest_path = dataset_path / Path("abs_manifest.json")
if not manifest_path.exists():
fix_path(dataset_path)
asr_data = list(asr_manifest_reader(manifest_path))
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
@app.command()
def validate(dataset_path: Path):
from natural.date import compress
from datetime import timedelta
for mf_type in ["train_manifest.json", "test_manifest.json"]:
data_file = dataset_path / Path(mf_type)
print(f"validating {data_file}.")
with Path(data_file).open("r") as pf:
pnr_jsonl = pf.readlines()
duration = 0
for (i, s) in enumerate(pnr_jsonl):
try:
d = json.loads(s)
duration += d["duration"]
audio_file = data_file.parent / Path(d["audio_filepath"])
if not audio_file.exists():
raise OSError(f"File {audio_file} not found")
except BaseException as e:
print(f'failed on {i} with "{e}"')
duration_str = compress(timedelta(seconds=duration), pad=" ")
print(
f"no errors found. seems like a valid {mf_type}. contains {duration_str} of audio"
)
@app.command()
def filter(src_dataset_path: Path, dest_dataset_path: Path, kind: str = "skip_dur"):
dest_manifest = dest_dataset_path / Path("manifest.json")
data_file = src_dataset_path / Path("manifest.json")
dest_wav_dir = dest_dataset_path / Path("wavs")
dest_wav_dir.mkdir(exist_ok=True, parents=True)
filter_kind_map = generate_filter_map(
src_dataset_path, dest_dataset_path, data_file
)
selected_filter = filter_kind_map.get(kind, None)
if selected_filter:
asr_manifest_writer(dest_manifest, selected_filter())
else:
typer.echo(f"filter kind - {kind} not implemented")
typer.echo(f"select one of {', '.join(filter_kind_map.keys())}")
@app.command()
def info(dataset_path: Path):
for k in ["", "abs_", "train_", "test_"]:
mf_wav_duration = (
real_duration
) = max_duration = empty_duration = empty_count = total_count = 0
data_file = dataset_path / Path(f"{k}manifest.json")
if data_file.exists():
print(f"stats on {data_file}")
for s in ExtendedPath(data_file).read_jsonl():
total_count += 1
mf_wav_duration += s["duration"]
if s["text"] == "":
empty_count += 1
empty_duration += s["duration"]
wav_path = str(dataset_path / Path(s["audio_filepath"]))
if max_duration < soundfile.info(wav_path).duration:
max_duration = soundfile.info(wav_path).duration
real_duration += soundfile.info(wav_path).duration
# frame_count = soundfile.info(audio_fname).frames
print(f"max audio duration : {duration_str(max_duration)}")
print(f"total audio duration : {duration_str(mf_wav_duration)}")
print(f"total real audio duration : {duration_str(real_duration)}")
print(
f"total content duration : {duration_str(mf_wav_duration-empty_duration)}"
)
print(f"total empty duration : {duration_str(empty_duration)}")
print(
f"total empty samples : {empty_count}/{total_count} ({empty_count*100/total_count:.2f}%)"
)
@app.command()
def audio_duration(dataset_path: Path):
wav_duration = 0
for audio_rel_fname in dataset_path.absolute().glob("**/*.wav"):
audio_fname = str(audio_rel_fname)
wav_duration += soundfile.info(audio_fname).duration
typer.echo(f"duration of wav files @ {dataset_path}: {duration_str(wav_duration)}")
@app.command()
def migrate(src_path: Path, dest_path: Path):
shutil.copytree(str(src_path), str(dest_path))
wav_dir = dest_path / Path("wavs")
wav_dir.mkdir(exist_ok=True, parents=True)
abs_manifest_path = ExtendedPath(dest_path / Path("abs_manifest.json"))
backup_abs_manifest_path = abs_manifest_path.with_suffix(".json.orig")
shutil.copy(abs_manifest_path, backup_abs_manifest_path)
manifest_data = list(abs_manifest_path.read_jsonl())
for md in manifest_data:
orig_path = Path(md["audio_filepath"])
new_path = wav_dir / Path(orig_path.name)
shutil.copy(orig_path, new_path)
md["audio_filepath"] = str(new_path)
abs_manifest_path.write_jsonl(manifest_data)
fix_path(dest_path)
@app.command()
def task_split(
data_dir: Path,
dump_file: Path = Path("ui_dump.json"),
task_count: int = typer.Option(2, show_default=True),
task_file: str = "task_dump",
sort: bool = True,
):
"""
split ui_dump.json to `task_count` tasks
"""
import pandas as pd
import numpy as np
processed_data_path = data_dir / dump_file
processed_data = ExtendedPath(processed_data_path).read_json()
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
task_f = task_f.reset_index(drop=True)
task_f["real_idx"] = task_f.index
task_data = task_f.to_dict("records")
if sort:
task_data = sorted(task_data, key=lambda x: x["asr_wer"], reverse=True)
processed_data["data"] = task_data
task_path = data_dir / Path(task_file + f"-{t_idx}.json")
ExtendedPath(task_path).write_json(processed_data)
def get_corrections(task_uid):
col = get_mongo_conn(col="asr_validation")
task_id = [
c
for c in col.distinct("task_id")
if c.rsplit("-", 1)[1] == task_uid or c == task_uid
][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find(
{"type": "correction", "task_id": task_id}, projection={"_id": False}
)
corrections = [c for c in cursor_obj]
return corrections
@app.command()
def dump_task_corrections(data_dir: Path, task_uid: str):
dump_fname: Path = Path(f"corrections-{task_uid}.json")
dump_path = data_dir / dump_fname
corrections = get_corrections(task_uid)
ExtendedPath(dump_path).write_json(corrections)
@app.command()
def dump_all_corrections(data_dir: Path):
for task_lcks in data_dir.glob('task-*.lck'):
task_uid = task_lcks.stem.replace('task-', '')
dump_task_corrections(data_dir, task_uid)
@app.command()
def update_corrections(
data_dir: Path,
skip_incorrect: bool = typer.Option(
False, show_default=True, help="treats incorrect as invalid"
),
skip_inaudible: bool = typer.Option(
False, show_default=True, help="include invalid as blank target"
),
):
"""
applies the corrections-*.json
backup the original dataset
"""
manifest_file: Path = Path("manifest.json")
renames_file: Path = Path("rename_map.json")
ui_dump_file: Path = Path("ui_dump.json")
data_manifest_path = data_dir / manifest_file
renames_path = data_dir / renames_file
def correct_ui_dump(data_dir, rename_result):
ui_dump_path = data_dir / ui_dump_file
# corrections_path = data_dir / Path("corrections.json")
corrections = [
t
for p in data_dir.glob("corrections-*.json")
for t in ExtendedPath(p).read_json()
]
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
correction_map = {
c["code"]: c["value"]["correction"]
for c in corrections
if c["value"]["status"] == "Incorrect"
}
for d in ui_data:
orig_audio_path = (data_dir / Path(d["audio_path"])).absolute()
if d["utterance_id"] in correct_set:
d["corrected_from"] = d["text"]
yield d
elif d["utterance_id"] in correction_map:
correct_text = correction_map[d["utterance_id"]]
if skip_incorrect:
ap = d["audio_path"]
print(f"skipping incorrect {ap} corrected to {correct_text}")
orig_audio_path.unlink()
else:
new_fname = tscript_uuid_fname(correct_text)
rename_result[new_fname] = {
"orig_text": d["text"],
"correct_text": correct_text,
"orig_id": d["utterance_id"],
}
new_name = str(Path(new_fname).with_suffix(".wav"))
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_path"]).with_name(new_name))
d["corrected_from"] = d["text"]
d["text"] = correct_text
d["audio_path"] = new_filepath
yield d
else:
if skip_inaudible:
orig_audio_path.unlink()
else:
d["corrected_from"] = d["text"]
d["text"] = ""
yield d
dataset_dir = data_manifest_path.parent
dataset_name = dataset_dir.name
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
if not backup_dir.exists():
typer.echo(f"backing up to {backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir))
renames = {}
corrected_ui_dump = list(correct_ui_dump(data_dir, renames))
ExtendedPath(data_dir / ui_dump_file).write_json({"data": corrected_ui_dump})
corrected_manifest = (
{
"audio_filepath": d["audio_path"],
"duration": d["duration"],
"text": d["text"],
}
for d in corrected_ui_dump
)
asr_manifest_writer(data_manifest_path, corrected_manifest)
ExtendedPath(renames_path).write_json(renames)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,12 @@
from pathlib import Path
import typer
from ...utils.tts import GoogleTTS
app = typer.Typer()
@app.command()
def tts_dataset(dest_path: Path):
tts = GoogleTTS()
pass

5
plume/cli/eval.py Normal file
View File

@ -0,0 +1,5 @@
import typer
from ..models.wav2vec2.eval import app as wav2vec2_app
app = typer.Typer()
app.add_typer(wav2vec2_app, name="wav2vec2")

7
plume/cli/serve.py Normal file
View File

@ -0,0 +1,7 @@
import typer
from ..models.wav2vec2.serve import app as wav2vec2_app
from ..models.jasper.serve import app as jasper_app
app = typer.Typer()
app.add_typer(wav2vec2_app, name="wav2vec2")
app.add_typer(jasper_app, name="jasper")

5
plume/cli/train.py Normal file
View File

@ -0,0 +1,5 @@
import typer
from ..models.wav2vec2.train import app as train_app
app = typer.Typer()
app.add_typer(train_app, name="wav2vec2")

1
plume/models/__init__.py Normal file
View File

@ -0,0 +1 @@
# from . import jasper, wav2vec2, matchboxnet

View File

View File

@ -0,0 +1,24 @@
from pathlib import Path
import typer
app = typer.Typer()
@app.command()
def set_root(dataset_path: Path, root_path: Path):
pass
# for dataset_kind in ["train", "valid"]:
# data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
# with data_file.open("r") as df:
# lines = df.readlines()
# with data_file.open("w") as df:
# lines[0] = str(root_path) + "\n"
# df.writelines(lines)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -45,7 +45,7 @@ def parse_args():
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
exp_name="jasper-speller",
exp_name="jasper",
)
# Overwrite default args

View File

@ -0,0 +1,52 @@
import os
import logging
from pathlib import Path
from rpyc.utils.server import ThreadedServer
import typer
# from .asr import JasperASR
from ...utils.serve import ASRService
from plume.utils import lazy_callable
JasperASR = lazy_callable('plume.models.jasper.asr.JasperASR')
app = typer.Typer()
@app.command()
def rpyc(
encoder_path: Path = "/path/to/encoder.pt",
decoder_path: Path = "/path/to/decoder.pt",
model_yaml_path: Path = "/path/to/model.yaml",
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
):
for p in [encoder_path, decoder_path, model_yaml_path]:
if not p.exists():
logging.info(f"{p} doesn't exists")
return
asr = JasperASR(str(model_yaml_path), str(encoder_path), str(decoder_path))
service = ASRService(asr)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logging.info("starting asr server...")
t = ThreadedServer(service, port=port)
t.start()
@app.command()
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
encoder_path = model_dir / Path("decoder.pt")
decoder_path = model_dir / Path("encoder.pt")
model_yaml_path = model_dir / Path("model.yaml")
rpyc(encoder_path, decoder_path, model_yaml_path, port)
def main():
app()
if __name__ == "__main__":
main()

View File

View File

View File

@ -0,0 +1,204 @@
from io import BytesIO
import warnings
import itertools as it
import torch
import soundfile as sf
import torch.nn.functional as F
try:
from fairseq import utils
from fairseq.models import BaseFairseqModel
from fairseq.data import Dictionary
from fairseq.models.wav2vec.wav2vec2_asr import base_architecture, Wav2VecEncoder
except ModuleNotFoundError:
warnings.warn("Install fairseq")
try:
from wav2letter.decoder import CriterionType
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
except ModuleNotFoundError:
warnings.warn("Install wav2letter")
class Wav2VecCtc(BaseFairseqModel):
def __init__(self, w2v_encoder, args):
super().__init__()
self.w2v_encoder = w2v_encoder
self.args = args
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, args, target_dict):
"""Build a new model instance."""
base_architecture(args)
w2v_encoder = Wav2VecEncoder(args, target_dict)
return cls(w2v_encoder, args)
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output["encoder_out"]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x
class W2lDecoder(object):
def __init__(self, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = 1
self.criterion_type = CriterionType.CTC
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
self.asg_transitions = None
def generate(self, model, sample, **unused):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(model, encoder_input)
return self.decode(emissions)
def get_emissions(self, model, encoder_input):
"""Run encoder and normalize emissions"""
# encoder_out = models[0].encoder(**encoder_input)
encoder_out = model(**encoder_input)
if self.criterion_type == CriterionType.CTC:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = list()
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
for b in range(B)
]
def post_process(sentence: str, symbol: str):
if symbol == "sentencepiece":
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
elif symbol == "wordpiece":
sentence = sentence.replace(" ", "").replace("_", " ").strip()
elif symbol == "letter":
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol is not None and symbol != "none":
sentence = (sentence + " ").replace(symbol, "").rstrip()
return sentence
def get_feature(filepath):
def postprocess(feats, sample_rate):
if feats.dim == 2:
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
wav, sample_rate = sf.read(filepath)
feats = torch.from_numpy(wav).float()
if torch.cuda.is_available():
feats = feats.cuda()
feats = postprocess(feats, sample_rate)
return feats
def load_model(ctc_model_path, w2v_model_path, target_dict):
w2v = torch.load(ctc_model_path)
w2v["args"].w2v_path = w2v_model_path
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
model.load_state_dict(w2v["model"], strict=True)
if torch.cuda.is_available():
model = model.cuda()
return model
class Wav2Vec2ASR(object):
"""docstring for Wav2Vec2ASR."""
def __init__(self, ctc_path, w2v_path, target_dict_path):
super(Wav2Vec2ASR, self).__init__()
self.target_dict = Dictionary.load(target_dict_path)
self.model = load_model(ctc_path, w2v_path, self.target_dict)
self.model.eval()
self.generator = W2lViterbiDecoder(self.target_dict)
def transcribe(self, audio_data, greedy=True):
aud_f = BytesIO(audio_data)
# aud_seg = pydub.AudioSegment.from_file(aud_f)
# feat_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
# feat_f = io.BytesIO()
# feat_seg.export(feat_f, format='wav')
# feat_f.seek(0)
net_input = {}
feature = get_feature(aud_f)
net_input["source"] = feature.unsqueeze(0)
padding_mask = (
torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)
)
if torch.cuda.is_available():
padding_mask = padding_mask.cuda()
net_input["padding_mask"] = padding_mask
sample = {}
sample["net_input"] = net_input
with torch.no_grad():
hypo = self.generator.generate(self.model, sample, prefix_tokens=None)
hyp_pieces = self.target_dict.string(hypo[0][0]["tokens"].int().cpu())
result = post_process(hyp_pieces, "letter")
return result

View File

@ -0,0 +1,86 @@
from pathlib import Path
from collections import Counter
import shutil
import soundfile
# import pydub
import typer
from tqdm import tqdm
from plume.utils import (
ExtendedPath,
replace_redundant_spaces_with,
lazy_module
)
pydub = lazy_module('pydub')
app = typer.Typer()
@app.command()
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True):
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
tok_counter = Counter()
shutil.copy(
src_dataset_path / Path("test_manifest.json"),
src_dataset_path / Path("valid_manifest.json"),
)
if unlink:
src_wavs = src_dataset_path / Path("wavs")
for wav_path in tqdm(list(src_wavs.glob("**/*.wav"))):
audio_seg = (
pydub.AudioSegment.from_wav(wav_path)
.set_frame_rate(16000)
.set_channels(1)
)
dest_path = dest_dataset_path / Path("wavs") / Path(wav_path.name)
audio_seg.export(dest_path, format="wav")
for dataset_kind in ["train", "valid"]:
abs_manifest_path = ExtendedPath(
src_dataset_path / Path(f"{dataset_kind}_manifest.json")
)
manifest_data = list(abs_manifest_path.read_jsonl())
o_tsv, o_ltr = f"{dataset_kind}.tsv", f"{dataset_kind}.ltr"
out_tsv = dest_dataset_path / Path(o_tsv)
out_ltr = dest_dataset_path / Path(o_ltr)
with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f:
if unlink:
tsv_f.write(f"{dest_dataset_path}\n")
else:
tsv_f.write(f"{src_dataset_path}\n")
for md in manifest_data:
audio_fname = md["audio_filepath"]
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
tok_counter.update(pipe_toks)
letter_toks = " ".join(pipe_toks) + " |\n"
frame_count = soundfile.info(audio_fname).frames
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute())
ltr_f.write(letter_toks)
tsv_f.write(f"{rel_path}\t{frame_count}\n")
with dict_ltr.open("w") as d_f:
for k, v in tok_counter.most_common():
d_f.write(f"{k} {v}\n")
(src_dataset_path / Path("valid_manifest.json")).unlink()
@app.command()
def set_root(dataset_path: Path, root_path: Path):
for dataset_kind in ["train", "valid"]:
data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
with data_file.open("r") as df:
lines = df.readlines()
with data_file.open("w") as df:
lines[0] = str(root_path) + "\n"
df.writelines(lines)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,49 @@
from pathlib import Path
import typer
from tqdm import tqdm
# import pandas as pd
from plume.utils import (
asr_manifest_reader,
discard_except_digits,
replace_digit_symbol,
lazy_module
# run_shell,
)
from ...utils.transcribe import triton_transcribe_grpc_gen
pd = lazy_module('pandas')
app = typer.Typer()
@app.command()
def manifest(manifest_file: Path, result_file: Path = "results.csv"):
from pydub import AudioSegment
host = "localhost"
port = 8044
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole')
result_path = manifest_file.parent / result_file
manifest_list = list(asr_manifest_reader(manifest_file))
def compute_frame(d):
audio_file = d["audio_path"]
orig_text = d["text"]
orig_num = discard_except_digits(replace_digit_symbol(orig_text))
aud_seg = AudioSegment.from_file(audio_file)
t_audio = audio_prep(aud_seg)
asr_text = transcriber(t_audio)
asr_num = discard_except_digits(replace_digit_symbol(asr_text))
return {
"audio_file": audio_file,
"asr_text": asr_text,
"asr_num": asr_num,
"orig_text": orig_text,
"orig_num": orig_num,
"asr_match": orig_num == asr_num,
}
# df_data = parallel_apply(compute_frame, manifest_list)
df_data = map(compute_frame, tqdm(manifest_list))
df = pd.DataFrame(df_data)
df.to_csv(result_path)

View File

@ -0,0 +1,53 @@
import os
import logging
from pathlib import Path
# from rpyc.utils.server import ThreadedServer
import typer
from ...utils.serve import ASRService
from plume.utils import lazy_callable
# from .asr import Wav2Vec2ASR
ThreadedServer = lazy_callable('rpyc.utils.server.ThreadedServer')
Wav2Vec2ASR = lazy_callable('plume.models.wav2vec2.asr.Wav2Vec2ASR')
app = typer.Typer()
@app.command()
def rpyc(
w2v_path: Path = "/path/to/base.pt",
ctc_path: Path = "/path/to/ctc.pt",
target_dict_path: Path = "/path/to/dict.ltr.txt",
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
):
for p in [w2v_path, ctc_path, target_dict_path]:
if not p.exists():
logging.info(f"{p} doesn't exists")
return
w2vasr = Wav2Vec2ASR(str(ctc_path), str(w2v_path), str(target_dict_path))
service = ASRService(w2vasr)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logging.info("starting asr server...")
t = ThreadedServer(service, port=port)
t.start()
@app.command()
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
ctc_path = model_dir / Path("ctc.pt")
w2v_path = model_dir / Path("base.pt")
target_dict_path = model_dir / Path("dict.ltr.txt")
rpyc(w2v_path, ctc_path, target_dict_path, port)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,34 @@
import typer
# from fairseq_cli.train import cli_main
import sys
from pathlib import Path
import shlex
from plume.utils import lazy_callable
cli_main = lazy_callable('fairseq_cli.train.cli_main')
app = typer.Typer()
@app.command()
def local(dataset_path: Path):
args = f'''--distributed-world-size 1 {dataset_path} \
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
'''
new_args = ['train.py']
new_args.extend(shlex.split(args))
sys.argv = new_args
cli_main()
if __name__ == "__main__":
cli_main()

64
plume/ui/__init__.py Normal file
View File

@ -0,0 +1,64 @@
import typer
import sys
from pathlib import Path
from plume.utils import lazy_module
# from streamlit import cli as stcli
stcli = lazy_module('streamlit.cli')
app = typer.Typer()
@app.command()
def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
annotation_lit_path = Path(__file__).parent / Path("annotation.py")
if task_id:
sys.argv = [
"streamlit",
"run",
str(annotation_lit_path),
"--",
str(data_dir),
"--task-id",
task_id,
"--dump-fname",
dump_fname,
]
else:
sys.argv = [
"streamlit",
"run",
str(annotation_lit_path),
"--",
str(data_dir),
"--dump-fname",
dump_fname,
]
sys.exit(stcli.main())
@app.command()
def preview(manifest_path: Path):
annotation_lit_path = Path(__file__).parent / Path("preview.py")
sys.argv = [
"streamlit",
"run",
str(annotation_lit_path),
"--",
str(manifest_path)
]
sys.exit(stcli.main())
@app.command()
def collection(data_dir: Path, task_id: str = ""):
# TODO: Implement web ui for data collection
pass
def main():
app()
if __name__ == "__main__":
main()

View File

@ -1,10 +1,12 @@
# import sys
from pathlib import Path
from uuid import uuid4
import streamlit as st
import typer
from uuid import uuid4
from ..utils import ExtendedPath, get_mongo_conn
from .st_rerun import rerun
from plume.utils import ExtendedPath, get_mongo_conn
from plume.preview.st_rerun import rerun
app = typer.Typer()
@ -42,10 +44,10 @@ if not hasattr(st, "mongo_connected"):
upsert=True,
)
def set_task_fn(mf_path, task_id):
def set_task_fn(data_path, task_id):
if task_id:
st.task_id = task_id
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
task_path = data_path / Path(f"task-{st.task_id}.lck")
if not task_path.exists():
print(f"creating task lock at {task_path}")
task_path.touch()
@ -62,17 +64,28 @@ if not hasattr(st, "mongo_connected"):
@st.cache()
def load_ui_data(validation_ui_data_path: Path):
def load_ui_data(data_dir: Path, dump_fname: Path):
validation_ui_data_path = data_dir / dump_fname
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
return ExtendedPath(validation_ui_data_path).read_json()
def show_key(sample, key, trail=""):
if key in sample:
title = key.replace("_", " ").title()
if type(sample[key]) == float:
st.sidebar.markdown(f"{title}: {sample[key]:.2f}{trail}")
else:
st.sidebar.markdown(f"{title}: {sample[key]}")
@app.command()
def main(manifest: Path, task_id: str = ""):
st.set_task(manifest, task_id)
ui_config = load_ui_data(manifest)
def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
st.set_task(data_dir, task_id)
ui_config = load_ui_data(data_dir, dump_fname)
asr_data = ui_config["data"]
annotation_only = ui_config.get("annotation_only", False)
asr_result_key = ui_config.get("asr_result_key", "pretrained_asr")
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
@ -91,15 +104,16 @@ def main(manifest: Path, task_id: str = ""):
st.update_cursor(new_sample - 1)
st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
# if "caller" in sample:
# st.sidebar.markdown(f"Caller: **{sample['caller']}**")
show_key(sample, "caller")
if not annotation_only:
st.sidebar.title("Results:")
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if "caller" in sample:
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb"))
show_key(sample, asr_result_key)
show_key(sample, "asr_wer", trail="%")
show_key(sample, "correct_candidate")
st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes())
st.audio((data_dir / Path(sample["audio_path"])).open("rb"))
# set default to text
corrected = sample["text"]
correction_entry = st.get_correction_entry(sample["utterance_id"])

58
plume/ui/preview.py Normal file
View File

@ -0,0 +1,58 @@
from pathlib import Path
import streamlit as st
import typer
from plume.utils import ExtendedPath
from plume.preview.st_rerun import rerun
app = typer.Typer()
if not hasattr(st, "state_lock"):
# st.task_id = str(uuid4())
task_path = ExtendedPath("preview.lck")
def current_cursor_fn():
return task_path.read_json()["current_cursor"]
def update_cursor_fn(val=0):
task_path.write_json({"current_cursor": val})
rerun()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn
st.state_lock = True
# cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
# if not cursor_obj:
update_cursor_fn(0)
@st.cache()
def load_ui_data(validation_ui_data_path: Path):
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
@app.command()
def main(manifest: Path):
asr_data = load_ui_data(manifest)
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
st.title(f"ASR Manifest Preview")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
st.audio((manifest.parent / Path(sample["audio_filepath"])).open("rb"))
if __name__ == "__main__":
try:
app()
except SystemExit:
pass

View File

@ -1,7 +1,15 @@
import streamlit.ReportThread as ReportThread
from streamlit.ScriptRequestQueue import RerunData
from streamlit.ScriptRunner import RerunException
from streamlit.server.Server import Server
try:
# Before Streamlit 0.65
from streamlit.ReportThread import get_report_ctx
from streamlit.server.Server import Server
from streamlit.ScriptRequestQueue import RerunData
from streamlit.ScriptRunner import RerunException
except ModuleNotFoundError:
# After Streamlit 0.65
from streamlit.report_thread import get_report_ctx
from streamlit.server.server import Server
from streamlit.script_request_queue import RerunData
from streamlit.script_runner import RerunException
def rerun():
@ -13,7 +21,7 @@ def rerun():
def _get_widget_states():
# Hack to get the session object from Streamlit.
ctx = ReportThread.get_report_ctx()
ctx = get_report_ctx()
session = None
@ -34,5 +42,4 @@ def _get_widget_states():
"Are you doing something fancy with threads?"
)
# Got the session object!
return session._widget_states

486
plume/utils/__init__.py Normal file
View File

@ -0,0 +1,486 @@
import io
import os
import re
import json
import wave
import logging
from pathlib import Path
from functools import partial
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import subprocess
import shutil
from urllib.parse import urlsplit
# from .lazy_loader import LazyLoader
from .lazy_import import lazy_callable, lazy_module
# from ruamel.yaml import YAML
# import boto3
import typer
# import pymongo
# from slugify import slugify
# import pydub
# import matplotlib.pyplot as plt
# import librosa
# import librosa.display as audio_display
# from natural.date import compress
# from num2words import num2words
from tqdm import tqdm
from datetime import timedelta
# from .transcribe import triton_transcribe_grpc_gen
# from .eval import app as eval_app
from .tts import app as tts_app
from .transcribe import app as transcribe_app
from .align import app as align_app
boto3 = lazy_module('boto3')
pymongo = lazy_module('pymongo')
pydub = lazy_module('pydub')
audio_display = lazy_module('librosa.display')
plt = lazy_module('matplotlib.pyplot')
librosa = lazy_module('librosa')
YAML = lazy_callable('ruamel.yaml.YAML')
num2words = lazy_callable('num2words.num2words')
slugify = lazy_callable('slugify.slugify')
compress = lazy_callable('natural.date.compress')
app = typer.Typer()
app.add_typer(tts_app, name="tts")
app.add_typer(align_app, name="align")
app.add_typer(transcribe_app, name="transcribe")
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def manifest_str(path, dur, text):
return (
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
+ "\n"
)
def duration_str(seconds):
return compress(timedelta(seconds=seconds), pad=" ")
def replace_digit_symbol(w2v_out):
num_int_map = {num2words(i): str(i) for i in range(10)}
out = w2v_out.lower()
for (k, v) in num_int_map.items():
out = re.sub(k, v, out)
return out
def discard_except_digits(inp):
return re.sub("[^0-9]", "", inp)
def digits_to_chars(text):
num_tokens = [num2words(c) + " " if "0" <= c <= "9" else c for c in text]
return ("".join(num_tokens)).lower()
def replace_redundant_spaces_with(text, sub):
return re.sub(" +", sub, text)
def space_out(text):
letters = " ".join(list(text))
return letters
def wav_bytes(audio_bytes, frame_rate=24000):
wf_b = io.BytesIO()
with wave.open(wf_b, mode="w") as wf:
wf.setnchannels(1)
wf.setframerate(frame_rate)
wf.setsampwidth(2)
wf.writeframesraw(audio_bytes)
return wf_b.getvalue()
def tscript_uuid_fname(transcript):
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
def run_shell(cmd_str, work_dir="."):
cwd_path = Path(work_dir).absolute()
p = subprocess.Popen(
cmd_str,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
shell=True,
cwd=cwd_path,
)
for line in p.stdout:
print(line.replace(b"\n", b"").decode("utf-8"))
def upload_s3(dataset_path, s3_path):
run_shell(f"aws s3 sync {dataset_path} {s3_path}")
def get_download_path(s3_uri, output_path):
s3_uri_p = urlsplit(s3_uri)
download_path = output_path / Path(s3_uri_p.path[1:])
download_path.parent.mkdir(exist_ok=True, parents=True)
return download_path
def s3_downloader():
s3 = boto3.client("s3")
def download_s3(s3_uri, download_path):
s3_uri_p = urlsplit(s3_uri)
download_path.parent.mkdir(exist_ok=True, parents=True)
if not download_path.exists():
print(f"downloading {s3_uri} to {download_path}")
s3.download_file(s3_uri_p.netloc, s3_uri_p.path[1:], str(download_path))
return download_s3
def asr_data_writer(dataset_dir, asr_data_source, verbose=False):
(dataset_dir / Path("wavs")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for transcript, audio_dur, wav_data in asr_data_source:
fname = tscript_uuid_fname(transcript)
audio_file = dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
rel_data_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
if verbose:
print(f"writing '{transcript}' of duration {audio_dur}")
num_datapoints += 1
return num_datapoints
def ui_data_generator(dataset_dir, asr_data_source, verbose=False):
(dataset_dir / Path("wavs")).mkdir(parents=True, exist_ok=True)
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
def data_fn(
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_file,
num_datapoints,
rel_data_path,
):
png_path = Path(fname).with_suffix(".png")
rel_plot_path = Path("wav_plots") / png_path
wav_plot_path = dataset_dir / rel_plot_path
if not wav_plot_path.exists():
plot_seg(wav_plot_path.absolute(), audio_file)
return {
"audio_path": str(rel_data_path),
"duration": round(audio_dur, 1),
"text": transcript,
"real_idx": num_datapoints,
"caller": caller_name,
"utterance_id": fname,
"plot_path": str(rel_plot_path),
}
num_datapoints = 0
data_funcs = []
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = (
dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav")
).absolute()
audio_file.write_bytes(wav_data)
# audio_path = str(audio_file)
rel_data_path = audio_file.relative_to(dataset_dir.absolute())
data_funcs.append(
partial(
data_fn,
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_file,
num_datapoints,
rel_data_path,
)
)
num_datapoints += 1
ui_data = parallel_apply(lambda x: x(), data_funcs)
return ui_data, num_datapoints
def ui_dump_manifest_writer(dataset_dir, asr_data_source, verbose=False):
dump_data, num_datapoints = ui_data_generator(
dataset_dir, asr_data_source, verbose=verbose
)
asr_manifest = dataset_dir / Path("manifest.json")
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for d in dump_data:
rel_data_path = d["audio_path"]
audio_dur = d["duration"]
transcript = d["text"]
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
ui_dump_file = dataset_dir / Path("ui_dump.json")
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
return num_datapoints
def asr_manifest_reader(data_manifest_path: Path):
print(f"reading manifest from {data_manifest_path}")
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
data_data = [json.loads(v) for v in data_jsonl]
for p in data_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["text"] = p["text"].strip()
yield p
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
with asr_manifest_path.open("w") as mf:
print(f"opening {asr_manifest_path} for writing manifest")
for mani_dict in manifest_str_source:
manifest = manifest_str(
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
)
mf.write(manifest)
def asr_test_writer(out_file_path: Path, source):
def dd_str(dd, idx):
path = dd["audio_filepath"]
# dur = dd["duration"]
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
res_file = out_file_path.with_suffix(".result.json")
with out_file_path.open("w") as of:
print(f"opening {out_file_path} for writing test")
results = []
idx = 0
for ui_dd in source:
results.append(ui_dd)
out_str = dd_str(ui_dd, idx)
of.write(out_str)
idx += 1
of.write("DO_HANGUP\n")
ExtendedPath(res_file).write_json(results)
def batch(iterable, n=1):
ls = len(iterable)
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
class ExtendedPath(type(Path())):
"""docstring for ExtendedPath."""
def read_json(self):
print(f"reading json from {self}")
with self.open("r") as jf:
return json.load(jf)
def read_yaml(self):
yaml = YAML(typ="safe", pure=True)
print(f"reading yaml from {self}")
with self.open("r") as yf:
return yaml.load(yf)
def read_jsonl(self):
print(f"reading jsonl from {self}")
with self.open("r") as jf:
for l in jf.readlines():
yield json.loads(l)
def write_json(self, data):
print(f"writing json to {self}")
self.parent.mkdir(parents=True, exist_ok=True)
with self.open("w") as jf:
json.dump(data, jf, indent=2)
def write_yaml(self, data):
yaml = YAML()
print(f"writing yaml to {self}")
with self.open("w") as yf:
yaml.dump(data, yf)
def write_jsonl(self, data):
print(f"writing jsonl to {self}")
self.parent.mkdir(parents=True, exist_ok=True)
with self.open("w") as jf:
for d in data:
jf.write(json.dumps(d) + "\n")
def get_mongo_coll(uri):
ud = pymongo.uri_parser.parse_uri(uri)
conn = pymongo.MongoClient(uri)
return conn[ud["database"]][ud["collection"]]
def get_mongo_conn(host="", port=27017, db="db", col="collection"):
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
mongo_uri = f"mongodb://{mongo_host}:{port}/"
return pymongo.MongoClient(mongo_uri)[db][col]
def strip_silence(sound):
from pydub.silence import detect_leading_silence
start_trim = detect_leading_silence(sound)
end_trim = detect_leading_silence(sound.reverse())
duration = len(sound)
return sound[start_trim : duration - end_trim]
def plot_seg(wav_plot_path, audio_path):
fig = plt.Figure()
ax = fig.add_subplot()
(y, sr) = librosa.load(str(audio_path))
audio_display.waveplot(y=y, sr=sr, ax=ax)
with wav_plot_path.open("wb") as wav_plot_f:
fig.set_tight_layout(True)
fig.savefig(wav_plot_f, format="png", dpi=50)
def parallel_apply(fn, iterable, workers=8, pool="thread"):
if pool == "thread":
with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")
return [
res
for res in tqdm(
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
)
]
elif pool == "process":
with ProcessPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")
return [
res
for res in tqdm(
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
)
]
else:
raise Exception(f"unsupported pool type - {pool}")
def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
min_nums = 3
max_duration = 1 * 60 * 60
skip_duration = 1 * 60 * 60
def filtered_max_dur():
wav_duration = 0
for s in ExtendedPath(data_file).read_jsonl():
nums = re.sub(" ", "", s["text"])
if len(nums) >= min_nums:
wav_duration += s["duration"]
shutil.copy(
src_dataset_path / Path(s["audio_filepath"]),
dest_dataset_path / Path(s["audio_filepath"]),
)
yield s
if wav_duration > max_duration:
break
typer.echo(f"filtered only {duration_str(wav_duration)} of audio")
def filtered_skip_dur():
wav_duration = 0
for s in ExtendedPath(data_file).read_jsonl():
nums = re.sub(" ", "", s["text"])
if len(nums) >= min_nums:
wav_duration += s["duration"]
if wav_duration <= skip_duration:
continue
elif len(nums) >= min_nums:
yield s
shutil.copy(
src_dataset_path / Path(s["audio_filepath"]),
dest_dataset_path / Path(s["audio_filepath"]),
)
typer.echo(f"skipped {duration_str(skip_duration)} of audio")
def filtered_blanks():
blank_count = 0
for s in ExtendedPath(data_file).read_jsonl():
nums = re.sub(" ", "", s["text"])
if nums != "":
blank_count += 1
shutil.copy(
src_dataset_path / Path(s["audio_filepath"]),
dest_dataset_path / Path(s["audio_filepath"]),
)
yield s
typer.echo(f"filtered {blank_count} blank samples")
def filtered_transform_digits():
count = 0
for s in ExtendedPath(data_file).read_jsonl():
count += 1
digit_text = replace_digit_symbol(s["text"])
only_digits = discard_except_digits(digit_text)
char_text = digits_to_chars(only_digits)
shutil.copy(
src_dataset_path / Path(s["audio_filepath"]),
dest_dataset_path / Path(s["audio_filepath"]),
)
s["text"] = char_text
yield s
typer.echo(f"transformed {count} samples")
def filtered_extract_chars():
count = 0
for s in ExtendedPath(data_file).read_jsonl():
count += 1
no_digits = digits_to_chars(s["text"]).upper()
only_chars = re.sub("[^A-Z'\b]", " ", no_digits)
filter_text = replace_redundant_spaces_with(only_chars, " ").strip()
shutil.copy(
src_dataset_path / Path(s["audio_filepath"]),
dest_dataset_path / Path(s["audio_filepath"]),
)
s["text"] = filter_text
yield s
typer.echo(f"transformed {count} samples")
def filtered_resample():
count = 0
for s in ExtendedPath(data_file).read_jsonl():
count += 1
src_aud = pydub.AudioSegment.from_file(
src_dataset_path / Path(s["audio_filepath"])
)
dst_aud = src_aud.set_channels(1).set_sample_width(1).set_frame_rate(24000)
dst_aud.export(dest_dataset_path / Path(s["audio_filepath"]), format="wav")
yield s
typer.echo(f"transformed {count} samples")
filter_kind_map = {
"max_dur_1hr_min3num": filtered_max_dur,
"skip_dur_1hr_min3num": filtered_skip_dur,
"blanks": filtered_blanks,
"transform_digits": filtered_transform_digits,
"extract_chars": filtered_extract_chars,
"resample_ulaw24kmono": filtered_resample,
}
return filter_kind_map

117
plume/utils/align.py Normal file
View File

@ -0,0 +1,117 @@
from pathlib import Path
from .tts import GoogleTTS
# from IPython import display
import requests
import io
import typer
from plume.utils import lazy_module
display = lazy_module('IPython.display')
pydub = lazy_module('pydub')
app = typer.Typer()
# Start gentle with following command
# docker run --rm -d --name gentle_service -p 8765:8765/tcp lowerquality/gentle
def gentle_aligner(service_uri, wav_data, utter_text):
# service_uri= "http://52.41.161.36:8765/transcriptions"
wav_f = io.BytesIO(wav_data)
wav_seg = pydub.AudioSegment.from_file(wav_f)
mp3_f = io.BytesIO()
wav_seg.export(mp3_f, format="mp3")
mp3_f.seek(0)
params = (("async", "false"),)
files = {
"audio": ("audio.mp3", mp3_f),
"transcript": ("words.txt", io.BytesIO(utter_text.encode("utf-8"))),
}
response = requests.post(service_uri, params=params, files=files)
print(f"Time duration of audio {wav_seg.duration_seconds}")
print(f"Time taken to align: {response.elapsed}s")
return wav_seg, response.json()
def gentle_align_iter(service_uri, wav_data, utter_text):
wav_seg, response = gentle_aligner(service_uri, wav_data, utter_text)
for span in response:
word_seg = wav_seg[int(span["start"] * 1000) : int(span["end"] * 1000)]
word = span["word"]
yield (word, word_seg)
def tts_jupyter():
google_voices = GoogleTTS.voice_list()
gtts = GoogleTTS()
# google_voices[4]
us_voice = [v for v in google_voices if v["language"] == "en-US"][0]
utter_text = (
"I would like to align the audio segments based on word level timestamps"
)
wav_data = gtts.text_to_speech(text=utter_text, params=us_voice)
for word, seg in gentle_align_iter(wav_data, utter_text):
print(word)
display.display(seg)
@app.command()
def cut(audio_path: Path, transcript_path: Path, out_dir: Path = "/tmp"):
from . import ExtendedPath
import datetime
import re
aud_seg = pydub.AudioSegment.from_file(audio_path)
aud_seg[: 15 * 60 * 1000].export(out_dir / Path("audio.mp3"), format="mp3")
tscript_json = ExtendedPath(transcript_path).read_json()
def time_to_msecs(time_str):
return (
datetime.datetime.strptime(time_str, "%H:%M:%S,%f")
- datetime.datetime(1900, 1, 1)
).total_seconds() * 1000
tscript_words = []
broken = False
for m in tscript_json["monologues"]:
# tscript_words.append("|")
for e in m["elements"]:
if e["type"] == "text":
text = e["value"]
text = re.sub(r"\[.*\]", "", text)
text = re.sub(r"\(.*\)", "", text)
tscript_words.append(text)
if "timestamp" in e and time_to_msecs(e["timestamp"]) >= 15 * 60 * 1000:
broken = True
break
if broken:
break
(out_dir / Path("words.txt")).write_text("".join(tscript_words))
@app.command()
def gentle_preview(
audio_path: Path,
transcript_path: Path,
service_uri="http://101.53.142.218:8765/transcriptions",
gent_preview_dir="../gentle_preview",
):
from . import ExtendedPath
ab = audio_path.read_bytes()
tt = transcript_path.read_text()
audio, alignment = gentle_aligner(service_uri, ab, tt)
audio.export(gent_preview_dir / Path("a.wav"), format="wav")
alignment["status"] = "OK"
ExtendedPath(gent_preview_dir / Path("status.json")).write_json(alignment)
def main():
app()
if __name__ == "__main__":
main()

28
plume/utils/audio.py Normal file
View File

@ -0,0 +1,28 @@
from scipy.signal import lfilter, butter
from scipy.io.wavfile import read, write
from numpy import array, int16
import sys
def butter_params(low_freq, high_freq, fs, order=5):
nyq = 0.5 * fs
low = low_freq / nyq
high = high_freq / nyq
b, a = butter(order, [low, high], btype="band")
return b, a
def butter_bandpass_filter(data, low_freq, high_freq, fs, order=5):
b, a = butter_params(low_freq, high_freq, fs, order=order)
y = lfilter(b, a, data)
return y
if __name__ == "__main__":
fs, audio = read(sys.argv[1])
import pdb; pdb.set_trace()
low_freq = 300.0
high_freq = 4000.0
filtered_signal = butter_bandpass_filter(audio, low_freq, high_freq, fs, order=6)
fname = sys.argv[1].split(".wav")[0] + "_moded.wav"
write(fname, fs, array(filtered_signal, dtype=int16))

737
plume/utils/lazy_import.py Normal file
View File

@ -0,0 +1,737 @@
# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4
#
# lazy_import --- https://github.com/mnmelo/lazy_import
# Copyright (C) 2017-2018 Manuel Nuno Melo
#
# This file is part of lazy_import.
#
# lazy_import is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# lazy_import is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with lazy_import. If not, see <http://www.gnu.org/licenses/>.
#
# lazy_import was based on code from the importing module from the PEAK
# package (see <http://peak.telecommunity.com/DevCenter/Importing>). The PEAK
# package is released under the following license, reproduced here:
#
# Copyright (C) 1996-2004 by Phillip J. Eby and Tyler C. Sarna.
# All rights reserved. This software may be used under the same terms
# as Zope or Python. THERE ARE ABSOLUTELY NO WARRANTIES OF ANY KIND.
# Code quality varies between modules, from "beta" to "experimental
# pre-alpha". :)
#
# Code pertaining to lazy loading from PEAK importing was included in
# lazy_import, modified in a number of ways. These are detailed in the
# CHANGELOG file of lazy_import. Changes mainly involved Python 3
# compatibility, extension to allow customizable behavior, and added
# functionality (lazy importing of callable objects).
#
"""
Lazy module loading
===================
Functions and classes for lazy module loading that also delay import errors.
Heavily borrowed from the `importing`_ module.
.. _`importing`: http://peak.telecommunity.com/DevCenter/Importing
Files and directories
---------------------
.. autofunction:: module
.. autofunction:: callable
"""
__all__ = [
"lazy_module",
"lazy_callable",
"lazy_function",
"lazy_class",
"LazyModule",
"LazyCallable",
"module_basename",
"_MSG",
"_MSG_CALLABLE",
]
from types import ModuleType
import sys
try:
from importlib._bootstrap import _ImportLockContext
except ImportError:
# Python 2 doesn't have the context manager. Roll it ourselves (copied from
# Python 3's importlib/_bootstrap.py)
import imp
class _ImportLockContext:
"""Context manager for the import lock."""
def __enter__(self):
imp.acquire_lock()
def __exit__(self, exc_type, exc_value, exc_traceback):
imp.release_lock()
# Adding a __spec__ doesn't really help. I'll leave the code here in case
# future python implementations start relying on it.
# try:
# from importlib.machinery import ModuleSpec
# except ImportError:
# ModuleSpec = None
import six
from six import raise_from
from six.moves import reload_module
# It is sometime useful to have access to the version number of a library.
# This is usually done through the __version__ special attribute.
# To make sure the version number is consistent between setup.py and the
# library, we read the version number from the file called VERSION that stays
# in the module directory.
import os
# VERSION_FILE = os.path.join(os.path.dirname(__file__), "VERSION")
# with open(VERSION_FILE) as infile:
# __version__ = infile.read().strip()
# Logging
import logging
# adding a TRACE level for stack debugging
_LAZY_TRACE = 1
logging.addLevelName(1, "LAZY_TRACE")
logging.basicConfig(level=logging.WARNING)
# Logs a formatted stack (takes no message or args/kwargs)
def _lazy_trace(self):
if self.isEnabledFor(_LAZY_TRACE):
import traceback
self._log(_LAZY_TRACE, " ### STACK TRACE ###", ())
for line in traceback.format_stack(sys._getframe(2)):
for subline in line.split("\n"):
self._log(_LAZY_TRACE, subline.rstrip(), ())
logging.Logger.lazy_trace = _lazy_trace
logger = logging.getLogger(__name__)
################################
# Module/function registration #
################################
#### Lazy classes ####
class LazyModule(ModuleType):
"""Class for lazily-loaded modules that triggers proper loading on access.
Instantiation should be made from a subclass of :class:`LazyModule`, with
one subclass per instantiated module. Regular attribute set/access can then
be recovered by setting the subclass's :meth:`__getattribute__` and
:meth:`__setattribute__` to those of :class:`types.ModuleType`.
"""
# peak.util.imports sets __slots__ to (), but it seems pointless because
# the base ModuleType doesn't itself set __slots__.
def __getattribute__(self, attr):
logger.debug(
"Getting attr {} of LazyModule instance of {}".format(
attr, super(LazyModule, self).__getattribute__("__name__")
)
)
logger.lazy_trace()
# IPython tries to be too clever and constantly inspects, asking for
# modules' attrs, which causes premature module loading and unesthetic
# internal errors if the lazily-loaded module doesn't exist.
if (
run_from_ipython()
and (attr.startswith(("__", "_ipython")) or attr == "_repr_mimebundle_")
and module_basename(_caller_name()) in ("inspect", "IPython")
):
logger.debug(
"Ignoring request for {}, deemed from IPython's "
"inspection.".format(
super(LazyModule, self).__getattribute__("__name__"), attr
)
)
raise AttributeError
if not attr in ("__name__", "__class__", "__spec__"):
# __name__ and __class__ yield their values from the LazyModule;
# __spec__ causes an AttributeError. Maybe in the future it will be
# necessary to return an actual ModuleSpec object, but it works as
# it is without that now.
# If it's an already-loaded submodule, we return it without
# triggering a full loading
try:
return sys.modules[self.__name__ + "." + attr]
except KeyError:
pass
# Check if it's one of the lazy callables
try:
_callable = type(self)._lazy_import_callables[attr]
logger.debug("Returning lazy-callable '{}'.".format(attr))
return _callable
except (AttributeError, KeyError) as err:
logger.debug(
"Proceeding to load module {}, "
"from requested value {}".format(
super(LazyModule, self).__getattribute__("__name__"), attr
)
)
_load_module(self)
logger.debug(
"Returning value '{}'.".format(
super(LazyModule, self).__getattribute__(attr)
)
)
return super(LazyModule, self).__getattribute__(attr)
def __setattr__(self, attr, value):
logger.debug(
"Setting attr {} to value {}, in LazyModule instance "
"of {}".format(
attr, value, super(LazyModule, self).__getattribute__("__name__")
)
)
_load_module(self)
return super(LazyModule, self).__setattr__(attr, value)
class LazyCallable(object):
"""Class for lazily-loaded callables that triggers module loading on access
"""
def __init__(self, *args):
if len(args) != 2:
# Maybe the user tried to base a class off this lazy callable?
try:
logger.debug(
"Got wrong number of args when init'ing "
"LazyCallable. args is '{}'".format(args)
)
base = args[1][0]
if isinstance(base, LazyCallable) and len(args) == 3:
raise NotImplementedError(
"It seems you are trying to use "
"a lazy callable as a class "
"base. This is not supported."
)
except (IndexError, TypeError):
raise_from(
TypeError(
"LazyCallable takes exactly 2 arguments: "
"a module/lazy module object and the name of "
"a callable to be lazily loaded."
),
None,
)
self.module, self.cname = args
self.modclass = type(self.module)
self.callable = None
# Need to save these, since the module-loading gets rid of them
self.error_msgs = self.modclass._lazy_import_error_msgs
self.error_strings = self.modclass._lazy_import_error_strings
def __call__(self, *args, **kwargs):
# No need to go through all the reloading more than once.
if self.callable:
return self.callable(*args, **kwargs)
try:
del self.modclass._lazy_import_callables[self.cname]
except (AttributeError, KeyError):
pass
try:
self.callable = getattr(self.module, self.cname)
except AttributeError:
msg = self.error_msgs["msg_callable"]
raise_from(
AttributeError(msg.format(callable=self.cname, **self.error_strings)),
None,
)
except ImportError as err:
# Import failed. We reset the dict and re-raise the ImportError.
try:
self.modclass._lazy_import_callables[self.cname] = self
except AttributeError:
self.modclass._lazy_import_callables = {self.cname: self}
raise_from(err, None)
else:
return self.callable(*args, **kwargs)
### Functions ###
def lazy_module(modname, error_strings=None, lazy_mod_class=LazyModule, level="leaf"):
"""Function allowing lazy importing of a module into the namespace.
A lazy module object is created, registered in `sys.modules`, and
returned. This is a hollow module; actual loading, and `ImportErrors` if
not found, are delayed until an attempt is made to access attributes of the
lazy module.
A handy application is to use :func:`lazy_module` early in your own code
(say, in `__init__.py`) to register all modulenames you want to be lazy.
Because of registration in `sys.modules` later invocations of
`import modulename` will also return the lazy object. This means that after
initial registration the rest of your code can use regular pyhon import
statements and retain the lazyness of the modules.
Parameters
----------
modname : str
The module to import.
error_strings : dict, optional
A dictionary of strings to use when module-loading fails. Key 'msg'
sets the message to use (defaults to :attr:`lazy_import._MSG`). The
message is formatted using the remaining dictionary keys. The default
message informs the user of which module is missing (key 'module'),
what code loaded the module as lazy (key 'caller'), and which package
should be installed to solve the dependency (key 'install_name').
None of the keys is mandatory and all are given smart names by default.
lazy_mod_class: type, optional
Which class to use when instantiating the lazy module, to allow
deep customization. The default is :class:`LazyModule` and custom
alternatives **must** be a subclass thereof.
level : str, optional
Which submodule reference to return. Either a reference to the 'leaf'
module (the default) or to the 'base' module. This is useful if you'll
be using the module functionality in the same place you're calling
:func:`lazy_module` from, since then you don't need to run `import`
again. Setting *level* does not affect which names/modules get
registered in `sys.modules`.
For *level* set to 'base' and *modulename* 'aaa.bbb.ccc'::
aaa = lazy_import.lazy_module("aaa.bbb.ccc", level='base')
# 'aaa' becomes defined in the current namespace, with
# (sub)attributes 'aaa.bbb' and 'aaa.bbb.ccc'.
# It's the lazy equivalent to:
import aaa.bbb.ccc
For *level* set to 'leaf'::
ccc = lazy_import.lazy_module("aaa.bbb.ccc", level='leaf')
# Only 'ccc' becomes set in the current namespace.
# Lazy equivalent to:
from aaa.bbb import ccc
Returns
-------
module
The module specified by *modname*, or its base, depending on *level*.
The module isn't immediately imported. Instead, an instance of
*lazy_mod_class* is returned. Upon access to any of its attributes, the
module is finally loaded.
Examples
--------
>>> import lazy_import, sys
>>> np = lazy_import.lazy_module("numpy")
>>> np
Lazily-loaded module numpy
>>> np is sys.modules['numpy']
True
>>> np.pi # This causes the full loading of the module ...
3.141592653589793
>>> np # ... and the module is changed in place.
<module 'numpy' from '/usr/local/lib/python/site-packages/numpy/__init__.py'>
>>> import lazy_import, sys
>>> # The following succeeds even when asking for a module that's not available
>>> missing = lazy_import.lazy_module("missing_module")
>>> missing
Lazily-loaded module missing_module
>>> missing is sys.modules['missing_module']
True
>>> missing.some_attr # This causes the full loading of the module, which now fails.
ImportError: __main__ attempted to use a functionality that requires module missing_module, but it couldn't be loaded. Please install missing_module and retry.
See Also
--------
:func:`lazy_callable`
:class:`LazyModule`
"""
if error_strings is None:
error_strings = {}
_set_default_errornames(modname, error_strings)
mod = _lazy_module(modname, error_strings, lazy_mod_class)
if level == "base":
return sys.modules[module_basename(modname)]
elif level == "leaf":
return mod
else:
raise ValueError("Parameter 'level' must be one of ('base', 'leaf')")
def _lazy_module(modname, error_strings, lazy_mod_class):
with _ImportLockContext():
fullmodname = modname
fullsubmodname = None
# ensure parent module/package is in sys.modules
# and parent.modname=module, as soon as the parent is imported
while modname:
try:
mod = sys.modules[modname]
# We reached a (base) module that's already loaded. Let's stop
# the cycle. Can't use 'break' because we still want to go
# through the fullsubmodname check below.
modname = ""
except KeyError:
err_s = error_strings.copy()
err_s.setdefault("module", modname)
class _LazyModule(lazy_mod_class):
_lazy_import_error_msgs = {"msg": err_s.pop("msg")}
try:
_lazy_import_error_msgs["msg_callable"] = err_s.pop(
"msg_callable"
)
except KeyError:
pass
_lazy_import_error_strings = err_s
_lazy_import_callables = {}
_lazy_import_submodules = {}
def __repr__(self):
return "Lazily-loaded module {}".format(self.__name__)
# A bit of cosmetic, to make AttributeErrors read more natural
_LazyModule.__name__ = "module"
# Actual module instantiation
mod = sys.modules[modname] = _LazyModule(modname)
# No need for __spec__. Maybe in the future.
# if ModuleSpec:
# ModuleType.__setattr__(mod, '__spec__',
# ModuleSpec(modname, None))
if fullsubmodname:
submod = sys.modules[fullsubmodname]
ModuleType.__setattr__(mod, submodname, submod)
_LazyModule._lazy_import_submodules[submodname] = submod
fullsubmodname = modname
modname, _, submodname = modname.rpartition(".")
return sys.modules[fullmodname]
def lazy_callable(modname, *names, **kwargs):
"""Performs lazy importing of one or more callables.
:func:`lazy_callable` creates functions that are thin wrappers that pass
any and all arguments straight to the target module's callables. These can
be functions or classes. The full loading of that module is only actually
triggered when the returned lazy function itself is called. This lazy
import of the target module uses the same mechanism as
:func:`lazy_module`.
If, however, the target module has already been fully imported prior
to invocation of :func:`lazy_callable`, then the target callables
themselves are returned and no lazy imports are made.
:func:`lazy_function` and :func:`lazy_function` are aliases of
:func:`lazy_callable`.
Parameters
----------
modname : str
The base module from where to import the callable(s) in *names*,
or a full 'module_name.callable_name' string.
names : str (optional)
The callable name(s) to import from the module specified by *modname*.
If left empty, *modname* is assumed to also include the callable name
to import.
error_strings : dict, optional
A dictionary of strings to use when reporting loading errors (either a
missing module, or a missing callable name in the loaded module).
*error_string* follows the same usage as described under
:func:`lazy_module`, with the exceptions that 1) a further key,
'msg_callable', can be supplied to be used as the error when a module
is successfully loaded but the target callable can't be found therein
(defaulting to :attr:`lazy_import._MSG_CALLABLE`); 2) a key 'callable'
is always added with the callable name being loaded.
lazy_mod_class : type, optional
See definition under :func:`lazy_module`.
lazy_call_class : type, optional
Analogously to *lazy_mod_class*, allows setting a custom class to
handle lazy callables, other than the default :class:`LazyCallable`.
Returns
-------
wrapper function or tuple of wrapper functions
If *names* is passed, returns a tuple of wrapper functions, one for
each element in *names*.
If only *modname* is passed it is assumed to be a full
'module_name.callable_name' string, in which case the wrapper for the
imported callable is returned directly, and not in a tuple.
Notes
-----
Unlike :func:`lazy_module`, which returns a lazy module that eventually
mutates into the fully-functional version, :func:`lazy_callable` only
returns thin wrappers that never change. This means that the returned
wrapper object never truly becomes the one under the module's namespace,
even after successful loading of the module in *modname*. This is fine for
most practical use cases, but may break code that relies on the usage of
the returned objects oter than calling them. One such example is the lazy
import of a class: it's fine to use the returned wrapper to instantiate an
object, but it can't be used, for instance, to subclass from.
Examples
--------
>>> import lazy_import, sys
>>> fn = lazy_import.lazy_callable("numpy.arange")
>>> sys.modules['numpy']
Lazily-loaded module numpy
>>> fn(10)
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> sys.modules['numpy']
<module 'numpy' from '/usr/local/lib/python3.5/site-packages/numpy/__init__.py'>
>>> import lazy_import, sys
>>> cl = lazy_import.lazy_callable("numpy.ndarray") # a class
>>> obj = cl([1, 2]) # This works OK (and also triggers the loading of numpy)
>>> class MySubclass(cl): # This fails because cls is just a wrapper,
>>> pass # not an actual class.
See Also
--------
:func:`lazy_module`
:class:`LazyCallable`
:class:`LazyModule`
"""
if not names:
modname, _, name = modname.rpartition(".")
lazy_mod_class = _setdef(kwargs, "lazy_mod_class", LazyModule)
lazy_call_class = _setdef(kwargs, "lazy_call_class", LazyCallable)
error_strings = _setdef(kwargs, "error_strings", {})
_set_default_errornames(modname, error_strings, call=True)
if not names:
# We allow passing a single string as 'modname.callable_name',
# in which case the wrapper is returned directly and not as a list.
return _lazy_callable(
modname, name, error_strings.copy(), lazy_mod_class, lazy_call_class
)
return tuple(
_lazy_callable(
modname, cname, error_strings.copy(), lazy_mod_class, lazy_call_class
)
for cname in names
)
lazy_function = lazy_class = lazy_callable
def _lazy_callable(modname, cname, error_strings, lazy_mod_class, lazy_call_class):
# We could do most of this in the LazyCallable __init__, but here we can
# pre-check whether to actually be lazy or not.
module = _lazy_module(modname, error_strings, lazy_mod_class)
modclass = type(module)
if issubclass(modclass, LazyModule) and hasattr(modclass, "_lazy_import_callables"):
modclass._lazy_import_callables.setdefault(
cname, lazy_call_class(module, cname)
)
return getattr(module, cname)
#######################
# Real module loading #
#######################
def _load_module(module):
"""Ensures that a module, and its parents, are properly loaded
"""
modclass = type(module)
# We only take care of our own LazyModule instances
if not issubclass(modclass, LazyModule):
raise TypeError("Passed module is not a LazyModule instance.")
with _ImportLockContext():
parent, _, modname = module.__name__.rpartition(".")
logger.debug("loading module {}".format(modname))
# We first identify whether this is a loadable LazyModule, then we
# strip as much of lazy_import behavior as possible (keeping it cached,
# in case loading fails and we need to reset the lazy state).
if not hasattr(modclass, "_lazy_import_error_msgs"):
# Alreay loaded (no _lazy_import_error_msgs attr). Not reloading.
return
# First, ensure the parent is loaded (using recursion; *very* unlikely
# we'll ever hit a stack limit in this case).
modclass._LOADING = True
try:
if parent:
logger.debug("first loading parent module {}".format(parent))
setattr(sys.modules[parent], modname, module)
if not hasattr(modclass, "_LOADING"):
logger.debug("Module {} already loaded by the parent".format(modname))
# We've been loaded by the parent. Let's bail.
return
cached_data = _clean_lazymodule(module)
try:
# Get Python to do the real import!
reload_module(module)
except:
# Loading failed. We reset our lazy state.
logger.debug("Failed to load module {}. Resetting...".format(modname))
_reset_lazymodule(module, cached_data)
raise
else:
# Successful load
logger.debug("Successfully loaded module {}".format(modname))
delattr(modclass, "_LOADING")
_reset_lazy_submod_refs(module)
except (AttributeError, ImportError) as err:
logger.debug(
"Failed to load {}.\n{}: {}".format(
modname, err.__class__.__name__, err
)
)
logger.lazy_trace()
# Under Python 3 reloading our dummy LazyModule instances causes an
# AttributeError if the module can't be found. Would be preferrable
# if we could always rely on an ImportError. As it is we vet the
# AttributeError as thoroughly as possible.
if (six.PY3 and isinstance(err, AttributeError)) and not err.args[
0
] == "'NoneType' object has no attribute 'name'":
# Not the AttributeError we were looking for.
raise
msg = modclass._lazy_import_error_msgs["msg"]
raise_from(
ImportError(msg.format(**modclass._lazy_import_error_strings)), None
)
##############################
# Helper functions/constants #
##############################
_MSG = (
"{caller} attempted to use a functionality that requires module "
"{module}, but it couldn't be loaded. Please install {install_name} "
"and retry."
)
_MSG_CALLABLE = (
"{caller} attempted to use a functionality that requires "
"{callable}, of module {module}, but it couldn't be found in that "
"module. Please install a version of {install_name} that has "
"{module}.{callable} and retry."
)
_CLS_ATTRS = (
"_lazy_import_error_strings",
"_lazy_import_error_msgs",
"_lazy_import_callables",
"_lazy_import_submodules",
"__repr__",
)
_DELETION_DICT = ("_lazy_import_submodules",)
def _setdef(argdict, name, defaultvalue):
"""Like dict.setdefault but sets the default value also if None is present.
"""
if not name in argdict or argdict[name] is None:
argdict[name] = defaultvalue
return argdict[name]
def module_basename(modname):
return modname.partition(".")[0]
def _set_default_errornames(modname, error_strings, call=False):
# We don't set the modulename default here because it will change for
# parents of lazily imported submodules.
error_strings.setdefault("caller", _caller_name(3, default="Python"))
error_strings.setdefault("install_name", module_basename(modname))
error_strings.setdefault("msg", _MSG)
if call:
error_strings.setdefault("msg_callable", _MSG_CALLABLE)
def _caller_name(depth=2, default=""):
"""Returns the name of the calling namespace.
"""
# the presence of sys._getframe might be implementation-dependent.
# It isn't that serious if we can't get the caller's name.
try:
return sys._getframe(depth).f_globals["__name__"]
except AttributeError:
return default
def _clean_lazymodule(module):
"""Removes all lazy behavior from a module's class, for loading.
Also removes all module attributes listed under the module's class deletion
dictionaries. Deletion dictionaries are class attributes with names
specified in `_DELETION_DICT`.
Parameters
----------
module: LazyModule
Returns
-------
dict
A dictionary of deleted class attributes, that can be used to reset the
lazy state using :func:`_reset_lazymodule`.
"""
modclass = type(module)
_clean_lazy_submod_refs(module)
modclass.__getattribute__ = ModuleType.__getattribute__
modclass.__setattr__ = ModuleType.__setattr__
cls_attrs = {}
for cls_attr in _CLS_ATTRS:
try:
cls_attrs[cls_attr] = getattr(modclass, cls_attr)
delattr(modclass, cls_attr)
except AttributeError:
pass
return cls_attrs
def _clean_lazy_submod_refs(module):
modclass = type(module)
for deldict in _DELETION_DICT:
try:
delnames = getattr(modclass, deldict)
except AttributeError:
continue
for delname in delnames:
try:
super(LazyModule, module).__delattr__(delname)
except AttributeError:
# Maybe raise a warning?
pass
def _reset_lazymodule(module, cls_attrs):
"""Resets a module's lazy state from cached data.
"""
modclass = type(module)
del modclass.__getattribute__
del modclass.__setattr__
try:
del modclass._LOADING
except AttributeError:
pass
for cls_attr in _CLS_ATTRS:
try:
setattr(modclass, cls_attr, cls_attrs[cls_attr])
except KeyError:
pass
_reset_lazy_submod_refs(module)
def _reset_lazy_submod_refs(module):
modclass = type(module)
for deldict in _DELETION_DICT:
try:
resetnames = getattr(modclass, deldict)
except AttributeError:
continue
for name, submod in resetnames.items():
super(LazyModule, module).__setattr__(name, submod)
def run_from_ipython():
# Taken from https://stackoverflow.com/questions/5376837
try:
__IPYTHON__
return True
except NameError:
return False

View File

@ -0,0 +1,46 @@
# Code copied from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/lazy_loader.py
"""A LazyLoader class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import importlib
import types
class LazyLoader(types.ModuleType):
"""Lazily import a module, mainly to avoid pulling in large dependencies.
`contrib`, and `ffmpeg` are examples of modules that are large and not always
needed, and this allows them to only be loaded when they are used.
"""
# The lint error here is incorrect.
def __init__(
self, local_name, parent_module_globals, name
): # pylint: disable=super-on-old-class
self._local_name = local_name
self._parent_module_globals = parent_module_globals
super(LazyLoader, self).__init__(name)
def _load(self):
# Import the target module and insert it into the parent's namespace
module = importlib.import_module(self.__name__)
self._parent_module_globals[self._local_name] = module
# Update this object's dict so that if someone keeps a reference to the
# LazyLoader, lookups are efficient (__getattr__ is only called on lookups
# that fail).
self.__dict__.update(module.__dict__)
return module
def __getattr__(self, item):
module = self._load()
return getattr(module, item)
def __dir__(self):
module = self._load()
return dir(module)

31
plume/utils/serve.py Normal file
View File

@ -0,0 +1,31 @@
from plume.utils import lazy_module
import typer
rpyc = lazy_module('rpyc')
app = typer.Typer()
class ASRService(rpyc.Service):
def __init__(self, asr_recognizer):
self.asr = asr_recognizer
def on_connect(self, conn):
# code that runs when a connection is created
# (to init the service, if needed)
pass
def on_disconnect(self, conn):
# code that runs after the connection has already closed
# (to finalize the service, if needed)
pass
def exposed_transcribe(self, utterance: bytes): # this is an exposed method
speech_audio = self.asr.transcribe(utterance)
return speech_audio
def exposed_transcribe_cb(
self, utterance: bytes, respond
): # this is an exposed method
speech_audio = self.asr.transcribe(utterance)
respond(speech_audio)

184
plume/utils/transcribe.py Normal file
View File

@ -0,0 +1,184 @@
import os
import logging
from io import BytesIO
from pathlib import Path
from functools import lru_cache
import typer
# import rpyc
# from tqdm import tqdm
# from pydub import AudioSegment
# from pydub.silence import split_on_silence
from plume.utils import lazy_module, lazy_callable
rpyc = lazy_module('rpyc')
AudioSegment = lazy_callable('pydub.AudioSegment')
split_on_silence = lazy_callable('pydub.silence.split_on_silence')
app = typer.Typer()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
ASR_RPYC_HOST = os.environ.get("JASR_RPYC_HOST", "localhost")
ASR_RPYC_PORT = int(os.environ.get("ASR_RPYC_PORT", "8044"))
TRITON_ASR_MODEL = os.environ.get("TRITON_ASR_MODEL", "slu_wav2vec2")
TRITON_GRPC_ASR_HOST = os.environ.get("TRITON_GRPC_ASR_HOST", "localhost")
TRITON_GRPC_ASR_PORT = int(os.environ.get("TRITON_GRPC_ASR_PORT", "8001"))
@lru_cache()
def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT):
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
try:
asr = rpyc.connect(asr_host, asr_port).root
logger.info(f"connected to asr server successfully")
except ConnectionRefusedError:
raise Exception("env-var JASPER_ASR_RPYC_HOST invalid")
def audio_prep(aud_seg):
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
return asr_seg
return asr.transcribe, audio_prep
def triton_transcribe_grpc_gen(
asr_host=TRITON_GRPC_ASR_HOST,
asr_port=TRITON_GRPC_ASR_PORT,
asr_model=TRITON_ASR_MODEL,
method="chunked",
chunk_msec=5000,
sil_msec=500,
# overlap=False,
sep=" ",
):
from tritonclient.utils import np_to_triton_dtype
import tritonclient.grpc as grpcclient
import numpy as np
sup_meth = ["chunked", "silence", "whole"]
if method not in sup_meth:
meths = "|".join(sup_meth)
raise Exception(f"unsupported method {method}. pick one of {meths}")
client = grpcclient.InferenceServerClient(f"{asr_host}:{asr_port}")
def transcriber(aud_seg):
af = BytesIO()
aud_seg.export(af, format="wav")
input_audio_bytes = af.getvalue()
input_audio_data = np.array([input_audio_bytes])
inputs = [
grpcclient.InferInput(
"INPUT_AUDIO",
input_audio_data.shape,
np_to_triton_dtype(input_audio_data.dtype),
)
]
inputs[0].set_data_from_numpy(input_audio_data)
outputs = [grpcclient.InferRequestedOutput("OUTPUT_TEXT")]
response = client.infer(asr_model, inputs, request_id=str(1), outputs=outputs)
transcript = response.as_numpy("OUTPUT_TEXT")[0]
return transcript.decode("utf-8")
def chunked_transcriber(aud_seg):
if method == "silence":
sil_chunks = split_on_silence(
aud_seg,
min_silence_len=sil_msec,
silence_thresh=-50,
keep_silence=500,
)
chunks = [sc for c in sil_chunks for sc in c[::chunk_msec]]
else:
chunks = aud_seg[::chunk_msec]
# if overlap:
# chunks = [
# aud_seg[start, end]
# for start, end in range(0, int(aud_seg.duration_seconds * 1000, 1000))
# ]
# pass
transcript_list = []
sil_pad = AudioSegment.silent(duration=sil_msec)
for seg in chunks:
t_seg = sil_pad + seg + sil_pad
c_transcript = transcriber(t_seg)
transcript_list.append(c_transcript)
transcript = sep.join(transcript_list)
return transcript
def audio_prep(aud_seg):
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
return asr_seg
whole_transcriber = transcriber if method == "whole" else chunked_transcriber
return whole_transcriber, audio_prep
@app.command()
def file(audio_file: Path, write_file: bool = False, chunked=True):
from pydub import AudioSegment
aseg = AudioSegment.from_file(audio_file)
transcriber, prep = triton_transcribe_grpc_gen()
transcription = transcriber(prep(aseg))
typer.echo(transcription)
if write_file:
tscript_file_path = audio_file.with_suffix(".txt")
with open(tscript_file_path, "w") as tf:
tf.write(transcription)
@app.command()
def benchmark(audio_file: Path):
from pydub import AudioSegment
transcriber, audio_prep = transcribe_rpyc_gen()
file_seg = AudioSegment.from_file(audio_file)
aud_seg = audio_prep(file_seg)
def timeinfo():
from timeit import Timer
timer = Timer(lambda: transcriber(aud_seg))
number = 100
repeat = 10
time_taken = timer.repeat(repeat, number=number)
best = min(time_taken) * 1000 / number
print(f"{number} loops, best of {repeat}: {best:.3f} msec per loop")
timeinfo()
import time
time.sleep(5)
transcriber, audio_prep = triton_transcribe_grpc_gen()
aud_seg = audio_prep(file_seg)
def timeinfo():
from timeit import Timer
timer = Timer(lambda: transcriber(aud_seg))
number = 100
repeat = 10
time_taken = timer.repeat(repeat, number=number)
best = min(time_taken) * 1000 / number
print(f"{number} loops, best of {repeat}: {best:.3f} msec per loop")
timeinfo()
def main():
app()
if __name__ == "__main__":
main()

92
plume/utils/tts.py Normal file
View File

@ -0,0 +1,92 @@
from logging import getLogger
from plume.utils import lazy_module
from pathlib import Path
import typer
# from google.cloud import texttospeech
texttospeech = lazy_module('google.cloud.texttospeech')
LOGGER = getLogger("googletts")
app = typer.Typer()
class GoogleTTS(object):
def __init__(self):
self.client = texttospeech.TextToSpeechClient()
def text_to_speech(self, text: str, params: dict) -> bytes:
tts_input = texttospeech.types.SynthesisInput(text=text)
voice = texttospeech.types.VoiceSelectionParams(
language_code=params["language"], name=params["name"]
)
audio_config = texttospeech.types.AudioConfig(
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
sample_rate_hertz=params["sample_rate"],
)
response = self.client.synthesize_speech(tts_input, voice, audio_config)
audio_content = response.audio_content
return audio_content
def ssml_to_speech(self, text: str, params: dict) -> bytes:
tts_input = texttospeech.types.SynthesisInput(ssml=text)
voice = texttospeech.types.VoiceSelectionParams(
language_code=params["language"], name=params["name"]
)
audio_config = texttospeech.types.AudioConfig(
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
sample_rate_hertz=params["sample_rate"],
)
response = self.client.synthesize_speech(tts_input, voice, audio_config)
audio_content = response.audio_content
return audio_content
@classmethod
def voice_list(cls):
"""Lists the available voices."""
client = cls().client
# Performs the list voices request
voices = client.list_voices()
results = []
for voice in voices.voices:
supported_eng_langs = [
lang for lang in voice.language_codes if lang[:2] == "en"
]
if len(supported_eng_langs) > 0:
lang = ",".join(supported_eng_langs)
else:
continue
ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender)
results.append(
{
"name": voice.name,
"language": lang,
"gender": ssml_gender.name,
"engine": "wavenet" if "Wav" in voice.name else "standard",
"sample_rate": voice.natural_sample_rate_hertz,
}
)
return results
@app.command()
def generate_audio_file(text, dest_path: Path = "./tts_audio.wav", voice="en-US-Wavenet-D"):
tts = GoogleTTS()
selected_voice = [v for v in tts.voice_list() if v["name"] == voice][0]
wav_data = tts.text_to_speech(text, selected_voice)
with dest_path.open("wb") as wf:
wf.write(wav_data)
def main():
app()
if __name__ == "__main__":
main()

103
setup.py
View File

@ -1,81 +1,80 @@
from setuptools import setup, find_packages
# pip install "nvidia-pyindex~=1.0.5"
requirements = [
"ruamel.yaml",
"torch==1.4.0",
"torchvision==0.5.0",
"torch~=1.6.0",
"torchvision~=0.7.0",
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
"fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq",
# "google-cloud-texttospeech~=1.0.1",
"tqdm~=4.54.0",
# "pydub~=0.24.0",
# "scikit_learn~=0.22.1",
# "pandas~=1.0.3",
# "boto3~=1.12.35",
# "ruamel.yaml~=0.16.10",
# "pymongo==3.10.1",
# "matplotlib==3.2.1",
# "tabulate==0.8.7",
# "natural==0.2.0",
# "num2words==0.5.10",
"typer[all]~=0.3.2",
# "python-slugify==4.0.0",
# "websockets==8.1",
# "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
"rpyc~=4.1.4",
# "streamlit~=0.61.0",
# "librosa~=0.7.2",
# "tritonclient[http]~=2.6.0",
"numba~=0.48.0",
]
extra_requirements = {
"server": ["rpyc~=4.1.4", "tqdm~=4.39.0"],
"data": [
"google-cloud-texttospeech~=1.0.1",
"tqdm~=4.39.0",
"pydub~=0.24.0",
"google-cloud-texttospeech~=1.0.1",
"scikit_learn~=0.22.1",
"pandas~=1.0.3",
"boto3~=1.12.35",
"ruamel.yaml==0.16.10",
"pymongo==3.10.1",
"librosa==0.7.2",
"numba==0.48",
"matplotlib==3.2.1",
"pandas==1.0.3",
"tabulate==0.8.7",
"natural==0.2.0",
"num2words==0.5.10",
"typer[all]==0.3.1",
"python-slugify==4.0.0",
"ruamel.yaml~=0.16.10",
"pymongo~=3.10.1",
"librosa~=0.7.2",
"matplotlib~=3.2.1",
"pandas~=1.0.3",
"tabulate~=0.8.7",
"natural~=0.2.0",
"num2words~=0.5.10",
"python-slugify~=4.0.0",
"rpyc~=4.1.4",
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
# "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
],
"validation": [
"rpyc~=4.1.4",
"pymongo==3.10.1",
"typer[all]==0.1.1",
"tqdm~=4.39.0",
"librosa==0.7.2",
"matplotlib==3.2.1",
"pymongo~=3.10.1",
"matplotlib~=3.2.1",
"pydub~=0.24.0",
"streamlit==0.58.0",
"natural==0.2.0",
"stringcase==1.2.0",
"streamlit~=0.58.0",
"natural~=0.2.0",
"stringcase~=1.2.0",
"google-cloud-speech~=1.3.1",
]
# "train": [
# "torchaudio==0.5.0",
# "torch-stft==0.1.4",
# ]
],
"train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"],
}
extra_requirements["all"] = list({d for l in extra_requirements.values() for d in l})
packages = find_packages()
setup(
name="jasper-asr",
version="0.1",
description="Tool to get gcp alignments of tts-data",
url="http://github.com/malarinv/jasper-asr",
name="plume-asr",
version="0.11",
description="Multi model ASR base package",
url="http://github.com/malarinv/plume-asr",
author="Malar Kannan",
author_email="malarkannan.invention@gmail.com",
license="MIT",
install_requires=requirements,
extras_require=extra_requirements,
packages=packages,
entry_points={
"console_scripts": [
"jasper_transcribe = jasper.transcribe:main",
"jasper_server = jasper.server:main",
"jasper_trainer = jasper.training.cli:main",
"jasper_evaluator = jasper.evaluate:main",
"jasper_data_tts_generate = jasper.data.tts_generator:main",
"jasper_data_conv_generate = jasper.data.conv_generator:main",
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
"jasper_data_rastrik_recycle = jasper.data.rastrik_recycler:main",
"jasper_data_server = jasper.data.server:main",
"jasper_data_validation = jasper.data.validation.process:main",
"jasper_data_preprocess = jasper.data.process:main",
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
]
},
entry_points={"console_scripts": ["plume = plume.cli:main"]},
zip_safe=False,
)

View File

@ -1,3 +0,0 @@
import runpy
runpy.run_module("jasper.data.validation.ui", run_name="__main__", alter_sys=True)