implemented evaluation of test data with model by overfitting on smaller dataset

master
Malar Kannan 2017-11-14 17:54:44 +05:30
parent e4b8b4e0a7
commit 10b024866e
7 changed files with 190 additions and 121 deletions

View File

@ -1,6 +1,5 @@
import pandas as pd import pandas as pd
from speech_utils import apply_by_multiprocessing from speech_tools import apply_by_multiprocessing,threadsafe_iter
from speech_utils import threadsafe_iter
# import dask as dd # import dask as dd
# import dask.dataframe as ddf # import dask.dataframe as ddf
import tensorflow as tf import tensorflow as tf
@ -199,6 +198,12 @@ def audio_samples_word_count(audio_group='audio'):
audio_samples = pd.read_csv( './outputs/' + audio_group + '.csv') audio_samples = pd.read_csv( './outputs/' + audio_group + '.csv')
return len(audio_samples.groupby(audio_samples['word'])) return len(audio_samples.groupby(audio_samples['word']))
def record_generator_count(records_file):
record_iterator = tf.python_io.tf_record_iterator(path=records_file)
count = len([i for i in record_iterator])
record_iterator = tf.python_io.tf_record_iterator(path=records_file)
return record_iterator,count
def fix_csv(audio_group='audio'): def fix_csv(audio_group='audio'):
audio_csv_lines = open('./outputs/' + audio_group + '.csv.orig','r').readlines() audio_csv_lines = open('./outputs/' + audio_group + '.csv.orig','r').readlines()
audio_csv_data = [i.strip().split(',') for i in audio_csv_lines] audio_csv_data = [i.strip().split(',') for i in audio_csv_lines]
@ -237,7 +242,8 @@ if __name__ == '__main__':
# pickle_constants('story_words') # pickle_constants('story_words')
# create_spectrogram_tfrecords('audio',sample_count=100) # create_spectrogram_tfrecords('audio',sample_count=100)
# create_spectrogram_tfrecords('story_all',sample_count=25) # create_spectrogram_tfrecords('story_all',sample_count=25)
create_spectrogram_tfrecords('story_words',sample_count=10,train_test_ratio=0.2) # fix_csv('story_words_test')
create_spectrogram_tfrecords('story_words_test',sample_count=100,train_test_ratio=0.0)
# create_spectrogram_tfrecords('audio',sample_count=50) # create_spectrogram_tfrecords('audio',sample_count=50)
# read_siamese_tfrecords_generator('audio') # read_siamese_tfrecords_generator('audio')
# padd_zeros_siamese_tfrecords('audio') # padd_zeros_siamese_tfrecords('audio')

View File

