trying to overfit 2 samples with model -> doesn't seem to converge
parent
8d550c58cc
commit
cc4fbe45b9
|
|
@ -258,7 +258,7 @@ if __name__ == '__main__':
|
||||||
# plot_segments('story_test_segments')
|
# plot_segments('story_test_segments')
|
||||||
# fix_csv('story_words')
|
# fix_csv('story_words')
|
||||||
# pass
|
# pass
|
||||||
create_segments_tfrecords('story_words', sample_count=0)
|
create_segments_tfrecords('story_words.3', sample_count=3,train_test_ratio=0.33)
|
||||||
# record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test')
|
# record_generator,input_data,output_data,copy_read_consts = read_segments_tfrecords_generator('story_test')
|
||||||
# tr_gen = record_generator()
|
# tr_gen = record_generator()
|
||||||
# for i in tr_gen:
|
# for i in tr_gen:
|
||||||
|
|
|
||||||
|
|
@ -79,13 +79,13 @@ def load_model_arch(mod_file):
|
||||||
|
|
||||||
def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0):
|
def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0):
|
||||||
# collection_name = 'story_test'
|
# collection_name = 'story_test'
|
||||||
batch_size = 64
|
# batch_size = 32
|
||||||
# batch_size = 4
|
batch_size = 1
|
||||||
model_dir = './models/segment/'+collection_name
|
model_dir = './models/segment/'+collection_name
|
||||||
create_dir(model_dir)
|
create_dir(model_dir)
|
||||||
log_dir = './logs/segment/'+collection_name
|
log_dir = './logs/segment/'+collection_name
|
||||||
create_dir(log_dir)
|
create_dir(log_dir)
|
||||||
tr_gen_fn,te_x,te_y,copy_read_consts = read_segments_tfrecords_generator(collection_name,batch_size,2*batch_size)
|
tr_gen_fn,te_x,te_y,copy_read_consts = read_segments_tfrecords_generator(collection_name,batch_size,batch_size)
|
||||||
tr_gen = tr_gen_fn()
|
tr_gen = tr_gen_fn()
|
||||||
n_step,n_features,n_records = copy_read_consts(model_dir)
|
n_step,n_features,n_records = copy_read_consts(model_dir)
|
||||||
input_dim = (n_step, n_features)
|
input_dim = (n_step, n_features)
|
||||||
|
|
@ -137,4 +137,4 @@ def train_segment(collection_name = 'test',resume_weights='',initial_epoch=0):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# pass
|
# pass
|
||||||
train_segment('story_words')#,'./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001)
|
train_segment('story_words.3')#,'./models/segment/story_phrases.1000/speech_segment_model-final.h5',1001)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue