1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-03-09 19:02:35 +00:00

1. added training utils with custom data loaders with remote rpyc dataservice support

2. fix validation correction dump path
3. cache dataset for precaching before training to memory
4. update dependencies
This commit is contained in:
2020-05-14 15:39:44 +05:30
parent d4aef4088d
commit 83db445a6f
7 changed files with 419 additions and 13 deletions

View File

@@ -1,24 +1,50 @@
import os
# from pathlib import Path
from pathlib import Path
import typer
import rpyc
from rpyc.utils.server import ThreadedServer
import nemo.collections.asr as nemo_asr
import nemo
import pickle
# import nemo.collections.asr as nemo_asr
from nemo.collections.asr.parts.segment import AudioSegment
app = typer.Typer()
nemo.core.NeuralModuleFactory(
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
)
class ASRDataService(rpyc.Service):
def get_data_loader(self):
return nemo_asr.AudioToTextDataLayer
def exposed_get_path_samples(
self, file_path, target_sr, int_values, offset, duration, trim
):
print(f"loading.. {file_path}")
audio = AudioSegment.from_file(
file_path,
target_sr=target_sr,
int_values=int_values,
offset=offset,
duration=duration,
trim=trim,
)
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
return pickle.dumps(audio.samples)
def exposed_read_path(self, file_path):
# print(f"reading path.. {file_path}")
return Path(file_path).read_bytes()
@app.command()
def run_server(port: int = 0):
listen_port = port if port else int(os.environ.get("ASR_RPYC_PORT", "8044"))
listen_port = port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
service = ASRDataService()
t = ThreadedServer(service, port=listen_port)
t = ThreadedServer(
service, port=listen_port, protocol_config={"allow_all_attrs": True}
)
typer.echo(f"starting asr server on {listen_port}...")
t.start()

View File

@@ -113,7 +113,7 @@ def dump_validation_ui_data(
@app.command()
def dump_corrections(dump_path: Path = Path("./data/corrections.json")):
def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")):
col = get_mongo_conn().test.asr_validation
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})