diff --git a/segment_data.py b/segment_data.py index de86d61..9d95f57 100644 --- a/segment_data.py +++ b/segment_data.py @@ -257,7 +257,7 @@ if __name__ == '__main__': # plot_segments('story_test_segments') # fix_csv('story_phrases') # pass - create_segments_tfrecords('story_phrases', sample_count=1000) + create_segments_tfrecords('story_phrases', sample_count=100) # record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test') # tr_gen = record_generator() # for i in tr_gen: diff --git a/segment_model.py b/segment_model.py index f470d68..cd3aecf 100644 --- a/segment_model.py +++ b/segment_model.py @@ -36,7 +36,7 @@ def ctc_lambda_func(args): return K.ctc_batch_cost(labels, y_pred, input_length, label_length) def segment_model(input_dim): - input_dim = (100,100,1) + # input_dim = (100,100,1) inp = Input(shape=input_dim) cnv1 = Conv2D(filters=32, kernel_size=(5,9))(inp) cnv2 = Conv2D(filters=1, kernel_size=(5,9))(cnv1) @@ -55,7 +55,7 @@ def segment_model(input_dim): def simple_segment_model(input_dim): # input_dim = (100,100) - input_dim = (506,743) + # input_dim = (506,743) inp = Input(shape=input_dim) b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp) # b_gr1 @@ -77,7 +77,7 @@ def load_model_arch(mod_file): def train_segment(collection_name = 'test'): # collection_name = 'story_test' - batch_size = 128 + batch_size = 64 # batch_size = 4 model_dir = './models/segment/'+collection_name create_dir(model_dir) @@ -133,4 +133,4 @@ def train_segment(collection_name = 'test'): if __name__ == '__main__': # pass - train_segment('test') + train_segment('story_phrases')