slim 讀取並使用預訓練模型 inception_v3 遷移學習
轉自:https://blog.csdn.net/amanfromearth/article/details/79155926#commentBox
在使用Tensorflow做讀取並finetune的時候,發現在讀取官方給的inception_v3預訓練模型總是出現各種錯誤,現記錄其正確的讀取方式和各種錯誤做法:
關鍵程式碼如下:
import tensorflow as tf import tensorflow.contrib.slim as slim from tensorflow.contrib.slim.python.slim.nets import inception_v3 from research.slim.preprocessing import inception_preprocessing Pretrained_model_dir = '/Users/apple/tensorflow_model/models-master/research/slim/pre_train/inception_v3.ckpt' image_size = 299 # 讀取網路 with slim.arg_scope(inception_v3.inception_v3_arg_scope()): imgPath = 'test.jpg' testImage_string = tf.gfile.FastGFile(imgPath, 'rb').read() testImage = tf.image.decode_jpeg(testImage_string, channels=3) processed_image = inception_preprocessing.preprocess_image(testImage, image_size, image_size, is_training=False) processed_images = tf.expand_dims(processed_image, 0) logits, end_points = inception_v3.inception_v3(processed_images, num_classes=128, is_training=False) w1 = tf.Variable(tf.truncated_normal([128, 5], stddev=tf.sqrt(1/128))) b1 = tf.Variable(tf.zeros([5])) logits = tf.nn.leaky_relu(tf.matmul(logits, w1) + b1) with tf.Session() as sess: # 先初始化所有變數,避免有些變數未讀取而產生錯誤 init = tf.global_variables_initializer() sess.run(init) #載入預訓練模型 print('Loading model check point from {:s}'.format(Pretrained_model_dir)) #這裡的exclusions是不需要讀取預訓練模型中的Logits,因為預設的類別數目是1000,當你的類別數目不是1000的時候,如果還要讀取的話,就會報錯 exclusions = ['InceptionV3/Logits', 'InceptionV3/AuxLogits'] #建立一個列表,包含除了exclusions之外所有需要讀取的變數 inception_except_logits = slim.get_variables_to_restore(exclude=exclusions) #建立一個從預訓練模型checkpoint中讀取上述列表中的相應變數的引數的函式 init_fn = slim.assign_from_checkpoint_fn(Pretrained_model_dir, inception_except_logits,ignore_missing_vars=True) #執行該函式 init_fn(sess) print('Loaded.') out = sess.run(logits) print(out.shape) print(out)
其中可能會出現的錯誤如下:
錯誤1
- 1
- 2
- 3
原因:
預訓練模型中的類別數class_num=1000,這裡輸入的class_num=5,當讀取完整模型的時候當然會出錯。
解決方案:
選擇不讀取包含類別數的Logits層和AuxLogits層:
- 1
- 2
錯誤2
Tensor name “xxxx” not found in checkpoint files
- 1
- 2
- 3
- 4
這裡的Tensor name可以是所有inception_v3中變數的名字,出現這種情況的各種原因和解決方案是:
1.建立圖的時候沒有用arg_scope,是這樣建立的:
- 1
解決方案:
在這裡加上arg_scope,裡面呼叫的是庫中自帶的inception_v3_arg_scope
- 1
- 2
2.在讀取checkpoint的時候未初始化所有變數,即未執行
- 1
- 2
這樣會導致有一些checkpoint中不存在的變數未被初始化,比如使用Momentum時的每一層的Momentum引數等。
3.使用slim.assign_from_checkpoint_fn()
函式時,沒有新增ignore_missing_vars=True
屬性,由於預設ignore_missing_vars=False,所以,當使用非SGD的optimizer的時候(如Momentum、RMSProp等)時,會提示Momentum或者RMSProp的引數在checkpoint中無法找到,如:
使用Momentum時:
- 1
- 2
- 3
- 4
使用RMSProp時:
- 1
- 2
- 3
- 4
解決方法很簡單,就是把ignore_missing_vars=True
- 1
注意:一定要在之前的步驟都完成之後才能設成True,不然如果變數名稱全部出錯的話,會忽視掉checkpoint中所有的變數,從而不讀取任何引數。
以上就是我碰見的問題,希望有所幫助。