@ -3,7 +3,7 @@ from __future__ import print_function
import numpy as np import numpy as np
# from speech_data import speech_model_data # from speech_data import speech_model_data
from speech_data import read_siamese_tfrecords_generator from speech_data import read_siamese_tfrecords_generator
from keras.models import Model,load_model from keras.models import Model,load_model,model_from_yaml
from keras.layers import Input, Dense, Dropout, LSTM, Lambda, Concatenate from keras.layers import Input, Dense, Dropout, LSTM, Lambda, Concatenate
from keras.losses import categorical_crossentropy from keras.losses import categorical_crossentropy
# from keras.losses import binary_crossentropy # from keras.losses import binary_crossentropy
@ -12,7 +12,7 @@ from keras.utils import to_categorical
from keras.optimizers import RMSprop from keras.optimizers import RMSprop
from keras.callbacks import TensorBoard, ModelCheckpoint from keras.callbacks import TensorBoard, ModelCheckpoint
from keras import backend as K from keras import backend as K
from speech_utils import create_dir from speech_tools import create_dir
# def euclidean_distance(vects): # def euclidean_distance(vects):
# x, y = vects # x, y = vects
@ -36,13 +36,13 @@ def create_base_rnn_network(input_dim):
'''Base network to be shared (eq. to feature extraction). '''Base network to be shared (eq. to feature extraction).
''' '''
inp = Input(shape=input_dim) inp = Input(shape=input_dim)
ls0 = LSTM(512, return_sequences=True)(inp) # ls0 = LSTM(512, return_sequences=True)(inp)
ls1 = LSTM(256, return_sequences=True)(ls0) ls1 = LSTM(256, return_sequences=True)(inp)
ls2 = LSTM(128, return_sequences=True)(ls1) ls2 = LSTM(128, return_sequences=True)(ls1)
# ls3 = LSTM(32, return_sequences=True)(ls2) # ls3 = LSTM(32, return_sequences=True)(ls2)
ls4 = LSTM(64)(ls2) ls4 = LSTM(64)(ls2)
d1 = Dense(128, activation='relu')(ls4) # d1 = Dense(128, activation='relu')(ls4)
d2 = Dense(64, activation='relu')(d1) d2 = Dense(64, activation='relu')(ls4)
return Model(inp, ls4) return Model(inp, ls4)
@ -62,8 +62,8 @@ def dense_classifier(processed):
conc_proc = Concatenate()(processed) conc_proc = Concatenate()(processed)
d1 = Dense(64, activation='relu')(conc_proc) d1 = Dense(64, activation='relu')(conc_proc)
# dr1 = Dropout(0.1)(d1) # dr1 = Dropout(0.1)(d1)
d2 = Dense(128, activation='relu')(d1) # d2 = Dense(128, activation='relu')(d1)
d3 = Dense(8, activation='relu')(d2) d3 = Dense(8, activation='relu')(d1)
# dr2 = Dropout(0.1)(d2) # dr2 = Dropout(0.1)(d2)
return Dense(2, activation='softmax')(d3) return Dense(2, activation='softmax')(d3)
@ -82,6 +82,16 @@ def siamese_model(input_dim):
# model = Model([input_a, input_b], distance) # model = Model([input_a, input_b], distance)
return model return model
def write_model_arch(mod,mod_file):
model_f = open(mod_file,'w')
model_f.write(mod.to_yaml())
model_f.close()
def load_model_arch(mod_file):
model_f = open(mod_file,'r')
mod = model_from_yaml(model_f.read())
model_f.close()
return mod
def train_siamese(audio_group = 'audio'): def train_siamese(audio_group = 'audio'):
# the data, shuffled and split between train and test sets # the data, shuffled and split between train and test sets
@ -91,7 +101,7 @@ def train_siamese(audio_group = 'audio'):
create_dir(model_dir) create_dir(model_dir)
log_dir = './logs/'+audio_group log_dir = './logs/'+audio_group
create_dir(log_dir) create_dir(log_dir)
tr_gen_fn,te_pairs,te_y,n_step,n_features,n_records = read_siamese_tfrecords_generator(audio_group,batch_size=batch_size) tr_gen_fn,te_pairs,te_y,n_step,n_features,n_records = read_siamese_tfrecords_generator(audio_group,batch_size=batch_size,test_size=batch_size)
tr_gen = tr_gen_fn() tr_gen = tr_gen_fn()
# tr_y = to_categorical(tr_y_e, num_classes=2) # tr_y = to_categorical(tr_y_e, num_classes=2)
# te_y = to_categorical(te_y_e, num_classes=2) # te_y = to_categorical(te_y_e, num_classes=2)
@ -123,6 +133,7 @@ def train_siamese(audio_group = 'audio'):
# train # train
rms = RMSprop()#lr=0.001 rms = RMSprop()#lr=0.001
model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy]) model.compile(loss=categorical_crossentropy, optimizer=rms, metrics=[accuracy])
write_model_arch(model,model_dir+'/siamese_speech_model_arch.yaml')
# model.fit( # model.fit(
# [tr_pairs[:, 0], tr_pairs[:, 1]], # [tr_pairs[:, 0], tr_pairs[:, 1]],
# tr_y, # tr_y,

View File

@ -12,6 +12,7 @@ import time
import progressbar import progressbar
from generate_similar import similar_phoneme_phrase,similar_phrase from generate_similar import similar_phoneme_phrase,similar_phrase
from speech_tools import format_filename
OUTPUT_NAME = 'story_all' OUTPUT_NAME = 'story_all'
dest_dir = os.path.abspath('.') + '/outputs/' + OUTPUT_NAME + '/' dest_dir = os.path.abspath('.') + '/outputs/' + OUTPUT_NAME + '/'
@ -40,7 +41,10 @@ def create_dir(direc):
def dest_filename(w, v, r, t): def dest_filename(w, v, r, t):
return '{}-{}-{}-{}-{}.aiff'.format(w, v, r, t, str(random.randint(0, 10000))) rand_no = str(random.randint(0, 10000))
fname = '{}-{}-{}-{}-{}.aiff'.format(w, v, r, t, rand_no)
sanitized = format_filename(fname)
return sanitized
def dest_path(v, r, n): def dest_path(v, r, n):

