mirror of
https://github.com/malarinv/plume-asr.git
synced 2026-03-08 04:12:35 +00:00
massive refactor/rename to plume
This commit is contained in:
64
plume/ui/__init__.py
Normal file
64
plume/ui/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import typer
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from plume.utils import lazy_module
|
||||
# from streamlit import cli as stcli
|
||||
|
||||
stcli = lazy_module('streamlit.cli')
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
|
||||
annotation_lit_path = Path(__file__).parent / Path("annotation.py")
|
||||
if task_id:
|
||||
sys.argv = [
|
||||
"streamlit",
|
||||
"run",
|
||||
str(annotation_lit_path),
|
||||
"--",
|
||||
str(data_dir),
|
||||
"--task-id",
|
||||
task_id,
|
||||
"--dump-fname",
|
||||
dump_fname,
|
||||
]
|
||||
else:
|
||||
sys.argv = [
|
||||
"streamlit",
|
||||
"run",
|
||||
str(annotation_lit_path),
|
||||
"--",
|
||||
str(data_dir),
|
||||
"--dump-fname",
|
||||
dump_fname,
|
||||
]
|
||||
sys.exit(stcli.main())
|
||||
|
||||
|
||||
@app.command()
|
||||
def preview(manifest_path: Path):
|
||||
annotation_lit_path = Path(__file__).parent / Path("preview.py")
|
||||
sys.argv = [
|
||||
"streamlit",
|
||||
"run",
|
||||
str(annotation_lit_path),
|
||||
"--",
|
||||
str(manifest_path)
|
||||
]
|
||||
sys.exit(stcli.main())
|
||||
|
||||
|
||||
@app.command()
|
||||
def collection(data_dir: Path, task_id: str = ""):
|
||||
# TODO: Implement web ui for data collection
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
160
plume/ui/annotation.py
Normal file
160
plume/ui/annotation.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# import sys
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import streamlit as st
|
||||
import typer
|
||||
|
||||
from plume.utils import ExtendedPath, get_mongo_conn
|
||||
from plume.preview.st_rerun import rerun
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
if not hasattr(st, "mongo_connected"):
|
||||
st.mongoclient = get_mongo_conn(col="asr_validation")
|
||||
mongo_conn = st.mongoclient
|
||||
st.task_id = str(uuid4())
|
||||
|
||||
def current_cursor_fn():
|
||||
# mongo_conn = st.mongoclient
|
||||
cursor_obj = mongo_conn.find_one(
|
||||
{"type": "current_cursor", "task_id": st.task_id}
|
||||
)
|
||||
cursor_val = cursor_obj["cursor"]
|
||||
return cursor_val
|
||||
|
||||
def update_cursor_fn(val=0):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "current_cursor", "task_id": st.task_id},
|
||||
{"$set": {"type": "current_cursor", "task_id": st.task_id, "cursor": val}},
|
||||
upsert=True,
|
||||
)
|
||||
rerun()
|
||||
|
||||
def get_correction_entry_fn(code):
|
||||
return mongo_conn.find_one(
|
||||
{"type": "correction", "code": code}, projection={"_id": False}
|
||||
)
|
||||
|
||||
def update_entry_fn(code, value):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": code},
|
||||
{"$set": {"value": value, "task_id": st.task_id}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
def set_task_fn(data_path, task_id):
|
||||
if task_id:
|
||||
st.task_id = task_id
|
||||
task_path = data_path / Path(f"task-{st.task_id}.lck")
|
||||
if not task_path.exists():
|
||||
print(f"creating task lock at {task_path}")
|
||||
task_path.touch()
|
||||
|
||||
st.get_current_cursor = current_cursor_fn
|
||||
st.update_cursor = update_cursor_fn
|
||||
st.get_correction_entry = get_correction_entry_fn
|
||||
st.update_entry = update_entry_fn
|
||||
st.set_task = set_task_fn
|
||||
st.mongo_connected = True
|
||||
cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
|
||||
if not cursor_obj:
|
||||
update_cursor_fn(0)
|
||||
|
||||
|
||||
@st.cache()
|
||||
def load_ui_data(data_dir: Path, dump_fname: Path):
|
||||
validation_ui_data_path = data_dir / dump_fname
|
||||
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||
return ExtendedPath(validation_ui_data_path).read_json()
|
||||
|
||||
|
||||
def show_key(sample, key, trail=""):
|
||||
if key in sample:
|
||||
title = key.replace("_", " ").title()
|
||||
if type(sample[key]) == float:
|
||||
st.sidebar.markdown(f"{title}: {sample[key]:.2f}{trail}")
|
||||
else:
|
||||
st.sidebar.markdown(f"{title}: {sample[key]}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
|
||||
st.set_task(data_dir, task_id)
|
||||
ui_config = load_ui_data(data_dir, dump_fname)
|
||||
asr_data = ui_config["data"]
|
||||
annotation_only = ui_config.get("annotation_only", False)
|
||||
asr_result_key = ui_config.get("asr_result_key", "pretrained_asr")
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
task_uid = st.task_id.rsplit("-", 1)[1]
|
||||
if annotation_only:
|
||||
st.title(f"ASR Annotation - # {task_uid}")
|
||||
else:
|
||||
st.title(f"ASR Validation - # {task_uid}")
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
# if "caller" in sample:
|
||||
# st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
||||
show_key(sample, "caller")
|
||||
if not annotation_only:
|
||||
show_key(sample, asr_result_key)
|
||||
show_key(sample, "asr_wer", trail="%")
|
||||
show_key(sample, "correct_candidate")
|
||||
|
||||
st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes())
|
||||
st.audio((data_dir / Path(sample["audio_path"])).open("rb"))
|
||||
# set default to text
|
||||
corrected = sample["text"]
|
||||
correction_entry = st.get_correction_entry(sample["utterance_id"])
|
||||
selected_idx = 0
|
||||
options = ("Correct", "Incorrect", "Inaudible")
|
||||
# if correction entry is present set the corresponding ui defaults
|
||||
if correction_entry:
|
||||
selected_idx = options.index(correction_entry["value"]["status"])
|
||||
corrected = correction_entry["value"]["correction"]
|
||||
selected = st.radio("The Audio is", options, index=selected_idx)
|
||||
if selected == "Incorrect":
|
||||
corrected = st.text_input("Actual:", value=corrected)
|
||||
if selected == "Inaudible":
|
||||
corrected = ""
|
||||
if st.button("Submit"):
|
||||
st.update_entry(
|
||||
sample["utterance_id"], {"status": selected, "correction": corrected}
|
||||
)
|
||||
st.update_cursor(sample_no + 1)
|
||||
if correction_entry:
|
||||
status = correction_entry["value"]["status"]
|
||||
correction = correction_entry["value"]["correction"]
|
||||
st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
|
||||
text_sample = st.text_input("Go to Text:", value="")
|
||||
if text_sample != "":
|
||||
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
|
||||
if len(candidates) > 0:
|
||||
st.update_cursor(candidates[0])
|
||||
real_idx = st.number_input(
|
||||
"Go to real-index",
|
||||
value=sample["real_idx"],
|
||||
min_value=0,
|
||||
max_value=len(asr_data) - 1,
|
||||
)
|
||||
if real_idx != int(sample["real_idx"]):
|
||||
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
|
||||
st.update_cursor(idx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
58
plume/ui/preview.py
Normal file
58
plume/ui/preview.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
import typer
|
||||
from plume.utils import ExtendedPath
|
||||
from plume.preview.st_rerun import rerun
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
if not hasattr(st, "state_lock"):
|
||||
# st.task_id = str(uuid4())
|
||||
task_path = ExtendedPath("preview.lck")
|
||||
|
||||
def current_cursor_fn():
|
||||
return task_path.read_json()["current_cursor"]
|
||||
|
||||
def update_cursor_fn(val=0):
|
||||
task_path.write_json({"current_cursor": val})
|
||||
rerun()
|
||||
|
||||
st.get_current_cursor = current_cursor_fn
|
||||
st.update_cursor = update_cursor_fn
|
||||
st.state_lock = True
|
||||
# cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
|
||||
# if not cursor_obj:
|
||||
update_cursor_fn(0)
|
||||
|
||||
|
||||
@st.cache()
|
||||
def load_ui_data(validation_ui_data_path: Path):
|
||||
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||
return list(ExtendedPath(validation_ui_data_path).read_jsonl())
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(manifest: Path):
|
||||
asr_data = load_ui_data(manifest)
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
st.title(f"ASR Manifest Preview")
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
st.audio((manifest.parent / Path(sample["audio_filepath"])).open("rb"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
45
plume/ui/st_rerun.py
Normal file
45
plume/ui/st_rerun.py
Normal file
@@ -0,0 +1,45 @@
|
||||
try:
|
||||
# Before Streamlit 0.65
|
||||
from streamlit.ReportThread import get_report_ctx
|
||||
from streamlit.server.Server import Server
|
||||
from streamlit.ScriptRequestQueue import RerunData
|
||||
from streamlit.ScriptRunner import RerunException
|
||||
except ModuleNotFoundError:
|
||||
# After Streamlit 0.65
|
||||
from streamlit.report_thread import get_report_ctx
|
||||
from streamlit.server.server import Server
|
||||
from streamlit.script_request_queue import RerunData
|
||||
from streamlit.script_runner import RerunException
|
||||
|
||||
|
||||
def rerun():
|
||||
"""Rerun a Streamlit app from the top!"""
|
||||
widget_states = _get_widget_states()
|
||||
raise RerunException(RerunData(widget_states))
|
||||
|
||||
|
||||
def _get_widget_states():
|
||||
# Hack to get the session object from Streamlit.
|
||||
|
||||
ctx = get_report_ctx()
|
||||
|
||||
session = None
|
||||
|
||||
current_server = Server.get_current()
|
||||
if hasattr(current_server, '_session_infos'):
|
||||
# Streamlit < 0.56
|
||||
session_infos = Server.get_current()._session_infos.values()
|
||||
else:
|
||||
session_infos = Server.get_current()._session_info_by_id.values()
|
||||
|
||||
for session_info in session_infos:
|
||||
if session_info.session.enqueue == ctx.enqueue:
|
||||
session = session_info.session
|
||||
|
||||
if session is None:
|
||||
raise RuntimeError(
|
||||
"Oh noes. Couldn't get your Streamlit Session object"
|
||||
"Are you doing something fancy with threads?"
|
||||
)
|
||||
# Got the session object!
|
||||
return session._widget_states
|
||||
Reference in New Issue
Block a user