parent
bb72c4045e
commit
d978272bdb
|
|
@ -93,6 +93,8 @@ def train_siamese(audio_group = 'audio'):
|
|||
batch_size = 128
|
||||
model_dir = './models/'+audio_group
|
||||
create_dir(model_dir)
|
||||
log_dir = './logs/'+audio_group
|
||||
create_dir(log_dir)
|
||||
tr_gen_fn,te_pairs,te_y,n_step,n_features,n_records = read_siamese_tfrecords_generator(audio_group,batch_size,256)
|
||||
tr_gen = tr_gen_fn()
|
||||
# tr_y = to_categorical(tr_y_e, num_classes=2)
|
||||
|
|
@ -102,7 +104,7 @@ def train_siamese(audio_group = 'audio'):
|
|||
model = siamese_model(input_dim)
|
||||
|
||||
tb_cb = TensorBoard(
|
||||
log_dir='./logs/siamese_logs',
|
||||
log_dir=log_dir,
|
||||
histogram_freq=1,
|
||||
batch_size=32,
|
||||
write_graph=True,
|
||||
|
|
@ -136,8 +138,8 @@ def train_siamese(audio_group = 'audio'):
|
|||
,epochs=1000
|
||||
,steps_per_epoch=n_records//batch_size
|
||||
,validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y)
|
||||
,use_multiprocessing=True)
|
||||
# ,callbacks=[tb_cb, cp_cb])
|
||||
,use_multiprocessing=True
|
||||
,callbacks=[tb_cb, cp_cb])
|
||||
model.save(model_dir+'/siamese_speech_model-final.h5')
|
||||
# compute final accuracy on training and test sets
|
||||
# y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
|
||||
|
|
|
|||
Loading…
Reference in New Issue