updated test code
parent
6fbf06814c
commit
4188585488
|
|
@ -4,8 +4,9 @@ import numpy as np
|
|||
from speech_data import speech_model_data
|
||||
from keras.models import Model,load_model
|
||||
from keras.layers import Input, Dense, Dropout, LSTM, Lambda, Concatenate
|
||||
# from keras.losses import categorical_crossentropy
|
||||
from keras.losses import binary_crossentropy
|
||||
from keras.losses import categorical_crossentropy
|
||||
# from keras.losses import binary_crossentropy
|
||||
from keras.utils import to_categorical
|
||||
# from keras.utils.np_utils import to_categorical
|
||||
from keras.optimizers import RMSprop
|
||||
from keras.callbacks import TensorBoard, ModelCheckpoint
|
||||
|
|
@ -30,15 +31,14 @@ def contrastive_loss(y_true, y_pred):
|
|||
return K.mean(y_true * K.square(y_pred) +
|
||||
(1 - y_true) * K.square(K.maximum(1 - y_pred, 0)))
|
||||
|
||||
|
||||
def create_base_rnn_network(input_dim):
|
||||
'''Base network to be shared (eq. to feature extraction).
|
||||
'''
|
||||
inp = Input(shape=input_dim)
|
||||
ls1 = LSTM(1024, return_sequences=True)(inp)
|
||||
ls2 = LSTM(512, return_sequences=True)(ls1)
|
||||
ls1 = LSTM(256, return_sequences=True)(inp)
|
||||
ls2 = LSTM(128, return_sequences=True)(ls1)
|
||||
# ls3 = LSTM(32, return_sequences=True)(ls2)
|
||||
ls4 = LSTM(32)(ls2)
|
||||
ls4 = LSTM(64)(ls2)
|
||||
return Model(inp, ls4)
|
||||
|
||||
|
||||
|
|
@ -52,15 +52,15 @@ def compute_accuracy(y_true, y_pred):
|
|||
def accuracy(y_true, y_pred):
|
||||
'''Compute classification accuracy with a fixed threshold on distances.
|
||||
'''
|
||||
return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))
|
||||
return K.mean(K.equal(y_true, K.cast(y_pred > 0.5, y_true.dtype)))
|
||||
|
||||
def dense_classifier(processed):
|
||||
conc_proc = Concatenate()(processed)
|
||||
d1 = Dense(8, activation='relu')(conc_proc)
|
||||
dr1 = Dropout(0.1)(d1)
|
||||
# d2 = Dense(8, activation='relu')(dr1)
|
||||
d1 = Dense(16, activation='relu')(conc_proc)
|
||||
# dr1 = Dropout(0.1)(d1)
|
||||
d2 = Dense(8, activation='relu')(d1)
|
||||
# dr2 = Dropout(0.1)(d2)
|
||||
return Dense(1, activation='sigmoid')(dr1)
|
||||
return Dense(2, activation='softmax')(d2)
|
||||
|
||||
def siamese_model(input_dim):
|
||||
# input_dim = (15, 1654)
|
||||
|
|
@ -80,7 +80,9 @@ def siamese_model(input_dim):
|
|||
|
||||
def train_siamese():
|
||||
# the data, shuffled and split between train and test sets
|
||||
tr_pairs, te_pairs, tr_y, te_y = speech_model_data()
|
||||
tr_pairs, te_pairs, tr_y_e, te_y_e = speech_model_data()
|
||||
tr_y = to_categorical(tr_y_e, num_classes=2)
|
||||
te_y = to_categorical(te_y_e, num_classes=2)
|
||||
input_dim = (tr_pairs.shape[2], tr_pairs.shape[3])
|
||||
|
||||
model = siamese_model(input_dim)
|
||||
|
|
@ -96,7 +98,7 @@ def train_siamese():
|
|||
embeddings_layer_names=None,
|
||||
embeddings_metadata=None)
|
||||
cp_file_fmt = './models/siamese_speech_model-{epoch:02d}-epoch-{val_loss:0.2f}\
|
||||
-acc.h5'
|
||||
-acc.h5'
|
||||
|
||||
cp_cb = ModelCheckpoint(
|
||||
cp_file_fmt,
|
||||
|
|
@ -108,7 +110,7 @@ def train_siamese():
|
|||
period=1)
|
||||
# train
|
||||
rms = RMSprop(lr=0.001)
|
||||
model.compile(loss=binary_crossentropy, optimizer=rms, metrics=[accuracy])
|
||||
model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy])
|
||||
model.fit(
|
||||
[tr_pairs[:, 0], tr_pairs[:, 1]],
|
||||
tr_y,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import speech_data
|
|||
reload(speech_data)
|
||||
from speech_data import create_test_pair,get_word_pairs_data
|
||||
import numpy as np
|
||||
from keras.utils import to_categorical
|
||||
|
||||
model = siamese_model((15, 1654))
|
||||
model.load_weights('./models/siamese_speech_model-final.h5')
|
||||
|
|
@ -17,6 +18,6 @@ def predict_recording_with(m,sample_size=15):
|
|||
|
||||
predict_recording_with(model)
|
||||
|
||||
sunflower_data,sunflower_result = get_word_pairs_data('sunflowers',15)
|
||||
sunflower_result
|
||||
model.predict([sunflower_data[:, 0], sunflower_data[:, 1]]) < 0.5
|
||||
sunflower_data,sunflower_result = get_word_pairs_data('sweater',15)
|
||||
print(np.argmax(model.predict([sunflower_data[:, 0], sunflower_data[:, 1]]),axis=1))
|
||||
print(sunflower_result)
|
||||
|
|
|
|||
Loading…
Reference in New Issue