diff --git a/taco2/tts.py b/taco2/tts.py index ce1695f..1574404 100644 --- a/taco2/tts.py +++ b/taco2/tts.py @@ -5,6 +5,8 @@ import numpy as np import torch import pyaudio import klepto +import argparse +from pathlib import Path from .model import Tacotron2 from glow import WaveGlow from .hparams import HParams @@ -133,10 +135,7 @@ def player_gen(): return play_device -def repl(): - tts_model = TTSModel( - "/Users/malar/Work/tacotron2_r4_83000.pt", "/Users/malar/Work/waveglow_484000" - ) +def repl(tts_model): player = player_gen() def loop(): @@ -148,7 +147,26 @@ def repl(): def main(): - interactive_loop = repl() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "-t", + "--tacotron2_path", + type=Path, + default="./tacotron.pt", + help="Path to a tacotron2 model", + ) + parser.add_argument( + "-w", + "--waveglow_path", + type=Path, + default="./waveglow_256channels.pt", + help="Path to a waveglow model", + ) + args = parser.parse_args() + tts_model = TTSModel(**vars(args)) + interactive_loop = repl(tts_model) while True: interactive_loop()