parent
b8a9f87031
commit
7cbfebbf1a
|
|
@ -11,11 +11,12 @@ import os
|
||||||
import random
|
import random
|
||||||
import csv
|
import csv
|
||||||
import gc
|
import gc
|
||||||
import progressbar
|
# import progressbar
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
def prog_bar(title):
|
# def prog_bar(title):
|
||||||
widgets = [title, progressbar.Counter(), ' [', progressbar.Bar(), '] - ', progressbar.ETA()]
|
# widgets = [title, progressbar.Counter(), ' [', progressbar.Bar(), '] - ', progressbar.ETA()]
|
||||||
return progressbar.ProgressBar(widgets=widgets)
|
# return progressbar.ProgressBar(widgets=widgets)
|
||||||
|
|
||||||
def siamese_pairs(rightGroup, wrongGroup):
|
def siamese_pairs(rightGroup, wrongGroup):
|
||||||
group1 = [r for (i, r) in rightGroup.iterrows()]
|
group1 = [r for (i, r) in rightGroup.iterrows()]
|
||||||
|
|
@ -26,7 +27,7 @@ def siamese_pairs(rightGroup, wrongGroup):
|
||||||
random.shuffle(rightRightPairs)
|
random.shuffle(rightRightPairs)
|
||||||
# return (random.sample(same,10), random.sample(diff,10))
|
# return (random.sample(same,10), random.sample(diff,10))
|
||||||
# return rightRightPairs[:10],rightWrongPairs[:10]
|
# return rightRightPairs[:10],rightWrongPairs[:10]
|
||||||
return rightRightPairs[:16],rightWrongPairs[:16]
|
return rightRightPairs[:32],rightWrongPairs[:32]
|
||||||
# return rightRightPairs,rightWrongPairs
|
# return rightRightPairs,rightWrongPairs
|
||||||
|
|
||||||
def create_spectrogram_tfrecords(audio_group='audio'):
|
def create_spectrogram_tfrecords(audio_group='audio'):
|
||||||
|
|
@ -55,16 +56,21 @@ def create_spectrogram_tfrecords(audio_group='audio'):
|
||||||
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
|
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
|
||||||
|
|
||||||
writer = tf.python_io.TFRecordWriter('./outputs/' + audio_group + '.tfrecords')
|
writer = tf.python_io.TFRecordWriter('./outputs/' + audio_group + '.tfrecords')
|
||||||
prog = prog_bar('Generating siamese pairs : ')
|
prog = tqdm(audio_samples.groupby(audio_samples['word']),desc='Computing spectrogram')
|
||||||
for (w, word_group) in prog(audio_samples.groupby(audio_samples['word'])):
|
for (w, word_group) in prog:
|
||||||
|
prog.set_postfix(word=w)
|
||||||
g = word_group.reset_index()
|
g = word_group.reset_index()
|
||||||
g['spectrogram'] = apply_by_multiprocessing(g['file_path'],generate_aiff_spectrogram)
|
g['spectrogram'] = apply_by_multiprocessing(g['file_path'],generate_aiff_spectrogram)
|
||||||
sample_right = g.loc[audio_samples['variant'] == 'low']
|
sample_right = g.loc[g['variant'] == 'low']
|
||||||
sample_wrong = g.loc[audio_samples['variant'] == 'medium']
|
sample_wrong = g.loc[g['variant'] == 'medium']
|
||||||
same, diff = siamese_pairs(sample_right, sample_wrong)
|
same, diff = siamese_pairs(sample_right, sample_wrong)
|
||||||
groups = [([0,1],same),([1,0],diff)]
|
groups = [([0,1],same),([1,0],diff)]
|
||||||
for (output,group) in groups:
|
for (output,group) in groups:
|
||||||
for sample1,sample2 in group:
|
group_prog = tqdm(group,desc='Writing Spectrogram')
|
||||||
|
for sample1,sample2 in group_prog:
|
||||||
|
group_prog.set_postfix(output=output
|
||||||
|
,var1=sample1['variant']
|
||||||
|
,var2=sample2['variant'])
|
||||||
spectro1,spectro2 = sample1['spectrogram'],sample2['spectrogram']
|
spectro1,spectro2 = sample1['spectrogram'],sample2['spectrogram']
|
||||||
spec_n1,spec_n2 = spectro1.shape[0],spectro2.shape[0]
|
spec_n1,spec_n2 = spectro1.shape[0],spectro2.shape[0]
|
||||||
spec_w1,spec_w2 = spectro1.shape[1],spectro2.shape[1]
|
spec_w1,spec_w2 = spectro1.shape[1],spectro2.shape[1]
|
||||||
|
|
@ -93,6 +99,7 @@ def create_spectrogram_tfrecords(audio_group='audio'):
|
||||||
}
|
}
|
||||||
))
|
))
|
||||||
writer.write(example.SerializeToString())
|
writer.write(example.SerializeToString())
|
||||||
|
prog.close()
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
def padd_zeros(spgr, max_samples):
|
def padd_zeros(spgr, max_samples):
|
||||||
|
|
@ -141,7 +148,7 @@ def read_siamese_tfrecords(audio_group='audio'):
|
||||||
# if len(input_pairs) > 50:
|
# if len(input_pairs) > 50:
|
||||||
# break
|
# break
|
||||||
input_data,output_data = np.asarray(input_pairs),np.asarray(output_class)
|
input_data,output_data = np.asarray(input_pairs),np.asarray(output_class)
|
||||||
# import pdb; pdb.set_trace()
|
import pdb; pdb.set_trace()
|
||||||
# tr_x1,te_x1,tr_x2,te_x2,tr_y,te_y = train_test_split(input1,input2,output_class)
|
# tr_x1,te_x1,tr_x2,te_x2,tr_y,te_y = train_test_split(input1,input2,output_class)
|
||||||
tr_pairs,te_pairs,tr_y,te_y = train_test_split(input_data,output_data)
|
tr_pairs,te_pairs,tr_y,te_y = train_test_split(input_data,output_data)
|
||||||
# return (tr_x1,te_x1,tr_x2,te_x2,tr_y,te_y)
|
# return (tr_x1,te_x1,tr_x2,te_x2,tr_y,te_y)
|
||||||
|
|
@ -186,9 +193,9 @@ if __name__ == '__main__':
|
||||||
# sunflower_pairs_data()
|
# sunflower_pairs_data()
|
||||||
# create_spectrogram_data()
|
# create_spectrogram_data()
|
||||||
# create_spectrogram_data('story_words')
|
# create_spectrogram_data('story_words')
|
||||||
# create_spectrogram_tfrecords('story_words')
|
create_spectrogram_tfrecords('story_words')
|
||||||
# create_spectrogram_tfrecords('story_all')
|
# create_spectrogram_tfrecords('story_words_test')
|
||||||
read_siamese_tfrecords('story_all')
|
# read_siamese_tfrecords('story_all')
|
||||||
# create_padded_spectrogram()
|
# create_padded_spectrogram()
|
||||||
# create_speech_pairs_data()
|
# create_speech_pairs_data()
|
||||||
# print(speech_model_data())
|
# print(speech_model_data())
|
||||||
|
|
|
||||||
|
|
@ -114,11 +114,11 @@ def train_siamese():
|
||||||
rms = RMSprop(lr=0.001)
|
rms = RMSprop(lr=0.001)
|
||||||
model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy])
|
model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy])
|
||||||
model.fit(
|
model.fit(
|
||||||
[tr_x1, tr_x2],
|
[tr_pairs[:, 0], tr_pairs[:, 1]],
|
||||||
tr_y,
|
tr_y,
|
||||||
batch_size=128,
|
batch_size=128,
|
||||||
epochs=50,
|
epochs=100,
|
||||||
validation_data=([tr_pairs[:, 0], tr_pairs[:, 1]], te_y),
|
validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y),
|
||||||
callbacks=[tb_cb, cp_cb])
|
callbacks=[tb_cb, cp_cb])
|
||||||
|
|
||||||
model.save('./models/siamese_speech_model-final.h5')
|
model.save('./models/siamese_speech_model-final.h5')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue