tensorflow2.0——影象增強的訓練相關程式碼
阿新 • • 發佈:2020-12-17
(x,y),(x_test,y_test) = tf.keras.datasets.cifar10.load_data() x = x/255. x = tf.cast(x,tf.float32) x_test = x_test/255. x_test = tf.cast(x_test,tf.float32) y = tf.one_hot(y,depth=10) y=tf.squeeze(y) y_test = tf.one_hot(y_test,depth=10) y_test=tf.squeeze(y_test) datagen = tf.keras.preprocessing.image.ImageDataGenerator(# set input mean to 0 over the dataset featurewise_center=False, # set each sample mean to 0 samplewise_center=False, # divide inputs by std of dataset featurewise_std_normalization=False, # divide each input by its std samplewise_std_normalization=False,# apply ZCA whitening zca_whitening=False, # epsilon for ZCA whitening zca_epsilon=1e-06, # randomly rotate images in the range (deg 0 to 180) rotation_range=0, # randomly shift images horizontally width_shift_range=0.1, # randomly shift images verticallyheight_shift_range=0.1, # set range for random shear shear_range=0., # set range for random zoom zoom_range=0., # set range for random channel shifts channel_shift_range=0., # set mode for filling points outside the input boundaries fill_mode='nearest', # value used for fill_mode = "constant" cval=0., # randomly flip images horizontal_flip=True, # randomly flip images vertical_flip=False, # set rescaling factor (applied before any other transformation) rescale=None, # set function that will be applied on each input preprocessing_function=None, # image data format, either "channels_first" or "channels_last" data_format=None, # fraction of images reserved for validation (strictly between 0 and 1) validation_split=0.0) datagen.fit(x) res_model = resnet18() res_model.compile(optimizer=opt,loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy']) # res_model.fit(x,y,validation_data=(x_test,y_test),epochs=epochs,batch_size=bat, validation_freq=1) res_model.fit_generator(datagen.flow(x,y,batch_size=bat),validation_data=(x_test,y_test),epochs=epochs, validation_freq=1)