Pytorch學習(十七)--- 模型load各種問題解決
簡單的模型load
一般來說,儲存模型是把引數全部用model.cpu().state_dict()
, 然後載入模型時一般用 model.load_state_dict(torch.load(model_path))
。 值得注意的是:torch.load
返回的是一個 OrderedDict.
import torch
import torch.nn as nn
class Net_old(nn.Module):
def __init__(self):
super(Net_old, self).__init__()
self.nets = nn.Sequential(
torch.nn.Conv2d(1 , 2, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(2, 1, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(1, 1, 3)
)
def forward(self, x):
return self.nets(x)
class Net_new(nn.Module):
def __init__(self):
super(Net_old, self).__init__()
self.conv1 = torch.nn.Conv2d(1 , 2, 3)
self.r1 = torch.nn.ReLU(True)
self.conv2 = torch.nn.Conv2d(2, 1, 3)
self.r2 = torch.nn.ReLU(True)
self.conv3 = torch.nn.Conv2d(1, 1, 3)
def forward(self, x):
x = self.conv1(x)
x = self.r1(x)
x = self.conv2(x)
x = self.r2(x)
x = self.conv3(x)
return x
network = Net_old()
torch.save(network.cpu().state_dict(), 't.pth')
pretrained_net = torch.load('t.pth')
print(pretrained_net)
for key, v in enumerate(pretrained_net):
print key, v
可以看到
OrderedDict([('nets.0.weight',
(0 ,0 ,.,.) =
-0.2436 0.2523 0.3097
-0.0315 -0.1307 0.0759
0.0750 0.1894 -0.0761
(1 ,0 ,.,.) =
0.0280 -0.2178 0.0914
0.3227 -0.0121 -0.0016
-0.0654 -0.0584 -0.1655
[torch.FloatTensor of size 2x1x3x3]
), ('nets.0.bias',
-0.0507
-0.2836
[torch.FloatTensor of size 2]
), ('nets.2.weight',
(0 ,0 ,.,.) =
-0.2233 0.0279 -0.0511
-0.0242 -0.1240 -0.0511
0.2266 0.1385 -0.1070
(0 ,1 ,.,.) =
-0.0943 -0.1403 0.0979
-0.2163 0.1906 -0.2269
-0.1984 0.0843 -0.0719
[torch.FloatTensor of size 1x2x3x3]
), ('nets.2.bias',
-0.1420
[torch.FloatTensor of size 1]
), ('nets.4.weight',
(0 ,0 ,.,.) =
0.1981 -0.0250 0.2429
0.3012 0.2428 -0.0114
0.2878 -0.2134 0.1173
[torch.FloatTensor of size 1x1x3x3]
), ('nets.4.bias',
1.00000e-02 *
-5.8426
[torch.FloatTensor of size 1]
)])
0 nets.0.weight
1 nets.0.bias
2 nets.2.weight
3 nets.2.bias
4 nets.4.weight
5 nets.4.bias
說明.state_dict()
只是把所有模型的引數都以OrderedDict
的形式存下來。通過
for key, v in enumerate(pretrained_net):
print key, v
得知這些引數的順序!,當然要看具體的值
for key, v in pretrained_net.items():
print key, v
nets.0.weight
(0 ,0 ,.,.) =
-0.2444 -0.3148 0.1626
0.2531 -0.0859 -0.0236
0.1635 0.1113 -0.1110
(1 ,0 ,.,.) =
0.2374 -0.2931 -0.1806
-0.1456 0.2264 -0.0114
0.1813 0.1134 -0.2095
[torch.FloatTensor of size 2x1x3x3]
nets.0.bias
-0.3087
-0.2407
[torch.FloatTensor of size 2]
nets.2.weight
(0 ,0 ,.,.) =
-0.2206 -0.1151 -0.0783
0.0723 -0.2008 0.0568
-0.0964 -0.1505 -0.1203
(0 ,1 ,.,.) =
0.0131 0.1329 -0.1763
0.1276 -0.2025 -0.0075
-0.1167 -0.1833 0.1103
[torch.FloatTensor of size 1x2x3x3]
nets.2.bias
-0.1858
[torch.FloatTensor of size 1]
nets.4.weight
(0 ,0 ,.,.) =
-0.1019 0.0534 0.2018
-0.0600 -0.1389 -0.0275
0.0696 0.0360 0.1560
[torch.FloatTensor of size 1x1x3x3]
nets.4.bias
1.00000e-03 *
-5.6003
[torch.FloatTensor of size 1]
如果哪一天我們需要重新寫這個網路的,比如使用Net_new
,這個網路是將每一層都作為類的一個屬性。如果直接load
import torch
import torch.nn as nn
class Net_old(nn.Module):
def __init__(self):
super(Net_old, self).__init__()
self.nets = nn.Sequential(
torch.nn.Conv2d(1, 2, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(2, 1, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(1, 1, 3)
)
def forward(self, x):
return self.nets(x)
class Net_new(nn.Module):
def __init__(self):
super(Net_new, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 2, 3)
self.r1 = torch.nn.ReLU(True)
self.conv2 = torch.nn.Conv2d(2, 1, 3)
self.r2 = torch.nn.ReLU(True)
self.conv3 = torch.nn.Conv2d(1, 1, 3)
def forward(self, x):
x = self.conv1(x)
x = self.r1(x)
x = self.conv2(x)
x = self.r2(x)
x = self.conv3(x)
return x
network = Net_old()
torch.save(network.cpu().state_dict(), 't.pth')
pretrained_net = torch.load('t.pth')
# Show keys of pretrained model
for key, v in pretrained_net.items():
print key
# Define new network, and directly load the state_dict
new_network = Net_new()
new_network.load_state_dict(pretrained_net)
會出現unexpected key
nets.0.weight
nets.0.bias
nets.2.weight
nets.2.bias
nets.4.weight
nets.4.bias
Traceback (most recent call last):
File "Blog.py", line 44, in <module>
new_network.load_state_dict(pretrained_net)
File "/home/vis/xxx/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 522, in load_state_dict
.format(name))
KeyError: 'unexpected key "nets.0.weight" in state_dict'
這是因為,我們新的網路,都是“屬性形式的”,檢視新網路的state_dict
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
strict=False載入模型的正確解讀
你可能會決定
import torch
import torch.nn as nn
class Net_old(nn.Module):
def __init__(self):
super(Net_old, self).__init__()
self.nets = nn.Sequential(
torch.nn.Conv2d(1, 2, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(2, 1, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(1, 1, 3)
)
def forward(self, x):
return self.nets(x)
class Net_new(nn.Module):
def __init__(self):
super(Net_new, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 2, 3)
self.r1 = torch.nn.ReLU(True)
self.conv2 = torch.nn.Conv2d(2, 1, 3)
self.r2 = torch.nn.ReLU(True)
self.conv3 = torch.nn.Conv2d(1, 1, 3)
def forward(self, x):
x = self.conv1(x)
x = self.r1(x)
x = self.conv2(x)
x = self.r2(x)
x = self.conv3(x)
return x
old_network = Net_old()
torch.save(old_network.cpu().state_dict(), 't.pth')
pretrained_net = torch.load('t.pth')
# Show keys of pretrained model
for key, v in pretrained_net.items():
print key
print('****Before loading********')
new_network = Net_new()
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
for key, _ in new_network.state_dict().items():
print key
print('-----After loading------')
new_network.load_state_dict(pretrained_net, strict=False)
# So you think that this two values are the same?? Hah!
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
for key, _ in new_network.state_dict().items():
print key
輸出
nets.0.weight
nets.0.bias
nets.2.weight
nets.2.bias
nets.4.weight
nets.4.bias
****Before loading********
-0.882688805461
0.34207585454
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
-----After loading------
-0.882688805461
0.34207585454
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
數值一點變化都沒有,說明“strict=False”沒有那麼智慧! 它直接忽略那些沒有的dict,有相同的就複製,沒有就直接放棄賦值!
import torch
import torch.nn as nn
class Net_old(nn.Module):
def __init__(self):
super(Net_old, self).__init__()
self.nets = nn.Sequential(
torch.nn.Conv2d(1, 2, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(2, 1, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(1, 1, 3)
)
def forward(self, x):
return self.nets(x)
class Net_new(nn.Module):
def __init__(self):
super(Net_new, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 2, 3)
self.r1 = torch.nn.ReLU(True)
self.conv2 = torch.nn.Conv2d(2, 1, 3)
self.r2 = torch.nn.ReLU(True)
##### 在Net_new也加入了一個'nets'屬性
self.nets = nn.Sequential(
torch.nn.Conv2d(1, 2, 3)
)
def forward(self, x):
x = self.conv1(x)
x = self.r1(x)
x = self.conv2(x)
x = self.r2(x)
x = self.conv3(x)
x = self.nets(x)
return x
old_network = Net_old()
torch.save(old_network.cpu().state_dict(), 't.pth')
pretrained_net = torch.load('t.pth')
# Show keys of pretrained model
for key, v in pretrained_net.items():
print key
print('****Before loading********')
new_network = Net_new()
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
print(torch.sum(new_network.nets[0].weight.data))
for key, _ in new_network.state_dict().items():
print key
print('-----After loading------')
new_network.load_state_dict(pretrained_net, strict=False)
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
# Hopefully, this value equals to 'old_network.nets[0].weight'
print(torch.sum(new_network.nets[0].weight.data))
for key, _ in new_network.state_dict().items():
print key
結果:
nets.0.weight
nets.0.bias
nets.2.weight
nets.2.bias
nets.4.weight
nets.4.bias
****Before loading********
-0.197643771768
0.862508803606
1.21658478677
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
nets.0.weight
nets.0.bias
-----After loading------
-0.197643771768
0.862508803606
-0.197643771768
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
nets.0.weight
nets.0.bias
發現After loading之後,預期的兩個值一致。
總結:用strict=False進行載入模型,則“能塞則塞,不能塞則丟”。load一般是依據key來載入的,一旦有key不匹配則出錯。如果設定strict=False,則直接忽略不匹配的key,對於匹配的key則進行正常的賦值。
Strict=False的用途
所以說,當你一個模型訓練好之後,你想往裡面加幾層,那麼strict=False可以很容易的載入預訓練的引數(注意檢查key是否匹配)。只要key能讓其進性匹配則可以進行正確的賦值。
出現unexpected key module.xxx.weight問題
有時候你的模型儲存時含有 nn.DataParallel
時,就會發現所有的dict都會有 module
的字首。
這時候載入含有module
字首的模型時,可能會出錯。其實你只要移除這些字首即可
pretrained_net = Net_OLD()
pretrained_net_dict = torch.load(save_path)
new_state_dict = OrderedDict()
for k, v in pretrained_net_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
pretrained_net.load_state_dict(new_state_dict)
總結
- 儲存的Dict是按照
net.屬性.weight
來儲存的。如果這個屬性
是一個Sequential
,我們可以類似這樣net.seqConvs.0.weight
來獲得。
當然在定義的類中,拿到Sequential的某一層用[]
, 比如self.seqConvs[0].weight
. strict=False
是沒有那麼智慧,遵循有相同的key則賦值,否則直接丟棄。
附加
由於第一段的問題還沒解決,即如何將Sequential
定義的網路的模型引數,載入到用“屬性一層層”定義的網路中?
下面是一種比較ugly的方法:
import torch
import torch.nn as nn
class Net_old(nn.Module):
def __init__(self):
super(Net_old, self).__init__()
self.nets = nn.Sequential(
torch.nn.Conv2d(1, 2, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(2, 1, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(1, 1, 3)
)
def forward(self, x):
return self.nets(x)
class Net_new(nn.Module):
def __init__(self):
super(Net_new, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 2, 3)
self.r1 = torch.nn.ReLU(True)
self.conv2 = torch.nn.Conv2d(2, 1, 3)
self.r2 = torch.nn.ReLU(True)
self.conv3 = torch.nn.Conv2d(1, 1, 3)
def forward(self, x):
x = self.conv1(x)
x = self.r1(x)
x = self.conv2(x)
x = self.r2(x)
x = self.conv3(x)
x = self.nets(x)
return x
def _initialize_weights_from_net(self):
save_path = 't.pth'
print('Successfully load model '+save_path)
# First load the net.
pretrained_net = Net_old()
pretrained_net_dict = torch.load(save_path)
# load params
pretrained_net.load_state_dict(pretrained_net_dict)
new_convs = self.get_convs()
cnt = 0
# Because sequential is a generator.
for i, name in enumerate(pretrained_net.nets):
if isinstance(name, torch.nn.Conv2d):
print('Assign weight of pretrained model layer : ', name, ' to layer: ', new_convs[cnt])
new_convs[cnt].weight.data = name.weight.data
new_convs[cnt].bias.data = name.bias.data
cnt += 1
def get_convs(self):
return [self.conv1, self.conv2, self.conv3]
old_network = Net_old()
torch.save(old_network.cpu().state_dict(), 't.pth')
pretrained_net = torch.load('t.pth')
# Show keys of pretrained model
for key, v in pretrained_net.items():
print key
print('****Before loading********')
new_network = Net_new()
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
print('-----New loading method------')
new_network._initialize_weights_from_net()
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
輸出:
nets.0.weight
nets.0.bias
nets.2.weight
nets.2.bias
nets.4.weight
nets.4.bias
****Before loading********
0.510313585401
0.198701560497
-----New loading method------
Successfully load model t.pth
('Assign weight of pretrained model layer : ', Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)), ' to layer: ', Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)))
('Assign weight of pretrained model layer : ', Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1)), ' to layer: ', Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1)))
('Assign weight of pretrained model layer : ', Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1)), ' to layer: ', Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1)))
0.510313585401
0.510313585401
搞定!