added batch normalization

master
Malar Kannan 2017-12-11 14:09:04 +05:30
parent fea9184aec
commit 05242d5991
2 changed files with 13 additions and 11 deletions

View File

@ -256,9 +256,9 @@ if __name__ == '__main__':
# plot_random_phrases()
# fix_csv('story_test_segments')
# plot_segments('story_test_segments')
# fix_csv('story_phrases')
# fix_csv('story_words')
# pass
create_segments_tfrecords('story_phrases_full', sample_count=0)
create_segments_tfrecords('story_words', sample_count=0)
# record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test')
# tr_gen = record_generator()
# for i in tr_gen:

View File

@ -49,16 +49,18 @@ def segment_model(input_dim):
return Model(inp, oup)
def simple_segment_model(input_dim):
# input_dim = (1000,300)
inp = Input(shape=input_dim)
b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp)
# b_gr1
# b_gr2 = Bidirectional(GRU(64, return_sequences=True),merge_mode='sum')(b_gr1)
b_gr3 = Bidirectional(GRU(64, return_sequences=True),merge_mode='sum')(b_gr1)
d1 = Dense(32, activation='relu')(b_gr3)
d2 = Dense(8, activation='relu')(d1)
d3 = Dense(1, activation='softmax')(d2)
oup = Reshape(target_shape=(input_dim[0],))(d3)
bn_b_gr1 = BatchNormalization(momentum=0.98)(b_gr1)
b_gr2 = Bidirectional(GRU(64, return_sequences=True),merge_mode='sum')(bn_b_gr1)
bn_b_gr2 = BatchNormalization(momentum=0.98)(b_gr2)
d1 = Dense(32, activation='relu')(bn_b_gr2)
bn_d1 = BatchNormalization(momentum=0.98)(d1)
d2 = Dense(8, activation='relu')(bn_d1)
bn_d2 = BatchNormalization(momentum=0.98)(d2)
d3 = Dense(1, activation='softmax')(bn_d2)
bn_d3 = BatchNormalization(momentum=0.98)(d3)
oup = Reshape(target_shape=(input_dim[0],))(bn_d3)
return Model(inp, oup)
def write_model_arch(mod,mod_file):
@ -132,4 +134,4 @@ def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0):
if __name__ == '__main__':
# pass
train_segment('story_phrases_full')#,'./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001)
train_segment('story_words')#,'./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001)