plume-asr/plume/utils/__init__.py

487 lines
15 KiB
Python

import io
import os
import re
import json
import wave
import logging
from pathlib import Path
from functools import partial
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import subprocess
import shutil
from urllib.parse import urlsplit
# from .lazy_loader import LazyLoader
from .lazy_import import lazy_callable, lazy_module
# from ruamel.yaml import YAML
# import boto3
import typer
# import pymongo
# from slugify import slugify
# import pydub
# import matplotlib.pyplot as plt
# import librosa
# import librosa.display as audio_display
# from natural.date import compress
# from num2words import num2words
from tqdm import tqdm
from datetime import timedelta
# from .transcribe import triton_transcribe_grpc_gen
# from .eval import app as eval_app
from .tts import app as tts_app
from .transcribe import app as transcribe_app
from .align import app as align_app
boto3 = lazy_module('boto3')
pymongo = lazy_module('pymongo')
pydub = lazy_module('pydub')
audio_display = lazy_module('librosa.display')
plt = lazy_module('matplotlib.pyplot')
librosa = lazy_module('librosa')
YAML = lazy_callable('ruamel.yaml.YAML')
num2words = lazy_callable('num2words.num2words')
slugify = lazy_callable('slugify.slugify')
compress = lazy_callable('natural.date.compress')
app = typer.Typer()
app.add_typer(tts_app, name="tts")
app.add_typer(align_app, name="align")
app.add_typer(transcribe_app, name="transcribe")
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def manifest_str(path, dur, text):
return (
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
+ "\n"
)
def duration_str(seconds):
return compress(timedelta(seconds=seconds), pad=" ")
def replace_digit_symbol(w2v_out):
num_int_map = {num2words(i): str(i) for i in range(10)}
out = w2v_out.lower()
for (k, v) in num_int_map.items():
out = re.sub(k, v, out)
return out
def discard_except_digits(inp):
return re.sub("[^0-9]", "", inp)
def digits_to_chars(text):
num_tokens = [num2words(c) + " " if "0" <= c <= "9" else c for c in text]
return ("".join(num_tokens)).lower()
def replace_redundant_spaces_with(text, sub):
return re.sub(" +", sub, text)
def space_out(text):
letters = " ".join(list(text))
return letters
def wav_bytes(audio_bytes, frame_rate=24000):
wf_b = io.BytesIO()
with wave.open(wf_b, mode="w") as wf:
wf.setnchannels(1)
wf.setframerate(frame_rate)
wf.setsampwidth(2)
wf.writeframesraw(audio_bytes)
return wf_b.getvalue()
def tscript_uuid_fname(transcript):
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
def run_shell(cmd_str, work_dir="."):
cwd_path = Path(work_dir).absolute()
p = subprocess.Popen(
cmd_str,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
shell=True,
cwd=cwd_path,
)
for line in p.stdout:
print(line.replace(b"\n", b"").decode("utf-8"))
def upload_s3(dataset_path, s3_path):
run_shell(f"aws s3 sync {dataset_path} {s3_path}")
def get_download_path(s3_uri, output_path):
s3_uri_p = urlsplit(s3_uri)
download_path = output_path / Path(s3_uri_p.path[1:])
download_path.parent.mkdir(exist_ok=True, parents=True)
return download_path
def s3_downloader():
s3 = boto3.client("s3")
def download_s3(s3_uri, download_path):
s3_uri_p = urlsplit(s3_uri)
download_path.parent.mkdir(exist_ok=True, parents=True)
if not download_path.exists():
print(f"downloading {s3_uri} to {download_path}")
s3.download_file(s3_uri_p.netloc, s3_uri_p.path[1:], str(download_path))
return download_s3
def asr_data_writer(dataset_dir, asr_data_source, verbose=False):
(dataset_dir / Path("wavs")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for transcript, audio_dur, wav_data in asr_data_source:
fname = tscript_uuid_fname(transcript)
audio_file = dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
rel_data_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
if verbose:
print(f"writing '{transcript}' of duration {audio_dur}")
num_datapoints += 1
return num_datapoints
def ui_data_generator(dataset_dir, asr_data_source, verbose=False):
(dataset_dir / Path("wavs")).mkdir(parents=True, exist_ok=True)
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
def data_fn(
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_file,
num_datapoints,
rel_data_path,
):
png_path = Path(fname).with_suffix(".png")
rel_plot_path = Path("wav_plots") / png_path
wav_plot_path = dataset_dir / rel_plot_path
if not wav_plot_path.exists():
plot_seg(wav_plot_path.absolute(), audio_file)
return {
"audio_path": str(rel_data_path),
"duration": round(audio_dur, 1),
"text": transcript,
"real_idx": num_datapoints,
"caller": caller_name,
"utterance_id": fname,
"plot_path": str(rel_plot_path),
}
num_datapoints = 0
data_funcs = []
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = (
dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav")
).absolute()
audio_file.write_bytes(wav_data)
# audio_path = str(audio_file)
rel_data_path = audio_file.relative_to(dataset_dir.absolute())
data_funcs.append(
partial(
data_fn,
transcript,
audio_dur,
wav_data,
caller_name,
aud_seg,
fname,
audio_file,
num_datapoints,
rel_data_path,
)
)
num_datapoints += 1
ui_data = parallel_apply(lambda x: x(), data_funcs)
return ui_data, num_datapoints
def ui_dump_manifest_writer(dataset_dir, asr_data_source, verbose=False):
dump_data, num_datapoints = ui_data_generator(
dataset_dir, asr_data_source, verbose=verbose
)
asr_manifest = dataset_dir / Path("manifest.json")
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for d in dump_data:
rel_data_path = d["audio_path"]
audio_dur = d["duration"]
transcript = d["text"]
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
ui_dump_file = dataset_dir / Path("ui_dump.json")
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
return num_datapoints
def asr_manifest_reader(data_manifest_path: Path):
print(f"reading manifest from {data_manifest_path}")
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
data_data = [json.loads(v) for v in data_jsonl]
for p in data_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["text"] = p["text"].strip()
yield p
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
with asr_manifest_path.open("w") as mf:
print(f"opening {asr_manifest_path} for writing manifest")
for mani_dict in manifest_str_source:
manifest = manifest_str(
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
)
mf.write(manifest)
def asr_test_writer(out_file_path: Path, source):
def dd_str(dd, idx):
path = dd["audio_filepath"]
# dur = dd["duration"]
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
res_file = out_file_path.with_suffix(".result.json")
with out_file_path.open("w") as of:
print(f"opening {out_file_path} for writing test")
results = []
idx = 0
for ui_dd in source:
results.append(ui_dd)
out_str = dd_str(ui_dd, idx)
of.write(out_str)
idx += 1
of.write("DO_HANGUP\n")
ExtendedPath(res_file).write_json(results)
def batch(iterable, n=1):
ls = len(iterable)
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
class ExtendedPath(type(Path())):
"""docstring for ExtendedPath."""
def read_json(self):
print(f"reading json from {self}")
with self.open("r") as jf:
return json.load(jf)
def read_yaml(self):
yaml = YAML(typ="safe", pure=True)
print(f"reading yaml from {self}")
with self.open("r") as yf:
return yaml.load(yf)
def read_jsonl(self):
print(f"reading jsonl from {self}")
with self.open("r") as jf:
for l in jf.readlines():
yield json.loads(l)
def write_json(self, data):
print(f"writing json to {self}")
self.parent.mkdir(parents=True, exist_ok=True)
with self.open("w") as jf:
json.dump(data, jf, indent=2)
def write_yaml(self, data):
yaml = YAML()
print(f"writing yaml to {self}")
with self.open("w") as yf:
yaml.dump(data, yf)
def write_jsonl(self, data):
print(f"writing jsonl to {self}")
self.parent.mkdir(parents=True, exist_ok=True)
with self.open("w") as jf:
for d in data:
jf.write(json.dumps(d) + "\n")
def get_mongo_coll(uri):
ud = pymongo.uri_parser.parse_uri(uri)
conn = pymongo.MongoClient(uri)
return conn[ud["database"]][ud["collection"]]
def get_mongo_conn(host="", port=27017, db="db", col="collection"):
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
mongo_uri = f"mongodb://{mongo_host}:{port}/"
return pymongo.MongoClient(mongo_uri)[db][col]
def strip_silence(sound):
from pydub.silence import detect_leading_silence
start_trim = detect_leading_silence(sound)
end_trim = detect_leading_silence(sound.reverse())
duration = len(sound)
return sound[start_trim : duration - end_trim]
def plot_seg(wav_plot_path, audio_path):
fig = plt.Figure()
ax = fig.add_subplot()
(y, sr) = librosa.load(str(audio_path))
audio_display.waveplot(y=y, sr=sr, ax=ax)
with wav_plot_path.open("wb") as wav_plot_f:
fig.set_tight_layout(True)
fig.savefig(wav_plot_f, format="png", dpi=50)
def parallel_apply(fn, iterable, workers=8, pool="thread"):
if pool == "thread":
with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")
return [
res
for res in tqdm(
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
)
]
elif pool == "process":
with ProcessPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")
return [
res
for res in tqdm(
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
)
]
else:
raise Exception(f"unsupported pool type - {pool}")
def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
min_nums = 3
max_duration = 1 * 60 * 60
skip_duration = 1 * 60 * 60
def filtered_max_dur():
wav_duration = 0
for s in ExtendedPath(data_file).read_jsonl():
nums = re.sub(" ", "", s["text"])
if len(nums) >= min_nums:
wav_duration += s["duration"]
shutil.copy(
src_dataset_path / Path(s["audio_filepath"]),
dest_dataset_path / Path(s["audio_filepath"]),
)
yield s
if wav_duration > max_duration:
break
typer.echo(f"filtered only {duration_str(wav_duration)} of audio")
def filtered_skip_dur():
wav_duration = 0
for s in ExtendedPath(data_file).read_jsonl():
nums = re.sub(" ", "", s["text"])
if len(nums) >= min_nums:
wav_duration += s["duration"]
if wav_duration <= skip_duration:
continue
elif len(nums) >= min_nums:
yield s
shutil.copy(
src_dataset_path / Path(s["audio_filepath"]),
dest_dataset_path / Path(s["audio_filepath"]),
)
typer.echo(f"skipped {duration_str(skip_duration)} of audio")
def filtered_blanks():
blank_count = 0
for s in ExtendedPath(data_file).read_jsonl():
nums = re.sub(" ", "", s["text"])
if nums != "":
blank_count += 1
shutil.copy(
src_dataset_path / Path(s["audio_filepath"]),
dest_dataset_path / Path(s["audio_filepath"]),
)
yield s
typer.echo(f"filtered {blank_count} blank samples")
def filtered_transform_digits():
count = 0
for s in ExtendedPath(data_file).read_jsonl():
count += 1
digit_text = replace_digit_symbol(s["text"])
only_digits = discard_except_digits(digit_text)
char_text = digits_to_chars(only_digits)
shutil.copy(
src_dataset_path / Path(s["audio_filepath"]),
dest_dataset_path / Path(s["audio_filepath"]),
)
s["text"] = char_text
yield s
typer.echo(f"transformed {count} samples")
def filtered_extract_chars():
count = 0
for s in ExtendedPath(data_file).read_jsonl():
count += 1
no_digits = digits_to_chars(s["text"]).upper()
only_chars = re.sub("[^A-Z'\b]", " ", no_digits)
filter_text = replace_redundant_spaces_with(only_chars, " ").strip()
shutil.copy(
src_dataset_path / Path(s["audio_filepath"]),
dest_dataset_path / Path(s["audio_filepath"]),
)
s["text"] = filter_text
yield s
typer.echo(f"transformed {count} samples")
def filtered_resample():
count = 0
for s in ExtendedPath(data_file).read_jsonl():
count += 1
src_aud = pydub.AudioSegment.from_file(
src_dataset_path / Path(s["audio_filepath"])
)
dst_aud = src_aud.set_channels(1).set_sample_width(1).set_frame_rate(24000)
dst_aud.export(dest_dataset_path / Path(s["audio_filepath"]), format="wav")
yield s
typer.echo(f"transformed {count} samples")
filter_kind_map = {
"max_dur_1hr_min3num": filtered_max_dur,
"skip_dur_1hr_min3num": filtered_skip_dur,
"blanks": filtered_blanks,
"transform_digits": filtered_transform_digits,
"extract_chars": filtered_extract_chars,
"resample_ulaw24kmono": filtered_resample,
}
return filter_kind_map