1. added a test generator and slu evaluator
2. ui dump now include gcp results 3. showing default option for more args validation process commands
parent
515e9c1037
commit
069392d098
|
|
@ -147,7 +147,7 @@ def analyze(
|
|||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
from tqdm import tqdm
|
||||
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll
|
||||
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll, get_call_logs
|
||||
from pydub import AudioSegment
|
||||
from natural.date import compress
|
||||
|
||||
|
|
@ -170,18 +170,6 @@ def analyze(
|
|||
|
||||
call_logs = yaml.load(call_logs_file.read_text())
|
||||
|
||||
def get_call_meta(call_obj):
|
||||
meta_s3_uri = call_obj["DataURI"]
|
||||
s3_event_url_p = urlsplit(meta_s3_uri)
|
||||
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
|
||||
if not saved_meta_path.exists():
|
||||
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
|
||||
s3.download_file(
|
||||
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
|
||||
)
|
||||
call_metas = json.load(saved_meta_path.open())
|
||||
return call_metas
|
||||
|
||||
def gen_ev_fev_timedelta(fev):
|
||||
fev_p = Timestamp()
|
||||
fev_p.FromJsonString(fev["CreatedTS"])
|
||||
|
|
@ -283,7 +271,7 @@ def analyze(
|
|||
return spoken
|
||||
|
||||
def process_call(call_obj):
|
||||
call_meta = get_call_meta(call_obj)
|
||||
call_meta = get_call_logs(call_obj, s3, call_meta_dir)
|
||||
call_events = call_meta["Events"]
|
||||
|
||||
def is_writer_uri_event(ev):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,180 @@
|
|||
import typer
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
# from .utils import generate_dates, asr_test_writer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def run_test(reg_path, coll, s3, call_meta_dir, city_code, test_path):
|
||||
from time import sleep
|
||||
import subprocess
|
||||
from .utils import ExtendedPath, get_call_logs
|
||||
|
||||
coll.delete_many({"CallID": test_path.name})
|
||||
# test_path = dump_dir / data_name / test_file
|
||||
# "../saas_reg/regression/run.sh -f data/asr_data/call_upwork_test_cnd_cities/asr_test.reg"
|
||||
test_output = subprocess.run(
|
||||
["/bin/bash", "-c", f"{str(reg_path)} --addr [::]:15400 -f {str(test_path)}"]
|
||||
)
|
||||
if test_output.returncode != 0:
|
||||
print("Error running test {test_file}")
|
||||
return
|
||||
|
||||
def get_meta():
|
||||
call_meta = coll.find_one({"CallID": test_path.name})
|
||||
if call_meta:
|
||||
return call_meta
|
||||
else:
|
||||
sleep(2)
|
||||
return get_meta()
|
||||
|
||||
call_meta = get_meta()
|
||||
call_logs = get_call_logs(call_meta, s3, call_meta_dir)
|
||||
call_events = call_logs["Events"]
|
||||
|
||||
test_data_path = test_path.with_suffix(".result.json")
|
||||
test_data = ExtendedPath(test_data_path).read_json()
|
||||
|
||||
def is_final_asr_event_or_spoken(ev):
|
||||
pld = json.loads(ev["Payload"])
|
||||
return (
|
||||
pld["AsrResult"]["Results"][0]["IsFinal"]
|
||||
if ev["Type"] == "ASR_RESULT"
|
||||
else False
|
||||
)
|
||||
|
||||
def is_test_event(ev):
|
||||
return (
|
||||
ev["Author"] == "NLU"
|
||||
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
|
||||
) and (ev["Type"] != "DEBUG")
|
||||
|
||||
test_evs = list(filter(is_test_event, call_events))
|
||||
if len(test_evs) == 2:
|
||||
try:
|
||||
asr_payload = test_evs[0]["Payload"]
|
||||
asr_result = json.loads(asr_payload)["AsrResult"]["Results"][0]
|
||||
alt_tscripts = [alt["Transcript"] for alt in asr_result["Alternatives"]]
|
||||
gcp_result = "|".join(alt_tscripts)
|
||||
entity_asr = asr_result["AsrDynamicResults"][0]["Candidate"]["Transcript"]
|
||||
nlu_payload = test_evs[1]["Payload"]
|
||||
nlu_result_payload = json.loads(nlu_payload)["NluResults"]
|
||||
entity = test_data[0]["entity"]
|
||||
text = test_data[0]["text"]
|
||||
audio_filepath = test_data[0]["audio_filepath"]
|
||||
pretrained_asr = test_data[0]["pretrained_asr"]
|
||||
nlu_entity = list(json.loads(nlu_result_payload)["Entities"].values())[0]
|
||||
asr_entity = city_code[entity] if entity in city_code else "UNKNOWN"
|
||||
entities_match = asr_entity == nlu_entity
|
||||
result = "Success" if entities_match else "Fail"
|
||||
return {
|
||||
"expected_entity": entity,
|
||||
"text": text,
|
||||
"audio_filepath": audio_filepath,
|
||||
"pretrained_asr": pretrained_asr,
|
||||
"entity_asr": entity_asr,
|
||||
"google_asr": gcp_result,
|
||||
"nlu_result": nlu_result_payload,
|
||||
"asr_entity": asr_entity,
|
||||
"nlu_entity": nlu_entity,
|
||||
"result": result,
|
||||
}
|
||||
except Exception:
|
||||
return {
|
||||
"expected_entity": test_data[0]["entity"],
|
||||
"text": test_data[0]["text"],
|
||||
"audio_filepath": test_data[0]["audio_filepath"],
|
||||
"pretrained_asr": test_data[0]["pretrained_asr"],
|
||||
"entity_asr": "",
|
||||
"google_asr": "",
|
||||
"nlu_result": "",
|
||||
"asr_entity": "",
|
||||
"nlu_entity": "",
|
||||
"result": "Error",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"expected_entity": test_data[0]["entity"],
|
||||
"text": test_data[0]["text"],
|
||||
"audio_filepath": test_data[0]["audio_filepath"],
|
||||
"pretrained_asr": test_data[0]["pretrained_asr"],
|
||||
"entity_asr": "",
|
||||
"google_asr": "",
|
||||
"nlu_result": "",
|
||||
"asr_entity": "",
|
||||
"nlu_entity": "",
|
||||
"result": "Empty",
|
||||
}
|
||||
|
||||
|
||||
@app.command()
|
||||
def evaluate_slu(
|
||||
# conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
|
||||
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
|
||||
# extraction_key: str = "Cities",
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
call_meta_dir: Path = Path("./data/call_metas"),
|
||||
test_file_pref: str = "asr_test",
|
||||
mongo_uri: str = typer.Option(
|
||||
"mongodb://localhost:27017/test.calls", show_default=True
|
||||
),
|
||||
test_results: Path = Path("./data/results.csv"),
|
||||
airport_codes: Path = Path("./airports_code.csv"),
|
||||
reg_path: Path = Path("../saas_reg/regression/run.sh"),
|
||||
test_id: str = "5ef481f27031edf6910e94e0",
|
||||
):
|
||||
# import json
|
||||
from .utils import get_mongo_coll
|
||||
import pandas as pd
|
||||
import boto3
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
# import subprocess
|
||||
# from time import sleep
|
||||
import csv
|
||||
from tqdm import tqdm
|
||||
|
||||
s3 = boto3.client("s3")
|
||||
df = pd.read_csv(airport_codes)[["iata", "city"]]
|
||||
city_code = pd.Series(df["iata"].values, index=df["city"]).to_dict()
|
||||
|
||||
test_files = list((dump_dir / data_name).glob(test_file_pref + "*.reg"))
|
||||
coll = get_mongo_coll(mongo_uri)
|
||||
with test_results.open("w") as csvfile:
|
||||
fieldnames = [
|
||||
"expected_entity",
|
||||
"text",
|
||||
"audio_filepath",
|
||||
"pretrained_asr",
|
||||
"entity_asr",
|
||||
"google_asr",
|
||||
"nlu_result",
|
||||
"asr_entity",
|
||||
"nlu_entity",
|
||||
"result",
|
||||
]
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
with ThreadPoolExecutor(max_workers=8) as exe:
|
||||
print("starting all loading tasks")
|
||||
for test_result in tqdm(
|
||||
exe.map(
|
||||
partial(run_test, reg_path, coll, s3, call_meta_dir, city_code),
|
||||
test_files,
|
||||
),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=len(test_files),
|
||||
):
|
||||
writer.writerow(test_result)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
import typer
|
||||
from pathlib import Path
|
||||
from .utils import generate_dates, asr_test_writer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_test_reg(
|
||||
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
|
||||
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
|
||||
extraction_key: str = "Cities",
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
test_file: Path = Path("asr_test.reg"),
|
||||
):
|
||||
from .utils import (
|
||||
ExtendedPath,
|
||||
asr_manifest_reader,
|
||||
gcp_transcribe_gen,
|
||||
parallel_apply,
|
||||
)
|
||||
from ..client import transcribe_gen
|
||||
from pydub import AudioSegment
|
||||
from queue import PriorityQueue
|
||||
|
||||
jasper_map = {
|
||||
"PNRs": 8045,
|
||||
"Cities": 8046,
|
||||
"Names": 8047,
|
||||
"Dates": 8048,
|
||||
}
|
||||
# jasper_map = {"PNRs": 8050, "Cities": 8050, "Names": 8050, "Dates": 8050}
|
||||
transcriber_gcp = gcp_transcribe_gen()
|
||||
transcriber_trained = transcribe_gen(asr_port=jasper_map[extraction_key])
|
||||
transcriber_all_trained = transcribe_gen(asr_port=8050)
|
||||
transcriber_libri_all_trained = transcribe_gen(asr_port=8051)
|
||||
|
||||
def find_ent(dd, conv_data):
|
||||
ents = PriorityQueue()
|
||||
for ent in conv_data:
|
||||
if ent in dd["text"]:
|
||||
ents.put((-len(ent), ent))
|
||||
return ents.get_nowait()[1]
|
||||
|
||||
def process_data(d):
|
||||
orig_seg = AudioSegment.from_wav(d["audio_path"])
|
||||
jas_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(24000)
|
||||
gcp_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
deepgram_file = Path("/home/shubham/voice_auto/pnrs/wav/") / Path(
|
||||
d["audio_path"].stem + ".txt"
|
||||
)
|
||||
if deepgram_file.exists():
|
||||
d["deepgram"] = "".join(
|
||||
[s.replace("CHANNEL 0:", "") for s in deepgram_file.read_text().split("\n")]
|
||||
)
|
||||
else:
|
||||
d["deepgram"] = 'Not Found'
|
||||
d["audio_path"] = str(d["audio_path"])
|
||||
d["gcp_transcript"] = transcriber_gcp(gcp_seg.raw_data)
|
||||
d["jasper_trained"] = transcriber_trained(jas_seg.raw_data)
|
||||
d["jasper_all"] = transcriber_all_trained(jas_seg.raw_data)
|
||||
d["jasper_libri"] = transcriber_libri_all_trained(jas_seg.raw_data)
|
||||
return d
|
||||
|
||||
conv_data = ExtendedPath(conv_src).read_json()
|
||||
conv_data["Dates"] = generate_dates()
|
||||
|
||||
dump_data_path = dump_dir / Path(data_name) / dump_file
|
||||
ui_dump_data = ExtendedPath(dump_data_path).read_json()["data"]
|
||||
ui_dump_map = {i["utterance_id"]: i for i in ui_dump_data}
|
||||
manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||
test_points = list(asr_manifest_reader(manifest_path))
|
||||
test_data_objs = [{**(ui_dump_map[t["audio_path"].stem]), **t} for t in test_points]
|
||||
test_data = parallel_apply(process_data, test_data_objs)
|
||||
# test_data = [process_data(t) for t in test_data_objs]
|
||||
test_path = dump_dir / Path(data_name) / test_file
|
||||
|
||||
def dd_gen(dump_data):
|
||||
for dd in dump_data:
|
||||
ent = find_ent(dd, conv_data[extraction_key])
|
||||
dd["entity"] = ent
|
||||
if ent:
|
||||
yield dd
|
||||
|
||||
asr_test_writer(test_path, dd_gen(test_data))
|
||||
# for i, b in enumerate(batch(test_data, 1)):
|
||||
# test_fname = Path(f"{test_file.stem}_{i}.reg")
|
||||
# test_path = dump_dir / Path(data_name) / test_fname
|
||||
# asr_test_writer(test_path, dd_gen(test_data))
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -7,6 +7,7 @@ from itertools import product
|
|||
from functools import partial
|
||||
from math import floor
|
||||
from uuid import uuid4
|
||||
from urllib.parse import urlsplit
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -99,6 +100,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
|||
"data": [],
|
||||
}
|
||||
data_funcs = []
|
||||
transcriber_gcp = gcp_transcribe_gen()
|
||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||
with asr_manifest.open("w") as mf:
|
||||
print(f"writing manifest to {asr_manifest}")
|
||||
|
|
@ -115,6 +117,8 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
|||
rel_pnr_path,
|
||||
):
|
||||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
||||
gcp_seg = aud_seg.set_frame_rate(16000)
|
||||
gcp_result = transcriber_gcp(gcp_seg.raw_data)
|
||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
||||
wav_plot_path = (
|
||||
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
||||
|
|
@ -130,6 +134,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
|||
"spoken": transcript,
|
||||
"caller": caller_name,
|
||||
"utterance_id": fname,
|
||||
"gcp_asr": gcp_result,
|
||||
"pretrained_asr": pretrained_result,
|
||||
"pretrained_wer": pretrained_wer,
|
||||
"plot_path": str(wav_plot_path),
|
||||
|
|
@ -194,6 +199,32 @@ def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
|
|||
mf.write(manifest)
|
||||
|
||||
|
||||
def asr_test_writer(out_file_path: Path, source):
|
||||
def dd_str(dd, idx):
|
||||
path = dd["audio_filepath"]
|
||||
# dur = dd["duration"]
|
||||
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
|
||||
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
|
||||
|
||||
res_file = out_file_path.with_suffix(".result.json")
|
||||
with out_file_path.open("w") as of:
|
||||
print(f"opening {out_file_path} for writing test")
|
||||
results = []
|
||||
idx = 0
|
||||
for ui_dd in source:
|
||||
results.append(ui_dd)
|
||||
out_str = dd_str(ui_dd, idx)
|
||||
of.write(out_str)
|
||||
idx += 1
|
||||
of.write("DO_HANGUP\n")
|
||||
ExtendedPath(res_file).write_json(results)
|
||||
|
||||
|
||||
def batch(iterable, n=1):
|
||||
ls = len(iterable)
|
||||
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
|
||||
|
||||
|
||||
class ExtendedPath(type(Path())):
|
||||
"""docstring for ExtendedPath."""
|
||||
|
||||
|
|
@ -278,6 +309,181 @@ def generate_dates():
|
|||
return [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
|
||||
|
||||
|
||||
def get_call_logs(call_obj, s3, call_meta_dir):
|
||||
meta_s3_uri = call_obj["DataURI"]
|
||||
s3_event_url_p = urlsplit(meta_s3_uri)
|
||||
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
|
||||
if not saved_meta_path.exists():
|
||||
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
|
||||
s3.download_file(
|
||||
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
|
||||
)
|
||||
call_metas = json.load(saved_meta_path.open())
|
||||
return call_metas
|
||||
|
||||
|
||||
def gcp_transcribe_gen():
|
||||
from google.cloud import speech_v1
|
||||
from google.cloud.speech_v1 import enums
|
||||
|
||||
# import io
|
||||
client = speech_v1.SpeechClient()
|
||||
# local_file_path = 'resources/brooklyn_bridge.raw'
|
||||
|
||||
# The language of the supplied audio
|
||||
language_code = "en-US"
|
||||
model = "phone_call"
|
||||
|
||||
# Sample rate in Hertz of the audio data sent
|
||||
sample_rate_hertz = 16000
|
||||
|
||||
# Encoding of audio data sent. This sample sets this explicitly.
|
||||
# This field is optional for FLAC and WAV audio formats.
|
||||
encoding = enums.RecognitionConfig.AudioEncoding.LINEAR16
|
||||
config = {
|
||||
"language_code": language_code,
|
||||
"sample_rate_hertz": sample_rate_hertz,
|
||||
"encoding": encoding,
|
||||
"model": model,
|
||||
"enable_automatic_punctuation": True,
|
||||
"max_alternatives": 10,
|
||||
"enable_word_time_offsets": True, # used to detect start and end time of utterances
|
||||
"speech_contexts": [
|
||||
{
|
||||
"phrases": [
|
||||
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||
"$OOV_CLASS_DIGIT_SEQUENCE",
|
||||
"$TIME",
|
||||
"$YEAR",
|
||||
]
|
||||
},
|
||||
{
|
||||
"phrases": [
|
||||
"A",
|
||||
"B",
|
||||
"C",
|
||||
"D",
|
||||
"E",
|
||||
"F",
|
||||
"G",
|
||||
"H",
|
||||
"I",
|
||||
"J",
|
||||
"K",
|
||||
"L",
|
||||
"M",
|
||||
"N",
|
||||
"O",
|
||||
"P",
|
||||
"Q",
|
||||
"R",
|
||||
"S",
|
||||
"T",
|
||||
"U",
|
||||
"V",
|
||||
"W",
|
||||
"X",
|
||||
"Y",
|
||||
"Z",
|
||||
]
|
||||
},
|
||||
{
|
||||
"phrases": [
|
||||
"PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||
"my PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||
"my PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||
"PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||
"It's $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE is my PNR",
|
||||
]
|
||||
},
|
||||
{"phrases": ["my name is"]},
|
||||
{"phrases": ["Number $ORDINAL", "Numeral $ORDINAL"]},
|
||||
{
|
||||
"phrases": [
|
||||
"John Smith",
|
||||
"Carina Hu",
|
||||
"Travis Lim",
|
||||
"Marvin Tan",
|
||||
"Samuel Tan",
|
||||
"Dawn Mathew",
|
||||
"Dawn",
|
||||
"Mathew",
|
||||
]
|
||||
},
|
||||
{
|
||||
"phrases": [
|
||||
"Beijing",
|
||||
"Tokyo",
|
||||
"London",
|
||||
"19 August",
|
||||
"7 October",
|
||||
"11 December",
|
||||
"17 September",
|
||||
"19th August",
|
||||
"7th October",
|
||||
"11th December",
|
||||
"17th September",
|
||||
"ABC123",
|
||||
"KWXUNP",
|
||||
"XLU5K1",
|
||||
"WL2JV6",
|
||||
"KBS651",
|
||||
]
|
||||
},
|
||||
{
|
||||
"phrases": [
|
||||
"first flight",
|
||||
"second flight",
|
||||
"third flight",
|
||||
"first option",
|
||||
"second option",
|
||||
"third option",
|
||||
"first one",
|
||||
"second one",
|
||||
"third one",
|
||||
]
|
||||
},
|
||||
],
|
||||
"metadata": {
|
||||
"industry_naics_code_of_audio": 481111,
|
||||
"interaction_type": enums.RecognitionMetadata.InteractionType.PHONE_CALL,
|
||||
},
|
||||
}
|
||||
|
||||
def sample_recognize(content):
|
||||
"""
|
||||
Transcribe a short audio file using synchronous speech recognition
|
||||
|
||||
Args:
|
||||
local_file_path Path to local audio file, e.g. /path/audio.wav
|
||||
"""
|
||||
|
||||
# with io.open(local_file_path, "rb") as f:
|
||||
# content = f.read()
|
||||
audio = {"content": content}
|
||||
|
||||
response = client.recognize(config, audio)
|
||||
for result in response.results:
|
||||
# First alternative is the most probable result
|
||||
return "/".join([alt.transcript for alt in result.alternatives])
|
||||
# print(u"Transcript: {}".format(alternative.transcript))
|
||||
return ""
|
||||
|
||||
return sample_recognize
|
||||
|
||||
|
||||
def parallel_apply(fn, iterable, workers=8):
|
||||
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||
print(f"parallelly applying {fn}")
|
||||
return [
|
||||
res
|
||||
for res in tqdm(
|
||||
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
for c in random_pnr_generator():
|
||||
print(c)
|
||||
|
|
|
|||
|
|
@ -271,7 +271,7 @@ def split_extract(
|
|||
dump_file: Path = Path("ui_dump.json"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
||||
conv_data_path: Path = Path("./data/conv_data.json"),
|
||||
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
||||
extraction_type: ExtractionType = ExtractionType.all,
|
||||
):
|
||||
import shutil
|
||||
|
|
@ -299,10 +299,16 @@ def split_extract(
|
|||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||
|
||||
ui_data_path = dump_dir / Path(data_name) / dump_file
|
||||
ui_data = json.load(ui_data_path.open())["data"]
|
||||
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
||||
ui_data = orig_ui_data["data"]
|
||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
||||
ExtendedPath(dest_ui_path).write_json(extracted_ui_data)
|
||||
final_data = []
|
||||
for i, d in enumerate(extracted_ui_data):
|
||||
d['real_idx'] = i
|
||||
final_data.append(d)
|
||||
orig_ui_data['data'] = final_data
|
||||
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
||||
|
||||
if corrections_file:
|
||||
dest_correction_path = dest_data_dir / corrections_file
|
||||
|
|
@ -331,7 +337,7 @@ def update_corrections(
|
|||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: Path = Path("corrections.json"),
|
||||
ui_dump_file: Path = Path("ui_dump.json"),
|
||||
skip_incorrect: bool = True,
|
||||
skip_incorrect: bool = typer.Option(True, show_default=True),
|
||||
):
|
||||
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -39,6 +39,7 @@ extra_requirements = {
|
|||
"streamlit==0.58.0",
|
||||
"natural==0.2.0",
|
||||
"stringcase==1.2.0",
|
||||
"google-cloud-speech~=1.3.1",
|
||||
]
|
||||
# "train": [
|
||||
# "torchaudio==0.5.0",
|
||||
|
|
@ -66,12 +67,14 @@ setup(
|
|||
"jasper_data_tts_generate = jasper.data.tts_generator:main",
|
||||
"jasper_data_conv_generate = jasper.data.conv_generator:main",
|
||||
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
|
||||
"jasper_data_test_generate = jasper.data.test_generator:main",
|
||||
"jasper_data_call_recycle = jasper.data.call_recycler:main",
|
||||
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
|
||||
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
|
||||
"jasper_data_server = jasper.data.server:main",
|
||||
"jasper_data_validation = jasper.data.validation.process:main",
|
||||
"jasper_data_preprocess = jasper.data.process:main",
|
||||
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
|
||||
]
|
||||
},
|
||||
zip_safe=False,
|
||||
|
|
|
|||
Loading…
Reference in New Issue