1. removed the transcriber_pretrained/speller from utils
2. introduced get_mongo_coll to get the collection object directly from mongo uri 3. removed processing of correction entries to remove space/upper casing
parent
e3a01169c2
commit
bca227a7d7
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import rpyc
|
import rpyc
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
|
@ -12,12 +13,9 @@ ASR_HOST = os.environ.get("JASPER_ASR_RPYC_HOST", "localhost")
|
||||||
ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045"))
|
ASR_PORT = int(os.environ.get("JASPER_ASR_RPYC_PORT", "8045"))
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
|
def transcribe_gen(asr_host=ASR_HOST, asr_port=ASR_PORT):
|
||||||
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
|
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
|
||||||
asr = rpyc.connect(asr_host, asr_port).root
|
asr = rpyc.connect(asr_host, asr_port).root
|
||||||
logger.info(f"connected to asr server successfully")
|
logger.info(f"connected to asr server successfully")
|
||||||
return asr.transcribe
|
return asr.transcribe
|
||||||
|
|
||||||
|
|
||||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
|
||||||
transcriber_speller = transcribe_gen(asr_port=8045)
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,7 @@
|
||||||
# import argparse
|
|
||||||
|
|
||||||
# import logging
|
|
||||||
import typer
|
import typer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
# leader_app = typer.Typer()
|
|
||||||
# app.add_typer(leader_app, name="leaderboard")
|
|
||||||
# plot_app = typer.Typer()
|
|
||||||
# app.add_typer(plot_app, name="plot")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
|
def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
|
||||||
|
|
@ -18,7 +10,7 @@ def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
|
||||||
from ruamel.yaml import YAML
|
from ruamel.yaml import YAML
|
||||||
|
|
||||||
yaml = YAML()
|
yaml = YAML()
|
||||||
mongo_coll = get_mongo_conn().test.calls
|
mongo_coll = get_mongo_conn()
|
||||||
caller_calls = defaultdict(lambda: [])
|
caller_calls = defaultdict(lambda: [])
|
||||||
for call in mongo_coll.find():
|
for call in mongo_coll.find():
|
||||||
sysid = call["SystemID"]
|
sysid = call["SystemID"]
|
||||||
|
|
@ -46,7 +38,7 @@ def export_calls_between(
|
||||||
from .utils import get_mongo_conn
|
from .utils import get_mongo_conn
|
||||||
|
|
||||||
yaml = YAML()
|
yaml = YAML()
|
||||||
mongo_coll = get_mongo_conn(port=mongo_port).test.calls
|
mongo_coll = get_mongo_conn(port=mongo_port)
|
||||||
start_meta = mongo_coll.find_one({"SystemID": start_cid})
|
start_meta = mongo_coll.find_one({"SystemID": start_cid})
|
||||||
end_meta = mongo_coll.find_one({"SystemID": end_cid})
|
end_meta = mongo_coll.find_one({"SystemID": end_cid})
|
||||||
|
|
||||||
|
|
@ -77,23 +69,21 @@ def analyze(
|
||||||
plot_calls: bool = False,
|
plot_calls: bool = False,
|
||||||
extract_data: bool = False,
|
extract_data: bool = False,
|
||||||
download_only: bool = False,
|
download_only: bool = False,
|
||||||
call_logs_file: Path = Path("./call_logs.yaml"),
|
call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True),
|
||||||
output_dir: Path = Path("./data"),
|
output_dir: Path = Path("./data"),
|
||||||
mongo_port: int = 27017,
|
data_name: str = None,
|
||||||
|
mongo_uri: str = typer.Option("mongodb://localhost:27017/test.calls", show_default=True),
|
||||||
):
|
):
|
||||||
|
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import json
|
import json
|
||||||
from ruamel.yaml import YAML
|
from ruamel.yaml import YAML
|
||||||
import re
|
import re
|
||||||
from google.protobuf.timestamp_pb2 import Timestamp
|
from google.protobuf.timestamp_pb2 import Timestamp
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
# from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import librosa
|
import librosa
|
||||||
import librosa.display
|
import librosa.display
|
||||||
from lenses import lens
|
from lenses import lens
|
||||||
|
|
@ -102,23 +92,17 @@ def analyze(
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib
|
import matplotlib
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .utils import asr_data_writer, get_mongo_conn
|
from .utils import asr_data_writer, get_mongo_coll
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
from natural.date import compress
|
from natural.date import compress
|
||||||
|
|
||||||
# from itertools import product, chain
|
|
||||||
|
|
||||||
matplotlib.rcParams["agg.path.chunksize"] = 10000
|
matplotlib.rcParams["agg.path.chunksize"] = 10000
|
||||||
|
|
||||||
matplotlib.use("agg")
|
matplotlib.use("agg")
|
||||||
|
|
||||||
# logging.basicConfig(
|
|
||||||
# level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
||||||
# )
|
|
||||||
# logger = logging.getLogger(__name__)
|
|
||||||
yaml = YAML()
|
yaml = YAML()
|
||||||
s3 = boto3.client("s3")
|
s3 = boto3.client("s3")
|
||||||
mongo_collection = get_mongo_conn(port=mongo_port).test.calls
|
mongo_collection = get_mongo_coll(mongo_uri)
|
||||||
call_media_dir: Path = output_dir / Path("call_wavs")
|
call_media_dir: Path = output_dir / Path("call_wavs")
|
||||||
call_media_dir.mkdir(exist_ok=True, parents=True)
|
call_media_dir.mkdir(exist_ok=True, parents=True)
|
||||||
call_meta_dir: Path = output_dir / Path("call_metas")
|
call_meta_dir: Path = output_dir / Path("call_metas")
|
||||||
|
|
@ -127,6 +111,7 @@ def analyze(
|
||||||
call_plot_dir.mkdir(exist_ok=True, parents=True)
|
call_plot_dir.mkdir(exist_ok=True, parents=True)
|
||||||
call_asr_data: Path = output_dir / Path("asr_data")
|
call_asr_data: Path = output_dir / Path("asr_data")
|
||||||
call_asr_data.mkdir(exist_ok=True, parents=True)
|
call_asr_data.mkdir(exist_ok=True, parents=True)
|
||||||
|
dataset_name = call_logs_file.stem if not data_name else data_name
|
||||||
|
|
||||||
call_logs = yaml.load(call_logs_file.read_text())
|
call_logs = yaml.load(call_logs_file.read_text())
|
||||||
|
|
||||||
|
|
@ -183,7 +168,7 @@ def analyze(
|
||||||
call_events = call_meta["Events"]
|
call_events = call_meta["Events"]
|
||||||
|
|
||||||
def is_writer_uri_event(ev):
|
def is_writer_uri_event(ev):
|
||||||
return ev["Author"] == "AUDIO_WRITER" and 's3://' in ev["Msg"]
|
return ev["Author"] == "AUDIO_WRITER" and "s3://" in ev["Msg"]
|
||||||
|
|
||||||
writer_events = list(filter(is_writer_uri_event, call_events))
|
writer_events = list(filter(is_writer_uri_event, call_events))
|
||||||
s3_wav_url = re.search(r"(s3://.*)", writer_events[0]["Msg"]).groups(0)[0]
|
s3_wav_url = re.search(r"(s3://.*)", writer_events[0]["Msg"]).groups(0)[0]
|
||||||
|
|
@ -268,8 +253,10 @@ def analyze(
|
||||||
meta = mongo_collection.find_one({"SystemID": cid})
|
meta = mongo_collection.find_one({"SystemID": cid})
|
||||||
duration = meta["EndTS"] - meta["StartTS"]
|
duration = meta["EndTS"] - meta["StartTS"]
|
||||||
process_meta = process_call(meta)
|
process_meta = process_call(meta)
|
||||||
data_points = get_data_points(process_meta['utter_events'], process_meta['first_event_fn'])
|
data_points = get_data_points(
|
||||||
process_meta['data_points'] = data_points
|
process_meta["utter_events"], process_meta["first_event_fn"]
|
||||||
|
)
|
||||||
|
process_meta["data_points"] = data_points
|
||||||
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
|
return {"url": uri, "meta": meta, "duration": duration, "process": process_meta}
|
||||||
|
|
||||||
def download_meta_audio():
|
def download_meta_audio():
|
||||||
|
|
@ -355,7 +342,7 @@ def analyze(
|
||||||
for dp in gen_data_values(saved_wav_path, data_points):
|
for dp in gen_data_values(saved_wav_path, data_points):
|
||||||
yield dp
|
yield dp
|
||||||
|
|
||||||
asr_data_writer(call_asr_data, "call_alphanum", data_source())
|
asr_data_writer(call_asr_data, dataset_name, data_source())
|
||||||
|
|
||||||
def show_leaderboard():
|
def show_leaderboard():
|
||||||
def compute_user_stats(call_stat):
|
def compute_user_stats(call_stat):
|
||||||
|
|
@ -383,14 +370,14 @@ def analyze(
|
||||||
leader_board = leader_df.rename(
|
leader_board = leader_df.rename(
|
||||||
columns={
|
columns={
|
||||||
"rank": "Rank",
|
"rank": "Rank",
|
||||||
"num_samples": "Codes",
|
"num_samples": "Count",
|
||||||
"name": "Name",
|
"name": "Name",
|
||||||
"samples_rate": "SpeechRate",
|
"samples_rate": "SpeechRate",
|
||||||
"duration_str": "Duration",
|
"duration_str": "Duration",
|
||||||
}
|
}
|
||||||
)[["Rank", "Name", "Codes", "Duration"]]
|
)[["Rank", "Name", "Count", "Duration"]]
|
||||||
print(
|
print(
|
||||||
"""ASR Speller Dataset Leaderboard :
|
"""ASR Dataset Leaderboard :
|
||||||
---------------------------------"""
|
---------------------------------"""
|
||||||
)
|
)
|
||||||
print(leader_board.to_string(index=False))
|
print(leader_board.to_string(index=False))
|
||||||
|
|
|
||||||
|
|
@ -104,10 +104,16 @@ class ExtendedPath(type(Path())):
|
||||||
return json.dump(data, jf, indent=2)
|
return json.dump(data, jf, indent=2)
|
||||||
|
|
||||||
|
|
||||||
def get_mongo_conn(host="", port=27017):
|
def get_mongo_coll(uri="mongodb://localhost:27017/test.calls"):
|
||||||
|
ud = pymongo.uri_parser.parse_uri(uri)
|
||||||
|
conn = pymongo.MongoClient(uri)
|
||||||
|
return conn[ud['database']][ud['collection']]
|
||||||
|
|
||||||
|
|
||||||
|
def get_mongo_conn(host="", port=27017, db="test", col="calls"):
|
||||||
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
|
||||||
mongo_uri = f"mongodb://{mongo_host}:{port}/"
|
mongo_uri = f"mongodb://{mongo_host}:{port}/"
|
||||||
return pymongo.MongoClient(mongo_uri)
|
return pymongo.MongoClient(mongo_uri)[db][col]
|
||||||
|
|
||||||
|
|
||||||
def strip_silence(sound):
|
def strip_silence(sound):
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ def preprocess_datapoint(
|
||||||
import librosa.display
|
import librosa.display
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
from nemo.collections.asr.metrics import word_error_rate
|
from nemo.collections.asr.metrics import word_error_rate
|
||||||
|
from jasper.client import transcribe_gen
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = dict(sample)
|
res = dict(sample)
|
||||||
|
|
@ -36,7 +37,7 @@ def preprocess_datapoint(
|
||||||
res["spoken"] = res["text"]
|
res["spoken"] = res["text"]
|
||||||
res["utterance_id"] = audio_path.stem
|
res["utterance_id"] = audio_path.stem
|
||||||
if not annotation_only:
|
if not annotation_only:
|
||||||
from jasper.client import transcriber_pretrained, transcriber_speller
|
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||||
|
|
||||||
aud_seg = (
|
aud_seg = (
|
||||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||||
|
|
@ -49,6 +50,7 @@ def preprocess_datapoint(
|
||||||
[res["text"]], [res["pretrained_asr"]]
|
[res["text"]], [res["pretrained_asr"]]
|
||||||
)
|
)
|
||||||
if use_domain_asr:
|
if use_domain_asr:
|
||||||
|
transcriber_speller = transcribe_gen(asr_port=8045)
|
||||||
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||||
res["domain_wer"] = word_error_rate(
|
res["domain_wer"] = word_error_rate(
|
||||||
[res["spoken"]], [res["pretrained_asr"]]
|
[res["spoken"]], [res["pretrained_asr"]]
|
||||||
|
|
@ -74,19 +76,19 @@ def preprocess_datapoint(
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def dump_validation_ui_data(
|
def dump_validation_ui_data(
|
||||||
data_manifest_path: Path = typer.Option(
|
dataset_path: Path = typer.Option(
|
||||||
Path("./data/asr_data/call_alphanum/manifest.json"), show_default=True
|
Path("./data/asr_data/call_alphanum"), show_default=True
|
||||||
),
|
),
|
||||||
dump_path: Path = typer.Option(
|
dump_name: str = typer.Option("ui_dump.json", show_default=True),
|
||||||
Path("./data/valiation_data/ui_dump.json"), show_default=True
|
use_domain_asr: bool = False,
|
||||||
),
|
annotation_only: bool = False,
|
||||||
use_domain_asr: bool = True,
|
|
||||||
annotation_only: bool = True,
|
|
||||||
enable_plots: bool = True,
|
enable_plots: bool = True,
|
||||||
):
|
):
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
data_manifest_path = dataset_path / Path("manifest.json")
|
||||||
|
dump_path: Path = Path(f"./data/valiation_data/{dataset_path.stem}/{dump_name}")
|
||||||
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
||||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||||
|
|
@ -137,7 +139,7 @@ def dump_validation_ui_data(
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")):
|
def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")):
|
||||||
col = get_mongo_conn().test.asr_validation
|
col = get_mongo_conn(col='asr_validation')
|
||||||
|
|
||||||
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
||||||
corrections = [c for c in cursor_obj]
|
corrections = [c for c in cursor_obj]
|
||||||
|
|
@ -154,7 +156,7 @@ def fill_unannotated(
|
||||||
annotated_codes = {c["code"] for c in corrections}
|
annotated_codes = {c["code"] for c in corrections}
|
||||||
all_codes = {c["gold_chars"] for c in processed_data}
|
all_codes = {c["gold_chars"] for c in processed_data}
|
||||||
unann_codes = all_codes - annotated_codes
|
unann_codes = all_codes - annotated_codes
|
||||||
mongo_conn = get_mongo_conn().test.asr_validation
|
mongo_conn = get_mongo_conn(col='asr_validation')
|
||||||
for c in unann_codes:
|
for c in unann_codes:
|
||||||
mongo_conn.find_one_and_update(
|
mongo_conn.find_one_and_update(
|
||||||
{"type": "correction", "code": c},
|
{"type": "correction", "code": c},
|
||||||
|
|
@ -232,7 +234,7 @@ def update_corrections(
|
||||||
def clear_mongo_corrections():
|
def clear_mongo_corrections():
|
||||||
delete = typer.confirm("are you sure you want to clear mongo collection it?")
|
delete = typer.confirm("are you sure you want to clear mongo collection it?")
|
||||||
if delete:
|
if delete:
|
||||||
col = get_mongo_conn().test.asr_validation
|
col = get_mongo_conn(col='asr_validation')
|
||||||
col.delete_many({"type": "correction"})
|
col.delete_many({"type": "correction"})
|
||||||
typer.echo("deleted mongo collection.")
|
typer.echo("deleted mongo collection.")
|
||||||
typer.echo("Aborted")
|
typer.echo("Aborted")
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(st, "mongo_connected"):
|
if not hasattr(st, "mongo_connected"):
|
||||||
st.mongoclient = get_mongo_conn().test.asr_validation
|
st.mongoclient = get_mongo_conn(col='asr_validation')
|
||||||
mongo_conn = st.mongoclient
|
mongo_conn = st.mongoclient
|
||||||
|
|
||||||
def current_cursor_fn():
|
def current_cursor_fn():
|
||||||
|
|
@ -111,9 +111,8 @@ def main(manifest: Path):
|
||||||
if selected == "Inaudible":
|
if selected == "Inaudible":
|
||||||
corrected = ""
|
corrected = ""
|
||||||
if st.button("Submit"):
|
if st.button("Submit"):
|
||||||
correct_code = corrected.replace(" ", "").upper()
|
|
||||||
st.update_entry(
|
st.update_entry(
|
||||||
sample["utterance_id"], {"status": selected, "correction": correct_code}
|
sample["utterance_id"], {"status": selected, "correction": corrected}
|
||||||
)
|
)
|
||||||
st.update_cursor(sample_no + 1)
|
st.update_cursor(sample_no + 1)
|
||||||
if correction_entry:
|
if correction_entry:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue