1. fix update-correction to use ui_dump instead of manifest
2. update training params no of checkpoints on chpk frequency
parent
000853b600
commit
e76ccda5dd
|
|
@ -23,7 +23,7 @@ def fixate_data(dataset_path: Path):
|
|||
|
||||
|
||||
@app.command()
|
||||
def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||
reader_list = []
|
||||
abs_manifest_path = Path("abs_manifest.json")
|
||||
for dataset_path in src_dataset_paths:
|
||||
|
|
|
|||
|
|
@ -59,6 +59,10 @@ def alnum_to_asr_tokens(text):
|
|||
return ("".join(num_tokens)).lower()
|
||||
|
||||
|
||||
def tscript_uuid_fname(transcript):
|
||||
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
|
||||
|
||||
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
dataset_dir = output_dir / Path(dataset_name)
|
||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -67,7 +71,7 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
|||
with asr_manifest.open("w") as mf:
|
||||
print(f"writing manifest to {asr_manifest}")
|
||||
for transcript, audio_dur, wav_data in asr_data_source:
|
||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
fname = tscript_uuid_fname(transcript)
|
||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
audio_file.write_bytes(wav_data)
|
||||
rel_pnr_path = audio_file.relative_to(dataset_dir)
|
||||
|
|
@ -174,7 +178,7 @@ def asr_manifest_reader(data_manifest_path: Path):
|
|||
pnr_data = [json.loads(v) for v in pnr_jsonl]
|
||||
for p in pnr_data:
|
||||
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
||||
p["chars"] = Path(p["audio_filepath"]).stem
|
||||
p["text"] = p["text"].strip()
|
||||
yield p
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from ..utils import (
|
|||
ExtendedPath,
|
||||
asr_manifest_reader,
|
||||
asr_manifest_writer,
|
||||
tscript_uuid_fname,
|
||||
get_mongo_conn,
|
||||
plot_seg,
|
||||
)
|
||||
|
|
@ -262,9 +263,9 @@ class ExtractionType(str, Enum):
|
|||
def split_extract(
|
||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
|
||||
dump_dir: Path = Path("./data/valiation_data"),
|
||||
# dump_dir: Path = Path("./data/valiation_data"),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
manifest_dir: Path = Path("./data/asr_data"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: Path = Path("corrections.json"),
|
||||
conv_data_path: Path = Path("./data/conv_data.json"),
|
||||
|
|
@ -272,33 +273,18 @@ def split_extract(
|
|||
):
|
||||
import shutil
|
||||
|
||||
def get_conv_data(cdp):
|
||||
from itertools import product
|
||||
|
||||
conv_data = json.load(cdp.open())
|
||||
days = [str(i) for i in range(1, 32)]
|
||||
months = conv_data["months"]
|
||||
day_months = {d + " " + m for d, m in product(days, months)}
|
||||
return {
|
||||
"cities": set(conv_data["cities"]),
|
||||
"names": set(conv_data["names"]),
|
||||
"dates": day_months,
|
||||
}
|
||||
|
||||
dest_data_name = data_name + "_" + extraction_type.value
|
||||
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
|
||||
conv_data = get_conv_data(conv_data_path)
|
||||
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||
conv_data = ExtendedPath(conv_data_path).read_json()
|
||||
extraction_vals = conv_data[extraction_type.value]
|
||||
|
||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
dest_data_dir = manifest_dir / Path(dest_data_name)
|
||||
dest_data_dir = dump_dir / Path(dest_data_name)
|
||||
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
||||
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
||||
dest_manifest_path = dest_data_dir / manifest_file
|
||||
dest_ui_dir = dump_dir / Path(dest_data_name)
|
||||
dest_ui_dir.mkdir(exist_ok=True, parents=True)
|
||||
dest_ui_path = dest_ui_dir / dump_file
|
||||
dest_correction_path = dest_ui_dir / corrections_file
|
||||
dest_ui_path = dest_data_dir / dump_file
|
||||
dest_correction_path = dest_data_dir / corrections_file
|
||||
|
||||
def extract_manifest(mg):
|
||||
for m in mg:
|
||||
|
|
@ -330,19 +316,19 @@ def split_extract(
|
|||
@app.command()
|
||||
def update_corrections(
|
||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||
dump_dir: Path = Path("./data/valiation_data"),
|
||||
manifest_dir: Path = Path("./data/asr_data"),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: Path = Path("corrections.json"),
|
||||
# data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
# corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
||||
ui_dump_file: Path = Path("ui_dump.json"),
|
||||
skip_incorrect: bool = True,
|
||||
):
|
||||
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
|
||||
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
|
||||
|
||||
def correct_manifest(manifest_data_gen, corrections_path):
|
||||
corrections = json.load(corrections_path.open())
|
||||
def correct_manifest(ui_dump_path, corrections_path):
|
||||
corrections = ExtendedPath(corrections_path).read_json()
|
||||
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
|
||||
correct_set = {
|
||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||
}
|
||||
|
|
@ -355,36 +341,38 @@ def update_corrections(
|
|||
# for d in manifest_data_gen:
|
||||
# if d["chars"] in incorrect_set:
|
||||
# d["audio_path"].unlink()
|
||||
renamed_set = set()
|
||||
for d in manifest_data_gen:
|
||||
if d["chars"] in correct_set:
|
||||
# renamed_set = set()
|
||||
for d in ui_data:
|
||||
if d["utterance_id"] in correct_set:
|
||||
yield {
|
||||
"audio_filepath": d["audio_filepath"],
|
||||
"duration": d["duration"],
|
||||
"text": d["text"],
|
||||
}
|
||||
elif d["chars"] in correction_map:
|
||||
correct_text = correction_map[d["chars"]]
|
||||
elif d["utterance_id"] in correction_map:
|
||||
correct_text = correction_map[d["utterance_id"]]
|
||||
if skip_incorrect:
|
||||
print(
|
||||
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
||||
)
|
||||
else:
|
||||
renamed_set.add(correct_text)
|
||||
new_name = str(Path(correct_text).with_suffix(".wav"))
|
||||
d["audio_path"].replace(d["audio_path"].with_name(new_name))
|
||||
orig_audio_path = Path(d["audio_path"])
|
||||
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
|
||||
new_audio_path = orig_audio_path.with_name(new_name)
|
||||
orig_audio_path.replace(new_audio_path)
|
||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||
yield {
|
||||
"audio_filepath": new_filepath,
|
||||
"duration": d["duration"],
|
||||
"text": alnum_to_asr_tokens(correct_text),
|
||||
"text": correct_text,
|
||||
}
|
||||
else:
|
||||
orig_audio_path = Path(d["audio_path"])
|
||||
# don't delete if another correction points to an old file
|
||||
if d["chars"] not in renamed_set:
|
||||
d["audio_path"].unlink()
|
||||
else:
|
||||
print(f'skipping deletion of correction:{d["chars"]}')
|
||||
# if d["text"] not in renamed_set:
|
||||
orig_audio_path.unlink()
|
||||
# else:
|
||||
# print(f'skipping deletion of correction:{d["text"]}')
|
||||
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
dataset_dir = data_manifest_path.parent
|
||||
|
|
@ -393,8 +381,8 @@ def update_corrections(
|
|||
if not backup_dir.exists():
|
||||
typer.echo(f"backing up to :{backup_dir}")
|
||||
shutil.copytree(str(dataset_dir), str(backup_dir))
|
||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
|
||||
# manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
|
||||
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
||||
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
||||
new_data_manifest_path.replace(data_manifest_path)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def parse_args():
|
|||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=200,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
|
|
@ -266,6 +266,7 @@ def create_all_dags(args, neural_factory):
|
|||
folder=neural_factory.checkpoint_dir,
|
||||
load_from_folder=args.load_dir,
|
||||
step_freq=args.checkpoint_save_freq,
|
||||
checkpoints_to_keep=30,
|
||||
)
|
||||
|
||||
callbacks = [train_callback, chpt_callback]
|
||||
|
|
|
|||
Loading…
Reference in New Issue