Merge branch 'master' of ssh://invnuc/~/Public/Repos/speech_scoring
commit
c8a07b3d7b
|
|
@ -74,7 +74,7 @@ def load_model_arch(mod_file):
|
||||||
model_f.close()
|
model_f.close()
|
||||||
return mod
|
return mod
|
||||||
|
|
||||||
def train_siamese(audio_group = 'audio'):
|
def train_siamese(audio_group = 'audio',resume_weights='',initial_epoch=0):
|
||||||
batch_size = 128
|
batch_size = 128
|
||||||
model_dir = './models/'+audio_group
|
model_dir = './models/'+audio_group
|
||||||
create_dir(model_dir)
|
create_dir(model_dir)
|
||||||
|
|
@ -114,19 +114,22 @@ def train_siamese(audio_group = 'audio'):
|
||||||
model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy])
|
model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy])
|
||||||
write_model_arch(model,model_dir+'/siamese_speech_model_arch.yaml')
|
write_model_arch(model,model_dir+'/siamese_speech_model_arch.yaml')
|
||||||
epoch_n_steps = step_count(n_records,batch_size)
|
epoch_n_steps = step_count(n_records,batch_size)
|
||||||
|
if resume_weights != '':
|
||||||
|
model.load_weights(resume_weights)
|
||||||
model.fit_generator(tr_gen
|
model.fit_generator(tr_gen
|
||||||
, epochs=1000
|
, epochs=1000
|
||||||
, steps_per_epoch=epoch_n_steps
|
, steps_per_epoch=epoch_n_steps
|
||||||
, validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y)
|
, validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y)
|
||||||
, max_queue_size=8
|
, max_queue_size=8
|
||||||
, callbacks=[tb_cb, cp_cb])
|
, callbacks=[tb_cb, cp_cb],initial_epoch=initial_epoch)
|
||||||
model.save(model_dir+'/siamese_speech_model-final.h5')
|
model.save(model_dir+'/siamese_speech_model-final.h5')
|
||||||
|
|
||||||
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
|
# y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
|
||||||
te_acc = compute_accuracy(te_y, y_pred)
|
# te_acc = compute_accuracy(te_y, y_pred)
|
||||||
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))
|
# print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train_siamese('story_words_pitch')
|
train_siamese('story_words_pitch')
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue