mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-03-08 02:22:34 +00:00
Compare commits
42 Commits
master
...
f5c49338d9
| Author | SHA1 | Date | |
|---|---|---|---|
| f5c49338d9 | |||
| fa89775f86 | |||
| ae5586be72 | |||
| 069392d098 | |||
| 515e9c1037 | |||
| e76ccda5dd | |||
| 000853b600 | |||
| ac0e04c226 | |||
| 62eefb9294 | |||
| 8e238c254e | |||
| 7dbb04dcbf | |||
| 7472b6457d | |||
| 120302aad3 | |||
| a7a25e9b07 | |||
| 6d149d282d | |||
| 8db1be0083 | |||
| bca227a7d7 | |||
| e3a01169c2 | |||
| 3a5ce069ab | |||
| 9f9cb62b60 | |||
| de21952349 | |||
| d87369c8fe | |||
| 41af0a87de | |||
| 6f395af10d | |||
| a38789d0c3 | |||
| 7ff2db3e2e | |||
| 1acf9e403c | |||
| 1f2bedc156 | |||
| fca9c1aeb3 | |||
| 2d5b720284 | |||
| 8e79bbb571 | |||
| 83db445a6f | |||
| d4aef4088d | |||
| fdccea6b23 | |||
| c06a0814b9 | |||
| a7da729c0b | |||
| aae03a6ae4 | |||
| 4fd05a56d0 | |||
| 41074a1bca | |||
| 61048f855e | |||
| 2c15b00da3 | |||
| d22a99a4f6 |
4
.flake8
4
.flake8
@@ -1,4 +0,0 @@
|
|||||||
[flake8]
|
|
||||||
exclude = docs
|
|
||||||
ignore = E203, W503
|
|
||||||
max-line-length = 119
|
|
||||||
@@ -7,16 +7,10 @@
|
|||||||
|
|
||||||
# Table of Contents
|
# Table of Contents
|
||||||
|
|
||||||
* [Prerequisites](#prerequisites)
|
|
||||||
* [Features](#features)
|
* [Features](#features)
|
||||||
* [Installation](#installation)
|
* [Installation](#installation)
|
||||||
* [Usage](#usage)
|
* [Usage](#usage)
|
||||||
|
|
||||||
# Prerequisites
|
|
||||||
```bash
|
|
||||||
# apt install libsndfile-dev ffmpeg
|
|
||||||
```
|
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
|
|
||||||
* ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) )
|
* ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) )
|
||||||
|
|||||||
@@ -2,6 +2,10 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
import rpyc
|
import rpyc
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
import typer
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
@@ -19,3 +23,28 @@ def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
|
|||||||
asr = rpyc.connect(asr_host, asr_port).root
|
asr = rpyc.connect(asr_host, asr_port).root
|
||||||
logger.info(f"connected to asr server successfully")
|
logger.info(f"connected to asr server successfully")
|
||||||
return asr.transcribe
|
return asr.transcribe
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def transcribe_file(audio_file: Path):
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
transcriber = transcribe_gen()
|
||||||
|
aud_seg = (
|
||||||
|
AudioSegment.from_file_using_temporary_files(audio_file)
|
||||||
|
.set_channels(1)
|
||||||
|
.set_sample_width(2)
|
||||||
|
.set_frame_rate(24000)
|
||||||
|
)
|
||||||
|
tscript_file_path = audio_file.with_suffix(".txt")
|
||||||
|
transcription = transcriber(aud_seg.raw_data)
|
||||||
|
with open(tscript_file_path, "w") as tf:
|
||||||
|
tf.write(transcription)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
104
jasper/data/asr_recycler.py
Normal file
104
jasper/data/asr_recycler.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import typer
|
||||||
|
from itertools import chain
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def extract_data(
|
||||||
|
call_audio_dir: Path = Path("/dataset/png_prod/call_audio"),
|
||||||
|
call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"),
|
||||||
|
output_dir: Path = Path("./data"),
|
||||||
|
dataset_name: str = "png_gcp_2jan",
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
|
from pydub import AudioSegment
|
||||||
|
from .utils import ExtendedPath, asr_data_writer, strip_silence
|
||||||
|
from lenses import lens
|
||||||
|
|
||||||
|
call_asr_data: Path = output_dir / Path("asr_data")
|
||||||
|
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
def wav_event_generator(call_audio_dir):
|
||||||
|
for wav_path in call_audio_dir.glob("**/*.wav"):
|
||||||
|
if verbose:
|
||||||
|
typer.echo(f"loading events for file {wav_path}")
|
||||||
|
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
|
||||||
|
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
|
||||||
|
meta_path = call_meta_dir / rel_meta_path
|
||||||
|
events = ExtendedPath(meta_path).read_json()
|
||||||
|
yield call_wav, wav_path, events
|
||||||
|
|
||||||
|
def contains_asr(x):
|
||||||
|
return "AsrResult" in x
|
||||||
|
|
||||||
|
def channel(n):
|
||||||
|
def filter_func(ev):
|
||||||
|
return (
|
||||||
|
ev["AsrResult"]["Channel"] == n
|
||||||
|
if "Channel" in ev["AsrResult"]
|
||||||
|
else n == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return filter_func
|
||||||
|
|
||||||
|
def compute_endtime(call_wav, state):
|
||||||
|
for (i, st) in enumerate(state):
|
||||||
|
start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0)
|
||||||
|
transcript = st["AsrResult"]["Alternatives"][0]["Transcript"]
|
||||||
|
if i + 1 < len(state):
|
||||||
|
end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
|
||||||
|
else:
|
||||||
|
end_time = call_wav.duration_seconds
|
||||||
|
full_code_seg = call_wav[start_time * 1000 : end_time * 1000]
|
||||||
|
code_seg = strip_silence(full_code_seg)
|
||||||
|
code_fb = BytesIO()
|
||||||
|
code_seg.export(code_fb, format="wav")
|
||||||
|
code_wav = code_fb.getvalue()
|
||||||
|
# only starting 1 min audio has reliable alignment ignore rest
|
||||||
|
if start_time > 60:
|
||||||
|
if verbose:
|
||||||
|
print(f'start time over 60 seconds of audio skipping.')
|
||||||
|
break
|
||||||
|
# only if some reasonable audio data is present yield it
|
||||||
|
if code_seg.duration_seconds < 0.5:
|
||||||
|
if verbose:
|
||||||
|
print(f'transcript chunk "{transcript}" contains no audio skipping.')
|
||||||
|
continue
|
||||||
|
yield transcript, code_seg.duration_seconds, code_wav
|
||||||
|
|
||||||
|
def asr_data_generator(call_wav, call_wav_fname, events):
|
||||||
|
call_wav_0, call_wav_1 = call_wav.split_to_mono()
|
||||||
|
asr_events = lens["Events"].Each()["Event"].Filter(contains_asr)
|
||||||
|
call_evs_0 = asr_events.Filter(channel(0)).collect()(events)
|
||||||
|
# Ignoring agent channel events
|
||||||
|
# call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
|
||||||
|
if verbose:
|
||||||
|
typer.echo(f"processing data points on {call_wav_fname}")
|
||||||
|
call_data_0 = compute_endtime(call_wav_0, call_evs_0)
|
||||||
|
# Ignoring agent channel
|
||||||
|
# call_data_1 = compute_endtime(call_wav_1, call_evs_1)
|
||||||
|
return call_data_0 # chain(call_data_0, call_data_1)
|
||||||
|
|
||||||
|
def generate_call_asr_data():
|
||||||
|
full_asr_data = []
|
||||||
|
total_duration = 0
|
||||||
|
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
|
||||||
|
asr_data = asr_data_generator(wav, wav_path, ev)
|
||||||
|
total_duration += wav.duration_seconds
|
||||||
|
full_asr_data.append(asr_data)
|
||||||
|
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
|
||||||
|
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
|
||||||
|
typer.echo(f"written {n_dps} data points")
|
||||||
|
|
||||||
|
generate_call_asr_data()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
509
jasper/data/call_recycler.py
Normal file
509
jasper/data/call_recycler.py
Normal file
@@ -0,0 +1,509 @@
|
|||||||
|
import typer
|
||||||
|
from pathlib import Path
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def export_all_logs(
|
||||||
|
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||||
|
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
|
||||||
|
):
|
||||||
|
from .utils import get_mongo_conn
|
||||||
|
from collections import defaultdict
|
||||||
|
from ruamel.yaml import YAML
|
||||||
|
|
||||||
|
yaml = YAML()
|
||||||
|
mongo_coll = get_mongo_conn()
|
||||||
|
caller_calls = defaultdict(lambda: [])
|
||||||
|
for call in mongo_coll.find():
|
||||||
|
sysid = call["SystemID"]
|
||||||
|
call_uri = f"http://{domain}/calls/{sysid}"
|
||||||
|
caller = call["Caller"]
|
||||||
|
caller_calls[caller].append(call_uri)
|
||||||
|
caller_list = []
|
||||||
|
for caller in caller_calls:
|
||||||
|
caller_list.append({"name": caller, "calls": caller_calls[caller]})
|
||||||
|
output_yaml = {"users": caller_list}
|
||||||
|
typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
|
||||||
|
with call_logs_file.open("w") as yf:
|
||||||
|
yaml.dump(output_yaml, yf)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def export_calls_between(
|
||||||
|
start_cid: str,
|
||||||
|
end_cid: str,
|
||||||
|
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||||
|
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
|
||||||
|
mongo_port: int = 27017,
|
||||||
|
):
|
||||||
|
from collections import defaultdict
|
||||||
|
from ruamel.yaml import YAML
|
||||||
|
from .utils import get_mongo_conn
|
||||||
|
|
||||||
|
yaml = YAML()
|
||||||
|
mongo_coll = get_mongo_conn(port=mongo_port)
|
||||||
|
start_meta = mongo_coll.find_one({"SystemID": start_cid})
|
||||||
|
end_meta = mongo_coll.find_one({"SystemID": end_cid})
|
||||||
|
|
||||||
|
caller_calls = defaultdict(lambda: [])
|
||||||
|
call_query = mongo_coll.find(
|
||||||
|
{
|
||||||
|
"StartTS": {"$gte": start_meta["StartTS"]},
|
||||||
|
"EndTS": {"$lte": end_meta["EndTS"]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for call in call_query:
|
||||||
|
sysid = call["SystemID"]
|
||||||
|
call_uri = f"http://{domain}/calls/{sysid}"
|
||||||
|
caller = call["Caller"]
|
||||||
|
caller_calls[caller].append(call_uri)
|
||||||
|
caller_list = []
|
||||||
|
for caller in caller_calls:
|
||||||
|
caller_list.append({"name": caller, "calls": caller_calls[caller]})
|
||||||
|
output_yaml = {"users": caller_list}
|
||||||
|
typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
|
||||||
|
with call_logs_file.open("w") as yf:
|
||||||
|
yaml.dump(output_yaml, yf)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def copy_metas(
|
||||||
|
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||||
|
output_dir: Path = Path("./data"),
|
||||||
|
meta_dir: Path = Path("/tmp/call_metas"),
|
||||||
|
):
|
||||||
|
from lenses import lens
|
||||||
|
from ruamel.yaml import YAML
|
||||||
|
from urllib.parse import urlsplit
|
||||||
|
from shutil import copy2
|
||||||
|
|
||||||
|
yaml = YAML()
|
||||||
|
call_logs = yaml.load(call_logs_file.read_text())
|
||||||
|
|
||||||
|
call_meta_dir: Path = output_dir / Path("call_metas")
|
||||||
|
call_meta_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
meta_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
def get_cid(uri):
|
||||||
|
return Path(urlsplit(uri).path).stem
|
||||||
|
|
||||||
|
def copy_meta(uri):
|
||||||
|
cid = get_cid(uri)
|
||||||
|
saved_meta_path = call_meta_dir / Path(f"{cid}.json")
|
||||||
|
dest_meta_path = meta_dir / Path(f"{cid}.json")
|
||||||
|
if not saved_meta_path.exists():
|
||||||
|
print(f"{saved_meta_path} not found")
|
||||||
|
copy2(saved_meta_path, dest_meta_path)
|
||||||
|
|
||||||
|
def download_meta_audio():
|
||||||
|
call_lens = lens["users"].Each()["calls"].Each()
|
||||||
|
call_lens.modify(copy_meta)(call_logs)
|
||||||
|
|
||||||
|
download_meta_audio()
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionType(str, Enum):
|
||||||
|
flow = "flow"
|
||||||
|
data = "data"
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def analyze(
|
||||||
|
leaderboard: bool = False,
|
||||||
|
plot_calls: bool = False,
|
||||||
|
extract_data: bool = False,
|
||||||
|
extraction_type: ExtractionType = typer.Option(
|
||||||
|
ExtractionType.data, show_default=True
|
||||||
|
),
|
||||||
|
start_delay: float = 1.5,
|
||||||
|
download_only: bool = False,
|
||||||
|
strip_silent_chunks: bool = True,
|
||||||
|
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||||
|
output_dir: Path = Path("./data"),
|
||||||
|
data_name: str = None,
|
||||||
|
mongo_uri: str = typer.Option(
|
||||||
|
"mongodb://localhost:27017/test.calls", show_default=True
|
||||||
|
),
|
||||||
|
):
|
||||||
|
|
||||||
|
from urllib.parse import urlsplit
|
||||||
|
from functools import reduce
|
||||||
|
import boto3
|
||||||
|
from io import BytesIO
|
||||||
|
import json
|
||||||
|
from ruamel.yaml import YAML
|
||||||
|
import re
|
||||||
|
from google.protobuf.timestamp_pb2 import Timestamp
|
||||||
|
from datetime import timedelta
|
||||||
|
import librosa
|
||||||
|
import librosa.display
|
||||||
|
from lenses import lens
|
||||||
|
from pprint import pprint
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib
|
||||||
|
from tqdm import tqdm
|
||||||
|
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll, get_call_logs
|
||||||
|
from pydub import AudioSegment
|
||||||
|
from natural.date import compress
|
||||||
|
|
||||||
|
matplotlib.rcParams["agg.path.chunksize"] = 10000
|
||||||
|
|
||||||
|
matplotlib.use("agg")
|
||||||
|
|
||||||
|
yaml = YAML()
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
mongo_collection = get_mongo_coll(mongo_uri)
|
||||||
|
call_media_dir: Path = output_dir / Path("call_wavs")
|
||||||
|
call_media_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
call_meta_dir: Path = output_dir / Path("call_metas")
|
||||||
|
call_meta_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
call_plot_dir: Path = output_dir / Path("plots")
|
||||||
|
call_plot_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
call_asr_data: Path = output_dir / Path("asr_data")
|
||||||
|
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||||
|
dataset_name = call_logs_file.stem if not data_name else data_name
|
||||||
|
|
||||||
|
call_logs = yaml.load(call_logs_file.read_text())
|
||||||
|
|
||||||
|
def gen_ev_fev_timedelta(fev):
|
||||||
|
fev_p = Timestamp()
|
||||||
|
fev_p.FromJsonString(fev["CreatedTS"])
|
||||||
|
fev_dt = fev_p.ToDatetime()
|
||||||
|
td_0 = timedelta()
|
||||||
|
|
||||||
|
def get_timedelta(ev):
|
||||||
|
ev_p = Timestamp()
|
||||||
|
ev_p.FromJsonString(value=ev["CreatedTS"])
|
||||||
|
ev_dt = ev_p.ToDatetime()
|
||||||
|
delta = ev_dt - fev_dt
|
||||||
|
return delta if delta > td_0 else td_0
|
||||||
|
|
||||||
|
return get_timedelta
|
||||||
|
|
||||||
|
def chunk_n(evs, n):
|
||||||
|
return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
|
||||||
|
|
||||||
|
if extraction_type == ExtractionType.data:
|
||||||
|
|
||||||
|
def is_utter_event(ev):
|
||||||
|
return (
|
||||||
|
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
|
||||||
|
and (ev["Type"] != "DEBUG")
|
||||||
|
and ev["Type"] != "ASR_RESULT"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_data_points(utter_events, td_fn):
|
||||||
|
data_points = []
|
||||||
|
for evs in chunk_n(utter_events, 3):
|
||||||
|
try:
|
||||||
|
assert evs[0]["Type"] == "CONV_RESULT"
|
||||||
|
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
||||||
|
assert evs[2]["Type"] == "STOPPED_SPEAKING"
|
||||||
|
start_time = td_fn(evs[1]).total_seconds() - start_delay
|
||||||
|
end_time = td_fn(evs[2]).total_seconds()
|
||||||
|
spoken = evs[0]["Msg"]
|
||||||
|
data_points.append(
|
||||||
|
{"start_time": start_time, "end_time": end_time, "code": spoken}
|
||||||
|
)
|
||||||
|
except AssertionError:
|
||||||
|
# skipping invalid data_points
|
||||||
|
pass
|
||||||
|
return data_points
|
||||||
|
|
||||||
|
def text_extractor(spoken):
|
||||||
|
return (
|
||||||
|
re.search(r"'(.*)'", spoken).groups(0)[0]
|
||||||
|
if len(spoken) > 6 and re.search(r"'(.*)'", spoken)
|
||||||
|
else spoken
|
||||||
|
)
|
||||||
|
|
||||||
|
elif extraction_type == ExtractionType.flow:
|
||||||
|
|
||||||
|
def is_final_asr_event_or_spoken(ev):
|
||||||
|
pld = json.loads(ev["Payload"])
|
||||||
|
return (
|
||||||
|
pld["AsrResult"]["Results"][0]["IsFinal"]
|
||||||
|
if ev["Type"] == "ASR_RESULT"
|
||||||
|
else True
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_utter_event(ev):
|
||||||
|
return (
|
||||||
|
ev["Author"] == "CONV"
|
||||||
|
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
|
||||||
|
) and (ev["Type"] != "DEBUG")
|
||||||
|
|
||||||
|
def get_data_points(utter_events, td_fn):
|
||||||
|
data_points = []
|
||||||
|
for evs in chunk_n(utter_events, 4):
|
||||||
|
try:
|
||||||
|
assert len(evs) == 4
|
||||||
|
assert evs[0]["Type"] == "CONV_RESULT"
|
||||||
|
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
||||||
|
assert evs[2]["Type"] == "ASR_RESULT"
|
||||||
|
assert evs[3]["Type"] == "STOPPED_SPEAKING"
|
||||||
|
start_time = td_fn(evs[1]).total_seconds() - start_delay
|
||||||
|
end_time = td_fn(evs[2]).total_seconds()
|
||||||
|
conv_msg = evs[0]["Msg"]
|
||||||
|
if "full name" in conv_msg.lower():
|
||||||
|
pld = json.loads(evs[2]["Payload"])
|
||||||
|
spoken = pld["AsrResult"]["Results"][0]["Alternatives"][0][
|
||||||
|
"Transcript"
|
||||||
|
]
|
||||||
|
data_points.append(
|
||||||
|
{
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
"code": spoken,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except AssertionError:
|
||||||
|
# skipping invalid data_points
|
||||||
|
pass
|
||||||
|
return data_points
|
||||||
|
|
||||||
|
def text_extractor(spoken):
|
||||||
|
return spoken
|
||||||
|
|
||||||
|
def process_call(call_obj):
|
||||||
|
call_meta = get_call_logs(call_obj, s3, call_meta_dir)
|
||||||
|
call_events = call_meta["Events"]
|
||||||
|
|
||||||
|
def is_writer_uri_event(ev):
|
||||||
|
return ev["Author"] == "AUDIO_WRITER" and "s3://" in ev["Msg"]
|
||||||
|
|
||||||
|
writer_events = list(filter(is_writer_uri_event, call_events))
|
||||||
|
s3_wav_url = re.search(r"(s3://.*)", writer_events[0]["Msg"]).groups(0)[0]
|
||||||
|
s3_wav_url_p = urlsplit(s3_wav_url)
|
||||||
|
|
||||||
|
def is_first_audio_ev(state, ev):
|
||||||
|
if state[0]:
|
||||||
|
return state
|
||||||
|
else:
|
||||||
|
return (ev["Author"] == "GATEWAY" and ev["Type"] == "AUDIO", ev)
|
||||||
|
|
||||||
|
(_, first_audio_ev) = reduce(is_first_audio_ev, call_events, (False, {}))
|
||||||
|
|
||||||
|
get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
|
||||||
|
|
||||||
|
uevs = list(filter(is_utter_event, call_events))
|
||||||
|
ev_count = len(uevs)
|
||||||
|
utter_events = uevs[: ev_count - ev_count % 3]
|
||||||
|
saved_wav_path = call_media_dir / Path(Path(s3_wav_url_p.path).name)
|
||||||
|
if not saved_wav_path.exists():
|
||||||
|
print(f"downloading : {saved_wav_path} from {s3_wav_url}")
|
||||||
|
s3.download_file(
|
||||||
|
s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"wav_path": saved_wav_path,
|
||||||
|
"num_samples": len(utter_events) // 3,
|
||||||
|
"meta": call_obj,
|
||||||
|
"first_event_fn": get_ev_fev_timedelta,
|
||||||
|
"utter_events": utter_events,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_cid(uri):
|
||||||
|
return Path(urlsplit(uri).path).stem
|
||||||
|
|
||||||
|
def ensure_call(uri):
|
||||||
|
cid = get_cid(uri)
|
||||||
|
meta = mongo_collection.find_one({"SystemID": cid})
|
||||||
|
process_meta = process_call(meta)
|
||||||
|
return process_meta
|
||||||
|
|
||||||
|
def retrieve_processed_callmeta(uri):
|
||||||
|
cid = get_cid(uri)
|
||||||
|
meta = mongo_collection.find_one({"SystemID": cid})
|
||||||
|
duration = meta["EndTS"] - meta["StartTS"]
|
||||||
|
process_meta = process_call(meta)
|
||||||
|
data_points = get_data_points(
|
||||||
|
process_meta["utter_events"], process_meta["first_event_fn"]
|
||||||
|
)
|
||||||
|
process_meta["data_points"] = data_points
|
||||||
|
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
|
||||||
|
|
||||||
|
def retrieve_callmeta(call_uri):
|
||||||
|
uri = call_uri["call_uri"]
|
||||||
|
name = call_uri["name"]
|
||||||
|
cid = get_cid(uri)
|
||||||
|
meta = mongo_collection.find_one({"SystemID": cid})
|
||||||
|
duration = meta["EndTS"] - meta["StartTS"]
|
||||||
|
process_meta = process_call(meta)
|
||||||
|
data_points = get_data_points(
|
||||||
|
process_meta["utter_events"], process_meta["first_event_fn"]
|
||||||
|
)
|
||||||
|
process_meta["data_points"] = data_points
|
||||||
|
return {
|
||||||
|
"url": uri,
|
||||||
|
"name": name,
|
||||||
|
"meta": meta,
|
||||||
|
"duration": duration,
|
||||||
|
"process": process_meta,
|
||||||
|
}
|
||||||
|
|
||||||
|
def download_meta_audio():
|
||||||
|
call_lens = lens["users"].Each()["calls"].Each()
|
||||||
|
call_lens.modify(ensure_call)(call_logs)
|
||||||
|
|
||||||
|
def plot_calls_data():
|
||||||
|
def plot_data_points(y, sr, data_points, file_path):
|
||||||
|
plt.figure(figsize=(16, 12))
|
||||||
|
librosa.display.waveplot(y=y, sr=sr)
|
||||||
|
for dp in data_points:
|
||||||
|
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
|
||||||
|
plt.axvspan(start, end, color="green", alpha=0.2)
|
||||||
|
text_pos = (start + end) / 2
|
||||||
|
plt.text(
|
||||||
|
text_pos,
|
||||||
|
0.25,
|
||||||
|
f"{code}",
|
||||||
|
rotation=90,
|
||||||
|
horizontalalignment="center",
|
||||||
|
verticalalignment="center",
|
||||||
|
)
|
||||||
|
plt.title("Datapoints")
|
||||||
|
plt.savefig(file_path, format="png")
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def plot_call(call_obj):
|
||||||
|
saved_wav_path, data_points, sys_id = (
|
||||||
|
call_obj["process"]["wav_path"],
|
||||||
|
call_obj["process"]["data_points"],
|
||||||
|
call_obj["meta"]["SystemID"],
|
||||||
|
)
|
||||||
|
file_path = call_plot_dir / Path(sys_id).with_suffix(".png")
|
||||||
|
if not file_path.exists():
|
||||||
|
print(f"plotting: {file_path}")
|
||||||
|
(y, sr) = librosa.load(saved_wav_path)
|
||||||
|
plot_data_points(y, sr, data_points, str(file_path))
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
call_lens = lens["users"].Each()["calls"].Each()
|
||||||
|
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
|
||||||
|
# call_plot_data = call_lens.collect()(call_stats)
|
||||||
|
call_plots = call_lens.modify(plot_call)(call_stats)
|
||||||
|
# with ThreadPoolExecutor(max_workers=20) as exe:
|
||||||
|
# print('starting all plot tasks')
|
||||||
|
# responses = [exe.submit(plot_call, w) for w in call_plot_data]
|
||||||
|
# print('submitted all plot tasks')
|
||||||
|
# call_plots = [r.result() for r in responses]
|
||||||
|
pprint(call_plots)
|
||||||
|
|
||||||
|
def extract_data_points():
|
||||||
|
if strip_silent_chunks:
|
||||||
|
|
||||||
|
def audio_process(seg):
|
||||||
|
return strip_silence(seg)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def audio_process(seg):
|
||||||
|
return seg
|
||||||
|
|
||||||
|
def gen_data_values(saved_wav_path, data_points, caller_name):
|
||||||
|
call_seg = (
|
||||||
|
AudioSegment.from_wav(saved_wav_path)
|
||||||
|
.set_channels(1)
|
||||||
|
.set_sample_width(2)
|
||||||
|
.set_frame_rate(24000)
|
||||||
|
)
|
||||||
|
for dp_id, dp in enumerate(data_points):
|
||||||
|
start, end, spoken = dp["start_time"], dp["end_time"], dp["code"]
|
||||||
|
spoken_seg = audio_process(call_seg[start * 1000 : end * 1000])
|
||||||
|
spoken_fb = BytesIO()
|
||||||
|
spoken_seg.export(spoken_fb, format="wav")
|
||||||
|
spoken_wav = spoken_fb.getvalue()
|
||||||
|
# search for actual pnr code and handle plain codes as well
|
||||||
|
extracted_code = text_extractor(spoken)
|
||||||
|
if strip_silent_chunks and spoken_seg.duration_seconds < 0.5:
|
||||||
|
print(f'transcript chunk "{spoken}" contains no audio skipping.')
|
||||||
|
continue
|
||||||
|
yield extracted_code, spoken_seg.duration_seconds, spoken_wav, caller_name, spoken_seg
|
||||||
|
|
||||||
|
call_lens = lens["users"].Each()["calls"].Each()
|
||||||
|
|
||||||
|
def assign_user_call(uc):
|
||||||
|
return (
|
||||||
|
lens["calls"]
|
||||||
|
.Each()
|
||||||
|
.modify(lambda c: {"call_uri": c, "name": uc["name"]})(uc)
|
||||||
|
)
|
||||||
|
|
||||||
|
user_call_logs = lens["users"].Each().modify(assign_user_call)(call_logs)
|
||||||
|
call_stats = call_lens.modify(retrieve_callmeta)(user_call_logs)
|
||||||
|
call_objs = call_lens.collect()(call_stats)
|
||||||
|
|
||||||
|
def data_source():
|
||||||
|
for call_obj in tqdm(call_objs):
|
||||||
|
saved_wav_path, data_points, name = (
|
||||||
|
call_obj["process"]["wav_path"],
|
||||||
|
call_obj["process"]["data_points"],
|
||||||
|
call_obj["name"],
|
||||||
|
)
|
||||||
|
for dp in gen_data_values(saved_wav_path, data_points, name):
|
||||||
|
yield dp
|
||||||
|
|
||||||
|
ui_dump_manifest_writer(call_asr_data, dataset_name, data_source())
|
||||||
|
|
||||||
|
def show_leaderboard():
|
||||||
|
def compute_user_stats(call_stat):
|
||||||
|
n_samples = (
|
||||||
|
lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
|
||||||
|
)
|
||||||
|
n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
|
||||||
|
return {
|
||||||
|
"num_samples": n_samples,
|
||||||
|
"duration": n_duration.total_seconds(),
|
||||||
|
"samples_rate": n_samples / n_duration.total_seconds(),
|
||||||
|
"duration_str": compress(n_duration, pad=" "),
|
||||||
|
"name": call_stat["name"],
|
||||||
|
}
|
||||||
|
|
||||||
|
call_lens = lens["users"].Each()["calls"].Each()
|
||||||
|
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
|
||||||
|
user_stats = lens["users"].Each().modify(compute_user_stats)(call_stats)
|
||||||
|
leader_df = (
|
||||||
|
pd.DataFrame(user_stats["users"])
|
||||||
|
.sort_values(by=["duration"], ascending=False)
|
||||||
|
.reset_index(drop=True)
|
||||||
|
)
|
||||||
|
leader_df["rank"] = leader_df.index + 1
|
||||||
|
leader_board = leader_df.rename(
|
||||||
|
columns={
|
||||||
|
"rank": "Rank",
|
||||||
|
"num_samples": "Count",
|
||||||
|
"name": "Name",
|
||||||
|
"samples_rate": "SpeechRate",
|
||||||
|
"duration_str": "Duration",
|
||||||
|
}
|
||||||
|
)[["Rank", "Name", "Count", "Duration"]]
|
||||||
|
print(
|
||||||
|
"""ASR Dataset Leaderboard :
|
||||||
|
---------------------------------"""
|
||||||
|
)
|
||||||
|
print(leader_board.to_string(index=False))
|
||||||
|
|
||||||
|
if download_only:
|
||||||
|
download_meta_audio()
|
||||||
|
return
|
||||||
|
if leaderboard:
|
||||||
|
show_leaderboard()
|
||||||
|
if plot_calls:
|
||||||
|
plot_calls_data()
|
||||||
|
if extract_data:
|
||||||
|
extract_data_points()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
27
jasper/data/conv_generator.py
Normal file
27
jasper/data/conv_generator.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import typer
|
||||||
|
from pathlib import Path
|
||||||
|
from .utils import generate_dates
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def export_conv_json(
|
||||||
|
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
|
||||||
|
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
||||||
|
):
|
||||||
|
from .utils import ExtendedPath
|
||||||
|
|
||||||
|
conv_data = ExtendedPath(conv_src).read_json()
|
||||||
|
|
||||||
|
conv_data["dates"] = generate_dates()
|
||||||
|
|
||||||
|
ExtendedPath(conv_dest).write_json(conv_data)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
98
jasper/data/nlu_generator.py
Normal file
98
jasper/data/nlu_generator.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
|
import pandas as pd
|
||||||
|
from ruamel.yaml import YAML
|
||||||
|
from itertools import product
|
||||||
|
from .utils import generate_dates
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
def unique_entity_list(entity_template_tags, entity_data):
|
||||||
|
unique_entity_set = {
|
||||||
|
t
|
||||||
|
for n in range(1, 5)
|
||||||
|
for t in entity_data[f"Answer.utterance-{n}"].tolist()
|
||||||
|
if any(et in t for et in entity_template_tags)
|
||||||
|
}
|
||||||
|
return list(unique_entity_set)
|
||||||
|
|
||||||
|
|
||||||
|
def nlu_entity_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
|
||||||
|
yaml = YAML()
|
||||||
|
nlu_data = yaml.load(nlu_data_file.read_text())
|
||||||
|
for cf in nlu_data["csv_files"]:
|
||||||
|
data = pd.read_csv(cf["fname"])
|
||||||
|
for et in cf["entities"]:
|
||||||
|
entity_name = et["name"]
|
||||||
|
entity_template_tags = et["tags"]
|
||||||
|
if "filter" in et:
|
||||||
|
entity_data = data[data[cf["filter_key"]] == et["filter"]]
|
||||||
|
else:
|
||||||
|
entity_data = data
|
||||||
|
yield entity_name, entity_template_tags, entity_data
|
||||||
|
|
||||||
|
|
||||||
|
def nlu_samples_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
|
||||||
|
yaml = YAML()
|
||||||
|
nlu_data = yaml.load(nlu_data_file.read_text())
|
||||||
|
sm = {s["name"]: s for s in nlu_data["samples_per_entity"]}
|
||||||
|
return sm
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def compute_unique_nlu_stats(
|
||||||
|
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
|
||||||
|
):
|
||||||
|
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
|
||||||
|
nlu_data_file
|
||||||
|
):
|
||||||
|
entity_count = len(unique_entity_list(entity_template_tags, entity_data))
|
||||||
|
print(f"{entity_name}\t{entity_count}")
|
||||||
|
|
||||||
|
|
||||||
|
def replace_entity(tmpl, value, tags):
|
||||||
|
result = tmpl
|
||||||
|
for t in tags:
|
||||||
|
result = result.replace(t, value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def export_nlu_conv_json(
|
||||||
|
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
|
||||||
|
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
||||||
|
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
|
||||||
|
):
|
||||||
|
from .utils import ExtendedPath
|
||||||
|
from random import sample
|
||||||
|
|
||||||
|
entity_samples = nlu_samples_reader(nlu_data_file)
|
||||||
|
conv_data = ExtendedPath(conv_src).read_json()
|
||||||
|
conv_data["Dates"] = generate_dates()
|
||||||
|
result_dict = {}
|
||||||
|
data_count = 0
|
||||||
|
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
|
||||||
|
nlu_data_file
|
||||||
|
):
|
||||||
|
entity_variants = sample(conv_data[entity_name], entity_samples[entity_name]["test_size"])
|
||||||
|
unique_entites = unique_entity_list(entity_template_tags, entity_data)
|
||||||
|
# sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
|
||||||
|
result_dict[entity_name] = []
|
||||||
|
for val in entity_variants:
|
||||||
|
sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
|
||||||
|
for tmpl in sample_entites:
|
||||||
|
result = replace_entity(tmpl, val, entity_template_tags)
|
||||||
|
result_dict[entity_name].append(result)
|
||||||
|
data_count += 1
|
||||||
|
print(f"Total of {data_count} variants generated")
|
||||||
|
ExtendedPath(conv_dest).write_json(result_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -38,9 +38,9 @@ def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
|||||||
def split_data(dataset_path: Path, test_size: float = 0.1):
|
def split_data(dataset_path: Path, test_size: float = 0.1):
|
||||||
manifest_path = dataset_path / Path("abs_manifest.json")
|
manifest_path = dataset_path / Path("abs_manifest.json")
|
||||||
asr_data = list(asr_manifest_reader(manifest_path))
|
asr_data = list(asr_manifest_reader(manifest_path))
|
||||||
train_data, test_data = train_test_split(asr_data, test_size=test_size)
|
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
|
||||||
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_data)
|
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
|
||||||
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_data)
|
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
@@ -52,9 +52,9 @@ def validate_data(dataset_path: Path):
|
|||||||
data_file = dataset_path / Path(mf_type)
|
data_file = dataset_path / Path(mf_type)
|
||||||
print(f"validating {data_file}.")
|
print(f"validating {data_file}.")
|
||||||
with Path(data_file).open("r") as pf:
|
with Path(data_file).open("r") as pf:
|
||||||
data_jsonl = pf.readlines()
|
pnr_jsonl = pf.readlines()
|
||||||
duration = 0
|
duration = 0
|
||||||
for (i, s) in enumerate(data_jsonl):
|
for (i, s) in enumerate(pnr_jsonl):
|
||||||
try:
|
try:
|
||||||
d = json.loads(s)
|
d = json.loads(s)
|
||||||
duration += d["duration"]
|
duration += d["duration"]
|
||||||
|
|||||||
@@ -1,93 +0,0 @@
|
|||||||
from rastrik.proto.callrecord_pb2 import CallRecord
|
|
||||||
import gzip
|
|
||||||
from pydub import AudioSegment
|
|
||||||
from .utils import ui_dump_manifest_writer, strip_silence
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from itertools import chain
|
|
||||||
from io import BytesIO
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def extract_manifest(
|
|
||||||
call_log_dir: Path = Path("./data/call_audio"),
|
|
||||||
output_dir: Path = Path("./data"),
|
|
||||||
dataset_name: str = "grassroot_pizzahut_v1",
|
|
||||||
caller_name: str = "grassroot",
|
|
||||||
verbose: bool = False,
|
|
||||||
):
|
|
||||||
call_asr_data: Path = output_dir / Path("asr_data")
|
|
||||||
call_asr_data.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
def wav_pb2_generator(log_dir):
|
|
||||||
for wav_path in log_dir.glob("**/*.wav"):
|
|
||||||
if verbose:
|
|
||||||
typer.echo(f"loading events for file {wav_path}")
|
|
||||||
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
|
|
||||||
meta_path = wav_path.with_suffix(".pb2.gz")
|
|
||||||
yield call_wav, wav_path, meta_path
|
|
||||||
|
|
||||||
def read_event(call_wav, log_file):
|
|
||||||
call_wav_0, call_wav_1 = call_wav.split_to_mono()
|
|
||||||
with gzip.open(log_file, "rb") as log_h:
|
|
||||||
record_data = log_h.read()
|
|
||||||
cr = CallRecord()
|
|
||||||
cr.ParseFromString(record_data)
|
|
||||||
|
|
||||||
first_audio_event_timestamp = next(
|
|
||||||
(
|
|
||||||
i
|
|
||||||
for i in cr.events
|
|
||||||
if i.WhichOneof("event_type") == "call_event"
|
|
||||||
and i.call_event.WhichOneof("event_type") == "call_audio"
|
|
||||||
)
|
|
||||||
).timestamp.ToDatetime()
|
|
||||||
|
|
||||||
speech_events = [
|
|
||||||
i
|
|
||||||
for i in cr.events
|
|
||||||
if i.WhichOneof("event_type") == "speech_event"
|
|
||||||
and i.speech_event.WhichOneof("event_type") == "asr_final"
|
|
||||||
]
|
|
||||||
previous_event_timestamp = (
|
|
||||||
first_audio_event_timestamp - first_audio_event_timestamp
|
|
||||||
)
|
|
||||||
for index, each_speech_events in enumerate(speech_events):
|
|
||||||
asr_final = each_speech_events.speech_event.asr_final
|
|
||||||
speech_timestamp = each_speech_events.timestamp.ToDatetime()
|
|
||||||
actual_timestamp = speech_timestamp - first_audio_event_timestamp
|
|
||||||
start_time = previous_event_timestamp.total_seconds() * 1000
|
|
||||||
end_time = actual_timestamp.total_seconds() * 1000
|
|
||||||
audio_segment = strip_silence(call_wav_1[start_time:end_time])
|
|
||||||
|
|
||||||
code_fb = BytesIO()
|
|
||||||
audio_segment.export(code_fb, format="wav")
|
|
||||||
wav_data = code_fb.getvalue()
|
|
||||||
previous_event_timestamp = actual_timestamp
|
|
||||||
duration = (end_time - start_time) / 1000
|
|
||||||
yield asr_final, duration, wav_data, "grassroot", audio_segment
|
|
||||||
|
|
||||||
def generate_call_asr_data():
|
|
||||||
full_data = []
|
|
||||||
total_duration = 0
|
|
||||||
for wav, wav_path, pb2_path in wav_pb2_generator(call_log_dir):
|
|
||||||
asr_data = read_event(wav, pb2_path)
|
|
||||||
total_duration += wav.duration_seconds
|
|
||||||
full_data.append(asr_data)
|
|
||||||
n_calls = len(full_data)
|
|
||||||
typer.echo(f"loaded {n_calls} calls of duration {total_duration}s")
|
|
||||||
n_dps = ui_dump_manifest_writer(call_asr_data, dataset_name, chain(*full_data))
|
|
||||||
typer.echo(f"written {n_dps} data points")
|
|
||||||
|
|
||||||
generate_call_asr_data()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
app()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
175
jasper/data/rev_recycler.py
Normal file
175
jasper/data/rev_recycler.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
import typer
|
||||||
|
from itertools import chain
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def extract_data(
|
||||||
|
call_audio_dir: Path = typer.Option(Path("/dataset/rev/wavs"), show_default=True),
|
||||||
|
call_meta_dir: Path = typer.Option(Path("/dataset/rev/jsons"), show_default=True),
|
||||||
|
output_dir: Path = typer.Option(Path("./data"), show_default=True),
|
||||||
|
dataset_name: str = typer.Option("rev_transribed", show_default=True),
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
|
from pydub import AudioSegment
|
||||||
|
from .utils import ExtendedPath, asr_data_writer, strip_silence
|
||||||
|
from lenses import lens
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
call_asr_data: Path = output_dir / Path("asr_data")
|
||||||
|
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
def wav_event_generator(call_audio_dir):
|
||||||
|
for wav_path in call_audio_dir.glob("**/*.wav"):
|
||||||
|
if verbose:
|
||||||
|
typer.echo(f"loading events for file {wav_path}")
|
||||||
|
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
|
||||||
|
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
|
||||||
|
meta_path = call_meta_dir / rel_meta_path
|
||||||
|
if meta_path.exists():
|
||||||
|
events = ExtendedPath(meta_path).read_json()
|
||||||
|
yield call_wav, wav_path, events
|
||||||
|
else:
|
||||||
|
if verbose:
|
||||||
|
typer.echo(f"missing json corresponding to {wav_path}")
|
||||||
|
|
||||||
|
def contains_asr(x):
|
||||||
|
return "AsrResult" in x
|
||||||
|
|
||||||
|
def channel(n):
|
||||||
|
def filter_func(ev):
|
||||||
|
return (
|
||||||
|
ev["AsrResult"]["Channel"] == n
|
||||||
|
if "Channel" in ev["AsrResult"]
|
||||||
|
else n == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return filter_func
|
||||||
|
|
||||||
|
def time_to_msecs(time_str):
|
||||||
|
return (
|
||||||
|
datetime.datetime.strptime(time_str, "%H:%M:%S,%f")
|
||||||
|
- datetime.datetime(1900, 1, 1)
|
||||||
|
).total_seconds() * 1000
|
||||||
|
|
||||||
|
def process_utterance_chunk(wav_seg, start_time, end_time, monologue):
|
||||||
|
# offset by 1sec left side to include vad? discarded audio
|
||||||
|
full_tscript_wav_seg = wav_seg[
|
||||||
|
time_to_msecs(start_time) - 1000 : time_to_msecs(end_time) # + 1000
|
||||||
|
]
|
||||||
|
tscript_wav_seg = strip_silence(full_tscript_wav_seg)
|
||||||
|
tscript_wav_fb = BytesIO()
|
||||||
|
tscript_wav_seg.export(tscript_wav_fb, format="wav")
|
||||||
|
tscript_wav = tscript_wav_fb.getvalue()
|
||||||
|
text = "".join(lens["elements"].Each()["value"].collect()(monologue))
|
||||||
|
text_clean = re.sub(r"\[.*\]", "", text)
|
||||||
|
return tscript_wav, tscript_wav_seg.duration_seconds, text_clean
|
||||||
|
|
||||||
|
def dual_asr_data_generator(wav_seg, wav_path, meta):
|
||||||
|
left_audio, right_audio = wav_seg.split_to_mono()
|
||||||
|
channel_map = {"Agent": right_audio, "Client": left_audio}
|
||||||
|
monologues = lens["monologues"].Each().collect()(meta)
|
||||||
|
for monologue in monologues:
|
||||||
|
# print(monologue["speaker_name"])
|
||||||
|
speaker_channel = channel_map.get(monologue["speaker_name"])
|
||||||
|
if not speaker_channel:
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f'unknown speaker tag {monologue["speaker_name"]} in wav:{wav_path} skipping.'
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
start_time = (
|
||||||
|
lens["elements"]
|
||||||
|
.Each()
|
||||||
|
.Filter(lambda x: "timestamp" in x)["timestamp"]
|
||||||
|
.collect()(monologue)[0]
|
||||||
|
)
|
||||||
|
end_time = (
|
||||||
|
lens["elements"]
|
||||||
|
.Each()
|
||||||
|
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
|
||||||
|
.collect()(monologue)[-1]
|
||||||
|
)
|
||||||
|
except IndexError:
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f"error when loading timestamp events in wav:{wav_path} skipping."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
tscript_wav, seg_dur, text_clean = process_utterance_chunk(
|
||||||
|
speaker_channel, start_time, end_time, monologue
|
||||||
|
)
|
||||||
|
if seg_dur < 0.5:
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f'transcript chunk "{text_clean}" contains no audio in {wav_path} skipping.'
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
yield text_clean, seg_dur, tscript_wav
|
||||||
|
|
||||||
|
def mono_asr_data_generator(wav_seg, wav_path, meta):
|
||||||
|
monologues = lens["monologues"].Each().collect()(meta)
|
||||||
|
for monologue in monologues:
|
||||||
|
try:
|
||||||
|
start_time = (
|
||||||
|
lens["elements"]
|
||||||
|
.Each()
|
||||||
|
.Filter(lambda x: "timestamp" in x)["timestamp"]
|
||||||
|
.collect()(monologue)[0]
|
||||||
|
)
|
||||||
|
end_time = (
|
||||||
|
lens["elements"]
|
||||||
|
.Each()
|
||||||
|
.Filter(lambda x: "end_timestamp" in x)["end_timestamp"]
|
||||||
|
.collect()(monologue)[-1]
|
||||||
|
)
|
||||||
|
except IndexError:
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f"error when loading timestamp events in wav:{wav_path} skipping."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
tscript_wav, seg_dur, text_clean = process_utterance_chunk(
|
||||||
|
wav_seg, start_time, end_time, monologue
|
||||||
|
)
|
||||||
|
if seg_dur < 0.5:
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f'transcript chunk "{text_clean}" contains no audio in {wav_path} skipping.'
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
yield text_clean, seg_dur, tscript_wav
|
||||||
|
|
||||||
|
def generate_rev_asr_data():
|
||||||
|
full_asr_data = []
|
||||||
|
total_duration = 0
|
||||||
|
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
|
||||||
|
if wav.channels > 2:
|
||||||
|
print(f"skipping many channel audio {wav_path}")
|
||||||
|
asr_data_generator = (
|
||||||
|
mono_asr_data_generator
|
||||||
|
if wav.channels == 1
|
||||||
|
else dual_asr_data_generator
|
||||||
|
)
|
||||||
|
asr_data = asr_data_generator(wav, wav_path, ev)
|
||||||
|
total_duration += wav.duration_seconds
|
||||||
|
full_asr_data.append(asr_data)
|
||||||
|
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
|
||||||
|
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
|
||||||
|
typer.echo(f"written {n_dps} data points")
|
||||||
|
|
||||||
|
generate_rev_asr_data()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
180
jasper/data/slu_evaluator.py
Normal file
180
jasper/data/slu_evaluator.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
import typer
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
|
||||||
|
# from .utils import generate_dates, asr_test_writer
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(reg_path, coll, s3, call_meta_dir, city_code, test_path):
|
||||||
|
from time import sleep
|
||||||
|
import subprocess
|
||||||
|
from .utils import ExtendedPath, get_call_logs
|
||||||
|
|
||||||
|
coll.delete_many({"CallID": test_path.name})
|
||||||
|
# test_path = dump_dir / data_name / test_file
|
||||||
|
# "../saas_reg/regression/run.sh -f data/asr_data/call_upwork_test_cnd_cities/asr_test.reg"
|
||||||
|
test_output = subprocess.run(
|
||||||
|
["/bin/bash", "-c", f"{str(reg_path)} --addr [::]:15400 -f {str(test_path)}"]
|
||||||
|
)
|
||||||
|
if test_output.returncode != 0:
|
||||||
|
print("Error running test {test_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_meta():
|
||||||
|
call_meta = coll.find_one({"CallID": test_path.name})
|
||||||
|
if call_meta:
|
||||||
|
return call_meta
|
||||||
|
else:
|
||||||
|
sleep(2)
|
||||||
|
return get_meta()
|
||||||
|
|
||||||
|
call_meta = get_meta()
|
||||||
|
call_logs = get_call_logs(call_meta, s3, call_meta_dir)
|
||||||
|
call_events = call_logs["Events"]
|
||||||
|
|
||||||
|
test_data_path = test_path.with_suffix(".result.json")
|
||||||
|
test_data = ExtendedPath(test_data_path).read_json()
|
||||||
|
|
||||||
|
def is_final_asr_event_or_spoken(ev):
|
||||||
|
pld = json.loads(ev["Payload"])
|
||||||
|
return (
|
||||||
|
pld["AsrResult"]["Results"][0]["IsFinal"]
|
||||||
|
if ev["Type"] == "ASR_RESULT"
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_test_event(ev):
|
||||||
|
return (
|
||||||
|
ev["Author"] == "NLU"
|
||||||
|
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
|
||||||
|
) and (ev["Type"] != "DEBUG")
|
||||||
|
|
||||||
|
test_evs = list(filter(is_test_event, call_events))
|
||||||
|
if len(test_evs) == 2:
|
||||||
|
try:
|
||||||
|
asr_payload = test_evs[0]["Payload"]
|
||||||
|
asr_result = json.loads(asr_payload)["AsrResult"]["Results"][0]
|
||||||
|
alt_tscripts = [alt["Transcript"] for alt in asr_result["Alternatives"]]
|
||||||
|
gcp_result = "|".join(alt_tscripts)
|
||||||
|
entity_asr = asr_result["AsrDynamicResults"][0]["Candidate"]["Transcript"]
|
||||||
|
nlu_payload = test_evs[1]["Payload"]
|
||||||
|
nlu_result_payload = json.loads(nlu_payload)["NluResults"]
|
||||||
|
entity = test_data[0]["entity"]
|
||||||
|
text = test_data[0]["text"]
|
||||||
|
audio_filepath = test_data[0]["audio_filepath"]
|
||||||
|
pretrained_asr = test_data[0]["pretrained_asr"]
|
||||||
|
nlu_entity = list(json.loads(nlu_result_payload)["Entities"].values())[0]
|
||||||
|
asr_entity = city_code[entity] if entity in city_code else "UNKNOWN"
|
||||||
|
entities_match = asr_entity == nlu_entity
|
||||||
|
result = "Success" if entities_match else "Fail"
|
||||||
|
return {
|
||||||
|
"expected_entity": entity,
|
||||||
|
"text": text,
|
||||||
|
"audio_filepath": audio_filepath,
|
||||||
|
"pretrained_asr": pretrained_asr,
|
||||||
|
"entity_asr": entity_asr,
|
||||||
|
"google_asr": gcp_result,
|
||||||
|
"nlu_result": nlu_result_payload,
|
||||||
|
"asr_entity": asr_entity,
|
||||||
|
"nlu_entity": nlu_entity,
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return {
|
||||||
|
"expected_entity": test_data[0]["entity"],
|
||||||
|
"text": test_data[0]["text"],
|
||||||
|
"audio_filepath": test_data[0]["audio_filepath"],
|
||||||
|
"pretrained_asr": test_data[0]["pretrained_asr"],
|
||||||
|
"entity_asr": "",
|
||||||
|
"google_asr": "",
|
||||||
|
"nlu_result": "",
|
||||||
|
"asr_entity": "",
|
||||||
|
"nlu_entity": "",
|
||||||
|
"result": "Error",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"expected_entity": test_data[0]["entity"],
|
||||||
|
"text": test_data[0]["text"],
|
||||||
|
"audio_filepath": test_data[0]["audio_filepath"],
|
||||||
|
"pretrained_asr": test_data[0]["pretrained_asr"],
|
||||||
|
"entity_asr": "",
|
||||||
|
"google_asr": "",
|
||||||
|
"nlu_result": "",
|
||||||
|
"asr_entity": "",
|
||||||
|
"nlu_entity": "",
|
||||||
|
"result": "Empty",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def evaluate_slu(
|
||||||
|
# conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
|
||||||
|
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
|
||||||
|
# extraction_key: str = "Cities",
|
||||||
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
|
call_meta_dir: Path = Path("./data/call_metas"),
|
||||||
|
test_file_pref: str = "asr_test",
|
||||||
|
mongo_uri: str = typer.Option(
|
||||||
|
"mongodb://localhost:27017/test.calls", show_default=True
|
||||||
|
),
|
||||||
|
test_results: Path = Path("./data/results.csv"),
|
||||||
|
airport_codes: Path = Path("./airports_code.csv"),
|
||||||
|
reg_path: Path = Path("../saas_reg/regression/run.sh"),
|
||||||
|
test_id: str = "5ef481f27031edf6910e94e0",
|
||||||
|
):
|
||||||
|
# import json
|
||||||
|
from .utils import get_mongo_coll
|
||||||
|
import pandas as pd
|
||||||
|
import boto3
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
# import subprocess
|
||||||
|
# from time import sleep
|
||||||
|
import csv
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
df = pd.read_csv(airport_codes)[["iata", "city"]]
|
||||||
|
city_code = pd.Series(df["iata"].values, index=df["city"]).to_dict()
|
||||||
|
|
||||||
|
test_files = list((dump_dir / data_name).glob(test_file_pref + "*.reg"))
|
||||||
|
coll = get_mongo_coll(mongo_uri)
|
||||||
|
with test_results.open("w") as csvfile:
|
||||||
|
fieldnames = [
|
||||||
|
"expected_entity",
|
||||||
|
"text",
|
||||||
|
"audio_filepath",
|
||||||
|
"pretrained_asr",
|
||||||
|
"entity_asr",
|
||||||
|
"google_asr",
|
||||||
|
"nlu_result",
|
||||||
|
"asr_entity",
|
||||||
|
"nlu_entity",
|
||||||
|
"result",
|
||||||
|
]
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
with ThreadPoolExecutor(max_workers=8) as exe:
|
||||||
|
print("starting all loading tasks")
|
||||||
|
for test_result in tqdm(
|
||||||
|
exe.map(
|
||||||
|
partial(run_test, reg_path, coll, s3, call_meta_dir, city_code),
|
||||||
|
test_files,
|
||||||
|
),
|
||||||
|
position=0,
|
||||||
|
leave=True,
|
||||||
|
total=len(test_files),
|
||||||
|
):
|
||||||
|
writer.writerow(test_result)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
99
jasper/data/test_generator.py
Normal file
99
jasper/data/test_generator.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
import typer
|
||||||
|
from pathlib import Path
|
||||||
|
from .utils import generate_dates, asr_test_writer
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def export_test_reg(
|
||||||
|
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
|
||||||
|
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
|
||||||
|
extraction_key: str = "Cities",
|
||||||
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
|
dump_file: Path = Path("ui_dump.json"),
|
||||||
|
manifest_file: Path = Path("manifest.json"),
|
||||||
|
test_file: Path = Path("asr_test.reg"),
|
||||||
|
):
|
||||||
|
from .utils import (
|
||||||
|
ExtendedPath,
|
||||||
|
asr_manifest_reader,
|
||||||
|
gcp_transcribe_gen,
|
||||||
|
parallel_apply,
|
||||||
|
)
|
||||||
|
from ..client import transcribe_gen
|
||||||
|
from pydub import AudioSegment
|
||||||
|
from queue import PriorityQueue
|
||||||
|
|
||||||
|
jasper_map = {
|
||||||
|
"PNRs": 8045,
|
||||||
|
"Cities": 8046,
|
||||||
|
"Names": 8047,
|
||||||
|
"Dates": 8048,
|
||||||
|
}
|
||||||
|
# jasper_map = {"PNRs": 8050, "Cities": 8050, "Names": 8050, "Dates": 8050}
|
||||||
|
transcriber_gcp = gcp_transcribe_gen()
|
||||||
|
transcriber_trained = transcribe_gen(asr_port=jasper_map[extraction_key])
|
||||||
|
transcriber_all_trained = transcribe_gen(asr_port=8050)
|
||||||
|
transcriber_libri_all_trained = transcribe_gen(asr_port=8051)
|
||||||
|
|
||||||
|
def find_ent(dd, conv_data):
|
||||||
|
ents = PriorityQueue()
|
||||||
|
for ent in conv_data:
|
||||||
|
if ent in dd["text"]:
|
||||||
|
ents.put((-len(ent), ent))
|
||||||
|
return ents.get_nowait()[1]
|
||||||
|
|
||||||
|
def process_data(d):
|
||||||
|
orig_seg = AudioSegment.from_wav(d["audio_path"])
|
||||||
|
jas_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(24000)
|
||||||
|
gcp_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||||
|
deepgram_file = Path("/home/shubham/voice_auto/pnrs/wav/") / Path(
|
||||||
|
d["audio_path"].stem + ".txt"
|
||||||
|
)
|
||||||
|
if deepgram_file.exists():
|
||||||
|
d["deepgram"] = "".join(
|
||||||
|
[s.replace("CHANNEL 0:", "") for s in deepgram_file.read_text().split("\n")]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
d["deepgram"] = 'Not Found'
|
||||||
|
d["audio_path"] = str(d["audio_path"])
|
||||||
|
d["gcp_transcript"] = transcriber_gcp(gcp_seg.raw_data)
|
||||||
|
d["jasper_trained"] = transcriber_trained(jas_seg.raw_data)
|
||||||
|
d["jasper_all"] = transcriber_all_trained(jas_seg.raw_data)
|
||||||
|
d["jasper_libri"] = transcriber_libri_all_trained(jas_seg.raw_data)
|
||||||
|
return d
|
||||||
|
|
||||||
|
conv_data = ExtendedPath(conv_src).read_json()
|
||||||
|
conv_data["Dates"] = generate_dates()
|
||||||
|
|
||||||
|
dump_data_path = dump_dir / Path(data_name) / dump_file
|
||||||
|
ui_dump_data = ExtendedPath(dump_data_path).read_json()["data"]
|
||||||
|
ui_dump_map = {i["utterance_id"]: i for i in ui_dump_data}
|
||||||
|
manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||||
|
test_points = list(asr_manifest_reader(manifest_path))
|
||||||
|
test_data_objs = [{**(ui_dump_map[t["audio_path"].stem]), **t} for t in test_points]
|
||||||
|
test_data = parallel_apply(process_data, test_data_objs)
|
||||||
|
# test_data = [process_data(t) for t in test_data_objs]
|
||||||
|
test_path = dump_dir / Path(data_name) / test_file
|
||||||
|
|
||||||
|
def dd_gen(dump_data):
|
||||||
|
for dd in dump_data:
|
||||||
|
ent = find_ent(dd, conv_data[extraction_key])
|
||||||
|
dd["entity"] = ent
|
||||||
|
if ent:
|
||||||
|
yield dd
|
||||||
|
|
||||||
|
asr_test_writer(test_path, dd_gen(test_data))
|
||||||
|
# for i, b in enumerate(batch(test_data, 1)):
|
||||||
|
# test_fname = Path(f"{test_file.stem}_{i}.reg")
|
||||||
|
# test_path = dump_dir / Path(data_name) / test_fname
|
||||||
|
# asr_test_writer(test_path, dd_gen(test_data))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
jasper/data/tts/__init__.py
Normal file
0
jasper/data/tts/__init__.py
Normal file
52
jasper/data/tts/googletts.py
Normal file
52
jasper/data/tts/googletts.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from logging import getLogger
|
||||||
|
from google.cloud import texttospeech
|
||||||
|
|
||||||
|
LOGGER = getLogger("googletts")
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleTTS(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.client = texttospeech.TextToSpeechClient()
|
||||||
|
|
||||||
|
def text_to_speech(self, text: str, params: dict) -> bytes:
|
||||||
|
tts_input = texttospeech.types.SynthesisInput(ssml=text)
|
||||||
|
voice = texttospeech.types.VoiceSelectionParams(
|
||||||
|
language_code=params["language"], name=params["name"]
|
||||||
|
)
|
||||||
|
audio_config = texttospeech.types.AudioConfig(
|
||||||
|
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
|
||||||
|
sample_rate_hertz=params["sample_rate"],
|
||||||
|
)
|
||||||
|
response = self.client.synthesize_speech(tts_input, voice, audio_config)
|
||||||
|
audio_content = response.audio_content
|
||||||
|
return audio_content
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def voice_list(cls):
|
||||||
|
"""Lists the available voices."""
|
||||||
|
|
||||||
|
client = cls().client
|
||||||
|
|
||||||
|
# Performs the list voices request
|
||||||
|
voices = client.list_voices()
|
||||||
|
results = []
|
||||||
|
for voice in voices.voices:
|
||||||
|
supported_eng_langs = [
|
||||||
|
lang for lang in voice.language_codes if lang[:2] == "en"
|
||||||
|
]
|
||||||
|
if len(supported_eng_langs) > 0:
|
||||||
|
lang = ",".join(supported_eng_langs)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"name": voice.name,
|
||||||
|
"language": lang,
|
||||||
|
"gender": ssml_gender.name,
|
||||||
|
"engine": "wavenet" if "Wav" in voice.name else "standard",
|
||||||
|
"sample_rate": voice.natural_sample_rate_hertz,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return results
|
||||||
26
jasper/data/tts/ttsclient.py
Normal file
26
jasper/data/tts/ttsclient.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
TTSClient Abstract Class
|
||||||
|
"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class TTSClient(ABC):
|
||||||
|
"""
|
||||||
|
Base class for TTS
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def text_to_speech(self, text: str, num_channels: int, sample_rate: int,
|
||||||
|
audio_encoding) -> bytes:
|
||||||
|
"""
|
||||||
|
convert text to bytes
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
text {[type]} -- text to convert
|
||||||
|
channel {[type]} -- output audio bytes channel setting
|
||||||
|
width {[type]} -- width of audio bytes
|
||||||
|
rate {[type]} -- rare for audio bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[type] -- [description]
|
||||||
|
"""
|
||||||
62
jasper/data/tts_generator.py
Normal file
62
jasper/data/tts_generator.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# import io
|
||||||
|
# import sys
|
||||||
|
# import json
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from .utils import random_pnr_generator, asr_data_writer
|
||||||
|
from .tts.googletts import GoogleTTS
|
||||||
|
from tqdm import tqdm
|
||||||
|
import random
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def pnr_tts_streamer(count):
|
||||||
|
google_voices = GoogleTTS.voice_list()
|
||||||
|
gtts = GoogleTTS()
|
||||||
|
for pnr_code in tqdm(random_pnr_generator(count)):
|
||||||
|
tts_code = f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
|
||||||
|
param = random.choice(google_voices)
|
||||||
|
param["sample_rate"] = 24000
|
||||||
|
param["num_channels"] = 1
|
||||||
|
wav_data = gtts.text_to_speech(text=tts_code, params=param)
|
||||||
|
audio_dur = len(wav_data[44:]) / (2 * 24000)
|
||||||
|
yield pnr_code, audio_dur, wav_data
|
||||||
|
|
||||||
|
|
||||||
|
def generate_asr_data_fromtts(output_dir, dataset_name, count):
|
||||||
|
asr_data_writer(output_dir, dataset_name, pnr_tts_streamer(count))
|
||||||
|
|
||||||
|
|
||||||
|
def arg_parser():
|
||||||
|
prog = Path(__file__).stem
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog=prog, description=f"generates asr training data"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("./train/asr_data"),
|
||||||
|
help="directory to output asr data",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--count", type=int, default=3, help="number of datapoints to generate"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_name", type=str, default="pnr_data", help="name of the dataset"
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = arg_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
generate_asr_data_fromtts(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,14 +1,22 @@
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import json
|
import json
|
||||||
|
import base64
|
||||||
import wave
|
import wave
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from itertools import product
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from math import floor
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
from urllib.parse import urlsplit, urlencode
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pymongo
|
import pymongo
|
||||||
from slugify import slugify
|
from slugify import slugify
|
||||||
|
from num2words import num2words
|
||||||
from jasper.client import transcribe_gen
|
from jasper.client import transcribe_gen
|
||||||
from nemo.collections.asr.metrics import word_error_rate
|
from nemo.collections.asr.metrics import word_error_rate
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@@ -34,6 +42,29 @@ def wav_bytes(audio_bytes, frame_rate=24000):
|
|||||||
return wf_b.getvalue()
|
return wf_b.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def random_pnr_generator(count=10000):
|
||||||
|
LENGTH = 3
|
||||||
|
|
||||||
|
# alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||||
|
alphabet = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||||
|
numeric = list("0123456789")
|
||||||
|
np_alphabet = np.array(alphabet, dtype="|S1")
|
||||||
|
np_numeric = np.array(numeric, dtype="|S1")
|
||||||
|
np_alpha_codes = np.random.choice(np_alphabet, [count, LENGTH])
|
||||||
|
np_num_codes = np.random.choice(np_numeric, [count, LENGTH])
|
||||||
|
np_code_seed = np.concatenate((np_alpha_codes, np_num_codes), axis=1).T
|
||||||
|
np.random.shuffle(np_code_seed)
|
||||||
|
np_codes = np_code_seed.T
|
||||||
|
codes = [(b"".join(np_codes[i])).decode("utf-8") for i in range(len(np_codes))]
|
||||||
|
return codes
|
||||||
|
|
||||||
|
|
||||||
|
def alnum_to_asr_tokens(text):
|
||||||
|
letters = " ".join(list(text))
|
||||||
|
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
|
||||||
|
return ("".join(num_tokens)).lower()
|
||||||
|
|
||||||
|
|
||||||
def tscript_uuid_fname(transcript):
|
def tscript_uuid_fname(transcript):
|
||||||
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||||
|
|
||||||
@@ -49,8 +80,8 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
|||||||
fname = tscript_uuid_fname(transcript)
|
fname = tscript_uuid_fname(transcript)
|
||||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||||
audio_file.write_bytes(wav_data)
|
audio_file.write_bytes(wav_data)
|
||||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
rel_pnr_path = audio_file.relative_to(dataset_dir)
|
||||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
|
||||||
mf.write(manifest)
|
mf.write(manifest)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"writing '{transcript}' of duration {audio_dur}")
|
print(f"writing '{transcript}' of duration {audio_dur}")
|
||||||
@@ -58,97 +89,104 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
|||||||
return num_datapoints
|
return num_datapoints
|
||||||
|
|
||||||
|
|
||||||
def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False):
|
|
||||||
dataset_dir = output_dir / Path(dataset_name)
|
|
||||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
|
||||||
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
def data_fn(
|
|
||||||
transcript,
|
|
||||||
audio_dur,
|
|
||||||
wav_data,
|
|
||||||
caller_name,
|
|
||||||
aud_seg,
|
|
||||||
fname,
|
|
||||||
audio_path,
|
|
||||||
num_datapoints,
|
|
||||||
rel_data_path,
|
|
||||||
):
|
|
||||||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
|
||||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
|
||||||
png_path = Path(fname).with_suffix(".png")
|
|
||||||
wav_plot_path = dataset_dir / Path("wav_plots") / png_path
|
|
||||||
if not wav_plot_path.exists():
|
|
||||||
plot_seg(wav_plot_path, audio_path)
|
|
||||||
return {
|
|
||||||
"audio_filepath": str(rel_data_path),
|
|
||||||
"duration": round(audio_dur, 1),
|
|
||||||
"text": transcript,
|
|
||||||
"real_idx": num_datapoints,
|
|
||||||
"audio_path": audio_path,
|
|
||||||
"spoken": transcript,
|
|
||||||
"caller": caller_name,
|
|
||||||
"utterance_id": fname,
|
|
||||||
"pretrained_asr": pretrained_result,
|
|
||||||
"pretrained_wer": pretrained_wer,
|
|
||||||
"plot_path": str(wav_plot_path),
|
|
||||||
}
|
|
||||||
|
|
||||||
num_datapoints = 0
|
|
||||||
data_funcs = []
|
|
||||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
|
||||||
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
|
||||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
|
||||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
|
||||||
audio_file.write_bytes(wav_data)
|
|
||||||
audio_path = str(audio_file)
|
|
||||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
|
||||||
data_funcs.append(
|
|
||||||
partial(
|
|
||||||
data_fn,
|
|
||||||
transcript,
|
|
||||||
audio_dur,
|
|
||||||
wav_data,
|
|
||||||
caller_name,
|
|
||||||
aud_seg,
|
|
||||||
fname,
|
|
||||||
audio_path,
|
|
||||||
num_datapoints,
|
|
||||||
rel_data_path,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
num_datapoints += 1
|
|
||||||
ui_data = parallel_apply(lambda x: x(), data_funcs)
|
|
||||||
return ui_data, num_datapoints
|
|
||||||
|
|
||||||
|
|
||||||
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||||
dataset_dir = output_dir / Path(dataset_name)
|
dataset_dir = output_dir / Path(dataset_name)
|
||||||
dump_data, num_datapoints = ui_data_generator(
|
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||||
output_dir, dataset_name, asr_data_source, verbose=verbose
|
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||||
)
|
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
asr_manifest = dataset_dir / Path("manifest.json")
|
asr_manifest = dataset_dir / Path("manifest.json")
|
||||||
|
num_datapoints = 0
|
||||||
|
ui_dump = {
|
||||||
|
"use_domain_asr": False,
|
||||||
|
"annotation_only": False,
|
||||||
|
"enable_plots": True,
|
||||||
|
"data": [],
|
||||||
|
}
|
||||||
|
data_funcs = []
|
||||||
|
|
||||||
|
deepgram_transcriber = deepgram_transcribe_gen()
|
||||||
|
# t2n = Text2Num()
|
||||||
|
transcriber_gcp = gcp_transcribe_gen()
|
||||||
|
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||||
with asr_manifest.open("w") as mf:
|
with asr_manifest.open("w") as mf:
|
||||||
print(f"writing manifest to {asr_manifest}")
|
print(f"writing manifest to {asr_manifest}")
|
||||||
for d in dump_data:
|
|
||||||
rel_data_path = d["audio_filepath"]
|
|
||||||
audio_dur = d["duration"]
|
|
||||||
transcript = d["text"]
|
|
||||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
|
||||||
mf.write(manifest)
|
|
||||||
|
|
||||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
def data_fn(
|
||||||
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
|
transcript,
|
||||||
|
audio_dur,
|
||||||
|
wav_data,
|
||||||
|
caller_name,
|
||||||
|
aud_seg,
|
||||||
|
fname,
|
||||||
|
audio_path,
|
||||||
|
num_datapoints,
|
||||||
|
rel_pnr_path,
|
||||||
|
):
|
||||||
|
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
||||||
|
gcp_seg = aud_seg.set_frame_rate(16000)
|
||||||
|
gcp_result = transcriber_gcp(gcp_seg.raw_data)
|
||||||
|
aud_data = audio_path.read_bytes()
|
||||||
|
dgram_result = deepgram_transcriber(aud_data)
|
||||||
|
# gtruth = dp['text']
|
||||||
|
# dgram_result = t2n.convert(dgram_script)
|
||||||
|
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
||||||
|
wav_plot_path = (
|
||||||
|
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
||||||
|
)
|
||||||
|
if not wav_plot_path.exists():
|
||||||
|
plot_seg(wav_plot_path, audio_path)
|
||||||
|
return {
|
||||||
|
"audio_filepath": str(rel_pnr_path),
|
||||||
|
"duration": round(audio_dur, 1),
|
||||||
|
"text": transcript,
|
||||||
|
"real_idx": num_datapoints,
|
||||||
|
"audio_path": audio_path,
|
||||||
|
"spoken": transcript,
|
||||||
|
"caller": caller_name,
|
||||||
|
"utterance_id": fname,
|
||||||
|
"gcp_asr": gcp_result,
|
||||||
|
"deepgram_asr": dgram_result,
|
||||||
|
"pretrained_asr": pretrained_result,
|
||||||
|
"pretrained_wer": pretrained_wer,
|
||||||
|
"plot_path": str(wav_plot_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
||||||
|
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||||
|
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||||
|
audio_file.write_bytes(wav_data)
|
||||||
|
audio_path = str(audio_file)
|
||||||
|
rel_pnr_path = audio_file.relative_to(dataset_dir)
|
||||||
|
manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
|
||||||
|
mf.write(manifest)
|
||||||
|
data_funcs.append(
|
||||||
|
partial(
|
||||||
|
data_fn,
|
||||||
|
transcript,
|
||||||
|
audio_dur,
|
||||||
|
wav_data,
|
||||||
|
caller_name,
|
||||||
|
aud_seg,
|
||||||
|
fname,
|
||||||
|
audio_path,
|
||||||
|
num_datapoints,
|
||||||
|
rel_pnr_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
num_datapoints += 1
|
||||||
|
dump_data = parallel_apply(lambda x: x(), data_funcs)
|
||||||
|
# dump_data = [x() for x in tqdm(data_funcs)]
|
||||||
|
ui_dump["data"] = dump_data
|
||||||
|
ExtendedPath(ui_dump_file).write_json(ui_dump)
|
||||||
return num_datapoints
|
return num_datapoints
|
||||||
|
|
||||||
|
|
||||||
def asr_manifest_reader(data_manifest_path: Path):
|
def asr_manifest_reader(data_manifest_path: Path):
|
||||||
print(f"reading manifest from {data_manifest_path}")
|
print(f"reading manifest from {data_manifest_path}")
|
||||||
with data_manifest_path.open("r") as pf:
|
with data_manifest_path.open("r") as pf:
|
||||||
data_jsonl = pf.readlines()
|
pnr_jsonl = pf.readlines()
|
||||||
data_data = [json.loads(v) for v in data_jsonl]
|
pnr_data = [json.loads(v) for v in pnr_jsonl]
|
||||||
for p in data_data:
|
for p in pnr_data:
|
||||||
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
||||||
p["text"] = p["text"].strip()
|
p["text"] = p["text"].strip()
|
||||||
yield p
|
yield p
|
||||||
@@ -198,6 +236,12 @@ class ExtendedPath(type(Path())):
|
|||||||
with self.open("r") as jf:
|
with self.open("r") as jf:
|
||||||
return json.load(jf)
|
return json.load(jf)
|
||||||
|
|
||||||
|
def read_jsonl(self):
|
||||||
|
print(f"reading jsonl from {self}")
|
||||||
|
with self.open("r") as jf:
|
||||||
|
for l in jf.readlines():
|
||||||
|
yield json.loads(l)
|
||||||
|
|
||||||
def write_json(self, data):
|
def write_json(self, data):
|
||||||
print(f"writing json to {self}")
|
print(f"writing json to {self}")
|
||||||
self.parent.mkdir(parents=True, exist_ok=True)
|
self.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -205,6 +249,12 @@ class ExtendedPath(type(Path())):
|
|||||||
return json.dump(data, jf, indent=2)
|
return json.dump(data, jf, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mongo_coll(uri="mongodb://localhost:27017/test.calls"):
|
||||||
|
ud = pymongo.uri_parser.parse_uri(uri)
|
||||||
|
conn = pymongo.MongoClient(uri)
|
||||||
|
return conn[ud["database"]][ud["collection"]]
|
||||||
|
|
||||||
|
|
||||||
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
|
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
|
||||||
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
||||||
mongo_uri = f"mongodb://{mongo_host}:{port}/"
|
mongo_uri = f"mongodb://{mongo_host}:{port}/"
|
||||||
@@ -230,6 +280,405 @@ def plot_seg(wav_plot_path, audio_path):
|
|||||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dates():
|
||||||
|
|
||||||
|
days = [i for i in range(1, 32)]
|
||||||
|
months = [
|
||||||
|
"January",
|
||||||
|
"February",
|
||||||
|
"March",
|
||||||
|
"April",
|
||||||
|
"May",
|
||||||
|
"June",
|
||||||
|
"July",
|
||||||
|
"August",
|
||||||
|
"September",
|
||||||
|
"October",
|
||||||
|
"November",
|
||||||
|
"December",
|
||||||
|
]
|
||||||
|
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
|
||||||
|
|
||||||
|
def ordinal(n):
|
||||||
|
return "%d%s" % (
|
||||||
|
n,
|
||||||
|
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
|
||||||
|
)
|
||||||
|
|
||||||
|
def canon_vars(d, m):
|
||||||
|
return [
|
||||||
|
ordinal(d) + " " + m,
|
||||||
|
m + " " + ordinal(d),
|
||||||
|
ordinal(d) + " of " + m,
|
||||||
|
m + " the " + ordinal(d),
|
||||||
|
str(d) + " " + m,
|
||||||
|
m + " " + str(d),
|
||||||
|
]
|
||||||
|
|
||||||
|
return [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
|
||||||
|
|
||||||
|
|
||||||
|
def get_call_logs(call_obj, s3, call_meta_dir):
|
||||||
|
meta_s3_uri = call_obj["DataURI"]
|
||||||
|
s3_event_url_p = urlsplit(meta_s3_uri)
|
||||||
|
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
|
||||||
|
if not saved_meta_path.exists():
|
||||||
|
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
|
||||||
|
s3.download_file(
|
||||||
|
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
|
||||||
|
)
|
||||||
|
call_metas = json.load(saved_meta_path.open())
|
||||||
|
return call_metas
|
||||||
|
|
||||||
|
|
||||||
|
def gcp_transcribe_gen():
|
||||||
|
from google.cloud import speech_v1
|
||||||
|
from google.cloud.speech_v1 import enums
|
||||||
|
|
||||||
|
# import io
|
||||||
|
client = speech_v1.SpeechClient()
|
||||||
|
# local_file_path = 'resources/brooklyn_bridge.raw'
|
||||||
|
|
||||||
|
# The language of the supplied audio
|
||||||
|
language_code = "en-US"
|
||||||
|
model = "phone_call"
|
||||||
|
|
||||||
|
# Sample rate in Hertz of the audio data sent
|
||||||
|
sample_rate_hertz = 16000
|
||||||
|
|
||||||
|
# Encoding of audio data sent. This sample sets this explicitly.
|
||||||
|
# This field is optional for FLAC and WAV audio formats.
|
||||||
|
encoding = enums.RecognitionConfig.AudioEncoding.LINEAR16
|
||||||
|
config = {
|
||||||
|
"language_code": language_code,
|
||||||
|
"sample_rate_hertz": sample_rate_hertz,
|
||||||
|
"encoding": encoding,
|
||||||
|
"model": model,
|
||||||
|
"enable_automatic_punctuation": True,
|
||||||
|
"max_alternatives": 10,
|
||||||
|
"enable_word_time_offsets": True, # used to detect start and end time of utterances
|
||||||
|
"speech_contexts": [
|
||||||
|
{
|
||||||
|
"phrases": [
|
||||||
|
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||||
|
"$OOV_CLASS_DIGIT_SEQUENCE",
|
||||||
|
"$TIME",
|
||||||
|
"$YEAR",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"phrases": [
|
||||||
|
"A",
|
||||||
|
"B",
|
||||||
|
"C",
|
||||||
|
"D",
|
||||||
|
"E",
|
||||||
|
"F",
|
||||||
|
"G",
|
||||||
|
"H",
|
||||||
|
"I",
|
||||||
|
"J",
|
||||||
|
"K",
|
||||||
|
"L",
|
||||||
|
"M",
|
||||||
|
"N",
|
||||||
|
"O",
|
||||||
|
"P",
|
||||||
|
"Q",
|
||||||
|
"R",
|
||||||
|
"S",
|
||||||
|
"T",
|
||||||
|
"U",
|
||||||
|
"V",
|
||||||
|
"W",
|
||||||
|
"X",
|
||||||
|
"Y",
|
||||||
|
"Z",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"phrases": [
|
||||||
|
"PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||||
|
"my PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||||
|
"my PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||||
|
"PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||||
|
"It's $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||||
|
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE is my PNR",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{"phrases": ["my name is"]},
|
||||||
|
{"phrases": ["Number $ORDINAL", "Numeral $ORDINAL"]},
|
||||||
|
{
|
||||||
|
"phrases": [
|
||||||
|
"John Smith",
|
||||||
|
"Carina Hu",
|
||||||
|
"Travis Lim",
|
||||||
|
"Marvin Tan",
|
||||||
|
"Samuel Tan",
|
||||||
|
"Dawn Mathew",
|
||||||
|
"Dawn",
|
||||||
|
"Mathew",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"phrases": [
|
||||||
|
"Beijing",
|
||||||
|
"Tokyo",
|
||||||
|
"London",
|
||||||
|
"19 August",
|
||||||
|
"7 October",
|
||||||
|
"11 December",
|
||||||
|
"17 September",
|
||||||
|
"19th August",
|
||||||
|
"7th October",
|
||||||
|
"11th December",
|
||||||
|
"17th September",
|
||||||
|
"ABC123",
|
||||||
|
"KWXUNP",
|
||||||
|
"XLU5K1",
|
||||||
|
"WL2JV6",
|
||||||
|
"KBS651",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"phrases": [
|
||||||
|
"first flight",
|
||||||
|
"second flight",
|
||||||
|
"third flight",
|
||||||
|
"first option",
|
||||||
|
"second option",
|
||||||
|
"third option",
|
||||||
|
"first one",
|
||||||
|
"second one",
|
||||||
|
"third one",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"industry_naics_code_of_audio": 481111,
|
||||||
|
"interaction_type": enums.RecognitionMetadata.InteractionType.PHONE_CALL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def sample_recognize(content):
|
||||||
|
"""
|
||||||
|
Transcribe a short audio file using synchronous speech recognition
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_file_path Path to local audio file, e.g. /path/audio.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
# with io.open(local_file_path, "rb") as f:
|
||||||
|
# content = f.read()
|
||||||
|
audio = {"content": content}
|
||||||
|
|
||||||
|
response = client.recognize(config, audio)
|
||||||
|
for result in response.results:
|
||||||
|
# First alternative is the most probable result
|
||||||
|
return "/".join([alt.transcript for alt in result.alternatives])
|
||||||
|
# print(u"Transcript: {}".format(alternative.transcript))
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return sample_recognize
|
||||||
|
|
||||||
|
|
||||||
|
def deepgram_transcribe_gen():
|
||||||
|
|
||||||
|
DEEPGRAM_URL = "https://brain.deepgram.com/v2/listen"
|
||||||
|
MODEL = "agara"
|
||||||
|
encoding = "linear16"
|
||||||
|
sample_rate = "8000"
|
||||||
|
# diarize = "false"
|
||||||
|
q_params = {
|
||||||
|
"model": MODEL,
|
||||||
|
"encoding": encoding,
|
||||||
|
"sample_rate": sample_rate,
|
||||||
|
"language": "en-US",
|
||||||
|
"multichannel": "false",
|
||||||
|
"punctuate": "true",
|
||||||
|
}
|
||||||
|
url = "{}?{}".format(DEEPGRAM_URL, urlencode(q_params))
|
||||||
|
# print(url)
|
||||||
|
creds = ("arjun@agaralabs.com", "PoX1Y@x4h%oS")
|
||||||
|
|
||||||
|
def deepgram_offline(audio_data):
|
||||||
|
request = Request(
|
||||||
|
url,
|
||||||
|
method="POST",
|
||||||
|
headers={
|
||||||
|
"Authorization": "Basic {}".format(
|
||||||
|
base64.b64encode("{}:{}".format(*creds).encode("utf-8")).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
},
|
||||||
|
data=audio_data,
|
||||||
|
)
|
||||||
|
with urlopen(request) as response:
|
||||||
|
msg = json.loads(response.read())
|
||||||
|
data = msg["results"]["channels"][0]["alternatives"][0]
|
||||||
|
return data["transcript"]
|
||||||
|
|
||||||
|
return deepgram_offline
|
||||||
|
|
||||||
|
|
||||||
|
class Text2Num(object):
|
||||||
|
"""docstring for Text2Num."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
numwords = {}
|
||||||
|
if not numwords:
|
||||||
|
units = [
|
||||||
|
"zero",
|
||||||
|
"one",
|
||||||
|
"two",
|
||||||
|
"three",
|
||||||
|
"four",
|
||||||
|
"five",
|
||||||
|
"six",
|
||||||
|
"seven",
|
||||||
|
"eight",
|
||||||
|
"nine",
|
||||||
|
"ten",
|
||||||
|
"eleven",
|
||||||
|
"twelve",
|
||||||
|
"thirteen",
|
||||||
|
"fourteen",
|
||||||
|
"fifteen",
|
||||||
|
"sixteen",
|
||||||
|
"seventeen",
|
||||||
|
"eighteen",
|
||||||
|
"nineteen",
|
||||||
|
]
|
||||||
|
|
||||||
|
tens = [
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"twenty",
|
||||||
|
"thirty",
|
||||||
|
"forty",
|
||||||
|
"fifty",
|
||||||
|
"sixty",
|
||||||
|
"seventy",
|
||||||
|
"eighty",
|
||||||
|
"ninety",
|
||||||
|
]
|
||||||
|
|
||||||
|
scales = ["hundred", "thousand", "million", "billion", "trillion"]
|
||||||
|
|
||||||
|
numwords["and"] = (1, 0)
|
||||||
|
for idx, word in enumerate(units):
|
||||||
|
numwords[word] = (1, idx)
|
||||||
|
for idx, word in enumerate(tens):
|
||||||
|
numwords[word] = (1, idx * 10)
|
||||||
|
for idx, word in enumerate(scales):
|
||||||
|
numwords[word] = (10 ** (idx * 3 or 2), 0)
|
||||||
|
self.numwords = numwords
|
||||||
|
|
||||||
|
def is_num(self, word):
|
||||||
|
return word in self.numwords
|
||||||
|
|
||||||
|
def parseOrdinal(self, utterance, **kwargs):
|
||||||
|
lookup_dict = {
|
||||||
|
"first": 1,
|
||||||
|
"second": 2,
|
||||||
|
"third": 3,
|
||||||
|
"fourth": 4,
|
||||||
|
"fifth": 5,
|
||||||
|
"sixth": 6,
|
||||||
|
"seventh": 7,
|
||||||
|
"eighth": 8,
|
||||||
|
"ninth": 9,
|
||||||
|
"tenth": 10,
|
||||||
|
"one": 1,
|
||||||
|
"two": 2,
|
||||||
|
"three": 3,
|
||||||
|
"four": 4,
|
||||||
|
"five": 5,
|
||||||
|
"six": 6,
|
||||||
|
"seven": 7,
|
||||||
|
"eight": 8,
|
||||||
|
"nine": 9,
|
||||||
|
"ten": 10,
|
||||||
|
"1": 1,
|
||||||
|
"2": 2,
|
||||||
|
"3": 3,
|
||||||
|
"4": 4,
|
||||||
|
"5": 5,
|
||||||
|
"6": 6,
|
||||||
|
"7": 7,
|
||||||
|
"8": 8,
|
||||||
|
"9": 9,
|
||||||
|
"10": 10,
|
||||||
|
"last": -1,
|
||||||
|
}
|
||||||
|
pattern = re.compile(
|
||||||
|
r"(\s|^)(?P<num>(first)|(third)|(fourth)|(fifth)|(sixth)|(seventh)|(eighth)|(ninth)|(tenth)|(two)|(three)|(four)|(five)|(six)|(seven)|(eight)|(nine)|(ten)|(1)|(2)|(3)|(4)|(5)|(6)|(7)|(8)|(9)|(10)|(last))(\s|$)",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
ordinal = ""
|
||||||
|
if pattern.search(utterance):
|
||||||
|
ordinal = pattern.search(utterance).groupdict()["num"].strip()
|
||||||
|
elif re.search(r"(\s|^)(?P<num>(second))(\s|$)", utterance):
|
||||||
|
ordinal = "second"
|
||||||
|
elif re.search(r"(\s|^)(?P<num>(one))(\s|$)", utterance):
|
||||||
|
ordinal = "one"
|
||||||
|
ordinal = lookup_dict.get(ordinal, "")
|
||||||
|
return ordinal
|
||||||
|
|
||||||
|
def convert(self, sent):
|
||||||
|
# res = []
|
||||||
|
# for token in sent.split():
|
||||||
|
# if token in self.numwords:
|
||||||
|
# res.append(str(self.text2int(token)))
|
||||||
|
# else:
|
||||||
|
# res.append(token)
|
||||||
|
# return " ".join(res)
|
||||||
|
|
||||||
|
return " ".join(
|
||||||
|
[
|
||||||
|
str(self.parseOrdinal(x)) if self.parseOrdinal(x) != "" else x
|
||||||
|
for x in sent.split()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def text2int(self, textnum):
|
||||||
|
|
||||||
|
current = result = 0
|
||||||
|
for word in textnum.split():
|
||||||
|
if word not in self.numwords:
|
||||||
|
raise Exception("Illegal word: " + word)
|
||||||
|
|
||||||
|
scale, increment = self.numwords[word]
|
||||||
|
current = current * scale + increment
|
||||||
|
if scale > 100:
|
||||||
|
result += current
|
||||||
|
current = 0
|
||||||
|
|
||||||
|
return result + current
|
||||||
|
|
||||||
|
|
||||||
|
def is_sub_sequence(str1, str2):
|
||||||
|
m = len(str1)
|
||||||
|
n = len(str2)
|
||||||
|
|
||||||
|
def check_seq(string1, string2, m, n):
|
||||||
|
# Base Cases
|
||||||
|
if m == 0:
|
||||||
|
return True
|
||||||
|
if n == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If last characters of two strings are matching
|
||||||
|
if string1[m - 1] == string2[n - 1]:
|
||||||
|
return check_seq(string1, string2, m - 1, n - 1)
|
||||||
|
|
||||||
|
# If last characters are not matching
|
||||||
|
return check_seq(string1, string2, m, n - 1)
|
||||||
|
|
||||||
|
return check_seq(str1, str2, m, n)
|
||||||
|
|
||||||
|
|
||||||
def parallel_apply(fn, iterable, workers=8):
|
def parallel_apply(fn, iterable, workers=8):
|
||||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||||
print(f"parallelly applying {fn}")
|
print(f"parallelly applying {fn}")
|
||||||
@@ -239,3 +688,12 @@ def parallel_apply(fn, iterable, workers=8):
|
|||||||
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
for c in random_pnr_generator():
|
||||||
|
print(c)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
|
alnum_to_asr_tokens,
|
||||||
ExtendedPath,
|
ExtendedPath,
|
||||||
asr_manifest_reader,
|
asr_manifest_reader,
|
||||||
asr_manifest_writer,
|
asr_manifest_writer,
|
||||||
@@ -16,7 +19,9 @@ from ..utils import (
|
|||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
def preprocess_datapoint(idx, rel_root, sample):
|
def preprocess_datapoint(
|
||||||
|
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
|
||||||
|
):
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
from nemo.collections.asr.metrics import word_error_rate
|
from nemo.collections.asr.metrics import word_error_rate
|
||||||
from jasper.client import transcribe_gen
|
from jasper.client import transcribe_gen
|
||||||
@@ -26,23 +31,37 @@ def preprocess_datapoint(idx, rel_root, sample):
|
|||||||
res["real_idx"] = idx
|
res["real_idx"] = idx
|
||||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
audio_path = rel_root / Path(sample["audio_filepath"])
|
||||||
res["audio_path"] = str(audio_path)
|
res["audio_path"] = str(audio_path)
|
||||||
|
if use_domain_asr:
|
||||||
|
res["spoken"] = alnum_to_asr_tokens(res["text"])
|
||||||
|
else:
|
||||||
|
res["spoken"] = res["text"]
|
||||||
res["utterance_id"] = audio_path.stem
|
res["utterance_id"] = audio_path.stem
|
||||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
if not annotation_only:
|
||||||
|
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||||
|
|
||||||
aud_seg = (
|
aud_seg = (
|
||||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||||
.set_channels(1)
|
.set_channels(1)
|
||||||
.set_sample_width(2)
|
.set_sample_width(2)
|
||||||
.set_frame_rate(24000)
|
.set_frame_rate(24000)
|
||||||
)
|
)
|
||||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||||
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
|
res["pretrained_wer"] = word_error_rate(
|
||||||
wav_plot_path = (
|
[res["text"]], [res["pretrained_asr"]]
|
||||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
)
|
||||||
)
|
if use_domain_asr:
|
||||||
if not wav_plot_path.exists():
|
transcriber_speller = transcribe_gen(asr_port=8045)
|
||||||
plot_seg(wav_plot_path, audio_path)
|
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||||
res["plot_path"] = str(wav_plot_path)
|
res["domain_wer"] = word_error_rate(
|
||||||
|
[res["spoken"]], [res["pretrained_asr"]]
|
||||||
|
)
|
||||||
|
if enable_plots:
|
||||||
|
wav_plot_path = (
|
||||||
|
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||||
|
)
|
||||||
|
if not wav_plot_path.exists():
|
||||||
|
plot_seg(wav_plot_path, audio_path)
|
||||||
|
res["plot_path"] = str(wav_plot_path)
|
||||||
return res
|
return res
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
||||||
@@ -50,59 +69,70 @@ def preprocess_datapoint(idx, rel_root, sample):
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def dump_ui(
|
def dump_ui(
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||||
dataset_dir: Path = Path("./data/asr_data"),
|
dataset_dir: Path = Path("./data/asr_data"),
|
||||||
dump_dir: Path = Path("./data/valiation_data"),
|
dump_dir: Path = Path("./data/valiation_data"),
|
||||||
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
|
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
|
||||||
|
use_domain_asr: bool = False,
|
||||||
|
annotation_only: bool = False,
|
||||||
|
enable_plots: bool = True,
|
||||||
):
|
):
|
||||||
from io import BytesIO
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pydub import AudioSegment
|
from functools import partial
|
||||||
from ..utils import ui_data_generator
|
|
||||||
|
|
||||||
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
|
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
|
||||||
|
dump_path: Path = dump_dir / Path(data_name) / dump_fname
|
||||||
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
||||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||||
|
with data_manifest_path.open("r") as pf:
|
||||||
|
pnr_jsonl = pf.readlines()
|
||||||
|
pnr_funcs = [
|
||||||
|
partial(
|
||||||
|
preprocess_datapoint,
|
||||||
|
i,
|
||||||
|
data_manifest_path.parent,
|
||||||
|
json.loads(v),
|
||||||
|
use_domain_asr,
|
||||||
|
annotation_only,
|
||||||
|
enable_plots,
|
||||||
|
)
|
||||||
|
for i, v in enumerate(pnr_jsonl)
|
||||||
|
]
|
||||||
|
|
||||||
def asr_data_source_gen():
|
def exec_func(f):
|
||||||
with data_manifest_path.open("r") as pf:
|
return f()
|
||||||
data_jsonl = pf.readlines()
|
|
||||||
for v in data_jsonl:
|
|
||||||
sample = json.loads(v)
|
|
||||||
rel_root = data_manifest_path.parent
|
|
||||||
res = dict(sample)
|
|
||||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
|
||||||
audio_segment = (
|
|
||||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
|
||||||
.set_channels(1)
|
|
||||||
.set_sample_width(2)
|
|
||||||
.set_frame_rate(24000)
|
|
||||||
)
|
|
||||||
wav_plot_path = (
|
|
||||||
rel_root
|
|
||||||
/ Path("wav_plots")
|
|
||||||
/ Path(audio_path.name).with_suffix(".png")
|
|
||||||
)
|
|
||||||
if not wav_plot_path.exists():
|
|
||||||
plot_seg(wav_plot_path, audio_path)
|
|
||||||
res["plot_path"] = str(wav_plot_path)
|
|
||||||
code_fb = BytesIO()
|
|
||||||
audio_segment.export(code_fb, format="wav")
|
|
||||||
wav_data = code_fb.getvalue()
|
|
||||||
duration = audio_segment.duration_seconds
|
|
||||||
asr_final = res["text"]
|
|
||||||
yield asr_final, duration, wav_data, "caller", audio_segment
|
|
||||||
|
|
||||||
dump_data, num_datapoints = ui_data_generator(
|
with ThreadPoolExecutor() as exe:
|
||||||
dataset_dir, data_name, asr_data_source_gen()
|
print("starting all preprocess tasks")
|
||||||
)
|
pnr_data = filter(
|
||||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
None,
|
||||||
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
|
list(
|
||||||
|
tqdm(
|
||||||
|
exe.map(exec_func, pnr_funcs),
|
||||||
|
position=0,
|
||||||
|
leave=True,
|
||||||
|
total=len(pnr_funcs),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if annotation_only:
|
||||||
|
result = list(pnr_data)
|
||||||
|
else:
|
||||||
|
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
||||||
|
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
||||||
|
ui_config = {
|
||||||
|
"use_domain_asr": use_domain_asr,
|
||||||
|
"annotation_only": annotation_only,
|
||||||
|
"enable_plots": enable_plots,
|
||||||
|
"data": result,
|
||||||
|
}
|
||||||
|
ExtendedPath(dump_path).write_json(ui_config)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def sample_ui(
|
def sample_ui(
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
dump_file: Path = Path("ui_dump.json"),
|
dump_file: Path = Path("ui_dump.json"),
|
||||||
sample_count: int = typer.Option(80, show_default=True),
|
sample_count: int = typer.Option(80, show_default=True),
|
||||||
@@ -126,9 +156,50 @@ def sample_ui(
|
|||||||
ExtendedPath(sample_path).write_json(processed_data)
|
ExtendedPath(sample_path).write_json(processed_data)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def sample_asr_accuracy(
|
||||||
|
data_name: str = typer.Option(
|
||||||
|
"png_06_2020_week1_numbers_window_customer", show_default=True
|
||||||
|
),
|
||||||
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
|
sample_file: Path = Path("sample_dump.json"),
|
||||||
|
asr_service: str = "deepgram",
|
||||||
|
):
|
||||||
|
# import pandas as pd
|
||||||
|
# from pydub import AudioSegment
|
||||||
|
from ..utils import is_sub_sequence, Text2Num
|
||||||
|
|
||||||
|
# from ..utils import deepgram_transcribe_gen
|
||||||
|
#
|
||||||
|
# deepgram_transcriber = deepgram_transcribe_gen()
|
||||||
|
t2n = Text2Num()
|
||||||
|
# processed_data_path = dump_dir / Path(data_name) / dump_file
|
||||||
|
sample_path = dump_dir / Path(data_name) / sample_file
|
||||||
|
processed_data = ExtendedPath(sample_path).read_json()
|
||||||
|
# asr_data = []
|
||||||
|
match_count, total_samples = 0, len(processed_data["data"])
|
||||||
|
for dp in tqdm(processed_data["data"]):
|
||||||
|
# aud_data = Path(dp["audio_path"]).read_bytes()
|
||||||
|
# dgram_result = deepgram_transcriber(aud_data)
|
||||||
|
# dp["deepgram_asr"] = dgram_result
|
||||||
|
gcp_num = dp["text"]
|
||||||
|
dgm_num = t2n.convert(dp["deepgram_asr"].lower())
|
||||||
|
if is_sub_sequence(gcp_num, dgm_num):
|
||||||
|
match_count += 1
|
||||||
|
print(f"MATCH GCP:{gcp_num}\tDGM:{dgm_num}")
|
||||||
|
else:
|
||||||
|
print(f"FAIL GCP:{gcp_num}\tDGM:{dgm_num}")
|
||||||
|
# asr_data.append(dp)
|
||||||
|
typer.echo(
|
||||||
|
f"{match_count} from deepgram matches with {total_samples} gcp transcripts."
|
||||||
|
)
|
||||||
|
# processed_data["data"] = asr_data
|
||||||
|
# ExtendedPath(sample_path).write_json(processed_data)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def task_ui(
|
def task_ui(
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
dump_file: Path = Path("ui_dump.json"),
|
dump_file: Path = Path("ui_dump.json"),
|
||||||
task_count: int = typer.Option(4, show_default=True),
|
task_count: int = typer.Option(4, show_default=True),
|
||||||
@@ -152,7 +223,7 @@ def task_ui(
|
|||||||
@app.command()
|
@app.command()
|
||||||
def dump_corrections(
|
def dump_corrections(
|
||||||
task_uid: str,
|
task_uid: str,
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
dump_fname: Path = Path("corrections.json"),
|
dump_fname: Path = Path("corrections.json"),
|
||||||
):
|
):
|
||||||
@@ -170,7 +241,7 @@ def dump_corrections(
|
|||||||
@app.command()
|
@app.command()
|
||||||
def caller_quality(
|
def caller_quality(
|
||||||
task_uid: str,
|
task_uid: str,
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
dump_fname: Path = Path("ui_dump.json"),
|
dump_fname: Path = Path("ui_dump.json"),
|
||||||
correction_fname: Path = Path("corrections.json"),
|
correction_fname: Path = Path("corrections.json"),
|
||||||
@@ -206,7 +277,7 @@ def caller_quality(
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def fill_unannotated(
|
def fill_unannotated(
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||||
dump_dir: Path = Path("./data/valiation_data"),
|
dump_dir: Path = Path("./data/valiation_data"),
|
||||||
dump_file: Path = Path("ui_dump.json"),
|
dump_file: Path = Path("ui_dump.json"),
|
||||||
corrections_file: Path = Path("corrections.json"),
|
corrections_file: Path = Path("corrections.json"),
|
||||||
@@ -227,9 +298,16 @@ def fill_unannotated(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionType(str, Enum):
|
||||||
|
date = "dates"
|
||||||
|
city = "cities"
|
||||||
|
name = "names"
|
||||||
|
all = "all"
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def split_extract(
|
def split_extract(
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||||
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
|
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
|
||||||
# dump_dir: Path = Path("./data/valiation_data"),
|
# dump_dir: Path = Path("./data/valiation_data"),
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
@@ -239,7 +317,7 @@ def split_extract(
|
|||||||
conv_data_path: Path = typer.Option(
|
conv_data_path: Path = typer.Option(
|
||||||
Path("./data/conv_data.json"), show_default=True
|
Path("./data/conv_data.json"), show_default=True
|
||||||
),
|
),
|
||||||
extraction_type: str = "all",
|
extraction_type: ExtractionType = ExtractionType.all,
|
||||||
):
|
):
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
@@ -303,7 +381,7 @@ def split_extract(
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def update_corrections(
|
def update_corrections(
|
||||||
data_name: str = typer.Option("dataname", show_default=True),
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
manifest_file: Path = Path("manifest.json"),
|
manifest_file: Path = Path("manifest.json"),
|
||||||
corrections_file: Path = Path("corrections.json"),
|
corrections_file: Path = Path("corrections.json"),
|
||||||
|
|||||||
@@ -42,9 +42,7 @@ if not hasattr(st, "mongo_connected"):
|
|||||||
upsert=True,
|
upsert=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_task_fn(mf_path, task_id):
|
def set_task_fn(mf_path):
|
||||||
if task_id:
|
|
||||||
st.task_id = task_id
|
|
||||||
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
|
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
|
||||||
if not task_path.exists():
|
if not task_path.exists():
|
||||||
print(f"creating task lock at {task_path}")
|
print(f"creating task lock at {task_path}")
|
||||||
@@ -68,22 +66,26 @@ def load_ui_data(validation_ui_data_path: Path):
|
|||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def main(manifest: Path, task_id: str = ""):
|
def main(manifest: Path):
|
||||||
st.set_task(manifest, task_id)
|
st.set_task(manifest)
|
||||||
ui_config = load_ui_data(manifest)
|
ui_config = load_ui_data(manifest)
|
||||||
asr_data = ui_config["data"]
|
asr_data = ui_config["data"]
|
||||||
|
use_domain_asr = ui_config.get("use_domain_asr", True)
|
||||||
annotation_only = ui_config.get("annotation_only", False)
|
annotation_only = ui_config.get("annotation_only", False)
|
||||||
|
enable_plots = ui_config.get("enable_plots", True)
|
||||||
sample_no = st.get_current_cursor()
|
sample_no = st.get_current_cursor()
|
||||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||||
print("Invalid samplno resetting to 0")
|
print("Invalid samplno resetting to 0")
|
||||||
st.update_cursor(0)
|
st.update_cursor(0)
|
||||||
sample = asr_data[sample_no]
|
sample = asr_data[sample_no]
|
||||||
|
title_type = "Speller " if use_domain_asr else ""
|
||||||
task_uid = st.task_id.rsplit("-", 1)[1]
|
task_uid = st.task_id.rsplit("-", 1)[1]
|
||||||
if annotation_only:
|
if annotation_only:
|
||||||
st.title(f"ASR Annotation - # {task_uid}")
|
st.title(f"ASR Annotation - # {task_uid}")
|
||||||
else:
|
else:
|
||||||
st.title(f"ASR Validation - # {task_uid}")
|
st.title(f"ASR {title_type}Validation - # {task_uid}")
|
||||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
||||||
|
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
||||||
new_sample = st.number_input(
|
new_sample = st.number_input(
|
||||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||||
)
|
)
|
||||||
@@ -92,13 +94,19 @@ def main(manifest: Path, task_id: str = ""):
|
|||||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||||
if not annotation_only:
|
if not annotation_only:
|
||||||
|
if use_domain_asr:
|
||||||
|
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
||||||
st.sidebar.title("Results:")
|
st.sidebar.title("Results:")
|
||||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||||
if "caller" in sample:
|
if "caller" in sample:
|
||||||
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
||||||
|
if use_domain_asr:
|
||||||
|
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
||||||
|
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
||||||
else:
|
else:
|
||||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
if enable_plots:
|
||||||
|
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
st.audio(Path(sample["audio_path"]).open("rb"))
|
||||||
# set default to text
|
# set default to text
|
||||||
corrected = sample["text"]
|
corrected = sample["text"]
|
||||||
@@ -120,12 +128,16 @@ def main(manifest: Path, task_id: str = ""):
|
|||||||
)
|
)
|
||||||
st.update_cursor(sample_no + 1)
|
st.update_cursor(sample_no + 1)
|
||||||
if correction_entry:
|
if correction_entry:
|
||||||
status = correction_entry["value"]["status"]
|
st.markdown(
|
||||||
correction = correction_entry["value"]["correction"]
|
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
||||||
st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
|
)
|
||||||
text_sample = st.text_input("Go to Text:", value="")
|
text_sample = st.text_input("Go to Text:", value="")
|
||||||
if text_sample != "":
|
if text_sample != "":
|
||||||
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
|
candidates = [
|
||||||
|
i
|
||||||
|
for (i, p) in enumerate(asr_data)
|
||||||
|
if p["text"] == text_sample or p["spoken"] == text_sample
|
||||||
|
]
|
||||||
if len(candidates) > 0:
|
if len(candidates) > 0:
|
||||||
st.update_cursor(candidates[0])
|
st.update_cursor(candidates[0])
|
||||||
real_idx = st.number_input(
|
real_idx = st.number_input(
|
||||||
|
|||||||
58
manifest_preview.py
Normal file
58
manifest_preview.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
import typer
|
||||||
|
from jasper.data.utils import ExtendedPath
|
||||||
|
from jasper.data.validation.st_rerun import rerun
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
if not hasattr(st, "mongo_connected"):
|
||||||
|
# st.task_id = str(uuid4())
|
||||||
|
task_path = ExtendedPath("preview.lck")
|
||||||
|
|
||||||
|
def current_cursor_fn():
|
||||||
|
return task_path.read_json()["current_cursor"]
|
||||||
|
|
||||||
|
def update_cursor_fn(val=0):
|
||||||
|
task_path.write_json({"current_cursor": val})
|
||||||
|
rerun()
|
||||||
|
|
||||||
|
st.get_current_cursor = current_cursor_fn
|
||||||
|
st.update_cursor = update_cursor_fn
|
||||||
|
st.mongo_connected = True
|
||||||
|
# cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
|
||||||
|
# if not cursor_obj:
|
||||||
|
update_cursor_fn(0)
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache()
|
||||||
|
def load_ui_data(validation_ui_data_path: Path):
|
||||||
|
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||||
|
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(manifest: Path):
|
||||||
|
asr_data = load_ui_data(manifest)
|
||||||
|
sample_no = st.get_current_cursor()
|
||||||
|
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||||
|
print("Invalid samplno resetting to 0")
|
||||||
|
st.update_cursor(0)
|
||||||
|
sample = asr_data[sample_no]
|
||||||
|
st.title(f"ASR Manifest Preview")
|
||||||
|
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||||
|
new_sample = st.number_input(
|
||||||
|
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||||
|
)
|
||||||
|
if new_sample != sample_no + 1:
|
||||||
|
st.update_cursor(new_sample - 1)
|
||||||
|
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||||
|
st.audio(Path(sample["audio_filepath"]).open("rb"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
app()
|
||||||
|
except SystemExit:
|
||||||
|
pass
|
||||||
11
setup.py
11
setup.py
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|||||||
|
|
||||||
requirements = [
|
requirements = [
|
||||||
"ruamel.yaml",
|
"ruamel.yaml",
|
||||||
"torch==2.8.0",
|
"torch==1.4.0",
|
||||||
"torchvision==0.5.0",
|
"torchvision==0.5.0",
|
||||||
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
||||||
]
|
]
|
||||||
@@ -19,15 +19,13 @@ extra_requirements = {
|
|||||||
"ruamel.yaml==0.16.10",
|
"ruamel.yaml==0.16.10",
|
||||||
"pymongo==3.10.1",
|
"pymongo==3.10.1",
|
||||||
"librosa==0.7.2",
|
"librosa==0.7.2",
|
||||||
"numba==0.48",
|
|
||||||
"matplotlib==3.2.1",
|
"matplotlib==3.2.1",
|
||||||
"pandas==1.0.3",
|
"pandas==1.0.3",
|
||||||
"tabulate==0.8.7",
|
"tabulate==0.8.7",
|
||||||
"natural==0.2.0",
|
"natural==0.2.0",
|
||||||
"num2words==0.5.10",
|
"num2words==0.5.10",
|
||||||
"typer[all]==0.3.1",
|
"typer[all]==0.1.1",
|
||||||
"python-slugify==4.0.0",
|
"python-slugify==4.0.0",
|
||||||
"rpyc~=4.1.4",
|
|
||||||
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
|
||||||
],
|
],
|
||||||
"validation": [
|
"validation": [
|
||||||
@@ -70,7 +68,10 @@ setup(
|
|||||||
"jasper_data_tts_generate = jasper.data.tts_generator:main",
|
"jasper_data_tts_generate = jasper.data.tts_generator:main",
|
||||||
"jasper_data_conv_generate = jasper.data.conv_generator:main",
|
"jasper_data_conv_generate = jasper.data.conv_generator:main",
|
||||||
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
|
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
|
||||||
"jasper_data_rastrik_recycle = jasper.data.rastrik_recycler:main",
|
"jasper_data_test_generate = jasper.data.test_generator:main",
|
||||||
|
"jasper_data_call_recycle = jasper.data.call_recycler:main",
|
||||||
|
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
|
||||||
|
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
|
||||||
"jasper_data_server = jasper.data.server:main",
|
"jasper_data_server = jasper.data.server:main",
|
||||||
"jasper_data_validation = jasper.data.validation.process:main",
|
"jasper_data_validation = jasper.data.validation.process:main",
|
||||||
"jasper_data_preprocess = jasper.data.process:main",
|
"jasper_data_preprocess = jasper.data.process:main",
|
||||||
|
|||||||
Reference in New Issue
Block a user