View File

@ -13,6 +13,8 @@ from pysndfile import sndio as snd
from numpy.lib import stride_tricks from numpy.lib import stride_tricks
""" short time fourier transform of audio signal """ """ short time fourier transform of audio signal """
STFT_WINDOWS_MSEC = 20
STFT_WINDOW_OVERLAP = 1.0 / 3
def stft(sig, frameSize, overlapFac=0.5, window=np.hanning): def stft(sig, frameSize, overlapFac=0.5, window=np.hanning):
win = window(frameSize) win = window(frameSize)
@ -74,7 +76,7 @@ def logscale_spec(spec, sr=44100, factor=20.):
def generate_spec_frec(samples, samplerate): def generate_spec_frec(samples, samplerate):
# samplerate, samples = wav.read(audiopath) # samplerate, samples = wav.read(audiopath)
# s = stft(samples, binsize) # s = stft(samples, binsize)
s = stft(samples, samplerate * 150 // 1000, 1.0 / 3) s = stft(samples, samplerate * STFT_WINDOWS_MSEC // 1000, STFT_WINDOW_OVERLAP)
sshow, freq = logscale_spec(s, factor=1.0, sr=samplerate) sshow, freq = logscale_spec(s, factor=1.0, sr=samplerate)
ims = 20. * np.log10(np.abs(sshow) / 10e-6) ims = 20. * np.log10(np.abs(sshow) / 10e-6)
@ -141,8 +143,11 @@ def play_sunflower():
if __name__ == '__main__': if __name__ == '__main__':
play_sunflower() # play_sunflower()
# plot_aiff_stft('./outputs/sunflowers-Alex-150-normal-589.aiff') plot_aiff_stft('./outputs/story_words/Agnes/150/chicken-Agnes-150-low-1077.aiff')
plot_aiff_stft('./outputs/story_words/Agnes/150/chicken-Agnes-150-medium-1762.aiff')
# spec = generate_aiff_spectrogram('./outputs/story_words/Agnes/150/chicken-Agnes-150-low-1077.aiff')
# print(spec.shape)
# plot_aiff_stft('./outputs/sunflowers-Alex-180-normal-4763.aiff') # plot_aiff_stft('./outputs/sunflowers-Alex-180-normal-4763.aiff')
# plot_aiff_stft('./outputs/sunflowers-Victoria-180-normal-870.aiff') # plot_aiff_stft('./outputs/sunflowers-Victoria-180-normal-870.aiff')
# plot_aiff_stft('./outputs/sunflowers-Fred-180-phoneme-9733.aiff') # plot_aiff_stft('./outputs/sunflowers-Fred-180-phoneme-9733.aiff')

View File

@ -1,5 +1,6 @@
# from speech_siamese import siamese_model from speech_model import load_model_arch
from speech_tools import record_spectrogram, file_player from speech_tools import record_spectrogram, file_player
from speech_data import record_generator_count
# from importlib import reload # from importlib import reload
# import speech_data # import speech_data
# reload(speech_data) # reload(speech_data)
@ -9,6 +10,7 @@ import os
import pickle import pickle
import tensorflow as tf import tensorflow as tf
import csv import csv
from tqdm import tqdm
from speech_data import padd_zeros from speech_data import padd_zeros
def predict_recording_with(m,sample_size=15): def predict_recording_with(m,sample_size=15):
@ -17,48 +19,40 @@ def predict_recording_with(m,sample_size=15):
inp = create_test_pair(spec1,spec2,sample_size) inp = create_test_pair(spec1,spec2,sample_size)
return m.predict([inp[:, 0], inp[:, 1]]) return m.predict([inp[:, 0], inp[:, 1]])
# while(True):
# print(predict_recording_with(model))
def test_with(audio_group): def test_with(audio_group):
X,Y = speech_data(audio_group) X,Y = speech_data(audio_group)
print(np.argmax(model.predict([X[:, 0], X[:, 1]]),axis=1)) print(np.argmax(model.predict([X[:, 0], X[:, 1]]),axis=1))
print(Y.astype(np.int8)) print(Y.astype(np.int8))
def evaluate_siamese(audio_group='audio',model_file = 'siamese_speech_model-305-epoch-0.20-acc.h5'): def evaluate_siamese(records_file,audio_group='audio',weights = 'siamese_speech_model-final.h5'):
# audio_group='audio';model_file = 'siamese_speech_model-305-epoch-0.20-acc.h5' # audio_group='audio';model_file = 'siamese_speech_model-305-epoch-0.20-acc.h5'
records_file = os.path.join('./outputs',audio_group+'.train.tfrecords') # records_file = os.path.join('./outputs',eval_group+'.train.tfrecords')
const_file = os.path.join('./outputs',audio_group+'.constants') const_file = os.path.join('./outputs',audio_group+'.constants')
model_weights_path =os.path.join('./models/story_words/',model_file) arch_file='./models/'+audio_group+'/siamese_speech_model_arch.yaml'
weight_file='./models/'+audio_group+'/'+weights
(n_spec,n_features,n_records) = pickle.load(open(const_file,'rb')) (n_spec,n_features,n_records) = pickle.load(open(const_file,'rb'))
print('evaluating tfrecords({}-train)...'.format(audio_group)) print('evaluating {}...'.format(records_file))
model = load_model_arch(arch_file)
model = siamese_model((n_spec, n_features)) # model = siamese_model((n_spec, n_features))
model.load_weights(model_weights_path) model.load_weights(weight_file)
record_iterator = tf.python_io.tf_record_iterator(path=records_file) record_iterator,records_count = record_generator_count(records_file)
#tqdm(enumerate(record_iterator),total=n_records) total,same_success,diff_success,skipped,same_failed,diff_failed = 0,0,0,0,0,0
result_csv = open('./outputs/' + audio_group + '.results.csv','w') all_results = []
result_csv_w = csv.writer(result_csv, quoting=csv.QUOTE_MINIMAL) for (i,string_record) in tqdm(enumerate(record_iterator),total=records_count):
result_csv_w.writerow(["phoneme1","phoneme2","voice1","voice2","rate1","rate2","variant1","variant2","file1","file2"])
for (i,string_record) in enumerate(record_iterator):
# string_record = next(record_iterator) # string_record = next(record_iterator)
total+=1
example = tf.train.Example() example = tf.train.Example()
example.ParseFromString(string_record) example.ParseFromString(string_record)
spec_n1 = example.features.feature['spec_n1'].int64_list.value[0] spec_n1 = example.features.feature['spec_n1'].int64_list.value[0]
spec_n2 = example.features.feature['spec_n2'].int64_list.value[0] spec_n2 = example.features.feature['spec_n2'].int64_list.value[0]
if n_spec < spec_n1 or n_spec < spec_n2:
skipped+=1
continue
spec_w1 = example.features.feature['spec_w1'].int64_list.value[0] spec_w1 = example.features.feature['spec_w1'].int64_list.value[0]
spec_w2 = example.features.feature['spec_w2'].int64_list.value[0] spec_w2 = example.features.feature['spec_w2'].int64_list.value[0]
spec1 = np.array(example.features.feature['spec1'].float_list.value).reshape(spec_n1,spec_w1) spec1 = np.array(example.features.feature['spec1'].float_list.value).reshape(spec_n1,spec_w1)
spec2 = np.array(example.features.feature['spec2'].float_list.value).reshape(spec_n2,spec_w2) spec2 = np.array(example.features.feature['spec2'].float_list.value).reshape(spec_n2,spec_w2)
p_spec1,p_spec2 = padd_zeros(spec1,n_spec),padd_zeros(spec2,n_spec)
input_arr = np.asarray([[p_spec1,p_spec2]])
output_arr = np.asarray([example.features.feature['output'].int64_list.value])
y_pred = model.predict([input_arr[:, 0], input_arr[:, 1]])
predicted = np.asarray(y_pred[0]>0.5).astype(output_arr.dtype)
expected = output_arr[0]
if np.all(predicted == expected):
continue
word = example.features.feature['word'].bytes_list.value[0].decode() word = example.features.feature['word'].bytes_list.value[0].decode()
phoneme1 = example.features.feature['phoneme1'].bytes_list.value[0].decode() phoneme1 = example.features.feature['phoneme1'].bytes_list.value[0].decode()
phoneme2 = example.features.feature['phoneme2'].bytes_list.value[0].decode() phoneme2 = example.features.feature['phoneme2'].bytes_list.value[0].decode()
@ -71,9 +65,41 @@ def evaluate_siamese(audio_group='audio',model_file = 'siamese_speech_model-305-
variant2 = example.features.feature['variant2'].bytes_list.value[0].decode() variant2 = example.features.feature['variant2'].bytes_list.value[0].decode()
file1 = example.features.feature['file1'].bytes_list.value[0].decode() file1 = example.features.feature['file1'].bytes_list.value[0].decode()
file2 = example.features.feature['file2'].bytes_list.value[0].decode() file2 = example.features.feature['file2'].bytes_list.value[0].decode()
print(phoneme1,phoneme2,voice1,voice2,rate1,rate2,variant1,variant2,file1,file2)
result_csv_w.writerow([phoneme1,phoneme2,voice1,voice2,rate1,rate2,variant1,variant2,file1,file2]) p_spec1,p_spec2 = padd_zeros(spec1,n_spec),padd_zeros(spec2,n_spec)
result_csv.close() input_arr = np.asarray([[p_spec1,p_spec2]])
output_arr = np.asarray([example.features.feature['output'].int64_list.value])
y_pred = model.predict([input_arr[:, 0], input_arr[:, 1]])
predicted = np.asarray(y_pred[0]>0.5).astype(output_arr.dtype)
expected = output_arr[0]
status = np.all(predicted == expected)
result = {"phoneme1":phoneme1,"phoneme2":phoneme2,"voice1":voice1
,"voice2":voice2,"rate1":rate1,"rate2":rate2
,"variant1":variant1,"variant2":variant2,"file1":file1
,"file2":file2,"expected":expected[0],"predicted":y_pred[0][0]
,"success":status}
all_results.append(result)
if status:
if variant1 == variant2:
same_success+=1
else:
diff_success+=1
continue
else:
if variant1 == variant2:
same_failed+=1
else:
diff_failed+=1
print('total-{},same_success-{},diff_success-{},skipped-{},same_failed-{},diff_failed-{}'.format(total,same_success,diff_success,skipped,same_failed,diff_failed))
success = same_success+diff_success
failure = same_failed+diff_failed
print('accuracy-{:.3f}'.format(success*100/(success+failure)))
print('same_accuracy-{:.3f}'.format(same_success*100/(same_success+same_failed)))
print('diff_accuracy-{:.3f}'.format(diff_success*100/(diff_success+diff_failed)))
result_data = pd.DataFrame(all_results,columns=["phoneme1","phoneme2"
,"voice1","voice2","rate1","rate2","variant1","variant2","file1","file2",
"expected","predicted","success"])
result_data.to_csv('./outputs/' + audio_group + '.results.csv')
def play_results(audio_group='audio'): def play_results(audio_group='audio'):
@ -102,8 +128,10 @@ def play_results(audio_group='audio'):
break break
close_player() close_player()
# evaluate_siamese('story_words',model_file='siamese_speech_model-305-epoch-0.20-acc.h5') if __name__ == '__main__':
play_results('story_words') evaluate_siamese('./outputs/story_words_test.train.tfrecords',audio_group='story_words',weights ='siamese_speech_model-712-epoch-0.00-acc.h5')
# evaluate_siamese('./outputs/story_words.test.tfrecords',audio_group='story_words',weights ='siamese_speech_model-675-epoch-0.00-acc.h5')
# play_results('story_words')
# test_with('rand_edu') # test_with('rand_edu')
# sunflower_data,sunflower_result = get_word_pairs_data('sweater',15) # sunflower_data,sunflower_result = get_word_pairs_data('sweater',15)
# print(np.argmax(model.predict([sunflower_data[:, 0], sunflower_data[:, 1]]),axis=1)) # print(np.argmax(model.predict([sunflower_data[:, 0], sunflower_data[:, 1]]),axis=1))

View File

@ -1,6 +1,10 @@
import os
import threading
import multiprocessing
import pandas as pd
import numpy as np
import pyaudio import pyaudio
from pysndfile import sndio as snd from pysndfile import sndio as snd
import numpy as np
# from matplotlib import pyplot as plt # from matplotlib import pyplot as plt
from speech_spectrum import plot_stft, generate_spec_frec from speech_spectrum import plot_stft, generate_spec_frec
@ -61,3 +65,88 @@ def record_spectrogram(n_sec, plot=False, playback=False):
p_oup.terminate() p_oup.terminate()
ims, _ = generate_spec_frec(one_channel, SAMPLE_RATE) ims, _ = generate_spec_frec(one_channel, SAMPLE_RATE)
return ims return ims
def _apply_df(args):
df, func, num, kwargs = args
return num, df.apply(func, **kwargs)
def apply_by_multiprocessing(df,func,**kwargs):
cores = multiprocessing.cpu_count()
workers=kwargs.pop('workers') if 'workers' in kwargs else cores
pool = multiprocessing.Pool(processes=workers)
result = pool.map(_apply_df, [(d, func, i, kwargs) for i,d in enumerate(np.array_split(df, workers))])
pool.close()
result=sorted(result,key=lambda x:x[0])
return pd.concat([i[1] for i in result])
def square(x):
return x**x
if __name__ == '__main__':
df = pd.DataFrame({'a':range(10), 'b':range(10)})
apply_by_multiprocessing(df, square, axis=1, workers=4)
def rm_rf(d):
for path in (os.path.join(d,f) for f in os.listdir(d)):
if os.path.isdir(path):
rm_rf(path)
else:
os.unlink(path)
os.rmdir(d)
def create_dir(direc):
if not os.path.exists(direc):
os.makedirs(direc)
else:
rm_rf(direc)
create_dir(direc)
#################### Now make the data generator threadsafe ####################
class threadsafe_iter:
"""Takes an iterator/generator and makes it thread-safe by
serializing call to the `next` method of given iterator/generator.
"""
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self): # Py3
with self.lock:
return next(self.it)
def next(self): # Py2
with self.lock:
return self.it.next()
def threadsafe_generator(f):
"""A decorator that takes a generator function and makes it thread-safe.
"""
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw))
return g
def format_filename(s):
"""
Take a string and return a valid filename constructed from the string.
Uses a whitelist approach: any characters not present in valid_chars are
removed. Also spaces are replaced with underscores.
Note: this method may produce invalid filenames such as ``, `.` or `..`
When I use this method I prepend a date string like '2009_01_15_19_46_32_'
and append a file extension like '.txt', so I avoid the potential of using
an invalid filename.
"""
valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
filename = ''.join(c for c in s if c in valid_chars)
filename = filename.replace(' ','_') # I don't like spaces in filenames.
return filename

View File

@ -1,74 +0,0 @@
import os
import threading
import multiprocessing
import pandas as pd
import numpy as np
def _apply_df(args):
df, func, num, kwargs = args
return num, df.apply(func, **kwargs)
def apply_by_multiprocessing(df,func,**kwargs):
cores = multiprocessing.cpu_count()
workers=kwargs.pop('workers') if 'workers' in kwargs else cores
pool = multiprocessing.Pool(processes=workers)
result = pool.map(_apply_df, [(d, func, i, kwargs) for i,d in enumerate(np.array_split(df, workers))])
pool.close()
result=sorted(result,key=lambda x:x[0])
return pd.concat([i[1] for i in result])
def square(x):
return x**x
if __name__ == '__main__':
df = pd.DataFrame({'a':range(10), 'b':range(10)})
apply_by_multiprocessing(df, square, axis=1, workers=4)
def rm_rf(d):
for path in (os.path.join(d,f) for f in os.listdir(d)):
if os.path.isdir(path):
rm_rf(path)
else:
os.unlink(path)
os.rmdir(d)
def create_dir(direc):
if not os.path.exists(direc):
os.makedirs(direc)
else:
rm_rf(direc)
create_dir(direc)
#################### Now make the data generator threadsafe ####################
class threadsafe_iter:
"""Takes an iterator/generator and makes it thread-safe by
serializing call to the `next` method of given iterator/generator.
"""
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self): # Py3
with self.lock:
return next(self.it)
def next(self): # Py2
with self.lock:
return self.it.next()
def threadsafe_generator(f):
"""A decorator that takes a generator function and makes it thread-safe.
"""
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw))
return g