1. added a data extraction type argument
2. cleanup/refactor
parent
8db1be0083
commit
6d149d282d
|
|
@ -1,10 +1,16 @@
|
||||||
import typer
|
import typer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
|
def export_all_logs(
|
||||||
|
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||||
|
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
|
||||||
|
):
|
||||||
from .utils import get_mongo_conn
|
from .utils import get_mongo_conn
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from ruamel.yaml import YAML
|
from ruamel.yaml import YAML
|
||||||
|
|
@ -14,14 +20,14 @@ def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
|
||||||
caller_calls = defaultdict(lambda: [])
|
caller_calls = defaultdict(lambda: [])
|
||||||
for call in mongo_coll.find():
|
for call in mongo_coll.find():
|
||||||
sysid = call["SystemID"]
|
sysid = call["SystemID"]
|
||||||
call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}"
|
call_uri = f"http://{domain}/calls/{sysid}"
|
||||||
caller = call["Caller"]
|
caller = call["Caller"]
|
||||||
caller_calls[caller].append(call_uri)
|
caller_calls[caller].append(call_uri)
|
||||||
caller_list = []
|
caller_list = []
|
||||||
for caller in caller_calls:
|
for caller in caller_calls:
|
||||||
caller_list.append({"name": caller, "calls": caller_calls[caller]})
|
caller_list.append({"name": caller, "calls": caller_calls[caller]})
|
||||||
output_yaml = {"users": caller_list}
|
output_yaml = {"users": caller_list}
|
||||||
typer.echo("exporting call logs to yaml file")
|
typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
|
||||||
with call_logs_file.open("w") as yf:
|
with call_logs_file.open("w") as yf:
|
||||||
yaml.dump(output_yaml, yf)
|
yaml.dump(output_yaml, yf)
|
||||||
|
|
||||||
|
|
@ -30,7 +36,8 @@ def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
|
||||||
def export_calls_between(
|
def export_calls_between(
|
||||||
start_cid: str,
|
start_cid: str,
|
||||||
end_cid: str,
|
end_cid: str,
|
||||||
call_logs_file: Path = Path("./call_sia_logs.yaml"),
|
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||||
|
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
|
||||||
mongo_port: int = 27017,
|
mongo_port: int = 27017,
|
||||||
):
|
):
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
@ -51,28 +58,38 @@ def export_calls_between(
|
||||||
)
|
)
|
||||||
for call in call_query:
|
for call in call_query:
|
||||||
sysid = call["SystemID"]
|
sysid = call["SystemID"]
|
||||||
call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}"
|
call_uri = f"http://{domain}/calls/{sysid}"
|
||||||
caller = call["Caller"]
|
caller = call["Caller"]
|
||||||
caller_calls[caller].append(call_uri)
|
caller_calls[caller].append(call_uri)
|
||||||
caller_list = []
|
caller_list = []
|
||||||
for caller in caller_calls:
|
for caller in caller_calls:
|
||||||
caller_list.append({"name": caller, "calls": caller_calls[caller]})
|
caller_list.append({"name": caller, "calls": caller_calls[caller]})
|
||||||
output_yaml = {"users": caller_list}
|
output_yaml = {"users": caller_list}
|
||||||
typer.echo("exporting call logs to yaml file")
|
typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
|
||||||
with call_logs_file.open("w") as yf:
|
with call_logs_file.open("w") as yf:
|
||||||
yaml.dump(output_yaml, yf)
|
yaml.dump(output_yaml, yf)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionType(str, Enum):
|
||||||
|
flow = "flow"
|
||||||
|
data = "data"
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def analyze(
|
def analyze(
|
||||||
leaderboard: bool = False,
|
leaderboard: bool = False,
|
||||||
plot_calls: bool = False,
|
plot_calls: bool = False,
|
||||||
extract_data: bool = False,
|
extract_data: bool = False,
|
||||||
|
extraction_type: ExtractionType = typer.Option(
|
||||||
|
ExtractionType.data, show_default=True
|
||||||
|
),
|
||||||
download_only: bool = False,
|
download_only: bool = False,
|
||||||
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||||
output_dir: Path = Path("./data"),
|
output_dir: Path = Path("./data"),
|
||||||
data_name: str = None,
|
data_name: str = None,
|
||||||
mongo_uri: str = typer.Option("mongodb://localhost:27017/test.calls", show_default=True),
|
mongo_uri: str = typer.Option(
|
||||||
|
"mongodb://localhost:27017/test.calls", show_default=True
|
||||||
|
),
|
||||||
):
|
):
|
||||||
|
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
|
|
@ -145,23 +162,77 @@ def analyze(
|
||||||
def chunk_n(evs, n):
|
def chunk_n(evs, n):
|
||||||
return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
|
return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
|
||||||
|
|
||||||
def get_data_points(utter_events, td_fn):
|
if extraction_type == ExtractionType.data:
|
||||||
data_points = []
|
|
||||||
for evs in chunk_n(utter_events, 3):
|
def is_utter_event(ev):
|
||||||
try:
|
return (
|
||||||
assert evs[0]["Type"] == "CONV_RESULT"
|
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
|
||||||
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
and (ev["Type"] != "DEBUG")
|
||||||
assert evs[2]["Type"] == "STOPPED_SPEAKING"
|
and ev["Type"] != "ASR_RESULT"
|
||||||
start_time = td_fn(evs[1]).total_seconds() - 1.5
|
)
|
||||||
end_time = td_fn(evs[2]).total_seconds()
|
|
||||||
code = evs[0]["Msg"]
|
def get_data_points(utter_events, td_fn):
|
||||||
data_points.append(
|
data_points = []
|
||||||
{"start_time": start_time, "end_time": end_time, "code": code}
|
for evs in chunk_n(utter_events, 3):
|
||||||
)
|
try:
|
||||||
except AssertionError:
|
assert evs[0]["Type"] == "CONV_RESULT"
|
||||||
# skipping invalid data_points
|
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
||||||
pass
|
assert evs[2]["Type"] == "STOPPED_SPEAKING"
|
||||||
return data_points
|
start_time = td_fn(evs[1]).total_seconds() - 1.5
|
||||||
|
end_time = td_fn(evs[2]).total_seconds()
|
||||||
|
spoken = evs[0]["Msg"]
|
||||||
|
data_points.append(
|
||||||
|
{"start_time": start_time, "end_time": end_time, "code": spoken}
|
||||||
|
)
|
||||||
|
except AssertionError:
|
||||||
|
# skipping invalid data_points
|
||||||
|
pass
|
||||||
|
return data_points
|
||||||
|
|
||||||
|
def text_extractor(spoken):
|
||||||
|
return re.search(r"'(.*)'", spoken).groups(0)[0] if len(spoken) > 6 else spoken
|
||||||
|
|
||||||
|
elif extraction_type == ExtractionType.flow:
|
||||||
|
|
||||||
|
def is_final_asr_event_or_spoken(ev):
|
||||||
|
pld = json.loads(ev["Payload"])
|
||||||
|
return (
|
||||||
|
pld["AsrResult"]["Results"][0]["IsFinal"]
|
||||||
|
if ev["Type"] == "ASR_RESULT"
|
||||||
|
else True
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_utter_event(ev):
|
||||||
|
return (
|
||||||
|
ev["Author"] == "CONV"
|
||||||
|
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
|
||||||
|
) and (ev["Type"] != "DEBUG")
|
||||||
|
|
||||||
|
def get_data_points(utter_events, td_fn):
|
||||||
|
data_points = []
|
||||||
|
for evs in chunk_n(utter_events, 4):
|
||||||
|
try:
|
||||||
|
assert len(evs) == 4
|
||||||
|
assert evs[0]["Type"] == "CONV_RESULT"
|
||||||
|
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
||||||
|
assert evs[2]["Type"] == "ASR_RESULT"
|
||||||
|
assert evs[3]["Type"] == "STOPPED_SPEAKING"
|
||||||
|
start_time = td_fn(evs[1]).total_seconds() - 1.5
|
||||||
|
end_time = td_fn(evs[2]).total_seconds()
|
||||||
|
conv_msg = evs[0]["Msg"]
|
||||||
|
if 'full name' in conv_msg.lower():
|
||||||
|
pld = json.loads(evs[2]["Payload"])
|
||||||
|
spoken = pld["AsrResult"]["Results"][0]["Alternatives"][0]['Transcript']
|
||||||
|
data_points.append(
|
||||||
|
{"start_time": start_time, "end_time": end_time, "code": spoken}
|
||||||
|
)
|
||||||
|
except AssertionError:
|
||||||
|
# skipping invalid data_points
|
||||||
|
pass
|
||||||
|
return data_points
|
||||||
|
|
||||||
|
def text_extractor(spoken):
|
||||||
|
return spoken
|
||||||
|
|
||||||
def process_call(call_obj):
|
def process_call(call_obj):
|
||||||
call_meta = get_call_meta(call_obj)
|
call_meta = get_call_meta(call_obj)
|
||||||
|
|
@ -184,13 +255,6 @@ def analyze(
|
||||||
|
|
||||||
get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
|
get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
|
||||||
|
|
||||||
def is_utter_event(ev):
|
|
||||||
return (
|
|
||||||
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
|
|
||||||
and (ev["Type"] != "DEBUG")
|
|
||||||
and ev["Type"] != "ASR_RESULT"
|
|
||||||
)
|
|
||||||
|
|
||||||
uevs = list(filter(is_utter_event, call_events))
|
uevs = list(filter(is_utter_event, call_events))
|
||||||
ev_count = len(uevs)
|
ev_count = len(uevs)
|
||||||
utter_events = uevs[: ev_count - ev_count % 3]
|
utter_events = uevs[: ev_count - ev_count % 3]
|
||||||
|
|
@ -201,36 +265,6 @@ def analyze(
|
||||||
s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
|
s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
# %config InlineBackend.figure_format = "retina"
|
|
||||||
|
|
||||||
def plot_events(y, sr, utter_events, file_path):
|
|
||||||
plt.figure(figsize=(16, 12))
|
|
||||||
librosa.display.waveplot(y=y, sr=sr)
|
|
||||||
# plt.tight_layout()
|
|
||||||
for evs in chunk_n(utter_events, 3):
|
|
||||||
assert evs[0]["Type"] == "CONV_RESULT"
|
|
||||||
assert evs[1]["Type"] == "STARTED_SPEAKING"
|
|
||||||
assert evs[2]["Type"] == "STOPPED_SPEAKING"
|
|
||||||
for ev in evs:
|
|
||||||
# print(ev["Type"])
|
|
||||||
ev_type = ev["Type"]
|
|
||||||
pos = get_ev_fev_timedelta(ev).total_seconds()
|
|
||||||
if ev_type == "STARTED_SPEAKING":
|
|
||||||
pos = pos - 1.5
|
|
||||||
plt.axvline(pos) # , label="pyplot vertical line")
|
|
||||||
plt.text(
|
|
||||||
pos,
|
|
||||||
0.2,
|
|
||||||
f"event:{ev_type}:{ev['Msg']}",
|
|
||||||
rotation=90,
|
|
||||||
horizontalalignment="left"
|
|
||||||
if ev_type != "STOPPED_SPEAKING"
|
|
||||||
else "right",
|
|
||||||
verticalalignment="center",
|
|
||||||
)
|
|
||||||
plt.title("Monophonic")
|
|
||||||
plt.savefig(file_path, format="png")
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"wav_path": saved_wav_path,
|
"wav_path": saved_wav_path,
|
||||||
"num_samples": len(utter_events) // 3,
|
"num_samples": len(utter_events) // 3,
|
||||||
|
|
@ -263,7 +297,6 @@ def analyze(
|
||||||
call_lens = lens["users"].Each()["calls"].Each()
|
call_lens = lens["users"].Each()["calls"].Each()
|
||||||
call_lens.modify(ensure_call)(call_logs)
|
call_lens.modify(ensure_call)(call_logs)
|
||||||
|
|
||||||
# @plot_app.command()
|
|
||||||
def plot_calls_data():
|
def plot_calls_data():
|
||||||
def plot_data_points(y, sr, data_points, file_path):
|
def plot_data_points(y, sr, data_points, file_path):
|
||||||
plt.figure(figsize=(16, 12))
|
plt.figure(figsize=(16, 12))
|
||||||
|
|
@ -317,16 +350,14 @@ def analyze(
|
||||||
.set_frame_rate(24000)
|
.set_frame_rate(24000)
|
||||||
)
|
)
|
||||||
for dp_id, dp in enumerate(data_points):
|
for dp_id, dp in enumerate(data_points):
|
||||||
start, end, code = dp["start_time"], dp["end_time"], dp["code"]
|
start, end, spoken = dp["start_time"], dp["end_time"], dp["code"]
|
||||||
code_seg = call_seg[start * 1000 : end * 1000]
|
spoken_seg = call_seg[start * 1000 : end * 1000]
|
||||||
code_fb = BytesIO()
|
spoken_fb = BytesIO()
|
||||||
code_seg.export(code_fb, format="wav")
|
spoken_seg.export(spoken_fb, format="wav")
|
||||||
code_wav = code_fb.getvalue()
|
spoken_wav = spoken_fb.getvalue()
|
||||||
# search for actual pnr code and handle plain codes as well
|
# search for actual pnr code and handle plain codes as well
|
||||||
extracted_code = (
|
extracted_code = text_extractor(spoken)
|
||||||
re.search(r"'(.*)'", code).groups(0)[0] if len(code) > 6 else code
|
yield extracted_code, spoken_seg.duration_seconds, spoken_wav
|
||||||
)
|
|
||||||
yield extracted_code, code_seg.duration_seconds, code_wav
|
|
||||||
|
|
||||||
call_lens = lens["users"].Each()["calls"].Each()
|
call_lens = lens["users"].Each()["calls"].Each()
|
||||||
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
|
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue