trying to overfit 2 samples with model -> doesn't seem to converge

master
Malar Kannan 2017-12-11 15:03:14 +05:30
parent 8d550c58cc
commit cc4fbe45b9
2 changed files with 5 additions and 5 deletions

View File

@ -258,7 +258,7 @@ if __name__ == '__main__':
# plot_segments('story_test_segments') # plot_segments('story_test_segments')
# fix_csv('story_words') # fix_csv('story_words')
# pass # pass
create_segments_tfrecords('story_words', sample_count=0) create_segments_tfrecords('story_words.3', sample_count=3,train_test_ratio=0.33)
# record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test') # record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test')
# tr_gen = record_generator() # tr_gen = record_generator()
# for i in tr_gen: # for i in tr_gen:

View File

@ -79,13 +79,13 @@ def load_model_arch(mod_file):
def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0): def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0):
# collection_name = 'story_test' # collection_name = 'story_test'
batch_size = 64 # batch_size = 32
# batch_size = 4 batch_size = 1
model_dir = './models/segment/'+collection_name model_dir = './models/segment/'+collection_name
create_dir(model_dir) create_dir(model_dir)
log_dir = './logs/segment/'+collection_name log_dir = './logs/segment/'+collection_name
create_dir(log_dir) create_dir(log_dir)
tr_gen_fn,te_x,te_y,copy_read_consts = read_segments_tfrecords_generator(collection_name,batch_size,2*batch_size) tr_gen_fn,te_x,te_y,copy_read_consts = read_segments_tfrecords_generator(collection_name,batch_size,batch_size)
tr_gen = tr_gen_fn() tr_gen = tr_gen_fn()
n_step,n_features,n_records = copy_read_consts(model_dir) n_step,n_features,n_records = copy_read_consts(model_dir)
input_dim = (n_step, n_features) input_dim = (n_step, n_features)
@ -137,4 +137,4 @@ def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0):
if __name__ == '__main__': if __name__ == '__main__':
# pass # pass
train_segment('story_words')#,'./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001) train_segment('story_words.3')#,'./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001)