mirror of
https://github.com/malarinv/tacotron2
synced 2026-03-08 01:32:35 +00:00
1. clean-up
2. update readme and release info
This commit is contained in:
67
taco2/tts.py
67
taco2/tts.py
@@ -46,7 +46,7 @@ class TTSModel(object):
|
||||
self.waveglow.load_state_dict(wave_params)
|
||||
self.waveglow.eval()
|
||||
except:
|
||||
self.waveglow = wave_params['model']
|
||||
self.waveglow = wave_params["model"]
|
||||
self.waveglow = self.waveglow.remove_weightnorm(self.waveglow)
|
||||
self.waveglow.eval()
|
||||
# workaround from
|
||||
@@ -60,7 +60,6 @@ class TTSModel(object):
|
||||
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(self.synth_speech)
|
||||
self.denoiser = Denoiser(self.waveglow)
|
||||
|
||||
|
||||
def synth_speech(self, text):
|
||||
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
|
||||
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
|
||||
@@ -78,15 +77,23 @@ class TTSModel(object):
|
||||
data = float2pcm(float_data)
|
||||
return data.tobytes()
|
||||
|
||||
def synth_speech_algo(self,text,griffin_iters=60):
|
||||
def synth_speech_algo(self, text, griffin_iters=60):
|
||||
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
|
||||
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
|
||||
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence)
|
||||
from .hparams import HParams
|
||||
from .layers import TacotronSTFT
|
||||
from .audio_processing import griffin_lim
|
||||
|
||||
hparams = HParams()
|
||||
taco_stft = TacotronSTFT(hparams.filter_length, hparams.hop_length, hparams.win_length, n_mel_channels=hparams.n_mel_channels, sampling_rate=hparams.sampling_rate, mel_fmax=4000)
|
||||
taco_stft = TacotronSTFT(
|
||||
hparams.filter_length,
|
||||
hparams.hop_length,
|
||||
hparams.win_length,
|
||||
n_mel_channels=hparams.n_mel_channels,
|
||||
sampling_rate=hparams.sampling_rate,
|
||||
mel_fmax=4000,
|
||||
)
|
||||
mel_decompress = taco_stft.spectral_de_normalize(mel_outputs_postnet)
|
||||
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
|
||||
spec_from_mel_scaling = 1000
|
||||
@@ -94,7 +101,11 @@ class TTSModel(object):
|
||||
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
|
||||
spec_from_mel = spec_from_mel * spec_from_mel_scaling
|
||||
|
||||
audio = griffin_lim(torch.autograd.Variable(spec_from_mel[:, :, :-1]), taco_stft.stft_fn, griffin_iters)
|
||||
audio = griffin_lim(
|
||||
torch.autograd.Variable(spec_from_mel[:, :, :-1]),
|
||||
taco_stft.stft_fn,
|
||||
griffin_iters,
|
||||
)
|
||||
audio = audio.squeeze()
|
||||
audio = audio.cpu().numpy()
|
||||
|
||||
@@ -102,6 +113,8 @@ class TTSModel(object):
|
||||
float_data = resample(slow_data, TTS_SAMPLE_RATE, OUTPUT_SAMPLE_RATE)
|
||||
data = float2pcm(float_data)
|
||||
return data.tobytes()
|
||||
|
||||
|
||||
# adapted from
|
||||
# https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py
|
||||
def float2pcm(sig, dtype="int16"):
|
||||
@@ -140,13 +153,6 @@ def float2pcm(sig, dtype="int16"):
|
||||
return (sig * abs_max + offset).clip(i.min, i.max).astype(dtype)
|
||||
|
||||
|
||||
def display(data):
|
||||
import IPython.display as ipd
|
||||
|
||||
aud = ipd.Audio(data, rate=16000)
|
||||
return aud
|
||||
|
||||
|
||||
def player_gen():
|
||||
audio_interface = pyaudio.PyAudio()
|
||||
_audio_stream = audio_interface.open(
|
||||
@@ -160,51 +166,22 @@ def player_gen():
|
||||
return play_device
|
||||
|
||||
|
||||
def synthesize_corpus():
|
||||
tts_model = TTSModel(
|
||||
"/Users/malar/Work/tacotron2_statedict.pt",
|
||||
"/Users/malar/Work/waveglow.pt",
|
||||
)
|
||||
all_data = []
|
||||
for (i, line) in enumerate(open("corpus.txt").readlines()):
|
||||
print(f'synthesizing... "{line.strip()}"')
|
||||
data = tts_model.synth_speech(line.strip())
|
||||
all_data.append(data)
|
||||
return all_data
|
||||
|
||||
def repl():
|
||||
tts_model = TTSModel(
|
||||
# "/Users/malar/Work/tacotron2_statedict.pt",
|
||||
# "/Users/malar/Work/tacotron2_80_22000.pt",
|
||||
"/path/to/tacotron2.pt",
|
||||
# "/Users/malar/Work/tacotron2_40_22000.pt",
|
||||
# "/Users/malar/Work/tacotron2_16000.pt",
|
||||
"/path/to/waveglow.pt",
|
||||
# "/Users/malar/Work/waveglow.pt",
|
||||
# "/Users/malar/Work/waveglow_38000",
|
||||
)
|
||||
tts_model = TTSModel("/path/to/tacotron2.pt", "/path/to/waveglow.pt")
|
||||
player = player_gen()
|
||||
|
||||
def loop():
|
||||
text = input('tts >')
|
||||
text = input("tts >")
|
||||
data = tts_model.synth_speech(text.strip())
|
||||
player(data)
|
||||
|
||||
return loop
|
||||
|
||||
|
||||
def play_corpus(corpus_synths):
|
||||
player = player_gen()
|
||||
for d in corpus_synths:
|
||||
player(d)
|
||||
|
||||
|
||||
def main():
|
||||
# corpus_synth_data = synthesize_corpus()
|
||||
# play_corpus(corpus_synth_data)
|
||||
interactive_loop = repl()
|
||||
while True:
|
||||
interactive_loop()
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user