mirror of https://github.com/malarinv/tacotron2
1. create a class for the tts api
2. implement a grpc server for tts 3. add a demo client for the grpc service 4. update gitignore and requriements 5. cleanupmaster
parent
ccd8ab42e7
commit
505af768a0
|
|
@ -0,0 +1,166 @@
|
|||
|
||||
# Created by https://www.gitignore.io/api/python
|
||||
# Edit at https://www.gitignore.io/?templates=python
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# End of https://www.gitignore.io/api/python
|
||||
|
||||
# Created by https://www.gitignore.io/api/macos
|
||||
# Edit at https://www.gitignore.io/?templates=macos
|
||||
|
||||
### macOS ###
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
# End of https://www.gitignore.io/api/macos
|
||||
|
||||
*.pkl
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
import grpc
|
||||
from sia.proto import tts_pb2
|
||||
from sia.proto import tts_pb2_grpc
|
||||
from tts import player_gen
|
||||
|
||||
|
||||
def main():
|
||||
channel = grpc.insecure_channel('localhost:50060')
|
||||
stub = tts_pb2_grpc.ServerStub(channel)
|
||||
test_text = tts_pb2.TextInput(text='How may I help you today?')
|
||||
speech = stub.TextToSpeechAPI(test_text)
|
||||
player = player_gen()
|
||||
player(speech.response)
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
10
final.py
10
final.py
|
|
@ -22,16 +22,17 @@ import pyaudio
|
|||
import klepto
|
||||
import IPython.display as ipd
|
||||
import time
|
||||
from sia.file_utils import cached_model_path
|
||||
|
||||
sys.path.append('waveglow/')
|
||||
hparams = create_hparams()
|
||||
hparams.sampling_rate = 22050
|
||||
checkpoint_path = "checkpoint_15000"
|
||||
model = load_model(hparams)
|
||||
tacotron2_path = cached_model_path("tacotron2_model")
|
||||
model.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location='cpu')['state_dict'])
|
||||
torch.load(tacotron2_path, map_location='cpu')['state_dict'])
|
||||
model.eval()
|
||||
waveglow_path = 'waveglow_256channels.pt'
|
||||
waveglow_path = cached_model_path('waveglow_model')
|
||||
waveglow = torch.load(waveglow_path, map_location='cpu')['model']
|
||||
waveglow.eval()
|
||||
for k in waveglow.convinv:
|
||||
|
|
@ -93,9 +94,10 @@ def player_gen():
|
|||
|
||||
def synthesize_corpus():
|
||||
all_data = []
|
||||
for line in open('corpus.txt').readlines():
|
||||
for (i, line) in enumerate(open('corpus.txt').readlines()):
|
||||
print('synthesizing... "{}"'.format(line.strip()))
|
||||
data = speech(line.strip())
|
||||
sf.write('tts_{}.wav'.format(i), data, 16000)
|
||||
all_data.append(data)
|
||||
return all_data
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import tensorflow as tf
|
||||
from text import symbols
|
||||
|
||||
#changed path, sampling rate and batch size
|
||||
|
||||
# changed path, sampling rate and batch size
|
||||
def create_hparams(hparams_string=None, verbose=False):
|
||||
"""Create model hyperparameters. Parse nondefault from given string."""
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
pip==18.1
|
||||
bumpversion==0.5.3
|
||||
wheel==0.32.1
|
||||
watchdog==0.9.0
|
||||
flake8==3.5.0
|
||||
tox==3.5.2
|
||||
coverage==4.5.1
|
||||
Sphinx==1.8.1
|
||||
twine==1.12.1
|
||||
|
||||
pytest==3.8.2
|
||||
pytest-runner==4.2
|
||||
pre-commit==1.16.1
|
||||
python-language-server[all]
|
||||
ipdb
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import grpc
|
||||
import time
|
||||
from sia.proto import tts_pb2
|
||||
from sia.proto import tts_pb2_grpc
|
||||
from concurrent import futures
|
||||
from sia.instruments import do_time
|
||||
from tts import TTSModel
|
||||
|
||||
|
||||
class TTSServer():
|
||||
def __init__(self):
|
||||
self.tts_model = TTSModel()
|
||||
|
||||
def TextToSpeechAPI(self, request, context):
|
||||
while (True):
|
||||
input_text = request.text
|
||||
speech_response = self.tts_model.synth_speech(input_text)
|
||||
return tts_pb2.SpeechResponse(response=speech_response)
|
||||
|
||||
|
||||
def main():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||
tts_server = TTSServer()
|
||||
tts_pb2_grpc.add_ServerServicer_to_server(tts_server, server)
|
||||
server.add_insecure_port('localhost:50060')
|
||||
server.start()
|
||||
print('TTSServer started!')
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(10000)
|
||||
except KeyboardInterrupt:
|
||||
server.start()
|
||||
# server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
tensorboard.png
BIN
tensorboard.png
Binary file not shown.
|
Before Width: | Height: | Size: 170 KiB |
|
|
@ -0,0 +1,128 @@
|
|||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# import matplotlib
|
||||
# import matplotlib.pylab as plt
|
||||
|
||||
# import IPython.display as ipd
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
from hparams import create_hparams
|
||||
from model import Tacotron2
|
||||
from layers import TacotronSTFT, STFT
|
||||
# from audio_processing import griffin_lim
|
||||
from train import load_model
|
||||
from text import text_to_sequence
|
||||
# from denoiser import Denoiser
|
||||
import os
|
||||
import soundfile as sf
|
||||
import pyaudio
|
||||
import klepto
|
||||
import IPython.display as ipd
|
||||
import time
|
||||
from sia.file_utils import cached_model_path
|
||||
|
||||
sys.path.append('waveglow/')
|
||||
|
||||
|
||||
class TTSModel(object):
|
||||
"""docstring for TTSModel."""
|
||||
|
||||
def __init__(self):
|
||||
super(TTSModel, self).__init__()
|
||||
hparams = create_hparams()
|
||||
hparams.sampling_rate = 22050
|
||||
self.model = load_model(hparams)
|
||||
tacotron2_path = cached_model_path("tacotron2_model")
|
||||
self.model.load_state_dict(
|
||||
torch.load(tacotron2_path, map_location='cpu')['state_dict'])
|
||||
self.model.eval()
|
||||
waveglow_path = cached_model_path('waveglow_model')
|
||||
self.waveglow = torch.load(waveglow_path, map_location='cpu')['model']
|
||||
self.waveglow.eval()
|
||||
for k in self.waveglow.convinv:
|
||||
k.float()
|
||||
self.k_cache = klepto.archives.file_archive(cached=False)
|
||||
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
|
||||
self.synth_speech)
|
||||
|
||||
# https://github.com/NVIDIA/waveglow/issues/127
|
||||
for m in self.waveglow.modules():
|
||||
if 'Conv' in str(type(m)):
|
||||
setattr(m, 'padding_mode', 'zeros')
|
||||
|
||||
def synth_speech(self, t):
|
||||
start = time.time()
|
||||
text = t
|
||||
sequence = np.array(text_to_sequence(text,
|
||||
['english_cleaners']))[None, :]
|
||||
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
|
||||
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(
|
||||
sequence)
|
||||
with torch.no_grad():
|
||||
audio = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
data = convert(audio[0].data.cpu().numpy())
|
||||
# _audio_stream.write(data.astype('float32'))
|
||||
# _audio_stream.write(data)
|
||||
end = time.time()
|
||||
print(end - start)
|
||||
return data.tobytes()
|
||||
|
||||
|
||||
def convert(array):
|
||||
sf.write('sample.wav', array, 22050)
|
||||
os.system('ffmpeg -i {0} -filter:a "atempo=0.80" -ar 16k {1}'.format(
|
||||
'sample.wav', 'sample0.wav'))
|
||||
data, rate = sf.read('sample0.wav', dtype='int16')
|
||||
os.remove('sample.wav')
|
||||
os.remove('sample0.wav')
|
||||
return data
|
||||
|
||||
|
||||
def display(data):
|
||||
aud = ipd.Audio(data, rate=16000)
|
||||
return aud
|
||||
|
||||
|
||||
def player_gen():
|
||||
audio_interface = pyaudio.PyAudio()
|
||||
_audio_stream = audio_interface.open(format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=16000,
|
||||
output=True)
|
||||
|
||||
def play_device(data):
|
||||
_audio_stream.write(data)
|
||||
# _audio_stream.close()
|
||||
|
||||
return play_device
|
||||
|
||||
|
||||
def synthesize_corpus():
|
||||
tts_model = TTSModel()
|
||||
all_data = []
|
||||
for (i, line) in enumerate(open('corpus.txt').readlines()):
|
||||
print('synthesizing... "{}"'.format(line.strip()))
|
||||
data = tts_model.synth_speech(line.strip())
|
||||
all_data.append(data)
|
||||
return all_data
|
||||
|
||||
|
||||
def play_corpus(corpus_synths):
|
||||
player = player_gen()
|
||||
for d in corpus_synths:
|
||||
player(d)
|
||||
|
||||
|
||||
def main():
|
||||
corpus_synth_data = synthesize_corpus()
|
||||
play_corpus(corpus_synth_data)
|
||||
import ipdb
|
||||
ipdb.set_trace()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Reference in New Issue