diff --git a/speech_data.py b/speech_data.py index 8b3cd84..e41cb88 100644 --- a/speech_data.py +++ b/speech_data.py @@ -8,7 +8,7 @@ import numpy as np from speech_spectrum import generate_aiff_spectrogram from sklearn.model_selection import train_test_split import itertools -import os +import os,shutil import random import csv import gc @@ -144,6 +144,9 @@ def read_siamese_tfrecords_generator(audio_group='audio',batch_size=32,test_size const_file = os.path.join('./outputs',audio_group+'.constants') (n_spec,n_features,n_records) = pickle.load(open(const_file,'rb')) + def copy_read_consts(dest_dir): + shutil.copy2(const_file,dest_dir) + return (n_spec,n_features,n_records) # @threadsafe_iter def record_generator(): print('reading tfrecords({}-train)...'.format(audio_group)) @@ -195,7 +198,7 @@ def read_siamese_tfrecords_generator(audio_group='audio',batch_size=32,test_size output = example.features.feature['output'].int64_list.value output_data[i] = np.asarray(output) - return record_generator,input_data,output_data,n_spec,n_features,n_records + return record_generator,input_data,output_data,copy_read_consts def audio_samples_word_count(audio_group='audio'): audio_samples = pd.read_csv( './outputs/' + audio_group + '.csv') @@ -249,7 +252,7 @@ if __name__ == '__main__': # create_spectrogram_tfrecords('story_all',sample_count=25) # fix_csv('story_words_test') #fix_csv('story_phrases') - create_spectrogram_tfrecords('story_phrases',sample_count=0,train_test_ratio=0.1) + create_spectrogram_tfrecords('story_phrases',sample_count=100,train_test_ratio=0.1) # create_spectrogram_tfrecords('audio',sample_count=50) # read_siamese_tfrecords_generator('audio') # padd_zeros_siamese_tfrecords('audio') diff --git a/speech_model.py b/speech_model.py index 5fe9b70..42a6fb3 100644 --- a/speech_model.py +++ b/speech_model.py @@ -74,7 +74,8 @@ def train_siamese(audio_group = 'audio'): create_dir(model_dir) log_dir = './logs/'+audio_group create_dir(log_dir) - tr_gen_fn,te_pairs,te_y,n_step,n_features,n_records = read_siamese_tfrecords_generator(audio_group,batch_size=batch_size,test_size=batch_size) + tr_gen_fn,te_pairs,te_y,copy_read_consts = read_siamese_tfrecords_generator(audio_group,batch_size=batch_size,test_size=batch_size) + n_step,n_features,n_records = copy_read_consts() tr_gen = tr_gen_fn() input_dim = (n_step, n_features)