From 52bbb69c6556f6e829246b8891ab1a61e69d3d26 Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Sun, 10 Dec 2017 21:58:55 +0530 Subject: [PATCH] resuming segment training --- segment_model.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/segment_model.py b/segment_model.py index cd3aecf..3b9995f 100644 --- a/segment_model.py +++ b/segment_model.py @@ -54,8 +54,6 @@ def segment_model(input_dim): return Model(inp, oup) def simple_segment_model(input_dim): - # input_dim = (100,100) - # input_dim = (506,743) inp = Input(shape=input_dim) b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp) # b_gr1 @@ -75,7 +73,7 @@ def load_model_arch(mod_file): model_f.close() return mod -def train_segment(collection_name = 'test'): +def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0): # collection_name = 'story_test' batch_size = 64 # batch_size = 4 @@ -101,7 +99,7 @@ def train_segment(collection_name = 'test'): embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None) - cp_file_fmt = model_dir+'/siamese_speech_model-{epoch:02d}-epoch-{val_loss:0.2f}\ + cp_file_fmt = model_dir+'/speech_segment_model-{epoch:02d}-epoch-{val_loss:0.2f}\ -acc.h5' cp_cb = ModelCheckpoint( @@ -115,14 +113,16 @@ def train_segment(collection_name = 'test'): # train rms = RMSprop() model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy]) - write_model_arch(model,model_dir+'/siamese_speech_model_arch.yaml') + write_model_arch(model,model_dir+'/speech_segment_model_arch.yaml') epoch_n_steps = step_count(n_records,batch_size) + if resume_weights != '': + model.load_weights(resume_weights) model.fit_generator(tr_gen - , epochs=1000 + , epochs=10000 , steps_per_epoch=epoch_n_steps , validation_data=(te_x, te_y) , max_queue_size=32 - , callbacks=[tb_cb, cp_cb]) + , callbacks=[tb_cb, cp_cb],initial_epoch=initial_epoch) model.save(model_dir+'/speech_segment_model-final.h5') # y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]]) @@ -133,4 +133,4 @@ def train_segment(collection_name = 'test'): if __name__ == '__main__': # pass - train_segment('story_phrases') + train_segment('story_phrases','./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001)