78 lines
2.6 KiB
Python
78 lines
2.6 KiB
Python
import json
|
|
from pathlib import Path
|
|
from sklearn.model_selection import train_test_split
|
|
from .utils import asr_manifest_reader, asr_manifest_writer
|
|
from typing import List
|
|
from itertools import chain
|
|
import typer
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
@app.command()
|
|
def fixate_data(dataset_path: Path):
|
|
manifest_path = dataset_path / Path("manifest.json")
|
|
real_manifest_path = dataset_path / Path("abs_manifest.json")
|
|
|
|
def fix_path():
|
|
for i in asr_manifest_reader(manifest_path):
|
|
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
|
|
yield i
|
|
|
|
asr_manifest_writer(real_manifest_path, fix_path())
|
|
|
|
|
|
@app.command()
|
|
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
|
reader_list = []
|
|
abs_manifest_path = Path("abs_manifest.json")
|
|
for dataset_path in src_dataset_paths:
|
|
manifest_path = dataset_path / abs_manifest_path
|
|
reader_list.append(asr_manifest_reader(manifest_path))
|
|
dest_dataset_path.mkdir(parents=True, exist_ok=True)
|
|
dest_manifest_path = dest_dataset_path / abs_manifest_path
|
|
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
|
|
|
|
|
|
@app.command()
|
|
def split_data(dataset_path: Path, test_size: float = 0.1):
|
|
manifest_path = dataset_path / Path("abs_manifest.json")
|
|
asr_data = list(asr_manifest_reader(manifest_path))
|
|
train_data, test_data = train_test_split(asr_data, test_size=test_size)
|
|
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_data)
|
|
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_data)
|
|
|
|
|
|
@app.command()
|
|
def validate_data(dataset_path: Path):
|
|
from natural.date import compress
|
|
from datetime import timedelta
|
|
|
|
for mf_type in ["train_manifest.json", "test_manifest.json"]:
|
|
data_file = dataset_path / Path(mf_type)
|
|
print(f"validating {data_file}.")
|
|
with Path(data_file).open("r") as pf:
|
|
data_jsonl = pf.readlines()
|
|
duration = 0
|
|
for (i, s) in enumerate(data_jsonl):
|
|
try:
|
|
d = json.loads(s)
|
|
duration += d["duration"]
|
|
audio_file = data_file.parent / Path(d["audio_filepath"])
|
|
if not audio_file.exists():
|
|
raise OSError(f"File {audio_file} not found")
|
|
except BaseException as e:
|
|
print(f'failed on {i} with "{e}"')
|
|
duration_str = compress(timedelta(seconds=duration), pad=" ")
|
|
print(
|
|
f"no errors found. seems like a valid {mf_type}. contains {duration_str}sec of audio"
|
|
)
|
|
|
|
|
|
def main():
|
|
app()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|