mirror of https://github.com/malarinv/tacotron2
1. create a class for the tts api
2. implement a grpc server for tts 3. add a demo client for the grpc service 4. update gitignore and requriements 5. cleanupmaster
parent
ccd8ab42e7
commit
505af768a0
|
|
@ -0,0 +1,166 @@
|
||||||
|
|
||||||
|
# Created by https://www.gitignore.io/api/python
|
||||||
|
# Edit at https://www.gitignore.io/?templates=python
|
||||||
|
|
||||||
|
### Python ###
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# End of https://www.gitignore.io/api/python
|
||||||
|
|
||||||
|
# Created by https://www.gitignore.io/api/macos
|
||||||
|
# Edit at https://www.gitignore.io/?templates=macos
|
||||||
|
|
||||||
|
### macOS ###
|
||||||
|
# General
|
||||||
|
.DS_Store
|
||||||
|
.AppleDouble
|
||||||
|
.LSOverride
|
||||||
|
|
||||||
|
# Icon must end with two \r
|
||||||
|
Icon
|
||||||
|
|
||||||
|
# Thumbnails
|
||||||
|
._*
|
||||||
|
|
||||||
|
# Files that might appear in the root of a volume
|
||||||
|
.DocumentRevisions-V100
|
||||||
|
.fseventsd
|
||||||
|
.Spotlight-V100
|
||||||
|
.TemporaryItems
|
||||||
|
.Trashes
|
||||||
|
.VolumeIcon.icns
|
||||||
|
.com.apple.timemachine.donotpresent
|
||||||
|
|
||||||
|
# Directories potentially created on remote AFP share
|
||||||
|
.AppleDB
|
||||||
|
.AppleDesktop
|
||||||
|
Network Trash Folder
|
||||||
|
Temporary Items
|
||||||
|
.apdisk
|
||||||
|
|
||||||
|
# End of https://www.gitignore.io/api/macos
|
||||||
|
|
||||||
|
*.pkl
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
import grpc
|
||||||
|
from sia.proto import tts_pb2
|
||||||
|
from sia.proto import tts_pb2_grpc
|
||||||
|
from tts import player_gen
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
channel = grpc.insecure_channel('localhost:50060')
|
||||||
|
stub = tts_pb2_grpc.ServerStub(channel)
|
||||||
|
test_text = tts_pb2.TextInput(text='How may I help you today?')
|
||||||
|
speech = stub.TextToSpeechAPI(test_text)
|
||||||
|
player = player_gen()
|
||||||
|
player(speech.response)
|
||||||
|
import pdb
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
10
final.py
10
final.py
|
|
@ -22,16 +22,17 @@ import pyaudio
|
||||||
import klepto
|
import klepto
|
||||||
import IPython.display as ipd
|
import IPython.display as ipd
|
||||||
import time
|
import time
|
||||||
|
from sia.file_utils import cached_model_path
|
||||||
|
|
||||||
sys.path.append('waveglow/')
|
sys.path.append('waveglow/')
|
||||||
hparams = create_hparams()
|
hparams = create_hparams()
|
||||||
hparams.sampling_rate = 22050
|
hparams.sampling_rate = 22050
|
||||||
checkpoint_path = "checkpoint_15000"
|
|
||||||
model = load_model(hparams)
|
model = load_model(hparams)
|
||||||
|
tacotron2_path = cached_model_path("tacotron2_model")
|
||||||
model.load_state_dict(
|
model.load_state_dict(
|
||||||
torch.load(checkpoint_path, map_location='cpu')['state_dict'])
|
torch.load(tacotron2_path, map_location='cpu')['state_dict'])
|
||||||
model.eval()
|
model.eval()
|
||||||
waveglow_path = 'waveglow_256channels.pt'
|
waveglow_path = cached_model_path('waveglow_model')
|
||||||
waveglow = torch.load(waveglow_path, map_location='cpu')['model']
|
waveglow = torch.load(waveglow_path, map_location='cpu')['model']
|
||||||
waveglow.eval()
|
waveglow.eval()
|
||||||
for k in waveglow.convinv:
|
for k in waveglow.convinv:
|
||||||
|
|
@ -93,9 +94,10 @@ def player_gen():
|
||||||
|
|
||||||
def synthesize_corpus():
|
def synthesize_corpus():
|
||||||
all_data = []
|
all_data = []
|
||||||
for line in open('corpus.txt').readlines():
|
for (i, line) in enumerate(open('corpus.txt').readlines()):
|
||||||
print('synthesizing... "{}"'.format(line.strip()))
|
print('synthesizing... "{}"'.format(line.strip()))
|
||||||
data = speech(line.strip())
|
data = speech(line.strip())
|
||||||
|
sf.write('tts_{}.wav'.format(i), data, 16000)
|
||||||
all_data.append(data)
|
all_data.append(data)
|
||||||
return all_data
|
return all_data
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from text import symbols
|
from text import symbols
|
||||||
|
|
||||||
#changed path, sampling rate and batch size
|
|
||||||
|
# changed path, sampling rate and batch size
|
||||||
def create_hparams(hparams_string=None, verbose=False):
|
def create_hparams(hparams_string=None, verbose=False):
|
||||||
"""Create model hyperparameters. Parse nondefault from given string."""
|
"""Create model hyperparameters. Parse nondefault from given string."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
pip==18.1
|
||||||
|
bumpversion==0.5.3
|
||||||
|
wheel==0.32.1
|
||||||
|
watchdog==0.9.0
|
||||||
|
flake8==3.5.0
|
||||||
|
tox==3.5.2
|
||||||
|
coverage==4.5.1
|
||||||
|
Sphinx==1.8.1
|
||||||
|
twine==1.12.1
|
||||||
|
|
||||||
|
pytest==3.8.2
|
||||||
|
pytest-runner==4.2
|
||||||
|
pre-commit==1.16.1
|
||||||
|
python-language-server[all]
|
||||||
|
ipdb
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import grpc
|
||||||
|
import time
|
||||||
|
from sia.proto import tts_pb2
|
||||||
|
from sia.proto import tts_pb2_grpc
|
||||||
|
from concurrent import futures
|
||||||
|
from sia.instruments import do_time
|
||||||
|
from tts import TTSModel
|
||||||
|
|
||||||
|
|
||||||
|
class TTSServer():
|
||||||
|
def __init__(self):
|
||||||
|
self.tts_model = TTSModel()
|
||||||
|
|
||||||
|
def TextToSpeechAPI(self, request, context):
|
||||||
|
while (True):
|
||||||
|
input_text = request.text
|
||||||
|
speech_response = self.tts_model.synth_speech(input_text)
|
||||||
|
return tts_pb2.SpeechResponse(response=speech_response)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||||
|
tts_server = TTSServer()
|
||||||
|
tts_pb2_grpc.add_ServerServicer_to_server(tts_server, server)
|
||||||
|
server.add_insecure_port('localhost:50060')
|
||||||
|
server.start()
|
||||||
|
print('TTSServer started!')
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(10000)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
server.start()
|
||||||
|
# server.stop(0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
tensorboard.png
BIN
tensorboard.png
Binary file not shown.
|
Before Width: | Height: | Size: 170 KiB |
|
|
@ -0,0 +1,128 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
# import matplotlib
|
||||||
|
# import matplotlib.pylab as plt
|
||||||
|
|
||||||
|
# import IPython.display as ipd
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from hparams import create_hparams
|
||||||
|
from model import Tacotron2
|
||||||
|
from layers import TacotronSTFT, STFT
|
||||||
|
# from audio_processing import griffin_lim
|
||||||
|
from train import load_model
|
||||||
|
from text import text_to_sequence
|
||||||
|
# from denoiser import Denoiser
|
||||||
|
import os
|
||||||
|
import soundfile as sf
|
||||||
|
import pyaudio
|
||||||
|
import klepto
|
||||||
|
import IPython.display as ipd
|
||||||
|
import time
|
||||||
|
from sia.file_utils import cached_model_path
|
||||||
|
|
||||||
|
sys.path.append('waveglow/')
|
||||||
|
|
||||||
|
|
||||||
|
class TTSModel(object):
|
||||||
|
"""docstring for TTSModel."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(TTSModel, self).__init__()
|
||||||
|
hparams = create_hparams()
|
||||||
|
hparams.sampling_rate = 22050
|
||||||
|
self.model = load_model(hparams)
|
||||||
|
tacotron2_path = cached_model_path("tacotron2_model")
|
||||||
|
self.model.load_state_dict(
|
||||||
|
torch.load(tacotron2_path, map_location='cpu')['state_dict'])
|
||||||
|
self.model.eval()
|
||||||
|
waveglow_path = cached_model_path('waveglow_model')
|
||||||
|
self.waveglow = torch.load(waveglow_path, map_location='cpu')['model']
|
||||||
|
self.waveglow.eval()
|
||||||
|
for k in self.waveglow.convinv:
|
||||||
|
k.float()
|
||||||
|
self.k_cache = klepto.archives.file_archive(cached=False)
|
||||||
|
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
|
||||||
|
self.synth_speech)
|
||||||
|
|
||||||
|
# https://github.com/NVIDIA/waveglow/issues/127
|
||||||
|
for m in self.waveglow.modules():
|
||||||
|
if 'Conv' in str(type(m)):
|
||||||
|
setattr(m, 'padding_mode', 'zeros')
|
||||||
|
|
||||||
|
def synth_speech(self, t):
|
||||||
|
start = time.time()
|
||||||
|
text = t
|
||||||
|
sequence = np.array(text_to_sequence(text,
|
||||||
|
['english_cleaners']))[None, :]
|
||||||
|
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
|
||||||
|
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(
|
||||||
|
sequence)
|
||||||
|
with torch.no_grad():
|
||||||
|
audio = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
data = convert(audio[0].data.cpu().numpy())
|
||||||
|
# _audio_stream.write(data.astype('float32'))
|
||||||
|
# _audio_stream.write(data)
|
||||||
|
end = time.time()
|
||||||
|
print(end - start)
|
||||||
|
return data.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
def convert(array):
|
||||||
|
sf.write('sample.wav', array, 22050)
|
||||||
|
os.system('ffmpeg -i {0} -filter:a "atempo=0.80" -ar 16k {1}'.format(
|
||||||
|
'sample.wav', 'sample0.wav'))
|
||||||
|
data, rate = sf.read('sample0.wav', dtype='int16')
|
||||||
|
os.remove('sample.wav')
|
||||||
|
os.remove('sample0.wav')
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def display(data):
|
||||||
|
aud = ipd.Audio(data, rate=16000)
|
||||||
|
return aud
|
||||||
|
|
||||||
|
|
||||||
|
def player_gen():
|
||||||
|
audio_interface = pyaudio.PyAudio()
|
||||||
|
_audio_stream = audio_interface.open(format=pyaudio.paInt16,
|
||||||
|
channels=1,
|
||||||
|
rate=16000,
|
||||||
|
output=True)
|
||||||
|
|
||||||
|
def play_device(data):
|
||||||
|
_audio_stream.write(data)
|
||||||
|
# _audio_stream.close()
|
||||||
|
|
||||||
|
return play_device
|
||||||
|
|
||||||
|
|
||||||
|
def synthesize_corpus():
|
||||||
|
tts_model = TTSModel()
|
||||||
|
all_data = []
|
||||||
|
for (i, line) in enumerate(open('corpus.txt').readlines()):
|
||||||
|
print('synthesizing... "{}"'.format(line.strip()))
|
||||||
|
data = tts_model.synth_speech(line.strip())
|
||||||
|
all_data.append(data)
|
||||||
|
return all_data
|
||||||
|
|
||||||
|
|
||||||
|
def play_corpus(corpus_synths):
|
||||||
|
player = player_gen()
|
||||||
|
for d in corpus_synths:
|
||||||
|
player(d)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
corpus_synth_data = synthesize_corpus()
|
||||||
|
play_corpus(corpus_synth_data)
|
||||||
|
import ipdb
|
||||||
|
ipdb.set_trace()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue