mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-03-08 02:22:34 +00:00
Compare commits
7 Commits
f5c49338d9
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| e30dd724f5 | |||
|
|
02df1b5282 | ||
| e8f58a5043 | |||
| 42647196fe | |||
| e77943b2f2 | |||
|
|
14d31a51c3 | ||
| e24a8cf9d0 |
4
.flake8
Normal file
4
.flake8
Normal file
@@ -0,0 +1,4 @@
|
||||
[flake8]
|
||||
exclude = docs
|
||||
ignore = E203, W503
|
||||
max-line-length = 119
|
||||
41
.gitignore
vendored
41
.gitignore
vendored
@@ -1,3 +1,11 @@
|
||||
/data/
|
||||
/model/
|
||||
/train/
|
||||
.env*
|
||||
*.yaml
|
||||
*.yml
|
||||
*.json
|
||||
|
||||
|
||||
# Created by https://www.gitignore.io/api/python
|
||||
# Edit at https://www.gitignore.io/?templates=python
|
||||
@@ -108,3 +116,36 @@ dmypy.json
|
||||
.pyre/
|
||||
|
||||
# End of https://www.gitignore.io/api/python
|
||||
|
||||
# Created by https://www.gitignore.io/api/macos
|
||||
# Edit at https://www.gitignore.io/?templates=macos
|
||||
|
||||
### macOS ###
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
# End of https://www.gitignore.io/api/macos
|
||||
|
||||
5
Notes.md
Normal file
5
Notes.md
Normal file
@@ -0,0 +1,5 @@
|
||||
|
||||
> Diff after splitting based on type
|
||||
```
|
||||
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
|
||||
```
|
||||
@@ -7,10 +7,16 @@
|
||||
|
||||
# Table of Contents
|
||||
|
||||
* [Prerequisites](#prerequisites)
|
||||
* [Features](#features)
|
||||
* [Installation](#installation)
|
||||
* [Usage](#usage)
|
||||
|
||||
# Prerequisites
|
||||
```bash
|
||||
# apt install libsndfile-dev ffmpeg
|
||||
```
|
||||
|
||||
# Features
|
||||
|
||||
* ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) )
|
||||
|
||||
@@ -62,7 +62,7 @@ class JasperASR(object):
|
||||
wf = wave.open(audio_file_path, "w")
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(16000)
|
||||
wf.setframerate(24000)
|
||||
wf.writeframesraw(audio_data)
|
||||
wf.close()
|
||||
manifest = {"audio_filepath": audio_file_path, "duration": 60, "text": "todo"}
|
||||
@@ -108,6 +108,8 @@ class JasperASR(object):
|
||||
tensors = self.neural_factory.infer(tensors=eval_tensors)
|
||||
prediction = post_process_predictions(tensors[0], self.labels)
|
||||
prediction_text = ". ".join(prediction)
|
||||
os.unlink(manifest_file.name)
|
||||
os.unlink(audio_file.name)
|
||||
return prediction_text
|
||||
|
||||
def transcribe_file(self, audio_file, *args, **kwargs):
|
||||
|
||||
21
jasper/client.py
Normal file
21
jasper/client.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import os
|
||||
import logging
|
||||
import rpyc
|
||||
from functools import lru_cache
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost")
|
||||
ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045"))
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
|
||||
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
|
||||
asr = rpyc.connect(asr_host, asr_port).root
|
||||
logger.info(f"connected to asr server successfully")
|
||||
return asr.transcribe
|
||||
1
jasper/data/__init__.py
Normal file
1
jasper/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
77
jasper/data/process.py
Normal file
77
jasper/data/process.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from sklearn.model_selection import train_test_split
|
||||
from .utils import asr_manifest_reader, asr_manifest_writer
|
||||
from typing import List
|
||||
from itertools import chain
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def fixate_data(dataset_path: Path):
|
||||
manifest_path = dataset_path / Path("manifest.json")
|
||||
real_manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
|
||||
def fix_path():
|
||||
for i in asr_manifest_reader(manifest_path):
|
||||
i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"]))
|
||||
yield i
|
||||
|
||||
asr_manifest_writer(real_manifest_path, fix_path())
|
||||
|
||||
|
||||
@app.command()
|
||||
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||
reader_list = []
|
||||
abs_manifest_path = Path("abs_manifest.json")
|
||||
for dataset_path in src_dataset_paths:
|
||||
manifest_path = dataset_path / abs_manifest_path
|
||||
reader_list.append(asr_manifest_reader(manifest_path))
|
||||
dest_dataset_path.mkdir(parents=True, exist_ok=True)
|
||||
dest_manifest_path = dest_dataset_path / abs_manifest_path
|
||||
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
|
||||
|
||||
|
||||
@app.command()
|
||||
def split_data(dataset_path: Path, test_size: float = 0.1):
|
||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||
asr_data = list(asr_manifest_reader(manifest_path))
|
||||
train_data, test_data = train_test_split(asr_data, test_size=test_size)
|
||||
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_data)
|
||||
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def validate_data(dataset_path: Path):
|
||||
from natural.date import compress
|
||||
from datetime import timedelta
|
||||
|
||||
for mf_type in ["train_manifest.json", "test_manifest.json"]:
|
||||
data_file = dataset_path / Path(mf_type)
|
||||
print(f"validating {data_file}.")
|
||||
with Path(data_file).open("r") as pf:
|
||||
data_jsonl = pf.readlines()
|
||||
duration = 0
|
||||
for (i, s) in enumerate(data_jsonl):
|
||||
try:
|
||||
d = json.loads(s)
|
||||
duration += d["duration"]
|
||||
audio_file = data_file.parent / Path(d["audio_filepath"])
|
||||
if not audio_file.exists():
|
||||
raise OSError(f"File {audio_file} not found")
|
||||
except BaseException as e:
|
||||
print(f'failed on {i} with "{e}"')
|
||||
duration_str = compress(timedelta(seconds=duration), pad=" ")
|
||||
print(
|
||||
f"no errors found. seems like a valid {mf_type}. contains {duration_str}sec of audio"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
93
jasper/data/rastrik_recycler.py
Normal file
93
jasper/data/rastrik_recycler.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from rastrik.proto.callrecord_pb2 import CallRecord
|
||||
import gzip
|
||||
from pydub import AudioSegment
|
||||
from .utils import ui_dump_manifest_writer, strip_silence
|
||||
|
||||
import typer
|
||||
from itertools import chain
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def extract_manifest(
|
||||
call_log_dir: Path = Path("./data/call_audio"),
|
||||
output_dir: Path = Path("./data"),
|
||||
dataset_name: str = "grassroot_pizzahut_v1",
|
||||
caller_name: str = "grassroot",
|
||||
verbose: bool = False,
|
||||
):
|
||||
call_asr_data: Path = output_dir / Path("asr_data")
|
||||
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def wav_pb2_generator(log_dir):
|
||||
for wav_path in log_dir.glob("**/*.wav"):
|
||||
if verbose:
|
||||
typer.echo(f"loading events for file {wav_path}")
|
||||
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
|
||||
meta_path = wav_path.with_suffix(".pb2.gz")
|
||||
yield call_wav, wav_path, meta_path
|
||||
|
||||
def read_event(call_wav, log_file):
|
||||
call_wav_0, call_wav_1 = call_wav.split_to_mono()
|
||||
with gzip.open(log_file, "rb") as log_h:
|
||||
record_data = log_h.read()
|
||||
cr = CallRecord()
|
||||
cr.ParseFromString(record_data)
|
||||
|
||||
first_audio_event_timestamp = next(
|
||||
(
|
||||
i
|
||||
for i in cr.events
|
||||
if i.WhichOneof("event_type") == "call_event"
|
||||
and i.call_event.WhichOneof("event_type") == "call_audio"
|
||||
)
|
||||
).timestamp.ToDatetime()
|
||||
|
||||
speech_events = [
|
||||
i
|
||||
for i in cr.events
|
||||
if i.WhichOneof("event_type") == "speech_event"
|
||||
and i.speech_event.WhichOneof("event_type") == "asr_final"
|
||||
]
|
||||
previous_event_timestamp = (
|
||||
first_audio_event_timestamp - first_audio_event_timestamp
|
||||
)
|
||||
for index, each_speech_events in enumerate(speech_events):
|
||||
asr_final = each_speech_events.speech_event.asr_final
|
||||
speech_timestamp = each_speech_events.timestamp.ToDatetime()
|
||||
actual_timestamp = speech_timestamp - first_audio_event_timestamp
|
||||
start_time = previous_event_timestamp.total_seconds() * 1000
|
||||
end_time = actual_timestamp.total_seconds() * 1000
|
||||
audio_segment = strip_silence(call_wav_1[start_time:end_time])
|
||||
|
||||
code_fb = BytesIO()
|
||||
audio_segment.export(code_fb, format="wav")
|
||||
wav_data = code_fb.getvalue()
|
||||
previous_event_timestamp = actual_timestamp
|
||||
duration = (end_time - start_time) / 1000
|
||||
yield asr_final, duration, wav_data, "grassroot", audio_segment
|
||||
|
||||
def generate_call_asr_data():
|
||||
full_data = []
|
||||
total_duration = 0
|
||||
for wav, wav_path, pb2_path in wav_pb2_generator(call_log_dir):
|
||||
asr_data = read_event(wav, pb2_path)
|
||||
total_duration += wav.duration_seconds
|
||||
full_data.append(asr_data)
|
||||
n_calls = len(full_data)
|
||||
typer.echo(f"loaded {n_calls} calls of duration {total_duration}s")
|
||||
n_dps = ui_dump_manifest_writer(call_asr_data, dataset_name, chain(*full_data))
|
||||
typer.echo(f"written {n_dps} data points")
|
||||
|
||||
generate_call_asr_data()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
57
jasper/data/server.py
Normal file
57
jasper/data/server.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import rpyc
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
import nemo
|
||||
import pickle
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
|
||||
)
|
||||
|
||||
|
||||
class ASRDataService(rpyc.Service):
|
||||
def exposed_get_path_samples(
|
||||
self, file_path, target_sr, int_values, offset, duration, trim
|
||||
):
|
||||
print(f"loading.. {file_path}")
|
||||
audio = AudioSegment.from_file(
|
||||
file_path,
|
||||
target_sr=target_sr,
|
||||
int_values=int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
|
||||
return pickle.dumps(audio.samples)
|
||||
|
||||
def exposed_read_path(self, file_path):
|
||||
# print(f"reading path.. {file_path}")
|
||||
return Path(file_path).read_bytes()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run_server(port: int = 0):
|
||||
listen_port = port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
|
||||
service = ASRDataService()
|
||||
t = ThreadedServer(
|
||||
service, port=listen_port, protocol_config={"allow_all_attrs": True}
|
||||
)
|
||||
typer.echo(f"starting asr server on {listen_port}...")
|
||||
t.start()
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
241
jasper/data/utils.py
Normal file
241
jasper/data/utils.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from uuid import uuid4
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pymongo
|
||||
from slugify import slugify
|
||||
from jasper.client import transcribe_gen
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def manifest_str(path, dur, text):
|
||||
return (
|
||||
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
|
||||
def wav_bytes(audio_bytes, frame_rate=24000):
|
||||
wf_b = io.BytesIO()
|
||||
with wave.open(wf_b, mode="w") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setframerate(frame_rate)
|
||||
wf.setsampwidth(2)
|
||||
wf.writeframesraw(audio_bytes)
|
||||
return wf_b.getvalue()
|
||||
|
||||
|
||||
def tscript_uuid_fname(transcript):
|
||||
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
|
||||
|
||||
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
dataset_dir = output_dir / Path(dataset_name)
|
||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
num_datapoints = 0
|
||||
with asr_manifest.open("w") as mf:
|
||||
print(f"writing manifest to {asr_manifest}")
|
||||
for transcript, audio_dur, wav_data in asr_data_source:
|
||||
fname = tscript_uuid_fname(transcript)
|
||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
audio_file.write_bytes(wav_data)
|
||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||
mf.write(manifest)
|
||||
if verbose:
|
||||
print(f"writing '{transcript}' of duration {audio_dur}")
|
||||
num_datapoints += 1
|
||||
return num_datapoints
|
||||
|
||||
|
||||
def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
dataset_dir = output_dir / Path(dataset_name)
|
||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def data_fn(
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
fname,
|
||||
audio_path,
|
||||
num_datapoints,
|
||||
rel_data_path,
|
||||
):
|
||||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
||||
png_path = Path(fname).with_suffix(".png")
|
||||
wav_plot_path = dataset_dir / Path("wav_plots") / png_path
|
||||
if not wav_plot_path.exists():
|
||||
plot_seg(wav_plot_path, audio_path)
|
||||
return {
|
||||
"audio_filepath": str(rel_data_path),
|
||||
"duration": round(audio_dur, 1),
|
||||
"text": transcript,
|
||||
"real_idx": num_datapoints,
|
||||
"audio_path": audio_path,
|
||||
"spoken": transcript,
|
||||
"caller": caller_name,
|
||||
"utterance_id": fname,
|
||||
"pretrained_asr": pretrained_result,
|
||||
"pretrained_wer": pretrained_wer,
|
||||
"plot_path": str(wav_plot_path),
|
||||
}
|
||||
|
||||
num_datapoints = 0
|
||||
data_funcs = []
|
||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||
audio_file.write_bytes(wav_data)
|
||||
audio_path = str(audio_file)
|
||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
||||
data_funcs.append(
|
||||
partial(
|
||||
data_fn,
|
||||
transcript,
|
||||
audio_dur,
|
||||
wav_data,
|
||||
caller_name,
|
||||
aud_seg,
|
||||
fname,
|
||||
audio_path,
|
||||
num_datapoints,
|
||||
rel_data_path,
|
||||
)
|
||||
)
|
||||
num_datapoints += 1
|
||||
ui_data = parallel_apply(lambda x: x(), data_funcs)
|
||||
return ui_data, num_datapoints
|
||||
|
||||
|
||||
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||
dataset_dir = output_dir / Path(dataset_name)
|
||||
dump_data, num_datapoints = ui_data_generator(
|
||||
output_dir, dataset_name, asr_data_source, verbose=verbose
|
||||
)
|
||||
|
||||
asr_manifest = dataset_dir / Path("manifest.json")
|
||||
with asr_manifest.open("w") as mf:
|
||||
print(f"writing manifest to {asr_manifest}")
|
||||
for d in dump_data:
|
||||
rel_data_path = d["audio_filepath"]
|
||||
audio_dur = d["duration"]
|
||||
transcript = d["text"]
|
||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||
mf.write(manifest)
|
||||
|
||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
|
||||
return num_datapoints
|
||||
|
||||
|
||||
def asr_manifest_reader(data_manifest_path: Path):
|
||||
print(f"reading manifest from {data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
data_jsonl = pf.readlines()
|
||||
data_data = [json.loads(v) for v in data_jsonl]
|
||||
for p in data_data:
|
||||
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
||||
p["text"] = p["text"].strip()
|
||||
yield p
|
||||
|
||||
|
||||
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
|
||||
with asr_manifest_path.open("w") as mf:
|
||||
print(f"opening {asr_manifest_path} for writing manifest")
|
||||
for mani_dict in manifest_str_source:
|
||||
manifest = manifest_str(
|
||||
mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
|
||||
)
|
||||
mf.write(manifest)
|
||||
|
||||
|
||||
def asr_test_writer(out_file_path: Path, source):
|
||||
def dd_str(dd, idx):
|
||||
path = dd["audio_filepath"]
|
||||
# dur = dd["duration"]
|
||||
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
|
||||
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
|
||||
|
||||
res_file = out_file_path.with_suffix(".result.json")
|
||||
with out_file_path.open("w") as of:
|
||||
print(f"opening {out_file_path} for writing test")
|
||||
results = []
|
||||
idx = 0
|
||||
for ui_dd in source:
|
||||
results.append(ui_dd)
|
||||
out_str = dd_str(ui_dd, idx)
|
||||
of.write(out_str)
|
||||
idx += 1
|
||||
of.write("DO_HANGUP\n")
|
||||
ExtendedPath(res_file).write_json(results)
|
||||
|
||||
|
||||
def batch(iterable, n=1):
|
||||
ls = len(iterable)
|
||||
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
|
||||
|
||||
|
||||
class ExtendedPath(type(Path())):
|
||||
"""docstring for ExtendedPath."""
|
||||
|
||||
def read_json(self):
|
||||
print(f"reading json from {self}")
|
||||
with self.open("r") as jf:
|
||||
return json.load(jf)
|
||||
|
||||
def write_json(self, data):
|
||||
print(f"writing json to {self}")
|
||||
self.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.open("w") as jf:
|
||||
return json.dump(data, jf, indent=2)
|
||||
|
||||
|
||||
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
|
||||
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
||||
mongo_uri = f"mongodb://{mongo_host}:{port}/"
|
||||
return pymongo.MongoClient(mongo_uri)[db][col]
|
||||
|
||||
|
||||
def strip_silence(sound):
|
||||
from pydub.silence import detect_leading_silence
|
||||
|
||||
start_trim = detect_leading_silence(sound)
|
||||
end_trim = detect_leading_silence(sound.reverse())
|
||||
duration = len(sound)
|
||||
return sound[start_trim : duration - end_trim]
|
||||
|
||||
|
||||
def plot_seg(wav_plot_path, audio_path):
|
||||
fig = plt.Figure()
|
||||
ax = fig.add_subplot()
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
||||
with wav_plot_path.open("wb") as wav_plot_f:
|
||||
fig.set_tight_layout(True)
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
|
||||
|
||||
def parallel_apply(fn, iterable, workers=8):
|
||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||
print(f"parallelly applying {fn}")
|
||||
return [
|
||||
res
|
||||
for res in tqdm(
|
||||
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
||||
)
|
||||
]
|
||||
1
jasper/data/validation/__init__.py
Normal file
1
jasper/data/validation/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
398
jasper/data/validation/process.py
Normal file
398
jasper/data/validation/process.py
Normal file
@@ -0,0 +1,398 @@
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
|
||||
from ..utils import (
|
||||
ExtendedPath,
|
||||
asr_manifest_reader,
|
||||
asr_manifest_writer,
|
||||
tscript_uuid_fname,
|
||||
get_mongo_conn,
|
||||
plot_seg,
|
||||
)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, rel_root, sample):
|
||||
from pydub import AudioSegment
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
from jasper.client import transcribe_gen
|
||||
|
||||
try:
|
||||
res = dict(sample)
|
||||
res["real_idx"] = idx
|
||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
||||
res["audio_path"] = str(audio_path)
|
||||
res["utterance_id"] = audio_path.stem
|
||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||
|
||||
aud_seg = (
|
||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
|
||||
wav_plot_path = (
|
||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
plot_seg(wav_plot_path, audio_path)
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
return res
|
||||
except BaseException as e:
|
||||
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_ui(
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dataset_dir: Path = Path("./data/asr_data"),
|
||||
dump_dir: Path = Path("./data/valiation_data"),
|
||||
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
|
||||
):
|
||||
from io import BytesIO
|
||||
from pydub import AudioSegment
|
||||
from ..utils import ui_data_generator
|
||||
|
||||
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
|
||||
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
|
||||
def asr_data_source_gen():
|
||||
with data_manifest_path.open("r") as pf:
|
||||
data_jsonl = pf.readlines()
|
||||
for v in data_jsonl:
|
||||
sample = json.loads(v)
|
||||
rel_root = data_manifest_path.parent
|
||||
res = dict(sample)
|
||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
||||
audio_segment = (
|
||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
wav_plot_path = (
|
||||
rel_root
|
||||
/ Path("wav_plots")
|
||||
/ Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
plot_seg(wav_plot_path, audio_path)
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
code_fb = BytesIO()
|
||||
audio_segment.export(code_fb, format="wav")
|
||||
wav_data = code_fb.getvalue()
|
||||
duration = audio_segment.duration_seconds
|
||||
asr_final = res["text"]
|
||||
yield asr_final, duration, wav_data, "caller", audio_segment
|
||||
|
||||
dump_data, num_datapoints = ui_data_generator(
|
||||
dataset_dir, data_name, asr_data_source_gen()
|
||||
)
|
||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
|
||||
|
||||
|
||||
@app.command()
|
||||
def sample_ui(
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
sample_count: int = typer.Option(80, show_default=True),
|
||||
sample_file: Path = Path("sample_dump.json"),
|
||||
):
|
||||
import pandas as pd
|
||||
|
||||
processed_data_path = dump_dir / Path(data_name) / dump_file
|
||||
sample_path = dump_dir / Path(data_name) / sample_file
|
||||
processed_data = ExtendedPath(processed_data_path).read_json()
|
||||
df = pd.DataFrame(processed_data["data"])
|
||||
samples_per_caller = sample_count // len(df["caller"].unique())
|
||||
caller_samples = pd.concat(
|
||||
[g.sample(samples_per_caller) for (c, g) in df.groupby("caller")]
|
||||
)
|
||||
caller_samples = caller_samples.reset_index(drop=True)
|
||||
caller_samples["real_idx"] = caller_samples.index
|
||||
sample_data = caller_samples.to_dict("records")
|
||||
processed_data["data"] = sample_data
|
||||
typer.echo(f"sampling {sample_count} datapoints")
|
||||
ExtendedPath(sample_path).write_json(processed_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def task_ui(
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
task_count: int = typer.Option(4, show_default=True),
|
||||
task_file: str = "task_dump",
|
||||
):
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
processed_data_path = dump_dir / Path(data_name) / dump_file
|
||||
processed_data = ExtendedPath(processed_data_path).read_json()
|
||||
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
|
||||
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
|
||||
task_f = task_f.reset_index(drop=True)
|
||||
task_f["real_idx"] = task_f.index
|
||||
task_data = task_f.to_dict("records")
|
||||
processed_data["data"] = task_data
|
||||
task_path = dump_dir / Path(data_name) / Path(task_file + f"-{t_idx}.json")
|
||||
ExtendedPath(task_path).write_json(processed_data)
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_corrections(
|
||||
task_uid: str,
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_fname: Path = Path("corrections.json"),
|
||||
):
|
||||
dump_path = dump_dir / Path(data_name) / dump_fname
|
||||
col = get_mongo_conn(col="asr_validation")
|
||||
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
||||
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
||||
cursor_obj = col.find(
|
||||
{"type": "correction", "task_id": task_id}, projection={"_id": False}
|
||||
)
|
||||
corrections = [c for c in cursor_obj]
|
||||
ExtendedPath(dump_path).write_json(corrections)
|
||||
|
||||
|
||||
@app.command()
|
||||
def caller_quality(
|
||||
task_uid: str,
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_fname: Path = Path("ui_dump.json"),
|
||||
correction_fname: Path = Path("corrections.json"),
|
||||
):
|
||||
import copy
|
||||
import pandas as pd
|
||||
|
||||
dump_path = dump_dir / Path(data_name) / dump_fname
|
||||
correction_path = dump_dir / Path(data_name) / correction_fname
|
||||
dump_data = ExtendedPath(dump_path).read_json()
|
||||
|
||||
dump_map = {d["utterance_id"]: d for d in dump_data["data"]}
|
||||
correction_data = ExtendedPath(correction_path).read_json()
|
||||
|
||||
def correction_dp(c):
|
||||
dp = copy.deepcopy(dump_map[c["code"]])
|
||||
dp["valid"] = c["value"]["status"] == "Correct"
|
||||
return dp
|
||||
|
||||
corrected_dump = [
|
||||
correction_dp(c)
|
||||
for c in correction_data
|
||||
if c["task_id"].rsplit("-", 1)[1] == task_uid
|
||||
]
|
||||
df = pd.DataFrame(corrected_dump)
|
||||
print(f"Total samples: {len(df)}")
|
||||
for (c, g) in df.groupby("caller"):
|
||||
total = len(g)
|
||||
valid = len(g[g["valid"] == True])
|
||||
valid_rate = valid * 100 / total
|
||||
print(f"Caller: {c} Valid%:{valid_rate:.2f} of {total} samples")
|
||||
|
||||
|
||||
@app.command()
|
||||
def fill_unannotated(
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/valiation_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
corrections_file: Path = Path("corrections.json"),
|
||||
):
|
||||
processed_data_path = dump_dir / Path(data_name) / dump_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
processed_data = json.load(processed_data_path.open())
|
||||
corrections = json.load(corrections_path.open())
|
||||
annotated_codes = {c["code"] for c in corrections}
|
||||
all_codes = {c["gold_chars"] for c in processed_data}
|
||||
unann_codes = all_codes - annotated_codes
|
||||
mongo_conn = get_mongo_conn(col="asr_validation")
|
||||
for c in unann_codes:
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": c},
|
||||
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def split_extract(
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
|
||||
# dump_dir: Path = Path("./data/valiation_data"),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
||||
conv_data_path: Path = typer.Option(
|
||||
Path("./data/conv_data.json"), show_default=True
|
||||
),
|
||||
extraction_type: str = "all",
|
||||
):
|
||||
import shutil
|
||||
|
||||
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||
conv_data = ExtendedPath(conv_data_path).read_json()
|
||||
|
||||
def extract_data_of_type(extraction_key):
|
||||
extraction_vals = conv_data[extraction_key]
|
||||
dest_data_name = data_name + "_" + extraction_key.lower()
|
||||
|
||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
dest_data_dir = dump_dir / Path(dest_data_name)
|
||||
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
||||
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
||||
dest_manifest_path = dest_data_dir / manifest_file
|
||||
dest_ui_path = dest_data_dir / dump_file
|
||||
|
||||
def extract_manifest(mg):
|
||||
for m in mg:
|
||||
if m["text"] in extraction_vals:
|
||||
shutil.copy(
|
||||
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
|
||||
)
|
||||
yield m
|
||||
|
||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||
|
||||
ui_data_path = dump_dir / Path(data_name) / dump_file
|
||||
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
||||
ui_data = orig_ui_data["data"]
|
||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||
extracted_ui_data = list(
|
||||
filter(lambda u: u["text"] in extraction_vals, ui_data)
|
||||
)
|
||||
final_data = []
|
||||
for i, d in enumerate(extracted_ui_data):
|
||||
d["real_idx"] = i
|
||||
final_data.append(d)
|
||||
orig_ui_data["data"] = final_data
|
||||
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
||||
|
||||
if corrections_file:
|
||||
dest_correction_path = dest_data_dir / corrections_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
corrections = json.load(corrections_path.open())
|
||||
extracted_corrections = list(
|
||||
filter(
|
||||
lambda c: c["code"] in file_ui_map
|
||||
and file_ui_map[c["code"]]["text"] in extraction_vals,
|
||||
corrections,
|
||||
)
|
||||
)
|
||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
||||
|
||||
if extraction_type.value == "all":
|
||||
for ext_key in conv_data.keys():
|
||||
extract_data_of_type(ext_key)
|
||||
else:
|
||||
extract_data_of_type(extraction_type.value)
|
||||
|
||||
|
||||
@app.command()
|
||||
def update_corrections(
|
||||
data_name: str = typer.Option("dataname", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: Path = Path("corrections.json"),
|
||||
ui_dump_file: Path = Path("ui_dump.json"),
|
||||
skip_incorrect: bool = typer.Option(True, show_default=True),
|
||||
):
|
||||
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
|
||||
|
||||
def correct_manifest(ui_dump_path, corrections_path):
|
||||
corrections = ExtendedPath(corrections_path).read_json()
|
||||
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
|
||||
correct_set = {
|
||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||
}
|
||||
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
|
||||
correction_map = {
|
||||
c["code"]: c["value"]["correction"]
|
||||
for c in corrections
|
||||
if c["value"]["status"] == "Incorrect"
|
||||
}
|
||||
# for d in manifest_data_gen:
|
||||
# if d["chars"] in incorrect_set:
|
||||
# d["audio_path"].unlink()
|
||||
# renamed_set = set()
|
||||
for d in ui_data:
|
||||
if d["utterance_id"] in correct_set:
|
||||
yield {
|
||||
"audio_filepath": d["audio_filepath"],
|
||||
"duration": d["duration"],
|
||||
"text": d["text"],
|
||||
}
|
||||
elif d["utterance_id"] in correction_map:
|
||||
correct_text = correction_map[d["utterance_id"]]
|
||||
if skip_incorrect:
|
||||
print(
|
||||
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
||||
)
|
||||
else:
|
||||
orig_audio_path = Path(d["audio_path"])
|
||||
new_name = str(
|
||||
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
|
||||
)
|
||||
new_audio_path = orig_audio_path.with_name(new_name)
|
||||
orig_audio_path.replace(new_audio_path)
|
||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||
yield {
|
||||
"audio_filepath": new_filepath,
|
||||
"duration": d["duration"],
|
||||
"text": correct_text,
|
||||
}
|
||||
else:
|
||||
orig_audio_path = Path(d["audio_path"])
|
||||
# don't delete if another correction points to an old file
|
||||
# if d["text"] not in renamed_set:
|
||||
orig_audio_path.unlink()
|
||||
# else:
|
||||
# print(f'skipping deletion of correction:{d["text"]}')
|
||||
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
dataset_dir = data_manifest_path.parent
|
||||
dataset_name = dataset_dir.name
|
||||
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
|
||||
if not backup_dir.exists():
|
||||
typer.echo(f"backing up to :{backup_dir}")
|
||||
shutil.copytree(str(dataset_dir), str(backup_dir))
|
||||
# manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
|
||||
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
||||
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
||||
new_data_manifest_path.replace(data_manifest_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
def clear_mongo_corrections():
|
||||
delete = typer.confirm("are you sure you want to clear mongo collection it?")
|
||||
if delete:
|
||||
col = get_mongo_conn(col="asr_validation")
|
||||
col.delete_many({"type": "correction"})
|
||||
col.delete_many({"type": "current_cursor"})
|
||||
typer.echo("deleted mongo collection.")
|
||||
return
|
||||
typer.echo("Aborted")
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
38
jasper/data/validation/st_rerun.py
Normal file
38
jasper/data/validation/st_rerun.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import streamlit.ReportThread as ReportThread
|
||||
from streamlit.ScriptRequestQueue import RerunData
|
||||
from streamlit.ScriptRunner import RerunException
|
||||
from streamlit.server.Server import Server
|
||||
|
||||
|
||||
def rerun():
|
||||
"""Rerun a Streamlit app from the top!"""
|
||||
widget_states = _get_widget_states()
|
||||
raise RerunException(RerunData(widget_states))
|
||||
|
||||
|
||||
def _get_widget_states():
|
||||
# Hack to get the session object from Streamlit.
|
||||
|
||||
ctx = ReportThread.get_report_ctx()
|
||||
|
||||
session = None
|
||||
|
||||
current_server = Server.get_current()
|
||||
if hasattr(current_server, '_session_infos'):
|
||||
# Streamlit < 0.56
|
||||
session_infos = Server.get_current()._session_infos.values()
|
||||
else:
|
||||
session_infos = Server.get_current()._session_info_by_id.values()
|
||||
|
||||
for session_info in session_infos:
|
||||
if session_info.session.enqueue == ctx.enqueue:
|
||||
session = session_info.session
|
||||
|
||||
if session is None:
|
||||
raise RuntimeError(
|
||||
"Oh noes. Couldn't get your Streamlit Session object"
|
||||
"Are you doing something fancy with threads?"
|
||||
)
|
||||
# Got the session object!
|
||||
|
||||
return session._widget_states
|
||||
146
jasper/data/validation/ui.py
Normal file
146
jasper/data/validation/ui.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
import typer
|
||||
from uuid import uuid4
|
||||
from ..utils import ExtendedPath, get_mongo_conn
|
||||
from .st_rerun import rerun
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
if not hasattr(st, "mongo_connected"):
|
||||
st.mongoclient = get_mongo_conn(col="asr_validation")
|
||||
mongo_conn = st.mongoclient
|
||||
st.task_id = str(uuid4())
|
||||
|
||||
def current_cursor_fn():
|
||||
# mongo_conn = st.mongoclient
|
||||
cursor_obj = mongo_conn.find_one(
|
||||
{"type": "current_cursor", "task_id": st.task_id}
|
||||
)
|
||||
cursor_val = cursor_obj["cursor"]
|
||||
return cursor_val
|
||||
|
||||
def update_cursor_fn(val=0):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "current_cursor", "task_id": st.task_id},
|
||||
{"$set": {"type": "current_cursor", "task_id": st.task_id, "cursor": val}},
|
||||
upsert=True,
|
||||
)
|
||||
rerun()
|
||||
|
||||
def get_correction_entry_fn(code):
|
||||
return mongo_conn.find_one(
|
||||
{"type": "correction", "code": code}, projection={"_id": False}
|
||||
)
|
||||
|
||||
def update_entry_fn(code, value):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": code},
|
||||
{"$set": {"value": value, "task_id": st.task_id}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
def set_task_fn(mf_path, task_id):
|
||||
if task_id:
|
||||
st.task_id = task_id
|
||||
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
|
||||
if not task_path.exists():
|
||||
print(f"creating task lock at {task_path}")
|
||||
task_path.touch()
|
||||
|
||||
st.get_current_cursor = current_cursor_fn
|
||||
st.update_cursor = update_cursor_fn
|
||||
st.get_correction_entry = get_correction_entry_fn
|
||||
st.update_entry = update_entry_fn
|
||||
st.set_task = set_task_fn
|
||||
st.mongo_connected = True
|
||||
cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
|
||||
if not cursor_obj:
|
||||
update_cursor_fn(0)
|
||||
|
||||
|
||||
@st.cache()
|
||||
def load_ui_data(validation_ui_data_path: Path):
|
||||
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||
return ExtendedPath(validation_ui_data_path).read_json()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(manifest: Path, task_id: str = ""):
|
||||
st.set_task(manifest, task_id)
|
||||
ui_config = load_ui_data(manifest)
|
||||
asr_data = ui_config["data"]
|
||||
annotation_only = ui_config.get("annotation_only", False)
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
task_uid = st.task_id.rsplit("-", 1)[1]
|
||||
if annotation_only:
|
||||
st.title(f"ASR Annotation - # {task_uid}")
|
||||
else:
|
||||
st.title(f"ASR Validation - # {task_uid}")
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
if not annotation_only:
|
||||
st.sidebar.title("Results:")
|
||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||
if "caller" in sample:
|
||||
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
||||
else:
|
||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
||||
# set default to text
|
||||
corrected = sample["text"]
|
||||
correction_entry = st.get_correction_entry(sample["utterance_id"])
|
||||
selected_idx = 0
|
||||
options = ("Correct", "Incorrect", "Inaudible")
|
||||
# if correction entry is present set the corresponding ui defaults
|
||||
if correction_entry:
|
||||
selected_idx = options.index(correction_entry["value"]["status"])
|
||||
corrected = correction_entry["value"]["correction"]
|
||||
selected = st.radio("The Audio is", options, index=selected_idx)
|
||||
if selected == "Incorrect":
|
||||
corrected = st.text_input("Actual:", value=corrected)
|
||||
if selected == "Inaudible":
|
||||
corrected = ""
|
||||
if st.button("Submit"):
|
||||
st.update_entry(
|
||||
sample["utterance_id"], {"status": selected, "correction": corrected}
|
||||
)
|
||||
st.update_cursor(sample_no + 1)
|
||||
if correction_entry:
|
||||
status = correction_entry["value"]["status"]
|
||||
correction = correction_entry["value"]["correction"]
|
||||
st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
|
||||
text_sample = st.text_input("Go to Text:", value="")
|
||||
if text_sample != "":
|
||||
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
|
||||
if len(candidates) > 0:
|
||||
st.update_cursor(candidates[0])
|
||||
real_idx = st.number_input(
|
||||
"Go to real-index",
|
||||
value=sample["real_idx"],
|
||||
min_value=0,
|
||||
max_value=len(asr_data) - 1,
|
||||
)
|
||||
if real_idx != int(sample["real_idx"]):
|
||||
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
|
||||
st.update_cursor(idx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
359
jasper/evaluate.py
Normal file
359
jasper/evaluate.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
# import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
# monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
# from nemo.utils.lr_policies import CosineAnnealing
|
||||
from training.data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper-speller",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs", type=int, required=False, help="number of epochs to train"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="encoder checkpoint file: JasperEncoder.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="decoder checkpoint file: JasperDecoderForCTC.pt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
# data_layer = data_loader_layer(
|
||||
# manifest_filepath=args.train_dataset,
|
||||
# sample_rate=sample_rate,
|
||||
# labels=vocab,
|
||||
# batch_size=args.batch_size,
|
||||
# num_workers=cpu_per_traindl,
|
||||
# **train_dl_params,
|
||||
# # normalize_transcripts=False
|
||||
# )
|
||||
#
|
||||
# N = len(data_layer)
|
||||
# steps_per_epoch = math.ceil(
|
||||
# N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
# )
|
||||
# logging.info("Have {0} examples to train on.".format(N))
|
||||
#
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
|
||||
)
|
||||
|
||||
# multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
# if multiply_batch_config:
|
||||
# multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
#
|
||||
# spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
# if spectr_augment_config:
|
||||
# data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
# **spectr_augment_config
|
||||
# )
|
||||
#
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
# if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
# else:
|
||||
# logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
# logging.info("================================")
|
||||
# logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
|
||||
# logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
|
||||
# logging.info(
|
||||
# f"Total number of parameters in model: "
|
||||
# f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
# )
|
||||
# logging.info("================================")
|
||||
#
|
||||
# # Train DAG
|
||||
# (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
|
||||
# processed_signal_t, p_length_t = data_preprocessor(
|
||||
# input_signal=audio_signal_t, length=a_sig_length_t
|
||||
# )
|
||||
#
|
||||
# if multiply_batch_config:
|
||||
# (
|
||||
# processed_signal_t,
|
||||
# p_length_t,
|
||||
# transcript_t,
|
||||
# transcript_len_t,
|
||||
# ) = multiply_batch(
|
||||
# in_x=processed_signal_t,
|
||||
# in_x_len=p_length_t,
|
||||
# in_y=transcript_t,
|
||||
# in_y_len=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# if spectr_augment_config:
|
||||
# processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
|
||||
#
|
||||
# encoded_t, encoded_len_t = jasper_encoder(
|
||||
# audio_signal=processed_signal_t, length=p_length_t
|
||||
# )
|
||||
# log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
# predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
# loss_t = ctc_loss(
|
||||
# log_probs=log_probs_t,
|
||||
# targets=transcript_t,
|
||||
# input_length=encoded_len_t,
|
||||
# target_length=transcript_len_t,
|
||||
# )
|
||||
#
|
||||
# # Callbacks needed to print info to console and Tensorboard
|
||||
# train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
# tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
# print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
# get_tb_values=lambda x: [("loss", x[0])],
|
||||
# tb_writer=neural_factory.tb_writer,
|
||||
# )
|
||||
#
|
||||
# chpt_callback = nemo.core.CheckpointCallback(
|
||||
# folder=neural_factory.checkpoint_dir,
|
||||
# load_from_folder=args.load_dir,
|
||||
# step_freq=args.checkpoint_save_freq,
|
||||
# checkpoints_to_keep=30,
|
||||
# )
|
||||
#
|
||||
# callbacks = [train_callback, chpt_callback]
|
||||
callbacks = []
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return callbacks
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
# name = construct_name(
|
||||
# args.exp_name,
|
||||
# args.lr,
|
||||
# args.batch_size,
|
||||
# args.max_steps,
|
||||
# args.num_epochs,
|
||||
# args.weight_decay,
|
||||
# args.optimizer,
|
||||
# args.iter_per_step,
|
||||
# )
|
||||
# log_dir = name
|
||||
# if args.work_dir:
|
||||
# log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
placement=nemo.core.DeviceType.GPU,
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
# local_rank=args.local_rank,
|
||||
# optimization_level=args.amp_opt_level,
|
||||
# log_dir=log_dir,
|
||||
# checkpoint_dir=args.checkpoint_dir,
|
||||
# create_tb_writer=args.create_tb_writer,
|
||||
# files_to_copy=[args.model_config, __file__],
|
||||
# cudnn_benchmark=args.cudnn_benchmark,
|
||||
# tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
# checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
callbacks = create_all_dags(args, neural_factory)
|
||||
# evaluate model
|
||||
neural_factory.eval(callbacks=callbacks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
jasper/training/__init__.py
Normal file
1
jasper/training/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
366
jasper/training/cli.py
Normal file
366
jasper/training/cli.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# Copyright (c) 2019 NVIDIA Corporation
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
import nemo
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import nemo.utils.argparse as nm_argparse
|
||||
from nemo.collections.asr.helpers import (
|
||||
monitor_asr_train_progress,
|
||||
process_evaluation_batch,
|
||||
process_evaluation_epoch,
|
||||
)
|
||||
|
||||
from nemo.utils.lr_policies import CosineAnnealing
|
||||
from .data_loaders import RpycAudioToTextDataLayer
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[nm_argparse.NemoArgParser()],
|
||||
description="Jasper",
|
||||
conflict_handler="resolve",
|
||||
)
|
||||
parser.set_defaults(
|
||||
checkpoint_dir=None,
|
||||
optimizer="novograd",
|
||||
batch_size=64,
|
||||
eval_batch_size=64,
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=100,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
exp_name="jasper-speller",
|
||||
)
|
||||
|
||||
# Overwrite default args
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="max number of steps to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
parser.add_argument("--beta1", default=0.95, type=float)
|
||||
parser.add_argument("--beta2", default=0.25, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--load_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory with pre-trained checkpoint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.max_steps is None and args.num_epochs is None:
|
||||
raise ValueError("Either max_steps or num_epochs should be provided.")
|
||||
return args
|
||||
|
||||
|
||||
def construct_name(
|
||||
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
|
||||
):
|
||||
if max_steps is not None:
|
||||
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
|
||||
)
|
||||
else:
|
||||
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
|
||||
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
|
||||
)
|
||||
|
||||
|
||||
def create_all_dags(args, neural_factory):
|
||||
yaml = YAML(typ="safe")
|
||||
with open(args.model_config) as f:
|
||||
jasper_params = yaml.load(f)
|
||||
vocab = jasper_params["labels"]
|
||||
sample_rate = jasper_params["sample_rate"]
|
||||
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
|
||||
data_layer = data_loader_layer(
|
||||
manifest_filepath=args.train_dataset,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**train_dl_params,
|
||||
# normalize_transcripts=False
|
||||
)
|
||||
|
||||
N = len(data_layer)
|
||||
steps_per_epoch = math.ceil(
|
||||
N / (args.batch_size * args.iter_per_step * args.num_gpus)
|
||||
)
|
||||
logging.info("Have {0} examples to train on.".format(N))
|
||||
|
||||
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
|
||||
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
|
||||
)
|
||||
|
||||
multiply_batch_config = jasper_params.get("MultiplyBatch", None)
|
||||
if multiply_batch_config:
|
||||
multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
|
||||
|
||||
spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
|
||||
if spectr_augment_config:
|
||||
data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
|
||||
**spectr_augment_config
|
||||
)
|
||||
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
if args.eval_datasets:
|
||||
for eval_datasets in args.eval_datasets:
|
||||
data_layer_eval = data_loader_layer(
|
||||
manifest_filepath=eval_datasets,
|
||||
sample_rate=sample_rate,
|
||||
labels=vocab,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=cpu_per_traindl,
|
||||
**eval_dl_params,
|
||||
)
|
||||
|
||||
data_layers_eval.append(data_layer_eval)
|
||||
else:
|
||||
logging.warning("There were no val datasets passed")
|
||||
|
||||
jasper_encoder = nemo_asr.JasperEncoder(
|
||||
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
|
||||
**jasper_params["JasperEncoder"],
|
||||
)
|
||||
|
||||
jasper_decoder = nemo_asr.JasperDecoderForCTC(
|
||||
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
|
||||
num_classes=len(vocab),
|
||||
)
|
||||
|
||||
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
|
||||
|
||||
greedy_decoder = nemo_asr.GreedyCTCDecoder()
|
||||
|
||||
logging.info("================================")
|
||||
logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
|
||||
logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
|
||||
logging.info(
|
||||
f"Total number of parameters in model: "
|
||||
f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
|
||||
)
|
||||
logging.info("================================")
|
||||
|
||||
# Train DAG
|
||||
(audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
|
||||
processed_signal_t, p_length_t = data_preprocessor(
|
||||
input_signal=audio_signal_t, length=a_sig_length_t
|
||||
)
|
||||
|
||||
if multiply_batch_config:
|
||||
(
|
||||
processed_signal_t,
|
||||
p_length_t,
|
||||
transcript_t,
|
||||
transcript_len_t,
|
||||
) = multiply_batch(
|
||||
in_x=processed_signal_t,
|
||||
in_x_len=p_length_t,
|
||||
in_y=transcript_t,
|
||||
in_y_len=transcript_len_t,
|
||||
)
|
||||
|
||||
if spectr_augment_config:
|
||||
processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
|
||||
|
||||
encoded_t, encoded_len_t = jasper_encoder(
|
||||
audio_signal=processed_signal_t, length=p_length_t
|
||||
)
|
||||
log_probs_t = jasper_decoder(encoder_output=encoded_t)
|
||||
predictions_t = greedy_decoder(log_probs=log_probs_t)
|
||||
loss_t = ctc_loss(
|
||||
log_probs=log_probs_t,
|
||||
targets=transcript_t,
|
||||
input_length=encoded_len_t,
|
||||
target_length=transcript_len_t,
|
||||
)
|
||||
|
||||
# Callbacks needed to print info to console and Tensorboard
|
||||
train_callback = nemo.core.SimpleLossLoggerCallback(
|
||||
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
|
||||
print_func=partial(monitor_asr_train_progress, labels=vocab),
|
||||
get_tb_values=lambda x: [("loss", x[0])],
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
chpt_callback = nemo.core.CheckpointCallback(
|
||||
folder=neural_factory.checkpoint_dir,
|
||||
load_from_folder=args.load_dir,
|
||||
step_freq=args.checkpoint_save_freq,
|
||||
checkpoints_to_keep=30,
|
||||
)
|
||||
|
||||
callbacks = [train_callback, chpt_callback]
|
||||
|
||||
# assemble eval DAGs
|
||||
for i, eval_dl in enumerate(data_layers_eval):
|
||||
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
|
||||
processed_signal_e, p_length_e = data_preprocessor(
|
||||
input_signal=audio_signal_e, length=a_sig_length_e
|
||||
)
|
||||
encoded_e, encoded_len_e = jasper_encoder(
|
||||
audio_signal=processed_signal_e, length=p_length_e
|
||||
)
|
||||
log_probs_e = jasper_decoder(encoder_output=encoded_e)
|
||||
predictions_e = greedy_decoder(log_probs=log_probs_e)
|
||||
loss_e = ctc_loss(
|
||||
log_probs=log_probs_e,
|
||||
targets=transcript_e,
|
||||
input_length=encoded_len_e,
|
||||
target_length=transcript_len_e,
|
||||
)
|
||||
|
||||
# create corresponding eval callback
|
||||
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
|
||||
eval_callback = nemo.core.EvaluatorCallback(
|
||||
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
|
||||
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
|
||||
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
|
||||
eval_step=args.eval_freq,
|
||||
tb_writer=neural_factory.tb_writer,
|
||||
)
|
||||
|
||||
callbacks.append(eval_callback)
|
||||
return loss_t, callbacks, steps_per_epoch
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
name = construct_name(
|
||||
args.exp_name,
|
||||
args.lr,
|
||||
args.batch_size,
|
||||
args.max_steps,
|
||||
args.num_epochs,
|
||||
args.weight_decay,
|
||||
args.optimizer,
|
||||
args.iter_per_step,
|
||||
)
|
||||
log_dir = name
|
||||
if args.work_dir:
|
||||
log_dir = os.path.join(args.work_dir, name)
|
||||
|
||||
# instantiate Neural Factory with supported backend
|
||||
neural_factory = nemo.core.NeuralModuleFactory(
|
||||
backend=nemo.core.Backend.PyTorch,
|
||||
local_rank=args.local_rank,
|
||||
optimization_level=args.amp_opt_level,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
create_tb_writer=args.create_tb_writer,
|
||||
files_to_copy=[args.model_config, __file__],
|
||||
cudnn_benchmark=args.cudnn_benchmark,
|
||||
tensorboard_dir=args.tensorboard_dir,
|
||||
)
|
||||
args.num_gpus = neural_factory.world_size
|
||||
|
||||
checkpoint_dir = neural_factory.checkpoint_dir
|
||||
if args.local_rank is not None:
|
||||
logging.info("Doing ALL GPU")
|
||||
|
||||
# build dags
|
||||
train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory)
|
||||
# train model
|
||||
neural_factory.train(
|
||||
tensors_to_optimize=[train_loss],
|
||||
callbacks=callbacks,
|
||||
lr_policy=CosineAnnealing(
|
||||
args.max_steps
|
||||
if args.max_steps is not None
|
||||
else args.num_epochs * steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
),
|
||||
optimizer=args.optimizer,
|
||||
optimization_params={
|
||||
"num_epochs": args.num_epochs,
|
||||
"max_steps": args.max_steps,
|
||||
"lr": args.lr,
|
||||
"betas": (args.beta1, args.beta2),
|
||||
"weight_decay": args.weight_decay,
|
||||
"grad_norm_clip": None,
|
||||
},
|
||||
batches_per_step=args.iter_per_step,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
334
jasper/training/data_loaders.py
Normal file
334
jasper/training/data_loaders.py
Normal file
@@ -0,0 +1,334 @@
|
||||
from functools import partial
|
||||
import tempfile
|
||||
|
||||
# from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import nemo
|
||||
|
||||
# import nemo.collections.asr as nemo_asr
|
||||
from nemo.backends.pytorch import DataLayerNM
|
||||
from nemo.core import DeviceType
|
||||
|
||||
# from nemo.core.neural_types import *
|
||||
from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType
|
||||
from nemo.utils.decorators import add_port_docs
|
||||
|
||||
from nemo.collections.asr.parts.dataset import (
|
||||
# AudioDataset,
|
||||
# AudioLabelDataset,
|
||||
# KaldiFeatureDataset,
|
||||
# TranscriptDataset,
|
||||
parsers,
|
||||
collections,
|
||||
seq_collate_fn,
|
||||
)
|
||||
|
||||
# from functools import lru_cache
|
||||
import rpyc
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
from .featurizer import RpycWaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
|
||||
|
||||
|
||||
logging = nemo.logging
|
||||
|
||||
|
||||
class CachedAudioDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset that loads tensors via a json file containing paths to audio
|
||||
files, transcripts, and durations (in seconds). Each new line is a
|
||||
different sample. Example below:
|
||||
|
||||
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
|
||||
"/path/to/audio.txt", "duration": 23.147}
|
||||
...
|
||||
{"audio_filepath": "/path/to/audio.wav", "text": "the
|
||||
transcription", offset": 301.75, "duration": 0.82, "utt":
|
||||
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
|
||||
|
||||
Args:
|
||||
manifest_filepath: Path to manifest json as described above. Can
|
||||
be comma-separated paths.
|
||||
labels: String containing all the possible characters to map to
|
||||
featurizer: Initialized featurizer class that converts paths of
|
||||
audio to feature tensors
|
||||
max_duration: If audio exceeds this length, do not include in dataset
|
||||
min_duration: If audio is less than this length, do not include
|
||||
in dataset
|
||||
max_utts: Limit number of utterances
|
||||
blank_index: blank character index, default = -1
|
||||
unk_index: unk_character index, default = -1
|
||||
normalize: whether to normalize transcript text (default): True
|
||||
bos_id: Id of beginning of sequence symbol to append if not None
|
||||
eos_id: Id of end of sequence symbol to append if not None
|
||||
load_audio: Boolean flag indicate whether do or not load audio
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
featurizer,
|
||||
max_duration=None,
|
||||
min_duration=None,
|
||||
max_utts=0,
|
||||
blank_index=-1,
|
||||
unk_index=-1,
|
||||
normalize=True,
|
||||
trim=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
load_audio=True,
|
||||
parser="en",
|
||||
):
|
||||
self.collection = collections.ASRAudioText(
|
||||
manifests_files=manifest_filepath.split(","),
|
||||
parser=parsers.make_parser(
|
||||
labels=labels,
|
||||
name=parser,
|
||||
unk_id=unk_index,
|
||||
blank_id=blank_index,
|
||||
do_normalize=normalize,
|
||||
),
|
||||
min_duration=min_duration,
|
||||
max_duration=max_duration,
|
||||
max_number=max_utts,
|
||||
)
|
||||
self.index_feature_map = {}
|
||||
|
||||
self.featurizer = featurizer
|
||||
self.trim = trim
|
||||
self.eos_id = eos_id
|
||||
self.bos_id = bos_id
|
||||
self.load_audio = load_audio
|
||||
print(f"initializing dataset {manifest_filepath}")
|
||||
|
||||
def exec_func(i):
|
||||
return self[i]
|
||||
|
||||
task_count = len(self.collection)
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all loading tasks")
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, range(task_count)),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=task_count,
|
||||
)
|
||||
)
|
||||
print(f"initializing complete")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.collection[index]
|
||||
if self.load_audio:
|
||||
cached_features = self.index_feature_map.get(index)
|
||||
if cached_features is not None:
|
||||
features = cached_features
|
||||
else:
|
||||
features = self.featurizer.process(
|
||||
sample.audio_file,
|
||||
offset=0,
|
||||
duration=sample.duration,
|
||||
trim=self.trim,
|
||||
)
|
||||
self.index_feature_map[index] = features
|
||||
f, fl = features, torch.tensor(features.shape[0]).long()
|
||||
else:
|
||||
f, fl = None, None
|
||||
|
||||
t, tl = sample.text_tokens, len(sample.text_tokens)
|
||||
if self.bos_id is not None:
|
||||
t = [self.bos_id] + t
|
||||
tl += 1
|
||||
if self.eos_id is not None:
|
||||
t = t + [self.eos_id]
|
||||
tl += 1
|
||||
|
||||
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.collection)
|
||||
|
||||
|
||||
class RpycAudioToTextDataLayer(DataLayerNM):
|
||||
"""Data Layer for general ASR tasks.
|
||||
|
||||
Module which reads ASR labeled data. It accepts comma-separated
|
||||
JSON manifest files describing the correspondence between wav audio files
|
||||
and their transcripts. JSON files should be of the following format::
|
||||
|
||||
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
|
||||
transcript_0}
|
||||
...
|
||||
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
|
||||
transcript_n}
|
||||
|
||||
Args:
|
||||
manifest_filepath (str): Dataset parameter.
|
||||
Path to JSON containing data.
|
||||
labels (list): Dataset parameter.
|
||||
List of characters that can be output by the ASR model.
|
||||
For Jasper, this is the 28 character set {a-z '}. The CTC blank
|
||||
symbol is automatically added later for models using ctc.
|
||||
batch_size (int): batch size
|
||||
sample_rate (int): Target sampling rate for data. Audio files will be
|
||||
resampled to sample_rate if it is not already.
|
||||
Defaults to 16000.
|
||||
int_values (bool): Bool indicating whether the audio file is saved as
|
||||
int data or float data.
|
||||
Defaults to False.
|
||||
eos_id (id): Dataset parameter.
|
||||
End of string symbol id used for seq2seq models.
|
||||
Defaults to None.
|
||||
min_duration (float): Dataset parameter.
|
||||
All training files which have a duration less than min_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to 0.1.
|
||||
max_duration (float): Dataset parameter.
|
||||
All training files which have a duration more than max_duration
|
||||
are dropped. Note: Duration is read from the manifest JSON.
|
||||
Defaults to None.
|
||||
normalize_transcripts (bool): Dataset parameter.
|
||||
Whether to use automatic text cleaning.
|
||||
It is highly recommended to manually clean text for best results.
|
||||
Defaults to True.
|
||||
trim_silence (bool): Whether to use trim silence from beginning and end
|
||||
of audio signal using librosa.effects.trim().
|
||||
Defaults to False.
|
||||
load_audio (bool): Dataset parameter.
|
||||
Controls whether the dataloader loads the audio signal and
|
||||
transcript or just the transcript.
|
||||
Defaults to True.
|
||||
drop_last (bool): See PyTorch DataLoader.
|
||||
Defaults to False.
|
||||
shuffle (bool): See PyTorch DataLoader.
|
||||
Defaults to True.
|
||||
num_workers (int): See PyTorch DataLoader.
|
||||
Defaults to 0.
|
||||
perturb_config (dict): Currently disabled.
|
||||
"""
|
||||
|
||||
@property
|
||||
@add_port_docs()
|
||||
def output_ports(self):
|
||||
"""Returns definitions of module output ports.
|
||||
"""
|
||||
return {
|
||||
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
|
||||
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
|
||||
"audio_signal": NeuralType(
|
||||
("B", "T"),
|
||||
AudioSignal(freq=self._sample_rate)
|
||||
if self is not None and self._sample_rate is not None
|
||||
else AudioSignal(),
|
||||
),
|
||||
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
|
||||
"transcripts": NeuralType(("B", "T"), LabelsType()),
|
||||
"transcript_length": NeuralType(tuple("B"), LengthsType()),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_filepath,
|
||||
labels,
|
||||
batch_size,
|
||||
sample_rate=16000,
|
||||
int_values=False,
|
||||
bos_id=None,
|
||||
eos_id=None,
|
||||
pad_id=None,
|
||||
min_duration=0.1,
|
||||
max_duration=None,
|
||||
normalize_transcripts=True,
|
||||
trim_silence=False,
|
||||
load_audio=True,
|
||||
rpyc_host="",
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
):
|
||||
super().__init__()
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
def rpyc_root_fn():
|
||||
return rpyc.connect(
|
||||
rpyc_host, 8064, config={"sync_request_timeout": 600}
|
||||
).root
|
||||
|
||||
rpyc_conn = rpyc_root_fn()
|
||||
|
||||
self._featurizer = RpycWaveformFeaturizer(
|
||||
sample_rate=self._sample_rate,
|
||||
int_values=int_values,
|
||||
augmentor=None,
|
||||
rpyc_conn=rpyc_conn,
|
||||
)
|
||||
|
||||
def read_remote_manifests():
|
||||
local_mp = []
|
||||
for mrp in manifest_filepath.split(","):
|
||||
md = rpyc_conn.read_path(mrp)
|
||||
mf = tempfile.NamedTemporaryFile(
|
||||
dir="/tmp", prefix="jasper_manifest.", delete=False
|
||||
)
|
||||
mf.write(md)
|
||||
mf.close()
|
||||
local_mp.append(mf.name)
|
||||
return ",".join(local_mp)
|
||||
|
||||
local_manifest_filepath = read_remote_manifests()
|
||||
dataset_params = {
|
||||
"manifest_filepath": local_manifest_filepath,
|
||||
"labels": labels,
|
||||
"featurizer": self._featurizer,
|
||||
"max_duration": max_duration,
|
||||
"min_duration": min_duration,
|
||||
"normalize": normalize_transcripts,
|
||||
"trim": trim_silence,
|
||||
"bos_id": bos_id,
|
||||
"eos_id": eos_id,
|
||||
"load_audio": load_audio,
|
||||
}
|
||||
|
||||
self._dataset = CachedAudioDataset(**dataset_params)
|
||||
self._batch_size = batch_size
|
||||
|
||||
# Set up data loader
|
||||
if self._placement == DeviceType.AllGpu:
|
||||
logging.info("Parallelizing Datalayer.")
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(self._dataset)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
if batch_size == -1:
|
||||
batch_size = len(self._dataset)
|
||||
|
||||
pad_id = 0 if pad_id is None else pad_id
|
||||
self._dataloader = torch.utils.data.DataLoader(
|
||||
dataset=self._dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
|
||||
drop_last=drop_last,
|
||||
shuffle=shuffle if sampler is None else False,
|
||||
sampler=sampler,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def data_iterator(self):
|
||||
return self._dataloader
|
||||
51
jasper/training/featurizer.py
Normal file
51
jasper/training/featurizer.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# import math
|
||||
|
||||
# import librosa
|
||||
import torch
|
||||
import pickle
|
||||
# import torch.nn as nn
|
||||
# from torch_stft import STFT
|
||||
|
||||
# from nemo import logging
|
||||
from nemo.collections.asr.parts.perturb import AudioAugmentor
|
||||
# from nemo.collections.asr.parts.segment import AudioSegment
|
||||
|
||||
|
||||
class RpycWaveformFeaturizer(object):
|
||||
def __init__(
|
||||
self, sample_rate=16000, int_values=False, augmentor=None, rpyc_conn=None
|
||||
):
|
||||
self.augmentor = augmentor if augmentor is not None else AudioAugmentor()
|
||||
self.sample_rate = sample_rate
|
||||
self.int_values = int_values
|
||||
self.remote_path_samples = rpyc_conn.get_path_samples
|
||||
|
||||
def max_augmentation_length(self, length):
|
||||
return self.augmentor.max_augmentation_length(length)
|
||||
|
||||
def process(self, file_path, offset=0, duration=0, trim=False):
|
||||
audio = self.remote_path_samples(
|
||||
file_path,
|
||||
target_sr=self.sample_rate,
|
||||
int_values=self.int_values,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
trim=trim,
|
||||
)
|
||||
return torch.tensor(pickle.loads(audio), dtype=torch.float)
|
||||
|
||||
def process_segment(self, audio_segment):
|
||||
self.augmentor.perturb(audio_segment)
|
||||
return torch.tensor(audio_segment, dtype=torch.float)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, input_config, perturbation_configs=None):
|
||||
if perturbation_configs is not None:
|
||||
aa = AudioAugmentor.from_config(perturbation_configs)
|
||||
else:
|
||||
aa = None
|
||||
|
||||
sample_rate = input_config.get("sample_rate", 16000)
|
||||
int_values = input_config.get("int_values", False)
|
||||
|
||||
return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa)
|
||||
61
setup.py
61
setup.py
@@ -1,11 +1,54 @@
|
||||
from setuptools import setup
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
requirements = [
|
||||
"ruamel.yaml",
|
||||
"torch==2.8.0",
|
||||
"torchvision==0.5.0",
|
||||
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
||||
]
|
||||
|
||||
extra_requirements = {"server": ["rpyc==4.1.4"]}
|
||||
extra_requirements = {
|
||||
"server": ["rpyc~=4.1.4", "tqdm~=4.39.0"],
|
||||
"data": [
|
||||
"google-cloud-texttospeech~=1.0.1",
|
||||
"tqdm~=4.39.0",
|
||||
"pydub~=0.24.0",
|
||||
"scikit_learn~=0.22.1",
|
||||
"pandas~=1.0.3",
|
||||
"boto3~=1.12.35",
|
||||
"ruamel.yaml==0.16.10",
|
||||
"pymongo==3.10.1",
|
||||
"librosa==0.7.2",
|
||||
"numba==0.48",
|
||||
"matplotlib==3.2.1",
|
||||
"pandas==1.0.3",
|
||||
"tabulate==0.8.7",
|
||||
"natural==0.2.0",
|
||||
"num2words==0.5.10",
|
||||
"typer[all]==0.3.1",
|
||||
"python-slugify==4.0.0",
|
||||
"rpyc~=4.1.4",
|
||||
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
||||
],
|
||||
"validation": [
|
||||
"rpyc~=4.1.4",
|
||||
"pymongo==3.10.1",
|
||||
"typer[all]==0.1.1",
|
||||
"tqdm~=4.39.0",
|
||||
"librosa==0.7.2",
|
||||
"matplotlib==3.2.1",
|
||||
"pydub~=0.24.0",
|
||||
"streamlit==0.58.0",
|
||||
"natural==0.2.0",
|
||||
"stringcase==1.2.0",
|
||||
"google-cloud-speech~=1.3.1",
|
||||
]
|
||||
# "train": [
|
||||
# "torchaudio==0.5.0",
|
||||
# "torch-stft==0.1.4",
|
||||
# ]
|
||||
}
|
||||
packages = find_packages()
|
||||
|
||||
setup(
|
||||
name="jasper-asr",
|
||||
@@ -17,11 +60,21 @@ setup(
|
||||
license="MIT",
|
||||
install_requires=requirements,
|
||||
extras_require=extra_requirements,
|
||||
packages=["."],
|
||||
packages=packages,
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"jasper_transcribe = jasper.transcribe:main",
|
||||
"jasper_asr_rpyc_server = jasper.server:main",
|
||||
"jasper_server = jasper.server:main",
|
||||
"jasper_trainer = jasper.training.cli:main",
|
||||
"jasper_evaluator = jasper.evaluate:main",
|
||||
"jasper_data_tts_generate = jasper.data.tts_generator:main",
|
||||
"jasper_data_conv_generate = jasper.data.conv_generator:main",
|
||||
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
|
||||
"jasper_data_rastrik_recycle = jasper.data.rastrik_recycler:main",
|
||||
"jasper_data_server = jasper.data.server:main",
|
||||
"jasper_data_validation = jasper.data.validation.process:main",
|
||||
"jasper_data_preprocess = jasper.data.process:main",
|
||||
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
|
||||
]
|
||||
},
|
||||
zip_safe=False,
|
||||
|
||||
3
validation_ui.py
Normal file
3
validation_ui.py
Normal file
@@ -0,0 +1,3 @@
|
||||
import runpy
|
||||
|
||||
runpy.run_module("jasper.data.validation.ui", run_name="__main__", alter_sys=True)
|
||||
Reference in New Issue
Block a user