finding number of record by streaming-onepass

master
Malar Kannan 2017-11-20 12:07:13 +05:30
parent 3ae8dc50a2
commit a5d4ede35d
2 changed files with 13 additions and 13 deletions

View File

@ -143,15 +143,16 @@ def read_siamese_tfrecords_generator(audio_group='audio',batch_size=32,test_size
output_class = [] output_class = []
const_file = os.path.join('./outputs',audio_group+'.constants') const_file = os.path.join('./outputs',audio_group+'.constants')
(n_spec,n_features,n_records) = pickle.load(open(const_file,'rb')) (n_spec,n_features,n_records) = pickle.load(open(const_file,'rb'))
print('reading tfrecords({}-train)...'.format(audio_group))
# @threadsafe_iter # @threadsafe_iter
def record_generator(): def record_generator():
print('reading tfrecords({}-train)...'.format(audio_group))
input_data = [] input_data = []
output_data = [] output_data = []
while True: while True:
record_iterator = tf.python_io.tf_record_iterator(path=records_file) record_iterator,records_count = record_generator_count(records_file)
#tqdm(enumerate(record_iterator),total=n_records) #tqdm(enumerate(record_iterator),total=records_count)
#enumerate(record_iterator)
for (i,string_record) in enumerate(record_iterator): for (i,string_record) in enumerate(record_iterator):
example = tf.train.Example() example = tf.train.Example()
example.ParseFromString(string_record) example.ParseFromString(string_record)
@ -173,11 +174,9 @@ def read_siamese_tfrecords_generator(audio_group='audio',batch_size=32,test_size
output_data = [] output_data = []
# Read test in one-shot # Read test in one-shot
te_records_file = os.path.join('./outputs',audio_group+'.test.tfrecords')
te_re_iterator = tf.python_io.tf_record_iterator(path=records_file)
te_n_records = len([i for i in te_re_iterator])
te_re_iterator = tf.python_io.tf_record_iterator(path=records_file)
print('reading tfrecords({}-test)...'.format(audio_group)) print('reading tfrecords({}-test)...'.format(audio_group))
te_records_file = os.path.join('./outputs',audio_group+'.test.tfrecords')
te_re_iterator,te_n_records = record_generator_count(records_file)
test_size = min([test_size,te_n_records]) if test_size > 0 else te_n_records test_size = min([test_size,te_n_records]) if test_size > 0 else te_n_records
input_data = np.zeros((test_size,2,n_spec,n_features)) input_data = np.zeros((test_size,2,n_spec,n_features))
output_data = np.zeros((test_size,2)) output_data = np.zeros((test_size,2))
@ -204,7 +203,9 @@ def audio_samples_word_count(audio_group='audio'):
def record_generator_count(records_file): def record_generator_count(records_file):
record_iterator = tf.python_io.tf_record_iterator(path=records_file) record_iterator = tf.python_io.tf_record_iterator(path=records_file)
count = len([i for i in record_iterator]) count = 0
for i in record_iterator:
count+=1
record_iterator = tf.python_io.tf_record_iterator(path=records_file) record_iterator = tf.python_io.tf_record_iterator(path=records_file)
return record_iterator,count return record_iterator,count
@ -248,7 +249,7 @@ if __name__ == '__main__':
# create_spectrogram_tfrecords('story_all',sample_count=25) # create_spectrogram_tfrecords('story_all',sample_count=25)
# fix_csv('story_words_test') # fix_csv('story_words_test')
#fix_csv('story_phrases') #fix_csv('story_phrases')
create_spectrogram_tfrecords('story_phrases',sample_count=10,train_test_ratio=0.1) create_spectrogram_tfrecords('story_phrases',sample_count=0,train_test_ratio=0.1)
# create_spectrogram_tfrecords('audio',sample_count=50) # create_spectrogram_tfrecords('audio',sample_count=50)
# read_siamese_tfrecords_generator('audio') # read_siamese_tfrecords_generator('audio')
# padd_zeros_siamese_tfrecords('audio') # padd_zeros_siamese_tfrecords('audio')

View File

@ -29,14 +29,13 @@ def test_with(audio_group):
def evaluate_siamese(records_file,audio_group='audio',weights = 'siamese_speech_model-final.h5'): def evaluate_siamese(records_file,audio_group='audio',weights = 'siamese_speech_model-final.h5'):
# audio_group='audio';model_file = 'siamese_speech_model-305-epoch-0.20-acc.h5' # audio_group='audio';model_file = 'siamese_speech_model-305-epoch-0.20-acc.h5'
# records_file = os.path.join('./outputs',eval_group+'.train.tfrecords') # records_file = os.path.join('./outputs',eval_group+'.train.tfrecords')
const_file = os.path.join('./outputs',audio_group+'.constants') const_file = os.path.join('./models/'+audio_group+'/',audio_group+'.constants')
arch_file='./models/'+audio_group+'/siamese_speech_model_arch.yaml' arch_file='./models/'+audio_group+'/siamese_speech_model_arch.yaml'
weight_file='./models/'+audio_group+'/'+weights weight_file='./models/'+audio_group+'/'+weights
(n_spec,n_features,n_records) = pickle.load(open(const_file,'rb')) (n_spec,n_features,n_records) = pickle.load(open(const_file,'rb'))
print('evaluating {}...'.format(records_file)) print('evaluating {}...'.format(records_file))
model = load_model_arch(arch_file) model = load_model_arch(arch_file)
# model = siamese_model((n_spec, n_features)) # model = siamese_model((n_spec, n_features))
n_spec = 422
model.load_weights(weight_file) model.load_weights(weight_file)
record_iterator,records_count = record_generator_count(records_file) record_iterator,records_count = record_generator_count(records_file)
total,same_success,diff_success,skipped,same_failed,diff_failed = 0,0,0,0,0,0 total,same_success,diff_success,skipped,same_failed,diff_failed = 0,0,0,0,0,0
@ -179,9 +178,9 @@ def visualize_results(audio_group='audio'):
if __name__ == '__main__': if __name__ == '__main__':
# evaluate_siamese('./outputs/story_words_test.train.tfrecords',audio_group='story_words.gpu',weights ='siamese_speech_model-58-epoch-0.00-acc.h5') # evaluate_siamese('./outputs/story_words_test.train.tfrecords',audio_group='story_words.gpu',weights ='siamese_speech_model-58-epoch-0.00-acc.h5')
# evaluate_siamese('./outputs/story_words.test.tfrecords',audio_group='story_words',weights ='siamese_speech_model-675-epoch-0.00-acc.h5') # evaluate_siamese('./outputs/story_words.test.tfrecords',audio_group='story_words',weights ='siamese_speech_model-675-epoch-0.00-acc.h5')
# evaluate_siamese('./outputs/story_phrases.test.tfrecords',audio_group='story_phrases',weights ='siamese_speech_model-329-epoch-0.00-acc.h5') evaluate_siamese('./outputs/story_words_test.train.tfrecords',audio_group='story_phrases',weights ='siamese_speech_model-231-epoch-0.00-acc.h5')
# play_results('story_words') # play_results('story_words')
inspect_tfrecord('./outputs/story_phrases.test.tfrecords',audio_group='story_phrases') #inspect_tfrecord('./outputs/story_phrases.test.tfrecords',audio_group='story_phrases')
# visualize_results('story_words.gpu') # visualize_results('story_words.gpu')
# test_with('rand_edu') # test_with('rand_edu')
# sunflower_data,sunflower_result = get_word_pairs_data('sweater',15) # sunflower_data,sunflower_result = get_word_pairs_data('sweater',15)