1. removed the transcriber_pretrained/speller from utils

2. introduced get_mongo_coll to get the collection object directly from mongo uri
3. removed processing of correction entries to remove space/upper casing
Malar Kannan 2020-06-04 17:49:16 +05:30
parent e3a01169c2
commit bca227a7d7
5 changed files with 42 additions and 50 deletions

View File

@ -1,6 +1,7 @@
import os import os
import logging import logging
import rpyc import rpyc
from functools import lru_cache
logging.basicConfig( logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@ -12,12 +13,9 @@ ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost")
ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045")) ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045"))
@lru_cache()
def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT): def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
logger.info(f"connecting to asr server at {asr_host}:{asr_port}") logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
asr = rpyc.connect(asr_host, asr_port).root asr = rpyc.connect(asr_host, asr_port).root
logger.info(f"connected to asr server successfully") logger.info(f"connected to asr server successfully")
return asr.transcribe return asr.transcribe
transcriber_pretrained = transcribe_gen(asr_port=8044)
transcriber_speller = transcribe_gen(asr_port=8045)

View File

@ -1,15 +1,7 @@
# import argparse
# import logging
import typer import typer
from pathlib import Path from pathlib import Path
app = typer.Typer() app = typer.Typer()
# leader_app = typer.Typer()
# app.add_typer(leader_app, name="leaderboard")
# plot_app = typer.Typer()
# app.add_typer(plot_app, name="plot")
@app.command() @app.command()
def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")): def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
@ -18,7 +10,7 @@ def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
from ruamel.yaml import YAML from ruamel.yaml import YAML
yaml = YAML() yaml = YAML()
mongo_coll = get_mongo_conn().test.calls mongo_coll = get_mongo_conn()
caller_calls = defaultdict(lambda: []) caller_calls = defaultdict(lambda: [])
for call in mongo_coll.find(): for call in mongo_coll.find():
sysid = call["SystemID"] sysid = call["SystemID"]
@ -46,7 +38,7 @@ def export_calls_between(
from .utils import get_mongo_conn from .utils import get_mongo_conn
yaml = YAML() yaml = YAML()
mongo_coll = get_mongo_conn(port=mongo_port).test.calls mongo_coll = get_mongo_conn(port=mongo_port)
start_meta = mongo_coll.find_one({"SystemID": start_cid}) start_meta = mongo_coll.find_one({"SystemID": start_cid})
end_meta = mongo_coll.find_one({"SystemID": end_cid}) end_meta = mongo_coll.find_one({"SystemID": end_cid})
@ -77,23 +69,21 @@ def analyze(
plot_calls: bool = False, plot_calls: bool = False,
extract_data: bool = False, extract_data: bool = False,
download_only: bool = False, download_only: bool = False,
call_logs_file: Path = Path("./call_logs.yaml"), call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
output_dir: Path = Path("./data"), output_dir: Path = Path("./data"),
mongo_port: int = 27017, data_name: str = None,
mongo_uri: str = typer.Option("mongodb://localhost:27017/test.calls", show_default=True),
): ):
from urllib.parse import urlsplit from urllib.parse import urlsplit
from functools import reduce from functools import reduce
import boto3 import boto3
from io import BytesIO from io import BytesIO
import json import json
from ruamel.yaml import YAML from ruamel.yaml import YAML
import re import re
from google.protobuf.timestamp_pb2 import Timestamp from google.protobuf.timestamp_pb2 import Timestamp
from datetime import timedelta from datetime import timedelta
# from concurrent.futures import ThreadPoolExecutor
import librosa import librosa
import librosa.display import librosa.display
from lenses import lens from lenses import lens
@ -102,23 +92,17 @@ def analyze(
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib import matplotlib
from tqdm import tqdm from tqdm import tqdm
from .utils import asr_data_writer, get_mongo_conn from .utils import asr_data_writer, get_mongo_coll
from pydub import AudioSegment from pydub import AudioSegment
from natural.date import compress from natural.date import compress
# from itertools import product, chain
matplotlib.rcParams["agg.path.chunksize"] = 10000 matplotlib.rcParams["agg.path.chunksize"] = 10000
matplotlib.use("agg") matplotlib.use("agg")
# logging.basicConfig(
# level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# )
# logger = logging.getLogger(__name__)
yaml = YAML() yaml = YAML()
s3 = boto3.client("s3") s3 = boto3.client("s3")
mongo_collection = get_mongo_conn(port=mongo_port).test.calls mongo_collection = get_mongo_coll(mongo_uri)
call_media_dir: Path = output_dir / Path("call_wavs") call_media_dir: Path = output_dir / Path("call_wavs")
call_media_dir.mkdir(exist_ok=True, parents=True) call_media_dir.mkdir(exist_ok=True, parents=True)
call_meta_dir: Path = output_dir / Path("call_metas") call_meta_dir: Path = output_dir / Path("call_metas")
@ -127,6 +111,7 @@ def analyze(
call_plot_dir.mkdir(exist_ok=True, parents=True) call_plot_dir.mkdir(exist_ok=True, parents=True)
call_asr_data: Path = output_dir / Path("asr_data") call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True) call_asr_data.mkdir(exist_ok=True, parents=True)
dataset_name = call_logs_file.stem if not data_name else data_name
call_logs = yaml.load(call_logs_file.read_text()) call_logs = yaml.load(call_logs_file.read_text())
@ -183,7 +168,7 @@ def analyze(
call_events = call_meta["Events"] call_events = call_meta["Events"]
def is_writer_uri_event(ev): def is_writer_uri_event(ev):
return ev["Author"] == "AUDIO_WRITER" and 's3://' in ev["Msg"] return ev["Author"] == "AUDIO_WRITER" and "s3://" in ev["Msg"]
writer_events = list(filter(is_writer_uri_event, call_events)) writer_events = list(filter(is_writer_uri_event, call_events))
s3_wav_url = re.search(r"(s3://.*)", writer_events[0]["Msg"]).groups(0)[0] s3_wav_url = re.search(r"(s3://.*)", writer_events[0]["Msg"]).groups(0)[0]
@ -268,8 +253,10 @@ def analyze(
meta = mongo_collection.find_one({"SystemID": cid}) meta = mongo_collection.find_one({"SystemID": cid})
duration = meta["EndTS"] - meta["StartTS"] duration = meta["EndTS"] - meta["StartTS"]
process_meta = process_call(meta) process_meta = process_call(meta)
data_points = get_data_points(process_meta['utter_events'], process_meta['first_event_fn']) data_points = get_data_points(
process_meta['data_points'] = data_points process_meta["utter_events"], process_meta["first_event_fn"]
)
process_meta["data_points"] = data_points
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta} return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
def download_meta_audio(): def download_meta_audio():
@ -355,7 +342,7 @@ def analyze(
for dp in gen_data_values(saved_wav_path, data_points): for dp in gen_data_values(saved_wav_path, data_points):
yield dp yield dp
asr_data_writer(call_asr_data, "call_alphanum", data_source()) asr_data_writer(call_asr_data, dataset_name, data_source())
def show_leaderboard(): def show_leaderboard():
def compute_user_stats(call_stat): def compute_user_stats(call_stat):
@ -383,14 +370,14 @@ def analyze(
leader_board = leader_df.rename( leader_board = leader_df.rename(
columns={ columns={
"rank": "Rank", "rank": "Rank",
"num_samples": "Codes", "num_samples": "Count",
"name": "Name", "name": "Name",
"samples_rate": "SpeechRate", "samples_rate": "SpeechRate",
"duration_str": "Duration", "duration_str": "Duration",
} }
)[["Rank", "Name", "Codes", "Duration"]] )[["Rank", "Name", "Count", "Duration"]]
print( print(
"""ASR Speller Dataset Leaderboard : """ASR Dataset Leaderboard :
---------------------------------""" ---------------------------------"""
) )
print(leader_board.to_string(index=False)) print(leader_board.to_string(index=False))

View File

@ -104,10 +104,16 @@ class ExtendedPath(type(Path())):
return json.dump(data, jf, indent=2) return json.dump(data, jf, indent=2)
def get_mongo_conn(host="", port=27017): def get_mongo_coll(uri="mongodb://localhost:27017/test.calls"):
ud = pymongo.uri_parser.parse_uri(uri)
conn = pymongo.MongoClient(uri)
return conn[ud['database']][ud['collection']]
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost") mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
mongo_uri = f"mongodb://{mongo_host}:{port}/" mongo_uri = f"mongodb://{mongo_host}:{port}/"
return pymongo.MongoClient(mongo_uri) return pymongo.MongoClient(mongo_uri)[db][col]
def strip_silence(sound): def strip_silence(sound):

View File

@ -24,6 +24,7 @@ def preprocess_datapoint(
import librosa.display import librosa.display
from pydub import AudioSegment from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate from nemo.collections.asr.metrics import word_error_rate
from jasper.client import transcribe_gen
try: try:
res = dict(sample) res = dict(sample)
@ -36,7 +37,7 @@ def preprocess_datapoint(
res["spoken"] = res["text"] res["spoken"] = res["text"]
res["utterance_id"] = audio_path.stem res["utterance_id"] = audio_path.stem
if not annotation_only: if not annotation_only:
from jasper.client import transcriber_pretrained, transcriber_speller transcriber_pretrained = transcribe_gen(asr_port=8044)
aud_seg = ( aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path) AudioSegment.from_file_using_temporary_files(audio_path)
@ -49,6 +50,7 @@ def preprocess_datapoint(
[res["text"]], [res["pretrained_asr"]] [res["text"]], [res["pretrained_asr"]]
) )
if use_domain_asr: if use_domain_asr:
transcriber_speller = transcribe_gen(asr_port=8045)
res["domain_asr"] = transcriber_speller(aud_seg.raw_data) res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate( res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]] [res["spoken"]], [res["pretrained_asr"]]
@ -74,19 +76,19 @@ def preprocess_datapoint(
@app.command() @app.command()
def dump_validation_ui_data( def dump_validation_ui_data(
data_manifest_path: Path = typer.Option( dataset_path: Path = typer.Option(
Path("./data/asr_data/call_alphanum/manifest.json"), show_default=True Path("./data/asr_data/call_alphanum"), show_default=True
), ),
dump_path: Path = typer.Option( dump_name: str = typer.Option("ui_dump.json", show_default=True),
Path("./data/valiation_data/ui_dump.json"), show_default=True use_domain_asr: bool = False,
), annotation_only: bool = False,
use_domain_asr: bool = True,
annotation_only: bool = True,
enable_plots: bool = True, enable_plots: bool = True,
): ):
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
data_manifest_path = dataset_path / Path("manifest.json")
dump_path: Path = Path(f"./data/valiation_data/{dataset_path.stem}/{dump_name}")
plot_dir = data_manifest_path.parent / Path("wav_plots") plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True) plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}") typer.echo(f"Using data manifest:{data_manifest_path}")
@ -137,7 +139,7 @@ def dump_validation_ui_data(
@app.command() @app.command()
def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")): def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")):
col = get_mongo_conn().test.asr_validation col = get_mongo_conn(col='asr_validation')
cursor_obj = col.find({"type": "correction"}, projection={"_id": False}) cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
corrections = [c for c in cursor_obj] corrections = [c for c in cursor_obj]
@ -154,7 +156,7 @@ def fill_unannotated(
annotated_codes = {c["code"] for c in corrections} annotated_codes = {c["code"] for c in corrections}
all_codes = {c["gold_chars"] for c in processed_data} all_codes = {c["gold_chars"] for c in processed_data}
unann_codes = all_codes - annotated_codes unann_codes = all_codes - annotated_codes
mongo_conn = get_mongo_conn().test.asr_validation mongo_conn = get_mongo_conn(col='asr_validation')
for c in unann_codes: for c in unann_codes:
mongo_conn.find_one_and_update( mongo_conn.find_one_and_update(
{"type": "correction", "code": c}, {"type": "correction", "code": c},
@ -232,7 +234,7 @@ def update_corrections(
def clear_mongo_corrections(): def clear_mongo_corrections():
delete = typer.confirm("are you sure you want to clear mongo collection it?") delete = typer.confirm("are you sure you want to clear mongo collection it?")
if delete: if delete:
col = get_mongo_conn().test.asr_validation col = get_mongo_conn(col='asr_validation')
col.delete_many({"type": "correction"}) col.delete_many({"type": "correction"})
typer.echo("deleted mongo collection.") typer.echo("deleted mongo collection.")
typer.echo("Aborted") typer.echo("Aborted")

View File

@ -9,7 +9,7 @@ app = typer.Typer()
if not hasattr(st, "mongo_connected"): if not hasattr(st, "mongo_connected"):
st.mongoclient = get_mongo_conn().test.asr_validation st.mongoclient = get_mongo_conn(col='asr_validation')
mongo_conn = st.mongoclient mongo_conn = st.mongoclient
def current_cursor_fn(): def current_cursor_fn():
@ -111,9 +111,8 @@ def main(manifest: Path):
if selected == "Inaudible": if selected == "Inaudible":
corrected = "" corrected = ""
if st.button("Submit"): if st.button("Submit"):
correct_code = corrected.replace(" ", "").upper()
st.update_entry( st.update_entry(
sample["utterance_id"], {"status": selected, "correction": correct_code} sample["utterance_id"], {"status": selected, "correction": corrected}
) )
st.update_cursor(sample_no + 1) st.update_cursor(sample_no + 1)
if correction_entry: if correction_entry: