fixed input_dim
parent
a7f1451a7f
commit
03edd935ea
|
|
@ -257,7 +257,7 @@ if __name__ == '__main__':
|
||||||
# plot_segments('story_test_segments')
|
# plot_segments('story_test_segments')
|
||||||
# fix_csv('story_phrases')
|
# fix_csv('story_phrases')
|
||||||
# pass
|
# pass
|
||||||
create_segments_tfrecords('story_phrases', sample_count=1000)
|
create_segments_tfrecords('story_phrases', sample_count=100)
|
||||||
# record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test')
|
# record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test')
|
||||||
# tr_gen = record_generator()
|
# tr_gen = record_generator()
|
||||||
# for i in tr_gen:
|
# for i in tr_gen:
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ def ctc_lambda_func(args):
|
||||||
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
|
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
|
||||||
|
|
||||||
def segment_model(input_dim):
|
def segment_model(input_dim):
|
||||||
input_dim = (100,100,1)
|
# input_dim = (100,100,1)
|
||||||
inp = Input(shape=input_dim)
|
inp = Input(shape=input_dim)
|
||||||
cnv1 = Conv2D(filters=32, kernel_size=(5,9))(inp)
|
cnv1 = Conv2D(filters=32, kernel_size=(5,9))(inp)
|
||||||
cnv2 = Conv2D(filters=1, kernel_size=(5,9))(cnv1)
|
cnv2 = Conv2D(filters=1, kernel_size=(5,9))(cnv1)
|
||||||
|
|
@ -55,7 +55,7 @@ def segment_model(input_dim):
|
||||||
|
|
||||||
def simple_segment_model(input_dim):
|
def simple_segment_model(input_dim):
|
||||||
# input_dim = (100,100)
|
# input_dim = (100,100)
|
||||||
input_dim = (506,743)
|
# input_dim = (506,743)
|
||||||
inp = Input(shape=input_dim)
|
inp = Input(shape=input_dim)
|
||||||
b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp)
|
b_gr1 = Bidirectional(GRU(256, return_sequences=True),merge_mode='sum')(inp)
|
||||||
# b_gr1
|
# b_gr1
|
||||||
|
|
@ -77,7 +77,7 @@ def load_model_arch(mod_file):
|
||||||
|
|
||||||
def train_segment(collection_name = 'test'):
|
def train_segment(collection_name = 'test'):
|
||||||
# collection_name = 'story_test'
|
# collection_name = 'story_test'
|
||||||
batch_size = 128
|
batch_size = 64
|
||||||
# batch_size = 4
|
# batch_size = 4
|
||||||
model_dir = './models/segment/'+collection_name
|
model_dir = './models/segment/'+collection_name
|
||||||
create_dir(model_dir)
|
create_dir(model_dir)
|
||||||
|
|
@ -133,4 +133,4 @@ def train_segment(collection_name = 'test'):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# pass
|
# pass
|
||||||
train_segment('test')
|
train_segment('story_phrases')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue