1. integrated data generator using google tts
2. added training script fix module packaging issue implement call audio data recycler for asr 1. added streamlit based validation ui with mongodb datastore integration 2. fix asr wrong sample rate inference 3. update requirements 1. refactored streamlit code 2. fixed issues in data manifest handling refresh to next entry on submit and comment out mongo clearing code for safety :P add validation ui and post processing to correct using validation data 1. added a tool to extract asr data from gcp transcripts logs 2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file 3. clean-up leaderboard duration logic 4. added a wip dataloader service 5. made the asr_data_writer util more generic with verbose flags and unique filename 6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node 7. refactored the validation post processing to dump a ui config for validation 8. included utility functions to correct, fill update and clear annotations from mongodb data 9. refactored the ui logic to be more generic for any asr data 10. updated setup.py dependencies to support the above features unlink temporary files after transcribing 1. clean-up unused data process code 2. fix invalid sample no from mongo 3. data loader service return remote netref 1. added training utils with custom data loaders with remote rpyc dataservice support 2. fix validation correction dump path 3. cache dataset for precaching before training to memory 4. update dependencies 1. implement dataset augmentation and validation in process 2. added option to skip 'incorrect' annotations in validation data 3. added confirmation on clearing mongo collection 4. added an option to navigate to a given text in the validation ui 5. added a dataset and remote option to trainer to load dataset from directory and remote rpyc service 1. added utility command to export call logs 2. mongo conn accepts port refactored module structure 1. enabled silece stripping in chunks when recycling audio from asr logs 2. limit asr recycling to 1 min of start audio to get reliable alignments and ignoring agent channel 3. added rev recycler for generating asr dataset from rev transcripts and audio 4. update pydub dependency for silence stripping fn and removing threadpool hardcoded worker count 1. added support for mono/dual channel rev transcripts 2. handle errors when extracting datapoints from rev meta data 3. added suport for annotation only task when dumping ui data cleanup rev recycle added option to disable plots during validation fix skipping null audio and add more verbose logs respect verbose flag don't load audio for annotation only ui and keep spoken as text for normal asr validation 1. refactored wav chunk processing method 2. renamed streamlit to validation_ui show duration on validation of dataset parallelize data loading from remote skipping invalid data points 1. removed the transcriber_pretrained/speller from utils 2. introduced get_mongo_coll to get the collection object directly from mongo uri 3. removed processing of correction entries to remove space/upper casing refactor validation process arguments and logging 1. added a data extraction type argument 2. cleanup/refactor 1. using dataname args for update/fill annotations 2. rename to dump_ui added support for name/dates/cities call data extraction and more logs handling non-pnr cases without parens in text data 1. added conv data generator 2. more utils 1. added start delay arg in call recycler 2. implement ui_dump/manifest writer in call_recycler itself 3. refactored call data point plotter 4. added sample-ui task-ui on the validation process 5. implemented call-quality stats using corrections from mongo 6. support deleting cursors on mongo 7. implement multiple task support on validation ui based on task_id mongo field fix 11st to 11th in ordinal stripping silence on call chunk 1. added option to strip silent chunks 2. computing caller quality based on task-id of corrections 1. fix update-correction to use ui_dump instead of manifest 2. update training params no of checkpoints on chpk frequency 1. split extract all data types in one shot with --extraction-type all flag 2. add notes about diffing split extracted and original data 3. add a nlu conv generator to generate conv data based on nlu utterances and entities 4. add task uid support for dumping corrections 5. abstracted generate date fn 1. added a test generator and slu evaluator 2. ui dump now include gcp results 3. showing default option for more args validation process commands added evaluation command clean-uptegra
parent
f7ebd8e90a
commit
e24a8cf9d0
|
|
@ -1,3 +1,11 @@
|
|||
/data/
|
||||
/model/
|
||||
/train/
|
||||
.env*
|
||||
*.yaml
|
||||
*.yml
|
||||
*.json
|
||||
|
||||
|
||||
# Created by https://www.gitignore.io/api/python
|
||||
# Edit at https://www.gitignore.io/?templates=python
|
||||
|
|
@ -108,3 +116,36 @@ dmypy.json
|
|||
.pyre/
|
||||
|
||||
# End of https://www.gitignore.io/api/python
|
||||
|
||||
# Created by https://www.gitignore.io/api/macos
|
||||
# Edit at https://www.gitignore.io/?templates=macos
|
||||
|
||||
### macOS ###
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
# End of https://www.gitignore.io/api/macos
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
|
||||
> Diff after splitting based on type
|
||||
```
|
||||
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
|
||||
```
|
||||
|
|
@ -62,7 +62,7 @@ class JasperASR(object):
|
|||
wf = wave.open(audio_file_path, "w")
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(16000)
|
||||
wf.setframerate(24000)
|
||||
wf.writeframesraw(audio_data)
|
||||
wf.close()
|
||||
manifest = {"audio_filepath": audio_file_path, "duration": 60, "text": "todo"}
|
||||
|
|
@ -108,6 +108,8 @@ class JasperASR(object):
|
|||
tensors = self.neural_factory.infer(tensors=eval_tensors)
|
||||
prediction = post_process_predictions(tensors[0], self.labels)
|
||||
prediction_text = ". ".join(prediction)
|
||||
os.unlink(manifest_file.name)
|
||||
os.unlink(audio_file.name)
|
||||
return prediction_text
|
||||
|
||||
def transcribe_file(self, audio_file, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
import os
|
||||
import logging
|
||||
import rpyc
|
||||
from functools import lru_cache
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost")
|
||||
ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045"))
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
|
||||
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
|
||||
asr = rpyc.connect(asr_host, asr_port).root
|
||||
logger.info(f"connected to asr server successfully")
|
||||
return asr.transcribe
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from sklearn.model_selection import train_test_split
|
||||
from .utils import asr_manifest_reader, asr_manifest_writer
|
||||
from typing import List
|
||||
from itertools import chain
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def fixate_data(dataset_path: Path):
|
||||
manifest_path = dataset_path / Path("manifest.json")
|
||||
real_manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
|
||||
def fix_path():
|
||||
for i in asr_manifest_reader(manifest_path):
|
||||
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
|
||||
yield i
|
||||
|
||||
asr_manifest_writer(real_manifest_path, fix_path())
|
||||
|
||||
|
||||
@app.command()
|
||||
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||
reader_list = []
|
||||
abs_manifest_path = Path("abs_manifest.json")
|
||||
for dataset_path in src_dataset_paths:
|
||||
manifest_path = dataset_path / abs_manifest_path
|
||||
reader_list.append(asr_manifest_reader(manifest_path))
|
||||
dest_dataset_path.mkdir(parents=True, exist_ok=True)
|
||||
dest_manifest_path = dest_dataset_path / abs_manifest_path
|
||||
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
|
||||
|
||||
|
||||
@app.command()
|
||||
def split_data(dataset_path: Path, test_size: float = 0.1):
|
||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
asr_data = list(asr_manifest_reader(manifest_path))
|
||||
train_data, test_data = train_test_split(asr_data, test_size=test_size)
|
||||
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_data)
|
||||
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def validate_data(dataset_path: Path):
|
||||
from natural.date import compress
|
||||
from datetime import timedelta
|
||||
|
||||
for mf_type in ["train_manifest.json", "test_manifest.json"]:
|
||||
data_file = dataset_path / Path(mf_type)
|
||||
print(f"validating {data_file}.")
|
||||
with Path(data_file).open("r") as pf:
|
||||
data_jsonl = pf.readlines()
|
||||
duration = 0
|
||||
for (i, s) in enumerate(data_jsonl):
|
||||
try:
|
||||
d = json.loads(s)
|
||||
duration += d["duration"]
|
||||
audio_file = data_file.parent / Path(d["audio_filepath"])
|
||||
if not audio_file.exists():
|
||||
raise OSError(f"File {audio_file} not found")
|
||||
except BaseException as e:
|
||||
print(f'failed on {i} with "{e}"')
|
||||
duration_str = compress(timedelta(seconds=duration), pad=" ")
|
||||
print(
|
||||
f"no errors found. seems like a valid {mf_type}. contains {duration_str}sec of audio"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import rpyc
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import nemo
|
||||
import pickle
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
|
||||
)
|
||||
|
||||
|
||||
class ASRDataService(rpyc.Service):
|
||||
def exposed_get_path_samples(
|
||||
self, file_path, target_sr, int_values, offset, duration, trim
|
||||
):
|
||||
print(f"loading.. {file_path}")
|
||||
audio = AudioSegment.from_file(
|
||||
file_path,
|
||||
target_sr=target_sr,
|
||||
int_values=int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
|
||||
return pickle.dumps(audio.samples)
|
||||
|
||||
def exposed_read_path(self, file_path):
|
||||
# print(f"reading path.. {file_path}")
|
||||
return Path(file_path).read_bytes()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run_server(port: int = 0):
|
||||
listen_port = port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
|
||||
service = ASRDataService()
|
||||
t = ThreadedServer(
|
||||
service, port=listen_port, protocol_config={"allow_all_attrs": True}
|
||||
)
|
||||
typer.echo(f"starting asr server on {listen_port}...")
|
||||
t.start()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
import io
|
||||
import os
|
||||
import json
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from uuid import uuid4
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pymongo
|
||||
from slugify import slugify
|
||||
from jasper.client import transcribe_gen
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def manifest_str(path, dur, text):
|
||||
return (
|
||||
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
|
||||
def wav_bytes(audio_bytes, frame_rate=24000):
|
||||
wf_b = io.BytesIO()
|
||||
with wave.open(wf_b, mode="w") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setframerate(frame_rate)
|
||||
wf.setsampwidth(2)
|
||||
wf.writeframesraw(audio_bytes)
|
||||
return wf_b.getvalue()
|
||||
|
||||
|
||||
def tscript_uuid_fname(transcript):
|
||||
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
|
||||
|
||||
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
dataset_dir = output_dir / Path(dataset_name)
|
||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
num_datapoints = 0
|
||||
with asr_manifest.open("w") as mf:
|
||||
print(f"writing manifest to {asr_manifest}")
|
||||
for transcript, audio_dur, wav_data in asr_data_source:
|
||||
fname = tscript_uuid_fname(transcript)
|
||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
audio_file.write_bytes(wav_data)
|
||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||
mf.write(manifest)
|
||||
if verbose:
|
||||
print(f"writing '{transcript}' of duration {audio_dur}")
|
||||
num_datapoints += 1
|
||||
return num_datapoints
|
||||
|
||||
|
||||
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
dataset_dir = output_dir / Path(dataset_name)
|
||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
num_datapoints = 0
|
||||
ui_dump = {
|
||||
"use_domain_asr": False,
|
||||
"annotation_only": False,
|
||||
"enable_plots": True,
|
||||
"data": [],
|
||||
}
|
||||
data_funcs = []
|
||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||
with asr_manifest.open("w") as mf:
|
||||
print(f"writing manifest to {asr_manifest}")
|
||||
|
||||
def data_fn(
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
fname,
|
||||
audio_path,
|
||||
num_datapoints,
|
||||
rel_data_path,
|
||||
):
|
||||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
||||
wav_plot_path = (
|
||||
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
plot_seg(wav_plot_path, audio_path)
|
||||
return {
|
||||
"audio_filepath": str(rel_data_path),
|
||||
"duration": round(audio_dur, 1),
|
||||
"text": transcript,
|
||||
"real_idx": num_datapoints,
|
||||
"audio_path": audio_path,
|
||||
"spoken": transcript,
|
||||
"caller": caller_name,
|
||||
"utterance_id": fname,
|
||||
"pretrained_asr": pretrained_result,
|
||||
"pretrained_wer": pretrained_wer,
|
||||
"plot_path": str(wav_plot_path),
|
||||
}
|
||||
|
||||
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
audio_file.write_bytes(wav_data)
|
||||
audio_path = str(audio_file)
|
||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||
mf.write(manifest)
|
||||
data_funcs.append(
|
||||
partial(
|
||||
data_fn,
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
fname,
|
||||
audio_path,
|
||||
num_datapoints,
|
||||
rel_data_path,
|
||||
)
|
||||
)
|
||||
num_datapoints += 1
|
||||
dump_data = parallel_apply(lambda x: x(), data_funcs)
|
||||
# dump_data = [x() for x in tqdm(data_funcs)]
|
||||
ui_dump["data"] = dump_data
|
||||
ExtendedPath(ui_dump_file).write_json(ui_dump)
|
||||
return num_datapoints
|
||||
|
||||
|
||||
def asr_manifest_reader(data_manifest_path: Path):
|
||||
print(f"reading manifest from {data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
data_jsonl = pf.readlines()
|
||||
data_data = [json.loads(v) for v in data_jsonl]
|
||||
for p in data_data:
|
||||
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
||||
p["text"] = p["text"].strip()
|
||||
yield p
|
||||
|
||||
|
||||
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
|
||||
with asr_manifest_path.open("w") as mf:
|
||||
print(f"opening {asr_manifest_path} for writing manifest")
|
||||
for mani_dict in manifest_str_source:
|
||||
manifest = manifest_str(
|
||||
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
|
||||
)
|
||||
mf.write(manifest)
|
||||
|
||||
|
||||
def asr_test_writer(out_file_path: Path, source):
|
||||
def dd_str(dd, idx):
|
||||
path = dd["audio_filepath"]
|
||||
# dur = dd["duration"]
|
||||
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
|
||||
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
|
||||
|
||||
res_file = out_file_path.with_suffix(".result.json")
|
||||
with out_file_path.open("w") as of:
|
||||
print(f"opening {out_file_path} for writing test")
|
||||
results = []
|
||||
idx = 0
|
||||
for ui_dd in source:
|
||||
results.append(ui_dd)
|
||||
out_str = dd_str(ui_dd, idx)
|
||||
of.write(out_str)
|
||||
idx += 1
|
||||
of.write("DO_HANGUP\n")
|
||||
ExtendedPath(res_file).write_json(results)
|
||||
|
||||
|
||||
def batch(iterable, n=1):
|
||||
ls = len(iterable)
|
||||
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
|
||||
|
||||
|
||||
class ExtendedPath(type(Path())):
|
||||
"""docstring for ExtendedPath."""
|
||||
|
||||
def read_json(self):
|
||||
print(f"reading json from {self}")
|
||||
with self.open("r") as jf:
|
||||
return json.load(jf)
|
||||
|
||||
def write_json(self, data):
|
||||
print(f"writing json to {self}")
|
||||
self.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.open("w") as jf:
|
||||
return json.dump(data, jf, indent=2)
|
||||
|
||||
|
||||
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
|
||||
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
||||
mongo_uri = f"mongodb://{mongo_host}:{port}/"
|
||||
return pymongo.MongoClient(mongo_uri)[db][col]
|
||||
|
||||
|
||||
def strip_silence(sound):
|
||||
from pydub.silence import detect_leading_silence
|
||||
|
||||
start_trim = detect_leading_silence(sound)
|
||||
end_trim = detect_leading_silence(sound.reverse())
|
||||
duration = len(sound)
|
||||
return sound[start_trim : duration - end_trim]
|
||||
|
||||
|
||||
def plot_seg(wav_plot_path, audio_path):
|
||||
fig = plt.Figure()
|
||||
ax = fig.add_subplot()
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
||||
with wav_plot_path.open("wb") as wav_plot_f:
|
||||
fig.set_tight_layout(True)
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
|
||||
|
||||
def parallel_apply(fn, iterable, workers=8):
|
||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||
print(f"parallelly applying {fn}")
|
||||
return [
|
||||
res
|
||||
for res in tqdm(
|
||||
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
||||
)
|
||||
]
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,418 @@
|
|||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import (
|
||||
alnum_to_asr_tokens,
|
||||
ExtendedPath,
|
||||
asr_manifest_reader,
|
||||
asr_manifest_writer,
|
||||
tscript_uuid_fname,
|
||||
get_mongo_conn,
|
||||
plot_seg,
|
||||
)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def preprocess_datapoint(
|
||||
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
|
||||
):
|
||||
from pydub import AudioSegment
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
from jasper.client import transcribe_gen
|
||||
|
||||
try:
|
||||
res = dict(sample)
|
||||
res["real_idx"] = idx
|
||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
||||
res["audio_path"] = str(audio_path)
|
||||
if use_domain_asr:
|
||||
res["spoken"] = alnum_to_asr_tokens(res["text"])
|
||||
else:
|
||||
res["spoken"] = res["text"]
|
||||
res["utterance_id"] = audio_path.stem
|
||||
if not annotation_only:
|
||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||
|
||||
aud_seg = (
|
||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["pretrained_wer"] = word_error_rate(
|
||||
[res["text"]], [res["pretrained_asr"]]
|
||||
)
|
||||
if use_domain_asr:
|
||||
transcriber_speller = transcribe_gen(asr_port=8045)
|
||||
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["domain_wer"] = word_error_rate(
|
||||
[res["spoken"]], [res["pretrained_asr"]]
|
||||
)
|
||||
if enable_plots:
|
||||
wav_plot_path = (
|
||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
plot_seg(wav_plot_path, audio_path)
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
return res
|
||||
except BaseException as e:
|
||||
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_ui(
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dataset_dir: Path = Path("./data/asr_data"),
|
||||
dump_dir: Path = Path("./data/valiation_data"),
|
||||
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
|
||||
use_domain_asr: bool = False,
|
||||
annotation_only: bool = False,
|
||||
enable_plots: bool = True,
|
||||
):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
|
||||
dump_path: Path = dump_dir / Path(data_name) / dump_fname
|
||||
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
data_jsonl = pf.readlines()
|
||||
data_funcs = [
|
||||
partial(
|
||||
preprocess_datapoint,
|
||||
i,
|
||||
data_manifest_path.parent,
|
||||
json.loads(v),
|
||||
use_domain_asr,
|
||||
annotation_only,
|
||||
enable_plots,
|
||||
)
|
||||
for i, v in enumerate(data_jsonl)
|
||||
]
|
||||
|
||||
def exec_func(f):
|
||||
return f()
|
||||
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all preprocess tasks")
|
||||
data_final = filter(
|
||||
None,
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, data_funcs),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=len(data_funcs),
|
||||
)
|
||||
),
|
||||
)
|
||||
if annotation_only:
|
||||
result = list(data_final)
|
||||
else:
|
||||
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
||||
result = sorted(data_final, key=lambda x: x[wer_key], reverse=True)
|
||||
ui_config = {
|
||||
"use_domain_asr": use_domain_asr,
|
||||
"annotation_only": annotation_only,
|
||||
"enable_plots": enable_plots,
|
||||
"data": result,
|
||||
}
|
||||
ExtendedPath(dump_path).write_json(ui_config)
|
||||
|
||||
|
||||
@app.command()
|
||||
def sample_ui(
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
sample_count: int = typer.Option(80, show_default=True),
|
||||
sample_file: Path = Path("sample_dump.json"),
|
||||
):
|
||||
import pandas as pd
|
||||
|
||||
processed_data_path = dump_dir / Path(data_name) / dump_file
|
||||
sample_path = dump_dir / Path(data_name) / sample_file
|
||||
processed_data = ExtendedPath(processed_data_path).read_json()
|
||||
df = pd.DataFrame(processed_data["data"])
|
||||
samples_per_caller = sample_count // len(df["caller"].unique())
|
||||
caller_samples = pd.concat(
|
||||
[g.sample(samples_per_caller) for (c, g) in df.groupby("caller")]
|
||||
)
|
||||
caller_samples = caller_samples.reset_index(drop=True)
|
||||
caller_samples["real_idx"] = caller_samples.index
|
||||
sample_data = caller_samples.to_dict("records")
|
||||
processed_data["data"] = sample_data
|
||||
typer.echo(f"sampling {sample_count} datapoints")
|
||||
ExtendedPath(sample_path).write_json(processed_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def task_ui(
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
task_count: int = typer.Option(4, show_default=True),
|
||||
task_file: str = "task_dump",
|
||||
):
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
processed_data_path = dump_dir / Path(data_name) / dump_file
|
||||
processed_data = ExtendedPath(processed_data_path).read_json()
|
||||
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
|
||||
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
|
||||
task_f = task_f.reset_index(drop=True)
|
||||
task_f["real_idx"] = task_f.index
|
||||
task_data = task_f.to_dict("records")
|
||||
processed_data["data"] = task_data
|
||||
task_path = dump_dir / Path(data_name) / Path(task_file + f"-{t_idx}.json")
|
||||
ExtendedPath(task_path).write_json(processed_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_corrections(
|
||||
task_uid: str,
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_fname: Path = Path("corrections.json"),
|
||||
):
|
||||
dump_path = dump_dir / Path(data_name) / dump_fname
|
||||
col = get_mongo_conn(col="asr_validation")
|
||||
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
||||
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
||||
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
|
||||
corrections = [c for c in cursor_obj]
|
||||
ExtendedPath(dump_path).write_json(corrections)
|
||||
|
||||
|
||||
@app.command()
|
||||
def caller_quality(
|
||||
task_uid: str,
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_fname: Path = Path("ui_dump.json"),
|
||||
correction_fname: Path = Path("corrections.json"),
|
||||
):
|
||||
import copy
|
||||
import pandas as pd
|
||||
|
||||
dump_path = dump_dir / Path(data_name) / dump_fname
|
||||
correction_path = dump_dir / Path(data_name) / correction_fname
|
||||
dump_data = ExtendedPath(dump_path).read_json()
|
||||
|
||||
dump_map = {d["utterance_id"]: d for d in dump_data["data"]}
|
||||
correction_data = ExtendedPath(correction_path).read_json()
|
||||
|
||||
def correction_dp(c):
|
||||
dp = copy.deepcopy(dump_map[c["code"]])
|
||||
dp["valid"] = c["value"]["status"] == "Correct"
|
||||
return dp
|
||||
|
||||
corrected_dump = [
|
||||
correction_dp(c)
|
||||
for c in correction_data
|
||||
if c["task_id"].rsplit("-", 1)[1] == task_uid
|
||||
]
|
||||
df = pd.DataFrame(corrected_dump)
|
||||
print(f"Total samples: {len(df)}")
|
||||
for (c, g) in df.groupby("caller"):
|
||||
total = len(g)
|
||||
valid = len(g[g["valid"] == True])
|
||||
valid_rate = valid * 100 / total
|
||||
print(f"Caller: {c} Valid%:{valid_rate:.2f} of {total} samples")
|
||||
|
||||
|
||||
@app.command()
|
||||
def fill_unannotated(
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/valiation_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
corrections_file: Path = Path("corrections.json"),
|
||||
):
|
||||
processed_data_path = dump_dir / Path(data_name) / dump_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
processed_data = json.load(processed_data_path.open())
|
||||
corrections = json.load(corrections_path.open())
|
||||
annotated_codes = {c["code"] for c in corrections}
|
||||
all_codes = {c["gold_chars"] for c in processed_data}
|
||||
unann_codes = all_codes - annotated_codes
|
||||
mongo_conn = get_mongo_conn(col="asr_validation")
|
||||
for c in unann_codes:
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": c},
|
||||
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def split_extract(
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
|
||||
# dump_dir: Path = Path("./data/valiation_data"),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
||||
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
||||
extraction_type: str = "all",
|
||||
):
|
||||
import shutil
|
||||
|
||||
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||
conv_data = ExtendedPath(conv_data_path).read_json()
|
||||
|
||||
def extract_data_of_type(extraction_key):
|
||||
extraction_vals = conv_data[extraction_key]
|
||||
dest_data_name = data_name + "_" + extraction_key.lower()
|
||||
|
||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
dest_data_dir = dump_dir / Path(dest_data_name)
|
||||
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
||||
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
||||
dest_manifest_path = dest_data_dir / manifest_file
|
||||
dest_ui_path = dest_data_dir / dump_file
|
||||
|
||||
def extract_manifest(mg):
|
||||
for m in mg:
|
||||
if m["text"] in extraction_vals:
|
||||
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
|
||||
yield m
|
||||
|
||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||
|
||||
ui_data_path = dump_dir / Path(data_name) / dump_file
|
||||
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
||||
ui_data = orig_ui_data["data"]
|
||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
||||
final_data = []
|
||||
for i, d in enumerate(extracted_ui_data):
|
||||
d['real_idx'] = i
|
||||
final_data.append(d)
|
||||
orig_ui_data['data'] = final_data
|
||||
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
||||
|
||||
if corrections_file:
|
||||
dest_correction_path = dest_data_dir / corrections_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
corrections = json.load(corrections_path.open())
|
||||
extracted_corrections = list(
|
||||
filter(
|
||||
lambda c: c["code"] in file_ui_map
|
||||
and file_ui_map[c["code"]]["text"] in extraction_vals,
|
||||
corrections,
|
||||
)
|
||||
)
|
||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
||||
|
||||
if extraction_type.value == 'all':
|
||||
for ext_key in conv_data.keys():
|
||||
extract_data_of_type(ext_key)
|
||||
else:
|
||||
extract_data_of_type(extraction_type.value)
|
||||
|
||||
|
||||
@app.command()
|
||||
def update_corrections(
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: Path = Path("corrections.json"),
|
||||
ui_dump_file: Path = Path("ui_dump.json"),
|
||||
skip_incorrect: bool = typer.Option(True, show_default=True),
|
||||
):
|
||||
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
|
||||
|
||||
def correct_manifest(ui_dump_path, corrections_path):
|
||||
corrections = ExtendedPath(corrections_path).read_json()
|
||||
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
|
||||
correct_set = {
|
||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||
}
|
||||
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
|
||||
correction_map = {
|
||||
c["code"]: c["value"]["correction"]
|
||||
for c in corrections
|
||||
if c["value"]["status"] == "Incorrect"
|
||||
}
|
||||
# for d in manifest_data_gen:
|
||||
# if d["chars"] in incorrect_set:
|
||||
# d["audio_path"].unlink()
|
||||
# renamed_set = set()
|
||||
for d in ui_data:
|
||||
if d["utterance_id"] in correct_set:
|
||||
yield {
|
||||
"audio_filepath": d["audio_filepath"],
|
||||
"duration": d["duration"],
|
||||
"text": d["text"],
|
||||
}
|
||||
elif d["utterance_id"] in correction_map:
|
||||
correct_text = correction_map[d["utterance_id"]]
|
||||
if skip_incorrect:
|
||||
print(
|
||||
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
||||
)
|
||||
else:
|
||||
orig_audio_path = Path(d["audio_path"])
|
||||
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
|
||||
new_audio_path = orig_audio_path.with_name(new_name)
|
||||
orig_audio_path.replace(new_audio_path)
|
||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||
yield {
|
||||
"audio_filepath": new_filepath,
|
||||
"duration": d["duration"],
|
||||
"text": correct_text,
|
||||
}
|
||||
else:
|
||||
orig_audio_path = Path(d["audio_path"])
|
||||
# don't delete if another correction points to an old file
|
||||
# if d["text"] not in renamed_set:
|
||||
orig_audio_path.unlink()
|
||||
# else:
|
||||
# print(f'skipping deletion of correction:{d["text"]}')
|
||||
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
dataset_dir = data_manifest_path.parent
|
||||
dataset_name = dataset_dir.name
|
||||
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
|
||||
if not backup_dir.exists():
|
||||
typer.echo(f"backing up to :{backup_dir}")
|
||||
shutil.copytree(str(dataset_dir), str(backup_dir))
|
||||
# manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
|
||||
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
||||
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
||||
new_data_manifest_path.replace(data_manifest_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
def clear_mongo_corrections():
|
||||
delete = typer.confirm("are you sure you want to clear mongo collection it?")
|
||||
if delete:
|
||||
col = get_mongo_conn(col="asr_validation")
|
||||
col.delete_many({"type": "correction"})
|
||||
col.delete_many({"type": "current_cursor"})
|
||||
typer.echo("deleted mongo collection.")
|
||||
return
|
||||
typer.echo("Aborted")
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
import streamlit.ReportThread as ReportThread
|
||||
from streamlit.ScriptRequestQueue import RerunData
|
||||
from streamlit.ScriptRunner import RerunException
|
||||
from streamlit.server.Server import Server
|
||||
|
||||
|
||||
def rerun():
|
||||
"""Rerun a Streamlit app from the top!"""
|
||||
widget_states = _get_widget_states()
|
||||
raise RerunException(RerunData(widget_states))
|
||||
|
||||
|
||||
def _get_widget_states():
|
||||
# Hack to get the session object from Streamlit.
|
||||
|
||||
ctx = ReportThread.get_report_ctx()
|
||||
|
||||
session = None
|
||||
|
||||
current_server = Server.get_current()
|
||||
if hasattr(current_server, '_session_infos'):
|
||||
# Streamlit < 0.56
|
||||
session_infos = Server.get_current()._session_infos.values()
|
||||
else:
|
||||
session_infos = Server.get_current()._session_info_by_id.values()
|
||||
|
||||
for session_info in session_infos:
|
||||
if session_info.session.enqueue == ctx.enqueue:
|
||||
session = session_info.session
|
||||
|
||||
if session is None:
|
||||
raise RuntimeError(
|
||||
"Oh noes. Couldn't get your Streamlit Session object"
|
||||
"Are you doing something fancy with threads?"
|
||||
)
|
||||
# Got the session object!
|
||||
|
||||
return session._widget_states
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
import typer
|
||||
from uuid import uuid4
|
||||
from ..utils import ExtendedPath, get_mongo_conn
|
||||
from .st_rerun import rerun
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
if not hasattr(st, "mongo_connected"):
|
||||
st.mongoclient = get_mongo_conn(col="asr_validation")
|
||||
mongo_conn = st.mongoclient
|
||||
st.task_id = str(uuid4())
|
||||
|
||||
def current_cursor_fn():
|
||||
# mongo_conn = st.mongoclient
|
||||
cursor_obj = mongo_conn.find_one(
|
||||
{"type": "current_cursor", "task_id": st.task_id}
|
||||
)
|
||||
cursor_val = cursor_obj["cursor"]
|
||||
return cursor_val
|
||||
|
||||
def update_cursor_fn(val=0):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "current_cursor", "task_id": st.task_id},
|
||||
{"$set": {"type": "current_cursor", "task_id": st.task_id, "cursor": val}},
|
||||
upsert=True,
|
||||
)
|
||||
rerun()
|
||||
|
||||
def get_correction_entry_fn(code):
|
||||
return mongo_conn.find_one(
|
||||
{"type": "correction", "code": code}, projection={"_id": False}
|
||||
)
|
||||
|
||||
def update_entry_fn(code, value):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": code},
|
||||
{"$set": {"value": value, "task_id": st.task_id}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
def set_task_fn(mf_path):
|
||||
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
|
||||
if not task_path.exists():
|
||||
print(f"creating task lock at {task_path}")
|
||||
task_path.touch()
|
||||
|
||||
st.get_current_cursor = current_cursor_fn
|
||||
st.update_cursor = update_cursor_fn
|
||||
st.get_correction_entry = get_correction_entry_fn
|
||||
st.update_entry = update_entry_fn
|
||||
st.set_task = set_task_fn
|
||||
st.mongo_connected = True
|
||||
cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
|
||||
if not cursor_obj:
|
||||
update_cursor_fn(0)
|
||||
|
||||
|
||||
@st.cache()
|
||||
def load_ui_data(validation_ui_data_path: Path):
|
||||
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||
return ExtendedPath(validation_ui_data_path).read_json()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(manifest: Path):
|
||||
st.set_task(manifest)
|
||||
ui_config = load_ui_data(manifest)
|
||||
asr_data = ui_config["data"]
|
||||
use_domain_asr = ui_config.get("use_domain_asr", True)
|
||||
annotation_only = ui_config.get("annotation_only", False)
|
||||
enable_plots = ui_config.get("enable_plots", True)
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
title_type = "Speller " if use_domain_asr else ""
|
||||
task_uid = st.task_id.rsplit("-", 1)[1]
|
||||
if annotation_only:
|
||||
st.title(f"ASR Annotation - # {task_uid}")
|
||||
else:
|
||||
st.title(f"ASR {title_type}Validation - # {task_uid}")
|
||||
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
if not annotation_only:
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
||||
st.sidebar.title("Results:")
|
||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||
if "caller" in sample:
|
||||
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
||||
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
||||
else:
|
||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||
if enable_plots:
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
||||
# set default to text
|
||||
corrected = sample["text"]
|
||||
correction_entry = st.get_correction_entry(sample["utterance_id"])
|
||||
selected_idx = 0
|
||||
options = ("Correct", "Incorrect", "Inaudible")
|
||||
# if correction entry is present set the corresponding ui defaults
|
||||
if correction_entry:
|
||||
selected_idx = options.index(correction_entry["value"]["status"])
|
||||
corrected = correction_entry["value"]["correction"]
|
||||
selected = st.radio("The Audio is", options, index=selected_idx)
|
||||
if selected == "Incorrect":
|
||||
corrected = st.text_input("Actual:", value=corrected)
|
||||
if selected == "Inaudible":
|
||||
corrected = ""
|
||||
if st.button("Submit"):
|
||||
st.update_entry(
|
||||
sample["utterance_id"], {"status": selected, "correction": corrected}
|
||||
)
|
||||
st.update_cursor(sample_no + 1)
|
||||
if correction_entry:
|
||||
st.markdown(
|
||||
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
||||
)
|
||||
text_sample = st.text_input("Go to Text:", value="")
|
||||
if text_sample != "":
|
||||
candidates = [
|
||||
i
|
||||
for (i, p) in enumerate(asr_data)
|
||||
if p["text"] == text_sample or p["spoken"] == text_sample
|
||||
]
|
||||
if len(candidates) > 0:
|
||||
st.update_cursor(candidates[0])
|
||||
real_idx = st.number_input(
|
||||
"Go to real-index",
|
||||
value=sample["real_idx"],
|
||||
min_value=0,
|
||||
max_value=len(asr_data) - 1,
|
||||
)
|
||||
if real_idx != int(sample["real_idx"]):
|
||||
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
|
||||
st.update_cursor(idx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
|
|
@ -0,0 +1,359 @@
|
|||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
# import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
# monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
# from nemo.utils.lr_policies import CosineAnnealing
|
||||
from training.data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper-speller",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs", type=int, required=False, help="number of epochs to train"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="encoder checkpoint file: JasperEncoder.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="decoder checkpoint file: JasperDecoderForCTC.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
# data_layer = data_loader_layer(
|
||||
# manifest_filepath=args.train_dataset,
|
||||
# sample_rate=sample_rate,
|
||||
# labels=vocab,
|
||||
# batch_size=args.batch_size,
|
||||
# num_workers=cpu_per_traindl,
|
||||
# **train_dl_params,
|
||||
# # normalize_transcripts=False
|
||||
# )
|
||||
#
|
||||
# N = len(data_layer)
|
||||
# steps_per_epoch = math.ceil(
|
||||
# N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
# )
|
||||
# logging.info("Have {0} examples to train on.".format(N))
|
||||
#
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
|
||||
)
|
||||
|
||||
# multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
# if multiply_batch_config:
|
||||
# multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
#
|
||||
# spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
# if spectr_augment_config:
|
||||
# data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
# **spectr_augment_config
|
||||
# )
|
||||
#
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
# if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
# else:
|
||||
# logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
# logging.info("================================")
|
||||
# logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
|
||||
# logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
|
||||
# logging.info(
|
||||
# f"Total number of parameters in model: "
|
||||
# f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
# )
|
||||
# logging.info("================================")
|
||||
#
|
||||
# # Train DAG
|
||||
# (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
|
||||
# processed_signal_t, p_length_t = data_preprocessor(
|
||||
# input_signal=audio_signal_t, length=a_sig_length_t
|
||||
# )
|
||||
#
|
||||
# if multiply_batch_config:
|
||||
# (
|
||||
# processed_signal_t,
|
||||
# p_length_t,
|
||||
# transcript_t,
|
||||
# transcript_len_t,
|
||||
# ) = multiply_batch(
|
||||
# in_x=processed_signal_t,
|
||||
# in_x_len=p_length_t,
|
||||
# in_y=transcript_t,
|
||||
# in_y_len=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# if spectr_augment_config:
|
||||
# processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
|
||||
#
|
||||
# encoded_t, encoded_len_t = jasper_encoder(
|
||||
# audio_signal=processed_signal_t, length=p_length_t
|
||||
# )
|
||||
# log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
# predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
# loss_t = ctc_loss(
|
||||
# log_probs=log_probs_t,
|
||||
# targets=transcript_t,
|
||||
# input_length=encoded_len_t,
|
||||
# target_length=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# # Callbacks needed to print info to console and Tensorboard
|
||||
# train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
# tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
# print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
# get_tb_values=lambda x: [("loss", x[0])],
|
||||
# tb_writer=neural_factory.tb_writer,
|
||||
# )
|
||||
#
|
||||
# chpt_callback = nemo.core.CheckpointCallback(
|
||||
# folder=neural_factory.checkpoint_dir,
|
||||
# load_from_folder=args.load_dir,
|
||||
# step_freq=args.checkpoint_save_freq,
|
||||
# checkpoints_to_keep=30,
|
||||
# )
|
||||
#
|
||||
# callbacks = [train_callback, chpt_callback]
|
||||
callbacks = []
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return callbacks
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
# name = construct_name(
|
||||
# args.exp_name,
|
||||
# args.lr,
|
||||
# args.batch_size,
|
||||
# args.max_steps,
|
||||
# args.num_epochs,
|
||||
# args.weight_decay,
|
||||
# args.optimizer,
|
||||
# args.iter_per_step,
|
||||
# )
|
||||
# log_dir = name
|
||||
# if args.work_dir:
|
||||
# log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
# local_rank=args.local_rank,
|
||||
# optimization_level=args.amp_opt_level,
|
||||
# log_dir=log_dir,
|
||||
# checkpoint_dir=args.checkpoint_dir,
|
||||
# create_tb_writer=args.create_tb_writer,
|
||||
# files_to_copy=[args.model_config, __file__],
|
||||
# cudnn_benchmark=args.cudnn_benchmark,
|
||||
# tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
# checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
callbacks = create_all_dags(args, neural_factory)
|
||||
# evaluate model
|
||||
neural_factory.eval(callbacks=callbacks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,366 @@
|
|||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
from nemo.utils.lr_policies import CosineAnnealing
|
||||
from .data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper-speller",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
data_layer = data_loader_layer(
|
||||
manifest_filepath=args.train_dataset,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**train_dl_params,
|
||||
# normalize_transcripts=False
|
||||
)
|
||||
|
||||
N = len(data_layer)
|
||||
steps_per_epoch = math.ceil(
|
||||
N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
)
|
||||
logging.info("Have {0} examples to train on.".format(N))
|
||||
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
|
||||
)
|
||||
|
||||
multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
if multiply_batch_config:
|
||||
multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
|
||||
spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
if spectr_augment_config:
|
||||
data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
**spectr_augment_config
|
||||
)
|
||||
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
else:
|
||||
logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
logging.info("================================")
|
||||
logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
|
||||
logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
|
||||
logging.info(
|
||||
f"Total number of parameters in model: "
|
||||
f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info("================================")
|
||||
|
||||
# Train DAG
|
||||
(audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
|
||||
processed_signal_t, p_length_t = data_preprocessor(
|
||||
input_signal=audio_signal_t, length=a_sig_length_t
|
||||
)
|
||||
|
||||
if multiply_batch_config:
|
||||
(
|
||||
processed_signal_t,
|
||||
p_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = multiply_batch(
|
||||
in_x=processed_signal_t,
|
||||
in_x_len=p_length_t,
|
||||
in_y=transcript_t,
|
||||
in_y_len=transcript_len_t,
|
||||
)
|
||||
|
||||
if spectr_augment_config:
|
||||
processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
|
||||
|
||||
encoded_t, encoded_len_t = jasper_encoder(
|
||||
audio_signal=processed_signal_t, length=p_length_t
|
||||
)
|
||||
log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
loss_t = ctc_loss(
|
||||
log_probs=log_probs_t,
|
||||
targets=transcript_t,
|
||||
input_length=encoded_len_t,
|
||||
target_length=transcript_len_t,
|
||||
)
|
||||
|
||||
# Callbacks needed to print info to console and Tensorboard
|
||||
train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
get_tb_values=lambda x: [("loss", x[0])],
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
chpt_callback = nemo.core.CheckpointCallback(
|
||||
folder=neural_factory.checkpoint_dir,
|
||||
load_from_folder=args.load_dir,
|
||||
step_freq=args.checkpoint_save_freq,
|
||||
checkpoints_to_keep=30,
|
||||
)
|
||||
|
||||
callbacks = [train_callback, chpt_callback]
|
||||
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return loss_t, callbacks, steps_per_epoch
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
name = construct_name(
|
||||
args.exp_name,
|
||||
args.lr,
|
||||
args.batch_size,
|
||||
args.max_steps,
|
||||
args.num_epochs,
|
||||
args.weight_decay,
|
||||
args.optimizer,
|
||||
args.iter_per_step,
|
||||
)
|
||||
log_dir = name
|
||||
if args.work_dir:
|
||||
log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
local_rank=args.local_rank,
|
||||
optimization_level=args.amp_opt_level,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
create_tb_writer=args.create_tb_writer,
|
||||
files_to_copy=[args.model_config, __file__],
|
||||
cudnn_benchmark=args.cudnn_benchmark,
|
||||
tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory)
|
||||
# train model
|
||||
neural_factory.train(
|
||||
tensors_to_optimize=[train_loss],
|
||||
callbacks=callbacks,
|
||||
lr_policy=CosineAnnealing(
|
||||
args.max_steps
|
||||
if args.max_steps is not None
|
||||
else args.num_epochs * steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
),
|
||||
optimizer=args.optimizer,
|
||||
optimization_params={
|
||||
"num_epochs": args.num_epochs,
|
||||
"max_steps": args.max_steps,
|
||||
"lr": args.lr,
|
||||
"betas": (args.beta1, args.beta2),
|
||||
"weight_decay": args.weight_decay,
|
||||
"grad_norm_clip": None,
|
||||
},
|
||||
batches_per_step=args.iter_per_step,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,334 @@
|
|||
from functools import partial
|
||||
import tempfile
|
||||
|
||||
# from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import nemo
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.backends.pytorch import DataLayerNM
|
||||
from nemo.core import DeviceType
|
||||
|
||||
# from nemo.core.neural_types import *
|
||||
from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType
|
||||
from nemo.utils.decorators import add_port_docs
|
||||
|
||||
from nemo.collections.asr.parts.dataset import (
|
||||
# AudioDataset,
|
||||
# AudioLabelDataset,
|
||||
# KaldiFeatureDataset,
|
||||
# TranscriptDataset,
|
||||
parsers,
|
||||
collections,
|
||||
seq_collate_fn,
|
||||
)
|
||||
|
||||
# from functools import lru_cache
|
||||
import rpyc
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
from .featurizer import RpycWaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
|
||||
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
class CachedAudioDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset that loads tensors via a json file containing paths to audio
|
||||
files, transcripts, and durations (in seconds). Each new line is a
|
||||
different sample. Example below:
|
||||
|
||||
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
|
||||
"/path/to/audio.txt", "duration": 23.147}
|
||||
...
|
||||
{"audio_filepath": "/path/to/audio.wav", "text": "the
|
||||
transcription", offset": 301.75, "duration": 0.82, "utt":
|
||||
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
||||
|
||||
Args:
|
||||
manifest_filepath: Path to manifest json as described above. Can
|
||||
be comma-separated paths.
|
||||
labels: String containing all the possible characters to map to
|
||||
featurizer: Initialized featurizer class that converts paths of
|
||||
audio to feature tensors
|
||||
max_duration: If audio exceeds this length, do not include in dataset
|
||||
min_duration: If audio is less than this length, do not include
|
||||
in dataset
|
||||
max_utts: Limit number of utterances
|
||||
blank_index: blank character index, default = -1
|
||||
unk_index: unk_character index, default = -1
|
||||
normalize: whether to normalize transcript text (default): True
|
||||
bos_id: Id of beginning of sequence symbol to append if not None
|
||||
eos_id: Id of end of sequence symbol to append if not None
|
||||
load_audio: Boolean flag indicate whether do or not load audio
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
featurizer,
|
||||
max_duration=None,
|
||||
min_duration=None,
|
||||
max_utts=0,
|
||||
blank_index=-1,
|
||||
unk_index=-1,
|
||||
normalize=True,
|
||||
trim=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
load_audio=True,
|
||||
parser="en",
|
||||
):
|
||||
self.collection = collections.ASRAudioText(
|
||||
manifests_files=manifest_filepath.split(","),
|
||||
parser=parsers.make_parser(
|
||||
labels=labels,
|
||||
name=parser,
|
||||
unk_id=unk_index,
|
||||
blank_id=blank_index,
|
||||
do_normalize=normalize,
|
||||
),
|
||||
min_duration=min_duration,
|
||||
max_duration=max_duration,
|
||||
max_number=max_utts,
|
||||
)
|
||||
self.index_feature_map = {}
|
||||
|
||||
self.featurizer = featurizer
|
||||
self.trim = trim
|
||||
self.eos_id = eos_id
|
||||
self.bos_id = bos_id
|
||||
self.load_audio = load_audio
|
||||
print(f"initializing dataset {manifest_filepath}")
|
||||
|
||||
def exec_func(i):
|
||||
return self[i]
|
||||
|
||||
task_count = len(self.collection)
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all loading tasks")
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, range(task_count)),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=task_count,
|
||||
)
|
||||
)
|
||||
print(f"initializing complete")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.collection[index]
|
||||
if self.load_audio:
|
||||
cached_features = self.index_feature_map.get(index)
|
||||
if cached_features is not None:
|
||||
features = cached_features
|
||||
else:
|
||||
features = self.featurizer.process(
|
||||
sample.audio_file,
|
||||
offset=0,
|
||||
duration=sample.duration,
|
||||
trim=self.trim,
|
||||
)
|
||||
self.index_feature_map[index] = features
|
||||
f, fl = features, torch.tensor(features.shape[0]).long()
|
||||
else:
|
||||
f, fl = None, None
|
||||
|
||||
t, tl = sample.text_tokens, len(sample.text_tokens)
|
||||
if self.bos_id is not None:
|
||||
t = [self.bos_id] + t
|
||||
tl += 1
|
||||
if self.eos_id is not None:
|
||||
t = t + [self.eos_id]
|
||||
tl += 1
|
||||
|
||||
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.collection)
|
||||
|
||||
|
||||
class RpycAudioToTextDataLayer(DataLayerNM):
|
||||
"""Data Layer for general ASR tasks.
|
||||
|
||||
Module which reads ASR labeled data. It accepts comma-separated
|
||||
JSON manifest files describing the correspondence between wav audio files
|
||||
and their transcripts. JSON files should be of the following format::
|
||||
|
||||
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
|
||||
transcript_0}
|
||||
...
|
||||
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
|
||||
transcript_n}
|
||||
|
||||
Args:
|
||||
manifest_filepath (str): Dataset parameter.
|
||||
Path to JSON containing data.
|
||||
labels (list): Dataset parameter.
|
||||
List of characters that can be output by the ASR model.
|
||||
For Jasper, this is the 28 character set {a-z '}. The CTC blank
|
||||
symbol is automatically added later for models using ctc.
|
||||
batch_size (int): batch size
|
||||
sample_rate (int): Target sampling rate for data. Audio files will be
|
||||
resampled to sample_rate if it is not already.
|
||||
Defaults to 16000.
|
||||
int_values (bool): Bool indicating whether the audio file is saved as
|
||||
int data or float data.
|
||||
Defaults to False.
|
||||
eos_id (id): Dataset parameter.
|
||||
End of string symbol id used for seq2seq models.
|
||||
Defaults to None.
|
||||
min_duration (float): Dataset parameter.
|
||||
All training files which have a duration less than min_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to 0.1.
|
||||
max_duration (float): Dataset parameter.
|
||||
All training files which have a duration more than max_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to None.
|
||||
normalize_transcripts (bool): Dataset parameter.
|
||||
Whether to use automatic text cleaning.
|
||||
It is highly recommended to manually clean text for best results.
|
||||
Defaults to True.
|
||||
trim_silence (bool): Whether to use trim silence from beginning and end
|
||||
of audio signal using librosa.effects.trim().
|
||||
Defaults to False.
|
||||
load_audio (bool): Dataset parameter.
|
||||
Controls whether the dataloader loads the audio signal and
|
||||
transcript or just the transcript.
|
||||
Defaults to True.
|
||||
drop_last (bool): See PyTorch DataLoader.
|
||||
Defaults to False.
|
||||
shuffle (bool): See PyTorch DataLoader.
|
||||
Defaults to True.
|
||||
num_workers (int): See PyTorch DataLoader.
|
||||
Defaults to 0.
|
||||
perturb_config (dict): Currently disabled.
|
||||
"""
|
||||
|
||||
@property
|
||||
@add_port_docs()
|
||||
def output_ports(self):
|
||||
"""Returns definitions of module output ports.
|
||||
"""
|
||||
return {
|
||||
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
"audio_signal": NeuralType(
|
||||
("B", "T"),
|
||||
AudioSignal(freq=self._sample_rate)
|
||||
if self is not None and self._sample_rate is not None
|
||||
else AudioSignal(),
|
||||
),
|
||||
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
|
||||
"transcripts": NeuralType(("B", "T"), LabelsType()),
|
||||
"transcript_length": NeuralType(tuple("B"), LengthsType()),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
batch_size,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
pad_id=None,
|
||||
min_duration=0.1,
|
||||
max_duration=None,
|
||||
normalize_transcripts=True,
|
||||
trim_silence=False,
|
||||
load_audio=True,
|
||||
rpyc_host="",
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
):
|
||||
super().__init__()
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
def rpyc_root_fn():
|
||||
return rpyc.connect(
|
||||
rpyc_host, 8064, config={"sync_request_timeout": 600}
|
||||
).root
|
||||
|
||||
rpyc_conn = rpyc_root_fn()
|
||||
|
||||
self._featurizer = RpycWaveformFeaturizer(
|
||||
sample_rate=self._sample_rate,
|
||||
int_values=int_values,
|
||||
augmentor=None,
|
||||
rpyc_conn=rpyc_conn,
|
||||
)
|
||||
|
||||
def read_remote_manifests():
|
||||
local_mp = []
|
||||
for mrp in manifest_filepath.split(","):
|
||||
md = rpyc_conn.read_path(mrp)
|
||||
mf = tempfile.NamedTemporaryFile(
|
||||
dir="/tmp", prefix="jasper_manifest.", delete=False
|
||||
)
|
||||
mf.write(md)
|
||||
mf.close()
|
||||
local_mp.append(mf.name)
|
||||
return ",".join(local_mp)
|
||||
|
||||
local_manifest_filepath = read_remote_manifests()
|
||||
dataset_params = {
|
||||
"manifest_filepath": local_manifest_filepath,
|
||||
"labels": labels,
|
||||
"featurizer": self._featurizer,
|
||||
"max_duration": max_duration,
|
||||
"min_duration": min_duration,
|
||||
"normalize": normalize_transcripts,
|
||||
"trim": trim_silence,
|
||||
"bos_id": bos_id,
|
||||
"eos_id": eos_id,
|
||||
"load_audio": load_audio,
|
||||
}
|
||||
|
||||
self._dataset = CachedAudioDataset(**dataset_params)
|
||||
self._batch_size = batch_size
|
||||
|
||||
# Set up data loader
|
||||
if self._placement == DeviceType.AllGpu:
|
||||
logging.info("Parallelizing Datalayer.")
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(self._dataset)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
if batch_size == -1:
|
||||
batch_size = len(self._dataset)
|
||||
|
||||
pad_id = 0 if pad_id is None else pad_id
|
||||
self._dataloader = torch.utils.data.DataLoader(
|
||||
dataset=self._dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
|
||||
drop_last=drop_last,
|
||||
shuffle=shuffle if sampler is None else False,
|
||||
sampler=sampler,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def data_iterator(self):
|
||||
return self._dataloader
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
# import math
|
||||
|
||||
# import librosa
|
||||
import torch
|
||||
import pickle
|
||||
# import torch.nn as nn
|
||||
# from torch_stft import STFT
|
||||
|
||||
# from nemo import logging
|
||||
from nemo.collections.asr.parts.perturb import AudioAugmentor
|
||||
# from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
|
||||
class RpycWaveformFeaturizer(object):
|
||||
def __init__(
|
||||
self, sample_rate=16000, int_values=False, augmentor=None, rpyc_conn=None
|
||||
):
|
||||
self.augmentor = augmentor if augmentor is not None else AudioAugmentor()
|
||||
self.sample_rate = sample_rate
|
||||
self.int_values = int_values
|
||||
self.remote_path_samples = rpyc_conn.get_path_samples
|
||||
|
||||
def max_augmentation_length(self, length):
|
||||
return self.augmentor.max_augmentation_length(length)
|
||||
|
||||
def process(self, file_path, offset=0, duration=0, trim=False):
|
||||
audio = self.remote_path_samples(
|
||||
file_path,
|
||||
target_sr=self.sample_rate,
|
||||
int_values=self.int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
return torch.tensor(pickle.loads(audio), dtype=torch.float)
|
||||
|
||||
def process_segment(self, audio_segment):
|
||||
self.augmentor.perturb(audio_segment)
|
||||
return torch.tensor(audio_segment, dtype=torch.float)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, input_config, perturbation_configs=None):
|
||||
if perturbation_configs is not None:
|
||||
aa = AudioAugmentor.from_config(perturbation_configs)
|
||||
else:
|
||||
aa = None
|
||||
|
||||
sample_rate = input_config.get("sample_rate", 16000)
|
||||
int_values = input_config.get("int_values", False)
|
||||
|
||||
return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa)
|
||||
62
setup.py
62
setup.py
|
|
@ -1,11 +1,52 @@
|
|||
from setuptools import setup
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
requirements = [
|
||||
"ruamel.yaml",
|
||||
"torch==1.4.0",
|
||||
"torchvision==0.5.0",
|
||||
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
||||
]
|
||||
|
||||
extra_requirements = {"server": ["rpyc==4.1.4"]}
|
||||
extra_requirements = {
|
||||
"server": ["rpyc~=4.1.4", "tqdm~=4.39.0"],
|
||||
"data": [
|
||||
"google-cloud-texttospeech~=1.0.1",
|
||||
"tqdm~=4.39.0",
|
||||
"pydub~=0.24.0",
|
||||
"scikit_learn~=0.22.1",
|
||||
"pandas~=1.0.3",
|
||||
"boto3~=1.12.35",
|
||||
"ruamel.yaml==0.16.10",
|
||||
"pymongo==3.10.1",
|
||||
"librosa==0.7.2",
|
||||
"matplotlib==3.2.1",
|
||||
"pandas==1.0.3",
|
||||
"tabulate==0.8.7",
|
||||
"natural==0.2.0",
|
||||
"num2words==0.5.10",
|
||||
"typer[all]==0.1.1",
|
||||
"python-slugify==4.0.0",
|
||||
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
||||
],
|
||||
"validation": [
|
||||
"rpyc~=4.1.4",
|
||||
"pymongo==3.10.1",
|
||||
"typer[all]==0.1.1",
|
||||
"tqdm~=4.39.0",
|
||||
"librosa==0.7.2",
|
||||
"matplotlib==3.2.1",
|
||||
"pydub~=0.24.0",
|
||||
"streamlit==0.58.0",
|
||||
"natural==0.2.0",
|
||||
"stringcase==1.2.0",
|
||||
"google-cloud-speech~=1.3.1",
|
||||
]
|
||||
# "train": [
|
||||
# "torchaudio==0.5.0",
|
||||
# "torch-stft==0.1.4",
|
||||
# ]
|
||||
}
|
||||
packages = find_packages()
|
||||
|
||||
setup(
|
||||
name="jasper-asr",
|
||||
|
|
@ -17,11 +58,24 @@ setup(
|
|||
license="MIT",
|
||||
install_requires=requirements,
|
||||
extras_require=extra_requirements,
|
||||
packages=["."],
|
||||
packages=packages,
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"jasper_transcribe = jasper.transcribe:main",
|
||||
"jasper_asr_rpyc_server = jasper.server:main",
|
||||
"jasper_server = jasper.server:main",
|
||||
"jasper_trainer = jasper.training.cli:main",
|
||||
"jasper_evaluator = jasper.evaluate:main",
|
||||
"jasper_data_tts_generate = jasper.data.tts_generator:main",
|
||||
"jasper_data_conv_generate = jasper.data.conv_generator:main",
|
||||
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
|
||||
"jasper_data_test_generate = jasper.data.test_generator:main",
|
||||
"jasper_data_call_recycle = jasper.data.call_recycler:main",
|
||||
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
|
||||
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
|
||||
"jasper_data_server = jasper.data.server:main",
|
||||
"jasper_data_validation = jasper.data.validation.process:main",
|
||||
"jasper_data_preprocess = jasper.data.process:main",
|
||||
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
|
||||
]
|
||||
},
|
||||
zip_safe=False,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
import runpy
|
||||
|
||||
runpy.run_module("jasper.data.validation.ui", run_name="__main__", alter_sys=True)
|
||||
Loading…
Reference in New Issue