1. 程式人生 > 其它 >tensorflow的ckpt檔案轉為npy檔案

tensorflow的ckpt檔案轉為npy檔案

技術標籤:tensorflow機器學習神經網路深度學習

我們遷移模型引數時,需要將ckpt檔案轉為npy檔案,以VGG19為例:

"""
將ckpt檔案轉化為npy檔案
"""

import numpy as np
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow


def ckpt2npy():
    checkpoint_path = 'model/model.ckpt-4999'
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()

    # 定義好ckpt模型的每一層
    vgg19 = {'conv1_1': [[], []], 'conv1_2': [[], []], 'conv2_1': [[], []], 'conv2_2': [[], []], 'conv3_1': [[], []],
             'conv3_2': [[], []], 'conv3_3': [[], []], 'conv3_4': [[], []], 'conv4_1': [[], []], 'conv4_2': [[], []],
             'conv4_3': [[], []], 'conv4_4': [[], []], 'conv5_1': [[], []], 'conv5_2': [[], []], 'conv5_3': [[], []],
             'conv5_4': [[], []], 'fc6': [[], []], 'fc7': [[], []], 'fc8': [[], []]}

    for key in var_to_shape_map:
        str_name = key
        print('tensor_name:', str_name)

        if str_name.find('/') > -1:
            names = str_name.split('/')
            # first layer name and weight, bias
            layer_name = names[0]
            layer_add_info = names[1]
        else:
            layer_name = str_name
            layer_add_info = None

        if layer_add_info == 'filter':
            vgg19[layer_name][0] = reader.get_tensor(key)
        if layer_add_info == 'weights':
            vgg19[layer_name][0] = reader.get_tensor(key)
        elif layer_add_info == 'bias':
            vgg19[layer_name][1] = reader.get_tensor(key)
        else:
            vgg19[layer_name] = reader.get_tensor(key)

    np.save('vgg19_2.npy', vgg19)

如果你的神經網路的引數每一層都是隨機初始化,沒有使用預訓練模型,直接用上面的方法就行。
然而,我們經常使用預訓練模型,例如VGG19,凍結前面的卷積層,只訓練全連線層。如果你使用了預訓練模型,且預訓練的模型檔案也是npy,那你生成的ckpt檔案裡面的卷積層引數是空的,因為你訓練時是直接載入的原npy檔案引數,如下圖:
在這裡插入圖片描述
這時,我們還需要將原npy檔案的卷積層引數,搬到新的npy檔案中,方法如下:

#  先載入兩個npy檔案
vgg19 = np.load('vgg19.npy', encoding='latin1').item()  # 別人預訓練的npy檔案
vgg19_2 = np.load('vgg19_2.npy', encoding='latin1').item()  # ckpt2npy()中儲存的npy檔案(沒有卷積層引數)
conv_layer = ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv3_4',
              'conv4_1', 'conv4_2', 'conv4_3', 'conv4_4', 'conv5_1', 'conv5_2', 'conv5_3', 'conv5_4']

for con in conv_layer:
    vgg19_2[con][0] = vgg19[con][0]
    vgg19_2[con][1] = vgg19[con][1]

np.save('vgg19_3.npy', vgg19_2)  # 儲存新的npy檔案,有卷積層和全連線層引數