181 lines
6.1 KiB
Python
181 lines
6.1 KiB
Python
|
|
import typer
|
||
|
|
from pathlib import Path
|
||
|
|
import json
|
||
|
|
|
||
|
|
# from .utils import generate_dates, asr_test_writer
|
||
|
|
|
||
|
|
app = typer.Typer()
|
||
|
|
|
||
|
|
|
||
|
|
def run_test(reg_path, coll, s3, call_meta_dir, city_code, test_path):
|
||
|
|
from time import sleep
|
||
|
|
import subprocess
|
||
|
|
from .utils import ExtendedPath, get_call_logs
|
||
|
|
|
||
|
|
coll.delete_many({"CallID": test_path.name})
|
||
|
|
# test_path = dump_dir / data_name / test_file
|
||
|
|
# "../saas_reg/regression/run.sh -f data/asr_data/call_upwork_test_cnd_cities/asr_test.reg"
|
||
|
|
test_output = subprocess.run(
|
||
|
|
["/bin/bash", "-c", f"{str(reg_path)} --addr [::]:15400 -f {str(test_path)}"]
|
||
|
|
)
|
||
|
|
if test_output.returncode != 0:
|
||
|
|
print("Error running test {test_file}")
|
||
|
|
return
|
||
|
|
|
||
|
|
def get_meta():
|
||
|
|
call_meta = coll.find_one({"CallID": test_path.name})
|
||
|
|
if call_meta:
|
||
|
|
return call_meta
|
||
|
|
else:
|
||
|
|
sleep(2)
|
||
|
|
return get_meta()
|
||
|
|
|
||
|
|
call_meta = get_meta()
|
||
|
|
call_logs = get_call_logs(call_meta, s3, call_meta_dir)
|
||
|
|
call_events = call_logs["Events"]
|
||
|
|
|
||
|
|
test_data_path = test_path.with_suffix(".result.json")
|
||
|
|
test_data = ExtendedPath(test_data_path).read_json()
|
||
|
|
|
||
|
|
def is_final_asr_event_or_spoken(ev):
|
||
|
|
pld = json.loads(ev["Payload"])
|
||
|
|
return (
|
||
|
|
pld["AsrResult"]["Results"][0]["IsFinal"]
|
||
|
|
if ev["Type"] == "ASR_RESULT"
|
||
|
|
else False
|
||
|
|
)
|
||
|
|
|
||
|
|
def is_test_event(ev):
|
||
|
|
return (
|
||
|
|
ev["Author"] == "NLU"
|
||
|
|
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
|
||
|
|
) and (ev["Type"] != "DEBUG")
|
||
|
|
|
||
|
|
test_evs = list(filter(is_test_event, call_events))
|
||
|
|
if len(test_evs) == 2:
|
||
|
|
try:
|
||
|
|
asr_payload = test_evs[0]["Payload"]
|
||
|
|
asr_result = json.loads(asr_payload)["AsrResult"]["Results"][0]
|
||
|
|
alt_tscripts = [alt["Transcript"] for alt in asr_result["Alternatives"]]
|
||
|
|
gcp_result = "|".join(alt_tscripts)
|
||
|
|
entity_asr = asr_result["AsrDynamicResults"][0]["Candidate"]["Transcript"]
|
||
|
|
nlu_payload = test_evs[1]["Payload"]
|
||
|
|
nlu_result_payload = json.loads(nlu_payload)["NluResults"]
|
||
|
|
entity = test_data[0]["entity"]
|
||
|
|
text = test_data[0]["text"]
|
||
|
|
audio_filepath = test_data[0]["audio_filepath"]
|
||
|
|
pretrained_asr = test_data[0]["pretrained_asr"]
|
||
|
|
nlu_entity = list(json.loads(nlu_result_payload)["Entities"].values())[0]
|
||
|
|
asr_entity = city_code[entity] if entity in city_code else "UNKNOWN"
|
||
|
|
entities_match = asr_entity == nlu_entity
|
||
|
|
result = "Success" if entities_match else "Fail"
|
||
|
|
return {
|
||
|
|
"expected_entity": entity,
|
||
|
|
"text": text,
|
||
|
|
"audio_filepath": audio_filepath,
|
||
|
|
"pretrained_asr": pretrained_asr,
|
||
|
|
"entity_asr": entity_asr,
|
||
|
|
"google_asr": gcp_result,
|
||
|
|
"nlu_result": nlu_result_payload,
|
||
|
|
"asr_entity": asr_entity,
|
||
|
|
"nlu_entity": nlu_entity,
|
||
|
|
"result": result,
|
||
|
|
}
|
||
|
|
except Exception:
|
||
|
|
return {
|
||
|
|
"expected_entity": test_data[0]["entity"],
|
||
|
|
"text": test_data[0]["text"],
|
||
|
|
"audio_filepath": test_data[0]["audio_filepath"],
|
||
|
|
"pretrained_asr": test_data[0]["pretrained_asr"],
|
||
|
|
"entity_asr": "",
|
||
|
|
"google_asr": "",
|
||
|
|
"nlu_result": "",
|
||
|
|
"asr_entity": "",
|
||
|
|
"nlu_entity": "",
|
||
|
|
"result": "Error",
|
||
|
|
}
|
||
|
|
else:
|
||
|
|
return {
|
||
|
|
"expected_entity": test_data[0]["entity"],
|
||
|
|
"text": test_data[0]["text"],
|
||
|
|
"audio_filepath": test_data[0]["audio_filepath"],
|
||
|
|
"pretrained_asr": test_data[0]["pretrained_asr"],
|
||
|
|
"entity_asr": "",
|
||
|
|
"google_asr": "",
|
||
|
|
"nlu_result": "",
|
||
|
|
"asr_entity": "",
|
||
|
|
"nlu_entity": "",
|
||
|
|
"result": "Empty",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@app.command()
|
||
|
|
def evaluate_slu(
|
||
|
|
# conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
|
||
|
|
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
|
||
|
|
# extraction_key: str = "Cities",
|
||
|
|
dump_dir: Path = Path("./data/asr_data"),
|
||
|
|
call_meta_dir: Path = Path("./data/call_metas"),
|
||
|
|
test_file_pref: str = "asr_test",
|
||
|
|
mongo_uri: str = typer.Option(
|
||
|
|
"mongodb://localhost:27017/test.calls", show_default=True
|
||
|
|
),
|
||
|
|
test_results: Path = Path("./data/results.csv"),
|
||
|
|
airport_codes: Path = Path("./airports_code.csv"),
|
||
|
|
reg_path: Path = Path("../saas_reg/regression/run.sh"),
|
||
|
|
test_id: str = "5ef481f27031edf6910e94e0",
|
||
|
|
):
|
||
|
|
# import json
|
||
|
|
from .utils import get_mongo_coll
|
||
|
|
import pandas as pd
|
||
|
|
import boto3
|
||
|
|
from concurrent.futures import ThreadPoolExecutor
|
||
|
|
from functools import partial
|
||
|
|
|
||
|
|
# import subprocess
|
||
|
|
# from time import sleep
|
||
|
|
import csv
|
||
|
|
from tqdm import tqdm
|
||
|
|
|
||
|
|
s3 = boto3.client("s3")
|
||
|
|
df = pd.read_csv(airport_codes)[["iata", "city"]]
|
||
|
|
city_code = pd.Series(df["iata"].values, index=df["city"]).to_dict()
|
||
|
|
|
||
|
|
test_files = list((dump_dir / data_name).glob(test_file_pref + "*.reg"))
|
||
|
|
coll = get_mongo_coll(mongo_uri)
|
||
|
|
with test_results.open("w") as csvfile:
|
||
|
|
fieldnames = [
|
||
|
|
"expected_entity",
|
||
|
|
"text",
|
||
|
|
"audio_filepath",
|
||
|
|
"pretrained_asr",
|
||
|
|
"entity_asr",
|
||
|
|
"google_asr",
|
||
|
|
"nlu_result",
|
||
|
|
"asr_entity",
|
||
|
|
"nlu_entity",
|
||
|
|
"result",
|
||
|
|
]
|
||
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||
|
|
writer.writeheader()
|
||
|
|
with ThreadPoolExecutor(max_workers=8) as exe:
|
||
|
|
print("starting all loading tasks")
|
||
|
|
for test_result in tqdm(
|
||
|
|
exe.map(
|
||
|
|
partial(run_test, reg_path, coll, s3, call_meta_dir, city_code),
|
||
|
|
test_files,
|
||
|
|
),
|
||
|
|
position=0,
|
||
|
|
leave=True,
|
||
|
|
total=len(test_files),
|
||
|
|
):
|
||
|
|
writer.writerow(test_result)
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
app()
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|