1. fix update-correction to use ui_dump instead of manifest

2. update training params no of checkpoints on chpk frequency
Malar Kannan 2020-06-19 14:16:04 +05:30
parent 000853b600
commit e76ccda5dd
4 changed files with 41 additions and 48 deletions

View File

@ -23,7 +23,7 @@ def fixate_data(dataset_path: Path):
@app.command() @app.command()
def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path): def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
reader_list = [] reader_list = []
abs_manifest_path = Path("abs_manifest.json") abs_manifest_path = Path("abs_manifest.json")
for dataset_path in src_dataset_paths: for dataset_path in src_dataset_paths:

View File

@ -59,6 +59,10 @@ def alnum_to_asr_tokens(text):
return ("".join(num_tokens)).lower() return ("".join(num_tokens)).lower()
def tscript_uuid_fname(transcript):
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False): def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name) dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
@ -67,7 +71,7 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
with asr_manifest.open("w") as mf: with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}") print(f"writing manifest to {asr_manifest}")
for transcript, audio_dur, wav_data in asr_data_source: for transcript, audio_dur, wav_data in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8) fname = tscript_uuid_fname(transcript)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data) audio_file.write_bytes(wav_data)
rel_pnr_path = audio_file.relative_to(dataset_dir) rel_pnr_path = audio_file.relative_to(dataset_dir)
@ -174,7 +178,7 @@ def asr_manifest_reader(data_manifest_path: Path):
pnr_data = [json.loads(v) for v in pnr_jsonl] pnr_data = [json.loads(v) for v in pnr_jsonl]
for p in pnr_data: for p in pnr_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"]) p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["chars"] = Path(p["audio_filepath"]).stem p["text"] = p["text"].strip()
yield p yield p

View File

