1. 程式人生 > 程式設計 >pytorch SENet實現案例

pytorch SENet實現案例

我就廢話不多說了,大家還是直接看程式碼吧~

from torch import nn

class SELayer(nn.Module):
 def __init__(self,channel,reduction=16):
  super(SELayer,self).__init__()

  //返回1X1大小的特徵圖,通道數不變
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
  self.fc = nn.Sequential(
   nn.Linear(channel,channel // reduction,bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction,nn.Sigmoid()
  )

 def forward(self,x):
  b,c,_,_ = x.size()

  //全域性平均池化,batch和channel和原來一樣保持不變
  y = self.avg_pool(x).view(b,c)

  //全連線層+池化
  y = self.fc(y).view(b,1,1)

  //和原特徵圖相乘
  return x * y.expand_as(x)

補充知識:pytorch 實現 SE Block

論文模組圖

pytorch SENet實現案例

程式碼

import torch.nn as nn
class SE_Block(nn.Module):
 def __init__(self,ch_in,reduction=16):
  super(SE_Block,self).__init__()
  self.avg_pool = nn.AdaptiveAvgPool2d(1)				# 全域性自適應池化
  self.fc = nn.Sequential(
   nn.Linear(ch_in,ch_in // reduction,nn.Linear(ch_in // reduction,_ = x.size()
  y = self.avg_pool(x).view(b,c)
  y = self.fc(y).view(b,1)
  return x * y.expand_as(x)

現在還有許多關於SE的變形,但大都大同小異

以上這篇pytorch SENet實現案例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。