resuming segment training
parent
03edd935ea
commit
52bbb69c65
|
|
@ -54,8 +54,6 @@ def segment_model(input_dim):
|
|||
return Model(inp, oup)
|
||||
|
||||
def simple_segment_model(input_dim):
|
||||
# input_dim = (100,100)
|
||||
# input_dim = (506,743)
|
||||
inp = Input(shape=input_dim)
|
||||
b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp)
|
||||
# b_gr1
|
||||
|
|
@ -75,7 +73,7 @@ def load_model_arch(mod_file):
|
|||
model_f.close()
|
||||
return mod
|
||||
|
||||
def train_segment(collection_name = 'test'):
|
||||
def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0):
|
||||
# collection_name = 'story_test'
|
||||
batch_size = 64
|
||||
# batch_size = 4
|
||||
|
|
@ -101,7 +99,7 @@ def train_segment(collection_name = 'test'):
|
|||
embeddings_freq=0,
|
||||
embeddings_layer_names=None,
|
||||
embeddings_metadata=None)
|
||||
cp_file_fmt = model_dir+'/siamese_speech_model-{epoch:02d}-epoch-{val_loss:0.2f}\
|
||||
cp_file_fmt = model_dir+'/speech_segment_model-{epoch:02d}-epoch-{val_loss:0.2f}\
|
||||
-acc.h5'
|
||||
|
||||
cp_cb = ModelCheckpoint(
|
||||
|
|
@ -115,14 +113,16 @@ def train_segment(collection_name = 'test'):
|
|||
# train
|
||||
rms = RMSprop()
|
||||
model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy])
|
||||
write_model_arch(model,model_dir+'/siamese_speech_model_arch.yaml')
|
||||
write_model_arch(model,model_dir+'/speech_segment_model_arch.yaml')
|
||||
epoch_n_steps = step_count(n_records,batch_size)
|
||||
if resume_weights != '':
|
||||
model.load_weights(resume_weights)
|
||||
model.fit_generator(tr_gen
|
||||
, epochs=1000
|
||||
, epochs=10000
|
||||
, steps_per_epoch=epoch_n_steps
|
||||
, validation_data=(te_x, te_y)
|
||||
, max_queue_size=32
|
||||
, callbacks=[tb_cb, cp_cb])
|
||||
, callbacks=[tb_cb, cp_cb],initial_epoch=initial_epoch)
|
||||
model.save(model_dir+'/speech_segment_model-final.h5')
|
||||
|
||||
# y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
|
||||
|
|
@ -133,4 +133,4 @@ def train_segment(collection_name = 'test'):
|
|||
|
||||
if __name__ == '__main__':
|
||||
# pass
|
||||
train_segment('story_phrases')
|
||||
train_segment('story_phrases','./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001)
|
||||
|
|
|
|||
Loading…
Reference in New Issue