@ -11,6 +11,7 @@ from ..utils import (
ExtendedPath, ExtendedPath,
asr_manifest_reader, asr_manifest_reader,
asr_manifest_writer, asr_manifest_writer,
tscript_uuid_fname,
get_mongo_conn, get_mongo_conn,
plot_seg, plot_seg,
) )
@ -262,9 +263,9 @@ class ExtractionType(str, Enum):
def split_extract( def split_extract(
data_name: str = typer.Option("call_alphanum", show_default=True), data_name: str = typer.Option("call_alphanum", show_default=True),
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True), # dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
dump_dir: Path = Path("./data/valiation_data"), # dump_dir: Path = Path("./data/valiation_data"),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"), dump_file: Path = Path("ui_dump.json"),
manifest_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"), manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"), corrections_file: Path = Path("corrections.json"),
conv_data_path: Path = Path("./data/conv_data.json"), conv_data_path: Path = Path("./data/conv_data.json"),
@ -272,33 +273,18 @@ def split_extract(
): ):
import shutil import shutil
def get_conv_data(cdp):
from itertools import product
conv_data = json.load(cdp.open())
days = [str(i) for i in range(1, 32)]
months = conv_data["months"]
day_months = {d + " " + m for d, m in product(days, months)}
return {
"cities": set(conv_data["cities"]),
"names": set(conv_data["names"]),
"dates": day_months,
}
dest_data_name = data_name + "_" + extraction_type.value dest_data_name = data_name + "_" + extraction_type.value
data_manifest_path = manifest_dir / Path(data_name) / manifest_file data_manifest_path = dump_dir / Path(data_name) / manifest_file
conv_data = get_conv_data(conv_data_path) conv_data = ExtendedPath(conv_data_path).read_json()
extraction_vals = conv_data[extraction_type.value] extraction_vals = conv_data[extraction_type.value]
manifest_gen = asr_manifest_reader(data_manifest_path) manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = manifest_dir / Path(dest_data_name) dest_data_dir = dump_dir / Path(dest_data_name)
dest_data_dir.mkdir(exist_ok=True, parents=True) dest_data_dir.mkdir(exist_ok=True, parents=True)
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True) (dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file dest_manifest_path = dest_data_dir / manifest_file
dest_ui_dir = dump_dir / Path(dest_data_name) dest_ui_path = dest_data_dir / dump_file
dest_ui_dir.mkdir(exist_ok=True, parents=True) dest_correction_path = dest_data_dir / corrections_file
dest_ui_path = dest_ui_dir / dump_file
dest_correction_path = dest_ui_dir / corrections_file
def extract_manifest(mg): def extract_manifest(mg):
for m in mg: for m in mg:
@ -330,19 +316,19 @@ def split_extract(
@app.command() @app.command()
def update_corrections( def update_corrections(
data_name: str = typer.Option("call_alphanum", show_default=True), data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/valiation_data"), dump_dir: Path = Path("./data/asr_data"),
manifest_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"), manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"), corrections_file: Path = Path("corrections.json"),
# data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), ui_dump_file: Path = Path("ui_dump.json"),
# corrections_path: Path = Path("./data/valiation_data/corrections.json"),
skip_incorrect: bool = True, skip_incorrect: bool = True,
): ):
data_manifest_path = manifest_dir / Path(data_name) / manifest_file data_manifest_path = dump_dir / Path(data_name) / manifest_file
corrections_path = dump_dir / Path(data_name) / corrections_file corrections_path = dump_dir / Path(data_name) / corrections_file
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
def correct_manifest(manifest_data_gen, corrections_path): def correct_manifest(ui_dump_path, corrections_path):
corrections = json.load(corrections_path.open()) corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
correct_set = { correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct" c["code"] for c in corrections if c["value"]["status"] == "Correct"
} }
@ -355,36 +341,38 @@ def update_corrections(
# for d in manifest_data_gen: # for d in manifest_data_gen:
# if d["chars"] in incorrect_set: # if d["chars"] in incorrect_set:
# d["audio_path"].unlink() # d["audio_path"].unlink()
renamed_set = set() # renamed_set = set()
for d in manifest_data_gen: for d in ui_data:
if d["chars"] in correct_set: if d["utterance_id"] in correct_set:
yield { yield {
"audio_filepath": d["audio_filepath"], "audio_filepath": d["audio_filepath"],
"duration": d["duration"], "duration": d["duration"],
"text": d["text"], "text": d["text"],
} }
elif d["chars"] in correction_map: elif d["utterance_id"] in correction_map:
correct_text = correction_map[d["chars"]] correct_text = correction_map[d["utterance_id"]]
if skip_incorrect: if skip_incorrect:
print( print(
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}' f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
) )
else: else:
renamed_set.add(correct_text) orig_audio_path = Path(d["audio_path"])
new_name = str(Path(correct_text).with_suffix(".wav")) new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
d["audio_path"].replace(d["audio_path"].with_name(new_name)) new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
yield { yield {
"audio_filepath": new_filepath, "audio_filepath": new_filepath,
"duration": d["duration"], "duration": d["duration"],
"text": alnum_to_asr_tokens(correct_text), "text": correct_text,
} }
else: else:
orig_audio_path = Path(d["audio_path"])
# don't delete if another correction points to an old file # don't delete if another correction points to an old file
if d["chars"] not in renamed_set: # if d["text"] not in renamed_set:
d["audio_path"].unlink() orig_audio_path.unlink()
else: # else:
print(f'skipping deletion of correction:{d["chars"]}') # print(f'skipping deletion of correction:{d["text"]}')
typer.echo(f"Using data manifest:{data_manifest_path}") typer.echo(f"Using data manifest:{data_manifest_path}")
dataset_dir = data_manifest_path.parent dataset_dir = data_manifest_path.parent
@ -393,8 +381,8 @@ def update_corrections(
if not backup_dir.exists(): if not backup_dir.exists():
typer.echo(f"backing up to :{backup_dir}") typer.echo(f"backing up to :{backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir)) shutil.copytree(str(dataset_dir), str(backup_dir))
manifest_gen = asr_manifest_reader(data_manifest_path) # manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(manifest_gen, corrections_path) corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
new_data_manifest_path = data_manifest_path.with_name("manifest.new") new_data_manifest_path = data_manifest_path.with_name("manifest.new")
asr_manifest_writer(new_data_manifest_path, corrected_manifest) asr_manifest_writer(new_data_manifest_path, corrected_manifest)
new_data_manifest_path.replace(data_manifest_path) new_data_manifest_path.replace(data_manifest_path)

View File

@ -41,7 +41,7 @@ def parse_args():
work_dir="./train/work", work_dir="./train/work",
num_epochs=300, num_epochs=300,
weight_decay=0.005, weight_decay=0.005,
checkpoint_save_freq=200, checkpoint_save_freq=100,
eval_freq=100, eval_freq=100,
load_dir="./train/models/jasper/", load_dir="./train/models/jasper/",
warmup_steps=3, warmup_steps=3,
@ -266,6 +266,7 @@ def create_all_dags(args, neural_factory):
folder=neural_factory.checkpoint_dir, folder=neural_factory.checkpoint_dir,
load_from_folder=args.load_dir, load_from_folder=args.load_dir,
step_freq=args.checkpoint_save_freq, step_freq=args.checkpoint_save_freq,
checkpoints_to_keep=30,
) )
callbacks = [train_callback, chpt_callback] callbacks = [train_callback, chpt_callback]