131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import print_function
|
|
import numpy as np
|
|
from speech_data import read_siamese_tfrecords_generator
|
|
from keras.models import Model,load_model,model_from_yaml
|
|
from keras.layers import Input,Concatenate,Lambda, BatchNormalization, Dropout
|
|
from keras.layers import Dense, LSTM, Bidirectional, GRU
|
|
from keras.losses import categorical_crossentropy
|
|
from keras.utils import to_categorical
|
|
from keras.optimizers import RMSprop
|
|
from keras.callbacks import TensorBoard, ModelCheckpoint
|
|
from keras import backend as K
|
|
from keras.utils import plot_model
|
|
from speech_tools import create_dir,step_count
|
|
|
|
|
|
def create_base_rnn_network(input_dim):
|
|
'''Base network to be shared (eq. to feature extraction).
|
|
'''
|
|
inp = Input(shape=input_dim)
|
|
# ls0 = LSTM(512, return_sequences=True)(inp)
|
|
ls1 = LSTM(128, return_sequences=True)(inp)
|
|
bn_ls1 = BatchNormalization(momentum=0.98)(ls1)
|
|
ls2 = LSTM(64, return_sequences=True)(bn_ls1)
|
|
bn_ls2 = BatchNormalization(momentum=0.98)(ls2)
|
|
# ls3 = LSTM(32, return_sequences=True)(ls2)
|
|
ls4 = LSTM(32)(bn_ls2)
|
|
# d1 = Dense(128, activation='relu')(ls4)
|
|
#d2 = Dense(64, activation='relu')(ls2)
|
|
return Model(inp, ls4)
|
|
|
|
|
|
def compute_accuracy(y_true, y_pred):
|
|
'''Compute classification accuracy with a fixed threshold on distances.
|
|
'''
|
|
pred = y_pred.ravel() > 0.5
|
|
return np.mean(pred == y_true)
|
|
|
|
|
|
def accuracy(y_true, y_pred):
|
|
'''Compute classification accuracy with a fixed threshold on distances.
|
|
'''
|
|
return K.mean(K.equal(y_true, K.cast(y_pred > 0.5, y_true.dtype)))
|
|
|
|
def dense_classifier(processed):
|
|
conc_proc = Concatenate()(processed)
|
|
d1 = Dense(64, activation='relu')(conc_proc)
|
|
# dr1 = Dropout(0.1)(d1)
|
|
# d2 = Dense(128, activation='relu')(d1)
|
|
d3 = Dense(8, activation='relu')(d1)
|
|
# dr2 = Dropout(0.1)(d2)
|
|
return Dense(2, activation='softmax')(d3)
|
|
|
|
def siamese_model(input_dim):
|
|
base_network = create_base_rnn_network(input_dim)
|
|
input_a = Input(shape=input_dim)
|
|
input_b = Input(shape=input_dim)
|
|
processed_a = base_network(input_a)
|
|
processed_b = base_network(input_b)
|
|
final_output = dense_classifier([processed_a,processed_b])
|
|
model = Model([input_a, input_b], final_output)
|
|
return model,base_network
|
|
|
|
def write_model_arch(mod,mod_file):
|
|
model_f = open(mod_file,'w')
|
|
model_f.write(mod.to_yaml())
|
|
model_f.close()
|
|
|
|
def load_model_arch(mod_file):
|
|
model_f = open(mod_file,'r')
|
|
mod = model_from_yaml(model_f.read())
|
|
model_f.close()
|
|
return mod
|
|
|
|
def train_siamese(audio_group = 'audio'):
|
|
batch_size = 128
|
|
model_dir = './models/'+audio_group
|
|
create_dir(model_dir)
|
|
log_dir = './logs/'+audio_group
|
|
create_dir(log_dir)
|
|
tr_gen_fn,te_pairs,te_y,copy_read_consts = read_siamese_tfrecords_generator(audio_group,batch_size=batch_size,test_size=batch_size)
|
|
n_step,n_features,n_records = copy_read_consts(model_dir)
|
|
tr_gen = tr_gen_fn()
|
|
input_dim = (n_step, n_features)
|
|
|
|
model,base_model = siamese_model(input_dim)
|
|
plot_model(model,show_shapes=True, to_file=model_dir+'/model.png')
|
|
plot_model(base_model,show_shapes=True, to_file=model_dir+'/base_model.png')
|
|
tb_cb = TensorBoard(
|
|
log_dir=log_dir,
|
|
histogram_freq=1,
|
|
batch_size=32,
|
|
write_graph=True,
|
|
write_grads=True,
|
|
write_images=True,
|
|
embeddings_freq=0,
|
|
embeddings_layer_names=None,
|
|
embeddings_metadata=None)
|
|
cp_file_fmt = model_dir+'/siamese_speech_model-{epoch:02d}-epoch-{val_loss:0.2f}\
|
|
-acc.h5'
|
|
|
|
cp_cb = ModelCheckpoint(
|
|
cp_file_fmt,
|
|
monitor='val_loss',
|
|
verbose=0,
|
|
save_best_only=True,
|
|
save_weights_only=True,
|
|
mode='auto',
|
|
period=1)
|
|
# train
|
|
rms = RMSprop()
|
|
model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy])
|
|
write_model_arch(model,model_dir+'/siamese_speech_model_arch.yaml')
|
|
epoch_n_steps = step_count(n_records,batch_size)
|
|
model.fit_generator(tr_gen
|
|
, epochs=1000
|
|
, steps_per_epoch=epoch_n_steps
|
|
, validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y)
|
|
, max_queue_size=8
|
|
, callbacks=[tb_cb, cp_cb])
|
|
model.save(model_dir+'/siamese_speech_model-final.h5')
|
|
|
|
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
|
|
te_acc = compute_accuracy(te_y, y_pred)
|
|
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
train_siamese('story_words_test')
|