massive refactor/rename to plume
parent
e8f58a5043
commit
ed6117559a
14
Notes.md
14
Notes.md
|
|
@ -3,3 +3,17 @@
|
||||||
```
|
```
|
||||||
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
|
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> Prepare Augmented Data
|
||||||
|
```
|
||||||
|
plume data filter /dataset/png_entities/png_numbers_2020_07/ /dataset/png_entities/png_numbers_2020_07_skip1hour/
|
||||||
|
|
||||||
|
plume data augment /dataset/agara_slu/call_alphanum_ag_sg_v1_abs/ /dataset/png_entities/png_numbers_2020_07_1hour_noblank/ /dataset/png_entities/png_numbers_2020_07_skip1hour/ /dataset/png_entities/aug_pngskip1hour-agsgalnum-1hournoblank/
|
||||||
|
|
||||||
|
plume data filter --kind transform_digits /dataset/agara_slu/png1hour-agsgalnum-1hournoblank/ /dataset/agara_slu/png1hour-agsgalnum-1hournoblank_prep/
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
KENLM_INC=/usr/local/include/kenlm/ pip install -e ../deps/wav2letter/bindings/python/
|
||||||
|
```
|
||||||
|
|
|
||||||
22
README.md
22
README.md
|
|
@ -1,8 +1,8 @@
|
||||||
# Jasper ASR
|
# Plume ASR
|
||||||
|
|
||||||
[](https://github.com/python/black)
|
[](https://github.com/python/black)
|
||||||
|
|
||||||
> Generates text from speech audio
|
> Generates text from audio containing speech
|
||||||
---
|
---
|
||||||
|
|
||||||
# Table of Contents
|
# Table of Contents
|
||||||
|
|
@ -20,7 +20,7 @@
|
||||||
# Features
|
# Features
|
||||||
|
|
||||||
* ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) )
|
* ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) )
|
||||||
|
* ASR using Wav2Vec2 (from [fairseq](https://github.com/pytorch/fairseq) )
|
||||||
|
|
||||||
# Installation
|
# Installation
|
||||||
To install the packages and its dependencies run.
|
To install the packages and its dependencies run.
|
||||||
|
|
@ -29,14 +29,26 @@ python setup.py install
|
||||||
```
|
```
|
||||||
or with pip
|
or with pip
|
||||||
```bash
|
```bash
|
||||||
pip install .[server]
|
pip install .[all]
|
||||||
```
|
```
|
||||||
|
|
||||||
The installation should work on Python 3.6 or newer. Untested on Python 2.7
|
The installation should work on Python 3.6 or newer. Untested on Python 2.7
|
||||||
|
|
||||||
# Usage
|
# Usage
|
||||||
|
### Library
|
||||||
|
> Jasper
|
||||||
```python
|
```python
|
||||||
from jasper.asr import JasperASR
|
from plume.models.jasper.asr import JasperASR
|
||||||
asr_model = JasperASR("/path/to/model_config_yaml","/path/to/encoder_checkpoint","/path/to/decoder_checkpoint") # Loads the models
|
asr_model = JasperASR("/path/to/model_config_yaml","/path/to/encoder_checkpoint","/path/to/decoder_checkpoint") # Loads the models
|
||||||
TEXT = asr_model.transcribe(wav_data) # Returns the text spoken in the wav
|
TEXT = asr_model.transcribe(wav_data) # Returns the text spoken in the wav
|
||||||
```
|
```
|
||||||
|
> Wav2Vec2
|
||||||
|
```python
|
||||||
|
from plume.models.wav2vec2.asr import Wav2Vec2ASR
|
||||||
|
asr_model = Wav2Vec2ASR("/path/to/ctc_checkpoint","/path/to/w2v_checkpoint","/path/to/target_dictionary") # Loads the models
|
||||||
|
TEXT = asr_model.transcribe(wav_data) # Returns the text spoken in the wav
|
||||||
|
```
|
||||||
|
### Command Line
|
||||||
|
```
|
||||||
|
$ plume
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import rpyc
|
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost")
|
|
||||||
ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045"))
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
|
|
||||||
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
|
|
||||||
asr = rpyc.connect(asr_host, asr_port).root
|
|
||||||
logger.info(f"connected to asr server successfully")
|
|
||||||
return asr.transcribe
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
|
|
||||||
|
|
@ -1,77 +0,0 @@
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from .utils import asr_manifest_reader, asr_manifest_writer
|
|
||||||
from typing import List
|
|
||||||
from itertools import chain
|
|
||||||
import typer
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def fixate_data(dataset_path: Path):
|
|
||||||
manifest_path = dataset_path / Path("manifest.json")
|
|
||||||
real_manifest_path = dataset_path / Path("abs_manifest.json")
|
|
||||||
|
|
||||||
def fix_path():
|
|
||||||
for i in asr_manifest_reader(manifest_path):
|
|
||||||
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
|
|
||||||
yield i
|
|
||||||
|
|
||||||
asr_manifest_writer(real_manifest_path, fix_path())
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
|
||||||
reader_list = []
|
|
||||||
abs_manifest_path = Path("abs_manifest.json")
|
|
||||||
for dataset_path in src_dataset_paths:
|
|
||||||
manifest_path = dataset_path / abs_manifest_path
|
|
||||||
reader_list.append(asr_manifest_reader(manifest_path))
|
|
||||||
dest_dataset_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
dest_manifest_path = dest_dataset_path / abs_manifest_path
|
|
||||||
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def split_data(dataset_path: Path, test_size: float = 0.1):
|
|
||||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
|
||||||
asr_data = list(asr_manifest_reader(manifest_path))
|
|
||||||
train_data, test_data = train_test_split(asr_data, test_size=test_size)
|
|
||||||
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_data)
|
|
||||||
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_data)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def validate_data(dataset_path: Path):
|
|
||||||
from natural.date import compress
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
for mf_type in ["train_manifest.json", "test_manifest.json"]:
|
|
||||||
data_file = dataset_path / Path(mf_type)
|
|
||||||
print(f"validating {data_file}.")
|
|
||||||
with Path(data_file).open("r") as pf:
|
|
||||||
data_jsonl = pf.readlines()
|
|
||||||
duration = 0
|
|
||||||
for (i, s) in enumerate(data_jsonl):
|
|
||||||
try:
|
|
||||||
d = json.loads(s)
|
|
||||||
duration += d["duration"]
|
|
||||||
audio_file = data_file.parent / Path(d["audio_filepath"])
|
|
||||||
if not audio_file.exists():
|
|
||||||
raise OSError(f"File {audio_file} not found")
|
|
||||||
except BaseException as e:
|
|
||||||
print(f'failed on {i} with "{e}"')
|
|
||||||
duration_str = compress(timedelta(seconds=duration), pad=" ")
|
|
||||||
print(
|
|
||||||
f"no errors found. seems like a valid {mf_type}. contains {duration_str}sec of audio"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
app()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
from rastrik.proto.callrecord_pb2 import CallRecord
|
|
||||||
import gzip
|
|
||||||
from pydub import AudioSegment
|
|
||||||
from .utils import ui_dump_manifest_writer, strip_silence
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from itertools import chain
|
|
||||||
from io import BytesIO
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def extract_manifest(
|
|
||||||
call_log_dir: Path = Path("./data/call_audio"),
|
|
||||||
output_dir: Path = Path("./data"),
|
|
||||||
dataset_name: str = "grassroot_pizzahut_v1",
|
|
||||||
caller_name: str = "grassroot",
|
|
||||||
verbose: bool = False,
|
|
||||||
):
|
|
||||||
call_asr_data: Path = output_dir / Path("asr_data")
|
|
||||||
call_asr_data.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
def wav_pb2_generator(log_dir):
|
|
||||||
for wav_path in log_dir.glob("**/*.wav"):
|
|
||||||
if verbose:
|
|
||||||
typer.echo(f"loading events for file {wav_path}")
|
|
||||||
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
|
|
||||||
meta_path = wav_path.with_suffix(".pb2.gz")
|
|
||||||
yield call_wav, wav_path, meta_path
|
|
||||||
|
|
||||||
def read_event(call_wav, log_file):
|
|
||||||
call_wav_0, call_wav_1 = call_wav.split_to_mono()
|
|
||||||
with gzip.open(log_file, "rb") as log_h:
|
|
||||||
record_data = log_h.read()
|
|
||||||
cr = CallRecord()
|
|
||||||
cr.ParseFromString(record_data)
|
|
||||||
|
|
||||||
first_audio_event_timestamp = next(
|
|
||||||
(
|
|
||||||
i
|
|
||||||
for i in cr.events
|
|
||||||
if i.WhichOneof("event_type") == "call_event"
|
|
||||||
and i.call_event.WhichOneof("event_type") == "call_audio"
|
|
||||||
)
|
|
||||||
).timestamp.ToDatetime()
|
|
||||||
|
|
||||||
speech_events = [
|
|
||||||
i
|
|
||||||
for i in cr.events
|
|
||||||
if i.WhichOneof("event_type") == "speech_event"
|
|
||||||
and i.speech_event.WhichOneof("event_type") == "asr_final"
|
|
||||||
]
|
|
||||||
previous_event_timestamp = (
|
|
||||||
first_audio_event_timestamp - first_audio_event_timestamp
|
|
||||||
)
|
|
||||||
for index, each_speech_events in enumerate(speech_events):
|
|
||||||
asr_final = each_speech_events.speech_event.asr_final
|
|
||||||
speech_timestamp = each_speech_events.timestamp.ToDatetime()
|
|
||||||
actual_timestamp = speech_timestamp - first_audio_event_timestamp
|
|
||||||
start_time = previous_event_timestamp.total_seconds() * 1000
|
|
||||||
end_time = actual_timestamp.total_seconds() * 1000
|
|
||||||
audio_segment = strip_silence(call_wav_1[start_time:end_time])
|
|
||||||
|
|
||||||
code_fb = BytesIO()
|
|
||||||
audio_segment.export(code_fb, format="wav")
|
|
||||||
wav_data = code_fb.getvalue()
|
|
||||||
previous_event_timestamp = actual_timestamp
|
|
||||||
duration = (end_time - start_time) / 1000
|
|
||||||
yield asr_final, duration, wav_data, "grassroot", audio_segment
|
|
||||||
|
|
||||||
def generate_call_asr_data():
|
|
||||||
full_data = []
|
|
||||||
total_duration = 0
|
|
||||||
for wav, wav_path, pb2_path in wav_pb2_generator(call_log_dir):
|
|
||||||
asr_data = read_event(wav, pb2_path)
|
|
||||||
total_duration += wav.duration_seconds
|
|
||||||
full_data.append(asr_data)
|
|
||||||
n_calls = len(full_data)
|
|
||||||
typer.echo(f"loaded {n_calls} calls of duration {total_duration}s")
|
|
||||||
n_dps = ui_dump_manifest_writer(call_asr_data, dataset_name, chain(*full_data))
|
|
||||||
typer.echo(f"written {n_dps} data points")
|
|
||||||
|
|
||||||
generate_call_asr_data()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
app()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,241 +0,0 @@
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import wave
|
|
||||||
from pathlib import Path
|
|
||||||
from functools import partial
|
|
||||||
from uuid import uuid4
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
import pymongo
|
|
||||||
from slugify import slugify
|
|
||||||
from jasper.client import transcribe_gen
|
|
||||||
from nemo.collections.asr.metrics import word_error_rate
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import librosa
|
|
||||||
import librosa.display
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
def manifest_str(path, dur, text):
|
|
||||||
return (
|
|
||||||
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def wav_bytes(audio_bytes, frame_rate=24000):
|
|
||||||
wf_b = io.BytesIO()
|
|
||||||
with wave.open(wf_b, mode="w") as wf:
|
|
||||||
wf.setnchannels(1)
|
|
||||||
wf.setframerate(frame_rate)
|
|
||||||
wf.setsampwidth(2)
|
|
||||||
wf.writeframesraw(audio_bytes)
|
|
||||||
return wf_b.getvalue()
|
|
||||||
|
|
||||||
|
|
||||||
def tscript_uuid_fname(transcript):
|
|
||||||
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
|
||||||
|
|
||||||
|
|
||||||
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
|
||||||
dataset_dir = output_dir / Path(dataset_name)
|
|
||||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
|
||||||
asr_manifest = dataset_dir / Path("manifest.json")
|
|
||||||
num_datapoints = 0
|
|
||||||
with asr_manifest.open("w") as mf:
|
|
||||||
print(f"writing manifest to {asr_manifest}")
|
|
||||||
for transcript, audio_dur, wav_data in asr_data_source:
|
|
||||||
fname = tscript_uuid_fname(transcript)
|
|
||||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
|
||||||
audio_file.write_bytes(wav_data)
|
|
||||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
|
||||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
|
||||||
mf.write(manifest)
|
|
||||||
if verbose:
|
|
||||||
print(f"writing '{transcript}' of duration {audio_dur}")
|
|
||||||
num_datapoints += 1
|
|
||||||
return num_datapoints
|
|
||||||
|
|
||||||
|
|
||||||
def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False):
|
|
||||||
dataset_dir = output_dir / Path(dataset_name)
|
|
||||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
|
||||||
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
def data_fn(
|
|
||||||
transcript,
|
|
||||||
audio_dur,
|
|
||||||
wav_data,
|
|
||||||
caller_name,
|
|
||||||
aud_seg,
|
|
||||||
fname,
|
|
||||||
audio_path,
|
|
||||||
num_datapoints,
|
|
||||||
rel_data_path,
|
|
||||||
):
|
|
||||||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
|
||||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
|
||||||
png_path = Path(fname).with_suffix(".png")
|
|
||||||
wav_plot_path = dataset_dir / Path("wav_plots") / png_path
|
|
||||||
if not wav_plot_path.exists():
|
|
||||||
plot_seg(wav_plot_path, audio_path)
|
|
||||||
return {
|
|
||||||
"audio_filepath": str(rel_data_path),
|
|
||||||
"duration": round(audio_dur, 1),
|
|
||||||
"text": transcript,
|
|
||||||
"real_idx": num_datapoints,
|
|
||||||
"audio_path": audio_path,
|
|
||||||
"spoken": transcript,
|
|
||||||
"caller": caller_name,
|
|
||||||
"utterance_id": fname,
|
|
||||||
"pretrained_asr": pretrained_result,
|
|
||||||
"pretrained_wer": pretrained_wer,
|
|
||||||
"plot_path": str(wav_plot_path),
|
|
||||||
}
|
|
||||||
|
|
||||||
num_datapoints = 0
|
|
||||||
data_funcs = []
|
|
||||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
|
||||||
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
|
||||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
|
||||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
|
||||||
audio_file.write_bytes(wav_data)
|
|
||||||
audio_path = str(audio_file)
|
|
||||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
|
||||||
data_funcs.append(
|
|
||||||
partial(
|
|
||||||
data_fn,
|
|
||||||
transcript,
|
|
||||||
audio_dur,
|
|
||||||
wav_data,
|
|
||||||
caller_name,
|
|
||||||
aud_seg,
|
|
||||||
fname,
|
|
||||||
audio_path,
|
|
||||||
num_datapoints,
|
|
||||||
rel_data_path,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
num_datapoints += 1
|
|
||||||
ui_data = parallel_apply(lambda x: x(), data_funcs)
|
|
||||||
return ui_data, num_datapoints
|
|
||||||
|
|
||||||
|
|
||||||
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
|
||||||
dataset_dir = output_dir / Path(dataset_name)
|
|
||||||
dump_data, num_datapoints = ui_data_generator(
|
|
||||||
output_dir, dataset_name, asr_data_source, verbose=verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
asr_manifest = dataset_dir / Path("manifest.json")
|
|
||||||
with asr_manifest.open("w") as mf:
|
|
||||||
print(f"writing manifest to {asr_manifest}")
|
|
||||||
for d in dump_data:
|
|
||||||
rel_data_path = d["audio_filepath"]
|
|
||||||
audio_dur = d["duration"]
|
|
||||||
transcript = d["text"]
|
|
||||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
|
||||||
mf.write(manifest)
|
|
||||||
|
|
||||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
|
||||||
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
|
|
||||||
return num_datapoints
|
|
||||||
|
|
||||||
|
|
||||||
def asr_manifest_reader(data_manifest_path: Path):
|
|
||||||
print(f"reading manifest from {data_manifest_path}")
|
|
||||||
with data_manifest_path.open("r") as pf:
|
|
||||||
data_jsonl = pf.readlines()
|
|
||||||
data_data = [json.loads(v) for v in data_jsonl]
|
|
||||||
for p in data_data:
|
|
||||||
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
|
||||||
p["text"] = p["text"].strip()
|
|
||||||
yield p
|
|
||||||
|
|
||||||
|
|
||||||
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
|
|
||||||
with asr_manifest_path.open("w") as mf:
|
|
||||||
print(f"opening {asr_manifest_path} for writing manifest")
|
|
||||||
for mani_dict in manifest_str_source:
|
|
||||||
manifest = manifest_str(
|
|
||||||
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
|
|
||||||
)
|
|
||||||
mf.write(manifest)
|
|
||||||
|
|
||||||
|
|
||||||
def asr_test_writer(out_file_path: Path, source):
|
|
||||||
def dd_str(dd, idx):
|
|
||||||
path = dd["audio_filepath"]
|
|
||||||
# dur = dd["duration"]
|
|
||||||
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
|
|
||||||
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
|
|
||||||
|
|
||||||
res_file = out_file_path.with_suffix(".result.json")
|
|
||||||
with out_file_path.open("w") as of:
|
|
||||||
print(f"opening {out_file_path} for writing test")
|
|
||||||
results = []
|
|
||||||
idx = 0
|
|
||||||
for ui_dd in source:
|
|
||||||
results.append(ui_dd)
|
|
||||||
out_str = dd_str(ui_dd, idx)
|
|
||||||
of.write(out_str)
|
|
||||||
idx += 1
|
|
||||||
of.write("DO_HANGUP\n")
|
|
||||||
ExtendedPath(res_file).write_json(results)
|
|
||||||
|
|
||||||
|
|
||||||
def batch(iterable, n=1):
|
|
||||||
ls = len(iterable)
|
|
||||||
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
|
|
||||||
|
|
||||||
|
|
||||||
class ExtendedPath(type(Path())):
|
|
||||||
"""docstring for ExtendedPath."""
|
|
||||||
|
|
||||||
def read_json(self):
|
|
||||||
print(f"reading json from {self}")
|
|
||||||
with self.open("r") as jf:
|
|
||||||
return json.load(jf)
|
|
||||||
|
|
||||||
def write_json(self, data):
|
|
||||||
print(f"writing json to {self}")
|
|
||||||
self.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with self.open("w") as jf:
|
|
||||||
return json.dump(data, jf, indent=2)
|
|
||||||
|
|
||||||
|
|
||||||
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
|
|
||||||
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
|
||||||
mongo_uri = f"mongodb://{mongo_host}:{port}/"
|
|
||||||
return pymongo.MongoClient(mongo_uri)[db][col]
|
|
||||||
|
|
||||||
|
|
||||||
def strip_silence(sound):
|
|
||||||
from pydub.silence import detect_leading_silence
|
|
||||||
|
|
||||||
start_trim = detect_leading_silence(sound)
|
|
||||||
end_trim = detect_leading_silence(sound.reverse())
|
|
||||||
duration = len(sound)
|
|
||||||
return sound[start_trim : duration - end_trim]
|
|
||||||
|
|
||||||
|
|
||||||
def plot_seg(wav_plot_path, audio_path):
|
|
||||||
fig = plt.Figure()
|
|
||||||
ax = fig.add_subplot()
|
|
||||||
(y, sr) = librosa.load(audio_path)
|
|
||||||
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
|
||||||
with wav_plot_path.open("wb") as wav_plot_f:
|
|
||||||
fig.set_tight_layout(True)
|
|
||||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
|
||||||
|
|
||||||
|
|
||||||
def parallel_apply(fn, iterable, workers=8):
|
|
||||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
|
||||||
print(f"parallelly applying {fn}")
|
|
||||||
return [
|
|
||||||
res
|
|
||||||
for res in tqdm(
|
|
||||||
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
|
|
||||||
|
|
@ -1,398 +0,0 @@
|
||||||
import json
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import typer
|
|
||||||
|
|
||||||
from ..utils import (
|
|
||||||
ExtendedPath,
|
|
||||||
asr_manifest_reader,
|
|
||||||
asr_manifest_writer,
|
|
||||||
tscript_uuid_fname,
|
|
||||||
get_mongo_conn,
|
|
||||||
plot_seg,
|
|
||||||
)
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_datapoint(idx, rel_root, sample):
|
|
||||||
from pydub import AudioSegment
|
|
||||||
from nemo.collections.asr.metrics import word_error_rate
|
|
||||||
from jasper.client import transcribe_gen
|
|
||||||
|
|
||||||
try:
|
|
||||||
res = dict(sample)
|
|
||||||
res["real_idx"] = idx
|
|
||||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
|
||||||
res["audio_path"] = str(audio_path)
|
|
||||||
res["utterance_id"] = audio_path.stem
|
|
||||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
|
||||||
|
|
||||||
aud_seg = (
|
|
||||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
|
||||||
.set_channels(1)
|
|
||||||
.set_sample_width(2)
|
|
||||||
.set_frame_rate(24000)
|
|
||||||
)
|
|
||||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
|
||||||
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
|
|
||||||
wav_plot_path = (
|
|
||||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
|
||||||
)
|
|
||||||
if not wav_plot_path.exists():
|
|
||||||
plot_seg(wav_plot_path, audio_path)
|
|
||||||
res["plot_path"] = str(wav_plot_path)
|
|
||||||
return res
|
|
||||||
except BaseException as e:
|
|
||||||
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def dump_ui(
|
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
|
||||||
dataset_dir: Path = Path("./data/asr_data"),
|
|
||||||
dump_dir: Path = Path("./data/valiation_data"),
|
|
||||||
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
|
|
||||||
):
|
|
||||||
from io import BytesIO
|
|
||||||
from pydub import AudioSegment
|
|
||||||
from ..utils import ui_data_generator
|
|
||||||
|
|
||||||
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
|
|
||||||
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
|
||||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
|
||||||
|
|
||||||
def asr_data_source_gen():
|
|
||||||
with data_manifest_path.open("r") as pf:
|
|
||||||
data_jsonl = pf.readlines()
|
|
||||||
for v in data_jsonl:
|
|
||||||
sample = json.loads(v)
|
|
||||||
rel_root = data_manifest_path.parent
|
|
||||||
res = dict(sample)
|
|
||||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
|
||||||
audio_segment = (
|
|
||||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
|
||||||
.set_channels(1)
|
|
||||||
.set_sample_width(2)
|
|
||||||
.set_frame_rate(24000)
|
|
||||||
)
|
|
||||||
wav_plot_path = (
|
|
||||||
rel_root
|
|
||||||
/ Path("wav_plots")
|
|
||||||
/ Path(audio_path.name).with_suffix(".png")
|
|
||||||
)
|
|
||||||
if not wav_plot_path.exists():
|
|
||||||
plot_seg(wav_plot_path, audio_path)
|
|
||||||
res["plot_path"] = str(wav_plot_path)
|
|
||||||
code_fb = BytesIO()
|
|
||||||
audio_segment.export(code_fb, format="wav")
|
|
||||||
wav_data = code_fb.getvalue()
|
|
||||||
duration = audio_segment.duration_seconds
|
|
||||||
asr_final = res["text"]
|
|
||||||
yield asr_final, duration, wav_data, "caller", audio_segment
|
|
||||||
|
|
||||||
dump_data, num_datapoints = ui_data_generator(
|
|
||||||
dataset_dir, data_name, asr_data_source_gen()
|
|
||||||
)
|
|
||||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
|
||||||
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def sample_ui(
|
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
|
||||||
dump_file: Path = Path("ui_dump.json"),
|
|
||||||
sample_count: int = typer.Option(80, show_default=True),
|
|
||||||
sample_file: Path = Path("sample_dump.json"),
|
|
||||||
):
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
processed_data_path = dump_dir / Path(data_name) / dump_file
|
|
||||||
sample_path = dump_dir / Path(data_name) / sample_file
|
|
||||||
processed_data = ExtendedPath(processed_data_path).read_json()
|
|
||||||
df = pd.DataFrame(processed_data["data"])
|
|
||||||
samples_per_caller = sample_count // len(df["caller"].unique())
|
|
||||||
caller_samples = pd.concat(
|
|
||||||
[g.sample(samples_per_caller) for (c, g) in df.groupby("caller")]
|
|
||||||
)
|
|
||||||
caller_samples = caller_samples.reset_index(drop=True)
|
|
||||||
caller_samples["real_idx"] = caller_samples.index
|
|
||||||
sample_data = caller_samples.to_dict("records")
|
|
||||||
processed_data["data"] = sample_data
|
|
||||||
typer.echo(f"sampling {sample_count} datapoints")
|
|
||||||
ExtendedPath(sample_path).write_json(processed_data)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def task_ui(
|
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
|
||||||
dump_file: Path = Path("ui_dump.json"),
|
|
||||||
task_count: int = typer.Option(4, show_default=True),
|
|
||||||
task_file: str = "task_dump",
|
|
||||||
):
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
processed_data_path = dump_dir / Path(data_name) / dump_file
|
|
||||||
processed_data = ExtendedPath(processed_data_path).read_json()
|
|
||||||
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
|
|
||||||
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
|
|
||||||
task_f = task_f.reset_index(drop=True)
|
|
||||||
task_f["real_idx"] = task_f.index
|
|
||||||
task_data = task_f.to_dict("records")
|
|
||||||
processed_data["data"] = task_data
|
|
||||||
task_path = dump_dir / Path(data_name) / Path(task_file + f"-{t_idx}.json")
|
|
||||||
ExtendedPath(task_path).write_json(processed_data)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def dump_corrections(
|
|
||||||
task_uid: str,
|
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
|
||||||
dump_fname: Path = Path("corrections.json"),
|
|
||||||
):
|
|
||||||
dump_path = dump_dir / Path(data_name) / dump_fname
|
|
||||||
col = get_mongo_conn(col="asr_validation")
|
|
||||||
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
|
||||||
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
|
||||||
cursor_obj = col.find(
|
|
||||||
{"type": "correction", "task_id": task_id}, projection={"_id": False}
|
|
||||||
)
|
|
||||||
corrections = [c for c in cursor_obj]
|
|
||||||
ExtendedPath(dump_path).write_json(corrections)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def caller_quality(
|
|
||||||
task_uid: str,
|
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
|
||||||
dump_fname: Path = Path("ui_dump.json"),
|
|
||||||
correction_fname: Path = Path("corrections.json"),
|
|
||||||
):
|
|
||||||
import copy
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
dump_path = dump_dir / Path(data_name) / dump_fname
|
|
||||||
correction_path = dump_dir / Path(data_name) / correction_fname
|
|
||||||
dump_data = ExtendedPath(dump_path).read_json()
|
|
||||||
|
|
||||||
dump_map = {d["utterance_id"]: d for d in dump_data["data"]}
|
|
||||||
correction_data = ExtendedPath(correction_path).read_json()
|
|
||||||
|
|
||||||
def correction_dp(c):
|
|
||||||
dp = copy.deepcopy(dump_map[c["code"]])
|
|
||||||
dp["valid"] = c["value"]["status"] == "Correct"
|
|
||||||
return dp
|
|
||||||
|
|
||||||
corrected_dump = [
|
|
||||||
correction_dp(c)
|
|
||||||
for c in correction_data
|
|
||||||
if c["task_id"].rsplit("-", 1)[1] == task_uid
|
|
||||||
]
|
|
||||||
df = pd.DataFrame(corrected_dump)
|
|
||||||
print(f"Total samples: {len(df)}")
|
|
||||||
for (c, g) in df.groupby("caller"):
|
|
||||||
total = len(g)
|
|
||||||
valid = len(g[g["valid"] == True])
|
|
||||||
valid_rate = valid * 100 / total
|
|
||||||
print(f"Caller: {c} Valid%:{valid_rate:.2f} of {total} samples")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def fill_unannotated(
|
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
|
||||||
dump_dir: Path = Path("./data/valiation_data"),
|
|
||||||
dump_file: Path = Path("ui_dump.json"),
|
|
||||||
corrections_file: Path = Path("corrections.json"),
|
|
||||||
):
|
|
||||||
processed_data_path = dump_dir / Path(data_name) / dump_file
|
|
||||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
|
||||||
processed_data = json.load(processed_data_path.open())
|
|
||||||
corrections = json.load(corrections_path.open())
|
|
||||||
annotated_codes = {c["code"] for c in corrections}
|
|
||||||
all_codes = {c["gold_chars"] for c in processed_data}
|
|
||||||
unann_codes = all_codes - annotated_codes
|
|
||||||
mongo_conn = get_mongo_conn(col="asr_validation")
|
|
||||||
for c in unann_codes:
|
|
||||||
mongo_conn.find_one_and_update(
|
|
||||||
{"type": "correction", "code": c},
|
|
||||||
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
|
|
||||||
upsert=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def split_extract(
|
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
|
||||||
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
|
|
||||||
# dump_dir: Path = Path("./data/valiation_data"),
|
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
|
||||||
dump_file: Path = Path("ui_dump.json"),
|
|
||||||
manifest_file: Path = Path("manifest.json"),
|
|
||||||
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
|
||||||
conv_data_path: Path = typer.Option(
|
|
||||||
Path("./data/conv_data.json"), show_default=True
|
|
||||||
),
|
|
||||||
extraction_type: str = "all",
|
|
||||||
):
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
|
||||||
conv_data = ExtendedPath(conv_data_path).read_json()
|
|
||||||
|
|
||||||
def extract_data_of_type(extraction_key):
|
|
||||||
extraction_vals = conv_data[extraction_key]
|
|
||||||
dest_data_name = data_name + "_" + extraction_key.lower()
|
|
||||||
|
|
||||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
|
||||||
dest_data_dir = dump_dir / Path(dest_data_name)
|
|
||||||
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
|
||||||
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
|
||||||
dest_manifest_path = dest_data_dir / manifest_file
|
|
||||||
dest_ui_path = dest_data_dir / dump_file
|
|
||||||
|
|
||||||
def extract_manifest(mg):
|
|
||||||
for m in mg:
|
|
||||||
if m["text"] in extraction_vals:
|
|
||||||
shutil.copy(
|
|
||||||
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
|
|
||||||
)
|
|
||||||
yield m
|
|
||||||
|
|
||||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
|
||||||
|
|
||||||
ui_data_path = dump_dir / Path(data_name) / dump_file
|
|
||||||
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
|
||||||
ui_data = orig_ui_data["data"]
|
|
||||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
|
||||||
extracted_ui_data = list(
|
|
||||||
filter(lambda u: u["text"] in extraction_vals, ui_data)
|
|
||||||
)
|
|
||||||
final_data = []
|
|
||||||
for i, d in enumerate(extracted_ui_data):
|
|
||||||
d["real_idx"] = i
|
|
||||||
final_data.append(d)
|
|
||||||
orig_ui_data["data"] = final_data
|
|
||||||
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
|
||||||
|
|
||||||
if corrections_file:
|
|
||||||
dest_correction_path = dest_data_dir / corrections_file
|
|
||||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
|
||||||
corrections = json.load(corrections_path.open())
|
|
||||||
extracted_corrections = list(
|
|
||||||
filter(
|
|
||||||
lambda c: c["code"] in file_ui_map
|
|
||||||
and file_ui_map[c["code"]]["text"] in extraction_vals,
|
|
||||||
corrections,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
|
||||||
|
|
||||||
if extraction_type.value == "all":
|
|
||||||
for ext_key in conv_data.keys():
|
|
||||||
extract_data_of_type(ext_key)
|
|
||||||
else:
|
|
||||||
extract_data_of_type(extraction_type.value)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def update_corrections(
|
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
|
||||||
manifest_file: Path = Path("manifest.json"),
|
|
||||||
corrections_file: Path = Path("corrections.json"),
|
|
||||||
ui_dump_file: Path = Path("ui_dump.json"),
|
|
||||||
skip_incorrect: bool = typer.Option(True, show_default=True),
|
|
||||||
):
|
|
||||||
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
|
||||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
|
||||||
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
|
|
||||||
|
|
||||||
def correct_manifest(ui_dump_path, corrections_path):
|
|
||||||
corrections = ExtendedPath(corrections_path).read_json()
|
|
||||||
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
|
|
||||||
correct_set = {
|
|
||||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
|
||||||
}
|
|
||||||
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
|
|
||||||
correction_map = {
|
|
||||||
c["code"]: c["value"]["correction"]
|
|
||||||
for c in corrections
|
|
||||||
if c["value"]["status"] == "Incorrect"
|
|
||||||
}
|
|
||||||
# for d in manifest_data_gen:
|
|
||||||
# if d["chars"] in incorrect_set:
|
|
||||||
# d["audio_path"].unlink()
|
|
||||||
# renamed_set = set()
|
|
||||||
for d in ui_data:
|
|
||||||
if d["utterance_id"] in correct_set:
|
|
||||||
yield {
|
|
||||||
"audio_filepath": d["audio_filepath"],
|
|
||||||
"duration": d["duration"],
|
|
||||||
"text": d["text"],
|
|
||||||
}
|
|
||||||
elif d["utterance_id"] in correction_map:
|
|
||||||
correct_text = correction_map[d["utterance_id"]]
|
|
||||||
if skip_incorrect:
|
|
||||||
print(
|
|
||||||
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
orig_audio_path = Path(d["audio_path"])
|
|
||||||
new_name = str(
|
|
||||||
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
|
|
||||||
)
|
|
||||||
new_audio_path = orig_audio_path.with_name(new_name)
|
|
||||||
orig_audio_path.replace(new_audio_path)
|
|
||||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
|
||||||
yield {
|
|
||||||
"audio_filepath": new_filepath,
|
|
||||||
"duration": d["duration"],
|
|
||||||
"text": correct_text,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
orig_audio_path = Path(d["audio_path"])
|
|
||||||
# don't delete if another correction points to an old file
|
|
||||||
# if d["text"] not in renamed_set:
|
|
||||||
orig_audio_path.unlink()
|
|
||||||
# else:
|
|
||||||
# print(f'skipping deletion of correction:{d["text"]}')
|
|
||||||
|
|
||||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
|
||||||
dataset_dir = data_manifest_path.parent
|
|
||||||
dataset_name = dataset_dir.name
|
|
||||||
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
|
|
||||||
if not backup_dir.exists():
|
|
||||||
typer.echo(f"backing up to :{backup_dir}")
|
|
||||||
shutil.copytree(str(dataset_dir), str(backup_dir))
|
|
||||||
# manifest_gen = asr_manifest_reader(data_manifest_path)
|
|
||||||
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
|
|
||||||
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
|
||||||
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
|
||||||
new_data_manifest_path.replace(data_manifest_path)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def clear_mongo_corrections():
|
|
||||||
delete = typer.confirm("are you sure you want to clear mongo collection it?")
|
|
||||||
if delete:
|
|
||||||
col = get_mongo_conn(col="asr_validation")
|
|
||||||
col.delete_many({"type": "correction"})
|
|
||||||
col.delete_many({"type": "current_cursor"})
|
|
||||||
typer.echo("deleted mongo collection.")
|
|
||||||
return
|
|
||||||
typer.echo("Aborted")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
app()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,57 +0,0 @@
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import rpyc
|
|
||||||
from rpyc.utils.server import ThreadedServer
|
|
||||||
|
|
||||||
from .asr import JasperASR
|
|
||||||
from .utils import arg_parser
|
|
||||||
|
|
||||||
|
|
||||||
class ASRService(rpyc.Service):
|
|
||||||
def __init__(self, asr_recognizer):
|
|
||||||
self.asr = asr_recognizer
|
|
||||||
|
|
||||||
def on_connect(self, conn):
|
|
||||||
# code that runs when a connection is created
|
|
||||||
# (to init the service, if needed)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_disconnect(self, conn):
|
|
||||||
# code that runs after the connection has already closed
|
|
||||||
# (to finalize the service, if needed)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def exposed_transcribe(self, utterance: bytes): # this is an exposed method
|
|
||||||
speech_audio = self.asr.transcribe(utterance)
|
|
||||||
return speech_audio
|
|
||||||
|
|
||||||
def exposed_transcribe_cb(
|
|
||||||
self, utterance: bytes, respond
|
|
||||||
): # this is an exposed method
|
|
||||||
speech_audio = self.asr.transcribe(utterance)
|
|
||||||
respond(speech_audio)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = arg_parser('jasper_transcribe')
|
|
||||||
parser.description = 'jasper asr rpyc server'
|
|
||||||
parser.add_argument(
|
|
||||||
"--port", type=int, default=int(os.environ.get("ASR_RPYC_PORT", "8044")), help="port to listen on"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
args_dict = vars(args)
|
|
||||||
port = args_dict.pop("port")
|
|
||||||
jasper_asr = JasperASR(**args_dict)
|
|
||||||
service = ASRService(jasper_asr)
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
||||||
)
|
|
||||||
logging.info("starting asr server...")
|
|
||||||
t = ThreadedServer(service, port=port)
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
|
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
from .asr import JasperASR
|
|
||||||
from .utils import arg_parser
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = arg_parser('jasper_transcribe')
|
|
||||||
parser.description = 'transcribe audio file to text'
|
|
||||||
parser.add_argument(
|
|
||||||
"audio_file",
|
|
||||||
type=Path,
|
|
||||||
help="audio file(16khz 1channel int16 wav) to transcribe",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--greedy", type=bool, default=False, help="enables greedy decoding"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
args_dict = vars(args)
|
|
||||||
audio_file = args_dict.pop("audio_file")
|
|
||||||
greedy = args_dict.pop("greedy")
|
|
||||||
jasper_asr = JasperASR(**args_dict)
|
|
||||||
jasper_asr.transcribe_file(audio_file, greedy)
|
|
||||||
|
|
@ -1,40 +0,0 @@
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
MODEL_YAML = os.environ.get("JASPER_MODEL_CONFIG", "/models/jasper/jasper10x5dr.yaml")
|
|
||||||
CHECKPOINT_ENCODER = os.environ.get(
|
|
||||||
"JASPER_ENCODER_CHECKPOINT", "/models/jasper/JasperEncoder-STEP-265520.pt"
|
|
||||||
)
|
|
||||||
CHECKPOINT_DECODER = os.environ.get(
|
|
||||||
"JASPER_DECODER_CHECKPOINT", "/models/jasper/JasperDecoderForCTC-STEP-265520.pt"
|
|
||||||
)
|
|
||||||
KEN_LM = os.environ.get("JASPER_KEN_LM", "/models/jasper/kenlm.pt")
|
|
||||||
|
|
||||||
|
|
||||||
def arg_parser(prog):
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
prog=prog, description=f"convert speech to text"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_yaml",
|
|
||||||
type=Path,
|
|
||||||
default=Path(MODEL_YAML),
|
|
||||||
help="model config yaml file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder_checkpoint",
|
|
||||||
type=Path,
|
|
||||||
default=Path(CHECKPOINT_ENCODER),
|
|
||||||
help="encoder checkpoint weights file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoder_checkpoint",
|
|
||||||
type=Path,
|
|
||||||
default=Path(CHECKPOINT_DECODER),
|
|
||||||
help="decoder checkpoint weights file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--language_model", type=Path, default=None, help="kenlm language model file"
|
|
||||||
)
|
|
||||||
return parser
|
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
import typer
|
||||||
|
from ..utils import app as utils_app
|
||||||
|
from .data import app as data_app
|
||||||
|
from ..ui import app as ui_app
|
||||||
|
from .train import app as train_app
|
||||||
|
from .eval import app as eval_app
|
||||||
|
from .serve import app as serve_app
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
app.add_typer(data_app, name="data")
|
||||||
|
app.add_typer(ui_app, name="ui")
|
||||||
|
app.add_typer(train_app, name="train")
|
||||||
|
app.add_typer(eval_app, name="eval")
|
||||||
|
app.add_typer(serve_app, name="serve")
|
||||||
|
app.add_typer(utils_app, name='utils')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,339 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
# from sklearn.model_selection import train_test_split
|
||||||
|
from plume.utils import (
|
||||||
|
asr_manifest_reader,
|
||||||
|
asr_manifest_writer,
|
||||||
|
ExtendedPath,
|
||||||
|
duration_str,
|
||||||
|
generate_filter_map,
|
||||||
|
get_mongo_conn,
|
||||||
|
tscript_uuid_fname,
|
||||||
|
lazy_callable
|
||||||
|
)
|
||||||
|
from typing import List
|
||||||
|
from itertools import chain
|
||||||
|
import shutil
|
||||||
|
import typer
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from ...models.wav2vec2.data import app as wav2vec2_app
|
||||||
|
from .generate import app as generate_app
|
||||||
|
|
||||||
|
train_test_split = lazy_callable('sklearn.model_selection.train_test_split')
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
app.add_typer(generate_app, name="generate")
|
||||||
|
app.add_typer(wav2vec2_app, name="wav2vec2")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def fix_path(dataset_path: Path, force: bool = False):
|
||||||
|
manifest_path = dataset_path / Path("manifest.json")
|
||||||
|
real_manifest_path = dataset_path / Path("abs_manifest.json")
|
||||||
|
|
||||||
|
def fix_real_path():
|
||||||
|
for i in asr_manifest_reader(manifest_path):
|
||||||
|
i["audio_filepath"] = str(
|
||||||
|
(dataset_path / Path(i["audio_filepath"])).absolute()
|
||||||
|
)
|
||||||
|
yield i
|
||||||
|
|
||||||
|
def fix_rel_path():
|
||||||
|
for i in asr_manifest_reader(real_manifest_path):
|
||||||
|
i["audio_filepath"] = str(
|
||||||
|
Path(i["audio_filepath"]).relative_to(dataset_path)
|
||||||
|
)
|
||||||
|
yield i
|
||||||
|
|
||||||
|
if not manifest_path.exists() and not real_manifest_path.exists():
|
||||||
|
typer.echo("Invalid dataset directory")
|
||||||
|
if not real_manifest_path.exists() or force:
|
||||||
|
asr_manifest_writer(real_manifest_path, fix_real_path())
|
||||||
|
if not manifest_path.exists():
|
||||||
|
asr_manifest_writer(manifest_path, fix_rel_path())
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def augment(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||||
|
reader_list = []
|
||||||
|
abs_manifest_path = Path("abs_manifest.json")
|
||||||
|
for dataset_path in src_dataset_paths:
|
||||||
|
manifest_path = dataset_path / abs_manifest_path
|
||||||
|
reader_list.append(asr_manifest_reader(manifest_path))
|
||||||
|
dest_dataset_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
dest_manifest_path = dest_dataset_path / abs_manifest_path
|
||||||
|
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def split(dataset_path: Path, test_size: float = 0.03):
|
||||||
|
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||||
|
if not manifest_path.exists():
|
||||||
|
fix_path(dataset_path)
|
||||||
|
asr_data = list(asr_manifest_reader(manifest_path))
|
||||||
|
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
|
||||||
|
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
|
||||||
|
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def validate(dataset_path: Path):
|
||||||
|
from natural.date import compress
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
for mf_type in ["train_manifest.json", "test_manifest.json"]:
|
||||||
|
data_file = dataset_path / Path(mf_type)
|
||||||
|
print(f"validating {data_file}.")
|
||||||
|
with Path(data_file).open("r") as pf:
|
||||||
|
pnr_jsonl = pf.readlines()
|
||||||
|
duration = 0
|
||||||
|
for (i, s) in enumerate(pnr_jsonl):
|
||||||
|
try:
|
||||||
|
d = json.loads(s)
|
||||||
|
duration += d["duration"]
|
||||||
|
audio_file = data_file.parent / Path(d["audio_filepath"])
|
||||||
|
if not audio_file.exists():
|
||||||
|
raise OSError(f"File {audio_file} not found")
|
||||||
|
except BaseException as e:
|
||||||
|
print(f'failed on {i} with "{e}"')
|
||||||
|
duration_str = compress(timedelta(seconds=duration), pad=" ")
|
||||||
|
print(
|
||||||
|
f"no errors found. seems like a valid {mf_type}. contains {duration_str} of audio"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def filter(src_dataset_path: Path, dest_dataset_path: Path, kind: str = "skip_dur"):
|
||||||
|
dest_manifest = dest_dataset_path / Path("manifest.json")
|
||||||
|
data_file = src_dataset_path / Path("manifest.json")
|
||||||
|
dest_wav_dir = dest_dataset_path / Path("wavs")
|
||||||
|
dest_wav_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
filter_kind_map = generate_filter_map(
|
||||||
|
src_dataset_path, dest_dataset_path, data_file
|
||||||
|
)
|
||||||
|
|
||||||
|
selected_filter = filter_kind_map.get(kind, None)
|
||||||
|
if selected_filter:
|
||||||
|
asr_manifest_writer(dest_manifest, selected_filter())
|
||||||
|
else:
|
||||||
|
typer.echo(f"filter kind - {kind} not implemented")
|
||||||
|
typer.echo(f"select one of {', '.join(filter_kind_map.keys())}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def info(dataset_path: Path):
|
||||||
|
for k in ["", "abs_", "train_", "test_"]:
|
||||||
|
mf_wav_duration = (
|
||||||
|
real_duration
|
||||||
|
) = max_duration = empty_duration = empty_count = total_count = 0
|
||||||
|
data_file = dataset_path / Path(f"{k}manifest.json")
|
||||||
|
if data_file.exists():
|
||||||
|
print(f"stats on {data_file}")
|
||||||
|
for s in ExtendedPath(data_file).read_jsonl():
|
||||||
|
total_count += 1
|
||||||
|
mf_wav_duration += s["duration"]
|
||||||
|
if s["text"] == "":
|
||||||
|
empty_count += 1
|
||||||
|
empty_duration += s["duration"]
|
||||||
|
wav_path = str(dataset_path / Path(s["audio_filepath"]))
|
||||||
|
if max_duration < soundfile.info(wav_path).duration:
|
||||||
|
max_duration = soundfile.info(wav_path).duration
|
||||||
|
real_duration += soundfile.info(wav_path).duration
|
||||||
|
|
||||||
|
# frame_count = soundfile.info(audio_fname).frames
|
||||||
|
print(f"max audio duration : {duration_str(max_duration)}")
|
||||||
|
print(f"total audio duration : {duration_str(mf_wav_duration)}")
|
||||||
|
print(f"total real audio duration : {duration_str(real_duration)}")
|
||||||
|
print(
|
||||||
|
f"total content duration : {duration_str(mf_wav_duration-empty_duration)}"
|
||||||
|
)
|
||||||
|
print(f"total empty duration : {duration_str(empty_duration)}")
|
||||||
|
print(
|
||||||
|
f"total empty samples : {empty_count}/{total_count} ({empty_count*100/total_count:.2f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def audio_duration(dataset_path: Path):
|
||||||
|
wav_duration = 0
|
||||||
|
for audio_rel_fname in dataset_path.absolute().glob("**/*.wav"):
|
||||||
|
audio_fname = str(audio_rel_fname)
|
||||||
|
wav_duration += soundfile.info(audio_fname).duration
|
||||||
|
typer.echo(f"duration of wav files @ {dataset_path}: {duration_str(wav_duration)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def migrate(src_path: Path, dest_path: Path):
|
||||||
|
shutil.copytree(str(src_path), str(dest_path))
|
||||||
|
wav_dir = dest_path / Path("wavs")
|
||||||
|
wav_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
abs_manifest_path = ExtendedPath(dest_path / Path("abs_manifest.json"))
|
||||||
|
backup_abs_manifest_path = abs_manifest_path.with_suffix(".json.orig")
|
||||||
|
shutil.copy(abs_manifest_path, backup_abs_manifest_path)
|
||||||
|
manifest_data = list(abs_manifest_path.read_jsonl())
|
||||||
|
for md in manifest_data:
|
||||||
|
orig_path = Path(md["audio_filepath"])
|
||||||
|
new_path = wav_dir / Path(orig_path.name)
|
||||||
|
shutil.copy(orig_path, new_path)
|
||||||
|
md["audio_filepath"] = str(new_path)
|
||||||
|
abs_manifest_path.write_jsonl(manifest_data)
|
||||||
|
fix_path(dest_path)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def task_split(
|
||||||
|
data_dir: Path,
|
||||||
|
dump_file: Path = Path("ui_dump.json"),
|
||||||
|
task_count: int = typer.Option(2, show_default=True),
|
||||||
|
task_file: str = "task_dump",
|
||||||
|
sort: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
split ui_dump.json to `task_count` tasks
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
processed_data_path = data_dir / dump_file
|
||||||
|
processed_data = ExtendedPath(processed_data_path).read_json()
|
||||||
|
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
|
||||||
|
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
|
||||||
|
task_f = task_f.reset_index(drop=True)
|
||||||
|
task_f["real_idx"] = task_f.index
|
||||||
|
task_data = task_f.to_dict("records")
|
||||||
|
if sort:
|
||||||
|
task_data = sorted(task_data, key=lambda x: x["asr_wer"], reverse=True)
|
||||||
|
processed_data["data"] = task_data
|
||||||
|
task_path = data_dir / Path(task_file + f"-{t_idx}.json")
|
||||||
|
ExtendedPath(task_path).write_json(processed_data)
|
||||||
|
|
||||||
|
|
||||||
|
def get_corrections(task_uid):
|
||||||
|
col = get_mongo_conn(col="asr_validation")
|
||||||
|
task_id = [
|
||||||
|
c
|
||||||
|
for c in col.distinct("task_id")
|
||||||
|
if c.rsplit("-", 1)[1] == task_uid or c == task_uid
|
||||||
|
][0]
|
||||||
|
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
||||||
|
cursor_obj = col.find(
|
||||||
|
{"type": "correction", "task_id": task_id}, projection={"_id": False}
|
||||||
|
)
|
||||||
|
corrections = [c for c in cursor_obj]
|
||||||
|
return corrections
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def dump_task_corrections(data_dir: Path, task_uid: str):
|
||||||
|
dump_fname: Path = Path(f"corrections-{task_uid}.json")
|
||||||
|
dump_path = data_dir / dump_fname
|
||||||
|
corrections = get_corrections(task_uid)
|
||||||
|
ExtendedPath(dump_path).write_json(corrections)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def dump_all_corrections(data_dir: Path):
|
||||||
|
for task_lcks in data_dir.glob('task-*.lck'):
|
||||||
|
task_uid = task_lcks.stem.replace('task-', '')
|
||||||
|
dump_task_corrections(data_dir, task_uid)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def update_corrections(
|
||||||
|
data_dir: Path,
|
||||||
|
skip_incorrect: bool = typer.Option(
|
||||||
|
False, show_default=True, help="treats incorrect as invalid"
|
||||||
|
),
|
||||||
|
skip_inaudible: bool = typer.Option(
|
||||||
|
False, show_default=True, help="include invalid as blank target"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
applies the corrections-*.json
|
||||||
|
backup the original dataset
|
||||||
|
"""
|
||||||
|
manifest_file: Path = Path("manifest.json")
|
||||||
|
renames_file: Path = Path("rename_map.json")
|
||||||
|
ui_dump_file: Path = Path("ui_dump.json")
|
||||||
|
data_manifest_path = data_dir / manifest_file
|
||||||
|
renames_path = data_dir / renames_file
|
||||||
|
|
||||||
|
def correct_ui_dump(data_dir, rename_result):
|
||||||
|
ui_dump_path = data_dir / ui_dump_file
|
||||||
|
# corrections_path = data_dir / Path("corrections.json")
|
||||||
|
corrections = [
|
||||||
|
t
|
||||||
|
for p in data_dir.glob("corrections-*.json")
|
||||||
|
for t in ExtendedPath(p).read_json()
|
||||||
|
]
|
||||||
|
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
|
||||||
|
correct_set = {
|
||||||
|
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||||
|
}
|
||||||
|
correction_map = {
|
||||||
|
c["code"]: c["value"]["correction"]
|
||||||
|
for c in corrections
|
||||||
|
if c["value"]["status"] == "Incorrect"
|
||||||
|
}
|
||||||
|
for d in ui_data:
|
||||||
|
orig_audio_path = (data_dir / Path(d["audio_path"])).absolute()
|
||||||
|
if d["utterance_id"] in correct_set:
|
||||||
|
d["corrected_from"] = d["text"]
|
||||||
|
yield d
|
||||||
|
elif d["utterance_id"] in correction_map:
|
||||||
|
correct_text = correction_map[d["utterance_id"]]
|
||||||
|
if skip_incorrect:
|
||||||
|
ap = d["audio_path"]
|
||||||
|
print(f"skipping incorrect {ap} corrected to {correct_text}")
|
||||||
|
orig_audio_path.unlink()
|
||||||
|
else:
|
||||||
|
new_fname = tscript_uuid_fname(correct_text)
|
||||||
|
rename_result[new_fname] = {
|
||||||
|
"orig_text": d["text"],
|
||||||
|
"correct_text": correct_text,
|
||||||
|
"orig_id": d["utterance_id"],
|
||||||
|
}
|
||||||
|
new_name = str(Path(new_fname).with_suffix(".wav"))
|
||||||
|
new_audio_path = orig_audio_path.with_name(new_name)
|
||||||
|
orig_audio_path.replace(new_audio_path)
|
||||||
|
new_filepath = str(Path(d["audio_path"]).with_name(new_name))
|
||||||
|
d["corrected_from"] = d["text"]
|
||||||
|
d["text"] = correct_text
|
||||||
|
d["audio_path"] = new_filepath
|
||||||
|
yield d
|
||||||
|
else:
|
||||||
|
if skip_inaudible:
|
||||||
|
orig_audio_path.unlink()
|
||||||
|
else:
|
||||||
|
d["corrected_from"] = d["text"]
|
||||||
|
d["text"] = ""
|
||||||
|
yield d
|
||||||
|
|
||||||
|
dataset_dir = data_manifest_path.parent
|
||||||
|
dataset_name = dataset_dir.name
|
||||||
|
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
|
||||||
|
if not backup_dir.exists():
|
||||||
|
typer.echo(f"backing up to {backup_dir}")
|
||||||
|
shutil.copytree(str(dataset_dir), str(backup_dir))
|
||||||
|
renames = {}
|
||||||
|
corrected_ui_dump = list(correct_ui_dump(data_dir, renames))
|
||||||
|
ExtendedPath(data_dir / ui_dump_file).write_json({"data": corrected_ui_dump})
|
||||||
|
corrected_manifest = (
|
||||||
|
{
|
||||||
|
"audio_filepath": d["audio_path"],
|
||||||
|
"duration": d["duration"],
|
||||||
|
"text": d["text"],
|
||||||
|
}
|
||||||
|
for d in corrected_ui_dump
|
||||||
|
)
|
||||||
|
asr_manifest_writer(data_manifest_path, corrected_manifest)
|
||||||
|
ExtendedPath(renames_path).write_json(renames)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from ...utils.tts import GoogleTTS
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def tts_dataset(dest_path: Path):
|
||||||
|
tts = GoogleTTS()
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
import typer
|
||||||
|
from ..models.wav2vec2.eval import app as wav2vec2_app
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
app.add_typer(wav2vec2_app, name="wav2vec2")
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
import typer
|
||||||
|
from ..models.wav2vec2.serve import app as wav2vec2_app
|
||||||
|
from ..models.jasper.serve import app as jasper_app
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
app.add_typer(wav2vec2_app, name="wav2vec2")
|
||||||
|
app.add_typer(jasper_app, name="jasper")
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
import typer
|
||||||
|
from ..models.wav2vec2.train import app as train_app
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
app.add_typer(train_app, name="wav2vec2")
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
# from . import jasper, wav2vec2, matchboxnet
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import typer
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def set_root(dataset_path: Path, root_path: Path):
|
||||||
|
pass
|
||||||
|
# for dataset_kind in ["train", "valid"]:
|
||||||
|
# data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||||
|
# with data_file.open("r") as df:
|
||||||
|
# lines = df.readlines()
|
||||||
|
# with data_file.open("w") as df:
|
||||||
|
# lines[0] = str(root_path) + "\n"
|
||||||
|
# df.writelines(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -45,7 +45,7 @@ def parse_args():
|
||||||
eval_freq=100,
|
eval_freq=100,
|
||||||
load_dir="./train/models/jasper/",
|
load_dir="./train/models/jasper/",
|
||||||
warmup_steps=3,
|
warmup_steps=3,
|
||||||
exp_name="jasper-speller",
|
exp_name="jasper",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Overwrite default args
|
# Overwrite default args
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from rpyc.utils.server import ThreadedServer
|
||||||
|
import typer
|
||||||
|
|
||||||
|
# from .asr import JasperASR
|
||||||
|
from ...utils.serve import ASRService
|
||||||
|
from plume.utils import lazy_callable
|
||||||
|
|
||||||
|
JasperASR = lazy_callable('plume.models.jasper.asr.JasperASR')
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def rpyc(
|
||||||
|
encoder_path: Path = "/path/to/encoder.pt",
|
||||||
|
decoder_path: Path = "/path/to/decoder.pt",
|
||||||
|
model_yaml_path: Path = "/path/to/model.yaml",
|
||||||
|
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||||
|
):
|
||||||
|
for p in [encoder_path, decoder_path, model_yaml_path]:
|
||||||
|
if not p.exists():
|
||||||
|
logging.info(f"{p} doesn't exists")
|
||||||
|
return
|
||||||
|
asr = JasperASR(str(model_yaml_path), str(encoder_path), str(decoder_path))
|
||||||
|
service = ASRService(asr)
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
logging.info("starting asr server...")
|
||||||
|
t = ThreadedServer(service, port=port)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
|
||||||
|
encoder_path = model_dir / Path("decoder.pt")
|
||||||
|
decoder_path = model_dir / Path("encoder.pt")
|
||||||
|
model_yaml_path = model_dir / Path("model.yaml")
|
||||||
|
rpyc(encoder_path, decoder_path, model_yaml_path, port)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,204 @@
|
||||||
|
from io import BytesIO
|
||||||
|
import warnings
|
||||||
|
import itertools as it
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import soundfile as sf
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fairseq import utils
|
||||||
|
from fairseq.models import BaseFairseqModel
|
||||||
|
from fairseq.data import Dictionary
|
||||||
|
from fairseq.models.wav2vec.wav2vec2_asr import base_architecture, Wav2VecEncoder
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
warnings.warn("Install fairseq")
|
||||||
|
try:
|
||||||
|
from wav2letter.decoder import CriterionType
|
||||||
|
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
warnings.warn("Install wav2letter")
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2VecCtc(BaseFairseqModel):
|
||||||
|
def __init__(self, w2v_encoder, args):
|
||||||
|
super().__init__()
|
||||||
|
self.w2v_encoder = w2v_encoder
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def upgrade_state_dict_named(self, state_dict, name):
|
||||||
|
super().upgrade_state_dict_named(state_dict, name)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_model(cls, args, target_dict):
|
||||||
|
"""Build a new model instance."""
|
||||||
|
base_architecture(args)
|
||||||
|
w2v_encoder = Wav2VecEncoder(args, target_dict)
|
||||||
|
return cls(w2v_encoder, args)
|
||||||
|
|
||||||
|
def get_normalized_probs(self, net_output, log_probs):
|
||||||
|
"""Get normalized probabilities (or log probs) from a net's output."""
|
||||||
|
logits = net_output["encoder_out"]
|
||||||
|
if log_probs:
|
||||||
|
return utils.log_softmax(logits.float(), dim=-1)
|
||||||
|
else:
|
||||||
|
return utils.softmax(logits.float(), dim=-1)
|
||||||
|
|
||||||
|
def forward(self, **kwargs):
|
||||||
|
x = self.w2v_encoder(**kwargs)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class W2lDecoder(object):
|
||||||
|
def __init__(self, tgt_dict):
|
||||||
|
self.tgt_dict = tgt_dict
|
||||||
|
self.vocab_size = len(tgt_dict)
|
||||||
|
self.nbest = 1
|
||||||
|
|
||||||
|
self.criterion_type = CriterionType.CTC
|
||||||
|
self.blank = (
|
||||||
|
tgt_dict.index("<ctc_blank>")
|
||||||
|
if "<ctc_blank>" in tgt_dict.indices
|
||||||
|
else tgt_dict.bos()
|
||||||
|
)
|
||||||
|
self.asg_transitions = None
|
||||||
|
|
||||||
|
def generate(self, model, sample, **unused):
|
||||||
|
"""Generate a batch of inferences."""
|
||||||
|
# model.forward normally channels prev_output_tokens into the decoder
|
||||||
|
# separately, but SequenceGenerator directly calls model.encoder
|
||||||
|
encoder_input = {
|
||||||
|
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
|
||||||
|
}
|
||||||
|
emissions = self.get_emissions(model, encoder_input)
|
||||||
|
return self.decode(emissions)
|
||||||
|
|
||||||
|
def get_emissions(self, model, encoder_input):
|
||||||
|
"""Run encoder and normalize emissions"""
|
||||||
|
# encoder_out = models[0].encoder(**encoder_input)
|
||||||
|
encoder_out = model(**encoder_input)
|
||||||
|
if self.criterion_type == CriterionType.CTC:
|
||||||
|
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
|
||||||
|
|
||||||
|
return emissions.transpose(0, 1).float().cpu().contiguous()
|
||||||
|
|
||||||
|
def get_tokens(self, idxs):
|
||||||
|
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
|
||||||
|
idxs = (g[0] for g in it.groupby(idxs))
|
||||||
|
idxs = filter(lambda x: x != self.blank, idxs)
|
||||||
|
|
||||||
|
return torch.LongTensor(list(idxs))
|
||||||
|
|
||||||
|
|
||||||
|
class W2lViterbiDecoder(W2lDecoder):
|
||||||
|
def __init__(self, tgt_dict):
|
||||||
|
super().__init__(tgt_dict)
|
||||||
|
|
||||||
|
def decode(self, emissions):
|
||||||
|
B, T, N = emissions.size()
|
||||||
|
hypos = list()
|
||||||
|
|
||||||
|
if self.asg_transitions is None:
|
||||||
|
transitions = torch.FloatTensor(N, N).zero_()
|
||||||
|
else:
|
||||||
|
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
|
||||||
|
|
||||||
|
viterbi_path = torch.IntTensor(B, T)
|
||||||
|
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
|
||||||
|
CpuViterbiPath.compute(
|
||||||
|
B,
|
||||||
|
T,
|
||||||
|
N,
|
||||||
|
get_data_ptr_as_bytes(emissions),
|
||||||
|
get_data_ptr_as_bytes(transitions),
|
||||||
|
get_data_ptr_as_bytes(viterbi_path),
|
||||||
|
get_data_ptr_as_bytes(workspace),
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
|
||||||
|
for b in range(B)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def post_process(sentence: str, symbol: str):
|
||||||
|
if symbol == "sentencepiece":
|
||||||
|
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
|
||||||
|
elif symbol == "wordpiece":
|
||||||
|
sentence = sentence.replace(" ", "").replace("_", " ").strip()
|
||||||
|
elif symbol == "letter":
|
||||||
|
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
||||||
|
elif symbol == "_EOW":
|
||||||
|
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
||||||
|
elif symbol is not None and symbol != "none":
|
||||||
|
sentence = (sentence + " ").replace(symbol, "").rstrip()
|
||||||
|
return sentence
|
||||||
|
|
||||||
|
|
||||||
|
def get_feature(filepath):
|
||||||
|
def postprocess(feats, sample_rate):
|
||||||
|
if feats.dim == 2:
|
||||||
|
feats = feats.mean(-1)
|
||||||
|
|
||||||
|
assert feats.dim() == 1, feats.dim()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
feats = F.layer_norm(feats, feats.shape)
|
||||||
|
return feats
|
||||||
|
|
||||||
|
wav, sample_rate = sf.read(filepath)
|
||||||
|
feats = torch.from_numpy(wav).float()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
feats = feats.cuda()
|
||||||
|
feats = postprocess(feats, sample_rate)
|
||||||
|
return feats
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(ctc_model_path, w2v_model_path, target_dict):
|
||||||
|
w2v = torch.load(ctc_model_path)
|
||||||
|
w2v["args"].w2v_path = w2v_model_path
|
||||||
|
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
|
||||||
|
model.load_state_dict(w2v["model"], strict=True)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
model = model.cuda()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2ASR(object):
|
||||||
|
"""docstring for Wav2Vec2ASR."""
|
||||||
|
|
||||||
|
def __init__(self, ctc_path, w2v_path, target_dict_path):
|
||||||
|
super(Wav2Vec2ASR, self).__init__()
|
||||||
|
self.target_dict = Dictionary.load(target_dict_path)
|
||||||
|
|
||||||
|
self.model = load_model(ctc_path, w2v_path, self.target_dict)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
self.generator = W2lViterbiDecoder(self.target_dict)
|
||||||
|
|
||||||
|
def transcribe(self, audio_data, greedy=True):
|
||||||
|
aud_f = BytesIO(audio_data)
|
||||||
|
# aud_seg = pydub.AudioSegment.from_file(aud_f)
|
||||||
|
# feat_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||||
|
# feat_f = io.BytesIO()
|
||||||
|
# feat_seg.export(feat_f, format='wav')
|
||||||
|
# feat_f.seek(0)
|
||||||
|
net_input = {}
|
||||||
|
feature = get_feature(aud_f)
|
||||||
|
net_input["source"] = feature.unsqueeze(0)
|
||||||
|
|
||||||
|
padding_mask = (
|
||||||
|
torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)
|
||||||
|
)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
padding_mask = padding_mask.cuda()
|
||||||
|
|
||||||
|
net_input["padding_mask"] = padding_mask
|
||||||
|
sample = {}
|
||||||
|
sample["net_input"] = net_input
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
hypo = self.generator.generate(self.model, sample, prefix_tokens=None)
|
||||||
|
hyp_pieces = self.target_dict.string(hypo[0][0]["tokens"].int().cpu())
|
||||||
|
result = post_process(hyp_pieces, "letter")
|
||||||
|
return result
|
||||||
|
|
@ -0,0 +1,86 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import Counter
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
# import pydub
|
||||||
|
import typer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from plume.utils import (
|
||||||
|
ExtendedPath,
|
||||||
|
replace_redundant_spaces_with,
|
||||||
|
lazy_module
|
||||||
|
)
|
||||||
|
pydub = lazy_module('pydub')
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool = True):
|
||||||
|
dict_ltr = dest_dataset_path / Path("dict.ltr.txt")
|
||||||
|
(dest_dataset_path / Path("wavs")).mkdir(exist_ok=True, parents=True)
|
||||||
|
tok_counter = Counter()
|
||||||
|
shutil.copy(
|
||||||
|
src_dataset_path / Path("test_manifest.json"),
|
||||||
|
src_dataset_path / Path("valid_manifest.json"),
|
||||||
|
)
|
||||||
|
if unlink:
|
||||||
|
src_wavs = src_dataset_path / Path("wavs")
|
||||||
|
for wav_path in tqdm(list(src_wavs.glob("**/*.wav"))):
|
||||||
|
audio_seg = (
|
||||||
|
pydub.AudioSegment.from_wav(wav_path)
|
||||||
|
.set_frame_rate(16000)
|
||||||
|
.set_channels(1)
|
||||||
|
)
|
||||||
|
dest_path = dest_dataset_path / Path("wavs") / Path(wav_path.name)
|
||||||
|
audio_seg.export(dest_path, format="wav")
|
||||||
|
|
||||||
|
for dataset_kind in ["train", "valid"]:
|
||||||
|
abs_manifest_path = ExtendedPath(
|
||||||
|
src_dataset_path / Path(f"{dataset_kind}_manifest.json")
|
||||||
|
)
|
||||||
|
manifest_data = list(abs_manifest_path.read_jsonl())
|
||||||
|
o_tsv, o_ltr = f"{dataset_kind}.tsv", f"{dataset_kind}.ltr"
|
||||||
|
out_tsv = dest_dataset_path / Path(o_tsv)
|
||||||
|
out_ltr = dest_dataset_path / Path(o_ltr)
|
||||||
|
with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f:
|
||||||
|
if unlink:
|
||||||
|
tsv_f.write(f"{dest_dataset_path}\n")
|
||||||
|
else:
|
||||||
|
tsv_f.write(f"{src_dataset_path}\n")
|
||||||
|
for md in manifest_data:
|
||||||
|
audio_fname = md["audio_filepath"]
|
||||||
|
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()
|
||||||
|
# pipe_toks = "|".join(re.sub(" ", "", md["text"]))
|
||||||
|
# pipe_toks = alnum_to_asr_tokens(md["text"]).upper().replace(" ", "|")
|
||||||
|
tok_counter.update(pipe_toks)
|
||||||
|
letter_toks = " ".join(pipe_toks) + " |\n"
|
||||||
|
frame_count = soundfile.info(audio_fname).frames
|
||||||
|
rel_path = Path(audio_fname).relative_to(src_dataset_path.absolute())
|
||||||
|
ltr_f.write(letter_toks)
|
||||||
|
tsv_f.write(f"{rel_path}\t{frame_count}\n")
|
||||||
|
with dict_ltr.open("w") as d_f:
|
||||||
|
for k, v in tok_counter.most_common():
|
||||||
|
d_f.write(f"{k} {v}\n")
|
||||||
|
(src_dataset_path / Path("valid_manifest.json")).unlink()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def set_root(dataset_path: Path, root_path: Path):
|
||||||
|
for dataset_kind in ["train", "valid"]:
|
||||||
|
data_file = dataset_path / Path(dataset_kind).with_suffix(".tsv")
|
||||||
|
with data_file.open("r") as df:
|
||||||
|
lines = df.readlines()
|
||||||
|
with data_file.open("w") as df:
|
||||||
|
lines[0] = str(root_path) + "\n"
|
||||||
|
df.writelines(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import typer
|
||||||
|
from tqdm import tqdm
|
||||||
|
# import pandas as pd
|
||||||
|
|
||||||
|
from plume.utils import (
|
||||||
|
asr_manifest_reader,
|
||||||
|
discard_except_digits,
|
||||||
|
replace_digit_symbol,
|
||||||
|
lazy_module
|
||||||
|
# run_shell,
|
||||||
|
)
|
||||||
|
from ...utils.transcribe import triton_transcribe_grpc_gen
|
||||||
|
|
||||||
|
pd = lazy_module('pandas')
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def manifest(manifest_file: Path, result_file: Path = "results.csv"):
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
host = "localhost"
|
||||||
|
port = 8044
|
||||||
|
transcriber, audio_prep = triton_transcribe_grpc_gen(host, port, method='whole')
|
||||||
|
result_path = manifest_file.parent / result_file
|
||||||
|
manifest_list = list(asr_manifest_reader(manifest_file))
|
||||||
|
|
||||||
|
def compute_frame(d):
|
||||||
|
audio_file = d["audio_path"]
|
||||||
|
orig_text = d["text"]
|
||||||
|
orig_num = discard_except_digits(replace_digit_symbol(orig_text))
|
||||||
|
aud_seg = AudioSegment.from_file(audio_file)
|
||||||
|
t_audio = audio_prep(aud_seg)
|
||||||
|
asr_text = transcriber(t_audio)
|
||||||
|
asr_num = discard_except_digits(replace_digit_symbol(asr_text))
|
||||||
|
return {
|
||||||
|
"audio_file": audio_file,
|
||||||
|
"asr_text": asr_text,
|
||||||
|
"asr_num": asr_num,
|
||||||
|
"orig_text": orig_text,
|
||||||
|
"orig_num": orig_num,
|
||||||
|
"asr_match": orig_num == asr_num,
|
||||||
|
}
|
||||||
|
|
||||||
|
# df_data = parallel_apply(compute_frame, manifest_list)
|
||||||
|
df_data = map(compute_frame, tqdm(manifest_list))
|
||||||
|
df = pd.DataFrame(df_data)
|
||||||
|
df.to_csv(result_path)
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# from rpyc.utils.server import ThreadedServer
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from ...utils.serve import ASRService
|
||||||
|
from plume.utils import lazy_callable
|
||||||
|
# from .asr import Wav2Vec2ASR
|
||||||
|
|
||||||
|
ThreadedServer = lazy_callable('rpyc.utils.server.ThreadedServer')
|
||||||
|
Wav2Vec2ASR = lazy_callable('plume.models.wav2vec2.asr.Wav2Vec2ASR')
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def rpyc(
|
||||||
|
w2v_path: Path = "/path/to/base.pt",
|
||||||
|
ctc_path: Path = "/path/to/ctc.pt",
|
||||||
|
target_dict_path: Path = "/path/to/dict.ltr.txt",
|
||||||
|
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
|
||||||
|
):
|
||||||
|
for p in [w2v_path, ctc_path, target_dict_path]:
|
||||||
|
if not p.exists():
|
||||||
|
logging.info(f"{p} doesn't exists")
|
||||||
|
return
|
||||||
|
w2vasr = Wav2Vec2ASR(str(ctc_path), str(w2v_path), str(target_dict_path))
|
||||||
|
service = ASRService(w2vasr)
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
logging.info("starting asr server...")
|
||||||
|
t = ThreadedServer(service, port=port)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
|
||||||
|
ctc_path = model_dir / Path("ctc.pt")
|
||||||
|
w2v_path = model_dir / Path("base.pt")
|
||||||
|
target_dict_path = model_dir / Path("dict.ltr.txt")
|
||||||
|
rpyc(w2v_path, ctc_path, target_dict_path, port)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,34 @@
|
||||||
|
import typer
|
||||||
|
# from fairseq_cli.train import cli_main
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
import shlex
|
||||||
|
from plume.utils import lazy_callable
|
||||||
|
|
||||||
|
cli_main = lazy_callable('fairseq_cli.train.cli_main')
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def local(dataset_path: Path):
|
||||||
|
args = f'''--distributed-world-size 1 {dataset_path} \
|
||||||
|
--save-dir /dataset/wav2vec2/model/wav2vec2_l_num_ctc_v1 --post-process letter --valid-subset \
|
||||||
|
valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 --max-update 80000 \
|
||||||
|
--sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /dataset/wav2vec2/pretrained/wav2vec_vox_new.pt \
|
||||||
|
--labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \
|
||||||
|
--mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 \
|
||||||
|
--zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 \
|
||||||
|
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 \
|
||||||
|
--hold-steps 32000 --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 \
|
||||||
|
--activation-dropout 0.1 --criterion ctc --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json \
|
||||||
|
--log-interval 500 --ddp-backend no_c10d --reset-optimizer --normalize
|
||||||
|
'''
|
||||||
|
new_args = ['train.py']
|
||||||
|
new_args.extend(shlex.split(args))
|
||||||
|
sys.argv = new_args
|
||||||
|
cli_main()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli_main()
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
import typer
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from plume.utils import lazy_module
|
||||||
|
# from streamlit import cli as stcli
|
||||||
|
|
||||||
|
stcli = lazy_module('streamlit.cli')
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
|
||||||
|
annotation_lit_path = Path(__file__).parent / Path("annotation.py")
|
||||||
|
if task_id:
|
||||||
|
sys.argv = [
|
||||||
|
"streamlit",
|
||||||
|
"run",
|
||||||
|
str(annotation_lit_path),
|
||||||
|
"--",
|
||||||
|
str(data_dir),
|
||||||
|
"--task-id",
|
||||||
|
task_id,
|
||||||
|
"--dump-fname",
|
||||||
|
dump_fname,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
sys.argv = [
|
||||||
|
"streamlit",
|
||||||
|
"run",
|
||||||
|
str(annotation_lit_path),
|
||||||
|
"--",
|
||||||
|
str(data_dir),
|
||||||
|
"--dump-fname",
|
||||||
|
dump_fname,
|
||||||
|
]
|
||||||
|
sys.exit(stcli.main())
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def preview(manifest_path: Path):
|
||||||
|
annotation_lit_path = Path(__file__).parent / Path("preview.py")
|
||||||
|
sys.argv = [
|
||||||
|
"streamlit",
|
||||||
|
"run",
|
||||||
|
str(annotation_lit_path),
|
||||||
|
"--",
|
||||||
|
str(manifest_path)
|
||||||
|
]
|
||||||
|
sys.exit(stcli.main())
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def collection(data_dir: Path, task_id: str = ""):
|
||||||
|
# TODO: Implement web ui for data collection
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
|
# import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import typer
|
import typer
|
||||||
from uuid import uuid4
|
|
||||||
from ..utils import ExtendedPath, get_mongo_conn
|
from plume.utils import ExtendedPath, get_mongo_conn
|
||||||
from .st_rerun import rerun
|
from plume.preview.st_rerun import rerun
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
@ -42,10 +44,10 @@ if not hasattr(st, "mongo_connected"):
|
||||||
upsert=True,
|
upsert=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_task_fn(mf_path, task_id):
|
def set_task_fn(data_path, task_id):
|
||||||
if task_id:
|
if task_id:
|
||||||
st.task_id = task_id
|
st.task_id = task_id
|
||||||
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
|
task_path = data_path / Path(f"task-{st.task_id}.lck")
|
||||||
if not task_path.exists():
|
if not task_path.exists():
|
||||||
print(f"creating task lock at {task_path}")
|
print(f"creating task lock at {task_path}")
|
||||||
task_path.touch()
|
task_path.touch()
|
||||||
|
|
@ -62,17 +64,28 @@ if not hasattr(st, "mongo_connected"):
|
||||||
|
|
||||||
|
|
||||||
@st.cache()
|
@st.cache()
|
||||||
def load_ui_data(validation_ui_data_path: Path):
|
def load_ui_data(data_dir: Path, dump_fname: Path):
|
||||||
|
validation_ui_data_path = data_dir / dump_fname
|
||||||
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||||
return ExtendedPath(validation_ui_data_path).read_json()
|
return ExtendedPath(validation_ui_data_path).read_json()
|
||||||
|
|
||||||
|
|
||||||
|
def show_key(sample, key, trail=""):
|
||||||
|
if key in sample:
|
||||||
|
title = key.replace("_", " ").title()
|
||||||
|
if type(sample[key]) == float:
|
||||||
|
st.sidebar.markdown(f"{title}: {sample[key]:.2f}{trail}")
|
||||||
|
else:
|
||||||
|
st.sidebar.markdown(f"{title}: {sample[key]}")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def main(manifest: Path, task_id: str = ""):
|
def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
|
||||||
st.set_task(manifest, task_id)
|
st.set_task(data_dir, task_id)
|
||||||
ui_config = load_ui_data(manifest)
|
ui_config = load_ui_data(data_dir, dump_fname)
|
||||||
asr_data = ui_config["data"]
|
asr_data = ui_config["data"]
|
||||||
annotation_only = ui_config.get("annotation_only", False)
|
annotation_only = ui_config.get("annotation_only", False)
|
||||||
|
asr_result_key = ui_config.get("asr_result_key", "pretrained_asr")
|
||||||
sample_no = st.get_current_cursor()
|
sample_no = st.get_current_cursor()
|
||||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||||
print("Invalid samplno resetting to 0")
|
print("Invalid samplno resetting to 0")
|
||||||
|
|
@ -91,15 +104,16 @@ def main(manifest: Path, task_id: str = ""):
|
||||||
st.update_cursor(new_sample - 1)
|
st.update_cursor(new_sample - 1)
|
||||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||||
|
# if "caller" in sample:
|
||||||
|
# st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
||||||
|
show_key(sample, "caller")
|
||||||
if not annotation_only:
|
if not annotation_only:
|
||||||
st.sidebar.title("Results:")
|
show_key(sample, asr_result_key)
|
||||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
show_key(sample, "asr_wer", trail="%")
|
||||||
if "caller" in sample:
|
show_key(sample, "correct_candidate")
|
||||||
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
|
||||||
else:
|
st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes())
|
||||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
st.audio((data_dir / Path(sample["audio_path"])).open("rb"))
|
||||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
|
||||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
|
||||||
# set default to text
|
# set default to text
|
||||||
corrected = sample["text"]
|
corrected = sample["text"]
|
||||||
correction_entry = st.get_correction_entry(sample["utterance_id"])
|
correction_entry = st.get_correction_entry(sample["utterance_id"])
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
import typer
|
||||||
|
from plume.utils import ExtendedPath
|
||||||
|
from plume.preview.st_rerun import rerun
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
if not hasattr(st, "state_lock"):
|
||||||
|
# st.task_id = str(uuid4())
|
||||||
|
task_path = ExtendedPath("preview.lck")
|
||||||
|
|
||||||
|
def current_cursor_fn():
|
||||||
|
return task_path.read_json()["current_cursor"]
|
||||||
|
|
||||||
|
def update_cursor_fn(val=0):
|
||||||
|
task_path.write_json({"current_cursor": val})
|
||||||
|
rerun()
|
||||||
|
|
||||||
|
st.get_current_cursor = current_cursor_fn
|
||||||
|
st.update_cursor = update_cursor_fn
|
||||||
|
st.state_lock = True
|
||||||
|
# cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
|
||||||
|
# if not cursor_obj:
|
||||||
|
update_cursor_fn(0)
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache()
|
||||||
|
def load_ui_data(validation_ui_data_path: Path):
|
||||||
|
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||||
|
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(manifest: Path):
|
||||||
|
asr_data = load_ui_data(manifest)
|
||||||
|
sample_no = st.get_current_cursor()
|
||||||
|
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||||
|
print("Invalid samplno resetting to 0")
|
||||||
|
st.update_cursor(0)
|
||||||
|
sample = asr_data[sample_no]
|
||||||
|
st.title(f"ASR Manifest Preview")
|
||||||
|
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||||
|
new_sample = st.number_input(
|
||||||
|
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||||
|
)
|
||||||
|
if new_sample != sample_no + 1:
|
||||||
|
st.update_cursor(new_sample - 1)
|
||||||
|
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||||
|
st.audio((manifest.parent / Path(sample["audio_filepath"])).open("rb"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
app()
|
||||||
|
except SystemExit:
|
||||||
|
pass
|
||||||
|
|
@ -1,7 +1,15 @@
|
||||||
import streamlit.ReportThread as ReportThread
|
try:
|
||||||
from streamlit.ScriptRequestQueue import RerunData
|
# Before Streamlit 0.65
|
||||||
from streamlit.ScriptRunner import RerunException
|
from streamlit.ReportThread import get_report_ctx
|
||||||
from streamlit.server.Server import Server
|
from streamlit.server.Server import Server
|
||||||
|
from streamlit.ScriptRequestQueue import RerunData
|
||||||
|
from streamlit.ScriptRunner import RerunException
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# After Streamlit 0.65
|
||||||
|
from streamlit.report_thread import get_report_ctx
|
||||||
|
from streamlit.server.server import Server
|
||||||
|
from streamlit.script_request_queue import RerunData
|
||||||
|
from streamlit.script_runner import RerunException
|
||||||
|
|
||||||
|
|
||||||
def rerun():
|
def rerun():
|
||||||
|
|
@ -13,7 +21,7 @@ def rerun():
|
||||||
def _get_widget_states():
|
def _get_widget_states():
|
||||||
# Hack to get the session object from Streamlit.
|
# Hack to get the session object from Streamlit.
|
||||||
|
|
||||||
ctx = ReportThread.get_report_ctx()
|
ctx = get_report_ctx()
|
||||||
|
|
||||||
session = None
|
session = None
|
||||||
|
|
||||||
|
|
@ -34,5 +42,4 @@ def _get_widget_states():
|
||||||
"Are you doing something fancy with threads?"
|
"Are you doing something fancy with threads?"
|
||||||
)
|
)
|
||||||
# Got the session object!
|
# Got the session object!
|
||||||
|
|
||||||
return session._widget_states
|
return session._widget_states
|
||||||
|
|
@ -0,0 +1,486 @@
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import wave
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from functools import partial
|
||||||
|
from uuid import uuid4
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||||
|
import subprocess
|
||||||
|
import shutil
|
||||||
|
from urllib.parse import urlsplit
|
||||||
|
# from .lazy_loader import LazyLoader
|
||||||
|
from .lazy_import import lazy_callable, lazy_module
|
||||||
|
|
||||||
|
# from ruamel.yaml import YAML
|
||||||
|
# import boto3
|
||||||
|
import typer
|
||||||
|
# import pymongo
|
||||||
|
# from slugify import slugify
|
||||||
|
# import pydub
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
# import librosa
|
||||||
|
# import librosa.display as audio_display
|
||||||
|
# from natural.date import compress
|
||||||
|
# from num2words import num2words
|
||||||
|
from tqdm import tqdm
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
# from .transcribe import triton_transcribe_grpc_gen
|
||||||
|
# from .eval import app as eval_app
|
||||||
|
from .tts import app as tts_app
|
||||||
|
from .transcribe import app as transcribe_app
|
||||||
|
from .align import app as align_app
|
||||||
|
|
||||||
|
boto3 = lazy_module('boto3')
|
||||||
|
pymongo = lazy_module('pymongo')
|
||||||
|
pydub = lazy_module('pydub')
|
||||||
|
audio_display = lazy_module('librosa.display')
|
||||||
|
plt = lazy_module('matplotlib.pyplot')
|
||||||
|
librosa = lazy_module('librosa')
|
||||||
|
YAML = lazy_callable('ruamel.yaml.YAML')
|
||||||
|
num2words = lazy_callable('num2words.num2words')
|
||||||
|
slugify = lazy_callable('slugify.slugify')
|
||||||
|
compress = lazy_callable('natural.date.compress')
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
app.add_typer(tts_app, name="tts")
|
||||||
|
app.add_typer(align_app, name="align")
|
||||||
|
app.add_typer(transcribe_app, name="transcribe")
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def manifest_str(path, dur, text):
|
||||||
|
return (
|
||||||
|
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def duration_str(seconds):
|
||||||
|
return compress(timedelta(seconds=seconds), pad=" ")
|
||||||
|
|
||||||
|
|
||||||
|
def replace_digit_symbol(w2v_out):
|
||||||
|
num_int_map = {num2words(i): str(i) for i in range(10)}
|
||||||
|
out = w2v_out.lower()
|
||||||
|
for (k, v) in num_int_map.items():
|
||||||
|
out = re.sub(k, v, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def discard_except_digits(inp):
|
||||||
|
return re.sub("[^0-9]", "", inp)
|
||||||
|
|
||||||
|
|
||||||
|
def digits_to_chars(text):
|
||||||
|
num_tokens = [num2words(c) + " " if "0" <= c <= "9" else c for c in text]
|
||||||
|
return ("".join(num_tokens)).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def replace_redundant_spaces_with(text, sub):
|
||||||
|
return re.sub(" +", sub, text)
|
||||||
|
|
||||||
|
|
||||||
|
def space_out(text):
|
||||||
|
letters = " ".join(list(text))
|
||||||
|
return letters
|
||||||
|
|
||||||
|
|
||||||
|
def wav_bytes(audio_bytes, frame_rate=24000):
|
||||||
|
wf_b = io.BytesIO()
|
||||||
|
with wave.open(wf_b, mode="w") as wf:
|
||||||
|
wf.setnchannels(1)
|
||||||
|
wf.setframerate(frame_rate)
|
||||||
|
wf.setsampwidth(2)
|
||||||
|
wf.writeframesraw(audio_bytes)
|
||||||
|
return wf_b.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def tscript_uuid_fname(transcript):
|
||||||
|
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||||
|
|
||||||
|
|
||||||
|
def run_shell(cmd_str, work_dir="."):
|
||||||
|
cwd_path = Path(work_dir).absolute()
|
||||||
|
p = subprocess.Popen(
|
||||||
|
cmd_str,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
shell=True,
|
||||||
|
cwd=cwd_path,
|
||||||
|
)
|
||||||
|
for line in p.stdout:
|
||||||
|
print(line.replace(b"\n", b"").decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def upload_s3(dataset_path, s3_path):
|
||||||
|
run_shell(f"aws s3 sync {dataset_path} {s3_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_download_path(s3_uri, output_path):
|
||||||
|
s3_uri_p = urlsplit(s3_uri)
|
||||||
|
download_path = output_path / Path(s3_uri_p.path[1:])
|
||||||
|
download_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
return download_path
|
||||||
|
|
||||||
|
|
||||||
|
def s3_downloader():
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
|
||||||
|
def download_s3(s3_uri, download_path):
|
||||||
|
s3_uri_p = urlsplit(s3_uri)
|
||||||
|
download_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
if not download_path.exists():
|
||||||
|
print(f"downloading {s3_uri} to {download_path}")
|
||||||
|
s3.download_file(s3_uri_p.netloc, s3_uri_p.path[1:], str(download_path))
|
||||||
|
|
||||||
|
return download_s3
|
||||||
|
|
||||||
|
|
||||||
|
def asr_data_writer(dataset_dir, asr_data_source, verbose=False):
|
||||||
|
(dataset_dir / Path("wavs")).mkdir(parents=True, exist_ok=True)
|
||||||
|
asr_manifest = dataset_dir / Path("manifest.json")
|
||||||
|
num_datapoints = 0
|
||||||
|
with asr_manifest.open("w") as mf:
|
||||||
|
print(f"writing manifest to {asr_manifest}")
|
||||||
|
for transcript, audio_dur, wav_data in asr_data_source:
|
||||||
|
fname = tscript_uuid_fname(transcript)
|
||||||
|
audio_file = dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav")
|
||||||
|
audio_file.write_bytes(wav_data)
|
||||||
|
rel_data_path = audio_file.relative_to(dataset_dir)
|
||||||
|
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||||
|
mf.write(manifest)
|
||||||
|
if verbose:
|
||||||
|
print(f"writing '{transcript}' of duration {audio_dur}")
|
||||||
|
num_datapoints += 1
|
||||||
|
return num_datapoints
|
||||||
|
|
||||||
|
|
||||||
|
def ui_data_generator(dataset_dir, asr_data_source, verbose=False):
|
||||||
|
(dataset_dir / Path("wavs")).mkdir(parents=True, exist_ok=True)
|
||||||
|
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def data_fn(
|
||||||
|
transcript,
|
||||||
|
audio_dur,
|
||||||
|
wav_data,
|
||||||
|
caller_name,
|
||||||
|
aud_seg,
|
||||||
|
fname,
|
||||||
|
audio_file,
|
||||||
|
num_datapoints,
|
||||||
|
rel_data_path,
|
||||||
|
):
|
||||||
|
png_path = Path(fname).with_suffix(".png")
|
||||||
|
rel_plot_path = Path("wav_plots") / png_path
|
||||||
|
wav_plot_path = dataset_dir / rel_plot_path
|
||||||
|
if not wav_plot_path.exists():
|
||||||
|
plot_seg(wav_plot_path.absolute(), audio_file)
|
||||||
|
return {
|
||||||
|
"audio_path": str(rel_data_path),
|
||||||
|
"duration": round(audio_dur, 1),
|
||||||
|
"text": transcript,
|
||||||
|
"real_idx": num_datapoints,
|
||||||
|
"caller": caller_name,
|
||||||
|
"utterance_id": fname,
|
||||||
|
"plot_path": str(rel_plot_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
num_datapoints = 0
|
||||||
|
data_funcs = []
|
||||||
|
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
||||||
|
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||||
|
audio_file = (
|
||||||
|
dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav")
|
||||||
|
).absolute()
|
||||||
|
audio_file.write_bytes(wav_data)
|
||||||
|
# audio_path = str(audio_file)
|
||||||
|
rel_data_path = audio_file.relative_to(dataset_dir.absolute())
|
||||||
|
data_funcs.append(
|
||||||
|
partial(
|
||||||
|
data_fn,
|
||||||
|
transcript,
|
||||||
|
audio_dur,
|
||||||
|
wav_data,
|
||||||
|
caller_name,
|
||||||
|
aud_seg,
|
||||||
|
fname,
|
||||||
|
audio_file,
|
||||||
|
num_datapoints,
|
||||||
|
rel_data_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
num_datapoints += 1
|
||||||
|
ui_data = parallel_apply(lambda x: x(), data_funcs)
|
||||||
|
return ui_data, num_datapoints
|
||||||
|
|
||||||
|
|
||||||
|
def ui_dump_manifest_writer(dataset_dir, asr_data_source, verbose=False):
|
||||||
|
dump_data, num_datapoints = ui_data_generator(
|
||||||
|
dataset_dir, asr_data_source, verbose=verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
asr_manifest = dataset_dir / Path("manifest.json")
|
||||||
|
with asr_manifest.open("w") as mf:
|
||||||
|
print(f"writing manifest to {asr_manifest}")
|
||||||
|
for d in dump_data:
|
||||||
|
rel_data_path = d["audio_path"]
|
||||||
|
audio_dur = d["duration"]
|
||||||
|
transcript = d["text"]
|
||||||
|
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||||
|
mf.write(manifest)
|
||||||
|
|
||||||
|
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||||
|
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
|
||||||
|
return num_datapoints
|
||||||
|
|
||||||
|
|
||||||
|
def asr_manifest_reader(data_manifest_path: Path):
|
||||||
|
print(f"reading manifest from {data_manifest_path}")
|
||||||
|
with data_manifest_path.open("r") as pf:
|
||||||
|
data_jsonl = pf.readlines()
|
||||||
|
data_data = [json.loads(v) for v in data_jsonl]
|
||||||
|
for p in data_data:
|
||||||
|
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
||||||
|
p["text"] = p["text"].strip()
|
||||||
|
yield p
|
||||||
|
|
||||||
|
|
||||||
|
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
|
||||||
|
with asr_manifest_path.open("w") as mf:
|
||||||
|
print(f"opening {asr_manifest_path} for writing manifest")
|
||||||
|
for mani_dict in manifest_str_source:
|
||||||
|
manifest = manifest_str(
|
||||||
|
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
|
||||||
|
)
|
||||||
|
mf.write(manifest)
|
||||||
|
|
||||||
|
|
||||||
|
def asr_test_writer(out_file_path: Path, source):
|
||||||
|
def dd_str(dd, idx):
|
||||||
|
path = dd["audio_filepath"]
|
||||||
|
# dur = dd["duration"]
|
||||||
|
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
|
||||||
|
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
|
||||||
|
|
||||||
|
res_file = out_file_path.with_suffix(".result.json")
|
||||||
|
with out_file_path.open("w") as of:
|
||||||
|
print(f"opening {out_file_path} for writing test")
|
||||||
|
results = []
|
||||||
|
idx = 0
|
||||||
|
for ui_dd in source:
|
||||||
|
results.append(ui_dd)
|
||||||
|
out_str = dd_str(ui_dd, idx)
|
||||||
|
of.write(out_str)
|
||||||
|
idx += 1
|
||||||
|
of.write("DO_HANGUP\n")
|
||||||
|
ExtendedPath(res_file).write_json(results)
|
||||||
|
|
||||||
|
|
||||||
|
def batch(iterable, n=1):
|
||||||
|
ls = len(iterable)
|
||||||
|
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
|
||||||
|
|
||||||
|
|
||||||
|
class ExtendedPath(type(Path())):
|
||||||
|
"""docstring for ExtendedPath."""
|
||||||
|
|
||||||
|
def read_json(self):
|
||||||
|
print(f"reading json from {self}")
|
||||||
|
with self.open("r") as jf:
|
||||||
|
return json.load(jf)
|
||||||
|
|
||||||
|
def read_yaml(self):
|
||||||
|
yaml = YAML(typ="safe", pure=True)
|
||||||
|
print(f"reading yaml from {self}")
|
||||||
|
with self.open("r") as yf:
|
||||||
|
return yaml.load(yf)
|
||||||
|
|
||||||
|
def read_jsonl(self):
|
||||||
|
print(f"reading jsonl from {self}")
|
||||||
|
with self.open("r") as jf:
|
||||||
|
for l in jf.readlines():
|
||||||
|
yield json.loads(l)
|
||||||
|
|
||||||
|
def write_json(self, data):
|
||||||
|
print(f"writing json to {self}")
|
||||||
|
self.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with self.open("w") as jf:
|
||||||
|
json.dump(data, jf, indent=2)
|
||||||
|
|
||||||
|
def write_yaml(self, data):
|
||||||
|
yaml = YAML()
|
||||||
|
print(f"writing yaml to {self}")
|
||||||
|
with self.open("w") as yf:
|
||||||
|
yaml.dump(data, yf)
|
||||||
|
|
||||||
|
def write_jsonl(self, data):
|
||||||
|
print(f"writing jsonl to {self}")
|
||||||
|
self.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with self.open("w") as jf:
|
||||||
|
for d in data:
|
||||||
|
jf.write(json.dumps(d) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def get_mongo_coll(uri):
|
||||||
|
ud = pymongo.uri_parser.parse_uri(uri)
|
||||||
|
conn = pymongo.MongoClient(uri)
|
||||||
|
return conn[ud["database"]][ud["collection"]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_mongo_conn(host="", port=27017, db="db", col="collection"):
|
||||||
|
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
||||||
|
mongo_uri = f"mongodb://{mongo_host}:{port}/"
|
||||||
|
return pymongo.MongoClient(mongo_uri)[db][col]
|
||||||
|
|
||||||
|
|
||||||
|
def strip_silence(sound):
|
||||||
|
from pydub.silence import detect_leading_silence
|
||||||
|
|
||||||
|
start_trim = detect_leading_silence(sound)
|
||||||
|
end_trim = detect_leading_silence(sound.reverse())
|
||||||
|
duration = len(sound)
|
||||||
|
return sound[start_trim : duration - end_trim]
|
||||||
|
|
||||||
|
|
||||||
|
def plot_seg(wav_plot_path, audio_path):
|
||||||
|
fig = plt.Figure()
|
||||||
|
ax = fig.add_subplot()
|
||||||
|
(y, sr) = librosa.load(str(audio_path))
|
||||||
|
audio_display.waveplot(y=y, sr=sr, ax=ax)
|
||||||
|
with wav_plot_path.open("wb") as wav_plot_f:
|
||||||
|
fig.set_tight_layout(True)
|
||||||
|
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_apply(fn, iterable, workers=8, pool="thread"):
|
||||||
|
if pool == "thread":
|
||||||
|
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||||
|
print(f"parallelly applying {fn}")
|
||||||
|
return [
|
||||||
|
res
|
||||||
|
for res in tqdm(
|
||||||
|
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif pool == "process":
|
||||||
|
with ProcessPoolExecutor(max_workers=workers) as exe:
|
||||||
|
print(f"parallelly applying {fn}")
|
||||||
|
return [
|
||||||
|
res
|
||||||
|
for res in tqdm(
|
||||||
|
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise Exception(f"unsupported pool type - {pool}")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
|
||||||
|
min_nums = 3
|
||||||
|
max_duration = 1 * 60 * 60
|
||||||
|
skip_duration = 1 * 60 * 60
|
||||||
|
|
||||||
|
def filtered_max_dur():
|
||||||
|
wav_duration = 0
|
||||||
|
for s in ExtendedPath(data_file).read_jsonl():
|
||||||
|
nums = re.sub(" ", "", s["text"])
|
||||||
|
if len(nums) >= min_nums:
|
||||||
|
wav_duration += s["duration"]
|
||||||
|
shutil.copy(
|
||||||
|
src_dataset_path / Path(s["audio_filepath"]),
|
||||||
|
dest_dataset_path / Path(s["audio_filepath"]),
|
||||||
|
)
|
||||||
|
yield s
|
||||||
|
if wav_duration > max_duration:
|
||||||
|
break
|
||||||
|
typer.echo(f"filtered only {duration_str(wav_duration)} of audio")
|
||||||
|
|
||||||
|
def filtered_skip_dur():
|
||||||
|
wav_duration = 0
|
||||||
|
for s in ExtendedPath(data_file).read_jsonl():
|
||||||
|
nums = re.sub(" ", "", s["text"])
|
||||||
|
if len(nums) >= min_nums:
|
||||||
|
wav_duration += s["duration"]
|
||||||
|
if wav_duration <= skip_duration:
|
||||||
|
continue
|
||||||
|
elif len(nums) >= min_nums:
|
||||||
|
yield s
|
||||||
|
shutil.copy(
|
||||||
|
src_dataset_path / Path(s["audio_filepath"]),
|
||||||
|
dest_dataset_path / Path(s["audio_filepath"]),
|
||||||
|
)
|
||||||
|
typer.echo(f"skipped {duration_str(skip_duration)} of audio")
|
||||||
|
|
||||||
|
def filtered_blanks():
|
||||||
|
blank_count = 0
|
||||||
|
for s in ExtendedPath(data_file).read_jsonl():
|
||||||
|
nums = re.sub(" ", "", s["text"])
|
||||||
|
if nums != "":
|
||||||
|
blank_count += 1
|
||||||
|
shutil.copy(
|
||||||
|
src_dataset_path / Path(s["audio_filepath"]),
|
||||||
|
dest_dataset_path / Path(s["audio_filepath"]),
|
||||||
|
)
|
||||||
|
yield s
|
||||||
|
typer.echo(f"filtered {blank_count} blank samples")
|
||||||
|
|
||||||
|
def filtered_transform_digits():
|
||||||
|
count = 0
|
||||||
|
for s in ExtendedPath(data_file).read_jsonl():
|
||||||
|
count += 1
|
||||||
|
digit_text = replace_digit_symbol(s["text"])
|
||||||
|
only_digits = discard_except_digits(digit_text)
|
||||||
|
char_text = digits_to_chars(only_digits)
|
||||||
|
shutil.copy(
|
||||||
|
src_dataset_path / Path(s["audio_filepath"]),
|
||||||
|
dest_dataset_path / Path(s["audio_filepath"]),
|
||||||
|
)
|
||||||
|
s["text"] = char_text
|
||||||
|
yield s
|
||||||
|
typer.echo(f"transformed {count} samples")
|
||||||
|
|
||||||
|
def filtered_extract_chars():
|
||||||
|
count = 0
|
||||||
|
for s in ExtendedPath(data_file).read_jsonl():
|
||||||
|
count += 1
|
||||||
|
no_digits = digits_to_chars(s["text"]).upper()
|
||||||
|
only_chars = re.sub("[^A-Z'\b]", " ", no_digits)
|
||||||
|
filter_text = replace_redundant_spaces_with(only_chars, " ").strip()
|
||||||
|
shutil.copy(
|
||||||
|
src_dataset_path / Path(s["audio_filepath"]),
|
||||||
|
dest_dataset_path / Path(s["audio_filepath"]),
|
||||||
|
)
|
||||||
|
s["text"] = filter_text
|
||||||
|
yield s
|
||||||
|
typer.echo(f"transformed {count} samples")
|
||||||
|
|
||||||
|
def filtered_resample():
|
||||||
|
count = 0
|
||||||
|
for s in ExtendedPath(data_file).read_jsonl():
|
||||||
|
count += 1
|
||||||
|
src_aud = pydub.AudioSegment.from_file(
|
||||||
|
src_dataset_path / Path(s["audio_filepath"])
|
||||||
|
)
|
||||||
|
dst_aud = src_aud.set_channels(1).set_sample_width(1).set_frame_rate(24000)
|
||||||
|
dst_aud.export(dest_dataset_path / Path(s["audio_filepath"]), format="wav")
|
||||||
|
yield s
|
||||||
|
typer.echo(f"transformed {count} samples")
|
||||||
|
|
||||||
|
filter_kind_map = {
|
||||||
|
"max_dur_1hr_min3num": filtered_max_dur,
|
||||||
|
"skip_dur_1hr_min3num": filtered_skip_dur,
|
||||||
|
"blanks": filtered_blanks,
|
||||||
|
"transform_digits": filtered_transform_digits,
|
||||||
|
"extract_chars": filtered_extract_chars,
|
||||||
|
"resample_ulaw24kmono": filtered_resample,
|
||||||
|
}
|
||||||
|
return filter_kind_map
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from .tts import GoogleTTS
|
||||||
|
# from IPython import display
|
||||||
|
import requests
|
||||||
|
import io
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from plume.utils import lazy_module
|
||||||
|
|
||||||
|
display = lazy_module('IPython.display')
|
||||||
|
pydub = lazy_module('pydub')
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
# Start gentle with following command
|
||||||
|
# docker run --rm -d --name gentle_service -p 8765:8765/tcp lowerquality/gentle
|
||||||
|
|
||||||
|
|
||||||
|
def gentle_aligner(service_uri, wav_data, utter_text):
|
||||||
|
# service_uri= "http://52.41.161.36:8765/transcriptions"
|
||||||
|
wav_f = io.BytesIO(wav_data)
|
||||||
|
wav_seg = pydub.AudioSegment.from_file(wav_f)
|
||||||
|
|
||||||
|
mp3_f = io.BytesIO()
|
||||||
|
wav_seg.export(mp3_f, format="mp3")
|
||||||
|
mp3_f.seek(0)
|
||||||
|
params = (("async", "false"),)
|
||||||
|
files = {
|
||||||
|
"audio": ("audio.mp3", mp3_f),
|
||||||
|
"transcript": ("words.txt", io.BytesIO(utter_text.encode("utf-8"))),
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(service_uri, params=params, files=files)
|
||||||
|
print(f"Time duration of audio {wav_seg.duration_seconds}")
|
||||||
|
print(f"Time taken to align: {response.elapsed}s")
|
||||||
|
return wav_seg, response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def gentle_align_iter(service_uri, wav_data, utter_text):
|
||||||
|
wav_seg, response = gentle_aligner(service_uri, wav_data, utter_text)
|
||||||
|
for span in response:
|
||||||
|
word_seg = wav_seg[int(span["start"] * 1000) : int(span["end"] * 1000)]
|
||||||
|
word = span["word"]
|
||||||
|
yield (word, word_seg)
|
||||||
|
|
||||||
|
|
||||||
|
def tts_jupyter():
|
||||||
|
google_voices = GoogleTTS.voice_list()
|
||||||
|
gtts = GoogleTTS()
|
||||||
|
# google_voices[4]
|
||||||
|
us_voice = [v for v in google_voices if v["language"] == "en-US"][0]
|
||||||
|
utter_text = (
|
||||||
|
"I would like to align the audio segments based on word level timestamps"
|
||||||
|
)
|
||||||
|
wav_data = gtts.text_to_speech(text=utter_text, params=us_voice)
|
||||||
|
for word, seg in gentle_align_iter(wav_data, utter_text):
|
||||||
|
print(word)
|
||||||
|
display.display(seg)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def cut(audio_path: Path, transcript_path: Path, out_dir: Path = "/tmp"):
|
||||||
|
from . import ExtendedPath
|
||||||
|
import datetime
|
||||||
|
import re
|
||||||
|
|
||||||
|
aud_seg = pydub.AudioSegment.from_file(audio_path)
|
||||||
|
aud_seg[: 15 * 60 * 1000].export(out_dir / Path("audio.mp3"), format="mp3")
|
||||||
|
tscript_json = ExtendedPath(transcript_path).read_json()
|
||||||
|
|
||||||
|
def time_to_msecs(time_str):
|
||||||
|
return (
|
||||||
|
datetime.datetime.strptime(time_str, "%H:%M:%S,%f")
|
||||||
|
- datetime.datetime(1900, 1, 1)
|
||||||
|
).total_seconds() * 1000
|
||||||
|
|
||||||
|
tscript_words = []
|
||||||
|
broken = False
|
||||||
|
for m in tscript_json["monologues"]:
|
||||||
|
# tscript_words.append("|")
|
||||||
|
for e in m["elements"]:
|
||||||
|
if e["type"] == "text":
|
||||||
|
text = e["value"]
|
||||||
|
text = re.sub(r"\[.*\]", "", text)
|
||||||
|
text = re.sub(r"\(.*\)", "", text)
|
||||||
|
tscript_words.append(text)
|
||||||
|
if "timestamp" in e and time_to_msecs(e["timestamp"]) >= 15 * 60 * 1000:
|
||||||
|
broken = True
|
||||||
|
break
|
||||||
|
if broken:
|
||||||
|
break
|
||||||
|
(out_dir / Path("words.txt")).write_text("".join(tscript_words))
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def gentle_preview(
|
||||||
|
audio_path: Path,
|
||||||
|
transcript_path: Path,
|
||||||
|
service_uri="http://101.53.142.218:8765/transcriptions",
|
||||||
|
gent_preview_dir="../gentle_preview",
|
||||||
|
):
|
||||||
|
from . import ExtendedPath
|
||||||
|
|
||||||
|
ab = audio_path.read_bytes()
|
||||||
|
tt = transcript_path.read_text()
|
||||||
|
audio, alignment = gentle_aligner(service_uri, ab, tt)
|
||||||
|
audio.export(gent_preview_dir / Path("a.wav"), format="wav")
|
||||||
|
alignment["status"] = "OK"
|
||||||
|
ExtendedPath(gent_preview_dir / Path("status.json")).write_json(alignment)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
from scipy.signal import lfilter, butter
|
||||||
|
from scipy.io.wavfile import read, write
|
||||||
|
from numpy import array, int16
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def butter_params(low_freq, high_freq, fs, order=5):
|
||||||
|
nyq = 0.5 * fs
|
||||||
|
low = low_freq / nyq
|
||||||
|
high = high_freq / nyq
|
||||||
|
b, a = butter(order, [low, high], btype="band")
|
||||||
|
return b, a
|
||||||
|
|
||||||
|
|
||||||
|
def butter_bandpass_filter(data, low_freq, high_freq, fs, order=5):
|
||||||
|
b, a = butter_params(low_freq, high_freq, fs, order=order)
|
||||||
|
y = lfilter(b, a, data)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fs, audio = read(sys.argv[1])
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
|
low_freq = 300.0
|
||||||
|
high_freq = 4000.0
|
||||||
|
filtered_signal = butter_bandpass_filter(audio, low_freq, high_freq, fs, order=6)
|
||||||
|
fname = sys.argv[1].split(".wav")[0] + "_moded.wav"
|
||||||
|
write(fname, fs, array(filtered_signal, dtype=int16))
|
||||||
|
|
@ -0,0 +1,737 @@
|
||||||
|
# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
|
||||||
|
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4
|
||||||
|
#
|
||||||
|
# lazy_import --- https://github.com/mnmelo/lazy_import
|
||||||
|
# Copyright (C) 2017-2018 Manuel Nuno Melo
|
||||||
|
#
|
||||||
|
# This file is part of lazy_import.
|
||||||
|
#
|
||||||
|
# lazy_import is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# lazy_import is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License
|
||||||
|
# along with lazy_import. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
#
|
||||||
|
# lazy_import was based on code from the importing module from the PEAK
|
||||||
|
# package (see <http://peak.telecommunity.com/DevCenter/Importing>). The PEAK
|
||||||
|
# package is released under the following license, reproduced here:
|
||||||
|
#
|
||||||
|
# Copyright (C) 1996-2004 by Phillip J. Eby and Tyler C. Sarna.
|
||||||
|
# All rights reserved. This software may be used under the same terms
|
||||||
|
# as Zope or Python. THERE ARE ABSOLUTELY NO WARRANTIES OF ANY KIND.
|
||||||
|
# Code quality varies between modules, from "beta" to "experimental
|
||||||
|
# pre-alpha". :)
|
||||||
|
#
|
||||||
|
# Code pertaining to lazy loading from PEAK importing was included in
|
||||||
|
# lazy_import, modified in a number of ways. These are detailed in the
|
||||||
|
# CHANGELOG file of lazy_import. Changes mainly involved Python 3
|
||||||
|
# compatibility, extension to allow customizable behavior, and added
|
||||||
|
# functionality (lazy importing of callable objects).
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
Lazy module loading
|
||||||
|
===================
|
||||||
|
Functions and classes for lazy module loading that also delay import errors.
|
||||||
|
Heavily borrowed from the `importing`_ module.
|
||||||
|
.. _`importing`: http://peak.telecommunity.com/DevCenter/Importing
|
||||||
|
Files and directories
|
||||||
|
---------------------
|
||||||
|
.. autofunction:: module
|
||||||
|
.. autofunction:: callable
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"lazy_module",
|
||||||
|
"lazy_callable",
|
||||||
|
"lazy_function",
|
||||||
|
"lazy_class",
|
||||||
|
"LazyModule",
|
||||||
|
"LazyCallable",
|
||||||
|
"module_basename",
|
||||||
|
"_MSG",
|
||||||
|
"_MSG_CALLABLE",
|
||||||
|
]
|
||||||
|
|
||||||
|
from types import ModuleType
|
||||||
|
import sys
|
||||||
|
|
||||||
|
try:
|
||||||
|
from importlib._bootstrap import _ImportLockContext
|
||||||
|
except ImportError:
|
||||||
|
# Python 2 doesn't have the context manager. Roll it ourselves (copied from
|
||||||
|
# Python 3's importlib/_bootstrap.py)
|
||||||
|
import imp
|
||||||
|
|
||||||
|
class _ImportLockContext:
|
||||||
|
"""Context manager for the import lock."""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
imp.acquire_lock()
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
|
imp.release_lock()
|
||||||
|
|
||||||
|
|
||||||
|
# Adding a __spec__ doesn't really help. I'll leave the code here in case
|
||||||
|
# future python implementations start relying on it.
|
||||||
|
# try:
|
||||||
|
# from importlib.machinery import ModuleSpec
|
||||||
|
# except ImportError:
|
||||||
|
# ModuleSpec = None
|
||||||
|
|
||||||
|
import six
|
||||||
|
from six import raise_from
|
||||||
|
from six.moves import reload_module
|
||||||
|
|
||||||
|
# It is sometime useful to have access to the version number of a library.
|
||||||
|
# This is usually done through the __version__ special attribute.
|
||||||
|
# To make sure the version number is consistent between setup.py and the
|
||||||
|
# library, we read the version number from the file called VERSION that stays
|
||||||
|
# in the module directory.
|
||||||
|
import os
|
||||||
|
|
||||||
|
# VERSION_FILE = os.path.join(os.path.dirname(__file__), "VERSION")
|
||||||
|
# with open(VERSION_FILE) as infile:
|
||||||
|
# __version__ = infile.read().strip()
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# adding a TRACE level for stack debugging
|
||||||
|
_LAZY_TRACE = 1
|
||||||
|
logging.addLevelName(1, "LAZY_TRACE")
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
# Logs a formatted stack (takes no message or args/kwargs)
|
||||||
|
def _lazy_trace(self):
|
||||||
|
if self.isEnabledFor(_LAZY_TRACE):
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
self._log(_LAZY_TRACE, " ### STACK TRACE ###", ())
|
||||||
|
for line in traceback.format_stack(sys._getframe(2)):
|
||||||
|
for subline in line.split("\n"):
|
||||||
|
self._log(_LAZY_TRACE, subline.rstrip(), ())
|
||||||
|
|
||||||
|
|
||||||
|
logging.Logger.lazy_trace = _lazy_trace
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
################################
|
||||||
|
# Module/function registration #
|
||||||
|
################################
|
||||||
|
|
||||||
|
#### Lazy classes ####
|
||||||
|
|
||||||
|
|
||||||
|
class LazyModule(ModuleType):
|
||||||
|
"""Class for lazily-loaded modules that triggers proper loading on access.
|
||||||
|
Instantiation should be made from a subclass of :class:`LazyModule`, with
|
||||||
|
one subclass per instantiated module. Regular attribute set/access can then
|
||||||
|
be recovered by setting the subclass's :meth:`__getattribute__` and
|
||||||
|
:meth:`__setattribute__` to those of :class:`types.ModuleType`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# peak.util.imports sets __slots__ to (), but it seems pointless because
|
||||||
|
# the base ModuleType doesn't itself set __slots__.
|
||||||
|
def __getattribute__(self, attr):
|
||||||
|
logger.debug(
|
||||||
|
"Getting attr {} of LazyModule instance of {}".format(
|
||||||
|
attr, super(LazyModule, self).__getattribute__("__name__")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.lazy_trace()
|
||||||
|
# IPython tries to be too clever and constantly inspects, asking for
|
||||||
|
# modules' attrs, which causes premature module loading and unesthetic
|
||||||
|
# internal errors if the lazily-loaded module doesn't exist.
|
||||||
|
if (
|
||||||
|
run_from_ipython()
|
||||||
|
and (attr.startswith(("__", "_ipython")) or attr == "_repr_mimebundle_")
|
||||||
|
and module_basename(_caller_name()) in ("inspect", "IPython")
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
"Ignoring request for {}, deemed from IPython's "
|
||||||
|
"inspection.".format(
|
||||||
|
super(LazyModule, self).__getattribute__("__name__"), attr
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise AttributeError
|
||||||
|
if not attr in ("__name__", "__class__", "__spec__"):
|
||||||
|
# __name__ and __class__ yield their values from the LazyModule;
|
||||||
|
# __spec__ causes an AttributeError. Maybe in the future it will be
|
||||||
|
# necessary to return an actual ModuleSpec object, but it works as
|
||||||
|
# it is without that now.
|
||||||
|
|
||||||
|
# If it's an already-loaded submodule, we return it without
|
||||||
|
# triggering a full loading
|
||||||
|
try:
|
||||||
|
return sys.modules[self.__name__ + "." + attr]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
# Check if it's one of the lazy callables
|
||||||
|
try:
|
||||||
|
_callable = type(self)._lazy_import_callables[attr]
|
||||||
|
logger.debug("Returning lazy-callable '{}'.".format(attr))
|
||||||
|
return _callable
|
||||||
|
except (AttributeError, KeyError) as err:
|
||||||
|
logger.debug(
|
||||||
|
"Proceeding to load module {}, "
|
||||||
|
"from requested value {}".format(
|
||||||
|
super(LazyModule, self).__getattribute__("__name__"), attr
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_load_module(self)
|
||||||
|
logger.debug(
|
||||||
|
"Returning value '{}'.".format(
|
||||||
|
super(LazyModule, self).__getattribute__(attr)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return super(LazyModule, self).__getattribute__(attr)
|
||||||
|
|
||||||
|
def __setattr__(self, attr, value):
|
||||||
|
logger.debug(
|
||||||
|
"Setting attr {} to value {}, in LazyModule instance "
|
||||||
|
"of {}".format(
|
||||||
|
attr, value, super(LazyModule, self).__getattribute__("__name__")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_load_module(self)
|
||||||
|
return super(LazyModule, self).__setattr__(attr, value)
|
||||||
|
|
||||||
|
|
||||||
|
class LazyCallable(object):
|
||||||
|
"""Class for lazily-loaded callables that triggers module loading on access
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
if len(args) != 2:
|
||||||
|
# Maybe the user tried to base a class off this lazy callable?
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
"Got wrong number of args when init'ing "
|
||||||
|
"LazyCallable. args is '{}'".format(args)
|
||||||
|
)
|
||||||
|
base = args[1][0]
|
||||||
|
if isinstance(base, LazyCallable) and len(args) == 3:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"It seems you are trying to use "
|
||||||
|
"a lazy callable as a class "
|
||||||
|
"base. This is not supported."
|
||||||
|
)
|
||||||
|
except (IndexError, TypeError):
|
||||||
|
raise_from(
|
||||||
|
TypeError(
|
||||||
|
"LazyCallable takes exactly 2 arguments: "
|
||||||
|
"a module/lazy module object and the name of "
|
||||||
|
"a callable to be lazily loaded."
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
self.module, self.cname = args
|
||||||
|
self.modclass = type(self.module)
|
||||||
|
self.callable = None
|
||||||
|
# Need to save these, since the module-loading gets rid of them
|
||||||
|
self.error_msgs = self.modclass._lazy_import_error_msgs
|
||||||
|
self.error_strings = self.modclass._lazy_import_error_strings
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
# No need to go through all the reloading more than once.
|
||||||
|
if self.callable:
|
||||||
|
return self.callable(*args, **kwargs)
|
||||||
|
try:
|
||||||
|
del self.modclass._lazy_import_callables[self.cname]
|
||||||
|
except (AttributeError, KeyError):
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self.callable = getattr(self.module, self.cname)
|
||||||
|
except AttributeError:
|
||||||
|
msg = self.error_msgs["msg_callable"]
|
||||||
|
raise_from(
|
||||||
|
AttributeError(msg.format(callable=self.cname, **self.error_strings)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
except ImportError as err:
|
||||||
|
# Import failed. We reset the dict and re-raise the ImportError.
|
||||||
|
try:
|
||||||
|
self.modclass._lazy_import_callables[self.cname] = self
|
||||||
|
except AttributeError:
|
||||||
|
self.modclass._lazy_import_callables = {self.cname: self}
|
||||||
|
raise_from(err, None)
|
||||||
|
else:
|
||||||
|
return self.callable(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
### Functions ###
|
||||||
|
|
||||||
|
|
||||||
|
def lazy_module(modname, error_strings=None, lazy_mod_class=LazyModule, level="leaf"):
|
||||||
|
"""Function allowing lazy importing of a module into the namespace.
|
||||||
|
A lazy module object is created, registered in `sys.modules`, and
|
||||||
|
returned. This is a hollow module; actual loading, and `ImportErrors` if
|
||||||
|
not found, are delayed until an attempt is made to access attributes of the
|
||||||
|
lazy module.
|
||||||
|
A handy application is to use :func:`lazy_module` early in your own code
|
||||||
|
(say, in `__init__.py`) to register all modulenames you want to be lazy.
|
||||||
|
Because of registration in `sys.modules` later invocations of
|
||||||
|
`import modulename` will also return the lazy object. This means that after
|
||||||
|
initial registration the rest of your code can use regular pyhon import
|
||||||
|
statements and retain the lazyness of the modules.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
modname : str
|
||||||
|
The module to import.
|
||||||
|
error_strings : dict, optional
|
||||||
|
A dictionary of strings to use when module-loading fails. Key 'msg'
|
||||||
|
sets the message to use (defaults to :attr:`lazy_import._MSG`). The
|
||||||
|
message is formatted using the remaining dictionary keys. The default
|
||||||
|
message informs the user of which module is missing (key 'module'),
|
||||||
|
what code loaded the module as lazy (key 'caller'), and which package
|
||||||
|
should be installed to solve the dependency (key 'install_name').
|
||||||
|
None of the keys is mandatory and all are given smart names by default.
|
||||||
|
lazy_mod_class: type, optional
|
||||||
|
Which class to use when instantiating the lazy module, to allow
|
||||||
|
deep customization. The default is :class:`LazyModule` and custom
|
||||||
|
alternatives **must** be a subclass thereof.
|
||||||
|
level : str, optional
|
||||||
|
Which submodule reference to return. Either a reference to the 'leaf'
|
||||||
|
module (the default) or to the 'base' module. This is useful if you'll
|
||||||
|
be using the module functionality in the same place you're calling
|
||||||
|
:func:`lazy_module` from, since then you don't need to run `import`
|
||||||
|
again. Setting *level* does not affect which names/modules get
|
||||||
|
registered in `sys.modules`.
|
||||||
|
For *level* set to 'base' and *modulename* 'aaa.bbb.ccc'::
|
||||||
|
aaa = lazy_import.lazy_module("aaa.bbb.ccc", level='base')
|
||||||
|
# 'aaa' becomes defined in the current namespace, with
|
||||||
|
# (sub)attributes 'aaa.bbb' and 'aaa.bbb.ccc'.
|
||||||
|
# It's the lazy equivalent to:
|
||||||
|
import aaa.bbb.ccc
|
||||||
|
For *level* set to 'leaf'::
|
||||||
|
ccc = lazy_import.lazy_module("aaa.bbb.ccc", level='leaf')
|
||||||
|
# Only 'ccc' becomes set in the current namespace.
|
||||||
|
# Lazy equivalent to:
|
||||||
|
from aaa.bbb import ccc
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
module
|
||||||
|
The module specified by *modname*, or its base, depending on *level*.
|
||||||
|
The module isn't immediately imported. Instead, an instance of
|
||||||
|
*lazy_mod_class* is returned. Upon access to any of its attributes, the
|
||||||
|
module is finally loaded.
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lazy_import, sys
|
||||||
|
>>> np = lazy_import.lazy_module("numpy")
|
||||||
|
>>> np
|
||||||
|
Lazily-loaded module numpy
|
||||||
|
>>> np is sys.modules['numpy']
|
||||||
|
True
|
||||||
|
>>> np.pi # This causes the full loading of the module ...
|
||||||
|
3.141592653589793
|
||||||
|
>>> np # ... and the module is changed in place.
|
||||||
|
<module 'numpy' from '/usr/local/lib/python/site-packages/numpy/__init__.py'>
|
||||||
|
>>> import lazy_import, sys
|
||||||
|
>>> # The following succeeds even when asking for a module that's not available
|
||||||
|
>>> missing = lazy_import.lazy_module("missing_module")
|
||||||
|
>>> missing
|
||||||
|
Lazily-loaded module missing_module
|
||||||
|
>>> missing is sys.modules['missing_module']
|
||||||
|
True
|
||||||
|
>>> missing.some_attr # This causes the full loading of the module, which now fails.
|
||||||
|
ImportError: __main__ attempted to use a functionality that requires module missing_module, but it couldn't be loaded. Please install missing_module and retry.
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
:func:`lazy_callable`
|
||||||
|
:class:`LazyModule`
|
||||||
|
"""
|
||||||
|
if error_strings is None:
|
||||||
|
error_strings = {}
|
||||||
|
_set_default_errornames(modname, error_strings)
|
||||||
|
|
||||||
|
mod = _lazy_module(modname, error_strings, lazy_mod_class)
|
||||||
|
if level == "base":
|
||||||
|
return sys.modules[module_basename(modname)]
|
||||||
|
elif level == "leaf":
|
||||||
|
return mod
|
||||||
|
else:
|
||||||
|
raise ValueError("Parameter 'level' must be one of ('base', 'leaf')")
|
||||||
|
|
||||||
|
|
||||||
|
def _lazy_module(modname, error_strings, lazy_mod_class):
|
||||||
|
with _ImportLockContext():
|
||||||
|
fullmodname = modname
|
||||||
|
fullsubmodname = None
|
||||||
|
# ensure parent module/package is in sys.modules
|
||||||
|
# and parent.modname=module, as soon as the parent is imported
|
||||||
|
while modname:
|
||||||
|
try:
|
||||||
|
mod = sys.modules[modname]
|
||||||
|
# We reached a (base) module that's already loaded. Let's stop
|
||||||
|
# the cycle. Can't use 'break' because we still want to go
|
||||||
|
# through the fullsubmodname check below.
|
||||||
|
modname = ""
|
||||||
|
except KeyError:
|
||||||
|
err_s = error_strings.copy()
|
||||||
|
err_s.setdefault("module", modname)
|
||||||
|
|
||||||
|
class _LazyModule(lazy_mod_class):
|
||||||
|
_lazy_import_error_msgs = {"msg": err_s.pop("msg")}
|
||||||
|
try:
|
||||||
|
_lazy_import_error_msgs["msg_callable"] = err_s.pop(
|
||||||
|
"msg_callable"
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
_lazy_import_error_strings = err_s
|
||||||
|
_lazy_import_callables = {}
|
||||||
|
_lazy_import_submodules = {}
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "Lazily-loaded module {}".format(self.__name__)
|
||||||
|
|
||||||
|
# A bit of cosmetic, to make AttributeErrors read more natural
|
||||||
|
_LazyModule.__name__ = "module"
|
||||||
|
# Actual module instantiation
|
||||||
|
mod = sys.modules[modname] = _LazyModule(modname)
|
||||||
|
# No need for __spec__. Maybe in the future.
|
||||||
|
# if ModuleSpec:
|
||||||
|
# ModuleType.__setattr__(mod, '__spec__',
|
||||||
|
# ModuleSpec(modname, None))
|
||||||
|
if fullsubmodname:
|
||||||
|
submod = sys.modules[fullsubmodname]
|
||||||
|
ModuleType.__setattr__(mod, submodname, submod)
|
||||||
|
_LazyModule._lazy_import_submodules[submodname] = submod
|
||||||
|
fullsubmodname = modname
|
||||||
|
modname, _, submodname = modname.rpartition(".")
|
||||||
|
return sys.modules[fullmodname]
|
||||||
|
|
||||||
|
|
||||||
|
def lazy_callable(modname, *names, **kwargs):
|
||||||
|
"""Performs lazy importing of one or more callables.
|
||||||
|
:func:`lazy_callable` creates functions that are thin wrappers that pass
|
||||||
|
any and all arguments straight to the target module's callables. These can
|
||||||
|
be functions or classes. The full loading of that module is only actually
|
||||||
|
triggered when the returned lazy function itself is called. This lazy
|
||||||
|
import of the target module uses the same mechanism as
|
||||||
|
:func:`lazy_module`.
|
||||||
|
|
||||||
|
If, however, the target module has already been fully imported prior
|
||||||
|
to invocation of :func:`lazy_callable`, then the target callables
|
||||||
|
themselves are returned and no lazy imports are made.
|
||||||
|
:func:`lazy_function` and :func:`lazy_function` are aliases of
|
||||||
|
:func:`lazy_callable`.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
modname : str
|
||||||
|
The base module from where to import the callable(s) in *names*,
|
||||||
|
or a full 'module_name.callable_name' string.
|
||||||
|
names : str (optional)
|
||||||
|
The callable name(s) to import from the module specified by *modname*.
|
||||||
|
If left empty, *modname* is assumed to also include the callable name
|
||||||
|
to import.
|
||||||
|
error_strings : dict, optional
|
||||||
|
A dictionary of strings to use when reporting loading errors (either a
|
||||||
|
missing module, or a missing callable name in the loaded module).
|
||||||
|
*error_string* follows the same usage as described under
|
||||||
|
:func:`lazy_module`, with the exceptions that 1) a further key,
|
||||||
|
'msg_callable', can be supplied to be used as the error when a module
|
||||||
|
is successfully loaded but the target callable can't be found therein
|
||||||
|
(defaulting to :attr:`lazy_import._MSG_CALLABLE`); 2) a key 'callable'
|
||||||
|
is always added with the callable name being loaded.
|
||||||
|
lazy_mod_class : type, optional
|
||||||
|
See definition under :func:`lazy_module`.
|
||||||
|
lazy_call_class : type, optional
|
||||||
|
Analogously to *lazy_mod_class*, allows setting a custom class to
|
||||||
|
handle lazy callables, other than the default :class:`LazyCallable`.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
wrapper function or tuple of wrapper functions
|
||||||
|
If *names* is passed, returns a tuple of wrapper functions, one for
|
||||||
|
each element in *names*.
|
||||||
|
If only *modname* is passed it is assumed to be a full
|
||||||
|
'module_name.callable_name' string, in which case the wrapper for the
|
||||||
|
imported callable is returned directly, and not in a tuple.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
Unlike :func:`lazy_module`, which returns a lazy module that eventually
|
||||||
|
mutates into the fully-functional version, :func:`lazy_callable` only
|
||||||
|
returns thin wrappers that never change. This means that the returned
|
||||||
|
wrapper object never truly becomes the one under the module's namespace,
|
||||||
|
even after successful loading of the module in *modname*. This is fine for
|
||||||
|
most practical use cases, but may break code that relies on the usage of
|
||||||
|
the returned objects oter than calling them. One such example is the lazy
|
||||||
|
import of a class: it's fine to use the returned wrapper to instantiate an
|
||||||
|
object, but it can't be used, for instance, to subclass from.
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lazy_import, sys
|
||||||
|
>>> fn = lazy_import.lazy_callable("numpy.arange")
|
||||||
|
>>> sys.modules['numpy']
|
||||||
|
Lazily-loaded module numpy
|
||||||
|
>>> fn(10)
|
||||||
|
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||||
|
>>> sys.modules['numpy']
|
||||||
|
<module 'numpy' from '/usr/local/lib/python3.5/site-packages/numpy/__init__.py'>
|
||||||
|
>>> import lazy_import, sys
|
||||||
|
>>> cl = lazy_import.lazy_callable("numpy.ndarray") # a class
|
||||||
|
>>> obj = cl([1, 2]) # This works OK (and also triggers the loading of numpy)
|
||||||
|
>>> class MySubclass(cl): # This fails because cls is just a wrapper,
|
||||||
|
>>> pass # not an actual class.
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
:func:`lazy_module`
|
||||||
|
:class:`LazyCallable`
|
||||||
|
:class:`LazyModule`
|
||||||
|
"""
|
||||||
|
if not names:
|
||||||
|
modname, _, name = modname.rpartition(".")
|
||||||
|
lazy_mod_class = _setdef(kwargs, "lazy_mod_class", LazyModule)
|
||||||
|
lazy_call_class = _setdef(kwargs, "lazy_call_class", LazyCallable)
|
||||||
|
error_strings = _setdef(kwargs, "error_strings", {})
|
||||||
|
_set_default_errornames(modname, error_strings, call=True)
|
||||||
|
|
||||||
|
if not names:
|
||||||
|
# We allow passing a single string as 'modname.callable_name',
|
||||||
|
# in which case the wrapper is returned directly and not as a list.
|
||||||
|
return _lazy_callable(
|
||||||
|
modname, name, error_strings.copy(), lazy_mod_class, lazy_call_class
|
||||||
|
)
|
||||||
|
return tuple(
|
||||||
|
_lazy_callable(
|
||||||
|
modname, cname, error_strings.copy(), lazy_mod_class, lazy_call_class
|
||||||
|
)
|
||||||
|
for cname in names
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
lazy_function = lazy_class = lazy_callable
|
||||||
|
|
||||||
|
|
||||||
|
def _lazy_callable(modname, cname, error_strings, lazy_mod_class, lazy_call_class):
|
||||||
|
# We could do most of this in the LazyCallable __init__, but here we can
|
||||||
|
# pre-check whether to actually be lazy or not.
|
||||||
|
module = _lazy_module(modname, error_strings, lazy_mod_class)
|
||||||
|
modclass = type(module)
|
||||||
|
if issubclass(modclass, LazyModule) and hasattr(modclass, "_lazy_import_callables"):
|
||||||
|
modclass._lazy_import_callables.setdefault(
|
||||||
|
cname, lazy_call_class(module, cname)
|
||||||
|
)
|
||||||
|
return getattr(module, cname)
|
||||||
|
|
||||||
|
|
||||||
|
#######################
|
||||||
|
# Real module loading #
|
||||||
|
#######################
|
||||||
|
|
||||||
|
|
||||||
|
def _load_module(module):
|
||||||
|
"""Ensures that a module, and its parents, are properly loaded
|
||||||
|
"""
|
||||||
|
modclass = type(module)
|
||||||
|
# We only take care of our own LazyModule instances
|
||||||
|
if not issubclass(modclass, LazyModule):
|
||||||
|
raise TypeError("Passed module is not a LazyModule instance.")
|
||||||
|
with _ImportLockContext():
|
||||||
|
parent, _, modname = module.__name__.rpartition(".")
|
||||||
|
logger.debug("loading module {}".format(modname))
|
||||||
|
# We first identify whether this is a loadable LazyModule, then we
|
||||||
|
# strip as much of lazy_import behavior as possible (keeping it cached,
|
||||||
|
# in case loading fails and we need to reset the lazy state).
|
||||||
|
if not hasattr(modclass, "_lazy_import_error_msgs"):
|
||||||
|
# Alreay loaded (no _lazy_import_error_msgs attr). Not reloading.
|
||||||
|
return
|
||||||
|
# First, ensure the parent is loaded (using recursion; *very* unlikely
|
||||||
|
# we'll ever hit a stack limit in this case).
|
||||||
|
modclass._LOADING = True
|
||||||
|
try:
|
||||||
|
if parent:
|
||||||
|
logger.debug("first loading parent module {}".format(parent))
|
||||||
|
setattr(sys.modules[parent], modname, module)
|
||||||
|
if not hasattr(modclass, "_LOADING"):
|
||||||
|
logger.debug("Module {} already loaded by the parent".format(modname))
|
||||||
|
# We've been loaded by the parent. Let's bail.
|
||||||
|
return
|
||||||
|
cached_data = _clean_lazymodule(module)
|
||||||
|
try:
|
||||||
|
# Get Python to do the real import!
|
||||||
|
reload_module(module)
|
||||||
|
except:
|
||||||
|
# Loading failed. We reset our lazy state.
|
||||||
|
logger.debug("Failed to load module {}. Resetting...".format(modname))
|
||||||
|
_reset_lazymodule(module, cached_data)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# Successful load
|
||||||
|
logger.debug("Successfully loaded module {}".format(modname))
|
||||||
|
delattr(modclass, "_LOADING")
|
||||||
|
_reset_lazy_submod_refs(module)
|
||||||
|
|
||||||
|
except (AttributeError, ImportError) as err:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to load {}.\n{}: {}".format(
|
||||||
|
modname, err.__class__.__name__, err
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.lazy_trace()
|
||||||
|
# Under Python 3 reloading our dummy LazyModule instances causes an
|
||||||
|
# AttributeError if the module can't be found. Would be preferrable
|
||||||
|
# if we could always rely on an ImportError. As it is we vet the
|
||||||
|
# AttributeError as thoroughly as possible.
|
||||||
|
if (six.PY3 and isinstance(err, AttributeError)) and not err.args[
|
||||||
|
0
|
||||||
|
] == "'NoneType' object has no attribute 'name'":
|
||||||
|
# Not the AttributeError we were looking for.
|
||||||
|
raise
|
||||||
|
msg = modclass._lazy_import_error_msgs["msg"]
|
||||||
|
raise_from(
|
||||||
|
ImportError(msg.format(**modclass._lazy_import_error_strings)), None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
##############################
|
||||||
|
# Helper functions/constants #
|
||||||
|
##############################
|
||||||
|
|
||||||
|
_MSG = (
|
||||||
|
"{caller} attempted to use a functionality that requires module "
|
||||||
|
"{module}, but it couldn't be loaded. Please install {install_name} "
|
||||||
|
"and retry."
|
||||||
|
)
|
||||||
|
|
||||||
|
_MSG_CALLABLE = (
|
||||||
|
"{caller} attempted to use a functionality that requires "
|
||||||
|
"{callable}, of module {module}, but it couldn't be found in that "
|
||||||
|
"module. Please install a version of {install_name} that has "
|
||||||
|
"{module}.{callable} and retry."
|
||||||
|
)
|
||||||
|
|
||||||
|
_CLS_ATTRS = (
|
||||||
|
"_lazy_import_error_strings",
|
||||||
|
"_lazy_import_error_msgs",
|
||||||
|
"_lazy_import_callables",
|
||||||
|
"_lazy_import_submodules",
|
||||||
|
"__repr__",
|
||||||
|
)
|
||||||
|
|
||||||
|
_DELETION_DICT = ("_lazy_import_submodules",)
|
||||||
|
|
||||||
|
|
||||||
|
def _setdef(argdict, name, defaultvalue):
|
||||||
|
"""Like dict.setdefault but sets the default value also if None is present.
|
||||||
|
"""
|
||||||
|
if not name in argdict or argdict[name] is None:
|
||||||
|
argdict[name] = defaultvalue
|
||||||
|
return argdict[name]
|
||||||
|
|
||||||
|
|
||||||
|
def module_basename(modname):
|
||||||
|
return modname.partition(".")[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _set_default_errornames(modname, error_strings, call=False):
|
||||||
|
# We don't set the modulename default here because it will change for
|
||||||
|
# parents of lazily imported submodules.
|
||||||
|
error_strings.setdefault("caller", _caller_name(3, default="Python"))
|
||||||
|
error_strings.setdefault("install_name", module_basename(modname))
|
||||||
|
error_strings.setdefault("msg", _MSG)
|
||||||
|
if call:
|
||||||
|
error_strings.setdefault("msg_callable", _MSG_CALLABLE)
|
||||||
|
|
||||||
|
|
||||||
|
def _caller_name(depth=2, default=""):
|
||||||
|
"""Returns the name of the calling namespace.
|
||||||
|
"""
|
||||||
|
# the presence of sys._getframe might be implementation-dependent.
|
||||||
|
# It isn't that serious if we can't get the caller's name.
|
||||||
|
try:
|
||||||
|
return sys._getframe(depth).f_globals["__name__"]
|
||||||
|
except AttributeError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_lazymodule(module):
|
||||||
|
"""Removes all lazy behavior from a module's class, for loading.
|
||||||
|
Also removes all module attributes listed under the module's class deletion
|
||||||
|
dictionaries. Deletion dictionaries are class attributes with names
|
||||||
|
specified in `_DELETION_DICT`.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
module: LazyModule
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
A dictionary of deleted class attributes, that can be used to reset the
|
||||||
|
lazy state using :func:`_reset_lazymodule`.
|
||||||
|
"""
|
||||||
|
modclass = type(module)
|
||||||
|
_clean_lazy_submod_refs(module)
|
||||||
|
|
||||||
|
modclass.__getattribute__ = ModuleType.__getattribute__
|
||||||
|
modclass.__setattr__ = ModuleType.__setattr__
|
||||||
|
cls_attrs = {}
|
||||||
|
for cls_attr in _CLS_ATTRS:
|
||||||
|
try:
|
||||||
|
cls_attrs[cls_attr] = getattr(modclass, cls_attr)
|
||||||
|
delattr(modclass, cls_attr)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
return cls_attrs
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_lazy_submod_refs(module):
|
||||||
|
modclass = type(module)
|
||||||
|
for deldict in _DELETION_DICT:
|
||||||
|
try:
|
||||||
|
delnames = getattr(modclass, deldict)
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
for delname in delnames:
|
||||||
|
try:
|
||||||
|
super(LazyModule, module).__delattr__(delname)
|
||||||
|
except AttributeError:
|
||||||
|
# Maybe raise a warning?
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_lazymodule(module, cls_attrs):
|
||||||
|
"""Resets a module's lazy state from cached data.
|
||||||
|
"""
|
||||||
|
modclass = type(module)
|
||||||
|
del modclass.__getattribute__
|
||||||
|
del modclass.__setattr__
|
||||||
|
try:
|
||||||
|
del modclass._LOADING
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
for cls_attr in _CLS_ATTRS:
|
||||||
|
try:
|
||||||
|
setattr(modclass, cls_attr, cls_attrs[cls_attr])
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
_reset_lazy_submod_refs(module)
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_lazy_submod_refs(module):
|
||||||
|
modclass = type(module)
|
||||||
|
for deldict in _DELETION_DICT:
|
||||||
|
try:
|
||||||
|
resetnames = getattr(modclass, deldict)
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
for name, submod in resetnames.items():
|
||||||
|
super(LazyModule, module).__setattr__(name, submod)
|
||||||
|
|
||||||
|
|
||||||
|
def run_from_ipython():
|
||||||
|
# Taken from https://stackoverflow.com/questions/5376837
|
||||||
|
try:
|
||||||
|
__IPYTHON__
|
||||||
|
return True
|
||||||
|
except NameError:
|
||||||
|
return False
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Code copied from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/lazy_loader.py
|
||||||
|
"""A LazyLoader class."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import types
|
||||||
|
|
||||||
|
|
||||||
|
class LazyLoader(types.ModuleType):
|
||||||
|
"""Lazily import a module, mainly to avoid pulling in large dependencies.
|
||||||
|
|
||||||
|
`contrib`, and `ffmpeg` are examples of modules that are large and not always
|
||||||
|
needed, and this allows them to only be loaded when they are used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The lint error here is incorrect.
|
||||||
|
def __init__(
|
||||||
|
self, local_name, parent_module_globals, name
|
||||||
|
): # pylint: disable=super-on-old-class
|
||||||
|
self._local_name = local_name
|
||||||
|
self._parent_module_globals = parent_module_globals
|
||||||
|
|
||||||
|
super(LazyLoader, self).__init__(name)
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
# Import the target module and insert it into the parent's namespace
|
||||||
|
module = importlib.import_module(self.__name__)
|
||||||
|
self._parent_module_globals[self._local_name] = module
|
||||||
|
|
||||||
|
# Update this object's dict so that if someone keeps a reference to the
|
||||||
|
# LazyLoader, lookups are efficient (__getattr__ is only called on lookups
|
||||||
|
# that fail).
|
||||||
|
self.__dict__.update(module.__dict__)
|
||||||
|
|
||||||
|
return module
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
module = self._load()
|
||||||
|
return getattr(module, item)
|
||||||
|
|
||||||
|
def __dir__(self):
|
||||||
|
module = self._load()
|
||||||
|
return dir(module)
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
from plume.utils import lazy_module
|
||||||
|
import typer
|
||||||
|
|
||||||
|
rpyc = lazy_module('rpyc')
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
class ASRService(rpyc.Service):
|
||||||
|
def __init__(self, asr_recognizer):
|
||||||
|
self.asr = asr_recognizer
|
||||||
|
|
||||||
|
def on_connect(self, conn):
|
||||||
|
# code that runs when a connection is created
|
||||||
|
# (to init the service, if needed)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_disconnect(self, conn):
|
||||||
|
# code that runs after the connection has already closed
|
||||||
|
# (to finalize the service, if needed)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def exposed_transcribe(self, utterance: bytes): # this is an exposed method
|
||||||
|
speech_audio = self.asr.transcribe(utterance)
|
||||||
|
return speech_audio
|
||||||
|
|
||||||
|
def exposed_transcribe_cb(
|
||||||
|
self, utterance: bytes, respond
|
||||||
|
): # this is an exposed method
|
||||||
|
speech_audio = self.asr.transcribe(utterance)
|
||||||
|
respond(speech_audio)
|
||||||
|
|
@ -0,0 +1,184 @@
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import typer
|
||||||
|
# import rpyc
|
||||||
|
|
||||||
|
# from tqdm import tqdm
|
||||||
|
# from pydub import AudioSegment
|
||||||
|
# from pydub.silence import split_on_silence
|
||||||
|
from plume.utils import lazy_module, lazy_callable
|
||||||
|
|
||||||
|
rpyc = lazy_module('rpyc')
|
||||||
|
AudioSegment = lazy_callable('pydub.AudioSegment')
|
||||||
|
split_on_silence = lazy_callable('pydub.silence.split_on_silence')
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ASR_RPYC_HOST = os.environ.get("JASR_RPYC_HOST", "localhost")
|
||||||
|
ASR_RPYC_PORT = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||||
|
|
||||||
|
TRITON_ASR_MODEL = os.environ.get("TRITON_ASR_MODEL", "slu_wav2vec2")
|
||||||
|
|
||||||
|
TRITON_GRPC_ASR_HOST = os.environ.get("TRITON_GRPC_ASR_HOST", "localhost")
|
||||||
|
TRITON_GRPC_ASR_PORT = int(os.environ.get("TRITON_GRPC_ASR_PORT", "8001"))
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT):
|
||||||
|
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
|
||||||
|
try:
|
||||||
|
asr = rpyc.connect(asr_host, asr_port).root
|
||||||
|
logger.info(f"connected to asr server successfully")
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
raise Exception("env-var JASPER_ASR_RPYC_HOST invalid")
|
||||||
|
|
||||||
|
def audio_prep(aud_seg):
|
||||||
|
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||||
|
return asr_seg
|
||||||
|
|
||||||
|
return asr.transcribe, audio_prep
|
||||||
|
|
||||||
|
|
||||||
|
def triton_transcribe_grpc_gen(
|
||||||
|
asr_host=TRITON_GRPC_ASR_HOST,
|
||||||
|
asr_port=TRITON_GRPC_ASR_PORT,
|
||||||
|
asr_model=TRITON_ASR_MODEL,
|
||||||
|
method="chunked",
|
||||||
|
chunk_msec=5000,
|
||||||
|
sil_msec=500,
|
||||||
|
# overlap=False,
|
||||||
|
sep=" ",
|
||||||
|
):
|
||||||
|
from tritonclient.utils import np_to_triton_dtype
|
||||||
|
import tritonclient.grpc as grpcclient
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
sup_meth = ["chunked", "silence", "whole"]
|
||||||
|
if method not in sup_meth:
|
||||||
|
meths = "|".join(sup_meth)
|
||||||
|
raise Exception(f"unsupported method {method}. pick one of {meths}")
|
||||||
|
|
||||||
|
client = grpcclient.InferenceServerClient(f"{asr_host}:{asr_port}")
|
||||||
|
|
||||||
|
def transcriber(aud_seg):
|
||||||
|
af = BytesIO()
|
||||||
|
aud_seg.export(af, format="wav")
|
||||||
|
input_audio_bytes = af.getvalue()
|
||||||
|
input_audio_data = np.array([input_audio_bytes])
|
||||||
|
inputs = [
|
||||||
|
grpcclient.InferInput(
|
||||||
|
"INPUT_AUDIO",
|
||||||
|
input_audio_data.shape,
|
||||||
|
np_to_triton_dtype(input_audio_data.dtype),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
inputs[0].set_data_from_numpy(input_audio_data)
|
||||||
|
outputs = [grpcclient.InferRequestedOutput("OUTPUT_TEXT")]
|
||||||
|
response = client.infer(asr_model, inputs, request_id=str(1), outputs=outputs)
|
||||||
|
transcript = response.as_numpy("OUTPUT_TEXT")[0]
|
||||||
|
return transcript.decode("utf-8")
|
||||||
|
|
||||||
|
def chunked_transcriber(aud_seg):
|
||||||
|
if method == "silence":
|
||||||
|
sil_chunks = split_on_silence(
|
||||||
|
aud_seg,
|
||||||
|
min_silence_len=sil_msec,
|
||||||
|
silence_thresh=-50,
|
||||||
|
keep_silence=500,
|
||||||
|
)
|
||||||
|
chunks = [sc for c in sil_chunks for sc in c[::chunk_msec]]
|
||||||
|
else:
|
||||||
|
chunks = aud_seg[::chunk_msec]
|
||||||
|
# if overlap:
|
||||||
|
# chunks = [
|
||||||
|
# aud_seg[start, end]
|
||||||
|
# for start, end in range(0, int(aud_seg.duration_seconds * 1000, 1000))
|
||||||
|
# ]
|
||||||
|
# pass
|
||||||
|
transcript_list = []
|
||||||
|
sil_pad = AudioSegment.silent(duration=sil_msec)
|
||||||
|
for seg in chunks:
|
||||||
|
t_seg = sil_pad + seg + sil_pad
|
||||||
|
c_transcript = transcriber(t_seg)
|
||||||
|
transcript_list.append(c_transcript)
|
||||||
|
transcript = sep.join(transcript_list)
|
||||||
|
return transcript
|
||||||
|
|
||||||
|
def audio_prep(aud_seg):
|
||||||
|
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||||
|
return asr_seg
|
||||||
|
|
||||||
|
whole_transcriber = transcriber if method == "whole" else chunked_transcriber
|
||||||
|
return whole_transcriber, audio_prep
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def file(audio_file: Path, write_file: bool = False, chunked=True):
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
aseg = AudioSegment.from_file(audio_file)
|
||||||
|
transcriber, prep = triton_transcribe_grpc_gen()
|
||||||
|
transcription = transcriber(prep(aseg))
|
||||||
|
|
||||||
|
typer.echo(transcription)
|
||||||
|
if write_file:
|
||||||
|
tscript_file_path = audio_file.with_suffix(".txt")
|
||||||
|
with open(tscript_file_path, "w") as tf:
|
||||||
|
tf.write(transcription)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def benchmark(audio_file: Path):
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
transcriber, audio_prep = transcribe_rpyc_gen()
|
||||||
|
file_seg = AudioSegment.from_file(audio_file)
|
||||||
|
aud_seg = audio_prep(file_seg)
|
||||||
|
|
||||||
|
def timeinfo():
|
||||||
|
from timeit import Timer
|
||||||
|
|
||||||
|
timer = Timer(lambda: transcriber(aud_seg))
|
||||||
|
number = 100
|
||||||
|
repeat = 10
|
||||||
|
time_taken = timer.repeat(repeat, number=number)
|
||||||
|
best = min(time_taken) * 1000 / number
|
||||||
|
print(f"{number} loops, best of {repeat}: {best:.3f} msec per loop")
|
||||||
|
|
||||||
|
timeinfo()
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
transcriber, audio_prep = triton_transcribe_grpc_gen()
|
||||||
|
aud_seg = audio_prep(file_seg)
|
||||||
|
|
||||||
|
def timeinfo():
|
||||||
|
from timeit import Timer
|
||||||
|
|
||||||
|
timer = Timer(lambda: transcriber(aud_seg))
|
||||||
|
number = 100
|
||||||
|
repeat = 10
|
||||||
|
time_taken = timer.repeat(repeat, number=number)
|
||||||
|
best = min(time_taken) * 1000 / number
|
||||||
|
print(f"{number} loops, best of {repeat}: {best:.3f} msec per loop")
|
||||||
|
|
||||||
|
timeinfo()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from plume.utils import lazy_module
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
# from google.cloud import texttospeech
|
||||||
|
texttospeech = lazy_module('google.cloud.texttospeech')
|
||||||
|
|
||||||
|
LOGGER = getLogger("googletts")
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleTTS(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.client = texttospeech.TextToSpeechClient()
|
||||||
|
|
||||||
|
def text_to_speech(self, text: str, params: dict) -> bytes:
|
||||||
|
tts_input = texttospeech.types.SynthesisInput(text=text)
|
||||||
|
voice = texttospeech.types.VoiceSelectionParams(
|
||||||
|
language_code=params["language"], name=params["name"]
|
||||||
|
)
|
||||||
|
audio_config = texttospeech.types.AudioConfig(
|
||||||
|
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
|
||||||
|
sample_rate_hertz=params["sample_rate"],
|
||||||
|
)
|
||||||
|
response = self.client.synthesize_speech(tts_input, voice, audio_config)
|
||||||
|
audio_content = response.audio_content
|
||||||
|
return audio_content
|
||||||
|
|
||||||
|
def ssml_to_speech(self, text: str, params: dict) -> bytes:
|
||||||
|
tts_input = texttospeech.types.SynthesisInput(ssml=text)
|
||||||
|
voice = texttospeech.types.VoiceSelectionParams(
|
||||||
|
language_code=params["language"], name=params["name"]
|
||||||
|
)
|
||||||
|
audio_config = texttospeech.types.AudioConfig(
|
||||||
|
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
|
||||||
|
sample_rate_hertz=params["sample_rate"],
|
||||||
|
)
|
||||||
|
response = self.client.synthesize_speech(tts_input, voice, audio_config)
|
||||||
|
audio_content = response.audio_content
|
||||||
|
return audio_content
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def voice_list(cls):
|
||||||
|
"""Lists the available voices."""
|
||||||
|
|
||||||
|
client = cls().client
|
||||||
|
|
||||||
|
# Performs the list voices request
|
||||||
|
voices = client.list_voices()
|
||||||
|
results = []
|
||||||
|
for voice in voices.voices:
|
||||||
|
supported_eng_langs = [
|
||||||
|
lang for lang in voice.language_codes if lang[:2] == "en"
|
||||||
|
]
|
||||||
|
if len(supported_eng_langs) > 0:
|
||||||
|
lang = ",".join(supported_eng_langs)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": voice.name,
|
||||||
|
"language": lang,
|
||||||
|
"gender": ssml_gender.name,
|
||||||
|
"engine": "wavenet" if "Wav" in voice.name else "standard",
|
||||||
|
"sample_rate": voice.natural_sample_rate_hertz,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def generate_audio_file(text, dest_path: Path = "./tts_audio.wav", voice="en-US-Wavenet-D"):
|
||||||
|
tts = GoogleTTS()
|
||||||
|
selected_voice = [v for v in tts.voice_list() if v["name"] == voice][0]
|
||||||
|
wav_data = tts.text_to_speech(text, selected_voice)
|
||||||
|
with dest_path.open("wb") as wf:
|
||||||
|
wf.write(wav_data)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
103
setup.py
103
setup.py
|
|
@ -1,81 +1,80 @@
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
# pip install "nvidia-pyindex~=1.0.5"
|
||||||
|
|
||||||
requirements = [
|
requirements = [
|
||||||
"ruamel.yaml",
|
"torch~=1.6.0",
|
||||||
"torch==1.4.0",
|
"torchvision~=0.7.0",
|
||||||
"torchvision==0.5.0",
|
|
||||||
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
||||||
|
"fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq",
|
||||||
|
# "google-cloud-texttospeech~=1.0.1",
|
||||||
|
"tqdm~=4.54.0",
|
||||||
|
# "pydub~=0.24.0",
|
||||||
|
# "scikit_learn~=0.22.1",
|
||||||
|
# "pandas~=1.0.3",
|
||||||
|
# "boto3~=1.12.35",
|
||||||
|
# "ruamel.yaml~=0.16.10",
|
||||||
|
# "pymongo==3.10.1",
|
||||||
|
# "matplotlib==3.2.1",
|
||||||
|
# "tabulate==0.8.7",
|
||||||
|
# "natural==0.2.0",
|
||||||
|
# "num2words==0.5.10",
|
||||||
|
"typer[all]~=0.3.2",
|
||||||
|
# "python-slugify==4.0.0",
|
||||||
|
# "websockets==8.1",
|
||||||
|
# "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
||||||
|
"rpyc~=4.1.4",
|
||||||
|
# "streamlit~=0.61.0",
|
||||||
|
# "librosa~=0.7.2",
|
||||||
|
# "tritonclient[http]~=2.6.0",
|
||||||
|
"numba~=0.48.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
extra_requirements = {
|
extra_requirements = {
|
||||||
"server": ["rpyc~=4.1.4", "tqdm~=4.39.0"],
|
|
||||||
"data": [
|
"data": [
|
||||||
"google-cloud-texttospeech~=1.0.1",
|
|
||||||
"tqdm~=4.39.0",
|
|
||||||
"pydub~=0.24.0",
|
"pydub~=0.24.0",
|
||||||
|
"google-cloud-texttospeech~=1.0.1",
|
||||||
"scikit_learn~=0.22.1",
|
"scikit_learn~=0.22.1",
|
||||||
"pandas~=1.0.3",
|
"pandas~=1.0.3",
|
||||||
"boto3~=1.12.35",
|
"boto3~=1.12.35",
|
||||||
"ruamel.yaml==0.16.10",
|
"ruamel.yaml~=0.16.10",
|
||||||
"pymongo==3.10.1",
|
"pymongo~=3.10.1",
|
||||||
"librosa==0.7.2",
|
"librosa~=0.7.2",
|
||||||
"numba==0.48",
|
"matplotlib~=3.2.1",
|
||||||
"matplotlib==3.2.1",
|
"pandas~=1.0.3",
|
||||||
"pandas==1.0.3",
|
"tabulate~=0.8.7",
|
||||||
"tabulate==0.8.7",
|
"natural~=0.2.0",
|
||||||
"natural==0.2.0",
|
"num2words~=0.5.10",
|
||||||
"num2words==0.5.10",
|
"python-slugify~=4.0.0",
|
||||||
"typer[all]==0.3.1",
|
|
||||||
"python-slugify==4.0.0",
|
|
||||||
"rpyc~=4.1.4",
|
"rpyc~=4.1.4",
|
||||||
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
# "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
||||||
],
|
],
|
||||||
"validation": [
|
"validation": [
|
||||||
"rpyc~=4.1.4",
|
"pymongo~=3.10.1",
|
||||||
"pymongo==3.10.1",
|
"matplotlib~=3.2.1",
|
||||||
"typer[all]==0.1.1",
|
|
||||||
"tqdm~=4.39.0",
|
|
||||||
"librosa==0.7.2",
|
|
||||||
"matplotlib==3.2.1",
|
|
||||||
"pydub~=0.24.0",
|
"pydub~=0.24.0",
|
||||||
"streamlit==0.58.0",
|
"streamlit~=0.58.0",
|
||||||
"natural==0.2.0",
|
"natural~=0.2.0",
|
||||||
"stringcase==1.2.0",
|
"stringcase~=1.2.0",
|
||||||
"google-cloud-speech~=1.3.1",
|
"google-cloud-speech~=1.3.1",
|
||||||
]
|
],
|
||||||
# "train": [
|
"train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"],
|
||||||
# "torchaudio==0.5.0",
|
|
||||||
# "torch-stft==0.1.4",
|
|
||||||
# ]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extra_requirements["all"] = list({d for l in extra_requirements.values() for d in l})
|
||||||
packages = find_packages()
|
packages = find_packages()
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="jasper-asr",
|
name="plume-asr",
|
||||||
version="0.1",
|
version="0.11",
|
||||||
description="Tool to get gcp alignments of tts-data",
|
description="Multi model ASR base package",
|
||||||
url="http://github.com/malarinv/jasper-asr",
|
url="http://github.com/malarinv/plume-asr",
|
||||||
author="Malar Kannan",
|
author="Malar Kannan",
|
||||||
author_email="malarkannan.invention@gmail.com",
|
author_email="malarkannan.invention@gmail.com",
|
||||||
license="MIT",
|
license="MIT",
|
||||||
install_requires=requirements,
|
install_requires=requirements,
|
||||||
extras_require=extra_requirements,
|
extras_require=extra_requirements,
|
||||||
packages=packages,
|
packages=packages,
|
||||||
entry_points={
|
entry_points={"console_scripts": ["plume = plume.cli:main"]},
|
||||||
"console_scripts": [
|
|
||||||
"jasper_transcribe = jasper.transcribe:main",
|
|
||||||
"jasper_server = jasper.server:main",
|
|
||||||
"jasper_trainer = jasper.training.cli:main",
|
|
||||||
"jasper_evaluator = jasper.evaluate:main",
|
|
||||||
"jasper_data_tts_generate = jasper.data.tts_generator:main",
|
|
||||||
"jasper_data_conv_generate = jasper.data.conv_generator:main",
|
|
||||||
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
|
|
||||||
"jasper_data_rastrik_recycle = jasper.data.rastrik_recycler:main",
|
|
||||||
"jasper_data_server = jasper.data.server:main",
|
|
||||||
"jasper_data_validation = jasper.data.validation.process:main",
|
|
||||||
"jasper_data_preprocess = jasper.data.process:main",
|
|
||||||
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
import runpy
|
|
||||||
|
|
||||||
runpy.run_module("jasper.data.validation.ui", run_name="__main__", alter_sys=True)
|
|
||||||
Loading…
Reference in New Issue