1. create a class for the tts api

2. implement a grpc server for tts
3. add a demo client for the grpc service
4. update gitignore and requriements
5. cleanup
master
Malar Kannan 2019-07-01 14:47:55 +05:30
parent ccd8ab42e7
commit 505af768a0
8 changed files with 375 additions and 5 deletions

166
.gitignore vendored Normal file
View File

@ -0,0 +1,166 @@
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# End of https://www.gitignore.io/api/python
# Created by https://www.gitignore.io/api/macos
# Edit at https://www.gitignore.io/?templates=macos
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
# End of https://www.gitignore.io/api/macos
*.pkl

19
demo_client.py Normal file
View File

@ -0,0 +1,19 @@
import grpc
from sia.proto import tts_pb2
from sia.proto import tts_pb2_grpc
from tts import player_gen
def main():
channel = grpc.insecure_channel('localhost:50060')
stub = tts_pb2_grpc.ServerStub(channel)
test_text = tts_pb2.TextInput(text='How may I help you today?')
speech = stub.TextToSpeechAPI(test_text)
player = player_gen()
player(speech.response)
import pdb
pdb.set_trace()
if __name__ == '__main__':
main()

View File

@ -22,16 +22,17 @@ import pyaudio
import klepto
import IPython.display as ipd
import time
from sia.file_utils import cached_model_path
sys.path.append('waveglow/')
hparams = create_hparams()
hparams.sampling_rate = 22050
checkpoint_path = "checkpoint_15000"
model = load_model(hparams)
tacotron2_path = cached_model_path("tacotron2_model")
model.load_state_dict(
torch.load(checkpoint_path, map_location='cpu')['state_dict'])
torch.load(tacotron2_path, map_location='cpu')['state_dict'])
model.eval()
waveglow_path = 'waveglow_256channels.pt'
waveglow_path = cached_model_path('waveglow_model')
waveglow = torch.load(waveglow_path, map_location='cpu')['model']
waveglow.eval()
for k in waveglow.convinv:
@ -93,9 +94,10 @@ def player_gen():
def synthesize_corpus():
all_data = []
for line in open('corpus.txt').readlines():
for (i, line) in enumerate(open('corpus.txt').readlines()):
print('synthesizing... "{}"'.format(line.strip()))
data = speech(line.strip())
sf.write('tts_{}.wav'.format(i), data, 16000)
all_data.append(data)
return all_data

View File

@ -1,7 +1,8 @@
import tensorflow as tf
from text import symbols
#changed path, sampling rate and batch size
# changed path, sampling rate and batch size
def create_hparams(hparams_string=None, verbose=False):
"""Create model hyperparameters. Parse nondefault from given string."""

15
requirements_dev.txt Normal file
View File

@ -0,0 +1,15 @@
pip==18.1
bumpversion==0.5.3
wheel==0.32.1
watchdog==0.9.0
flake8==3.5.0
tox==3.5.2
coverage==4.5.1
Sphinx==1.8.1
twine==1.12.1
pytest==3.8.2
pytest-runner==4.2
pre-commit==1.16.1
python-language-server[all]
ipdb

39
server.py Normal file
View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
import grpc
import time
from sia.proto import tts_pb2
from sia.proto import tts_pb2_grpc
from concurrent import futures
from sia.instruments import do_time
from tts import TTSModel
class TTSServer():
def __init__(self):
self.tts_model = TTSModel()
def TextToSpeechAPI(self, request, context):
while (True):
input_text = request.text
speech_response = self.tts_model.synth_speech(input_text)
return tts_pb2.SpeechResponse(response=speech_response)
def main():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
tts_server = TTSServer()
tts_pb2_grpc.add_ServerServicer_to_server(tts_server, server)
server.add_insecure_port('localhost:50060')
server.start()
print('TTSServer started!')
try:
while True:
time.sleep(10000)
except KeyboardInterrupt:
server.start()
# server.stop(0)
if __name__ == "__main__":
main()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 170 KiB

128
tts.py Normal file
View File

@ -0,0 +1,128 @@
#!/usr/bin/env python
# coding: utf-8
# import matplotlib
# import matplotlib.pylab as plt
# import IPython.display as ipd
import sys
import numpy as np
import torch
from hparams import create_hparams
from model import Tacotron2
from layers import TacotronSTFT, STFT
# from audio_processing import griffin_lim
from train import load_model
from text import text_to_sequence
# from denoiser import Denoiser
import os
import soundfile as sf
import pyaudio
import klepto
import IPython.display as ipd
import time
from sia.file_utils import cached_model_path
sys.path.append('waveglow/')
class TTSModel(object):
"""docstring for TTSModel."""
def __init__(self):
super(TTSModel, self).__init__()
hparams = create_hparams()
hparams.sampling_rate = 22050
self.model = load_model(hparams)
tacotron2_path = cached_model_path("tacotron2_model")
self.model.load_state_dict(
torch.load(tacotron2_path, map_location='cpu')['state_dict'])
self.model.eval()
waveglow_path = cached_model_path('waveglow_model')
self.waveglow = torch.load(waveglow_path, map_location='cpu')['model']
self.waveglow.eval()
for k in self.waveglow.convinv:
k.float()
self.k_cache = klepto.archives.file_archive(cached=False)
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
self.synth_speech)
# https://github.com/NVIDIA/waveglow/issues/127
for m in self.waveglow.modules():
if 'Conv' in str(type(m)):
setattr(m, 'padding_mode', 'zeros')
def synth_speech(self, t):
start = time.time()
text = t
sequence = np.array(text_to_sequence(text,
['english_cleaners']))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(
sequence)
with torch.no_grad():
audio = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
# import ipdb; ipdb.set_trace()
data = convert(audio[0].data.cpu().numpy())
# _audio_stream.write(data.astype('float32'))
# _audio_stream.write(data)
end = time.time()
print(end - start)
return data.tobytes()
def convert(array):
sf.write('sample.wav', array, 22050)
os.system('ffmpeg -i {0} -filter:a "atempo=0.80" -ar 16k {1}'.format(
'sample.wav', 'sample0.wav'))
data, rate = sf.read('sample0.wav', dtype='int16')
os.remove('sample.wav')
os.remove('sample0.wav')
return data
def display(data):
aud = ipd.Audio(data, rate=16000)
return aud
def player_gen():
audio_interface = pyaudio.PyAudio()
_audio_stream = audio_interface.open(format=pyaudio.paInt16,
channels=1,
rate=16000,
output=True)
def play_device(data):
_audio_stream.write(data)
# _audio_stream.close()
return play_device
def synthesize_corpus():
tts_model = TTSModel()
all_data = []
for (i, line) in enumerate(open('corpus.txt').readlines()):
print('synthesizing... "{}"'.format(line.strip()))
data = tts_model.synth_speech(line.strip())
all_data.append(data)
return all_data
def play_corpus(corpus_synths):
player = player_gen()
for d in corpus_synths:
player(d)
def main():
corpus_synth_data = synthesize_corpus()
play_corpus(corpus_synth_data)
import ipdb
ipdb.set_trace()
if __name__ == '__main__':
main()