1. fix update-correction to use ui_dump instead of manifest

2. update training params no of checkpoints on chpk frequency
Malar Kannan 2020-06-19 14:16:04 +05:30
parent 000853b600
commit e76ccda5dd
4 changed files with 41 additions and 48 deletions

View File

@ -23,7 +23,7 @@ def fixate_data(dataset_path: Path):
@app.command()
def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
reader_list = []
abs_manifest_path = Path("abs_manifest.json")
for dataset_path in src_dataset_paths:

View File

@ -59,6 +59,10 @@ def alnum_to_asr_tokens(text):
return ("".join(num_tokens)).lower()
def tscript_uuid_fname(transcript):
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
@ -67,7 +71,7 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for transcript, audio_dur, wav_data in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
fname = tscript_uuid_fname(transcript)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
rel_pnr_path = audio_file.relative_to(dataset_dir)
@ -174,7 +178,7 @@ def asr_manifest_reader(data_manifest_path: Path):
pnr_data = [json.loads(v) for v in pnr_jsonl]
for p in pnr_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["chars"] = Path(p["audio_filepath"]).stem
p["text"] = p["text"].strip()
yield p

View File

@ -11,6 +11,7 @@ from ..utils import (
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
tscript_uuid_fname,
get_mongo_conn,
plot_seg,
)
@ -262,9 +263,9 @@ class ExtractionType(str, Enum):
def split_extract(
data_name: str = typer.Option("call_alphanum", show_default=True),
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
# dump_dir: Path = Path("./data/valiation_data"),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
conv_data_path: Path = Path("./data/conv_data.json"),
@ -272,33 +273,18 @@ def split_extract(
):
import shutil
def get_conv_data(cdp):
from itertools import product
conv_data = json.load(cdp.open())
days = [str(i) for i in range(1, 32)]
months = conv_data["months"]
day_months = {d + " " + m for d, m in product(days, months)}
return {
"cities": set(conv_data["cities"]),
"names": set(conv_data["names"]),
"dates": day_months,
}
dest_data_name = data_name + "_" + extraction_type.value
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
conv_data = get_conv_data(conv_data_path)
data_manifest_path = dump_dir / Path(data_name) / manifest_file
conv_data = ExtendedPath(conv_data_path).read_json()
extraction_vals = conv_data[extraction_type.value]
manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = manifest_dir / Path(dest_data_name)
dest_data_dir = dump_dir / Path(dest_data_name)
dest_data_dir.mkdir(exist_ok=True, parents=True)
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file
dest_ui_dir = dump_dir / Path(dest_data_name)
dest_ui_dir.mkdir(exist_ok=True, parents=True)
dest_ui_path = dest_ui_dir / dump_file
dest_correction_path = dest_ui_dir / corrections_file
dest_ui_path = dest_data_dir / dump_file
dest_correction_path = dest_data_dir / corrections_file
def extract_manifest(mg):
for m in mg:
@ -330,19 +316,19 @@ def split_extract(
@app.command()
def update_corrections(
data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
manifest_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
# data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
# corrections_path: Path = Path("./data/valiation_data/corrections.json"),
ui_dump_file: Path = Path("ui_dump.json"),
skip_incorrect: bool = True,
):
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
data_manifest_path = dump_dir / Path(data_name) / manifest_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
def correct_manifest(manifest_data_gen, corrections_path):
corrections = json.load(corrections_path.open())
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
@ -355,36 +341,38 @@ def update_corrections(
# for d in manifest_data_gen:
# if d["chars"] in incorrect_set:
# d["audio_path"].unlink()
renamed_set = set()
for d in manifest_data_gen:
if d["chars"] in correct_set:
# renamed_set = set()
for d in ui_data:
if d["utterance_id"] in correct_set:
yield {
"audio_filepath": d["audio_filepath"],
"duration": d["duration"],
"text": d["text"],
}
elif d["chars"] in correction_map:
correct_text = correction_map[d["chars"]]
elif d["utterance_id"] in correction_map:
correct_text = correction_map[d["utterance_id"]]
if skip_incorrect:
print(
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
)
else:
renamed_set.add(correct_text)
new_name = str(Path(correct_text).with_suffix(".wav"))
d["audio_path"].replace(d["audio_path"].with_name(new_name))
orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
yield {
"audio_filepath": new_filepath,
"duration": d["duration"],
"text": alnum_to_asr_tokens(correct_text),
"text": correct_text,
}
else:
orig_audio_path = Path(d["audio_path"])
# don't delete if another correction points to an old file
if d["chars"] not in renamed_set:
d["audio_path"].unlink()
else:
print(f'skipping deletion of correction:{d["chars"]}')
# if d["text"] not in renamed_set:
orig_audio_path.unlink()
# else:
# print(f'skipping deletion of correction:{d["text"]}')
typer.echo(f"Using data manifest:{data_manifest_path}")
dataset_dir = data_manifest_path.parent
@ -393,8 +381,8 @@ def update_corrections(
if not backup_dir.exists():
typer.echo(f"backing up to :{backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir))
manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
# manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
new_data_manifest_path.replace(data_manifest_path)

View File

@ -41,7 +41,7 @@ def parse_args():
work_dir="./train/work",
num_epochs=300,
weight_decay=0.005,
checkpoint_save_freq=200,
checkpoint_save_freq=100,
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
@ -266,6 +266,7 @@ def create_all_dags(args, neural_factory):
folder=neural_factory.checkpoint_dir,
load_from_folder=args.load_dir,
step_freq=args.checkpoint_save_freq,
checkpoints_to_keep=30,
)
callbacks = [train_callback, chpt_callback]