resuming segment training
parent
03edd935ea
commit
52bbb69c65
|
|
@ -54,8 +54,6 @@ def segment_model(input_dim):
|
||||||
return Model(inp, oup)
|
return Model(inp, oup)
|
||||||
|
|
||||||
def simple_segment_model(input_dim):
|
def simple_segment_model(input_dim):
|
||||||
# input_dim = (100,100)
|
|
||||||
# input_dim = (506,743)
|
|
||||||
inp = Input(shape=input_dim)
|
inp = Input(shape=input_dim)
|
||||||
b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp)
|
b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp)
|
||||||
# b_gr1
|
# b_gr1
|
||||||
|
|
@ -75,7 +73,7 @@ def load_model_arch(mod_file):
|
||||||
model_f.close()
|
model_f.close()
|
||||||
return mod
|
return mod
|
||||||
|
|
||||||
def train_segment(collection_name = 'test'):
|
def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0):
|
||||||
# collection_name = 'story_test'
|
# collection_name = 'story_test'
|
||||||
batch_size = 64
|
batch_size = 64
|
||||||
# batch_size = 4
|
# batch_size = 4
|
||||||
|
|
@ -101,7 +99,7 @@ def train_segment(collection_name = 'test'):
|
||||||
embeddings_freq=0,
|
embeddings_freq=0,
|
||||||
embeddings_layer_names=None,
|
embeddings_layer_names=None,
|
||||||
embeddings_metadata=None)
|
embeddings_metadata=None)
|
||||||
cp_file_fmt = model_dir+'/siamese_speech_model-{epoch:02d}-epoch-{val_loss:0.2f}\
|
cp_file_fmt = model_dir+'/speech_segment_model-{epoch:02d}-epoch-{val_loss:0.2f}\
|
||||||
-acc.h5'
|
-acc.h5'
|
||||||
|
|
||||||
cp_cb = ModelCheckpoint(
|
cp_cb = ModelCheckpoint(
|
||||||
|
|
@ -115,14 +113,16 @@ def train_segment(collection_name = 'test'):
|
||||||
# train
|
# train
|
||||||
rms = RMSprop()
|
rms = RMSprop()
|
||||||
model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy])
|
model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy])
|
||||||
write_model_arch(model,model_dir+'/siamese_speech_model_arch.yaml')
|
write_model_arch(model,model_dir+'/speech_segment_model_arch.yaml')
|
||||||
epoch_n_steps = step_count(n_records,batch_size)
|
epoch_n_steps = step_count(n_records,batch_size)
|
||||||
|
if resume_weights != '':
|
||||||
|
model.load_weights(resume_weights)
|
||||||
model.fit_generator(tr_gen
|
model.fit_generator(tr_gen
|
||||||
, epochs=1000
|
, epochs=10000
|
||||||
, steps_per_epoch=epoch_n_steps
|
, steps_per_epoch=epoch_n_steps
|
||||||
, validation_data=(te_x, te_y)
|
, validation_data=(te_x, te_y)
|
||||||
, max_queue_size=32
|
, max_queue_size=32
|
||||||
, callbacks=[tb_cb, cp_cb])
|
, callbacks=[tb_cb, cp_cb],initial_epoch=initial_epoch)
|
||||||
model.save(model_dir+'/speech_segment_model-final.h5')
|
model.save(model_dir+'/speech_segment_model-final.h5')
|
||||||
|
|
||||||
# y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
|
# y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
|
||||||
|
|
@ -133,4 +133,4 @@ def train_segment(collection_name = 'test'):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# pass
|
# pass
|
||||||
train_segment('story_phrases')
|
train_segment('story_phrases','./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue