Keras中的多輸入ImageDataGenerator圖片生成器
阿新 • • 發佈:2019-02-16
- keras:2.1.2
- tensorflow:1.4.0
- python3.6
- win7
from keras import backend as K
import numpy as np
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator, Iterator
from keras.utils import np_utils
class ImageDataGenerator_Triplet(ImageDataGenerator):
def flow(self, basepath, batch_size=32 , class_num=731, input_size=299,
train_vali_flag = 'train',
shuffle=False, seed=None,
save_to_dir=None, save_prefix='', save_format='png'):
return NumpyArrayIterator_Triplet(
self,
class_num = class_num,
batch_size=batch_size,
input_size=input_size,
train_vali_flag=train_vali_flag,
basepath=basepath,
shuffle=shuffle,
seed=seed)
class NumpyArrayIterator_Triplet(Iterator):
def __init__(self, image_data_generator, class_num, input_size,
train_vali_flag,basepath,
batch_size=32, shuffle=False, seed=None):
self.image_data_generator = image_data_generator
self.class_num = class_num
self.input_size = input_size
self.train_vali_flag = train_vali_flag
self.data_format = K.image_data_format()
self.basepath = basepath
super(NumpyArrayIterator_Triplet, self).__init__(8848 , batch_size*3, shuffle, seed)
def _get_batches_of_transformed_samples(self, index_array):
batch_x = np.zeros(tuple([len(index_array)] + [self.input_size, self.input_size, 3]), dtype=K.floatx())
batch_y = np.zeros([len(index_array), 1])
batch_z = np.zeros([len(index_array), 1])
for i in range(len(index_array) // 3):
ka, kb = np.random.randint(low=0, high=self.class_num, size=2) # 隨機生成二位陣列
while ka == kb:
ka, kb = np.random.randint(low=0, high=self.class_num, size=2)
if self.train_vali_flag == 'train':
kc, kd = np.random.choice([1, 2, 4, 5], 2)
elif self.train_vali_flag == 'test':
kc, kd = np.random.choice([3, 6], 2)
else:
raise('param train_vali_flag must be choosen from train and vali')
img_achor = Image.open(self.basepath + self.train_vali_flag + '/' + str(ka) + '/' + str(kc) + '.bmp')
x_anchor = np.array(img_achor.resize([self.input_size, self.input_size]))
img_pos = Image.open(self.basepath + self.train_vali_flag + '/' + str(ka) + '/' + str(kd) + '.bmp')
x_pos = np.array(img_pos.resize([self.input_size, self.input_size]))
img_neg = Image.open(self.basepath + self.train_vali_flag + '/' + str(kb) + '/' + str(kd) + '.bmp')
x_neg = np.array(img_neg.resize([self.input_size, self.input_size]))
x_anchor = self.image_data_generator.random_transform(x_anchor.astype(K.floatx()))
x_anchor = self.image_data_generator.standardize(x_anchor)
x_pos = self.image_data_generator.random_transform(x_pos.astype(K.floatx()))
x_pos = self.image_data_generator.standardize(x_pos)
x_neg = self.image_data_generator.random_transform(x_neg.astype(K.floatx()))
x_neg = self.image_data_generator.standardize(x_neg)
batch_x[i] = x_anchor
batch_x[i + len(index_array) // 3] = x_pos
batch_x[i + len(index_array) // 3 * 2] = x_neg
batch_y[i] = ka
batch_y[i + len(index_array) // 3] = ka
batch_y[i + len(index_array) // 3 * 2] = kb
batch_y = np_utils.to_categorical(batch_y, self.class_num)
#print(batch_x.shape)
return batch_x, [batch_y, batch_z]
def next(self):
"""For python 2.x.
# Returns
The next batch.
"""
# Keeps under lock only the mechanism which advances
# the indexing of each batch.
with self.lock:
index_array = next(self.index_generator)
# The transformation of images is not under thread lock
# so it can be done in parallel
return self._get_batches_of_transformed_samples(index_array)
def __getitem__(self, idx):
if self.index_array is None:
self._set_index_array()
index_array = self.index_array[0: self.batch_size]
return self._get_batches_of_transformed_samples(index_array)
def _flow_index(self):
# Ensure self.batch_index is 0.
self.reset()
while 1:
yield self.index_array[0: self.batch_size]