Compare commits

..

2 Commits

Author SHA1 Message Date
Malar Kannan c75ff4d109 failure visualization wip 2017-11-15 15:17:37 +05:30
Malar Kannan a9b244a50c the pair generation order is randomized 2017-11-15 14:43:39 +05:30
2 changed files with 17 additions and 3 deletions

View File

@ -19,8 +19,8 @@ from tqdm import tqdm
def siamese_pairs(rightGroup, wrongGroup):
group1 = [r for (i, r) in rightGroup.iterrows()]
group2 = [r for (i, r) in wrongGroup.iterrows()]
rightWrongPairs = [(g1, g2) for g2 in group2 for g1 in group1]
rightRightPairs = [i for i in itertools.combinations(group1, 2)]#+[i for i in itertools.combinations(group2, 2)]
rightWrongPairs = [(g1, g2) for g2 in group2 for g1 in group1]+[(g2, g1) for g2 in group2 for g1 in group1]
rightRightPairs = [i for i in itertools.permutations(group1, 2)]#+[i for i in itertools.combinations(group2, 2)]
# random.shuffle(rightWrongPairs)
# random.shuffle(rightRightPairs)
# return rightRightPairs[:10],rightWrongPairs[:10]

View File

@ -128,10 +128,24 @@ def play_results(audio_group='audio'):
break
close_player()
def visualize_results(audio_group='audio'):
# %matplotlib inline
audio_group = 'story_words'
result = pd.read_csv('./outputs/' + audio_group + '.results.csv',index_col=0)
result.groupby('success').size().plot(kind='bar')
result.describe(include=['object'])
failed = result[result['success'] == False]
same_failed = failed[failed['variant1'] == failed['variant2']]
diff_failed = failed[failed['variant1'] != failed['variant2']]
same_failed[same_failed['voice1'] != same_failed['voice2']]
diff_failed[diff_failed['voice1'] != diff_failed['voice2']]
if __name__ == '__main__':
evaluate_siamese('./outputs/story_words_test.train.tfrecords',audio_group='story_words',weights ='siamese_speech_model-712-epoch-0.00-acc.h5')
# evaluate_siamese('./outputs/story_words_test.train.tfrecords',audio_group='story_words',weights ='siamese_speech_model-712-epoch-0.00-acc.h5')
# evaluate_siamese('./outputs/story_words.test.tfrecords',audio_group='story_words',weights ='siamese_speech_model-675-epoch-0.00-acc.h5')
# play_results('story_words')
visualize_results('story_words')
# test_with('rand_edu')
# sunflower_data,sunflower_result = get_word_pairs_data('sweater',15)
# print(np.argmax(model.predict([sunflower_data[:, 0], sunflower_data[:, 1]]),axis=1))