fixed input_dim

master
Malar Kannan 2017-12-07 17:15:44 +05:30
parent a7f1451a7f
commit 03edd935ea
2 changed files with 5 additions and 5 deletions

View File

@ -257,7 +257,7 @@ if __name__ == '__main__':
# plot_segments('story_test_segments')
# fix_csv('story_phrases')
# pass
create_segments_tfrecords('story_phrases', sample_count=1000)
create_segments_tfrecords('story_phrases', sample_count=100)
# record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test')
# tr_gen = record_generator()
# for i in tr_gen:

View File

@ -36,7 +36,7 @@ def ctc_lambda_func(args):
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
def segment_model(input_dim):
input_dim = (100,100,1)
# input_dim = (100,100,1)
inp = Input(shape=input_dim)
cnv1 = Conv2D(filters=32, kernel_size=(5,9))(inp)
cnv2 = Conv2D(filters=1, kernel_size=(5,9))(cnv1)
@ -55,7 +55,7 @@ def segment_model(input_dim):
def simple_segment_model(input_dim):
# input_dim = (100,100)
input_dim = (506,743)
# input_dim = (506,743)
inp = Input(shape=input_dim)
b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp)
# b_gr1
@ -77,7 +77,7 @@ def load_model_arch(mod_file):
def train_segment(collection_name = 'test'):
# collection_name = 'story_test'
batch_size = 128
batch_size = 64
# batch_size = 4
model_dir = './models/segment/'+collection_name
create_dir(model_dir)
@ -133,4 +133,4 @@ def train_segment(collection_name = 'test'):
if __name__ == '__main__':
# pass
train_segment('test')
train_segment('story_phrases')