plume-asr/plume/cli/data/__init__.py

340 lines
12 KiB
Python

import json
from pathlib import Path
# from sklearn.model_selection import train_test_split
from plume.utils import (
asr_manifest_reader,
asr_manifest_writer,
ExtendedPath,
duration_str,
generate_filter_map,
get_mongo_conn,
tscript_uuid_fname,
lazy_callable
)
from typing import List
from itertools import chain
import shutil
import typer
import soundfile
from ...models.wav2vec2.data import app as wav2vec2_app
from .generate import app as generate_app
train_test_split = lazy_callable('sklearn.model_selection.train_test_split')
app = typer.Typer()
app.add_typer(generate_app, name="generate")
app.add_typer(wav2vec2_app, name="wav2vec2")
@app.command()
def fix_path(dataset_path: Path, force: bool = False):
manifest_path = dataset_path / Path("manifest.json")
real_manifest_path = dataset_path / Path("abs_manifest.json")
def fix_real_path():
for i in asr_manifest_reader(manifest_path):
i["audio_filepath"] = str(
(dataset_path / Path(i["audio_filepath"])).absolute()
)
yield i
def fix_rel_path():
for i in asr_manifest_reader(real_manifest_path):
i["audio_filepath"] = str(
Path(i["audio_filepath"]).relative_to(dataset_path)
)
yield i
if not manifest_path.exists() and not real_manifest_path.exists():
typer.echo("Invalid dataset directory")
if not real_manifest_path.exists() or force:
asr_manifest_writer(real_manifest_path, fix_real_path())
if not manifest_path.exists():
asr_manifest_writer(manifest_path, fix_rel_path())
@app.command()
def augment(src_dataset_paths: List[Path], dest_dataset_path: Path):
reader_list = []
abs_manifest_path = Path("abs_manifest.json")
for dataset_path in src_dataset_paths:
manifest_path = dataset_path / abs_manifest_path
reader_list.append(asr_manifest_reader(manifest_path))
dest_dataset_path.mkdir(parents=True, exist_ok=True)
dest_manifest_path = dest_dataset_path / abs_manifest_path
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
@app.command()
def split(dataset_path: Path, test_size: float = 0.03):
manifest_path = dataset_path / Path("abs_manifest.json")
if not manifest_path.exists():
fix_path(dataset_path)
asr_data = list(asr_manifest_reader(manifest_path))
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
@app.command()
def validate(dataset_path: Path):
from natural.date import compress
from datetime import timedelta
for mf_type in ["train_manifest.json", "test_manifest.json"]:
data_file = dataset_path / Path(mf_type)
print(f"validating {data_file}.")
with Path(data_file).open("r") as pf:
pnr_jsonl = pf.readlines()
duration = 0
for (i, s) in enumerate(pnr_jsonl):
try:
d = json.loads(s)
duration += d["duration"]
audio_file = data_file.parent / Path(d["audio_filepath"])
if not audio_file.exists():
raise OSError(f"File {audio_file} not found")
except BaseException as e:
print(f'failed on {i} with "{e}"')
duration_str = compress(timedelta(seconds=duration), pad=" ")
print(
f"no errors found. seems like a valid {mf_type}. contains {duration_str} of audio"
)
@app.command()
def filter(src_dataset_path: Path, dest_dataset_path: Path, kind: str = "skip_dur"):
dest_manifest = dest_dataset_path / Path("manifest.json")
data_file = src_dataset_path / Path("manifest.json")
dest_wav_dir = dest_dataset_path / Path("wavs")
dest_wav_dir.mkdir(exist_ok=True, parents=True)
filter_kind_map = generate_filter_map(
src_dataset_path, dest_dataset_path, data_file
)
selected_filter = filter_kind_map.get(kind, None)
if selected_filter:
asr_manifest_writer(dest_manifest, selected_filter())
else:
typer.echo(f"filter kind - {kind} not implemented")
typer.echo(f"select one of {', '.join(filter_kind_map.keys())}")
@app.command()
def info(dataset_path: Path):
for k in ["", "abs_", "train_", "test_"]:
mf_wav_duration = (
real_duration
) = max_duration = empty_duration = empty_count = total_count = 0
data_file = dataset_path / Path(f"{k}manifest.json")
if data_file.exists():
print(f"stats on {data_file}")
for s in ExtendedPath(data_file).read_jsonl():
total_count += 1
mf_wav_duration += s["duration"]
if s["text"] == "":
empty_count += 1
empty_duration += s["duration"]
wav_path = str(dataset_path / Path(s["audio_filepath"]))
if max_duration < soundfile.info(wav_path).duration:
max_duration = soundfile.info(wav_path).duration
real_duration += soundfile.info(wav_path).duration
# frame_count = soundfile.info(audio_fname).frames
print(f"max audio duration : {duration_str(max_duration)}")
print(f"total audio duration : {duration_str(mf_wav_duration)}")
print(f"total real audio duration : {duration_str(real_duration)}")
print(
f"total content duration : {duration_str(mf_wav_duration-empty_duration)}"
)
print(f"total empty duration : {duration_str(empty_duration)}")
print(
f"total empty samples : {empty_count}/{total_count} ({empty_count*100/total_count:.2f}%)"
)
@app.command()
def audio_duration(dataset_path: Path):
wav_duration = 0
for audio_rel_fname in dataset_path.absolute().glob("**/*.wav"):
audio_fname = str(audio_rel_fname)
wav_duration += soundfile.info(audio_fname).duration
typer.echo(f"duration of wav files @ {dataset_path}: {duration_str(wav_duration)}")
@app.command()
def migrate(src_path: Path, dest_path: Path):
shutil.copytree(str(src_path), str(dest_path))
wav_dir = dest_path / Path("wavs")
wav_dir.mkdir(exist_ok=True, parents=True)
abs_manifest_path = ExtendedPath(dest_path / Path("abs_manifest.json"))
backup_abs_manifest_path = abs_manifest_path.with_suffix(".json.orig")
shutil.copy(abs_manifest_path, backup_abs_manifest_path)
manifest_data = list(abs_manifest_path.read_jsonl())
for md in manifest_data:
orig_path = Path(md["audio_filepath"])
new_path = wav_dir / Path(orig_path.name)
shutil.copy(orig_path, new_path)
md["audio_filepath"] = str(new_path)
abs_manifest_path.write_jsonl(manifest_data)
fix_path(dest_path)
@app.command()
def task_split(
data_dir: Path,
dump_file: Path = Path("ui_dump.json"),
task_count: int = typer.Option(2, show_default=True),
task_file: str = "task_dump",
sort: bool = True,
):
"""
split ui_dump.json to `task_count` tasks
"""
import pandas as pd
import numpy as np
processed_data_path = data_dir / dump_file
processed_data = ExtendedPath(processed_data_path).read_json()
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
task_f = task_f.reset_index(drop=True)
task_f["real_idx"] = task_f.index
task_data = task_f.to_dict("records")
if sort:
task_data = sorted(task_data, key=lambda x: x["asr_wer"], reverse=True)
processed_data["data"] = task_data
task_path = data_dir / Path(task_file + f"-{t_idx}.json")
ExtendedPath(task_path).write_json(processed_data)
def get_corrections(task_uid):
col = get_mongo_conn(col="asr_validation")
task_id = [
c
for c in col.distinct("task_id")
if c.rsplit("-", 1)[1] == task_uid or c == task_uid
][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find(
{"type": "correction", "task_id": task_id}, projection={"_id": False}
)
corrections = [c for c in cursor_obj]
return corrections
@app.command()
def dump_task_corrections(data_dir: Path, task_uid: str):
dump_fname: Path = Path(f"corrections-{task_uid}.json")
dump_path = data_dir / dump_fname
corrections = get_corrections(task_uid)
ExtendedPath(dump_path).write_json(corrections)
@app.command()
def dump_all_corrections(data_dir: Path):
for task_lcks in data_dir.glob('task-*.lck'):
task_uid = task_lcks.stem.replace('task-', '')
dump_task_corrections(data_dir, task_uid)
@app.command()
def update_corrections(
data_dir: Path,
skip_incorrect: bool = typer.Option(
False, show_default=True, help="treats incorrect as invalid"
),
skip_inaudible: bool = typer.Option(
False, show_default=True, help="include invalid as blank target"
),
):
"""
applies the corrections-*.json
backup the original dataset
"""
manifest_file: Path = Path("manifest.json")
renames_file: Path = Path("rename_map.json")
ui_dump_file: Path = Path("ui_dump.json")
data_manifest_path = data_dir / manifest_file
renames_path = data_dir / renames_file
def correct_ui_dump(data_dir, rename_result):
ui_dump_path = data_dir / ui_dump_file
# corrections_path = data_dir / Path("corrections.json")
corrections = [
t
for p in data_dir.glob("corrections-*.json")
for t in ExtendedPath(p).read_json()
]
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
correction_map = {
c["code"]: c["value"]["correction"]
for c in corrections
if c["value"]["status"] == "Incorrect"
}
for d in ui_data:
orig_audio_path = (data_dir / Path(d["audio_path"])).absolute()
if d["utterance_id"] in correct_set:
d["corrected_from"] = d["text"]
yield d
elif d["utterance_id"] in correction_map:
correct_text = correction_map[d["utterance_id"]]
if skip_incorrect:
ap = d["audio_path"]
print(f"skipping incorrect {ap} corrected to {correct_text}")
orig_audio_path.unlink()
else:
new_fname = tscript_uuid_fname(correct_text)
rename_result[new_fname] = {
"orig_text": d["text"],
"correct_text": correct_text,
"orig_id": d["utterance_id"],
}
new_name = str(Path(new_fname).with_suffix(".wav"))
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_path"]).with_name(new_name))
d["corrected_from"] = d["text"]
d["text"] = correct_text
d["audio_path"] = new_filepath
yield d
else:
if skip_inaudible:
orig_audio_path.unlink()
else:
d["corrected_from"] = d["text"]
d["text"] = ""
yield d
dataset_dir = data_manifest_path.parent
dataset_name = dataset_dir.name
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
if not backup_dir.exists():
typer.echo(f"backing up to {backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir))
renames = {}
corrected_ui_dump = list(correct_ui_dump(data_dir, renames))
ExtendedPath(data_dir / ui_dump_file).write_json({"data": corrected_ui_dump})
corrected_manifest = (
{
"audio_filepath": d["audio_path"],
"duration": d["duration"],
"text": d["text"],
}
for d in corrected_ui_dump
)
asr_manifest_writer(data_manifest_path, corrected_manifest)
ExtendedPath(renames_path).write_json(renames)
def main():
app()
if __name__ == "__main__":
main()