1. implement dataset augmentation and validation in process
2. added option to skip 'incorrect' annotations in validation data 3. added confirmation on clearing mongo collection 4. added an option to navigate to a given text in the validation ui 5. added a dataset and remote option to trainer to load dataset from directory and remote rpyc service
parent
83db445a6f
commit
8e79bbb571
|
|
@ -2,20 +2,13 @@ import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from .utils import asr_manifest_reader, asr_manifest_writer
|
from .utils import asr_manifest_reader, asr_manifest_writer
|
||||||
|
from typing import List
|
||||||
|
from itertools import chain
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def split_data(dataset_path: Path, test_size: float = 0.1):
|
|
||||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
|
||||||
asr_data = list(asr_manifest_reader(manifest_path))
|
|
||||||
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
|
|
||||||
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
|
|
||||||
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def fixate_data(dataset_path: Path):
|
def fixate_data(dataset_path: Path):
|
||||||
manifest_path = dataset_path / Path("manifest.json")
|
manifest_path = dataset_path / Path("manifest.json")
|
||||||
|
|
@ -30,31 +23,42 @@ def fixate_data(dataset_path: Path):
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def augment_an4():
|
def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||||
an4_train = Path("/dataset/asr_data/an4/train_manifest.json").read_bytes()
|
reader_list = []
|
||||||
an4_test = Path("/dataset/asr_data/an4/test_manifest.json").read_bytes()
|
abs_manifest_path = Path("abs_manifest.json")
|
||||||
pnr_train = Path("/dataset/asr_data/pnr_data/train_manifest.json").read_bytes()
|
for dataset_path in src_dataset_paths:
|
||||||
pnr_test = Path("/dataset/asr_data/pnr_data/test_manifest.json").read_bytes()
|
manifest_path = dataset_path / abs_manifest_path
|
||||||
|
reader_list.append(asr_manifest_reader(manifest_path))
|
||||||
with Path("/dataset/asr_data/an4_pnr/train_manifest.json").open("wb") as pf:
|
dest_dataset_path.mkdir(parents=True, exist_ok=True)
|
||||||
pf.write(an4_train + pnr_train)
|
dest_manifest_path = dest_dataset_path / abs_manifest_path
|
||||||
with Path("/dataset/asr_data/an4_pnr/test_manifest.json").open("wb") as pf:
|
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
|
||||||
pf.write(an4_test + pnr_test)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def validate_data(data_file: Path):
|
def split_data(dataset_path: Path, test_size: float = 0.1):
|
||||||
with Path(data_file).open("r") as pf:
|
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||||
pnr_jsonl = pf.readlines()
|
asr_data = list(asr_manifest_reader(manifest_path))
|
||||||
for (i, s) in enumerate(pnr_jsonl):
|
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
|
||||||
try:
|
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
|
||||||
d = json.loads(s)
|
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
|
||||||
audio_file = data_file.parent / Path(d["audio_filepath"])
|
|
||||||
if not audio_file.exists():
|
|
||||||
raise OSError(f"File {audio_file} not found")
|
@app.command()
|
||||||
except BaseException as e:
|
def validate_data(dataset_path: Path):
|
||||||
print(f'failed on {i} with "{e}"')
|
for mf_type in ["train_manifest.json", "test_manifest.json"]:
|
||||||
print("no errors found. seems like a valid manifest.")
|
data_file = dataset_path / Path(mf_type)
|
||||||
|
print(f"validating {data_file}.")
|
||||||
|
with Path(data_file).open("r") as pf:
|
||||||
|
pnr_jsonl = pf.readlines()
|
||||||
|
for (i, s) in enumerate(pnr_jsonl):
|
||||||
|
try:
|
||||||
|
d = json.loads(s)
|
||||||
|
audio_file = data_file.parent / Path(d["audio_filepath"])
|
||||||
|
if not audio_file.exists():
|
||||||
|
raise OSError(f"File {audio_file} not found")
|
||||||
|
except BaseException as e:
|
||||||
|
print(f'failed on {i} with "{e}"')
|
||||||
|
print(f"no errors found. seems like a valid {mf_type}.")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
||||||
|
|
@ -143,8 +143,8 @@ def fill_unannotated(
|
||||||
@app.command()
|
@app.command()
|
||||||
def update_corrections(
|
def update_corrections(
|
||||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||||
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
|
||||||
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
||||||
|
skip_incorrect: bool = True,
|
||||||
):
|
):
|
||||||
def correct_manifest(manifest_data_gen, corrections_path):
|
def correct_manifest(manifest_data_gen, corrections_path):
|
||||||
corrections = json.load(corrections_path.open())
|
corrections = json.load(corrections_path.open())
|
||||||
|
|
@ -170,15 +170,18 @@ def update_corrections(
|
||||||
}
|
}
|
||||||
elif d["chars"] in correction_map:
|
elif d["chars"] in correction_map:
|
||||||
correct_text = correction_map[d["chars"]]
|
correct_text = correction_map[d["chars"]]
|
||||||
renamed_set.add(correct_text)
|
if skip_incorrect:
|
||||||
new_name = str(Path(correct_text).with_suffix(".wav"))
|
print(f'skipping incorrect {d["audio_path"]} corrected to {correct_text}')
|
||||||
d["audio_path"].replace(d["audio_path"].with_name(new_name))
|
else:
|
||||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
renamed_set.add(correct_text)
|
||||||
yield {
|
new_name = str(Path(correct_text).with_suffix(".wav"))
|
||||||
"audio_filepath": new_filepath,
|
d["audio_path"].replace(d["audio_path"].with_name(new_name))
|
||||||
"duration": d["duration"],
|
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||||
"text": alnum_to_asr_tokens(correct_text),
|
yield {
|
||||||
}
|
"audio_filepath": new_filepath,
|
||||||
|
"duration": d["duration"],
|
||||||
|
"text": alnum_to_asr_tokens(correct_text),
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
# don't delete if another correction points to an old file
|
# don't delete if another correction points to an old file
|
||||||
if d["chars"] not in renamed_set:
|
if d["chars"] not in renamed_set:
|
||||||
|
|
@ -202,8 +205,12 @@ def update_corrections(
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def clear_mongo_corrections():
|
def clear_mongo_corrections():
|
||||||
col = get_mongo_conn().test.asr_validation
|
delete = typer.confirm("are you sure you want to clear mongo collection it?")
|
||||||
col.delete_many({"type": "correction"})
|
if delete:
|
||||||
|
col = get_mongo_conn().test.asr_validation
|
||||||
|
col.delete_many({"type": "correction"})
|
||||||
|
typer.echo("deleted mongo collection.")
|
||||||
|
typer.echo("Aborted")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ if not hasattr(st, "mongo_connected"):
|
||||||
|
|
||||||
@st.cache()
|
@st.cache()
|
||||||
def load_ui_data(validation_ui_data_path: Path):
|
def load_ui_data(validation_ui_data_path: Path):
|
||||||
typer.echo(f"Using validation ui data from :{validation_ui_data_path}")
|
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||||
return ExtendedPath(validation_ui_data_path).read_json()
|
return ExtendedPath(validation_ui_data_path).read_json()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -117,6 +117,11 @@ def main(manifest: Path):
|
||||||
# pass
|
# pass
|
||||||
# if st.button("Next Untagged"):
|
# if st.button("Next Untagged"):
|
||||||
# pass
|
# pass
|
||||||
|
text_sample = st.text_input("Go to Text:", value='')
|
||||||
|
if text_sample != '':
|
||||||
|
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample]
|
||||||
|
if len(candidates) > 0:
|
||||||
|
st.update_cursor(candidates[0])
|
||||||
real_idx = st.number_input(
|
real_idx = st.number_input(
|
||||||
"Go to real-index",
|
"Go to real-index",
|
||||||
value=sample["real_idx"],
|
value=sample["real_idx"],
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import argparse
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from ruamel.yaml import YAML
|
from ruamel.yaml import YAML
|
||||||
|
|
@ -36,13 +37,13 @@ def parse_args():
|
||||||
lr=0.002,
|
lr=0.002,
|
||||||
amp_opt_level="O1",
|
amp_opt_level="O1",
|
||||||
create_tb_writer=True,
|
create_tb_writer=True,
|
||||||
model_config="./train/jasper10x5dr.yaml",
|
model_config="./train/jasper-speller10x5dr.yaml",
|
||||||
train_dataset="./train/asr_data/train_manifest.json",
|
# train_dataset="./train/asr_data/train_manifest.json",
|
||||||
eval_datasets="./train/asr_data/test_manifest.json",
|
# eval_datasets="./train/asr_data/test_manifest.json",
|
||||||
work_dir="./train/work",
|
work_dir="./train/work",
|
||||||
num_epochs=50,
|
num_epochs=300,
|
||||||
weight_decay=0.005,
|
weight_decay=0.005,
|
||||||
checkpoint_save_freq=1000,
|
checkpoint_save_freq=200,
|
||||||
eval_freq=100,
|
eval_freq=100,
|
||||||
load_dir="./train/models/jasper/",
|
load_dir="./train/models/jasper/",
|
||||||
warmup_steps=3,
|
warmup_steps=3,
|
||||||
|
|
@ -70,7 +71,6 @@ def parse_args():
|
||||||
required=False,
|
required=False,
|
||||||
help="model configuration file: model.yaml",
|
help="model configuration file: model.yaml",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--remote_data",
|
"--remote_data",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -78,6 +78,13 @@ def parse_args():
|
||||||
default="",
|
default="",
|
||||||
help="remote dataloader endpoint",
|
help="remote dataloader endpoint",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="",
|
||||||
|
help="dataset directory containing train/test manifests",
|
||||||
|
)
|
||||||
|
|
||||||
# Create new args
|
# Create new args
|
||||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||||
|
|
@ -120,22 +127,26 @@ def create_all_dags(args, neural_factory):
|
||||||
# Calculate num_workers for dataloader
|
# Calculate num_workers for dataloader
|
||||||
total_cpus = os.cpu_count()
|
total_cpus = os.cpu_count()
|
||||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||||
# cpu_per_traindl = 1
|
|
||||||
# perturb_config = jasper_params.get('perturb', None)
|
# perturb_config = jasper_params.get('perturb', None)
|
||||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||||
del train_dl_params["train"]
|
del train_dl_params["train"]
|
||||||
del train_dl_params["eval"]
|
del train_dl_params["eval"]
|
||||||
# del train_dl_params["normalize_transcripts"]
|
# del train_dl_params["normalize_transcripts"]
|
||||||
|
|
||||||
|
if args.dataset:
|
||||||
|
d_path = Path(args.dataset)
|
||||||
|
if not args.train_dataset:
|
||||||
|
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||||
|
if not args.eval_datasets:
|
||||||
|
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||||
|
|
||||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||||
|
|
||||||
if args.remote_data:
|
if args.remote_data:
|
||||||
train_dl_params['rpyc_host'] = args.remote_data
|
train_dl_params["rpyc_host"] = args.remote_data
|
||||||
data_loader_layer = RpycAudioToTextDataLayer
|
data_loader_layer = RpycAudioToTextDataLayer
|
||||||
# if args.remote_data:
|
|
||||||
# # import pdb; pdb.set_trace()
|
|
||||||
# data_loader_layer = rpyc.connect(
|
|
||||||
# args.remote_data, 8064, config={"sync_request_timeout": 600}
|
|
||||||
# ).root.get_data_loader()
|
|
||||||
data_layer = data_loader_layer(
|
data_layer = data_loader_layer(
|
||||||
manifest_filepath=args.train_dataset,
|
manifest_filepath=args.train_dataset,
|
||||||
sample_rate=sample_rate,
|
sample_rate=sample_rate,
|
||||||
|
|
@ -169,7 +180,7 @@ def create_all_dags(args, neural_factory):
|
||||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||||
if args.remote_data:
|
if args.remote_data:
|
||||||
eval_dl_params['rpyc_host'] = args.remote_data
|
eval_dl_params["rpyc_host"] = args.remote_data
|
||||||
del eval_dl_params["train"]
|
del eval_dl_params["train"]
|
||||||
del eval_dl_params["eval"]
|
del eval_dl_params["eval"]
|
||||||
data_layers_eval = []
|
data_layers_eval = []
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue