1. implemented corpus wav generator

2. refactor
master
Malar Kannan 2019-10-07 16:00:35 +05:30
parent 36f229449f
commit dde35048b7
3 changed files with 177 additions and 100 deletions

View File

@ -3,6 +3,8 @@ import torch
import numpy as np
from scipy.signal import get_window
import librosa.util as librosa_util
from librosa import resample
from librosa.effects import time_stretch
def window_sumsquare(
@ -60,9 +62,7 @@ def window_sumsquare(
# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample : min(n, sample + n_fft)] += win_sq[
: max(0, min(n_fft, n - sample))
]
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
return x
@ -101,3 +101,48 @@ def dynamic_range_decompression(x, C=1):
C: compression factor used to compress
"""
return torch.exp(x) / C
# adapted from
# https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py
def float2pcm(sig, dtype="int16"):
"""Convert floating point signal with a range from -1 to 1 to PCM.
Any signal values outside the interval [-1.0, 1.0) are clipped.
No dithering is used.
Note that there are different possibilities for scaling floating
point numbers to PCM numbers, this function implements just one of
them. For an overview of alternatives see
http://blog.bjornroche.com/2009/12/int-float-int-its-jungle-out-there.html
Parameters
----------
sig : array_like
Input array, must have floating point type.
dtype : data type, optional
Desired (integer) data type.
Returns
-------
numpy.ndarray
Integer data, scaled and clipped to the range of the given
*dtype*.
See Also
--------
pcm2float, dtype
"""
sig = np.asarray(sig)
if sig.dtype.kind != "f":
raise TypeError("'sig' must be a float array")
dtype = np.dtype(dtype)
if dtype.kind not in "iu":
raise TypeError("'dtype' must be an integer type")
i = np.iinfo(dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
return (sig * abs_max + offset).clip(i.min, i.max).astype(dtype)
def postprocess_audio(audio, tempo=0.8, src_rate=22050, dst_rate=16000):
slow_data = time_stretch(audio, tempo)
float_data = resample(slow_data, 22050, dst_rate)
data = float2pcm(float_data)
return data.tobytes()

63
taco2/generate_corpus.py Normal file
View File

@ -0,0 +1,63 @@
from .tts import TTSModel, OUTPUT_SAMPLE_RATE
import argparse
from pathlib import Path
import wave
def synthesize_corpus(
corpus_path=Path("corpus.txt"),
tacotron_path=Path("/path/to/tacotron.pt"),
waveglow_path=Path("/path/to/waveglow.pt"),
output_dir=Path("./out_dir"),
):
tts_model = TTSModel(str(tacotron_path), str(waveglow_path))
output_dir.mkdir(exist_ok=True)
for (i, line) in enumerate(open(str(corpus_path)).readlines()):
print(f'synthesizing... "{line.strip()}"')
data = tts_model.synth_speech(line.strip())
out_file = str(output_dir / Path(str(i) + ".wav"))
with wave.open(out_file, "w") as out_file_h:
out_file_h.setnchannels(1) # mono
out_file_h.setsampwidth(2) # pcm int16 2bytes
out_file_h.setframerate(OUTPUT_SAMPLE_RATE)
out_file_h.writeframes(data)
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"-t",
"--tacotron_path",
type=Path,
default="./tacotron.pt",
help="Path to a tacotron2 model",
)
parser.add_argument(
"-w",
"--waveglow_path",
type=Path,
default="./waveglow_256channels.pt",
help="Path to a waveglow model",
)
parser.add_argument(
"-c",
"--corpus_path",
type=Path,
default="./corpus.txt",
help="Path to a corpus file",
)
parser.add_argument(
"-o",
"--output_dir",
type=Path,
default="./synth",
help="Path to a output directory",
)
args = parser.parse_args()
synthesize_corpus(**vars(args))
if __name__ == "__main__":
main()

View File

@ -4,14 +4,14 @@
import numpy as np
import torch
import pyaudio
from librosa import resample
from librosa.effects import time_stretch
import klepto
from .model import Tacotron2
from glow import WaveGlow
from .hparams import HParams
from .layers import TacotronSTFT
from .text import text_to_sequence
from .denoiser import Denoiser
from .audio_processing import griffin_lim, postprocess_audio
TTS_SAMPLE_RATE = 22050
OUTPUT_SAMPLE_RATE = 16000
@ -31,62 +31,44 @@ WAVEGLOW_CONFIG = {
class TTSModel(object):
"""docstring for TTSModel."""
def __init__(self, tacotron2_path, waveglow_path):
def __init__(self, tacotron2_path, waveglow_path, **kwargs):
super(TTSModel, self).__init__()
hparams = HParams()
hparams = HParams(**kwargs)
hparams.sampling_rate = TTS_SAMPLE_RATE
self.model = Tacotron2(hparams)
self.model.load_state_dict(
torch.load(tacotron2_path, map_location="cpu")["state_dict"]
)
self.model.eval()
wave_params = torch.load(waveglow_path, map_location="cpu")
try:
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
self.waveglow.load_state_dict(wave_params)
self.waveglow.eval()
except:
self.waveglow = wave_params["model"]
self.waveglow = self.waveglow.remove_weightnorm(self.waveglow)
self.waveglow.eval()
# workaround from
# https://github.com/NVIDIA/waveglow/issues/127
for m in self.waveglow.modules():
if "Conv" in str(type(m)):
setattr(m, "padding_mode", "zeros")
for k in self.waveglow.convinv:
k.float()
self.k_cache = klepto.archives.file_archive(cached=False)
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(self.synth_speech)
self.denoiser = Denoiser(self.waveglow)
def synth_speech(self, text):
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence)
# width = mel_outputs_postnet.shape[2]
# wave_glow_input = torch.randn(1, 80, width)*0.00001
# wave_glow_input[:,40:,:] = mel_outputs_postnet
with torch.no_grad():
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
audio_t = self.denoiser(audio_t, 0.1)[0]
audio = audio_t[0].data.cpu().numpy()
# data = convert(audio)
slow_data = time_stretch(audio, 0.8)
float_data = resample(slow_data, TTS_SAMPLE_RATE, OUTPUT_SAMPLE_RATE)
data = float2pcm(float_data)
return data.tobytes()
def synth_speech_algo(self, text, griffin_iters=60):
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence)
from .hparams import HParams
from .layers import TacotronSTFT
from .audio_processing import griffin_lim
hparams = HParams()
taco_stft = TacotronSTFT(
if waveglow_path:
wave_params = torch.load(waveglow_path, map_location="cpu")
try:
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
self.waveglow.load_state_dict(wave_params)
self.waveglow.eval()
except:
self.waveglow = wave_params["model"]
self.waveglow = self.waveglow.remove_weightnorm(self.waveglow)
self.waveglow.eval()
# workaround from
# https://github.com/NVIDIA/waveglow/issues/127
for m in self.waveglow.modules():
if "Conv" in str(type(m)):
setattr(m, "padding_mode", "zeros")
for k in self.waveglow.convinv:
k.float()
self.denoiser = Denoiser(
self.waveglow, n_mel_channels=hparams.n_mel_channels
)
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
self.synth_speech
)
else:
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
self.synth_speech_gl
)
self.taco_stft = TacotronSTFT(
hparams.filter_length,
hparams.hop_length,
hparams.win_length,
@ -94,63 +76,48 @@ class TTSModel(object):
sampling_rate=hparams.sampling_rate,
mel_fmax=4000,
)
mel_decompress = taco_stft.spectral_de_normalize(mel_outputs_postnet)
def generate_mel_postnet(self, text):
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
with torch.no_grad():
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(
sequence
)
return mel_outputs_postnet
def synth_speech(self, text):
mel_outputs_postnet = self.generate_mel_postnet(text)
with torch.no_grad():
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
audio_t = self.denoiser(audio_t, 0.1)[0]
audio = audio_t[0].data.cpu().numpy()
return postprocess_audio(
audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE
)
def synth_speech_gl(self, text, griffin_iters=60):
mel_outputs_postnet = self.generate_mel_postnet(text)
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
spec_from_mel_scaling = 1000
spec_from_mel = torch.mm(mel_decompress[0], taco_stft.mel_basis)
spec_from_mel = torch.mm(mel_decompress[0], self.taco_stft.mel_basis)
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
spec_from_mel = spec_from_mel * spec_from_mel_scaling
audio = griffin_lim(
torch.autograd.Variable(spec_from_mel[:, :, :-1]),
taco_stft.stft_fn,
self.taco_stft.stft_fn,
griffin_iters,
)
audio = audio.squeeze()
audio = audio.cpu().numpy()
slow_data = time_stretch(audio, 0.8)
float_data = resample(slow_data, TTS_SAMPLE_RATE, OUTPUT_SAMPLE_RATE)
data = float2pcm(float_data)
return data.tobytes()
# adapted from
# https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py
def float2pcm(sig, dtype="int16"):
"""Convert floating point signal with a range from -1 to 1 to PCM.
Any signal values outside the interval [-1.0, 1.0) are clipped.
No dithering is used.
Note that there are different possibilities for scaling floating
point numbers to PCM numbers, this function implements just one of
them. For an overview of alternatives see
http://blog.bjornroche.com/2009/12/int-float-int-its-jungle-out-there.html
Parameters
----------
sig : array_like
Input array, must have floating point type.
dtype : data type, optional
Desired (integer) data type.
Returns
-------
numpy.ndarray
Integer data, scaled and clipped to the range of the given
*dtype*.
See Also
--------
pcm2float, dtype
"""
sig = np.asarray(sig)
if sig.dtype.kind != "f":
raise TypeError("'sig' must be a float array")
dtype = np.dtype(dtype)
if dtype.kind not in "iu":
raise TypeError("'dtype' must be an integer type")
i = np.iinfo(dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
return (sig * abs_max + offset).clip(i.min, i.max).astype(dtype)
return postprocess_audio(
audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE
)
def player_gen():
@ -167,7 +134,9 @@ def player_gen():
def repl():
tts_model = TTSModel("/path/to/tacotron2.pt", "/path/to/waveglow.pt")
tts_model = TTSModel(
"/Users/malar/Work/tacotron2_r4_83000.pt", "/Users/malar/Work/waveglow_484000"
)
player = player_gen()
def loop():