1. added a data extraction type argument

2. cleanup/refactor
Malar Kannan 2020-06-09 19:16:24 +05:30
parent 8db1be0083
commit 6d149d282d
1 changed files with 102 additions and 71 deletions

View File

@ -1,10 +1,16 @@
import typer import typer
from pathlib import Path from pathlib import Path
from enum import Enum
app = typer.Typer() app = typer.Typer()
@app.command() @app.command()
def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")): def export_all_logs(
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
):
from .utils import get_mongo_conn from .utils import get_mongo_conn
from collections import defaultdict from collections import defaultdict
from ruamel.yaml import YAML from ruamel.yaml import YAML
@ -14,14 +20,14 @@ def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
caller_calls = defaultdict(lambda: []) caller_calls = defaultdict(lambda: [])
for call in mongo_coll.find(): for call in mongo_coll.find():
sysid = call["SystemID"] sysid = call["SystemID"]
call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}" call_uri = f"http://{domain}/calls/{sysid}"
caller = call["Caller"] caller = call["Caller"]
caller_calls[caller].append(call_uri) caller_calls[caller].append(call_uri)
caller_list = [] caller_list = []
for caller in caller_calls: for caller in caller_calls:
caller_list.append({"name": caller, "calls": caller_calls[caller]}) caller_list.append({"name": caller, "calls": caller_calls[caller]})
output_yaml = {"users": caller_list} output_yaml = {"users": caller_list}
typer.echo("exporting call logs to yaml file") typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
with call_logs_file.open("w") as yf: with call_logs_file.open("w") as yf:
yaml.dump(output_yaml, yf) yaml.dump(output_yaml, yf)
@ -30,7 +36,8 @@ def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
def export_calls_between( def export_calls_between(
start_cid: str, start_cid: str,
end_cid: str, end_cid: str,
call_logs_file: Path = Path("./call_sia_logs.yaml"), call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
domain: str = typer.Option("sia-data.agaralabs.com", show_default=True),
mongo_port: int = 27017, mongo_port: int = 27017,
): ):
from collections import defaultdict from collections import defaultdict
@ -51,28 +58,38 @@ def export_calls_between(
) )
for call in call_query: for call in call_query:
sysid = call["SystemID"] sysid = call["SystemID"]
call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}" call_uri = f"http://{domain}/calls/{sysid}"
caller = call["Caller"] caller = call["Caller"]
caller_calls[caller].append(call_uri) caller_calls[caller].append(call_uri)
caller_list = [] caller_list = []
for caller in caller_calls: for caller in caller_calls:
caller_list.append({"name": caller, "calls": caller_calls[caller]}) caller_list.append({"name": caller, "calls": caller_calls[caller]})
output_yaml = {"users": caller_list} output_yaml = {"users": caller_list}
typer.echo("exporting call logs to yaml file") typer.echo(f"exporting call logs to yaml file at {call_logs_file}")
with call_logs_file.open("w") as yf: with call_logs_file.open("w") as yf:
yaml.dump(output_yaml, yf) yaml.dump(output_yaml, yf)
class ExtractionType(str, Enum):
flow = "flow"
data = "data"
@app.command() @app.command()
def analyze( def analyze(
leaderboard: bool = False, leaderboard: bool = False,
plot_calls: bool = False, plot_calls: bool = False,
extract_data: bool = False, extract_data: bool = False,
extraction_type: ExtractionType = typer.Option(
ExtractionType.data, show_default=True
),
download_only: bool = False, download_only: bool = False,
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True), call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
output_dir: Path = Path("./data"), output_dir: Path = Path("./data"),
data_name: str = None, data_name: str = None,
mongo_uri: str = typer.Option("mongodb://localhost:27017/test.calls", show_default=True), mongo_uri: str = typer.Option(
"mongodb://localhost:27017/test.calls", show_default=True
),
): ):
from urllib.parse import urlsplit from urllib.parse import urlsplit
@ -145,6 +162,15 @@ def analyze(
def chunk_n(evs, n): def chunk_n(evs, n):
return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)] return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)]
if extraction_type == ExtractionType.data:
def is_utter_event(ev):
return (
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
and (ev["Type"] != "DEBUG")
and ev["Type"] != "ASR_RESULT"
)
def get_data_points(utter_events, td_fn): def get_data_points(utter_events, td_fn):
data_points = [] data_points = []
for evs in chunk_n(utter_events, 3): for evs in chunk_n(utter_events, 3):
@ -154,15 +180,60 @@ def analyze(
assert evs[2]["Type"] == "STOPPED_SPEAKING" assert evs[2]["Type"] == "STOPPED_SPEAKING"
start_time = td_fn(evs[1]).total_seconds() - 1.5 start_time = td_fn(evs[1]).total_seconds() - 1.5
end_time = td_fn(evs[2]).total_seconds() end_time = td_fn(evs[2]).total_seconds()
code = evs[0]["Msg"] spoken = evs[0]["Msg"]
data_points.append( data_points.append(
{"start_time": start_time, "end_time": end_time, "code": code} {"start_time": start_time, "end_time": end_time, "code": spoken}
) )
except AssertionError: except AssertionError:
# skipping invalid data_points # skipping invalid data_points
pass pass
return data_points return data_points
def text_extractor(spoken):
return re.search(r"'(.*)'", spoken).groups(0)[0] if len(spoken) > 6 else spoken
elif extraction_type == ExtractionType.flow:
def is_final_asr_event_or_spoken(ev):
pld = json.loads(ev["Payload"])
return (
pld["AsrResult"]["Results"][0]["IsFinal"]
if ev["Type"] == "ASR_RESULT"
else True
)
def is_utter_event(ev):
return (
ev["Author"] == "CONV"
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
) and (ev["Type"] != "DEBUG")
def get_data_points(utter_events, td_fn):
data_points = []
for evs in chunk_n(utter_events, 4):
try:
assert len(evs) == 4
assert evs[0]["Type"] == "CONV_RESULT"
assert evs[1]["Type"] == "STARTED_SPEAKING"
assert evs[2]["Type"] == "ASR_RESULT"
assert evs[3]["Type"] == "STOPPED_SPEAKING"
start_time = td_fn(evs[1]).total_seconds() - 1.5
end_time = td_fn(evs[2]).total_seconds()
conv_msg = evs[0]["Msg"]
if 'full name' in conv_msg.lower():
pld = json.loads(evs[2]["Payload"])
spoken = pld["AsrResult"]["Results"][0]["Alternatives"][0]['Transcript']
data_points.append(
{"start_time": start_time, "end_time": end_time, "code": spoken}
)
except AssertionError:
# skipping invalid data_points
pass
return data_points
def text_extractor(spoken):
return spoken
def process_call(call_obj): def process_call(call_obj):
call_meta = get_call_meta(call_obj) call_meta = get_call_meta(call_obj)
call_events = call_meta["Events"] call_events = call_meta["Events"]
@ -184,13 +255,6 @@ def analyze(
get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev) get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev)
def is_utter_event(ev):
return (
(ev["Author"] == "CONV" or ev["Author"] == "ASR")
and (ev["Type"] != "DEBUG")
and ev["Type"] != "ASR_RESULT"
)
uevs = list(filter(is_utter_event, call_events)) uevs = list(filter(is_utter_event, call_events))
ev_count = len(uevs) ev_count = len(uevs)
utter_events = uevs[: ev_count - ev_count % 3] utter_events = uevs[: ev_count - ev_count % 3]
@ -201,36 +265,6 @@ def analyze(
s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path) s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path)
) )
# %config InlineBackend.figure_format = "retina"
def plot_events(y, sr, utter_events, file_path):
plt.figure(figsize=(16, 12))
librosa.display.waveplot(y=y, sr=sr)
# plt.tight_layout()
for evs in chunk_n(utter_events, 3):
assert evs[0]["Type"] == "CONV_RESULT"
assert evs[1]["Type"] == "STARTED_SPEAKING"
assert evs[2]["Type"] == "STOPPED_SPEAKING"
for ev in evs:
# print(ev["Type"])
ev_type = ev["Type"]
pos = get_ev_fev_timedelta(ev).total_seconds()
if ev_type == "STARTED_SPEAKING":
pos = pos - 1.5
plt.axvline(pos) # , label="pyplot vertical line")
plt.text(
pos,
0.2,
f"event:{ev_type}:{ev['Msg']}",
rotation=90,
horizontalalignment="left"
if ev_type != "STOPPED_SPEAKING"
else "right",
verticalalignment="center",
)
plt.title("Monophonic")
plt.savefig(file_path, format="png")
return { return {
"wav_path": saved_wav_path, "wav_path": saved_wav_path,
"num_samples": len(utter_events) // 3, "num_samples": len(utter_events) // 3,
@ -263,7 +297,6 @@ def analyze(
call_lens = lens["users"].Each()["calls"].Each() call_lens = lens["users"].Each()["calls"].Each()
call_lens.modify(ensure_call)(call_logs) call_lens.modify(ensure_call)(call_logs)
# @plot_app.command()
def plot_calls_data(): def plot_calls_data():
def plot_data_points(y, sr, data_points, file_path): def plot_data_points(y, sr, data_points, file_path):
plt.figure(figsize=(16, 12)) plt.figure(figsize=(16, 12))
@ -317,16 +350,14 @@ def analyze(
.set_frame_rate(24000) .set_frame_rate(24000)
) )
for dp_id, dp in enumerate(data_points): for dp_id, dp in enumerate(data_points):
start, end, code = dp["start_time"], dp["end_time"], dp["code"] start, end, spoken = dp["start_time"], dp["end_time"], dp["code"]
code_seg = call_seg[start * 1000 : end * 1000] spoken_seg = call_seg[start * 1000 : end * 1000]
code_fb = BytesIO() spoken_fb = BytesIO()
code_seg.export(code_fb, format="wav") spoken_seg.export(spoken_fb, format="wav")
code_wav = code_fb.getvalue() spoken_wav = spoken_fb.getvalue()
# search for actual pnr code and handle plain codes as well # search for actual pnr code and handle plain codes as well
extracted_code = ( extracted_code = text_extractor(spoken)
re.search(r"'(.*)'", code).groups(0)[0] if len(code) > 6 else code yield extracted_code, spoken_seg.duration_seconds, spoken_wav
)
yield extracted_code, code_seg.duration_seconds, code_wav
call_lens = lens["users"].Each()["calls"].Each() call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs) call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)