tensorflow的ckpt檔案轉為npy檔案
阿新 • • 發佈:2021-01-20
我們遷移模型引數時,需要將ckpt檔案轉為npy檔案,以VGG19為例:
""" 將ckpt檔案轉化為npy檔案 """ import numpy as np import tensorflow as tf from tensorflow.python import pywrap_tensorflow def ckpt2npy(): checkpoint_path = 'model/model.ckpt-4999' reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() # 定義好ckpt模型的每一層 vgg19 = {'conv1_1': [[], []], 'conv1_2': [[], []], 'conv2_1': [[], []], 'conv2_2': [[], []], 'conv3_1': [[], []], 'conv3_2': [[], []], 'conv3_3': [[], []], 'conv3_4': [[], []], 'conv4_1': [[], []], 'conv4_2': [[], []], 'conv4_3': [[], []], 'conv4_4': [[], []], 'conv5_1': [[], []], 'conv5_2': [[], []], 'conv5_3': [[], []], 'conv5_4': [[], []], 'fc6': [[], []], 'fc7': [[], []], 'fc8': [[], []]} for key in var_to_shape_map: str_name = key print('tensor_name:', str_name) if str_name.find('/') > -1: names = str_name.split('/') # first layer name and weight, bias layer_name = names[0] layer_add_info = names[1] else: layer_name = str_name layer_add_info = None if layer_add_info == 'filter': vgg19[layer_name][0] = reader.get_tensor(key) if layer_add_info == 'weights': vgg19[layer_name][0] = reader.get_tensor(key) elif layer_add_info == 'bias': vgg19[layer_name][1] = reader.get_tensor(key) else: vgg19[layer_name] = reader.get_tensor(key) np.save('vgg19_2.npy', vgg19)
如果你的神經網路的引數每一層都是隨機初始化,沒有使用預訓練模型,直接用上面的方法就行。
然而,我們經常使用預訓練模型,例如VGG19,凍結前面的卷積層,只訓練全連線層。如果你使用了預訓練模型,且預訓練的模型檔案也是npy,那你生成的ckpt檔案裡面的卷積層引數是空的,因為你訓練時是直接載入的原npy檔案引數,如下圖:
這時,我們還需要將原npy檔案的卷積層引數,搬到新的npy檔案中,方法如下:
# 先載入兩個npy檔案 vgg19 = np.load('vgg19.npy', encoding='latin1').item() # 別人預訓練的npy檔案 vgg19_2 = np.load('vgg19_2.npy', encoding='latin1').item() # ckpt2npy()中儲存的npy檔案(沒有卷積層引數) conv_layer = ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv3_4', 'conv4_1', 'conv4_2', 'conv4_3', 'conv4_4', 'conv5_1', 'conv5_2', 'conv5_3', 'conv5_4'] for con in conv_layer: vgg19_2[con][0] = vgg19[con][0] vgg19_2[con][1] = vgg19[con][1] np.save('vgg19_3.npy', vgg19_2) # 儲存新的npy檔案,有卷積層和全連線層引數