1. 程式人生 > 程式設計 >pytorch如何凍結某層引數的實現

pytorch如何凍結某層引數的實現

在遷移學習finetune時我們通常需要凍結前幾層的引數不參與訓練,在Pytorch中的實現如下:

class Model(nn.Module):
 def __init__(self):
  super(Transfer_model,self).__init__()
  self.linear1 = nn.Linear(20,50)
  self.linear2 = nn.Linear(50,20)
  self.linear3 = nn.Linear(20,2)

 def forward(self,x):
 pass

假如我們想要凍結linear1層,需要做如下操作:

model = Model()
# 這裡是一般情況,共享層往往不止一層,所以做一個for迴圈
for para in model.linear1.parameters():
 para.requires_grad = False
# 假如真的只有一層也可以這樣操作:
# model.linear1.weight.requires_grad = False

最後我們需要將需要優化的引數傳入優化器,不需要傳入的引數過濾掉,所以要用到filter()函式。

optimizer = optim.Adam(filter(lambda p: p.requires_grad,model.parameters()),lr=0.1)

其它的部落格中都沒有講解filter()函式的作用,在這裡我簡單講一下有助於更好的理解。

filter(function,iterable)

  • function: 判斷函式
  • iterable: 可迭代物件

filter() 函式用於過濾序列,過濾掉不符合條件的元素,返回一個迭代器物件,如果要轉換為列表,可以使用 list() 來轉換。

該接收兩個引數,第一個為函式,第二個為序列,序列的每個元素作為引數傳遞給函式進行判,然後返回 True 或 False,最後將返回 True 的元素放到新列表中。

filter()函式將requires_grad = True的引數傳入優化器進行反向傳播,requires_grad = False的則被過濾掉。

以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支援我們。