1. fix update-correction to use ui_dump instead of manifest
2. update training params no of checkpoints on chpk frequency
parent
000853b600
commit
e76ccda5dd
|
|
@ -23,7 +23,7 @@ def fixate_data(dataset_path: Path):
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||||
reader_list = []
|
reader_list = []
|
||||||
abs_manifest_path = Path("abs_manifest.json")
|
abs_manifest_path = Path("abs_manifest.json")
|
||||||
for dataset_path in src_dataset_paths:
|
for dataset_path in src_dataset_paths:
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,10 @@ def alnum_to_asr_tokens(text):
|
||||||
return ("".join(num_tokens)).lower()
|
return ("".join(num_tokens)).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def tscript_uuid_fname(transcript):
|
||||||
|
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||||
|
|
||||||
|
|
||||||
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||||
dataset_dir = output_dir / Path(dataset_name)
|
dataset_dir = output_dir / Path(dataset_name)
|
||||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -67,7 +71,7 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||||
with asr_manifest.open("w") as mf:
|
with asr_manifest.open("w") as mf:
|
||||||
print(f"writing manifest to {asr_manifest}")
|
print(f"writing manifest to {asr_manifest}")
|
||||||
for transcript, audio_dur, wav_data in asr_data_source:
|
for transcript, audio_dur, wav_data in asr_data_source:
|
||||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
fname = tscript_uuid_fname(transcript)
|
||||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||||
audio_file.write_bytes(wav_data)
|
audio_file.write_bytes(wav_data)
|
||||||
rel_pnr_path = audio_file.relative_to(dataset_dir)
|
rel_pnr_path = audio_file.relative_to(dataset_dir)
|
||||||
|
|
@ -174,7 +178,7 @@ def asr_manifest_reader(data_manifest_path: Path):
|
||||||
pnr_data = [json.loads(v) for v in pnr_jsonl]
|
pnr_data = [json.loads(v) for v in pnr_jsonl]
|
||||||
for p in pnr_data:
|
for p in pnr_data:
|
||||||
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
||||||
p["chars"] = Path(p["audio_filepath"]).stem
|
p["text"] = p["text"].strip()
|
||||||
yield p
|
yield p
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from ..utils import (
|
||||||
ExtendedPath,
|
ExtendedPath,
|
||||||
asr_manifest_reader,
|
asr_manifest_reader,
|
||||||
asr_manifest_writer,
|
asr_manifest_writer,
|
||||||
|
tscript_uuid_fname,
|
||||||
get_mongo_conn,
|
get_mongo_conn,
|
||||||
plot_seg,
|
plot_seg,
|
||||||
)
|
)
|
||||||
|
|
@ -262,9 +263,9 @@ class ExtractionType(str, Enum):
|
||||||
def split_extract(
|
def split_extract(
|
||||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||||
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
|
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
|
||||||
dump_dir: Path = Path("./data/valiation_data"),
|
# dump_dir: Path = Path("./data/valiation_data"),
|
||||||
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
dump_file: Path = Path("ui_dump.json"),
|
dump_file: Path = Path("ui_dump.json"),
|
||||||
manifest_dir: Path = Path("./data/asr_data"),
|
|
||||||
manifest_file: Path = Path("manifest.json"),
|
manifest_file: Path = Path("manifest.json"),
|
||||||
corrections_file: Path = Path("corrections.json"),
|
corrections_file: Path = Path("corrections.json"),
|
||||||
conv_data_path: Path = Path("./data/conv_data.json"),
|
conv_data_path: Path = Path("./data/conv_data.json"),
|
||||||
|
|
@ -272,33 +273,18 @@ def split_extract(
|
||||||
):
|
):
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
def get_conv_data(cdp):
|
|
||||||
from itertools import product
|
|
||||||
|
|
||||||
conv_data = json.load(cdp.open())
|
|
||||||
days = [str(i) for i in range(1, 32)]
|
|
||||||
months = conv_data["months"]
|
|
||||||
day_months = {d + " " + m for d, m in product(days, months)}
|
|
||||||
return {
|
|
||||||
"cities": set(conv_data["cities"]),
|
|
||||||
"names": set(conv_data["names"]),
|
|
||||||
"dates": day_months,
|
|
||||||
}
|
|
||||||
|
|
||||||
dest_data_name = data_name + "_" + extraction_type.value
|
dest_data_name = data_name + "_" + extraction_type.value
|
||||||
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
|
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||||
conv_data = get_conv_data(conv_data_path)
|
conv_data = ExtendedPath(conv_data_path).read_json()
|
||||||
extraction_vals = conv_data[extraction_type.value]
|
extraction_vals = conv_data[extraction_type.value]
|
||||||
|
|
||||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||||
dest_data_dir = manifest_dir / Path(dest_data_name)
|
dest_data_dir = dump_dir / Path(dest_data_name)
|
||||||
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
||||||
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
||||||
dest_manifest_path = dest_data_dir / manifest_file
|
dest_manifest_path = dest_data_dir / manifest_file
|
||||||
dest_ui_dir = dump_dir / Path(dest_data_name)
|
dest_ui_path = dest_data_dir / dump_file
|
||||||
dest_ui_dir.mkdir(exist_ok=True, parents=True)
|
dest_correction_path = dest_data_dir / corrections_file
|
||||||
dest_ui_path = dest_ui_dir / dump_file
|
|
||||||
dest_correction_path = dest_ui_dir / corrections_file
|
|
||||||
|
|
||||||
def extract_manifest(mg):
|
def extract_manifest(mg):
|
||||||
for m in mg:
|
for m in mg:
|
||||||
|
|
@ -330,19 +316,19 @@ def split_extract(
|
||||||
@app.command()
|
@app.command()
|
||||||
def update_corrections(
|
def update_corrections(
|
||||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||||
dump_dir: Path = Path("./data/valiation_data"),
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
manifest_dir: Path = Path("./data/asr_data"),
|
|
||||||
manifest_file: Path = Path("manifest.json"),
|
manifest_file: Path = Path("manifest.json"),
|
||||||
corrections_file: Path = Path("corrections.json"),
|
corrections_file: Path = Path("corrections.json"),
|
||||||
# data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
ui_dump_file: Path = Path("ui_dump.json"),
|
||||||
# corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
|
||||||
skip_incorrect: bool = True,
|
skip_incorrect: bool = True,
|
||||||
):
|
):
|
||||||
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
|
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||||
|
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
|
||||||
|
|
||||||
def correct_manifest(manifest_data_gen, corrections_path):
|
def correct_manifest(ui_dump_path, corrections_path):
|
||||||
corrections = json.load(corrections_path.open())
|
corrections = ExtendedPath(corrections_path).read_json()
|
||||||
|
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
|
||||||
correct_set = {
|
correct_set = {
|
||||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||||
}
|
}
|
||||||
|
|
@ -355,36 +341,38 @@ def update_corrections(
|
||||||
# for d in manifest_data_gen:
|
# for d in manifest_data_gen:
|
||||||
# if d["chars"] in incorrect_set:
|
# if d["chars"] in incorrect_set:
|
||||||
# d["audio_path"].unlink()
|
# d["audio_path"].unlink()
|
||||||
renamed_set = set()
|
# renamed_set = set()
|
||||||
for d in manifest_data_gen:
|
for d in ui_data:
|
||||||
if d["chars"] in correct_set:
|
if d["utterance_id"] in correct_set:
|
||||||
yield {
|
yield {
|
||||||
"audio_filepath": d["audio_filepath"],
|
"audio_filepath": d["audio_filepath"],
|
||||||
"duration": d["duration"],
|
"duration": d["duration"],
|
||||||
"text": d["text"],
|
"text": d["text"],
|
||||||
}
|
}
|
||||||
elif d["chars"] in correction_map:
|
elif d["utterance_id"] in correction_map:
|
||||||
correct_text = correction_map[d["chars"]]
|
correct_text = correction_map[d["utterance_id"]]
|
||||||
if skip_incorrect:
|
if skip_incorrect:
|
||||||
print(
|
print(
|
||||||
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
renamed_set.add(correct_text)
|
orig_audio_path = Path(d["audio_path"])
|
||||||
new_name = str(Path(correct_text).with_suffix(".wav"))
|
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
|
||||||
d["audio_path"].replace(d["audio_path"].with_name(new_name))
|
new_audio_path = orig_audio_path.with_name(new_name)
|
||||||
|
orig_audio_path.replace(new_audio_path)
|
||||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||||
yield {
|
yield {
|
||||||
"audio_filepath": new_filepath,
|
"audio_filepath": new_filepath,
|
||||||
"duration": d["duration"],
|
"duration": d["duration"],
|
||||||
"text": alnum_to_asr_tokens(correct_text),
|
"text": correct_text,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
orig_audio_path = Path(d["audio_path"])
|
||||||
# don't delete if another correction points to an old file
|
# don't delete if another correction points to an old file
|
||||||
if d["chars"] not in renamed_set:
|
# if d["text"] not in renamed_set:
|
||||||
d["audio_path"].unlink()
|
orig_audio_path.unlink()
|
||||||
else:
|
# else:
|
||||||
print(f'skipping deletion of correction:{d["chars"]}')
|
# print(f'skipping deletion of correction:{d["text"]}')
|
||||||
|
|
||||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||||
dataset_dir = data_manifest_path.parent
|
dataset_dir = data_manifest_path.parent
|
||||||
|
|
@ -393,8 +381,8 @@ def update_corrections(
|
||||||
if not backup_dir.exists():
|
if not backup_dir.exists():
|
||||||
typer.echo(f"backing up to :{backup_dir}")
|
typer.echo(f"backing up to :{backup_dir}")
|
||||||
shutil.copytree(str(dataset_dir), str(backup_dir))
|
shutil.copytree(str(dataset_dir), str(backup_dir))
|
||||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
# manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||||
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
|
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
|
||||||
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
||||||
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
||||||
new_data_manifest_path.replace(data_manifest_path)
|
new_data_manifest_path.replace(data_manifest_path)
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ def parse_args():
|
||||||
work_dir="./train/work",
|
work_dir="./train/work",
|
||||||
num_epochs=300,
|
num_epochs=300,
|
||||||
weight_decay=0.005,
|
weight_decay=0.005,
|
||||||
checkpoint_save_freq=200,
|
checkpoint_save_freq=100,
|
||||||
eval_freq=100,
|
eval_freq=100,
|
||||||
load_dir="./train/models/jasper/",
|
load_dir="./train/models/jasper/",
|
||||||
warmup_steps=3,
|
warmup_steps=3,
|
||||||
|
|
@ -266,6 +266,7 @@ def create_all_dags(args, neural_factory):
|
||||||
folder=neural_factory.checkpoint_dir,
|
folder=neural_factory.checkpoint_dir,
|
||||||
load_from_folder=args.load_dir,
|
load_from_folder=args.load_dir,
|
||||||
step_freq=args.checkpoint_save_freq,
|
step_freq=args.checkpoint_save_freq,
|
||||||
|
checkpoints_to_keep=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
callbacks = [train_callback, chpt_callback]
|
callbacks = [train_callback, chpt_callback]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue