tacotron2/inference.ipynb

235 lines
202 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tacotron 2 inference code \n",
"Edit the variables **checkpoint_path** and **text** to match yours and run the entire code to generate plots of mel outputs, alignments and audio synthesis from the generated mel-spectrogram using Griffin-Lim."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Import libraries and setup matplotlib"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": []
}
],
"source": [
"import matplotlib\n",
"matplotlib.use(\"Agg\")\n",
"import matplotlib.pylab as plt\n",
"%matplotlib inline\n",
"import IPython.display as ipd\n",
"\n",
"import sys\n",
"sys.path.append('waveglow/')\n",
"import numpy as np\n",
"import torch\n",
"\n",
"from hparams import create_hparams\n",
"from model import Tacotron2\n",
"from layers import TacotronSTFT\n",
"from audio_processing import griffin_lim\n",
"from train import load_model\n",
"from text import text_to_sequence\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def plot_data(data, figsize=(16, 4)):\n",
" fig, axes = plt.subplots(1, len(data), figsize=figsize)\n",
" for i in range(len(data)):\n",
" axes[i].imshow(data[i], aspect='auto', origin='bottom', \n",
" interpolation='none')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Setup hparams"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
2018-11-27 20:04:36 +00:00
"outputs": [],
"source": [
2018-11-27 20:04:36 +00:00
"hparams = create_hparams()\n",
"hparams.sampling_rate = 22050"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Load model from checkpoint"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"checkpoint_path = \"tacotron2_statedict\"\n",
"model = load_model(hparams)\n",
2018-11-27 20:04:36 +00:00
"model.load_state_dict(torch.load(checkpoint_path)['state_dict'])\n",
"_ = model.eval()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Load WaveGlow for mel2audio synthesis"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": []
}
],
"source": [
"waveglow_path = 'waveglow_old.pt'\n",
"waveglow = torch.load(waveglow_path)['model']\n",
"waveglow.cuda().half()\n",
"for k in waveglow.convinv:\n",
" k.float()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Prepare text input"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"text = \"Waveglow is really awesome!\"\n",
"sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]\n",
"sequence = torch.autograd.Variable(\n",
" torch.from_numpy(sequence)).cuda().long()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Decode text input and plot results"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
2018-11-27 20:04:36 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA6IAAAD8CAYAAABtlBmdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzsvVusbVl6HvT9Y8y51r6cS126u7rabdLYDo4MUhJiRFAECpiAkgdMXizyEBlkyXmAgKU8JPDEYx4SR3mKZByQI4VLSIISAYpIrOTBQrJkjEViN8jCtOVLd7XLXVXn7L3XWnPOMX4exvgvc+21T53qPrVP1an/k85Zc881L+M2v7nm/L/x/cTMCAQCgUAgEAgEAoFA4L6QXnYBAoFAIBAIBAKBQCDw2UI8iAYCgUAgEAgEAoFA4F4RD6KBQCAQCAQCgUAgELhXxINoIBAIBAKBQCAQCATuFfEgGggEAoFAIBAIBAKBe0U8iAYCgUAgEAgEAoFA4F4RD6KBQCAQCAQCgUAgELhXxINoIBAIBAKBQCAQCATuFfEgGggEAoFAIBAIBAKBe8Vwnyfb0JbPcAlKCcgZAMDbAXVsz8PLGcC5b6yPyAyQrSPitpbtuES2rOuZAFmuBKp9dQJA8gXZKUo/xWzLVAGqvNq0bNzJ+jZA24f6yZlIy89k30uBiG0/MIOTbOQOLKdlO8Z6vWsA+SoROJPWU8tW7Rx17PXYAjS2DXKqSGT1LLUVviy9EoW0TbAquz+HfUqbkSsiH9VPyiP9zcm1FdxydX1T+gGrOxa5YyX4Lu07AVXGlCsTFSAtsszaTjhRdrjyrAbbiX4B2Tbczy/1WJVLK9rXD4TlvC9v28ZDrhhzcafiXv2EpMsE6suF+3VUEmrvO1qotddR/fXvCv1+1XduPGuhmdf7nxiDbdsT65/x/W55gqns6OT2n1LcxXVl06+t7R1cJ8tHHUUnWsea3/VxdcsrrpOd6DTXsV0HgrqhFfcYL7JxgL/+yH2v3OvGEbnjnbp2+PR6OkH2nGBc56pJlRv/Aqibtm7FdbnqtQMAS+e6WhzXaVs6XmOsrhOtp7bfuu3YdZhyXL7NU776VIGk95MP4Tqy9dKmrU3suM/LdX5bf7yPzHXS/0IbdLTts7huqBiT3GTWkFP7Fv5IXOfupSe5zt/nP2auA4An0zffZebP373jpwvCdYF7BvmL6qWW5JOPUzdQ4PR1/R2fS/8DEek5+OM418lzr/8gIkCfMZItp8abnAg89HtmJr2H1AFgeUIcKoZszwtydOHsszRj7DfDkQqGvvzVfzo/F9fd64PoGS7xr9IPIV1cIr32GAAwfc9buP7SFgDw3vcnzI9aJetZ77zMQP8RQZuKYWwVLCWBUtsmJbv7VPlhMSfw3JbT0wH50Bp6OWewHE/uuAth837b9vwdYPtB+37cMfKuH7v/Tnn65cHd+Ambp60Mm6uKNPeHg5FQtv18/XP7tOrNLk2MYdcfMCpjOWs9X/uPVP8DJx8Kyrb/kCUgzX0wzO6O27GcZRxeG/oyYdi3843XVctz9aV2rCffVzF+8QYA8PjBDhfj3M5NjPdv2q+ED95rNxd6b8Rw3fbPO0KeepNMwLCTc/TPG8Z4Xfr31f1YItRRLgbCctHqenjUPudLQu2jkRgo/UfkeMXYftCOvf2g9Dap+vKiDoT5Ul5kkPvR1z7LlnB43BsoAdR/kG3fZ5x9q/fdk34BPZ2Rd70dlmrkxQweex8kAuc+GNyPwrT0cTtm+8GZCWXbth1uipZXH0oHAi2tDLvPj/jWv9j2m79nDwB447VrvP3wSStjWvRlwc2ywVlu5dyXUddfz63RfufqEk/fbX23eWdElmc8svrLb768A4abtv+wZ2yu+g/D66JjjKb2maaltQsAFAZVNwaLVIqN3B3x2g89Bvx+lfG/f+O/xauGU1w3f+ULePp7zgAAH3xvwvRaHzPbznUDA9t+7QwVKff1MI5z9zWwe2HEhzY+03VG6lxXzit4PHqgnZJx3TeN64YdY7hZc8r124O+uGLCmusmKbNxnbykG6+NC9NUdRwxEYpwXecCKqwPR2lm44j+N9A4EFjfMJeLjOlhP1aGct2wqyhnrX5Pv9y57nuN615/eIOzYdFzvNe57ukH/cno/Q3GK8d1B/QynOK6iuHKOMkKDtTBxEbzg1aOw+P2OV9gVU/hqvGarY2fOA6Vl4iO68rGXjqWbT/PJWF+4IrRq7l5n3EuXPdBWzlcL8j7vkFh99JpzXXKcdnqQ/1ar0MCj/ZjRjhZ+quOyR52E2k/7z434ls/0Dn7+3YAgM+/8QSfv7huZSN7IF04Y9MrMlX7uSJc9+7VJZ54rts7Tp7X7TDsgLw7wXW7qmVO/ZPmsuY6aZ9arU4fxnW1nvyh+w++9ld+/dbKTzGE6z6VcPf4ez0fHYkR+fbvOfvu6EVc35dy1v1Yx9zplzkfuYzH5XvWse+qUyvYh7ftx/2QmHJrK8AewgCgMri4OkkfSD2e1SenIP2SSF8+02YDzI2I6qHfTF5EvVYvCR03+/pJGXIGXbT7G2234PN2w+Cz9lkebTG93vh0/3rGfNmOsfsCYf+FzotvTHj9cePn184bZ4+p4K3zpwCA33f5Dbw9vgcA+OLwAb6QrwAAf+grv/FcXBfS3EAgEAgEAoFAIBAI3CvuNSIKoD3Jp6RREVosUsi5y5bQowMAkBjokYGUGSl7rVTfJLG9IJAweOKVDMxUbLYf9+XVywUnyaHKGoGUt89ptmhdk6Cx7Seyp8EkVCIn9bI3ToSaXdhczr1Inb1k1eRseaoaedNt3duVlBOGvUR7k56zjvb2XCIcyIzUI8o5VQ23J2KLMCdpCGufJg9zy7dkZ+xkt65ulTUCXUZCGZ1MDb3vZTRWWy5b0vaWSAcV9yaaLAq6Kk+yT+51rpmRetnmS9I35tQjS3nKFp3B4mRcFTT3N2dDsiEkb8wHG88JAA4SLUpIXTaWdu21PG2ySddm0ldBbZys3wwSscoIBxf1T8TYnJCxTalVfswFNNg1pVHpBD2fHm0LUHER05p0GTc9MiRjNyW7fhKD+8GIGZCICfM6YtDBZOu0lsyv9qsw4Tq5xp0Usg4nuC4zyHFdHnokkEmnJKTEqLVfy8I9xNq+XrLZVpwolpcsSncW4zpBmo0vUvHyV4v6MBs/SeQpLWyKsUyqUvHlUj70ctPE2lb5UFo0sFW0n5dVN5rmhHyQtjKe5UwqxdcA2h1cV5lMZpSk/Uza67mOncS/KtcRuEc+q5cguzYv26SRYilPHcl4GMZZy5lFYD3X+Zkk7M5dj7kuW39x5sYvAJZLwkG4rvTI+cKm0JkrWO7Hc7VIYKYm9QZAh2Lrej1TZeVATqSReL1Hze7iJrsHUWUdTD444LmuSj+zcV0i1vVL55vtuKy4bnVvkvaGtC9gg7XdI6Xssnoorm7C/6hgkWB7DvwQriNgrf4IBAKBwC28yj8DA4FAIBAIBAKBQCDwCcRLiIi26BEvfW7N9QHDrmmV04S1QYBAIgCFUKjPXwH07WYt0B1rfwvLUwYdejTqQDonDvuEOtrbeKDNm9P5JBOQ+/ynvKs6l1Pe2g+7Qd9wpwVmbAH3opRI54ZKlACElbGFRQtZ50LJq3Q/f4gK6/yjvF9uRS38G1cfwa2bZMYo52llaNE+raFLTWZQhBblAKBRGFc0gLGKqKgJksxn89Hcuo7wSRnKGWE5X89vWs7d22y4Y1xAowRpkWiJGa7UgVB1ojVuRx82dt5VPdyyj4ArXISDClvUtep/FiVli6LQUgFnCMISHSjShxb1aeXzEYoe7egRSmbCvpy+RJfeWJVJ504V34AuQqZtSbAr3oqupi4gcmM02XKPopArOtWk9VzN+6wA6SRYmbdydFGfdBd7BdG5DksjgXwzIR/aHNEVd/jrqUfQa2JgMeeZk7Q4dS6ck3Jd3pNxDk5w3UygznX5wBh6VHG4qRhubO4kAAz7QedyU1lznZrUZOgcUeFQTmagxSCndGCNtin7ZbIge2Ggz9lJ+4Lkrq9
"text/plain": [
"<Figure size 1152x288 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)\n",
"plot_data((mel_outputs.data.cpu().numpy()[0],\n",
" mel_outputs_postnet.data.cpu().numpy()[0],\n",
" alignments.data.cpu().numpy()[0].T))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Synthesize audio from spectrogram using WaveGlow"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <audio controls=\"controls\" >\n",
2018-11-27 20:04:36 +00:00
" <source src=\"data:audio/wav;base64,UklGRiRQAQBXQVZFZm10IBAAAAABAAEAIlYAAESsAAACABAAZGF0YQBQAQAiACQAIwAhAB8AIQAfAB4AGwARABEACQABAAAA+P/z/+v/3//j/+H/4v/e/+H/4f/f/9L/1f/b/9H/zv/U/9r/1f/E/8n/1P/U/9D/2v/c/93/2P/d/+r/5f/q//P/8v/v/+r/7v/5//L/8v////b/+P/z//r/AAAAAAUAAwAAAP//+P8CAAMABQADAAkACwAEAAIABQAPAAwAFgAZAB8AHwAcAC0ALwA3ADcAQgBAAD0AOQA/AEQAQAA9AEEAPAApACAAGwAQAAgAAgAAAPP/5P/U/9T/z//C/7r/xv/H/8b/wP/P/9b/2v/g//H/AAD5/wAADgAWABYAFgAgABwAFAALABEAEAACAP7/BwADAP//BAAEAAMACQATAB4AGgAcABQAIQAdABIAHAAYAB0AGwANABUAFgAKABIAEwAVABMACwAWABQAGQAeACEAGwAfABcAGwAcABUAHgAfABQAEgALAAoAAQD8//7/9f/w//D/4v/c/+b/4//d/+T/4f/i/+X/7f/3/wIACwATABYAGQAUABMAGwAeACAAHAAYABcAEQAWABIADgASABUAEgAHAP////8AAAAAAAAMAAsAAAD9//v/BQAEAAAABwAFAPf/8//9//L/6v/k/+7/5//l/9H/0v/b/9D/1//W/9X/0//K/9D/3//l/+v/8v/+/wEAAAAXACIAJwAzADwAQQBHAEwAVABcAGYAaQBsAGcAZQBdAF8AZwBZAF4AYQBUAFoAUABMAFMATQBPAEwASgBHADAANwA5ADQAMgA7AEEAOQA3ADwATABFAEgASwBGAEQAPAA7ADUAJQAOAP//6f/U/7b/n/+J/3H/Zf9J/y//Iv8J/wH/+v7f/sP+sf6l/pT+bf5k/lj+Pf4y/iH+D/4D/gP+/f0G/hH+Ef4c/iP+M/49/lv+cP6I/qz+xf7i/v7+G/9C/3D/j/+x/+f/EAA/AFoAigDFAN8AGgFaAXsBoQHTAQwCPAJmApsCzgL7AikDSQOBA7oD4AMKBD0EYwSKBKAEwwTfBOsE/gQUBRUFCgUPBQsF+gTVBKkEdwQzBOkDkgM9A9sCZgLmAWAByQAwAJT/+f5j/rn9Hv2C/Of7WfvI+kb6wflE+dT4bfgL+Kb3TPcQ9832lfZx9lL2VPZb9mf2hfa59vP2RPed9wD4h/j/+Jn5M/rY+pL7TvwX/eL9s/6N/3IASAEdAukCqQNeBPgEmQUlBpUG8QZCB3YHngerB7UHwge5B74HugexB6sHrAfJB/EHFAg9CFYIZwh1CHQIcwhVCC8I4AeOByIHpQYaBm8FywQPBFkDnQLcASgBdwDP/zX/iv7v/WH9zvxT/MH7RPu++jP6yPla+fD4hPgO+OL3qvdt92H3RPdU93D3j/fD9wP4Uvim+CH5pPkj+qr6Lvu6+z38rPwp/Zj96v0s/nn+vP7d/vD++P70/uL+2P6//pv+cP5U/jD+AP7L/ZX9k/2b/Zj9pf2x/eP9KP5y/t3+af8LAM4AsAGWAogDjgTCBfcGGQhBCWQKdgtuDEMNCg6aDhAPTA9gD0IP5Q5eDr0N4gzFC38KEgmPB+QFKgRwAqsA7/5S/dP7ePpY+Vr4rfc69/72BvdD97j3V/gK+eb5uPqb+2z8MP3V/V7+1f4E/zT/G//k/qr+Pf6//TL9v/w//M/7dvsy+/n6zPq/+tH68voT+1P7lPvh+x38U/yN/Kj8u/y6/Kz8evw9/Pb7sPtj+/36rPps+jn6IPou+mP6v/o/++j71/zY/QP/XADCAUkDyARVBswHLgmACqMLxAyeDTYOrg7+DhYP9w6aDg8OeQ2pDK0LqAqLCVYIBAe8BWoECAO9AXkANv8R/uv88Pv8+iP6ZPnE+Fb4+vfC97/30PcF+GL43/hz+RD6vPpj+x78zfxg/ej9Q/6S/sz+5f7k/sT+kv5Z/gf+r/1g/Rf94/yx/I38dfxW/Ej8SPxL/EP8Nfwo/CH8//vY+7D7iPtc+yv79fra+t/65foD+zv7jPvj+0780vxv/RH+3f7E/60AywHxAkMElwUDB3gI+QmBC/cMVg6qD70QpRFdEt0SKRMKE60S2hHIEH4P5Q1ADGkKYAhmBm8EXQJ+AKb+A/2P+076M/lQ+Hv3r/YY9p/1MfXP9Hz0SPQc9Pbz7PP/8zz0kvQf9dP1n/aE95X4u/n0+jv8Zv1m/jT/3v9lALMAwgCvAFMA3f9H/4/+vP3x/EH8q/tH++T6mfpo+k/6bvq0+u/6Lvtn+4/7tvvV+9X7u/uG+0X7+vqK+if65vnY+Q36uvrU+3D9jP/tAdoENAjSC64PZxPnFjcaGh1DH/og6SEFIoAhECDMHRMbphfVE8QPjwuGB5cD2v9a/Dv5oPab9CHzDPJl8fzw2fD48Ebxc/G/8QjyI/Jk8nLyXfKJ8pvywfI6867zXvQ59Tz2hffl+Gz6APyG/Rz/nwD6AToDLgTNBCgFPAUDBYgEsgObAkgBwf8H/iv8Zvq0+Cn37/Xe9PbzVPMN8znzvvN19Ez1H/bn9sr3b/j/+Gn5o/n1+UX6YPpq+tT6y/u+/T0AbAPVBpkK9g7UE+UYqB0MIh0mgSkOLFUtHC0dLEQqsSdVJPMf1BpGFWoPygmHBHf/8vqs9tfyre8q7Tfr7en36G/oL+gM6AjoLuie6EPpMeo66zbsQ+2t7kjwMfJQ9IP26vhZ+6z95f8KAt4DlQUkB00IKQmZCZ0JeAkhCaYIzQe/BmAFrAPYAdb/lv1J+5/4zfUM8znwl+1E6xXpOefc5TLlL+XR5QTnj+jG6jLtDPDn8n718/c++jT8Pf69/14B8AMbB1MLXRAGFdkZnh9wJTgsaDLiNoI6/zwEPhc+oTumN7Myhiy5JTYesRV/DacFlf6T+FPz/e4D65Dn3eSa4nLhZOAR3/Ddo9yv2xfbQdr32Zba09tW3ljh7uR26aLug/Qm+7AB0gdqDc4RbxUbGMcZXRrCGYQYbBbaEyIRIw5WC7gIVQY7BOoBbv/0/NT5sfZS8yHvFevI5hfi89032jrXt9Uz1fvV89fR2pfeOeOR6OntjPLb9jP6Wf2d/38ABgEfAvwE7QklENMVDxxGI5QsrjapP2lGzUvhT2FS5lFwTSRGFT7gNeQshyI2F/sLPAIc+jLzyu0n6ZjliuKn3wDdutr/1/LU3NHNzrLMtsq5yPzHEMnJzODSxdme4TnqhvOe/egGgQ8/F6Qd3CKfJUEmWSU2I70g5h29GrMXoxSFESkP9gw4C+kJlgdgBJQApfvS9RPwlukm45TdhtcR0kzOqstayzHNp8+c01jYfNw34ajkX+cv6vHsde+F8UDy/fEj8/n5SQT8ENEdcCZsMQU/kkvUVapcUV5AYfZisF0YVeNI9TwvNWMt0CN1GrwP/QZV/2L4cPKg7MnmJeCS2HfRS8qOwqq8LLjctjy3x7d7uKW7V8JrzGfXGOHv6q/0/f4zCOoOnRSPGtQfayUZKDkprinwKTcqHSpvKesmzSPFHlAZiBP2DX4HzACa+bzz7+0y6BziD9yH2HrVdtPO0InOp8x1zXrNhc+00NbSqdS91l7ZBt3y4fvlcekI6w7ulO+B+QcD7BKUIV8roTRFPytJ9lOvWm1bgGBpY2Jil12xU8BK5UnxRB9AODjhKxMiRBUeBaX5Te2m4qjZEc15w8O947UJscWuMq8ctYC4ArkYusK8CcQzzVHVRt5Y6NXycP0SB9cPrRkDI5wrIzLKNUg3xDWiM5gxVC+mLb0pTSKFHA4V0A6KCUUBfvod9QzvfOn34qPantYX0x3SRNCFzUrKY8gCyFDKrcvYzpXQAdPo15PZo9373kDikOcb6o/p5fJH9vsMByDrKXA5HEGKS6ZXZF4TXK5fkWLeYZJfmVqATnxNtUzXRpZCLTpjKeoaJAqK97frMN46zybEer1JtwK1oa/jq3+tc7MKuG+5crgKuWK+xsXNzsDW4eRo8V8
" Your browser does not support the audio element.\n",
" </audio>\n",
" "
],
"text/plain": [
"<IPython.lib.display.Audio object>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with torch.no_grad():\n",
" audio = waveglow.infer(mel_outputs_postnet.half(), sigma=0.666)\n",
"ipd.Audio(audio[0].data.cpu().numpy(), rate=hparams.sampling_rate)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}