diff --git a/speech_data.py b/speech_data.py index a47627c..2860961 100644 --- a/speech_data.py +++ b/speech_data.py @@ -11,11 +11,12 @@ import os import random import csv import gc -import progressbar +# import progressbar +from tqdm import tqdm -def prog_bar(title): - widgets = [title, progressbar.Counter(), ' [', progressbar.Bar(), '] - ', progressbar.ETA()] - return progressbar.ProgressBar(widgets=widgets) +# def prog_bar(title): +# widgets = [title, progressbar.Counter(), ' [', progressbar.Bar(), '] - ', progressbar.ETA()] +# return progressbar.ProgressBar(widgets=widgets) def siamese_pairs(rightGroup, wrongGroup): group1 = [r for (i, r) in rightGroup.iterrows()] @@ -26,7 +27,7 @@ def siamese_pairs(rightGroup, wrongGroup): random.shuffle(rightRightPairs) # return (random.sample(same,10), random.sample(diff,10)) # return rightRightPairs[:10],rightWrongPairs[:10] - return rightRightPairs[:16],rightWrongPairs[:16] + return rightRightPairs[:32],rightWrongPairs[:32] # return rightRightPairs,rightWrongPairs def create_spectrogram_tfrecords(audio_group='audio'): @@ -55,16 +56,21 @@ def create_spectrogram_tfrecords(audio_group='audio'): return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) writer = tf.python_io.TFRecordWriter('./outputs/' + audio_group + '.tfrecords') - prog = prog_bar('Generating siamese pairs : ') - for (w, word_group) in prog(audio_samples.groupby(audio_samples['word'])): + prog = tqdm(audio_samples.groupby(audio_samples['word']),desc='Computing spectrogram') + for (w, word_group) in prog: + prog.set_postfix(word=w) g = word_group.reset_index() g['spectrogram'] = apply_by_multiprocessing(g['file_path'],generate_aiff_spectrogram) - sample_right = g.loc[audio_samples['variant'] == 'low'] - sample_wrong = g.loc[audio_samples['variant'] == 'medium'] + sample_right = g.loc[g['variant'] == 'low'] + sample_wrong = g.loc[g['variant'] == 'medium'] same, diff = siamese_pairs(sample_right, sample_wrong) groups = [([0,1],same),([1,0],diff)] for (output,group) in groups: - for sample1,sample2 in group: + group_prog = tqdm(group,desc='Writing Spectrogram') + for sample1,sample2 in group_prog: + group_prog.set_postfix(output=output + ,var1=sample1['variant'] + ,var2=sample2['variant']) spectro1,spectro2 = sample1['spectrogram'],sample2['spectrogram'] spec_n1,spec_n2 = spectro1.shape[0],spectro2.shape[0] spec_w1,spec_w2 = spectro1.shape[1],spectro2.shape[1] @@ -93,6 +99,7 @@ def create_spectrogram_tfrecords(audio_group='audio'): } )) writer.write(example.SerializeToString()) + prog.close() writer.close() def padd_zeros(spgr, max_samples): @@ -141,7 +148,7 @@ def read_siamese_tfrecords(audio_group='audio'): # if len(input_pairs) > 50: # break input_data,output_data = np.asarray(input_pairs),np.asarray(output_class) - # import pdb; pdb.set_trace() + import pdb; pdb.set_trace() # tr_x1,te_x1,tr_x2,te_x2,tr_y,te_y = train_test_split(input1,input2,output_class) tr_pairs,te_pairs,tr_y,te_y = train_test_split(input_data,output_data) # return (tr_x1,te_x1,tr_x2,te_x2,tr_y,te_y) @@ -186,9 +193,9 @@ if __name__ == '__main__': # sunflower_pairs_data() # create_spectrogram_data() # create_spectrogram_data('story_words') - # create_spectrogram_tfrecords('story_words') - # create_spectrogram_tfrecords('story_all') - read_siamese_tfrecords('story_all') + create_spectrogram_tfrecords('story_words') + # create_spectrogram_tfrecords('story_words_test') + # read_siamese_tfrecords('story_all') # create_padded_spectrogram() # create_speech_pairs_data() # print(speech_model_data()) diff --git a/speech_siamese.py b/speech_siamese.py index e9ad718..353fa4b 100644 --- a/speech_siamese.py +++ b/speech_siamese.py @@ -114,11 +114,11 @@ def train_siamese(): rms = RMSprop(lr=0.001) model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy]) model.fit( - [tr_x1, tr_x2], + [tr_pairs[:, 0], tr_pairs[:, 1]], tr_y, batch_size=128, - epochs=50, - validation_data=([tr_pairs[:, 0], tr_pairs[:, 1]], te_y), + epochs=100, + validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y), callbacks=[tb_cb, cp_cb]) model.save('./models/siamese_speech_model-final.h5')