updated to third week
This commit is contained in:
@@ -1,16 +1,18 @@
|
||||
%matplotlib inline
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
mnist = input_data.read_data_sets('./SecondSunday/mnist_data', one_hot=True)
|
||||
label_number = mnist.train.labels.argmax(axis=1)
|
||||
number_imgs = {str(i):mnist.train.images[np.argwhere(label_number == i).squeeze()] for i in range(10)}
|
||||
DATA_COUNT = 100
|
||||
phone_number_digits = np.random.randint(10**9,10**10,size=(DATA_COUNT,10))
|
||||
DATA_COUNT = 10240
|
||||
# phone_number_digits = np.random.randint(10**9,10**10,size=(DATA_COUNT,10))
|
||||
phone_number_digits = np.random.randint(10,size=(DATA_COUNT,10))
|
||||
phone_number_digits.astype(str)
|
||||
phone_number_strings = phone_number_digits.astype(str)
|
||||
phone_number_strings = pd.DataFrame(phone_number_digits.astype(str).T).apply(lambda x: ''.join(x)).values
|
||||
|
||||
def pick_img(num):
|
||||
rand_idx = np.random.randint(number_imgs[num].shape[0])
|
||||
@@ -27,20 +29,29 @@ def create_phone_images(phone_array):
|
||||
return np.array(phone_number_images).reshape(-1,28*280)
|
||||
|
||||
phone_number_imgs = create_phone_images(phone_number_strings)
|
||||
train_imgs,test_imgs,train_digits,test_digits = train_test_split(phone_number_imgs,phone_number_digits)
|
||||
|
||||
from keras.models import Sequential
|
||||
from keras.layers import Dense, Activation
|
||||
|
||||
model = Sequential([
|
||||
Dense(32, input_shape=(7840,)),
|
||||
Activation('relu'),
|
||||
Dense(10),
|
||||
Activation('linear'),
|
||||
])
|
||||
# model = Sequential([
|
||||
# Dense(32, input_shape=(7840,)),
|
||||
# Activation('relu'),
|
||||
# # Dense(24, input_shape=(32,)),
|
||||
# # Activation('relu'),
|
||||
# Dense(10),
|
||||
# Activation('linear'),
|
||||
# ])
|
||||
|
||||
model.compile(optimizer='rmsprop',
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
# model.compile(optimizer='sgd',
|
||||
# loss='mean_squared_error',
|
||||
# metrics=['accuracy'])
|
||||
#
|
||||
# model.fit(train_imgs, train_digits,
|
||||
# batch_size=128,
|
||||
# epochs=100,
|
||||
# validation_data=(test_imgs, test_digits))
|
||||
|
||||
model.fit()
|
||||
# plt.imshow(phone_number_imgs[np.random.randint(phone_number_imgs.shape[0])])
|
||||
# img_idx = np.random.randint(phone_number_imgs.shape[0])
|
||||
# print(phone_number_strings[img_idx])
|
||||
# plt.imshow(phone_number_imgs[img_idx].reshape(28,280))
|
||||
|
||||
Reference in New Issue
Block a user