1. 程式人生 > 其它 >tensorflow學習025——常見的預訓練網路模型及使用示例

tensorflow學習025——常見的預訓練網路模型及使用示例

在ImageNet上與訓練過的用於影象分類的模型:VGG16, VGG19, ResNet50, InceptionV3, InceptionResNetV2, Xception, MobileNet, MobileNetV2, DenseNet, NASNet

是在ImageNet上1000個分類的準確率
top1——你預測的label取最後概率向量裡面最大的那一個作為預測結果,如果你的預測結果中概率最大的那個分類正確,則預測正確,否則預測錯誤
top5——就是最後概率向量最大的前五名中,只要出現了正確預測,則為預測準確,否則預測錯誤。
從上面途中我們可以看出,VGG16 VGG19是比較落後的,訓練引數多但是準確率低。MobileNet MobileNetV2 僅僅十幾M,但是精度也不低,可以部署於手機上。
在ImageNet上預訓練的XceptionV1模型,在ImageNet上,該模型取得了驗證集top1 0.790和top 5 0.945的準確率。需要注意的是該模型只支援channels_last的維度順序(高度、寬度、通道),該模型預設輸入尺寸是299*299
其它訓練網路引數可通過網址(

https://keras.io/zh/applications)中檢視。
Xception網路訓練貓狗資料集程式碼
資料鏈接:https://pan.baidu.com/s/1-nLlW6Nng1pAvrxwEfHttA
提取碼:vt8p

點選檢視程式碼
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import glob
import os


train_image_path = glob.glob(r'E:\WORK\tensorflow\dataset\dc_2000\train\*\*.jpg')
print(len(train_image_path))
print(train_image_path[:5])
train_image_label = [int(p.split("\\")[-2] == 'cat') for p in train_image_path]

def load_preprosess_image(path,label):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image,channels=3)
    image = tf.image.resize(image,[256,256])
    image = tf.cast(image,tf.float32)
    image = image / 255
    return image,label

train_image_ds = tf.data.Dataset.from_tensor_slices((train_image_path,train_image_label))
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_image_ds = train_image_ds.map(load_preprosess_image,num_parallel_calls=AUTOTUNE)

BATCH_SZIE = 32
train_count = len(train_image_path)
train_image_ds = train_image_ds.shuffle(train_count).repeat().batch(BATCH_SZIE)

test_image_path = glob.glob(r'E:\WORK\tensorflow\dataset\dc_2000\test\*\*.jpg')
test_image_label = [int(p.split("\\")[-2] == 'cat') for p in test_image_path]
test_image_ds = tf.data.Dataset.from_tensor_slices((test_image_path,test_image_label))
test_image_ds = test_image_ds.map(load_preprosess_image,num_parallel_calls=AUTOTUNE)
test_image_ds = test_image_ds.repeat().batch(BATCH_SZIE)

test_count = len(test_image_path)

conv_base = tf.keras.applications.xception.Xception(weights='imagenet',
                                                    include_top=False,
                                                    input_shape=(256,256,3),
                                                    pooling='avg')
conv_base.trainable = False

model = tf.keras.Sequential()
model.add(conv_base)
model.add(tf.keras.layers.Dense(512,activation='relu'))
model.add(tf.keras.layers.Dense(1,activation='sigmoid'))

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
              loss='binary_crossentropy',
              metrics=['acc'])
initial_epoches = 5

histroy = model.fit(train_image_ds,
                    steps_per_epoch=train_count//BATCH_SZIE,
                    epochs=initial_epoches,
                    validation_data=test_image_ds,
                    validation_steps=test_count//BATCH_SZIE)

conv_base.trainable = True

for layer in conv_base.layers[:-33]:  # 原先弓133層
    layer.trainable = False

model.compile(loss='binary_crossentropy',
              optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005/10),
              metrics=['accuracy'])

fine_tune_epoches = 5
total_epoches = initial_epoches + fine_tune_epoches

histroy = model.fit(train_image_ds,
                    steps_per_epoch=train_count//BATCH_SZIE,
                    epochs=total_epoches,
                    initial_epoch=initial_epoches,
                    validation_data=test_image_ds,
                    validation_steps=test_count//BATCH_SZIE)


作者:孫建釗
出處:http://www.cnblogs.com/sunjianzhao/
本文版權歸作者和部落格園共有,歡迎轉載,但未經作者同意必須保留此段宣告,且在文章頁面明顯位置給出原文連線,否則保留追究法律責任的權利。