58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
%matplotlib inline
|
|
from tensorflow.examples.tutorials.mnist import input_data
|
|
import numpy as np
|
|
import pandas as pd
|
|
from sklearn.model_selection import train_test_split
|
|
import matplotlib.pyplot as plt
|
|
|
|
mnist = input_data.read_data_sets('./SecondSunday/mnist_data', one_hot=True)
|
|
label_number = mnist.train.labels.argmax(axis=1)
|
|
number_imgs = {str(i):mnist.train.images[np.argwhere(label_number == i).squeeze()] for i in range(10)}
|
|
DATA_COUNT = 10240
|
|
# phone_number_digits = np.random.randint(10**9,10**10,size=(DATA_COUNT,10))
|
|
phone_number_digits = np.random.randint(10,size=(DATA_COUNT,10))
|
|
phone_number_digits.astype(str)
|
|
phone_number_strings = pd.DataFrame(phone_number_digits.astype(str).T).apply(lambda x: ''.join(x)).values
|
|
|
|
def pick_img(num):
|
|
rand_idx = np.random.randint(number_imgs[num].shape[0])
|
|
img = number_imgs[num][rand_idx].reshape(28,28)
|
|
return img
|
|
|
|
def create_phone_img(phon_no):
|
|
return np.hstack(tuple([pick_img(d) for d in phon_no]))
|
|
|
|
def create_phone_images(phone_array):
|
|
phone_number_images = []
|
|
for phon_no in phone_array:
|
|
phone_number_images.append(create_phone_img(phon_no))
|
|
return np.array(phone_number_images).reshape(-1,28*280)
|
|
|
|
phone_number_imgs = create_phone_images(phone_number_strings)
|
|
train_imgs,test_imgs,train_digits,test_digits = train_test_split(phone_number_imgs,phone_number_digits)
|
|
|
|
from keras.models import Sequential
|
|
from keras.layers import Dense, Activation
|
|
|
|
# model = Sequential([
|
|
# Dense(32, input_shape=(7840,)),
|
|
# Activation('relu'),
|
|
# # Dense(24, input_shape=(32,)),
|
|
# # Activation('relu'),
|
|
# Dense(10),
|
|
# Activation('linear'),
|
|
# ])
|
|
|
|
# model.compile(optimizer='sgd',
|
|
# loss='mean_squared_error',
|
|
# metrics=['accuracy'])
|
|
#
|
|
# model.fit(train_imgs, train_digits,
|
|
# batch_size=128,
|
|
# epochs=100,
|
|
# validation_data=(test_imgs, test_digits))
|
|
|
|
# img_idx = np.random.randint(phone_number_imgs.shape[0])
|
|
# print(phone_number_strings[img_idx])
|
|
# plt.imshow(phone_number_imgs[img_idx].reshape(28,280))
|