diff --git a/segment_data.py b/segment_data.py index 51530f5..fb3c28a 100644 --- a/segment_data.py +++ b/segment_data.py @@ -256,9 +256,9 @@ if __name__ == '__main__': # plot_random_phrases() # fix_csv('story_test_segments') # plot_segments('story_test_segments') - # fix_csv('story_phrases') + # fix_csv('story_words') # pass - create_segments_tfrecords('story_phrases_full', sample_count=0) + create_segments_tfrecords('story_words', sample_count=0) # record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test') # tr_gen = record_generator() # for i in tr_gen: diff --git a/segment_model.py b/segment_model.py index cd6caa2..e6d5071 100644 --- a/segment_model.py +++ b/segment_model.py @@ -49,16 +49,18 @@ def segment_model(input_dim): return Model(inp, oup) def simple_segment_model(input_dim): - # input_dim = (1000,300) inp = Input(shape=input_dim) b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp) - # b_gr1 - # b_gr2 = Bidirectional(GRU(64, return_sequences=True),merge_mode='sum')(b_gr1) - b_gr3 = Bidirectional(GRU(64, return_sequences=True),merge_mode='sum')(b_gr1) - d1 = Dense(32, activation='relu')(b_gr3) - d2 = Dense(8, activation='relu')(d1) - d3 = Dense(1, activation='softmax')(d2) - oup = Reshape(target_shape=(input_dim[0],))(d3) + bn_b_gr1 = BatchNormalization(momentum=0.98)(b_gr1) + b_gr2 = Bidirectional(GRU(64, return_sequences=True),merge_mode='sum')(bn_b_gr1) + bn_b_gr2 = BatchNormalization(momentum=0.98)(b_gr2) + d1 = Dense(32, activation='relu')(bn_b_gr2) + bn_d1 = BatchNormalization(momentum=0.98)(d1) + d2 = Dense(8, activation='relu')(bn_d1) + bn_d2 = BatchNormalization(momentum=0.98)(d2) + d3 = Dense(1, activation='softmax')(bn_d2) + bn_d3 = BatchNormalization(momentum=0.98)(d3) + oup = Reshape(target_shape=(input_dim[0],))(bn_d3) return Model(inp, oup) def write_model_arch(mod,mod_file): @@ -132,4 +134,4 @@ def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0): if __name__ == '__main__': # pass - train_segment('story_phrases_full')#,'./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001) + train_segment('story_words')#,'./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001)