1. 程式人生 > >Keras中的多輸入ImageDataGenerator圖片生成器

Keras中的多輸入ImageDataGenerator圖片生成器

  • keras:2.1.2
  • tensorflow:1.4.0
  • python3.6
  • win7
from keras import backend as K
import numpy as np
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator, Iterator
from keras.utils import np_utils
class ImageDataGenerator_Triplet(ImageDataGenerator):
    def flow(self, basepath, batch_size=32
, class_num=731, input_size=299, train_vali_flag = 'train', shuffle=False, seed=None, save_to_dir=None, save_prefix='', save_format='png')
:
return NumpyArrayIterator_Triplet( self, class_num = class_num, batch_size=batch_size, input_size=input_size, train_vali_flag=train_vali_flag, basepath=basepath, shuffle=shuffle, seed=seed) class
NumpyArrayIterator_Triplet(Iterator):
def __init__(self, image_data_generator, class_num, input_size, train_vali_flag,basepath, batch_size=32, shuffle=False, seed=None): self.image_data_generator = image_data_generator self.class_num = class_num self.input_size = input_size self.train_vali_flag = train_vali_flag self.data_format = K.image_data_format() self.basepath = basepath super(NumpyArrayIterator_Triplet, self).__init__(8848
, batch_size*3, shuffle, seed) def _get_batches_of_transformed_samples(self, index_array): batch_x = np.zeros(tuple([len(index_array)] + [self.input_size, self.input_size, 3]), dtype=K.floatx()) batch_y = np.zeros([len(index_array), 1]) batch_z = np.zeros([len(index_array), 1]) for i in range(len(index_array) // 3): ka, kb = np.random.randint(low=0, high=self.class_num, size=2) # 隨機生成二位陣列 while ka == kb: ka, kb = np.random.randint(low=0, high=self.class_num, size=2) if self.train_vali_flag == 'train': kc, kd = np.random.choice([1, 2, 4, 5], 2) elif self.train_vali_flag == 'test': kc, kd = np.random.choice([3, 6], 2) else: raise('param train_vali_flag must be choosen from train and vali') img_achor = Image.open(self.basepath + self.train_vali_flag + '/' + str(ka) + '/' + str(kc) + '.bmp') x_anchor = np.array(img_achor.resize([self.input_size, self.input_size])) img_pos = Image.open(self.basepath + self.train_vali_flag + '/' + str(ka) + '/' + str(kd) + '.bmp') x_pos = np.array(img_pos.resize([self.input_size, self.input_size])) img_neg = Image.open(self.basepath + self.train_vali_flag + '/' + str(kb) + '/' + str(kd) + '.bmp') x_neg = np.array(img_neg.resize([self.input_size, self.input_size])) x_anchor = self.image_data_generator.random_transform(x_anchor.astype(K.floatx())) x_anchor = self.image_data_generator.standardize(x_anchor) x_pos = self.image_data_generator.random_transform(x_pos.astype(K.floatx())) x_pos = self.image_data_generator.standardize(x_pos) x_neg = self.image_data_generator.random_transform(x_neg.astype(K.floatx())) x_neg = self.image_data_generator.standardize(x_neg) batch_x[i] = x_anchor batch_x[i + len(index_array) // 3] = x_pos batch_x[i + len(index_array) // 3 * 2] = x_neg batch_y[i] = ka batch_y[i + len(index_array) // 3] = ka batch_y[i + len(index_array) // 3 * 2] = kb batch_y = np_utils.to_categorical(batch_y, self.class_num) #print(batch_x.shape) return batch_x, [batch_y, batch_z] def next(self): """For python 2.x. # Returns The next batch. """ # Keeps under lock only the mechanism which advances # the indexing of each batch. with self.lock: index_array = next(self.index_generator) # The transformation of images is not under thread lock # so it can be done in parallel return self._get_batches_of_transformed_samples(index_array) def __getitem__(self, idx): if self.index_array is None: self._set_index_array() index_array = self.index_array[0: self.batch_size] return self._get_batches_of_transformed_samples(index_array) def _flow_index(self): # Ensure self.batch_index is 0. self.reset() while 1: yield self.index_array[0: self.batch_size]