From a9b244a50ce8a7913d27add027610b56c12d4269 Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Wed, 15 Nov 2017 14:43:39 +0530 Subject: [PATCH] the pair generation order is randomized --- speech_data.py | 4 ++-- speech_test.py | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/speech_data.py b/speech_data.py index b7b5067..21f0466 100644 --- a/speech_data.py +++ b/speech_data.py @@ -19,8 +19,8 @@ from tqdm import tqdm def siamese_pairs(rightGroup, wrongGroup): group1 = [r for (i, r) in rightGroup.iterrows()] group2 = [r for (i, r) in wrongGroup.iterrows()] - rightWrongPairs = [(g1, g2) for g2 in group2 for g1 in group1] - rightRightPairs = [i for i in itertools.combinations(group1, 2)]#+[i for i in itertools.combinations(group2, 2)] + rightWrongPairs = [(g1, g2) for g2 in group2 for g1 in group1]+[(g2, g1) for g2 in group2 for g1 in group1] + rightRightPairs = [i for i in itertools.permutations(group1, 2)]#+[i for i in itertools.combinations(group2, 2)] # random.shuffle(rightWrongPairs) # random.shuffle(rightRightPairs) # return rightRightPairs[:10],rightWrongPairs[:10] diff --git a/speech_test.py b/speech_test.py index 1ee7789..58f07e7 100644 --- a/speech_test.py +++ b/speech_test.py @@ -128,10 +128,16 @@ def play_results(audio_group='audio'): break close_player() +def visualize_results(audio_group='audio'): + result = pd.read_csv('./outputs/' + audio_group + '.results.csv',index_col=0) + import pdb; pdb.set_trace() + + if __name__ == '__main__': - evaluate_siamese('./outputs/story_words_test.train.tfrecords',audio_group='story_words',weights ='siamese_speech_model-712-epoch-0.00-acc.h5') + # evaluate_siamese('./outputs/story_words_test.train.tfrecords',audio_group='story_words',weights ='siamese_speech_model-712-epoch-0.00-acc.h5') # evaluate_siamese('./outputs/story_words.test.tfrecords',audio_group='story_words',weights ='siamese_speech_model-675-epoch-0.00-acc.h5') # play_results('story_words') + visualize_results('story_words') # test_with('rand_edu') # sunflower_data,sunflower_result = get_word_pairs_data('sweater',15) # print(np.argmax(model.predict([sunflower_data[:, 0], sunflower_data[:, 1]]),axis=1))