diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6318ddd --- /dev/null +++ b/.gitignore @@ -0,0 +1,166 @@ + +# Created by https://www.gitignore.io/api/python +# Edit at https://www.gitignore.io/?templates=python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# End of https://www.gitignore.io/api/python + +# Created by https://www.gitignore.io/api/macos +# Edit at https://www.gitignore.io/?templates=macos + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# End of https://www.gitignore.io/api/macos + +*.pkl diff --git a/demo_client.py b/demo_client.py new file mode 100644 index 0000000..bc2a12e --- /dev/null +++ b/demo_client.py @@ -0,0 +1,19 @@ +import grpc +from sia.proto import tts_pb2 +from sia.proto import tts_pb2_grpc +from tts import player_gen + + +def main(): + channel = grpc.insecure_channel('localhost:50060') + stub = tts_pb2_grpc.ServerStub(channel) + test_text = tts_pb2.TextInput(text='How may I help you today?') + speech = stub.TextToSpeechAPI(test_text) + player = player_gen() + player(speech.response) + import pdb + pdb.set_trace() + + +if __name__ == '__main__': + main() diff --git a/final.py b/final.py index 31d63b7..0b826b2 100644 --- a/final.py +++ b/final.py @@ -22,16 +22,17 @@ import pyaudio import klepto import IPython.display as ipd import time +from sia.file_utils import cached_model_path sys.path.append('waveglow/') hparams = create_hparams() hparams.sampling_rate = 22050 -checkpoint_path = "checkpoint_15000" model = load_model(hparams) +tacotron2_path = cached_model_path("tacotron2_model") model.load_state_dict( - torch.load(checkpoint_path, map_location='cpu')['state_dict']) + torch.load(tacotron2_path, map_location='cpu')['state_dict']) model.eval() -waveglow_path = 'waveglow_256channels.pt' +waveglow_path = cached_model_path('waveglow_model') waveglow = torch.load(waveglow_path, map_location='cpu')['model'] waveglow.eval() for k in waveglow.convinv: @@ -93,9 +94,10 @@ def player_gen(): def synthesize_corpus(): all_data = [] - for line in open('corpus.txt').readlines(): + for (i, line) in enumerate(open('corpus.txt').readlines()): print('synthesizing... "{}"'.format(line.strip())) data = speech(line.strip()) + sf.write('tts_{}.wav'.format(i), data, 16000) all_data.append(data) return all_data diff --git a/hparams.py b/hparams.py index 9905136..9704bbd 100644 --- a/hparams.py +++ b/hparams.py @@ -1,7 +1,8 @@ import tensorflow as tf from text import symbols -#changed path, sampling rate and batch size + +# changed path, sampling rate and batch size def create_hparams(hparams_string=None, verbose=False): """Create model hyperparameters. Parse nondefault from given string.""" diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..c5473c1 --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,15 @@ +pip==18.1 +bumpversion==0.5.3 +wheel==0.32.1 +watchdog==0.9.0 +flake8==3.5.0 +tox==3.5.2 +coverage==4.5.1 +Sphinx==1.8.1 +twine==1.12.1 + +pytest==3.8.2 +pytest-runner==4.2 +pre-commit==1.16.1 +python-language-server[all] +ipdb diff --git a/server.py b/server.py new file mode 100644 index 0000000..7611e04 --- /dev/null +++ b/server.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +import grpc +import time +from sia.proto import tts_pb2 +from sia.proto import tts_pb2_grpc +from concurrent import futures +from sia.instruments import do_time +from tts import TTSModel + + +class TTSServer(): + def __init__(self): + self.tts_model = TTSModel() + + def TextToSpeechAPI(self, request, context): + while (True): + input_text = request.text + speech_response = self.tts_model.synth_speech(input_text) + return tts_pb2.SpeechResponse(response=speech_response) + + +def main(): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) + tts_server = TTSServer() + tts_pb2_grpc.add_ServerServicer_to_server(tts_server, server) + server.add_insecure_port('localhost:50060') + server.start() + print('TTSServer started!') + + try: + while True: + time.sleep(10000) + except KeyboardInterrupt: + server.start() + # server.stop(0) + + +if __name__ == "__main__": + main() diff --git a/tensorboard.png b/tensorboard.png deleted file mode 100644 index 59223cf..0000000 Binary files a/tensorboard.png and /dev/null differ diff --git a/tts.py b/tts.py new file mode 100644 index 0000000..a946fa0 --- /dev/null +++ b/tts.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +# coding: utf-8 + +# import matplotlib +# import matplotlib.pylab as plt + +# import IPython.display as ipd + +import sys +import numpy as np +import torch +from hparams import create_hparams +from model import Tacotron2 +from layers import TacotronSTFT, STFT +# from audio_processing import griffin_lim +from train import load_model +from text import text_to_sequence +# from denoiser import Denoiser +import os +import soundfile as sf +import pyaudio +import klepto +import IPython.display as ipd +import time +from sia.file_utils import cached_model_path + +sys.path.append('waveglow/') + + +class TTSModel(object): + """docstring for TTSModel.""" + + def __init__(self): + super(TTSModel, self).__init__() + hparams = create_hparams() + hparams.sampling_rate = 22050 + self.model = load_model(hparams) + tacotron2_path = cached_model_path("tacotron2_model") + self.model.load_state_dict( + torch.load(tacotron2_path, map_location='cpu')['state_dict']) + self.model.eval() + waveglow_path = cached_model_path('waveglow_model') + self.waveglow = torch.load(waveglow_path, map_location='cpu')['model'] + self.waveglow.eval() + for k in self.waveglow.convinv: + k.float() + self.k_cache = klepto.archives.file_archive(cached=False) + self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)( + self.synth_speech) + + # https://github.com/NVIDIA/waveglow/issues/127 + for m in self.waveglow.modules(): + if 'Conv' in str(type(m)): + setattr(m, 'padding_mode', 'zeros') + + def synth_speech(self, t): + start = time.time() + text = t + sequence = np.array(text_to_sequence(text, + ['english_cleaners']))[None, :] + sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long() + mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference( + sequence) + with torch.no_grad(): + audio = self.waveglow.infer(mel_outputs_postnet, sigma=0.666) + # import ipdb; ipdb.set_trace() + data = convert(audio[0].data.cpu().numpy()) + # _audio_stream.write(data.astype('float32')) + # _audio_stream.write(data) + end = time.time() + print(end - start) + return data.tobytes() + + +def convert(array): + sf.write('sample.wav', array, 22050) + os.system('ffmpeg -i {0} -filter:a "atempo=0.80" -ar 16k {1}'.format( + 'sample.wav', 'sample0.wav')) + data, rate = sf.read('sample0.wav', dtype='int16') + os.remove('sample.wav') + os.remove('sample0.wav') + return data + + +def display(data): + aud = ipd.Audio(data, rate=16000) + return aud + + +def player_gen(): + audio_interface = pyaudio.PyAudio() + _audio_stream = audio_interface.open(format=pyaudio.paInt16, + channels=1, + rate=16000, + output=True) + + def play_device(data): + _audio_stream.write(data) + # _audio_stream.close() + + return play_device + + +def synthesize_corpus(): + tts_model = TTSModel() + all_data = [] + for (i, line) in enumerate(open('corpus.txt').readlines()): + print('synthesizing... "{}"'.format(line.strip())) + data = tts_model.synth_speech(line.strip()) + all_data.append(data) + return all_data + + +def play_corpus(corpus_synths): + player = player_gen() + for d in corpus_synths: + player(d) + + +def main(): + corpus_synth_data = synthesize_corpus() + play_corpus(corpus_synth_data) + import ipdb + ipdb.set_trace() + + +if __name__ == '__main__': + main()