added batch normalization
parent
fea9184aec
commit
05242d5991
|
|
@ -256,9 +256,9 @@ if __name__ == '__main__':
|
|||
# plot_random_phrases()
|
||||
# fix_csv('story_test_segments')
|
||||
# plot_segments('story_test_segments')
|
||||
# fix_csv('story_phrases')
|
||||
# fix_csv('story_words')
|
||||
# pass
|
||||
create_segments_tfrecords('story_phrases_full', sample_count=0)
|
||||
create_segments_tfrecords('story_words', sample_count=0)
|
||||
# record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test')
|
||||
# tr_gen = record_generator()
|
||||
# for i in tr_gen:
|
||||
|
|
|
|||
|
|
@ -49,16 +49,18 @@ def segment_model(input_dim):
|
|||
return Model(inp, oup)
|
||||
|
||||
def simple_segment_model(input_dim):
|
||||
# input_dim = (1000,300)
|
||||
inp = Input(shape=input_dim)
|
||||
b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp)
|
||||
# b_gr1
|
||||
# b_gr2 = Bidirectional(GRU(64, return_sequences=True),merge_mode='sum')(b_gr1)
|
||||
b_gr3 = Bidirectional(GRU(64, return_sequences=True),merge_mode='sum')(b_gr1)
|
||||
d1 = Dense(32, activation='relu')(b_gr3)
|
||||
d2 = Dense(8, activation='relu')(d1)
|
||||
d3 = Dense(1, activation='softmax')(d2)
|
||||
oup = Reshape(target_shape=(input_dim[0],))(d3)
|
||||
bn_b_gr1 = BatchNormalization(momentum=0.98)(b_gr1)
|
||||
b_gr2 = Bidirectional(GRU(64, return_sequences=True),merge_mode='sum')(bn_b_gr1)
|
||||
bn_b_gr2 = BatchNormalization(momentum=0.98)(b_gr2)
|
||||
d1 = Dense(32, activation='relu')(bn_b_gr2)
|
||||
bn_d1 = BatchNormalization(momentum=0.98)(d1)
|
||||
d2 = Dense(8, activation='relu')(bn_d1)
|
||||
bn_d2 = BatchNormalization(momentum=0.98)(d2)
|
||||
d3 = Dense(1, activation='softmax')(bn_d2)
|
||||
bn_d3 = BatchNormalization(momentum=0.98)(d3)
|
||||
oup = Reshape(target_shape=(input_dim[0],))(bn_d3)
|
||||
return Model(inp, oup)
|
||||
|
||||
def write_model_arch(mod,mod_file):
|
||||
|
|
@ -132,4 +134,4 @@ def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0):
|
|||
|
||||
if __name__ == '__main__':
|
||||
# pass
|
||||
train_segment('story_phrases_full')#,'./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001)
|
||||
train_segment('story_words')#,'./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001)
|
||||
|
|
|
|||
Loading…
Reference in New Issue