saving model and tensorboard

checkpointing model
master
Malar Kannan 2017-11-10 18:06:45 +05:30
parent bb72c4045e
commit d978272bdb
1 changed files with 5 additions and 3 deletions

View File

@ -93,6 +93,8 @@ def train_siamese(audio_group = 'audio'):
batch_size = 128
model_dir = './models/'+audio_group
create_dir(model_dir)
log_dir = './logs/'+audio_group
create_dir(log_dir)
tr_gen_fn,te_pairs,te_y,n_step,n_features,n_records = read_siamese_tfrecords_generator(audio_group,batch_size,256)
tr_gen = tr_gen_fn()
# tr_y = to_categorical(tr_y_e, num_classes=2)
@ -102,7 +104,7 @@ def train_siamese(audio_group = 'audio'):
model = siamese_model(input_dim)
tb_cb = TensorBoard(
log_dir='./logs/siamese_logs',
log_dir=log_dir,
histogram_freq=1,
batch_size=32,
write_graph=True,
@ -136,8 +138,8 @@ def train_siamese(audio_group = 'audio'):
,epochs=1000
,steps_per_epoch=n_records//batch_size
,validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y)
,use_multiprocessing=True)
# ,callbacks=[tb_cb, cp_cb])
,use_multiprocessing=True
,callbacks=[tb_cb, cp_cb])
model.save(model_dir+'/siamese_speech_model-final.h5')
# compute final accuracy on training and test sets
# y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])