code cleanup

master
Malar Kannan 2017-10-26 12:48:31 +05:30
parent 5824158af2
commit 49e6a46efd
2 changed files with 17 additions and 11 deletions

View File

@ -31,6 +31,7 @@ def create_test_pair(l, r, max_samples):
r_sample = append_zeros(r, max_samples) r_sample = append_zeros(r, max_samples)
return np.asarray([[l_sample, r_sample]]) return np.asarray([[l_sample, r_sample]])
def create_X(sp, max_samples): def create_X(sp, max_samples):
return create_pair(sp[0]['spectrogram'], sp[1]['spectrogram'], max_samples) return create_pair(sp[0]['spectrogram'], sp[1]['spectrogram'], max_samples)
@ -106,8 +107,8 @@ def create_speech_pairs_data(audio_group='audio'):
def speech_model_data(): def speech_model_data():
tr_pairs = np.load('outputs/tr_pairs.npy') / 255.0 tr_pairs = np.load('outputs/tr_pairs.npy') / 255.0
te_pairs = np.load('outputs/te_pairs.npy') / 255.0 te_pairs = np.load('outputs/te_pairs.npy') / 255.0
tr_pairs[tr_pairs < 0] = 0 # tr_pairs[tr_pairs < 0] = 0
te_pairs[te_pairs < 0] = 0 # te_pairs[te_pairs < 0] = 0
tr_y = np.load('outputs/tr_y.npy') tr_y = np.load('outputs/tr_y.npy')
te_y = np.load('outputs/te_y.npy') te_y = np.load('outputs/te_y.npy')
return tr_pairs, te_pairs, tr_y, te_y return tr_pairs, te_pairs, tr_y, te_y

View File

@ -6,12 +6,17 @@ reload(speech_data)
from speech_data import create_test_pair,get_word_pairs_data from speech_data import create_test_pair,get_word_pairs_data
import numpy as np import numpy as np
sunflower_data,sunflower_result = get_word_pairs_data('sunflowers',15)
sunflower_result
model = siamese_model((15, 1654)) model = siamese_model((15, 1654))
model.load_weights('./models/siamese_speech_model-final.h5') model.load_weights('./models/siamese_speech_model-final.h5')
def predict_recording_with(m,sample_size=15):
spec1 = record_spectrogram(n_sec=1.4) spec1 = record_spectrogram(n_sec=1.4)
spec2 = record_spectrogram(n_sec=1.4) spec2 = record_spectrogram(n_sec=1.4)
inp = create_test_pair(spec1,spec2,16) inp = create_test_pair(spec1,spec2,sample_size)
model.predict([inp[:, 0], inp[:, 1]]) return m.predict([inp[:, 0], inp[:, 1]])
predict_recording_with(model)
sunflower_data,sunflower_result = get_word_pairs_data('sunflowers',15)
sunflower_result
model.predict([sunflower_data[:, 0], sunflower_data[:, 1]]) < 0.5 model.predict([sunflower_data[:, 0], sunflower_data[:, 1]]) < 0.5