diff --git a/segment_data.py b/segment_data.py index 521fe02..eaceec3 100644 --- a/segment_data.py +++ b/segment_data.py @@ -258,7 +258,7 @@ if __name__ == '__main__': # plot_segments('story_test_segments') # fix_csv('story_words') # pass - create_segments_tfrecords('story_words.3', sample_count=3,train_test_ratio=0.33) + create_segments_tfrecords('story_words.30', sample_count=36,train_test_ratio=0.1) # record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test') # tr_gen = record_generator() # for i in tr_gen: diff --git a/segment_model.py b/segment_model.py index d6c8f37..de21753 100644 --- a/segment_model.py +++ b/segment_model.py @@ -7,7 +7,7 @@ from keras.layers import Dense,Conv2D, LSTM, Bidirectional, GRU from keras.layers import BatchNormalization,Activation from keras.losses import categorical_crossentropy from keras.utils import to_categorical -from keras.optimizers import RMSprop +from keras.optimizers import RMSprop,Adadelta,Adagrad,Adam,Nadam from keras.callbacks import TensorBoard, ModelCheckpoint from keras import backend as K from keras.utils import plot_model @@ -50,20 +50,24 @@ def segment_model(input_dim): def simple_segment_model(input_dim): inp = Input(shape=input_dim) - b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp) + b_gr1 = Bidirectional(LSTM(32, return_sequences=True))(inp) + b_gr1 = Bidirectional(LSTM(16, return_sequences=True),merge_mode='sum')(b_gr1) + b_gr1 = LSTM(1, return_sequences=True,activation='softmax')(b_gr1) + # b_gr1 = LSTM(4, return_sequences=True)(b_gr1) + # b_gr1 = LSTM(2, return_sequences=True)(b_gr1) # bn_b_gr1 = BatchNormalization(momentum=0.98)(b_gr1) - b_gr2 = Bidirectional(GRU(64, return_sequences=True),merge_mode='sum')(b_gr1) + # b_gr2 = GRU(64, return_sequences=True)(b_gr1) # bn_b_gr2 = BatchNormalization(momentum=0.98)(b_gr2) - d1 = Dense(32)(b_gr2) - bn_d1 = BatchNormalization(momentum=0.98)(d1) - bn_da1 = Activation('relu')(bn_d1) - d2 = Dense(8)(bn_da1) - bn_d2 = BatchNormalization(momentum=0.98)(d2) - bn_da2 = Activation('relu')(bn_d2) - d3 = Dense(1)(bn_da2) - bn_d3 = BatchNormalization(momentum=0.98)(d3) - bn_da3 = Activation('softmax')(bn_d3) - oup = Reshape(target_shape=(input_dim[0],))(bn_da3) + # d1 = Dense(32)(b_gr2) + # bn_d1 = BatchNormalization(momentum=0.98)(d1) + # bn_da1 = Activation('relu')(bn_d1) + # d2 = Dense(8)(bn_da1) + # bn_d2 = BatchNormalization(momentum=0.98)(d2) + # bn_da2 = Activation('relu')(bn_d2) + # d3 = Dense(1)(b_gr1) + # # bn_d3 = BatchNormalization(momentum=0.98)(d3) + # bn_da3 = Activation('softmax')(d3) + oup = Reshape(target_shape=(input_dim[0],))(b_gr1) return Model(inp, oup) def write_model_arch(mod,mod_file): @@ -79,13 +83,13 @@ def load_model_arch(mod_file): def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0): # collection_name = 'story_test' - # batch_size = 32 - batch_size = 1 + batch_size = 128 + # batch_size = 4 model_dir = './models/segment/'+collection_name create_dir(model_dir) log_dir = './logs/segment/'+collection_name create_dir(log_dir) - tr_gen_fn,te_x,te_y,copy_read_consts = read_segments_tfrecords_generator(collection_name,batch_size,batch_size) + tr_gen_fn,te_x,te_y,copy_read_consts = read_segments_tfrecords_generator(collection_name,batch_size,2*batch_size) tr_gen = tr_gen_fn() n_step,n_features,n_records = copy_read_consts(model_dir) input_dim = (n_step, n_features) @@ -115,8 +119,8 @@ def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0): mode='auto', period=1) # train - rms = RMSprop() - model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy]) + opt = RMSprop() + model.compile(loss=categorical_crossentropy, optimizer=opt, metrics=[accuracy]) write_model_arch(model,model_dir+'/speech_segment_model_arch.yaml') epoch_n_steps = step_count(n_records,batch_size) if resume_weights != '': @@ -137,4 +141,4 @@ def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0): if __name__ == '__main__': # pass - train_segment('story_words.3')#,'./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001) + train_segment('story_words')#,'./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001)