add tts cli args

master
Malar Kannan 2019-10-09 16:03:29 +05:30
parent dde35048b7
commit dcc9ab3625
1 changed files with 23 additions and 5 deletions

View File

@ -5,6 +5,8 @@ import numpy as np
import torch import torch
import pyaudio import pyaudio
import klepto import klepto
import argparse
from pathlib import Path
from .model import Tacotron2 from .model import Tacotron2
from glow import WaveGlow from glow import WaveGlow
from .hparams import HParams from .hparams import HParams
@ -133,10 +135,7 @@ def player_gen():
return play_device return play_device
def repl(): def repl(tts_model):
tts_model = TTSModel(
"/Users/malar/Work/tacotron2_r4_83000.pt", "/Users/malar/Work/waveglow_484000"
)
player = player_gen() player = player_gen()
def loop(): def loop():
@ -148,7 +147,26 @@ def repl():
def main(): def main():
interactive_loop = repl() parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"-t",
"--tacotron2_path",
type=Path,
default="./tacotron.pt",
help="Path to a tacotron2 model",
)
parser.add_argument(
"-w",
"--waveglow_path",
type=Path,
default="./waveglow_256channels.pt",
help="Path to a waveglow model",
)
args = parser.parse_args()
tts_model = TTSModel(**vars(args))
interactive_loop = repl(tts_model)
while True: while True:
interactive_loop() interactive_loop()