1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-03-08 02:22:34 +00:00

Compare commits

..

6 Commits

Author SHA1 Message Date
e30dd724f5 Merge pull request #3 from malarinv/dependabot/pip/torch-2.8.0
Bump torch from 1.4.0 to 2.8.0
2025-08-30 18:24:04 +05:30
dependabot[bot]
02df1b5282 Bump torch from 1.4.0 to 2.8.0
Bumps [torch](https://github.com/pytorch/pytorch) from 1.4.0 to 2.8.0.
- [Release notes](https://github.com/pytorch/pytorch/releases)
- [Changelog](https://github.com/pytorch/pytorch/blob/main/RELEASE.md)
- [Commits](https://github.com/pytorch/pytorch/compare/v1.4.0...v2.8.0)

---
updated-dependencies:
- dependency-name: torch
  dependency-version: 2.8.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-08-30 12:53:51 +00:00
e8f58a5043 1. refactored ui_dump
2. added flake8
2020-08-09 19:16:35 +05:30
42647196fe 1. fixed dependency issues
2. add task-id option to validation ui to respawn previous task
3. clean-up rastrik-recycler
2020-08-06 22:40:14 +05:30
e77943b2f2 Merge pull request #1 from wrat/master
adding support for asr data generator
2020-08-06 00:11:53 +05:30
wabi_sabi004
14d31a51c3 adding support for asr data generator 2020-08-06 00:08:46 +05:30
7 changed files with 263 additions and 188 deletions

4
.flake8 Normal file
View File

@@ -0,0 +1,4 @@
[flake8]
exclude = docs
ignore = E203, W503
max-line-length = 119

View File

@@ -7,10 +7,16 @@
# Table of Contents
* [Prerequisites](#prerequisites)
* [Features](#features)
* [Installation](#installation)
* [Usage](#usage)
# Prerequisites
```bash
# apt install libsndfile-dev ffmpeg
```
# Features
* ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) )

View File

@@ -0,0 +1,93 @@
from rastrik.proto.callrecord_pb2 import CallRecord
import gzip
from pydub import AudioSegment
from .utils import ui_dump_manifest_writer, strip_silence
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
app = typer.Typer()
@app.command()
def extract_manifest(
call_log_dir: Path = Path("./data/call_audio"),
output_dir: Path = Path("./data"),
dataset_name: str = "grassroot_pizzahut_v1",
caller_name: str = "grassroot",
verbose: bool = False,
):
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_pb2_generator(log_dir):
for wav_path in log_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
meta_path = wav_path.with_suffix(".pb2.gz")
yield call_wav, wav_path, meta_path
def read_event(call_wav, log_file):
call_wav_0, call_wav_1 = call_wav.split_to_mono()
with gzip.open(log_file, "rb") as log_h:
record_data = log_h.read()
cr = CallRecord()
cr.ParseFromString(record_data)
first_audio_event_timestamp = next(
(
i
for i in cr.events
if i.WhichOneof("event_type") == "call_event"
and i.call_event.WhichOneof("event_type") == "call_audio"
)
).timestamp.ToDatetime()
speech_events = [
i
for i in cr.events
if i.WhichOneof("event_type") == "speech_event"
and i.speech_event.WhichOneof("event_type") == "asr_final"
]
previous_event_timestamp = (
first_audio_event_timestamp - first_audio_event_timestamp
)
for index, each_speech_events in enumerate(speech_events):
asr_final = each_speech_events.speech_event.asr_final
speech_timestamp = each_speech_events.timestamp.ToDatetime()
actual_timestamp = speech_timestamp - first_audio_event_timestamp
start_time = previous_event_timestamp.total_seconds() * 1000
end_time = actual_timestamp.total_seconds() * 1000
audio_segment = strip_silence(call_wav_1[start_time:end_time])
code_fb = BytesIO()
audio_segment.export(code_fb, format="wav")
wav_data = code_fb.getvalue()
previous_event_timestamp = actual_timestamp
duration = (end_time - start_time) / 1000
yield asr_final, duration, wav_data, "grassroot", audio_segment
def generate_call_asr_data():
full_data = []
total_duration = 0
for wav, wav_path, pb2_path in wav_pb2_generator(call_log_dir):
asr_data = read_event(wav, pb2_path)
total_duration += wav.duration_seconds
full_data.append(asr_data)
n_calls = len(full_data)
typer.echo(f"loaded {n_calls} calls of duration {total_duration}s")
n_dps = ui_dump_manifest_writer(call_asr_data, dataset_name, chain(*full_data))
typer.echo(f"written {n_dps} data points")
generate_call_asr_data()
def main():
app()
if __name__ == "__main__":
main()

View File

@@ -58,23 +58,10 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
return num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
ui_dump_file = dataset_dir / Path("ui_dump.json")
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
ui_dump = {
"use_domain_asr": False,
"annotation_only": False,
"enable_plots": True,
"data": [],
}
data_funcs = []
transcriber_pretrained = transcribe_gen(asr_port=8044)
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
def data_fn(
transcript,
@@ -89,9 +76,8 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
)
png_path = Path(fname).with_suffix(".png")
wav_plot_path = dataset_dir / Path("wav_plots") / png_path
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
return {
@@ -108,14 +94,15 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"plot_path": str(wav_plot_path),
}
num_datapoints = 0
data_funcs = []
transcriber_pretrained = transcribe_gen(asr_port=8044)
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
audio_path = str(audio_file)
rel_data_path = audio_file.relative_to(dataset_dir)
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
data_funcs.append(
partial(
data_fn,
@@ -131,10 +118,28 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
)
)
num_datapoints += 1
dump_data = parallel_apply(lambda x: x(), data_funcs)
# dump_data = [x() for x in tqdm(data_funcs)]
ui_dump["data"] = dump_data
ExtendedPath(ui_dump_file).write_json(ui_dump)
ui_data = parallel_apply(lambda x: x(), data_funcs)
return ui_data, num_datapoints
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
dump_data, num_datapoints = ui_data_generator(
output_dir, dataset_name, asr_data_source, verbose=verbose
)
asr_manifest = dataset_dir / Path("manifest.json")
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for d in dump_data:
rel_data_path = d["audio_filepath"]
audio_dur = d["duration"]
transcript = d["text"]
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
mf.write(manifest)
ui_dump_file = dataset_dir / Path("ui_dump.json")
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
return num_datapoints

View File

@@ -1,13 +1,10 @@
import json
import shutil
from pathlib import Path
from enum import Enum
import typer
from tqdm import tqdm
from ..utils import (
alnum_to_asr_tokens,
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
@@ -19,9 +16,7 @@ from ..utils import (
app = typer.Typer()
def preprocess_datapoint(
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
):
def preprocess_datapoint(idx, rel_root, sample):
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.client import transcribe_gen
@@ -31,12 +26,7 @@ def preprocess_datapoint(
res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
if use_domain_asr:
res["spoken"] = alnum_to_asr_tokens(res["text"])
else:
res["spoken"] = res["text"]
res["utterance_id"] = audio_path.stem
if not annotation_only:
transcriber_pretrained = transcribe_gen(asr_port=8044)
aud_seg = (
@@ -46,16 +36,7 @@ def preprocess_datapoint(
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate(
[res["text"]], [res["pretrained_asr"]]
)
if use_domain_asr:
transcriber_speller = transcribe_gen(asr_port=8045)
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]]
)
if enable_plots:
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
@@ -73,61 +54,50 @@ def dump_ui(
dataset_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/valiation_data"),
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
use_domain_asr: bool = False,
annotation_only: bool = False,
enable_plots: bool = True,
):
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from io import BytesIO
from pydub import AudioSegment
from ..utils import ui_data_generator
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
dump_path: Path = dump_dir / Path(data_name) / dump_fname
plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}")
def asr_data_source_gen():
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
data_funcs = [
partial(
preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
annotation_only,
enable_plots,
for v in data_jsonl:
sample = json.loads(v)
rel_root = data_manifest_path.parent
res = dict(sample)
audio_path = rel_root / Path(sample["audio_filepath"])
audio_segment = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
for i, v in enumerate(data_jsonl)
]
wav_plot_path = (
rel_root
/ Path("wav_plots")
/ Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
code_fb = BytesIO()
audio_segment.export(code_fb, format="wav")
wav_data = code_fb.getvalue()
duration = audio_segment.duration_seconds
asr_final = res["text"]
yield asr_final, duration, wav_data, "caller", audio_segment
def exec_func(f):
return f()
with ThreadPoolExecutor() as exe:
print("starting all preprocess tasks")
data_final = filter(
None,
list(
tqdm(
exe.map(exec_func, data_funcs),
position=0,
leave=True,
total=len(data_funcs),
dump_data, num_datapoints = ui_data_generator(
dataset_dir, data_name, asr_data_source_gen()
)
),
)
if annotation_only:
result = list(data_final)
else:
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(data_final, key=lambda x: x[wer_key], reverse=True)
ui_config = {
"use_domain_asr": use_domain_asr,
"annotation_only": annotation_only,
"enable_plots": enable_plots,
"data": result,
}
ExtendedPath(dump_path).write_json(ui_config)
ui_dump_file = dataset_dir / Path("ui_dump.json")
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
@app.command()
@@ -190,7 +160,9 @@ def dump_corrections(
col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
cursor_obj = col.find(
{"type": "correction", "task_id": task_id}, projection={"_id": False}
)
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@@ -264,7 +236,9 @@ def split_extract(
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
conv_data_path: Path = typer.Option(
Path("./data/conv_data.json"), show_default=True
),
extraction_type: str = "all",
):
import shutil
@@ -286,7 +260,9 @@ def split_extract(
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
shutil.copy(
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
)
yield m
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
@@ -295,12 +271,14 @@ def split_extract(
orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
extracted_ui_data = list(
filter(lambda u: u["text"] in extraction_vals, ui_data)
)
final_data = []
for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i
d["real_idx"] = i
final_data.append(d)
orig_ui_data['data'] = final_data
orig_ui_data["data"] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file:
@@ -316,7 +294,7 @@ def split_extract(
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all':
if extraction_type.value == "all":
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
@@ -338,7 +316,7 @@ def update_corrections(
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
@@ -367,7 +345,9 @@ def update_corrections(
)
else:
orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
new_name = str(
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
)
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))

View File

@@ -42,7 +42,9 @@ if not hasattr(st, "mongo_connected"):
upsert=True,
)
def set_task_fn(mf_path):
def set_task_fn(mf_path, task_id):
if task_id:
st.task_id = task_id
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
if not task_path.exists():
print(f"creating task lock at {task_path}")
@@ -66,26 +68,22 @@ def load_ui_data(validation_ui_data_path: Path):
@app.command()
def main(manifest: Path):
st.set_task(manifest)
def main(manifest: Path, task_id: str = ""):
st.set_task(manifest, task_id)
ui_config = load_ui_data(manifest)
asr_data = ui_config["data"]
use_domain_asr = ui_config.get("use_domain_asr", True)
annotation_only = ui_config.get("annotation_only", False)
enable_plots = ui_config.get("enable_plots", True)
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
title_type = "Speller " if use_domain_asr else ""
task_uid = st.task_id.rsplit("-", 1)[1]
if annotation_only:
st.title(f"ASR Annotation - # {task_uid}")
else:
st.title(f"ASR {title_type}Validation - # {task_uid}")
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
st.title(f"ASR Validation - # {task_uid}")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
@@ -94,18 +92,12 @@ def main(manifest: Path):
st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
if not annotation_only:
if use_domain_asr:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:")
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if "caller" in sample:
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
if enable_plots:
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text
@@ -128,16 +120,12 @@ def main(manifest: Path):
)
st.update_cursor(sample_no + 1)
if correction_entry:
st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
)
status = correction_entry["value"]["status"]
correction = correction_entry["value"]["correction"]
st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
text_sample = st.text_input("Go to Text:", value="")
if text_sample != "":
candidates = [
i
for (i, p) in enumerate(asr_data)
if p["text"] == text_sample or p["spoken"] == text_sample
]
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
if len(candidates) > 0:
st.update_cursor(candidates[0])
real_idx = st.number_input(

View File

@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
requirements = [
"ruamel.yaml",
"torch==1.4.0",
"torch==2.8.0",
"torchvision==0.5.0",
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
]
@@ -19,13 +19,15 @@ extra_requirements = {
"ruamel.yaml==0.16.10",
"pymongo==3.10.1",
"librosa==0.7.2",
"numba==0.48",
"matplotlib==3.2.1",
"pandas==1.0.3",
"tabulate==0.8.7",
"natural==0.2.0",
"num2words==0.5.10",
"typer[all]==0.1.1",
"typer[all]==0.3.1",
"python-slugify==4.0.0",
"rpyc~=4.1.4",
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
],
"validation": [
@@ -68,10 +70,7 @@ setup(
"jasper_data_tts_generate = jasper.data.tts_generator:main",
"jasper_data_conv_generate = jasper.data.conv_generator:main",
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
"jasper_data_test_generate = jasper.data.test_generator:main",
"jasper_data_call_recycle = jasper.data.call_recycler:main",
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
"jasper_data_rastrik_recycle = jasper.data.rastrik_recycler:main",
"jasper_data_server = jasper.data.server:main",
"jasper_data_validation = jasper.data.validation.process:main",
"jasper_data_preprocess = jasper.data.process:main",