fixed input_dim
parent
a7f1451a7f
commit
03edd935ea
|
|
@ -257,7 +257,7 @@ if __name__ == '__main__':
|
|||
# plot_segments('story_test_segments')
|
||||
# fix_csv('story_phrases')
|
||||
# pass
|
||||
create_segments_tfrecords('story_phrases', sample_count=1000)
|
||||
create_segments_tfrecords('story_phrases', sample_count=100)
|
||||
# record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test')
|
||||
# tr_gen = record_generator()
|
||||
# for i in tr_gen:
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def ctc_lambda_func(args):
|
|||
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
|
||||
|
||||
def segment_model(input_dim):
|
||||
input_dim = (100,100,1)
|
||||
# input_dim = (100,100,1)
|
||||
inp = Input(shape=input_dim)
|
||||
cnv1 = Conv2D(filters=32, kernel_size=(5,9))(inp)
|
||||
cnv2 = Conv2D(filters=1, kernel_size=(5,9))(cnv1)
|
||||
|
|
@ -55,7 +55,7 @@ def segment_model(input_dim):
|
|||
|
||||
def simple_segment_model(input_dim):
|
||||
# input_dim = (100,100)
|
||||
input_dim = (506,743)
|
||||
# input_dim = (506,743)
|
||||
inp = Input(shape=input_dim)
|
||||
b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp)
|
||||
# b_gr1
|
||||
|
|
@ -77,7 +77,7 @@ def load_model_arch(mod_file):
|
|||
|
||||
def train_segment(collection_name = 'test'):
|
||||
# collection_name = 'story_test'
|
||||
batch_size = 128
|
||||
batch_size = 64
|
||||
# batch_size = 4
|
||||
model_dir = './models/segment/'+collection_name
|
||||
create_dir(model_dir)
|
||||
|
|
@ -133,4 +133,4 @@ def train_segment(collection_name = 'test'):
|
|||
|
||||
if __name__ == '__main__':
|
||||
# pass
|
||||
train_segment('test')
|
||||
train_segment('story_phrases')
|
||||
|
|
|
|||
Loading…
Reference in New Issue