1
0
mirror of https://github.com/malarinv/tacotron2 synced 2026-03-08 01:32:35 +00:00

enable gpu support if available

This commit is contained in:
2019-11-27 21:42:10 +05:30
parent 5a30069f0a
commit ac5ffcf6d5
6 changed files with 87 additions and 48 deletions

View File

@@ -9,9 +9,14 @@ class Denoiser(torch.nn.Module):
def __init__(self, waveglow, filter_length=1024, n_overlap=4,
win_length=1024, mode='zeros', n_mel_channels=80,):
super(Denoiser, self).__init__()
self.stft = STFT(filter_length=filter_length,
hop_length=int(filter_length/n_overlap),
win_length=win_length).cpu()
if torch.cuda.is_available():
self.stft = STFT(filter_length=filter_length,
hop_length=int(filter_length/n_overlap),
win_length=win_length).cuda()
else:
self.stft = STFT(filter_length=filter_length,
hop_length=int(filter_length/n_overlap),
win_length=win_length).cpu()
if mode == 'zeros':
mel_input = torch.zeros(
(1, n_mel_channels, 88),
@@ -32,7 +37,10 @@ class Denoiser(torch.nn.Module):
self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
def forward(self, audio, strength=0.1):
audio_spec, audio_angles = self.stft.transform(audio.cpu().float())
if torch.cuda.is_available():
audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
else:
audio_spec, audio_angles = self.stft.transform(audio.cpu().float())
audio_spec_denoised = audio_spec - self.bias_spec * strength
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)

View File

@@ -39,9 +39,9 @@ class HParams(object):
filter_length = 1024
hop_length = 256
win_length = 1024
n_mel_channels: int = 40
n_mel_channels: int = 80
mel_fmin: float = 0.0
mel_fmax: float = 4000.0
mel_fmax: float = 8000.0
################################
# Model Parameters #
################################

View File

@@ -40,6 +40,7 @@ from scipy.signal import get_window
from librosa.util import pad_center, tiny
from .audio_processing import window_sumsquare
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class STFT(torch.nn.Module):
"""
@@ -147,9 +148,7 @@ class STFT(torch.nn.Module):
window_sum = torch.autograd.Variable(
torch.from_numpy(window_sum), requires_grad=False
)
# window_sum = window_sum.cuda() if magnitude.is_cuda else
# window_sum
# initially not commented out
window_sum = window_sum.to(DEVICE)
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
approx_nonzero_indices
]

View File

@@ -39,20 +39,29 @@ class TTSModel(object):
hparams = HParams(**kwargs)
hparams.sampling_rate = TTS_SAMPLE_RATE
self.model = Tacotron2(hparams)
self.model.load_state_dict(
torch.load(tacotron2_path, map_location="cpu")["state_dict"]
)
self.model.eval()
if torch.cuda.is_available():
self.model.load_state_dict(torch.load(tacotron2_path)["state_dict"])
self.model.cuda().eval()
else:
self.model.load_state_dict(
torch.load(tacotron2_path, map_location="cpu")["state_dict"]
)
self.model.eval()
self.k_cache = klepto.archives.file_archive(cached=False)
if waveglow_path:
wave_params = torch.load(waveglow_path, map_location="cpu")
if torch.cuda.is_available():
wave_params = torch.load(waveglow_path)
else:
wave_params = torch.load(waveglow_path, map_location="cpu")
try:
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
self.waveglow.load_state_dict(wave_params)
self.waveglow.eval()
except:
self.waveglow = wave_params["model"]
self.waveglow = self.waveglow.remove_weightnorm(self.waveglow)
if torch.cuda.is_available():
self.waveglow.cuda().eval()
else:
self.waveglow.eval()
# workaround from
# https://github.com/NVIDIA/waveglow/issues/127
@@ -60,7 +69,7 @@ class TTSModel(object):
if "Conv" in str(type(m)):
setattr(m, "padding_mode", "zeros")
for k in self.waveglow.convinv:
k.float()
k.float().half()
self.denoiser = Denoiser(
self.waveglow, n_mel_channels=hparams.n_mel_channels
)
@@ -82,20 +91,27 @@ class TTSModel(object):
def generate_mel_postnet(self, text):
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
if torch.cuda.is_available():
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
else:
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
with torch.no_grad():
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(
sequence
)
return mel_outputs_postnet
def synth_speech(self, text):
def synth_speech_array(self, text):
mel_outputs_postnet = self.generate_mel_postnet(text)
with torch.no_grad():
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
audio_t = self.denoiser(audio_t, 0.1)[0]
audio = audio_t[0].data.cpu().numpy()
return audio
def synth_speech(self, text):
audio = self.synth_speech_array(text)
return postprocess_audio(
audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE

View File

@@ -27,6 +27,6 @@ def load_filepaths_and_text(filename, split="|"):
def to_gpu(x):
x = x.contiguous()
# if torch.cuda.is_available(): #initially not commented out
# x = x.cuda(non_blocking=True) # initially not commented out
if torch.cuda.is_available(): #initially not commented out
x = x.cuda(non_blocking=True) # initially not commented out
return torch.autograd.Variable(x)