Compare commits

...

2 Commits

Author SHA1 Message Date
Malar Kannan 5a30069f0a update tempo and output sample rate 2019-10-09 17:25:51 +05:30
Malar Kannan dcc9ab3625 add tts cli args 2019-10-09 16:23:21 +05:30
1 changed files with 26 additions and 7 deletions

View File

@ -5,6 +5,8 @@ import numpy as np
import torch import torch
import pyaudio import pyaudio
import klepto import klepto
import argparse
from pathlib import Path
from .model import Tacotron2 from .model import Tacotron2
from glow import WaveGlow from glow import WaveGlow
from .hparams import HParams from .hparams import HParams
@ -14,7 +16,8 @@ from .denoiser import Denoiser
from .audio_processing import griffin_lim, postprocess_audio from .audio_processing import griffin_lim, postprocess_audio
TTS_SAMPLE_RATE = 22050 TTS_SAMPLE_RATE = 22050
OUTPUT_SAMPLE_RATE = 16000 OUTPUT_SAMPLE_RATE = 22050
# OUTPUT_SAMPLE_RATE = 16000
# config from # config from
# https://github.com/NVIDIA/waveglow/blob/master/config.json # https://github.com/NVIDIA/waveglow/blob/master/config.json
@ -116,7 +119,7 @@ class TTSModel(object):
audio = audio.cpu().numpy() audio = audio.cpu().numpy()
return postprocess_audio( return postprocess_audio(
audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE audio, tempo=0.6, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE
) )
@ -133,10 +136,7 @@ def player_gen():
return play_device return play_device
def repl(): def repl(tts_model):
tts_model = TTSModel(
"/Users/malar/Work/tacotron2_r4_83000.pt", "/Users/malar/Work/waveglow_484000"
)
player = player_gen() player = player_gen()
def loop(): def loop():
@ -148,7 +148,26 @@ def repl():
def main(): def main():
interactive_loop = repl() parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"-t",
"--tacotron2_path",
type=Path,
default="./tacotron.pt",
help="Path to a tacotron2 model",
)
parser.add_argument(
"-w",
"--waveglow_path",
type=Path,
default="./waveglow_256channels.pt",
help="Path to a waveglow model",
)
args = parser.parse_args()
tts_model = TTSModel(**vars(args))
interactive_loop = repl(tts_model)
while True: while True:
interactive_loop() interactive_loop()