487 lines
15 KiB
Python
487 lines
15 KiB
Python
import io
|
|
import os
|
|
import re
|
|
import json
|
|
import wave
|
|
import logging
|
|
from pathlib import Path
|
|
from functools import partial
|
|
from uuid import uuid4
|
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
|
import subprocess
|
|
import shutil
|
|
from urllib.parse import urlsplit
|
|
# from .lazy_loader import LazyLoader
|
|
from .lazy_import import lazy_callable, lazy_module
|
|
|
|
# from ruamel.yaml import YAML
|
|
# import boto3
|
|
import typer
|
|
# import pymongo
|
|
# from slugify import slugify
|
|
# import pydub
|
|
# import matplotlib.pyplot as plt
|
|
# import librosa
|
|
# import librosa.display as audio_display
|
|
# from natural.date import compress
|
|
# from num2words import num2words
|
|
from tqdm import tqdm
|
|
from datetime import timedelta
|
|
|
|
# from .transcribe import triton_transcribe_grpc_gen
|
|
# from .eval import app as eval_app
|
|
from .tts import app as tts_app
|
|
from .transcribe import app as transcribe_app
|
|
from .align import app as align_app
|
|
|
|
boto3 = lazy_module('boto3')
|
|
pymongo = lazy_module('pymongo')
|
|
pydub = lazy_module('pydub')
|
|
audio_display = lazy_module('librosa.display')
|
|
plt = lazy_module('matplotlib.pyplot')
|
|
librosa = lazy_module('librosa')
|
|
YAML = lazy_callable('ruamel.yaml.YAML')
|
|
num2words = lazy_callable('num2words.num2words')
|
|
slugify = lazy_callable('slugify.slugify')
|
|
compress = lazy_callable('natural.date.compress')
|
|
|
|
app = typer.Typer()
|
|
app.add_typer(tts_app, name="tts")
|
|
app.add_typer(align_app, name="align")
|
|
app.add_typer(transcribe_app, name="transcribe")
|
|
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def manifest_str(path, dur, text):
|
|
return (
|
|
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
|
|
+ "\n"
|
|
)
|
|
|
|
|
|
def duration_str(seconds):
|
|
return compress(timedelta(seconds=seconds), pad=" ")
|
|
|
|
|
|
def replace_digit_symbol(w2v_out):
|
|
num_int_map = {num2words(i): str(i) for i in range(10)}
|
|
out = w2v_out.lower()
|
|
for (k, v) in num_int_map.items():
|
|
out = re.sub(k, v, out)
|
|
return out
|
|
|
|
|
|
def discard_except_digits(inp):
|
|
return re.sub("[^0-9]", "", inp)
|
|
|
|
|
|
def digits_to_chars(text):
|
|
num_tokens = [num2words(c) + " " if "0" <= c <= "9" else c for c in text]
|
|
return ("".join(num_tokens)).lower()
|
|
|
|
|
|
def replace_redundant_spaces_with(text, sub):
|
|
return re.sub(" +", sub, text)
|
|
|
|
|
|
def space_out(text):
|
|
letters = " ".join(list(text))
|
|
return letters
|
|
|
|
|
|
def wav_bytes(audio_bytes, frame_rate=24000):
|
|
wf_b = io.BytesIO()
|
|
with wave.open(wf_b, mode="w") as wf:
|
|
wf.setnchannels(1)
|
|
wf.setframerate(frame_rate)
|
|
wf.setsampwidth(2)
|
|
wf.writeframesraw(audio_bytes)
|
|
return wf_b.getvalue()
|
|
|
|
|
|
def tscript_uuid_fname(transcript):
|
|
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
|
|
|
|
|
def run_shell(cmd_str, work_dir="."):
|
|
cwd_path = Path(work_dir).absolute()
|
|
p = subprocess.Popen(
|
|
cmd_str,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.STDOUT,
|
|
shell=True,
|
|
cwd=cwd_path,
|
|
)
|
|
for line in p.stdout:
|
|
print(line.replace(b"\n", b"").decode("utf-8"))
|
|
|
|
|
|
def upload_s3(dataset_path, s3_path):
|
|
run_shell(f"aws s3 sync {dataset_path} {s3_path}")
|
|
|
|
|
|
def get_download_path(s3_uri, output_path):
|
|
s3_uri_p = urlsplit(s3_uri)
|
|
download_path = output_path / Path(s3_uri_p.path[1:])
|
|
download_path.parent.mkdir(exist_ok=True, parents=True)
|
|
return download_path
|
|
|
|
|
|
def s3_downloader():
|
|
s3 = boto3.client("s3")
|
|
|
|
def download_s3(s3_uri, download_path):
|
|
s3_uri_p = urlsplit(s3_uri)
|
|
download_path.parent.mkdir(exist_ok=True, parents=True)
|
|
if not download_path.exists():
|
|
print(f"downloading {s3_uri} to {download_path}")
|
|
s3.download_file(s3_uri_p.netloc, s3_uri_p.path[1:], str(download_path))
|
|
|
|
return download_s3
|
|
|
|
|
|
def asr_data_writer(dataset_dir, asr_data_source, verbose=False):
|
|
(dataset_dir / Path("wavs")).mkdir(parents=True, exist_ok=True)
|
|
asr_manifest = dataset_dir / Path("manifest.json")
|
|
num_datapoints = 0
|
|
with asr_manifest.open("w") as mf:
|
|
print(f"writing manifest to {asr_manifest}")
|
|
for transcript, audio_dur, wav_data in asr_data_source:
|
|
fname = tscript_uuid_fname(transcript)
|
|
audio_file = dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav")
|
|
audio_file.write_bytes(wav_data)
|
|
rel_data_path = audio_file.relative_to(dataset_dir)
|
|
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
|
mf.write(manifest)
|
|
if verbose:
|
|
print(f"writing '{transcript}' of duration {audio_dur}")
|
|
num_datapoints += 1
|
|
return num_datapoints
|
|
|
|
|
|
def ui_data_generator(dataset_dir, asr_data_source, verbose=False):
|
|
(dataset_dir / Path("wavs")).mkdir(parents=True, exist_ok=True)
|
|
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
|
|
|
|
def data_fn(
|
|
transcript,
|
|
audio_dur,
|
|
wav_data,
|
|
caller_name,
|
|
aud_seg,
|
|
fname,
|
|
audio_file,
|
|
num_datapoints,
|
|
rel_data_path,
|
|
):
|
|
png_path = Path(fname).with_suffix(".png")
|
|
rel_plot_path = Path("wav_plots") / png_path
|
|
wav_plot_path = dataset_dir / rel_plot_path
|
|
if not wav_plot_path.exists():
|
|
plot_seg(wav_plot_path.absolute(), audio_file)
|
|
return {
|
|
"audio_path": str(rel_data_path),
|
|
"duration": round(audio_dur, 1),
|
|
"text": transcript,
|
|
"real_idx": num_datapoints,
|
|
"caller": caller_name,
|
|
"utterance_id": fname,
|
|
"plot_path": str(rel_plot_path),
|
|
}
|
|
|
|
num_datapoints = 0
|
|
data_funcs = []
|
|
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
|
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
|
audio_file = (
|
|
dataset_dir / Path("wavs") / Path(fname).with_suffix(".wav")
|
|
).absolute()
|
|
audio_file.write_bytes(wav_data)
|
|
# audio_path = str(audio_file)
|
|
rel_data_path = audio_file.relative_to(dataset_dir.absolute())
|
|
data_funcs.append(
|
|
partial(
|
|
data_fn,
|
|
transcript,
|
|
audio_dur,
|
|
wav_data,
|
|
caller_name,
|
|
aud_seg,
|
|
fname,
|
|
audio_file,
|
|
num_datapoints,
|
|
rel_data_path,
|
|
)
|
|
)
|
|
num_datapoints += 1
|
|
ui_data = parallel_apply(lambda x: x(), data_funcs)
|
|
return ui_data, num_datapoints
|
|
|
|
|
|
def ui_dump_manifest_writer(dataset_dir, asr_data_source, verbose=False):
|
|
dump_data, num_datapoints = ui_data_generator(
|
|
dataset_dir, asr_data_source, verbose=verbose
|
|
)
|
|
|
|
asr_manifest = dataset_dir / Path("manifest.json")
|
|
with asr_manifest.open("w") as mf:
|
|
print(f"writing manifest to {asr_manifest}")
|
|
for d in dump_data:
|
|
rel_data_path = d["audio_path"]
|
|
audio_dur = d["duration"]
|
|
transcript = d["text"]
|
|
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
|
mf.write(manifest)
|
|
|
|
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
|
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
|
|
return num_datapoints
|
|
|
|
|
|
def asr_manifest_reader(data_manifest_path: Path):
|
|
print(f"reading manifest from {data_manifest_path}")
|
|
with data_manifest_path.open("r") as pf:
|
|
data_jsonl = pf.readlines()
|
|
data_data = [json.loads(v) for v in data_jsonl]
|
|
for p in data_data:
|
|
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
|
p["text"] = p["text"].strip()
|
|
yield p
|
|
|
|
|
|
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
|
|
with asr_manifest_path.open("w") as mf:
|
|
print(f"opening {asr_manifest_path} for writing manifest")
|
|
for mani_dict in manifest_str_source:
|
|
manifest = manifest_str(
|
|
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
|
|
)
|
|
mf.write(manifest)
|
|
|
|
|
|
def asr_test_writer(out_file_path: Path, source):
|
|
def dd_str(dd, idx):
|
|
path = dd["audio_filepath"]
|
|
# dur = dd["duration"]
|
|
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
|
|
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
|
|
|
|
res_file = out_file_path.with_suffix(".result.json")
|
|
with out_file_path.open("w") as of:
|
|
print(f"opening {out_file_path} for writing test")
|
|
results = []
|
|
idx = 0
|
|
for ui_dd in source:
|
|
results.append(ui_dd)
|
|
out_str = dd_str(ui_dd, idx)
|
|
of.write(out_str)
|
|
idx += 1
|
|
of.write("DO_HANGUP\n")
|
|
ExtendedPath(res_file).write_json(results)
|
|
|
|
|
|
def batch(iterable, n=1):
|
|
ls = len(iterable)
|
|
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
|
|
|
|
|
|
class ExtendedPath(type(Path())):
|
|
"""docstring for ExtendedPath."""
|
|
|
|
def read_json(self):
|
|
print(f"reading json from {self}")
|
|
with self.open("r") as jf:
|
|
return json.load(jf)
|
|
|
|
def read_yaml(self):
|
|
yaml = YAML(typ="safe", pure=True)
|
|
print(f"reading yaml from {self}")
|
|
with self.open("r") as yf:
|
|
return yaml.load(yf)
|
|
|
|
def read_jsonl(self):
|
|
print(f"reading jsonl from {self}")
|
|
with self.open("r") as jf:
|
|
for l in jf.readlines():
|
|
yield json.loads(l)
|
|
|
|
def write_json(self, data):
|
|
print(f"writing json to {self}")
|
|
self.parent.mkdir(parents=True, exist_ok=True)
|
|
with self.open("w") as jf:
|
|
json.dump(data, jf, indent=2)
|
|
|
|
def write_yaml(self, data):
|
|
yaml = YAML()
|
|
print(f"writing yaml to {self}")
|
|
with self.open("w") as yf:
|
|
yaml.dump(data, yf)
|
|
|
|
def write_jsonl(self, data):
|
|
print(f"writing jsonl to {self}")
|
|
self.parent.mkdir(parents=True, exist_ok=True)
|
|
with self.open("w") as jf:
|
|
for d in data:
|
|
jf.write(json.dumps(d) + "\n")
|
|
|
|
|
|
def get_mongo_coll(uri):
|
|
ud = pymongo.uri_parser.parse_uri(uri)
|
|
conn = pymongo.MongoClient(uri)
|
|
return conn[ud["database"]][ud["collection"]]
|
|
|
|
|
|
def get_mongo_conn(host="", port=27017, db="db", col="collection"):
|
|
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
|
mongo_uri = f"mongodb://{mongo_host}:{port}/"
|
|
return pymongo.MongoClient(mongo_uri)[db][col]
|
|
|
|
|
|
def strip_silence(sound):
|
|
from pydub.silence import detect_leading_silence
|
|
|
|
start_trim = detect_leading_silence(sound)
|
|
end_trim = detect_leading_silence(sound.reverse())
|
|
duration = len(sound)
|
|
return sound[start_trim : duration - end_trim]
|
|
|
|
|
|
def plot_seg(wav_plot_path, audio_path):
|
|
fig = plt.Figure()
|
|
ax = fig.add_subplot()
|
|
(y, sr) = librosa.load(str(audio_path))
|
|
audio_display.waveplot(y=y, sr=sr, ax=ax)
|
|
with wav_plot_path.open("wb") as wav_plot_f:
|
|
fig.set_tight_layout(True)
|
|
fig.savefig(wav_plot_f, format="png", dpi=50)
|
|
|
|
|
|
def parallel_apply(fn, iterable, workers=8, pool="thread"):
|
|
if pool == "thread":
|
|
with ThreadPoolExecutor(max_workers=workers) as exe:
|
|
print(f"parallelly applying {fn}")
|
|
return [
|
|
res
|
|
for res in tqdm(
|
|
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
|
)
|
|
]
|
|
elif pool == "process":
|
|
with ProcessPoolExecutor(max_workers=workers) as exe:
|
|
print(f"parallelly applying {fn}")
|
|
return [
|
|
res
|
|
for res in tqdm(
|
|
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
|
)
|
|
]
|
|
else:
|
|
raise Exception(f"unsupported pool type - {pool}")
|
|
|
|
|
|
def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
|
|
min_nums = 3
|
|
max_duration = 1 * 60 * 60
|
|
skip_duration = 1 * 60 * 60
|
|
|
|
def filtered_max_dur():
|
|
wav_duration = 0
|
|
for s in ExtendedPath(data_file).read_jsonl():
|
|
nums = re.sub(" ", "", s["text"])
|
|
if len(nums) >= min_nums:
|
|
wav_duration += s["duration"]
|
|
shutil.copy(
|
|
src_dataset_path / Path(s["audio_filepath"]),
|
|
dest_dataset_path / Path(s["audio_filepath"]),
|
|
)
|
|
yield s
|
|
if wav_duration > max_duration:
|
|
break
|
|
typer.echo(f"filtered only {duration_str(wav_duration)} of audio")
|
|
|
|
def filtered_skip_dur():
|
|
wav_duration = 0
|
|
for s in ExtendedPath(data_file).read_jsonl():
|
|
nums = re.sub(" ", "", s["text"])
|
|
if len(nums) >= min_nums:
|
|
wav_duration += s["duration"]
|
|
if wav_duration <= skip_duration:
|
|
continue
|
|
elif len(nums) >= min_nums:
|
|
yield s
|
|
shutil.copy(
|
|
src_dataset_path / Path(s["audio_filepath"]),
|
|
dest_dataset_path / Path(s["audio_filepath"]),
|
|
)
|
|
typer.echo(f"skipped {duration_str(skip_duration)} of audio")
|
|
|
|
def filtered_blanks():
|
|
blank_count = 0
|
|
for s in ExtendedPath(data_file).read_jsonl():
|
|
nums = re.sub(" ", "", s["text"])
|
|
if nums != "":
|
|
blank_count += 1
|
|
shutil.copy(
|
|
src_dataset_path / Path(s["audio_filepath"]),
|
|
dest_dataset_path / Path(s["audio_filepath"]),
|
|
)
|
|
yield s
|
|
typer.echo(f"filtered {blank_count} blank samples")
|
|
|
|
def filtered_transform_digits():
|
|
count = 0
|
|
for s in ExtendedPath(data_file).read_jsonl():
|
|
count += 1
|
|
digit_text = replace_digit_symbol(s["text"])
|
|
only_digits = discard_except_digits(digit_text)
|
|
char_text = digits_to_chars(only_digits)
|
|
shutil.copy(
|
|
src_dataset_path / Path(s["audio_filepath"]),
|
|
dest_dataset_path / Path(s["audio_filepath"]),
|
|
)
|
|
s["text"] = char_text
|
|
yield s
|
|
typer.echo(f"transformed {count} samples")
|
|
|
|
def filtered_extract_chars():
|
|
count = 0
|
|
for s in ExtendedPath(data_file).read_jsonl():
|
|
count += 1
|
|
no_digits = digits_to_chars(s["text"]).upper()
|
|
only_chars = re.sub("[^A-Z'\b]", " ", no_digits)
|
|
filter_text = replace_redundant_spaces_with(only_chars, " ").strip()
|
|
shutil.copy(
|
|
src_dataset_path / Path(s["audio_filepath"]),
|
|
dest_dataset_path / Path(s["audio_filepath"]),
|
|
)
|
|
s["text"] = filter_text
|
|
yield s
|
|
typer.echo(f"transformed {count} samples")
|
|
|
|
def filtered_resample():
|
|
count = 0
|
|
for s in ExtendedPath(data_file).read_jsonl():
|
|
count += 1
|
|
src_aud = pydub.AudioSegment.from_file(
|
|
src_dataset_path / Path(s["audio_filepath"])
|
|
)
|
|
dst_aud = src_aud.set_channels(1).set_sample_width(1).set_frame_rate(24000)
|
|
dst_aud.export(dest_dataset_path / Path(s["audio_filepath"]), format="wav")
|
|
yield s
|
|
typer.echo(f"transformed {count} samples")
|
|
|
|
filter_kind_map = {
|
|
"max_dur_1hr_min3num": filtered_max_dur,
|
|
"skip_dur_1hr_min3num": filtered_skip_dur,
|
|
"blanks": filtered_blanks,
|
|
"transform_digits": filtered_transform_digits,
|
|
"extract_chars": filtered_extract_chars,
|
|
"resample_ulaw24kmono": filtered_resample,
|
|
}
|
|
return filter_kind_map
|