1. 程式人生 > 實用技巧 >pytorch固定BN層引數

pytorch固定BN層引數

背景:基於PyTorch的模型,想固定主分支引數,只訓練子分支,結果發現在不同epoch相同的測試資料經過主分支輸出的結果不同。

原因:未固定主分支BN層中的running_meanrunning_var

解決方法:將需要固定的BN層狀態設定為eval

問題示例

環境:torch:1.7.0

# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 5)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

def print_parameter_grad_info(net):
    print('-------parameters requires grad info--------')
    for name, p in net.named_parameters():
        print(f'{name}:\t{p.requires_grad}')

def print_net_state_dict(net):
    for key, v in net.state_dict().items():
        print(f'{key}')

if __name__ == "__main__":
    net = Net()

    print_parameter_grad_info(net)
    net.requires_grad_(False)
    print_parameter_grad_info(net)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假設每個epoch只迭代一次
        net.train()
        pre = net(train_data)
        # 計算損失和引數更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

執行結果:

-------parameters requires grad info--------
conv1.weight:   True
conv1.bias:     True
bn1.weight:     True
bn1.bias:       True
conv2.weight:   True
conv2.bias:     True
bn2.weight:     True
bn2.bias:       True
fc1.weight:     True
fc1.bias:       True
fc2.weight:     True
fc2.bias:       True
fc3.weight:     True
fc3.bias:       True
-------parameters requires grad info--------
conv1.weight:   False
conv1.bias:     False
bn1.weight:     False
bn1.bias:       False
conv2.weight:   False
conv2.bias:     False
bn2.weight:     False
bn2.bias:       False
fc1.weight:     False
fc1.bias:       False
fc2.weight:     False
fc2.bias:       False
fc3.weight:     False
fc3.bias:       False
epoch:0 tensor([[-0.0755,  0.1138,  0.0966,  0.0564, -0.0224]])
epoch:1 tensor([[-0.0763,  0.1113,  0.0970,  0.0574, -0.0235]])

可以看到:

net.requires_grad_(False)已經將網路中的各引數設定成了不需要梯度更新的狀態,但是同樣的測試資料test_data在不同epoch中前向之後出現了不同的結果。

呼叫print_net_state_dict可以看到BN層中的引數running_meanrunning_var並沒在可優化引數net.parameters

bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked

但在training pahse的前向過程中,這兩個引數被更新了。導致整個網路在freeze

的情況下,同樣的測試資料出現了不同的結果

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source

因此在training phase時對BN層顯式設定eval狀態:

if __name__ == "__main__":
    net = Net()
    net.requires_grad_(False)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假設每個epoch只迭代一次
        net.train()
        net.bn1.eval()
        net.bn2.eval()
        pre = net(train_data)
        # 計算損失和引數更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

可以看到結果正常了:

epoch:0 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])

交流基地:630390733