【吳恩達Tensorflow 2.0實踐課】2.2 Transfer learning
阿新 • • 發佈:2021-01-17
技術標籤:TensorFlow卷積深度學習tensorflow
2.2.1 Transfer learning - the concepts & coding
遷移學習就是把已經訓練好的模型、引數,遷移至另外的一個新模型上使得我們不需要從零開始重新訓練一個新model。
使用 Image-net.org世界上影象識別最大的資料庫
已經訓練好的模型“快照”存放在這裡:https://storage.googleapis.com/mledu-datasets/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
1.下載模型"快照"
import ssl import urllib url = 'https://storage.googleapis.com/mledu-datasets/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5' ssl._create_default_https_context = ssl._create_unverified_context urllib.request.urlopen(url) wget.download(url, out='inception_v3_weights_tf_dim_ordering_tf_kernels_notop')
可將引數載入到模型的骨架中,使之程式設計訓練好的模型。keras有內建的模型定義,指定不需要的權重層。
Inception-V3: 在頂部具有全連線的層。include_top設定為false,將忽略全連線層,直接進入卷積層。
這裡建議將檔名刪除後面的 '.h5' 不然老報錯
2.載入模型"快照"
local_weights_file = 'inception_v3_weights_tf_dim_ordering_tf_kernels_notop' pre_trained_model = InceptionV3(input_shape=(150, 150, 3), include_top=False, weights=None) # 遍歷所有層並鎖定它們 pre_trained_model.load_weights(local_weights_file) for layer in pre_trained_model.layers: layer.trainable = False # 列印預模型摘要 # pre_trained_model.summary()
最後一層已卷積到3x3
希望保留更多資訊,所以將最底層卷積到7x7
last_layer = pre_trained_model.get_layer('mixed7')
print('last layer output shape: ', last_layer.output_shape)
last_output = last_layer.output
3. 編譯
from tensorflow.keras.optimizers import RMSprop
# 將輸出層扁平化到1維
x = layers.Flatten()(last_output)
# 增加一層1024的全連線層
x = layers.Dense(1024, activation='relu')(x)
# 新增 dropout 值 0.2,意味著圖片中不重要的資訊/特徵就不參與計算
x = layers.Dropout(0.2)(x)
# 新增最後一層作為分類
x = layers.Dense (1, activation='sigmoid')(x)
model = Model(pre_trained_model.input, x)
model.compile( optimizer=RMSprop(lr=0.0001),
loss = 'binary_crossentropy',
metrics=['acc'])
4. 訓練貓狗資料集
# -------------------------------------------------------- #
# 4. 訓練、驗證資料集
# -------------------------------------------------------- #
# import ssl
# import urllib
# url = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
# ssl._create_default_https_context = ssl._create_unverified_context
# urllib.request.urlopen(url)
# wget.download(url, out='./tmp/cats_and_dogs_filtered.zip')
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import zipfile
local_zip = './tmp/cats_and_dogs_filtered.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('./temp')
zip_ref.close()
# 定義目錄
base_dir = './tmp/cats_and_dogs_filtered'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
train_cat_dir = os.path.join(train_dir, 'cats')
train_dog_dir = os.path.join(train_dir, 'dogs')
validation_cat_dir = os.path.join(validation_dir, 'cats')
validation_dog_dir = os.path.join(validation_dir, 'dogs')
train_cat_fnames = os.listdir(train_cat_dir)
train_dog_fnames = os.listdir(train_dog_dir)
# 訓練資料集圖片生成器
train_datagen = ImageDataGenerator(
rescale=1.0/255.,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
test_datagen = ImageDataGenerator(rescale=1.0/255.)
train_generator = train_datagen.flow_from_directory(train_dir,
batch_size=20,
class_mode='binary',
target_size=(150, 150))
validation_generator = test_datagen.flow_from_directory(validation_dir,
batch_size=20,
class_mode='binary',
target_size=(150, 150))
# -------------------------------------------------------- #
# 5. 訓練
# -------------------------------------------------------- #
history = model.fit_generator(
train_generator,
validation_data=validation_generator,
steps_per_epoch=100,
epochs=20,
validation_steps=50,
verbose=2)
# -------------------------------------------------------- #
# 6. 顯示結果
# -------------------------------------------------------- #
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs=range(len(acc))
plt.plot(epochs, acc, 'r', label='Training accruracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.figure()
plt.show()