1
0
mirror of https://github.com/malarinv/plume-asr.git synced 2026-03-08 04:12:35 +00:00

1. Self contained typers

2. Asr force-aligner visualization
3. streamlit state management abstraction
4. new utils / reorganize
5. added verbose flags
6. add tts by name
This commit is contained in:
2021-03-23 13:27:35 +05:30
parent f72c6bbe5b
commit c474aa5f5a
22 changed files with 1097 additions and 146 deletions

View File

@@ -3,12 +3,20 @@ import sys
from pathlib import Path
from plume.utils import lazy_module
# from streamlit import cli as stcli
stcli = lazy_module('streamlit.cli')
stcli = lazy_module("streamlit.cli")
app = typer.Typer()
@app.callback()
def ui():
"""
ui sub commands
"""
@app.command()
def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
annotation_lit_path = Path(__file__).parent / Path("annotation.py")
@@ -40,13 +48,7 @@ def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str =
@app.command()
def preview(manifest_path: Path):
annotation_lit_path = Path(__file__).parent / Path("preview.py")
sys.argv = [
"streamlit",
"run",
str(annotation_lit_path),
"--",
str(manifest_path)
]
sys.argv = ["streamlit", "run", str(annotation_lit_path), "--", str(manifest_path)]
sys.exit(stcli.main())
@@ -56,6 +58,18 @@ def collection(data_dir: Path, task_id: str = ""):
pass
@app.command()
def alignment(preview_dir: Path, port: int = 8010):
from RangeHTTPServer import RangeRequestHandler
from functools import partial
from http.server import HTTPServer
server_address = ("", port)
handler_class = partial(RangeRequestHandler, directory=str(preview_dir))
httpd = HTTPServer(server_address, handler_class)
httpd.serve_forever()
def main():
app()

View File

@@ -1,66 +1,14 @@
# import sys
from pathlib import Path
from uuid import uuid4
import streamlit as st
import typer
from plume.utils import ExtendedPath, get_mongo_conn
from plume.preview.st_rerun import rerun
from plume.utils import ExtendedPath
from plume.utils.ui_persist import setup_mongo_asr_validation_state
app = typer.Typer()
if not hasattr(st, "mongo_connected"):
st.mongoclient = get_mongo_conn(col="asr_validation")
mongo_conn = st.mongoclient
st.task_id = str(uuid4())
def current_cursor_fn():
# mongo_conn = st.mongoclient
cursor_obj = mongo_conn.find_one(
{"type": "current_cursor", "task_id": st.task_id}
)
cursor_val = cursor_obj["cursor"]
return cursor_val
def update_cursor_fn(val=0):
mongo_conn.find_one_and_update(
{"type": "current_cursor", "task_id": st.task_id},
{"$set": {"type": "current_cursor", "task_id": st.task_id, "cursor": val}},
upsert=True,
)
rerun()
def get_correction_entry_fn(code):
return mongo_conn.find_one(
{"type": "correction", "code": code}, projection={"_id": False}
)
def update_entry_fn(code, value):
mongo_conn.find_one_and_update(
{"type": "correction", "code": code},
{"$set": {"value": value, "task_id": st.task_id}},
upsert=True,
)
def set_task_fn(data_path, task_id):
if task_id:
st.task_id = task_id
task_path = data_path / Path(f"task-{st.task_id}.lck")
if not task_path.exists():
print(f"creating task lock at {task_path}")
task_path.touch()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn
st.get_correction_entry = get_correction_entry_fn
st.update_entry = update_entry_fn
st.set_task = set_task_fn
st.mongo_connected = True
cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
if not cursor_obj:
update_cursor_fn(0)
setup_mongo_asr_validation_state(st)
@st.cache()

View File

@@ -3,27 +3,11 @@ from pathlib import Path
import streamlit as st
import typer
from plume.utils import ExtendedPath
from plume.preview.st_rerun import rerun
from plume.utils.ui_persist import setup_file_state
app = typer.Typer()
if not hasattr(st, "state_lock"):
# st.task_id = str(uuid4())
task_path = ExtendedPath("preview.lck")
def current_cursor_fn():
return task_path.read_json()["current_cursor"]
def update_cursor_fn(val=0):
task_path.write_json({"current_cursor": val})
rerun()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn
st.state_lock = True
# cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
# if not cursor_obj:
update_cursor_fn(0)
setup_file_state(st)
@st.cache()
@@ -40,7 +24,7 @@ def main(manifest: Path):
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
st.title(f"ASR Manifest Preview")
st.title("ASR Manifest Preview")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)

View File

@@ -1,45 +0,0 @@
try:
# Before Streamlit 0.65
from streamlit.ReportThread import get_report_ctx
from streamlit.server.Server import Server
from streamlit.ScriptRequestQueue import RerunData
from streamlit.ScriptRunner import RerunException
except ModuleNotFoundError:
# After Streamlit 0.65
from streamlit.report_thread import get_report_ctx
from streamlit.server.server import Server
from streamlit.script_request_queue import RerunData
from streamlit.script_runner import RerunException
def rerun():
"""Rerun a Streamlit app from the top!"""
widget_states = _get_widget_states()
raise RerunException(RerunData(widget_states))
def _get_widget_states():
# Hack to get the session object from Streamlit.
ctx = get_report_ctx()
session = None
current_server = Server.get_current()
if hasattr(current_server, '_session_infos'):
# Streamlit < 0.56
session_infos = Server.get_current()._session_infos.values()
else:
session_infos = Server.get_current()._session_info_by_id.values()
for session_info in session_infos:
if session_info.session.enqueue == ctx.enqueue:
session = session_info.session
if session is None:
raise RuntimeError(
"Oh noes. Couldn't get your Streamlit Session object"
"Are you doing something fancy with threads?"
)
# Got the session object!
return session._widget_states