added a resume parameter for training

master
Malar Kannan 2017-12-07 12:00:42 +05:30
parent 3f76207f0d
commit 435c4a4aa6
1 changed files with 8 additions and 5 deletions

View File

@ -74,7 +74,7 @@ def load_model_arch(mod_file):
model_f.close()
return mod
def train_siamese(audio_group = 'audio'):
def train_siamese(audio_group = 'audio',resume_weights='',initial_epoch=0):
batch_size = 128
model_dir = './models/'+audio_group
create_dir(model_dir)
@ -114,19 +114,22 @@ def train_siamese(audio_group = 'audio'):
model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy])
write_model_arch(model,model_dir+'/siamese_speech_model_arch.yaml')
epoch_n_steps = step_count(n_records,batch_size)
if resume_weights != '':
model.load_weights(resume_weights)
model.fit_generator(tr_gen
, epochs=1000
, steps_per_epoch=epoch_n_steps
, validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y)
, max_queue_size=8
, callbacks=[tb_cb, cp_cb])
, callbacks=[tb_cb, cp_cb],initial_epoch=initial_epoch)
model.save(model_dir+'/siamese_speech_model-final.h5')
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))
# y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
# te_acc = compute_accuracy(te_y, y_pred)
# print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))
if __name__ == '__main__':
train_siamese('story_words_pitch')