447 lines
88 KiB
Plaintext
447 lines
88 KiB
Plaintext
|
|
{
|
||
|
|
"cells": [
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": 7,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [
|
||
|
|
{
|
||
|
|
"name": "stdout",
|
||
|
|
"output_type": "stream",
|
||
|
|
"text": [
|
||
|
|
"1\n"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"data": {
|
||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAC7CAYAAAB1qmWGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAEYJJREFUeJzt3XuQ1fV5x/HPw7KwKAgoihvEQgze\nEzHuoFXTghqrTuKljlYbHYxk0KBTNGkStBpjJipN1ZjUxIqVQhqrpPFujC2DF3Q04EKRqwJVrEsQ\nVDRQVFh2n/6xh87q97twds/vXH5f3q8ZZ895zvec3/M7++zDz9/ta+4uAED+9ap2AgCAbNDQASAR\nNHQASAQNHQASQUMHgETQ0AEgETR0AEgEDR0AElFSQzez08zsNTNbbWZTskoKqDZqG3lkPb1S1Mzq\nJK2U9GVJLZJelnShuy/PLj2g8qht5FXvEt47RtJqd39dkszsAUlnSeqy6PtYX2/QniUsEujax9qi\nbb7VMvgoahs1pdjaLqWhD5P0VqfnLZKO3dkbGrSnjrWTS1gk0LV5Pierj6K2UVOKre1SGnpRzGyi\npImS1KA9yr04oGKobdSaUg6KrpU0vNPzAwqxT3D3ae7e5O5N9epbwuKAiqG2kUulNPSXJY0ys5Fm\n1kfSBZIeyyYtoKqobeRSj3e5uPt2M7tS0n9IqpM03d2XZZYZUCXUNvKqpH3o7v6kpCczygWoGdQ2\n8ogrRQEgETR0AEgEDR0AEkFDB4BE0NABIBE0dABIBA0dABJBQweARNDQASARNHQASAQNHQASQUMH\ngETQ0AEgETR0AEhE2aegQ6hu0MAg9tqdnw1ir4775+j7r9twTBBb8rWDo2Pblq/sZnbArn381THR\neL/fLQxi3nR4EHvjzPiE2l86aUkQe/7pzxedV+NLbdF4w+Pzi/6MPGMLHQASQUMHgETQ0AEgETR0\nAEhESQdFzWyNpM2S2iRtd/emLJJKXfvIA4LYkrF3B7FWj7//R/stCGJHnXN8dOxwDor2yO5Y23VD\n9onG22b1C2IPjLo9OnZ9W30QG9jr2SB2YO89ik9s/Nyih2646MNo/A8/6xPELrt5cnTsPve8VPTy\nak0WZ7mMc/d3M/gcoNZQ28gVdrkAQCJKbegu6T/NbIGZTcwiIaBGUNvInVJ3uZzo7mvNbD9Js83s\nVXf/xA6vwh/DRElqUDf2mwHVRW0jd0raQnf3tYWfGyQ9LCm4fMzdp7l7k7s31atvKYsDKobaRh71\neAvdzPaU1MvdNxcenyrph5llloDew8OzWSRp5LTVFc4E3bG71vbKnx4Yjb926L2RaPz/SParC2O/\n+CC8LcXCzfFltWwZ1GV+n1Zn7UHst4c8XnRes677h+jYy1dcGcR6vbCo6LyqqZRdLkMlPWxmOz7n\n39z9qUyyAqqL2kYu9bihu/vrko7KMBegJlDbyCtOWwSARNDQASAR3A89I//z/fDS+2NOWx4d++PG\n5zNffv/j34nG37o+zGvI4u3Rsf0e3T3uGQ3J/zTcozTr+PD2Ex3CNvHUR/GDolO/Mz6IDVgWudj2\nnY3R9/d6/60ucgh5r/BI58G3TYqOXX7+Pwaxg+r7R8d+dN2mIDbwkqHRsdvfXr+zFCuOLXQASAQN\nHQASQUMHgETQ0AEgETR0AEgEZ7lkZPFl4VH0Vo/PQF4Ozx51X/yFyOUxD29pjA6dvvnsINb76XAy\nDeRf68BwwofRfeLtoF3hTCvf+ZdLo2OHP/xiECvbX0F7+Mmfu/r30aGH9Qkv51981k+jY5/7/G+C\n2AmnxM+eGfgrznIBAJQBDR0AEkFDB4BE0NABIBEcFO2m+mfjBxTrLXLD5TL5r23hfaDXtO4bHXvO\nnuEl1uf33xAde/6/TgtiXxl2TDezQx60NVjRY7/w4iVB7MCbwoOftWzUFfOC2BOnxP+Wz+v/XhD7\n4Mwt0bEDf1VaXlljCx0AEkFDB4BE0NABIBE0dABIxC4buplNN7MNZra0U2xvM5ttZqsKPweXN00g\ne9Q2UlPMWS4zJN0p6ZedYlMkzXH3qWY2pfD8e9mnV10fnT0miH298d+jY2OX+Zd66f+Rcy6Pxved\n0zeI9f1jfFnXjA3/zV5y3s+KzqHlmnCCDEk64JZ8neXQhRnaTWv7kGuWFT22bsGAMmZSPX/3cnir\nC0k6b9y9QeyKI+ZGxz6h2vr3fpdb6O4+V9Knz307S9LMwuOZkuLfDFDDqG2kpqf70Ie6+7rC47cl\nxednAvKH2kZulXxQ1N1dityOrcDMJppZs5k1t2prqYsDKobaRt70tKGvN7NGSSr8jF96KMndp7l7\nk7s31Svc9wvUGGobudXTS/8fkzRe0tTCz0czy6gK6o44JBr/0e3hpfBNfbZ19SlFLy92P/Lrnjk3\niB323Vej72/bFM5K3pVDVh0cxOaf2RAdO6bvx0Hsd9/8cXTsqQ3fDWIjbo7fO9235mrrNana7vWF\nQ6PxsYNmB7GVreHvX5KGLG7NNKdaMfi5+N+BxlU2jywVc9ri/ZJeknSImbWY2QR1FPuXzWyVpFMK\nz4FcobaRml1uobv7hV28dHLGuQAVRW0jNVwpCgCJoKEDQCJo6ACQCCa4kNTexWznXZ/RUpxL3zwt\nGt/8V/2C2MEt84NYFrOlty1fGcQmzYjfUqD5sjuCWGNdmKskLZwQjj33ofHRsf7Kip2liDJaNX5Q\nNH5B/3eC2ImLL46O3evJlzPNCeXDFjoAJIKGDgCJoKEDQCJo6ACQCA6KZuTa9U1BbNM39omObWtZ\nVe50dmrEg+9G49effVwQm7o/B8Ty7OrTfxuNxy7z7/PzeL1K/51hRignttABIBE0dABIBA0dABJB\nQweARHBQdCfqrfh7nC/+Ymxim+oe/OySWTTcu1d7EOvOd/CHG+Px/ZmVs+bc/d6fBbGGJ8KrlZEv\nbKEDQCJo6ACQCBo6ACSChg4AiShmTtHpZrbBzJZ2iv3AzNaa2aLCf2eUN00ge9Q2UlPMWS4zJN0p\n6Zefiv/E3W/NPKMqeO2be0TjrZ7FHclrz5q/jF/i/Zt9w7McWj1+lkvsu/nMDfHlhefO1IwZSqi2\n6wYNDGIDerVUIRNUyy630N19rqSNFcgFqChqG6kpZR/6lWa2uPC/rYMzywioPmobudTThn6XpIMk\njZa0TtJtXQ00s4lm1mxmza3a2sPFARVDbSO3etTQ3X29u7e5e7ukeySN2cnYae7e5O5N9erb0zyB\niqC2kWc9uvTfzBrdfV3h6TmSlu5sfK277kuPVzuFkvUefkA0vvmYzwSxf/r6L0pe3vytDUHMtm0v\n+XOrLc+13TLhiCD2tQHPRMcu3DKizNnUvq1n/LHosR+29yljJtnZZUM3s/sljZU0xMxaJN0gaayZ\njZbkktZIuqyMOQJlQW0jNbts6O5+YSR8bxlyASqK2kZquFIUABJBQweARNDQASARTHCRiOU37h+N\nLzv1zpI+98H/HRKN3/W35wWxhhVMkIDatP2kY4LYA0d39bcRnoL68N+fHB05UL8vJa3MsYUOAImg\noQNAImjoAJAIGjoAJIKDojlU/2xjELul8cGyLGvG2uOj8YbHOQCK2hM7+ClJGydvCWKH1sfvvzNp\n7QlBbNCshdGx3o3cKoEtdABIBA0dABJBQweARNDQASARNHQASARnuUiqs/i89PUWn/E+ZtNfH1f0\n2Bt/GN6hdVy/j4t+fyyvVm/rYnTx6xDjJ60t6f2onL3WhDWwZvuHVcikMqx32L4+uHpzdGzzFx8I\nYrM/6hcdu/L6cKKQPq3N3cyuOthCB4BE0NABIBE0dABIxC4bupkNN7NnzGy5mS0zs8mF+N5mNtvM\nVhV+Di5/ukB2qG2kxtx3fvGqmTVKanT3hWY2QNICSWdLukTSRnefamZTJA129+/t7LP2sr39WIvf\nV7ia3rwxfnn7wgl3FP0
|
||
|
|
"text/plain": [
|
||
|
|
"<matplotlib.figure.Figure at 0x7f3fcdaefb38>"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
"metadata": {},
|
||
|
|
"output_type": "display_data"
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"source": [
|
||
|
|
"'''Train a Siamese MLP on pairs of digits from the MNIST dataset.\n",
|
||
|
|
"\n",
|
||
|
|
"It follows Hadsell-et-al.'06 [1] by computing the Euclidean distance on the\n",
|
||
|
|
"output of the shared network and by optimizing the contrastive loss (see paper\n",
|
||
|
|
"for mode details).\n",
|
||
|
|
"\n",
|
||
|
|
"[1] \"Dimensionality Reduction by Learning an Invariant Mapping\"\n",
|
||
|
|
" http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf\n",
|
||
|
|
"\n",
|
||
|
|
"Gets to 97.2% test accuracy after 20 epochs.\n",
|
||
|
|
"2 seconds per epoch on a Titan X Maxwell GPU\n",
|
||
|
|
"'''\n",
|
||
|
|
"from __future__ import absolute_import\n",
|
||
|
|
"from __future__ import print_function\n",
|
||
|
|
"import numpy as np\n",
|
||
|
|
"\n",
|
||
|
|
"import random\n",
|
||
|
|
"from keras.datasets import mnist\n",
|
||
|
|
"from keras.models import Model\n",
|
||
|
|
"from keras.layers import Dense, Dropout, Input, Lambda\n",
|
||
|
|
"from keras.optimizers import RMSprop\n",
|
||
|
|
"from keras import backend as K\n",
|
||
|
|
"\n",
|
||
|
|
"%matplotlib inline\n",
|
||
|
|
"import matplotlib.pyplot as plt\n",
|
||
|
|
"\n",
|
||
|
|
"num_classes = 10\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"def euclidean_distance(vects):\n",
|
||
|
|
" x, y = vects\n",
|
||
|
|
" return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon()))\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"def eucl_dist_output_shape(shapes):\n",
|
||
|
|
" shape1, shape2 = shapes\n",
|
||
|
|
" return (shape1[0], 1)\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"def contrastive_loss(y_true, y_pred):\n",
|
||
|
|
" '''Contrastive loss from Hadsell-et-al.'06\n",
|
||
|
|
" http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf\n",
|
||
|
|
" '''\n",
|
||
|
|
" margin = 1\n",
|
||
|
|
" return K.mean(y_true * K.square(y_pred) +\n",
|
||
|
|
" (1 - y_true) * K.square(K.maximum(margin - y_pred, 0)))\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"def create_pairs(x, digit_indices):\n",
|
||
|
|
" '''Positive and negative pair creation.\n",
|
||
|
|
" Alternates between positive and negative pairs.\n",
|
||
|
|
" '''\n",
|
||
|
|
" pairs = []\n",
|
||
|
|
" labels = []\n",
|
||
|
|
" n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1\n",
|
||
|
|
" for d in range(num_classes):\n",
|
||
|
|
" for i in range(n):\n",
|
||
|
|
" z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]\n",
|
||
|
|
" pairs += [[x[z1], x[z2]]]\n",
|
||
|
|
" inc = random.randrange(1, num_classes)\n",
|
||
|
|
" dn = (d + inc) % num_classes\n",
|
||
|
|
" z1, z2 = digit_indices[d][i], digit_indices[dn][i]\n",
|
||
|
|
" pairs += [[x[z1], x[z2]]]\n",
|
||
|
|
" labels += [1, 0]\n",
|
||
|
|
" return np.array(pairs), np.array(labels)\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"def create_base_network(input_dim):\n",
|
||
|
|
" '''Base network to be shared (eq. to feature extraction).\n",
|
||
|
|
" '''\n",
|
||
|
|
" input = Input(shape=(input_dim,))\n",
|
||
|
|
" x = Dense(128, activation='relu')(input)\n",
|
||
|
|
" x = Dropout(0.1)(x)\n",
|
||
|
|
" x = Dense(128, activation='relu')(x)\n",
|
||
|
|
" x = Dropout(0.1)(x)\n",
|
||
|
|
" x = Dense(128, activation='relu')(x)\n",
|
||
|
|
" return Model(input, x)\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"def compute_accuracy(y_true, y_pred):\n",
|
||
|
|
" '''Compute classification accuracy with a fixed threshold on distances.\n",
|
||
|
|
" '''\n",
|
||
|
|
" pred = y_pred.ravel() < 0.5\n",
|
||
|
|
" return np.mean(pred == y_true)\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"def accuracy(y_true, y_pred):\n",
|
||
|
|
" '''Compute classification accuracy with a fixed threshold on distances.\n",
|
||
|
|
" '''\n",
|
||
|
|
" return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"# the data, shuffled and split between train and test sets\n",
|
||
|
|
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
|
||
|
|
"x_train = x_train.reshape(60000, 784)\n",
|
||
|
|
"x_test = x_test.reshape(10000, 784)\n",
|
||
|
|
"x_train = x_train.astype('float32')\n",
|
||
|
|
"x_test = x_test.astype('float32')\n",
|
||
|
|
"x_train /= 255\n",
|
||
|
|
"x_test /= 255\n",
|
||
|
|
"input_dim = 784\n",
|
||
|
|
"epochs = 20\n",
|
||
|
|
"\n",
|
||
|
|
"# create training+test positive and negative pairs\n",
|
||
|
|
"digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]\n",
|
||
|
|
"tr_pairs, tr_y = create_pairs(x_train, digit_indices)\n",
|
||
|
|
"\n",
|
||
|
|
"digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]\n",
|
||
|
|
"te_pairs, te_y = create_pairs(x_test, digit_indices)\n",
|
||
|
|
"def show_nth(n):\n",
|
||
|
|
" plt.subplot(1,2,1)\n",
|
||
|
|
" plt.imshow(te_pairs[n][0].reshape(28,28))\n",
|
||
|
|
" print(te_y[n])\n",
|
||
|
|
" plt.subplot(1,2,2)\n",
|
||
|
|
" plt.imshow(te_pairs[n][1].reshape(28,28))\n",
|
||
|
|
"show_nth(0)\n",
|
||
|
|
"# # network definition\n",
|
||
|
|
"# base_network = create_base_network(input_dim)\n",
|
||
|
|
"\n",
|
||
|
|
"# input_a = Input(shape=(input_dim,))\n",
|
||
|
|
"# input_b = Input(shape=(input_dim,))\n",
|
||
|
|
"\n",
|
||
|
|
"# # because we re-use the same instance `base_network`,\n",
|
||
|
|
"# # the weights of the network\n",
|
||
|
|
"# # will be shared across the two branches\n",
|
||
|
|
"# processed_a = base_network(input_a)\n",
|
||
|
|
"# processed_b = base_network(input_b)\n",
|
||
|
|
"\n",
|
||
|
|
"# distance = Lambda(euclidean_distance,\n",
|
||
|
|
"# output_shape=eucl_dist_output_shape)([processed_a, processed_b])\n",
|
||
|
|
"\n",
|
||
|
|
"# model = Model([input_a, input_b], distance)\n",
|
||
|
|
"\n",
|
||
|
|
"# # train\n",
|
||
|
|
"# rms = RMSprop()\n",
|
||
|
|
"# model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])\n",
|
||
|
|
"# model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,\n",
|
||
|
|
"# batch_size=128,\n",
|
||
|
|
"# epochs=epochs,\n",
|
||
|
|
"# validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))\n",
|
||
|
|
"\n",
|
||
|
|
"# # compute final accuracy on training and test sets\n",
|
||
|
|
"# y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])\n",
|
||
|
|
"# tr_acc = compute_accuracy(tr_y, y_pred)\n",
|
||
|
|
"# y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])\n",
|
||
|
|
"# te_acc = compute_accuracy(te_y, y_pred)\n",
|
||
|
|
"\n",
|
||
|
|
"# print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))\n",
|
||
|
|
"# print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))\n"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": 13,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [
|
||
|
|
{
|
||
|
|
"name": "stdout",
|
||
|
|
"output_type": "stream",
|
||
|
|
"text": [
|
||
|
|
"0\n"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"data": {
|
||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAC7CAYAAAB1qmWGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAEYpJREFUeJzt3X+QVfV5x/HPAyzLL2tB44oILk0U\nxEQXWaBGneCvDnEiaKsZmWrRJEOSqtFIrIxOJG3shBkTNf6IFQUhrcpgFcGOk0AorWZUZEXkt1H5\noTD8SkyVGkWWffrHXjor3++Fy95z7+758n7NOHvvc7/3nuewDw/H+z3ne8zdBQDIvy4dnQAAIBs0\ndABIBA0dABJBQweARNDQASARNHQASAQNHQASQUMHgESU1dDNbKyZvWlmb5vZlKySAjoatY08svZe\nKWpmXSX9TtJFkrZIWiZpgruvzS49oPqobeRVtzLeO0rS2+6+QZLMbI6k8ZKKFn13q/Ue6l3GJoHi\nPtFH+tT3WAYfRW2jUym1tstp6AMkvdfm+RZJow/2hh7qrdF2QRmbBIpb6ouz+ihqG51KqbVdTkMv\niZlNkjRJknqoV6U3B1QNtY3OppxJ0a2SBrZ5fmIh9hnuPt3dG929sUa1ZWwOqBpqG7lUTkNfJulk\nMxtsZt0lXSlpQTZpAR2K2kYutfsrF3dvNrPrJf1aUldJM919TWaZAR2E2kZelfUdurs/L+n5jHIB\nOg1qG3nElaIAkAgaOgAkgoYOAImgoQNAImjoAJAIGjoAJIKGDgCJoKEDQCJo6ACQCBo6ACSChg4A\niaChA0AiaOgAkAgaOgAkgoYOAImgoQNAImjoAJAIGjoAJIKGDgCJKOueoma2SdJuSfskNbt7YxZJ\npa7b8XVB7IOz64PY1os8+v6N46YHsb2+Lzr27BVXBrFd7/WNjh02bXsQa970bnRs6qht5FFZDb3g\nPHf/fQafA3Q21DZyha9cACAR5TZ0l7TQzF4zs0lZJAR0EtQ2cqfcr1zOcfetZnacpEVmtt7dX2g7\noPCXYZIk9VCvMjcHVA21jdwp6wjd3bcWfu6UNE/SqMiY6e7e6O6NNaotZ3NA1VDbyKN2H6GbWW9J\nXdx9d+HxX0n6p8wyyxmrDf9Cb/jHM6NjH7j80SD2lZ5/Knlbez38d7hFLdGxLzY8EQYb4p/bcMw3\ngtigK0pOKxnUNvKqnK9c6iTNM7P9n/OEu/8qk6yAjkVtI5fa3dDdfYOkMzLMBegUqG3kFactAkAi\naOgAkIgsrhSFpHdvGRHEVl3984ps69rNFwSxGSctKvtzV3x5ZhAbp5Flfy6OHPvGhCcCdLtjRxB7\nbsiC6PtrrGsQO5xlLY65vSY61jZtDWJ/uGRYdGy/Z1cHsZbdu6NjOxuO0AEgETR0AEgEDR0AEkFD\nB4BE0NABIBGc5XKY/Kz49SYzv3F/5ts6/bHvReODf7w8iA2957ro2PXjH8w0Jxx5Ysta7B4XXz9i\n6k/CM6Viy1rEF6qQ9kbu6XI4y1qc+cNromPPOD48dp1f/0B07Mg/vyGI1d3/UnRsZ8MROgAkgoYO\nAImgoQNAImjoAJAIJkUPIjYB6ne+Hx07InJ/g2ITP/P+97ggNvOacUGsfumr8bxawkuhh3z/jejY\nrz773SD243+ZHh3bWBt+7oWr45c8/+aLR0XjSM+eMV8KYv95b3xCMWbJx32C2B13hmvvS1LNnyKz\nokV8eFJ4PNq9yG0F/uEH4WTtBy3N0bF9tsWXGsgDjtABIBE0dABIBA0dABJBQweARByyoZvZTDPb\naWar28T6mdkiM3ur8LNvZdMEskdtIzWlnOUyS9IDkn7ZJjZF0mJ3n2ZmUwrPb80+vY61c2TvILZs\naDhbLsUX5v+g5dPo2Klzw4X5619++TCz+yzfsyee18KmIHbVr78THbvmkvDMhVv6vRMd+8iTE4PY\n4AnxM206sVk6Qms7ptiyFj956OGSP2PCOxcHsQ+nDgxifZeUV++SdPQXBgexhqfi9Xpq9/DYdej8\n70fHnvLvS8tLrAMd8gjd3V+QdOC5euMlzS48ni3p0ozzAiqO2kZq2vsdep27bys83i6pLqN8gI5G\nbSO3yp4UdXeXVPRqADObZGZNZta0V/GvBYDOiNpG3rS3oe8ws/6SVPi5s9hAd5/u7o3u3lijyOWU\nQOdCbSO32nvp/wJJEyVNK/ycn1lGnUiXC/8QxIqtzRxbx/naDeHl/JJU/8PyJ4TKccp340sK3H/O\naUHs5n7ro2P/dtiyIPaSupeXWOdwRNR2zB9v/zgajy1rcfH6v46O7fqDPwtjr4fr92fhf0aE34ZN\nPW5uye8fuDDLbDqHUk5bfFLSy5KGmNkWM/umWov9IjN7S9KFhedArlDbSM0hj9DdfUKRly7IOBeg\nqqhtpIYrRQEgETR0AEgEDR0AEsENLiR1G3BCND55yG/K+twNT50cjddpV1mfWykz518YxG6+Nn6W\nC/Jt45zTg9ia4Y9Fx25pDs9+6XJ7fIkbf31leYlFWG38lNAv3LQ2iHUpcox67eZwWqTns/GzvfKM\nI3QASAQNHQASQUMHgETQ0AEgEUyKSvrjOYOi8cv7lH7V96T3xgSxAUXWZo7fazxfvthzSxB79S/O\nj45t3rCpwtngcP3dsHBCsNiyFpubw8v59Ur2k59SfAL0zXvj67TPH/RgEIvvgbT5riFBrJfyu+55\nMRyhA0AiaOgAkAgaOgAkgoYOAIlgUlTSrjOt7M94Z9qpQazn9vSuRNvva73DteLvbjw+OrYPk6I4\nQNfTwklKSVp3w9FBbP0l4eRnMUs+7hONH/XSxiC2r+RPzQ+O0AEgETR0AEgEDR0AEkFDB4BElHJP\n0ZlmttPMVreJ/cjMtprZisJ/F1c2TSB71DZSU8pZLrMkPSDplwfE73H3n2aeUQfY1yt+wXCxtZVj\nUlxbWZJqrGs0vternEhlzFLitV3M0xsbgtgtx6yKjh1e+1EQO3flJ2Vtf1SvZ6Lx83qGn1vscv6Y\nyW9cHo2fuGPNYXxKfh2yY7n7C5Ler0IuQFVR20hNOd+hX29mKwv/2xq/fQmQT9Q2cqm9Df0hSZ+X\n1CBpm6SfFRtoZpPMrMnMmvZqTzs3B1QNtY3caldDd/cd7r7P3VskPSJp1EHGTnf3RndvrFH83oBA\nZ0FtI8/adem/mfV3922Fp5dJWn2w8Z3d6advisaLrQ99JNnr8QukU/2zSa22izn+qq1BbNyzl0XH\n/sfQ8L4AxSZQy3XurTcEsZYJ4TITkvRiwxNB7LhHemWeU54csqGb2ZOSxkg61sy2SJoqaYyZNUhy\nSZskfbuCOQIVQW0jNYds6O4+IRKeUYFcgKqitpEarhQFgETQ0AEgETR0AEgEN7hAu2xu/jSI9dwV\nxtA5tezeHQYviMQknX/Z3wexnSNKPxbsuy5cJ+Lox1+Jjt31r+H5/Osb5kTHzvigPoj1WrMtHCip\n+SD5pYQjdABIBA0dABJBQweARNDQASARTIri/33r0oUljx3/2C1BbNCSl7JMB51Er3lLg1j9vMps\na/35jwaxYstMPPjmV4LYCe+tzTynPOEIHQASQUMHgETQ0AEgETR0AEgEDR0AEsFZLpI+uuOEaLzp\nsfCO94218Rs+vPvUl4LYoCsqcxOAShnZc2MQe3WPRcfW3/VGEEvzlheohK6nDSnyymtBJLbMhCTV\n3dcjw4zSwBE6ACSChg4AiaChA0AiDtnQzWygmS0xs7VmtsbMbizE+5nZIjN7q/Czb+XTBbJDbSM1\npUyKNkua7O7LzewoSa+Z2SJJ10ha7O7TzGyKpCmSbq1cqpXT5b9fj8avu/f6ILbs1vujYxeNfiiI\nXXPe96Jjuy5ZfhjZZW/jnNOj8bN7hBNSX349dttNqd9Hv8s0pw6SfG13Vhumdi957BWvfysaP76D\n/x51Roc8Qnf3be6+vPB
|
||
|
|
"text/plain": [
|
||
|
|
"<matplotlib.figure.Figure at 0x7f3fcf264a20>"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
"metadata": {},
|
||
|
|
"output_type": "display_data"
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"source": [
|
||
|
|
"show_nth(5)"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": 14,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [
|
||
|
|
{
|
||
|
|
"name": "stdout",
|
||
|
|
"output_type": "stream",
|
||
|
|
"text": [
|
||
|
|
"0.0\n"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"data": {
|
||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAAkCAYAAACZmsEQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAACeFJREFUeJztnFuMZEUZx39f1Tl9mZ5ZmGF1g4ou\nKDESHwwQJMQQEw23F9QHA8a4URJe5E0fMDzIg4mK8cWgJGtCIj5A8MG4CTG6Eg1PXtaE5WKyslyM\nrMiwl5mdnek+l6rPh6ruHYZdZ2d2prunt35J55yuPtX1r6p/fXVOnz4lqkoikUgkdj5m1AISiUQi\nsTWkgJ5IJBITQgroiUQiMSGkgJ5IJBITQgroiUQiMSGkgJ5IJBITwkUFdBG5Q0SOiMhREXlwq0Ql\nEqMmeTuxE1k3oIvI4yIyLyIvrUqbE5GDwAHgHeAW4F4RuW77pCYSW0vydmLSkPUeLBKRW4EzwBOq\n+smY9ggwDXwU+CMwCywAqOr3t1NwIrFVJG8nJo1svQNU9TkR2bsm+W7gR0AD+AXwJ+B7wKfX5heR\n+4H7AUwrv2Hvx3IaUnPKdchwtE1JT3MWqzYAXmWwdc5grSKieC+oCiLhvRElNw5FaBiHoHiEnstx\nXjCiWFGcCl6FpnUo4Lyh9gYRxZUWyRT1AgLGeLwXRGLdnZA3aqraAmCM4p0BUUQI+SqBhgJxYqwM\nSHwrIDVoroiTcITV8LkXkJhHBUzcryUco3EL4AEXRWV6Nq8JZYCGz53EtFiGxLxWz+bv79tYh/58\nXhvIfEirJZShgGVQN6nCd2gek2JeUVAT35tVx5pQDYnH2gJ8FnSZEnwOGh0oDsQHvaYGWzg0M6EZ\nVZGiAmPxDYupHDgP1oAIaoPWXvcUpVuRtR48Hxfj7dW+llZ+w+V7L2NPfpoMT42h0IzFeooq+s2r\noNFb6kJ/m9hWqhL6IXrfZg5VIbOelq1waijqDCNKWWSY3JNZjxWPoBQuw1X2bL/GMSOiuNoOfGat\nYsSTGQ+AU4P3oczaWYzxmOgdE8cTgPeCL+27fVUL0vCDsQM68FD4cjk7DqLXbCy37mXBx9E8ZsXg\nW8HPErX2txrbRCuDbdY4F00Vi8GG8SBZqJfrj+3aIjboM0bPtjGAF0zmQ1otIQbErTUep4IRcLUB\nBZOFdvEqfcujLvy4IYM+ZND+YezE8TMYpzHdC+KD32PVBuNnMAYiWddTdQy2gOWFN4+r6vtYh3UD\n+nnYQzxrAf4b369LvqvFJ362j72dEyEgi2M2W+HZ+Y8zvzRNp1ly8vQUWeax1lMUGa1WRa+X46Ph\nQu1DkG63S3Lr+PBlCxzvdvoxlMVui1Ze46JZK2fJjGeuswJAJy850Z3idLeFqtBuVJS1JbNhIPXK\nnHazxMbOKqqMqWbJdKPkrcVdNPMKAGuCAfZMneE/Z3bRqzJ63QZZ7ui0Soo6lHtqfgYEZq5YJjOe\nss5YPtWmtavg8ukVTp7u0GjU5NbRK3PqyjJ32TKnlqYwJg7MaCDnDHnumGoVeB/SlleaIEoWB1Nd\nh4nQe6GuLFOdAu+FoshxvYxGp0Q1DPC6suSNmu5ii6xd452gziALeQjGHnS2Qlcs0p8wpqswkAsL\nDY89keN21aGT40Rkliy2EKpZBx6kMogHU4Q+EQ+urUgt+KbHFKEu+RnBFtCeV1b2hH0EXANaJxXb\ng7oNjSUlXwn9YyolW3EcOvToJqz8Hjbs7Xymzecf/xIdWzLXWOZ03aL2ln+c2oOqcGJhGps5vDNn\nT1hqgxil0awplhuIDX0nBj6we4HSWdp5hRGlqDO6VUblLL1uA0SxVsnz0Oa7WgV19ML8O7todUqq\nMsPVBi1N6I9+AKtMCDy5J2s6xHhmOj0WFjp0ZnoURU4WJ5S6shjryXPHmeMdpuZWEIGyyLCZp5HX\n1M6gKpRlmHCmOz0y6+k0SpbLRugfUZa6TTqtchAYiyocv9xt4J3FV+FEqdGusNZjTPC8qmBM0N5u\nlhRVjjGeurY4F4J3XVtarYqZdo9emSOiLHebcVILgdnYkMcYT7XYZGr3Ct3lJngha1WDydUYpejm\nYSJRBnq6y83QhvFExxcWyT2NdkWx2EJ6JpzA5D4E7VJozVvqaUXjHIQo+elwwmdCCME1IeuGgJ4t\nh/DmmmF85EuKawmmVA4/9q1/XYh5NxvQAY4BV6mqSphOPxTT3oWq7gf2A4jI0jO3PnrkvV/1zEXI\ngMMXlfuC2A0c3/5izvLq1nzN0HVvEZvV/ZEtKn9db6/19a9u2X8OX2+O17fqi9bnUvPHOLCt3t5s\nQH8b+DdwrYjcBMwD9wBfWSffEVW9cZNljgwROZR0D48R696Mt5Ovh8hO1Q3br32zf1s8AHwVeIBw\nev1+4GlVfXmrhCUSIyJ5O7FjWfcMXUSeBD4L7BaRN4HvAj8AngbuI/zi8WVVPbmNOhOJLSd5OzFp\nXMi/XO49z0ef20R5+zeRZxxIuofLUHRvobdTOw+Xnaobtln7uv9DTyQSicTOIK3lkkgkEhNCCuiJ\nRCIxIQwloI/7Qkci8oaIvCgiz4vIoZg2JyIHReSVuJ2N6SIiP4l1eUFErh+izvOuPbIRnSKyLx7/\niojsG5Huh0XkWGzz50XkrlWffSfqPiIit69KHzsfjaOm1SRvj0T36Lytqtv6Ijw8/ipwDeFx6sPA\nddtd7gY1vgHsXpP2CPBg3H8Q+GHcvwv4LeGh1JuBvwxR563A9cBLm9UJzAGvxe1s3J8dge6HgW+f\n49jrokeawNXRO3YcfTSOms6hMXl7+LpH5u1hnKHfBBxV1ddUtQSeIqyXMe7cTVjLg7j9wqr0JzTw\nZ+ByEblyGIJU9Tlg7V/oNqrzduCgqp5U1VPAQeCOEeg+H3cDT6lqoaqvA0cJHhpHH42jpgsheXt7\ndZ+Pbff2MAL6BwlP3vV5M6aNEwr8XkT+LmHRJYA9qvpW3F+9pse41WejOsdJ/wPxkvnx/uU0O0N3\nn3HUtJbk7dEwEm+nm6KBz6jq9cCdwDclLKs6QMP10tj/v3On6Iw8Rlii9lPAW8CPRytnYkneHj4j\n8/YwAvox4KpV78+5iNcoUdVjcTsP/JpwCfR2/3Izbufj4eNWn43qHAv9qvq2qjpV9cDPCW3O/9E3\nFrrXMI6a3kXy9qXl7WEE9L8RFjq6WkQahIWODgyh3AtCRDoiMtPfB24DXiJo7N8l3wf8Ju4fAL4W\n77TfDCyuuiwcBRvV+TvgNhGZjZeCt8W0obLmt9kvEtocgu57RKQpIlcD1wJ/ZTx9NI6aBiRvX4Le\n3s47wKvu7t4F/JNwJ/ehYZS5AW3XEO4qHwZe7usDrgCeBV4B/gDMxXQBfhrr8iJw4xC1Pkm4hKsI\nv7PdtxmdwDcIN2SOAl8fke5fRl0vRPNeuer4h6LuI8Cd4+yjcdSUvH3pejs9+p9IJBITQropmkgk\nEhNCCuiJRCIxIaSAnkgkEhNCCuiJRCIxIaSAnkgkEhNCCuiJRCIxIaSAnkgkEhPC/wCGQslVMqHG\n+wAAAABJRU5ErkJggg==\n",
|
||
|
|
"text/plain": [
|
||
|
|
"<matplotlib.figure.Figure at 0x7f3fcf1c84e0>"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
"metadata": {},
|
||
|
|
"output_type": "display_data"
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"source": [
|
||
|
|
"'''Train a Siamese MLP on pairs of digits from the MNIST dataset.\n",
|
||
|
|
"\n",
|
||
|
|
"It follows Hadsell-et-al.'06 [1] by computing the Euclidean distance on the\n",
|
||
|
|
"output of the shared network and by optimizing the contrastive loss (see paper\n",
|
||
|
|
"for mode details).\n",
|
||
|
|
"\n",
|
||
|
|
"[1] \"Dimensionality Reduction by Learning an Invariant Mapping\"\n",
|
||
|
|
" http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf\n",
|
||
|
|
"\n",
|
||
|
|
"Gets to 97.2% test accuracy after 20 epochs.\n",
|
||
|
|
"2 seconds per epoch on a Titan X Maxwell GPU\n",
|
||
|
|
"'''\n",
|
||
|
|
"from __future__ import absolute_import\n",
|
||
|
|
"from __future__ import print_function\n",
|
||
|
|
"import numpy as np\n",
|
||
|
|
"\n",
|
||
|
|
"# import random\n",
|
||
|
|
"# from keras.datasets import mnist\n",
|
||
|
|
"from speech_data import speech_model_data\n",
|
||
|
|
"from keras.models import Model\n",
|
||
|
|
"from keras.layers import Input, Dense, Dropout, SimpleRNN, LSTM, Lambda\n",
|
||
|
|
"# Dense, Dropout, Input, Lambda, LSTM, SimpleRNN\n",
|
||
|
|
"from keras.optimizers import RMSprop, SGD\n",
|
||
|
|
"from keras.callbacks import TensorBoard\n",
|
||
|
|
"from keras import backend as K\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"def euclidean_distance(vects):\n",
|
||
|
|
" x, y = vects\n",
|
||
|
|
" return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True),\n",
|
||
|
|
" K.epsilon()))\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"def eucl_dist_output_shape(shapes):\n",
|
||
|
|
" shape1, shape2 = shapes\n",
|
||
|
|
" return (shape1[0], 1)\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"def contrastive_loss(y_true, y_pred):\n",
|
||
|
|
" '''Contrastive loss from Hadsell-et-al.'06\n",
|
||
|
|
" http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf\n",
|
||
|
|
" '''\n",
|
||
|
|
" margin = 1\n",
|
||
|
|
" # print(y_true, y_pred)\n",
|
||
|
|
" return K.mean(y_true * K.square(y_pred) +\n",
|
||
|
|
" (1 - y_true) * K.square(K.maximum(margin - y_pred, 0)))\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"def create_base_rnn_network(input_dim):\n",
|
||
|
|
" '''Base network to be shared (eq. to feature extraction).\n",
|
||
|
|
" '''\n",
|
||
|
|
" inp = Input(shape=input_dim)\n",
|
||
|
|
" # d1 = Dense(1024, activation='sigmoid')(inp)\n",
|
||
|
|
" # # d2 = Dense(2, activation='sigmoid')(d1)\n",
|
||
|
|
" ls1 = LSTM(1024, return_sequences=True)(inp)\n",
|
||
|
|
" ls2 = LSTM(512, return_sequences=True)(ls1)\n",
|
||
|
|
" ls3 = LSTM(32)(ls2) # , return_sequences=True\n",
|
||
|
|
" # sr2 = SimpleRNN(128, return_sequences=True)(sr1)\n",
|
||
|
|
" # sr3 = SimpleRNN(32)(sr2)\n",
|
||
|
|
" # x = Dense(128, activation='relu')(sr1)\n",
|
||
|
|
" return Model(inp, ls3)\n",
|
||
|
|
"\n",
|
||
|
|
"def create_base_network(input_dim):\n",
|
||
|
|
" '''Base network to be shared (eq. to feature extraction).\n",
|
||
|
|
" '''\n",
|
||
|
|
" input = Input(shape=input_dim)\n",
|
||
|
|
" x = Dense(128, activation='relu')(input)\n",
|
||
|
|
" x = Dropout(0.1)(x)\n",
|
||
|
|
" x = Dense(128, activation='relu')(x)\n",
|
||
|
|
" x = Dropout(0.1)(x)\n",
|
||
|
|
" x = Dense(128, activation='relu')(x)\n",
|
||
|
|
" return Model(input, x)\n",
|
||
|
|
"\n",
|
||
|
|
"def compute_accuracy(y_true, y_pred):\n",
|
||
|
|
" '''Compute classification accuracy with a fixed threshold on distances.\n",
|
||
|
|
" '''\n",
|
||
|
|
" pred = y_pred.ravel() < 0.5\n",
|
||
|
|
" return np.mean(pred == y_true)\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"def accuracy(y_true, y_pred):\n",
|
||
|
|
" '''Compute classification accuracy with a fixed threshold on distances.\n",
|
||
|
|
" '''\n",
|
||
|
|
" return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"# the data, shuffled and split between train and test sets\n",
|
||
|
|
"tr_pairs, te_pairs, tr_y, te_y = speech_model_data()\n",
|
||
|
|
"\n",
|
||
|
|
"# y_train.shape,y_test.shape\n",
|
||
|
|
"# x_train.shape,x_test.shape\n",
|
||
|
|
"# x_train = x_train.reshape(60000, 784)\n",
|
||
|
|
"# x_test = x_test.reshape(10000, 784)\n",
|
||
|
|
"# x_train = x_train.astype('float32')\n",
|
||
|
|
"# x_test = x_test.astype('float32')\n",
|
||
|
|
"# x_train /= 255\n",
|
||
|
|
"# x_test /= 255\n",
|
||
|
|
"\n",
|
||
|
|
"# input_dim = (tr_pairs.shape[2], tr_pairs.shape[3])\n",
|
||
|
|
"# epochs = 20\n",
|
||
|
|
"\n",
|
||
|
|
"# # network definition\n",
|
||
|
|
"# base_network = create_base_rnn_network(input_dim)\n",
|
||
|
|
"# input_a = Input(shape=input_dim)\n",
|
||
|
|
"# input_b = Input(shape=input_dim)\n",
|
||
|
|
"\n",
|
||
|
|
"# # because we re-use the same instance `base_network`,\n",
|
||
|
|
"# # the weights of the network\n",
|
||
|
|
"# # will be shared across the two branches\n",
|
||
|
|
"# processed_a = base_network(input_a)\n",
|
||
|
|
"# processed_b = base_network(input_b)\n",
|
||
|
|
"\n",
|
||
|
|
"# distance = Lambda(euclidean_distance,\n",
|
||
|
|
"# output_shape=eucl_dist_output_shape)(\n",
|
||
|
|
"# [processed_a, processed_b]\n",
|
||
|
|
"# )\n",
|
||
|
|
"\n",
|
||
|
|
"# model = Model([input_a, input_b], distance)\n",
|
||
|
|
"\n",
|
||
|
|
"# tb_cb = TensorBoard(log_dir='./siamese_logs', histogram_freq=1, batch_size=32,\n",
|
||
|
|
"# write_graph=True, write_grads=True, write_images=True,\n",
|
||
|
|
"# embeddings_freq=0, embeddings_layer_names=None,\n",
|
||
|
|
"# embeddings_metadata=None)\n",
|
||
|
|
"# # train\n",
|
||
|
|
"# rms = RMSprop(lr=0.00001) # lr=0.001)\n",
|
||
|
|
"# sgd = SGD(lr=0.001)\n",
|
||
|
|
"# model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])\n",
|
||
|
|
"# model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,\n",
|
||
|
|
"# batch_size=128,\n",
|
||
|
|
"# epochs=epochs,\n",
|
||
|
|
"# validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y),\n",
|
||
|
|
"# callbacks=[tb_cb])\n",
|
||
|
|
"\n",
|
||
|
|
"# # compute final accuracy on training and test sets\n",
|
||
|
|
"# y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])\n",
|
||
|
|
"# tr_acc = compute_accuracy(tr_y, y_pred)\n",
|
||
|
|
"# y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])\n",
|
||
|
|
"# te_acc = compute_accuracy(te_y, y_pred)\n",
|
||
|
|
"\n",
|
||
|
|
"# print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))\n",
|
||
|
|
"# print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))\n"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": 28,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [
|
||
|
|
{
|
||
|
|
"name": "stdout",
|
||
|
|
"output_type": "stream",
|
||
|
|
"text": [
|
||
|
|
"0.0\n"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"data": {
|
||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2UAAAG6CAYAAACIge6AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAIABJREFUeJzsvXuYleV5qH9/s2aYgQEHBxwQHB0U\nhCiWMRAxokFJPMQkmtRkG9McTOqOTZNt293spmlyxaSpjWnTdue02/izxpo0xsacjHVvjfEQDxFF\nheABFARFURAQ5DjMrPl+f8ws6xEYc38zHy/PfV29GsfxZn0wrGc97/scsjzPCYIgCIIgCIIgCIaG\nuqF+AUEQBEEQBEEQBPsykZQFQRAEQRAEQRAMIZGUBUEQBEEQBEEQDCGRlAVBEARBEARBEAwhkZQF\nQRAEQRAEQRAMIZGUBUEQBEEQBEEQDCGRlAVBEARBEARBEAwhkZQFQRAEQRAEQRAMIZGUBUEQBEEQ\nBEEQDCH1Q/0CdkXWODZnRIcnbPFUL7BJ9u2wfdtkIXDoCNe3wdUFQWnYeO+6PM8PMFSTsyw3/zY/\nDdfneX6aqAwGmSwbm8MhnrAj81w1Vru6ylHdqq/aXVF9ACMa3Ljbg/sau6sNqg+AvICfnSBp8lVP\nkK9fr/zg2PERhiZGljop46AO+NsFnm+Lp3qB22Xf5bnre3cBb5Sfl33Xyr4gKAtfzB63VNuA8y0Z\n8EUYK+qCoWBYB0zwYuSMFXdprhqL2o9VfdX3qDr4oJvkAUw7ZL7q28j+qu+xNx+p+gD/c0GQPn86\nS1PZ8RGGJkaWOymrAKNFXxG/vTfKvk/JSdS3/tn1ASO//17Vt+Uu5SIhCJImo+xv2MGgs7MbVq7R\ndItudRMowP+h9T7HAVDXtNMVAlX5oUfg3gGM/81jqg/gmQcP1Z1B4gzzVKnEx3I/g52U2QkUwBLZ\nZyeOIz8hC2GLeHkJwELZFwRBsC8wogGOGKfpmjr9WvId01pV38jjn1V927bI5fh7Ac/82E+gGk58\nXncGadPTUB3ql1A6yp2U7QSeFH1vE101lsm+78u+kbIPP3DvuMsN2kGQIhlQQCdIsDczEjje0x3R\n8pAn6+e+leILBHp63P6q3o3Nqg+gOs59jV3mlQLALL9ksyr/uQTpk4t9iKnEx3InZSOATrHH6q4C\n+qs2yj63MhCekX1Afb18ulFEr18QJEYq5RmByAig09NNYqUn6+e+LW5StuMZ+RCvyU9QKrgxchxr\nVd+jO2aoPoDennh3CgZIrzcAPpX4WO5nqOulbqRXS9073j8R02+i7D5rO2kEtjwp94CZJapBEAT7\nCj3AOk93OI94shrTZZ9d7j6r3B+DAG5/6gTV1zDWLzXs3jJcdwaJk8mD7RKg3O9GO+roXSImUld7\nqhcYL/vsnjK7vBIYP9VtEn7mX6JBOAh2RyrlGYGIXOJflUevA9iXby3vdss/Nm8cpfqK4KCJq1Tf\nk8snqz6AprHP6c4gbboq3o1yKvGx3ElZUy9107Zqut53F3BTVvZx7v57Lxs3yVdbBbzGIAiC5GkB\nxC06hSRl8sHlppWusGl8+Rdlns51qu/Ow45TfQAbo+QlGCBrxKQsFUqdlNU39LD/uPWa79nRBSRl\ndnngXT92feed5foooKcsCILdkkrNfCCS01fCKDFcHr0OwC3eyH4ANnrTJgG/2gU/uZ2Nu/dsPrNV\nH/gLroP0MYsXU4mPu32GLMsuA94JrM3zfPqLvv4/gE8CVeA/8zz/i/6vfxb4w/6vX5Dn+fX9Xz8N\n+Dp9g+4vzfP84t392j3d9axfM2bAD/WaFDD0wgyIAIyUk6gHXB1AxU7KCuh7C4LUSKU8IzWGMkay\nFRBXlNz9Dv/DOpPdJKql0w3kejwDhtGl+uzBIWdzleoDuJTzdGeQNnX0aq5U4uOeJJaXA98Crqh9\nIcuyk4AzgRl5nndlWdbW//UjgPcDRwITgBuzLDu8/z/7NnAyfRXw92RZdk2e57uev9tTR+8z4u1W\npz9lie/LPwZ2T5mdNAKblshHi1H1EATB3svlDFWMXAN8zXuQIy70R+LfMO1M1bdpgRt/xs/xFynb\ny6PH4d422suoASpFfNgIkqaAeeh7Pbt958jz/NdZlnW87MufAC7O87yr/3tq81rPBH7Y//UVWZYt\nA47p/3fL8jx/DCDLsh/2f++uI8Bq4It78hh7yAcLyKPtEfa3yz5zz1s/rdOfUn0brp6o+oIgRVIp\nz0iNIY2RYwCxuMK+kSmEg9zD1e1dBSyPbnR1q2hXfUuZqvoAdtoPHSRPLqZlqcTH1/sMhwMnZFl2\nEbAD+HSe5/cAE3npUPcn+78GsOplX3/VOoksyz4OfByAMQe7SU+H6KohjiMG/Ne4UvZRwJJIcc9O\nEKRKKuUZ+wiDEyP3O1itNGjHnfIH9D29SOtB7s6unTvkxcygJ2X3c7Tquw13xH4QDDWpxMfXm5TV\nA63AscCbgP/IskyZa57n+SXAJQDZjJk5J4rv6M80ea6iuEu+Knu3u7gT/OlXeslmEATB0DI4MXLi\nrNzclbmNAm6N5LUsHZUVqm99c/kDkH1TVsSkxEKGxARJk6mjPtLg9SZlTwI/yfM8B+7OsqyXvo/W\nT8FL3j0O6v8au/j6a7+4hp2Mnbj6db7EV1Kd6F9u2lORrm16n+orgoOmPqr6uqZG2UOQJs+KrlTK\nM/YRBiVG2myjgAXAcs5z3+PuMJIDDnla9QGMlqdX2dMcn7x1iuoDaOos/2qBoFzsrHp3W6nEx9f7\nDD8DTgJu7m9SHkZfId81wA+yLPtH+pqYpwB30/f7NSXLskn0BZr3Ax/Y/S+TqQ2zmzeJR4r9XHut\nnES929UVMXHyyTvkN/SYvhgEQVoMTozsRu0b3kIBi5Q7ZN9Kt0hpfZM44bmfCeO8w2SANnnQBwf5\nNxQ79oIl3EHJqMYahZezJyPxrwROBMZmWfYkcCFwGXBZlmUPADuBj/SfCD6YZdl/0Nec3AN8Ms/z\nar/nU8D19I37vSzP8wd392vbI/Hbx/n18qP/4BHVt+iDx6o+PuXqAH9a4vi4wg6C3ZFKzXxqDGWM\ntPeUdRTRhHyQqztk7hLVt2qNWxpYBPY0R9b5c+9aOr2dssG+wZYG780rlfi4J9MXz3mNf/XB1/j+\ni4CLXuXr18HA1tJX6quMHL15IP/JLnn8O9M01wtOfUiFXJc9q4AegS2yb2UMRg2C3ZFK0EmNoYyR\nNAFiWFuDvJgZ9PJFewdY7zJx7U4N+bfRHjdf17FV9UEx+96CxBE/+qUSH0tdgpllOcMad3rCIp72\nZ7JvtJxEFTHbRG7cjvUmQRAEr4MKauWCPVAC0CcUj8I7qAV4w5z7VF8R2DdlvUv8RHTCXH/HXZA2\n2xA/3ydCqZOynu3DePbBgzVfw7uf11w1uq/dzxXal3mz/NLAGYe5w00WLZVLNoMgUUr9hh0MPr2o\nlQuFLACWd2WOxS2T6ymgqXm9fD1oJ6LmxM4a+msMkqdCr+pLIT6W+hkqw7tpmeYNoNrwqQKWFE+W\nfaX+E+lj6abDXaF98xYECZJKeUYgI8aMu199NdrvhnxTdsNTp6q+6RMXqz7wl3Db0xcPmen25QGs\nwx+YEqRNj/jmlUp8LHUKUN3RwIZlYiL16QIGSjwj90PJr7G1w50CBfDeytWq75LJf6L6giAIgoFT\nxPLoh0e+UfW9YaJbJmcnPOAnZUfgPvMCZqo+gJ32xuwgeTL5piwFSp2UVRq7GdnhzXTftExeegy0\nHuuuktkgl2ZUKv6SyFksUH1XTx70dTxBMCiYm3tS2cMSiPSg3kQdzlJP1s8N8qCPLvnD/wT8g8vN\n8mqBlfJegSISKDsRDdLHvNJIJT6W+hnqsl5GNXp1yu868heaq4b9Znn7iServhH4ZQo98slie8U/\nnQ2CMmAnZSmUZwQidaj
|
||
|
|
"text/plain": [
|
||
|
|
"<matplotlib.figure.Figure at 0x7f3fcd509f60>"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
"metadata": {},
|
||
|
|
"output_type": "display_data"
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"source": [
|
||
|
|
"def plot_spec(ims):\n",
|
||
|
|
" timebins, freqbins = np.shape(ims)\n",
|
||
|
|
" # import pdb;pdb.set_trace()\n",
|
||
|
|
"# plt.figure(figsize=(15, 7.5))\n",
|
||
|
|
" plt.imshow(np.transpose(ims), origin=\"lower\", aspect=\"auto\", cmap=\"jet\", interpolation=\"none\")\n",
|
||
|
|
" plt.colorbar()\n",
|
||
|
|
" xlocs = np.float32(np.linspace(0, timebins-1, 5))\n",
|
||
|
|
" plt.xticks(xlocs, [\"%.02f\" % l for l in ((xlocs*15/timebins)+(0.5*2**10))/22100])\n",
|
||
|
|
" ylocs = np.int16(np.round(np.linspace(0, freqbins-1, 10)))\n",
|
||
|
|
"# plt.yticks(ylocs, [\"%.02f\" % freq[i] for i in ylocs])\n",
|
||
|
|
" \n",
|
||
|
|
"def show_nth(n):\n",
|
||
|
|
" plt.figure(figsize=(15,7.5))\n",
|
||
|
|
" plt.subplot(1,2,1)\n",
|
||
|
|
" plot_spec(te_pairs[n][0].reshape(15,1654))\n",
|
||
|
|
" print(te_y[n])\n",
|
||
|
|
" plt.subplot(1,2,2)\n",
|
||
|
|
" plot_spec(te_pairs[n][1].reshape(15,1654))\n",
|
||
|
|
"show_nth(0)"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": null,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [],
|
||
|
|
"source": []
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"metadata": {
|
||
|
|
"kernelspec": {
|
||
|
|
"display_name": "Python 3",
|
||
|
|
"language": "python",
|
||
|
|
"name": "python3"
|
||
|
|
},
|
||
|
|
"language_info": {
|
||
|
|
"codemirror_mode": {
|
||
|
|
"name": "ipython",
|
||
|
|
"version": 3
|
||
|
|
},
|
||
|
|
"file_extension": ".py",
|
||
|
|
"mimetype": "text/x-python",
|
||
|
|
"name": "python",
|
||
|
|
"nbconvert_exporter": "python",
|
||
|
|
"pygments_lexer": "ipython3",
|
||
|
|
"version": "3.5.2"
|
||
|
|
}
|
||
|
|
},
|
||
|
|
"nbformat": 4,
|
||
|
|
"nbformat_minor": 2
|
||
|
|
}
|