From bca227a7d767a022b2e76022f42010fec093478b Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Thu, 4 Jun 2020 17:49:16 +0530 Subject: [PATCH] 1. removed the transcriber_pretrained/speller from utils 2. introduced get_mongo_coll to get the collection object directly from mongo uri 3. removed processing of correction entries to remove space/upper casing --- jasper/client.py | 6 ++-- jasper/data/call_recycler.py | 47 +++++++++++-------------------- jasper/data/utils.py | 10 +++++-- jasper/data/validation/process.py | 24 ++++++++-------- jasper/data/validation/ui.py | 5 ++-- 5 files changed, 42 insertions(+), 50 deletions(-) diff --git a/jasper/client.py b/jasper/client.py index 84fd465..6c474a5 100644 --- a/jasper/client.py +++ b/jasper/client.py @@ -1,6 +1,7 @@ import os import logging import rpyc +from functools import lru_cache logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -12,12 +13,9 @@ ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost") ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045")) +@lru_cache() def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT): logger.info(f"connecting to asr server at {asr_host}:{asr_port}") asr = rpyc.connect(asr_host, asr_port).root logger.info(f"connected to asr server successfully") return asr.transcribe - - -transcriber_pretrained = transcribe_gen(asr_port=8044) -transcriber_speller = transcribe_gen(asr_port=8045) diff --git a/jasper/data/call_recycler.py b/jasper/data/call_recycler.py index c7ccde6..93cb023 100644 --- a/jasper/data/call_recycler.py +++ b/jasper/data/call_recycler.py @@ -1,15 +1,7 @@ -# import argparse - -# import logging import typer from pathlib import Path app = typer.Typer() -# leader_app = typer.Typer() -# app.add_typer(leader_app, name="leaderboard") -# plot_app = typer.Typer() -# app.add_typer(plot_app, name="plot") - @app.command() def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")): @@ -18,7 +10,7 @@ def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")): from ruamel.yaml import YAML yaml = YAML() - mongo_coll = get_mongo_conn().test.calls + mongo_coll = get_mongo_conn() caller_calls = defaultdict(lambda: []) for call in mongo_coll.find(): sysid = call["SystemID"] @@ -46,7 +38,7 @@ def export_calls_between( from .utils import get_mongo_conn yaml = YAML() - mongo_coll = get_mongo_conn(port=mongo_port).test.calls + mongo_coll = get_mongo_conn(port=mongo_port) start_meta = mongo_coll.find_one({"SystemID": start_cid}) end_meta = mongo_coll.find_one({"SystemID": end_cid}) @@ -77,23 +69,21 @@ def analyze( plot_calls: bool = False, extract_data: bool = False, download_only: bool = False, - call_logs_file: Path = Path("./call_logs.yaml"), + call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True), output_dir: Path = Path("./data"), - mongo_port: int = 27017, + data_name: str = None, + mongo_uri: str = typer.Option("mongodb://localhost:27017/test.calls", show_default=True), ): from urllib.parse import urlsplit from functools import reduce import boto3 - from io import BytesIO import json from ruamel.yaml import YAML import re from google.protobuf.timestamp_pb2 import Timestamp from datetime import timedelta - - # from concurrent.futures import ThreadPoolExecutor import librosa import librosa.display from lenses import lens @@ -102,23 +92,17 @@ def analyze( import matplotlib.pyplot as plt import matplotlib from tqdm import tqdm - from .utils import asr_data_writer, get_mongo_conn + from .utils import asr_data_writer, get_mongo_coll from pydub import AudioSegment from natural.date import compress - # from itertools import product, chain - matplotlib.rcParams["agg.path.chunksize"] = 10000 matplotlib.use("agg") - # logging.basicConfig( - # level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" - # ) - # logger = logging.getLogger(__name__) yaml = YAML() s3 = boto3.client("s3") - mongo_collection = get_mongo_conn(port=mongo_port).test.calls + mongo_collection = get_mongo_coll(mongo_uri) call_media_dir: Path = output_dir / Path("call_wavs") call_media_dir.mkdir(exist_ok=True, parents=True) call_meta_dir: Path = output_dir / Path("call_metas") @@ -127,6 +111,7 @@ def analyze( call_plot_dir.mkdir(exist_ok=True, parents=True) call_asr_data: Path = output_dir / Path("asr_data") call_asr_data.mkdir(exist_ok=True, parents=True) + dataset_name = call_logs_file.stem if not data_name else data_name call_logs = yaml.load(call_logs_file.read_text()) @@ -183,7 +168,7 @@ def analyze( call_events = call_meta["Events"] def is_writer_uri_event(ev): - return ev["Author"] == "AUDIO_WRITER" and 's3://' in ev["Msg"] + return ev["Author"] == "AUDIO_WRITER" and "s3://" in ev["Msg"] writer_events = list(filter(is_writer_uri_event, call_events)) s3_wav_url = re.search(r"(s3://.*)", writer_events[0]["Msg"]).groups(0)[0] @@ -268,8 +253,10 @@ def analyze( meta = mongo_collection.find_one({"SystemID": cid}) duration = meta["EndTS"] - meta["StartTS"] process_meta = process_call(meta) - data_points = get_data_points(process_meta['utter_events'], process_meta['first_event_fn']) - process_meta['data_points'] = data_points + data_points = get_data_points( + process_meta["utter_events"], process_meta["first_event_fn"] + ) + process_meta["data_points"] = data_points return {"url": uri, "meta": meta, "duration": duration, "process": process_meta} def download_meta_audio(): @@ -355,7 +342,7 @@ def analyze( for dp in gen_data_values(saved_wav_path, data_points): yield dp - asr_data_writer(call_asr_data, "call_alphanum", data_source()) + asr_data_writer(call_asr_data, dataset_name, data_source()) def show_leaderboard(): def compute_user_stats(call_stat): @@ -383,14 +370,14 @@ def analyze( leader_board = leader_df.rename( columns={ "rank": "Rank", - "num_samples": "Codes", + "num_samples": "Count", "name": "Name", "samples_rate": "SpeechRate", "duration_str": "Duration", } - )[["Rank", "Name", "Codes", "Duration"]] + )[["Rank", "Name", "Count", "Duration"]] print( - """ASR Speller Dataset Leaderboard : + """ASR Dataset Leaderboard : ---------------------------------""" ) print(leader_board.to_string(index=False)) diff --git a/jasper/data/utils.py b/jasper/data/utils.py index 76ba597..1f5f5b1 100644 --- a/jasper/data/utils.py +++ b/jasper/data/utils.py @@ -104,10 +104,16 @@ class ExtendedPath(type(Path())): return json.dump(data, jf, indent=2) -def get_mongo_conn(host="", port=27017): +def get_mongo_coll(uri="mongodb://localhost:27017/test.calls"): + ud = pymongo.uri_parser.parse_uri(uri) + conn = pymongo.MongoClient(uri) + return conn[ud['database']][ud['collection']] + + +def get_mongo_conn(host="", port=27017, db="test", col="calls"): mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost") mongo_uri = f"mongodb://{mongo_host}:{port}/" - return pymongo.MongoClient(mongo_uri) + return pymongo.MongoClient(mongo_uri)[db][col] def strip_silence(sound): diff --git a/jasper/data/validation/process.py b/jasper/data/validation/process.py index 832e6a8..217dffc 100644 --- a/jasper/data/validation/process.py +++ b/jasper/data/validation/process.py @@ -24,6 +24,7 @@ def preprocess_datapoint( import librosa.display from pydub import AudioSegment from nemo.collections.asr.metrics import word_error_rate + from jasper.client import transcribe_gen try: res = dict(sample) @@ -36,7 +37,7 @@ def preprocess_datapoint( res["spoken"] = res["text"] res["utterance_id"] = audio_path.stem if not annotation_only: - from jasper.client import transcriber_pretrained, transcriber_speller + transcriber_pretrained = transcribe_gen(asr_port=8044) aud_seg = ( AudioSegment.from_file_using_temporary_files(audio_path) @@ -49,6 +50,7 @@ def preprocess_datapoint( [res["text"]], [res["pretrained_asr"]] ) if use_domain_asr: + transcriber_speller = transcribe_gen(asr_port=8045) res["domain_asr"] = transcriber_speller(aud_seg.raw_data) res["domain_wer"] = word_error_rate( [res["spoken"]], [res["pretrained_asr"]] @@ -74,19 +76,19 @@ def preprocess_datapoint( @app.command() def dump_validation_ui_data( - data_manifest_path: Path = typer.Option( - Path("./data/asr_data/call_alphanum/manifest.json"), show_default=True + dataset_path: Path = typer.Option( + Path("./data/asr_data/call_alphanum"), show_default=True ), - dump_path: Path = typer.Option( - Path("./data/valiation_data/ui_dump.json"), show_default=True - ), - use_domain_asr: bool = True, - annotation_only: bool = True, + dump_name: str = typer.Option("ui_dump.json", show_default=True), + use_domain_asr: bool = False, + annotation_only: bool = False, enable_plots: bool = True, ): from concurrent.futures import ThreadPoolExecutor from functools import partial + data_manifest_path = dataset_path / Path("manifest.json") + dump_path: Path = Path(f"./data/valiation_data/{dataset_path.stem}/{dump_name}") plot_dir = data_manifest_path.parent / Path("wav_plots") plot_dir.mkdir(parents=True, exist_ok=True) typer.echo(f"Using data manifest:{data_manifest_path}") @@ -137,7 +139,7 @@ def dump_validation_ui_data( @app.command() def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")): - col = get_mongo_conn().test.asr_validation + col = get_mongo_conn(col='asr_validation') cursor_obj = col.find({"type": "correction"}, projection={"_id": False}) corrections = [c for c in cursor_obj] @@ -154,7 +156,7 @@ def fill_unannotated( annotated_codes = {c["code"] for c in corrections} all_codes = {c["gold_chars"] for c in processed_data} unann_codes = all_codes - annotated_codes - mongo_conn = get_mongo_conn().test.asr_validation + mongo_conn = get_mongo_conn(col='asr_validation') for c in unann_codes: mongo_conn.find_one_and_update( {"type": "correction", "code": c}, @@ -232,7 +234,7 @@ def update_corrections( def clear_mongo_corrections(): delete = typer.confirm("are you sure you want to clear mongo collection it?") if delete: - col = get_mongo_conn().test.asr_validation + col = get_mongo_conn(col='asr_validation') col.delete_many({"type": "correction"}) typer.echo("deleted mongo collection.") typer.echo("Aborted") diff --git a/jasper/data/validation/ui.py b/jasper/data/validation/ui.py index f013677..0179cdd 100644 --- a/jasper/data/validation/ui.py +++ b/jasper/data/validation/ui.py @@ -9,7 +9,7 @@ app = typer.Typer() if not hasattr(st, "mongo_connected"): - st.mongoclient = get_mongo_conn().test.asr_validation + st.mongoclient = get_mongo_conn(col='asr_validation') mongo_conn = st.mongoclient def current_cursor_fn(): @@ -111,9 +111,8 @@ def main(manifest: Path): if selected == "Inaudible": corrected = "" if st.button("Submit"): - correct_code = corrected.replace(" ", "").upper() st.update_entry( - sample["utterance_id"], {"status": selected, "correction": correct_code} + sample["utterance_id"], {"status": selected, "correction": corrected} ) st.update_cursor(sample_no + 1) if correction_entry: