使用pytorch實現論文中的unet網路
阿新 • • 發佈:2020-06-28
設計神經網路的一般步驟:
1. 設計框架
2. 設計骨幹網路
Unet網路設計的步驟:
1. 設計Unet網路工廠模式
2. 設計編解碼結構
3. 設計卷積模組
4. unet例項模組
Unet網路最重要的特徵:
1. 編解碼結構。
2. 解碼結構,比FCN更加完善,採用連線方式。
3. 本質是一個框架,編碼部分可以使用很多影象分類網路。
示例程式碼:
import torch import torch.nn as nn class Unet(nn.Module): #初始化引數:Encoder,Decoder,bridge #bridge預設值為無,如果有引數傳入,則用該引數替換None def __init__(self,Encoder,bridge = None): super(Unet,self).__init__() self.encoder = Encoder(encoder_blocks) self.decoder = Decoder(decoder_blocks) self.bridge = bridge def forward(self,x): res = self.encoder(x) out,skip = res[0],res[1,:] if bridge is not None: out = bridge(out) out = self.decoder(out,skip) return out #設計編碼模組 class Encoder(nn.Module): def __init__(self,blocks): super(Encoder,self).__init__() #assert:斷言函式,避免出現引數錯誤 assert len(blocks) > 0 #nn.Modulelist():模型列表,所有的引數可以納入網路,但是沒有forward函式 self.blocks = nn.Modulelist(blocks) def forward(self,x): skip = [] for i in range(len(self.blocks) - 1): x = self.blocks[i](x) skip.append(x) res = [self.block[i+1](x)] #列表之間可以通過+號拼接 res += skip return res #設計Decoder模組 class Decoder(nn.Module): def __init__(self,blocks): super(Decoder,self).__init__() assert len(blocks) > 0 self.blocks = nn.Modulelist(blocks) def ceter_crop(self,skips,x): _,_,height1,width1 = skips.shape() _,height2,width2 = x.shape() #對影象進行剪下處理,拼接的時候保持對應size引數一致 ht,wt = min(height1,height2),min(width1,width2) dh1 = (height1 - height2)//2 if height1 > height2 else 0 dw1 = (width1 - width2)//2 if width1 > width2 else 0 dh2 = (height2 - height1)//2 if height2 > height1 else 0 dw2 = (width2 - width1)//2 if width2 > width1 else 0 return skips[:,:,dh1:(dh1 + ht),dw1:(dw1 + wt)],\ x[:,dh2:(dh2 + ht),dw2 : (dw2 + wt)] def forward(self,x,reverse_skips = True): assert len(skips) == len(blocks) - 1 if reverse_skips is True: skips = skips[: : -1] x = self.blocks[0](x) for i in range(1,len(self.blocks)): skip = skips[i-1] x = torch.cat(skip,1) x = self.blocks[i](x) return x #定義了一個卷積block def unet_convs(in_channels,out_channels,padding = 0): #nn.Sequential:與Modulelist相比,包含了forward函式 return nn.Sequential( nn.Conv2d(in_channels,kernal_size = 3,padding = padding,bias = False),nn.BatchNorm2d(outchannels),nn.ReLU(inplace = True),nn.Conv2d(in_channels,kernal_size=3,padding=padding,bias=False),nn.ReLU(inplace=True),) #例項化Unet模型 def unet(in_channels,out_channels): encoder_blocks = [unet_convs(in_channels,64),\ nn.Sequential(nn.Maxpool2d(kernal_size = 2,stride = 2,ceil_mode = True),\ unet_convs(64,128)),\ nn.Sequential(nn.Maxpool2d(kernal_size=2,stride=2,ceil_mode=True),\ unet_convs(128,256)),nn.Sequential(nn.Maxpool2d(kernal_size=2,\ unet_convs(256,512)),] bridge = nn.Sequential(unet_convs(512,1024)) decoder_blocks = [nn.conTranpose2d(1024,512),\ nn.Sequential(unet_convs(1024,nn.conTranpose2d(512,\ nn.Sequential(unet_convs(512,256),nn.conTranpose2d(256,\ nn.Sequential(unet_convs(256,128),nn.conTranpose2d(128,64)) ] return Unet(encoder_blocks,decoder_blocks,bridge)
補充知識:Pytorch搭建U-Net網路
U-Net: Convolutional Networks for Biomedical Image Segmentation
import torch.nn as nn import torch from torch import autograd from torchsummary import summary class DoubleConv(nn.Module): def __init__(self,in_ch,out_ch): super(DoubleConv,self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch,out_ch,3,padding=0),nn.BatchNorm2d(out_ch),nn.Conv2d(out_ch,nn.ReLU(inplace=True) ) def forward(self,input): return self.conv(input) class Unet(nn.Module): def __init__(self,out_ch): super(Unet,self).__init__() self.conv1 = DoubleConv(in_ch,64) self.pool1 = nn.MaxPool2d(2) self.conv2 = DoubleConv(64,128) self.pool2 = nn.MaxPool2d(2) self.conv3 = DoubleConv(128,256) self.pool3 = nn.MaxPool2d(2) self.conv4 = DoubleConv(256,512) self.pool4 = nn.MaxPool2d(2) self.conv5 = DoubleConv(512,1024) # 逆卷積,也可以使用上取樣 self.up6 = nn.ConvTranspose2d(1024,512,2,stride=2) self.conv6 = DoubleConv(1024,512) self.up7 = nn.ConvTranspose2d(512,256,stride=2) self.conv7 = DoubleConv(512,256) self.up8 = nn.ConvTranspose2d(256,128,stride=2) self.conv8 = DoubleConv(256,128) self.up9 = nn.ConvTranspose2d(128,64,stride=2) self.conv9 = DoubleConv(128,64) self.conv10 = nn.Conv2d(64,1) def forward(self,x): c1 = self.conv1(x) crop1 = c1[:,88:480,88:480] p1 = self.pool1(c1) c2 = self.conv2(p1) crop2 = c2[:,40:240,40:240] p2 = self.pool2(c2) c3 = self.conv3(p2) crop3 = c3[:,16:120,16:120] p3 = self.pool3(c3) c4 = self.conv4(p3) crop4 = c4[:,4:60,4:60] p4 = self.pool4(c4) c5 = self.conv5(p4) up_6 = self.up6(c5) merge6 = torch.cat([up_6,crop4],dim=1) c6 = self.conv6(merge6) up_7 = self.up7(c6) merge7 = torch.cat([up_7,crop3],dim=1) c7 = self.conv7(merge7) up_8 = self.up8(c7) merge8 = torch.cat([up_8,crop2],dim=1) c8 = self.conv8(merge8) up_9 = self.up9(c8) merge9 = torch.cat([up_9,crop1],dim=1) c9 = self.conv9(merge9) c10 = self.conv10(c9) out = nn.Sigmoid()(c10) return out if __name__=="__main__": test_input=torch.rand(1,1,572,572) model=Unet(in_ch=1,out_ch=2) summary(model,(1,572)) ouput=model(test_input) print(ouput.size())
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1,570,570] 640 BatchNorm2d-2 [-1,570] 128 ReLU-3 [-1,570] 0 Conv2d-4 [-1,568,568] 36,928 BatchNorm2d-5 [-1,568] 128 ReLU-6 [-1,568] 0 DoubleConv-7 [-1,568] 0 MaxPool2d-8 [-1,284,284] 0 Conv2d-9 [-1,282,282] 73,856 BatchNorm2d-10 [-1,282] 256 ReLU-11 [-1,282] 0 Conv2d-12 [-1,280,280] 147,584 BatchNorm2d-13 [-1,280] 256 ReLU-14 [-1,280] 0 DoubleConv-15 [-1,280] 0 MaxPool2d-16 [-1,140,140] 0 Conv2d-17 [-1,138,138] 295,168 BatchNorm2d-18 [-1,138] 512 ReLU-19 [-1,138] 0 Conv2d-20 [-1,136,136] 590,080 BatchNorm2d-21 [-1,136] 512 ReLU-22 [-1,136] 0 DoubleConv-23 [-1,136] 0 MaxPool2d-24 [-1,68,68] 0 Conv2d-25 [-1,66,66] 1,180,160 BatchNorm2d-26 [-1,66] 1,024 ReLU-27 [-1,66] 0 Conv2d-28 [-1,64] 2,359,808 BatchNorm2d-29 [-1,64] 1,024 ReLU-30 [-1,64] 0 DoubleConv-31 [-1,64] 0 MaxPool2d-32 [-1,32,32] 0 Conv2d-33 [-1,1024,30,30] 4,719,616 BatchNorm2d-34 [-1,30] 2,048 ReLU-35 [-1,30] 0 Conv2d-36 [-1,28,28] 9,438,208 BatchNorm2d-37 [-1,28] 2,048 ReLU-38 [-1,28] 0 DoubleConv-39 [-1,28] 0 ConvTranspose2d-40 [-1,56,56] 2,097,664 Conv2d-41 [-1,54,54] 4,104 BatchNorm2d-42 [-1,54] 1,024 ReLU-43 [-1,54] 0 Conv2d-44 [-1,52,52] 2,808 BatchNorm2d-45 [-1,52] 1,024 ReLU-46 [-1,52] 0 DoubleConv-47 [-1,52] 0 ConvTranspose2d-48 [-1,104,104] 524,544 Conv2d-49 [-1,102,102] 1,179,904 BatchNorm2d-50 [-1,102] 512 ReLU-51 [-1,102] 0 Conv2d-52 [-1,100,100] 590,080 BatchNorm2d-53 [-1,100] 512 ReLU-54 [-1,100] 0 DoubleConv-55 [-1,100] 0 ConvTranspose2d-56 [-1,200,200] 131,200 Conv2d-57 [-1,198,198] 295,040 BatchNorm2d-58 [-1,198] 256 ReLU-59 [-1,198] 0 Conv2d-60 [-1,196,196] 147,584 BatchNorm2d-61 [-1,196] 256 ReLU-62 [-1,196] 0 DoubleConv-63 [-1,196] 0 ConvTranspose2d-64 [-1,392,392] 32,832 Conv2d-65 [-1,390,390] 73,792 BatchNorm2d-66 [-1,390] 128 ReLU-67 [-1,390] 0 Conv2d-68 [-1,388,388] 36,928 BatchNorm2d-69 [-1,388] 128 ReLU-70 [-1,388] 0 DoubleConv-71 [-1,388] 0 Conv2d-72 [-1,388] 130 ================================================================ Total params: 31,042,434 Trainable params: 31,434 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 1.25 Forward/backward pass size (MB): 3280.59 Params size (MB): 118.42 Estimated Total Size (MB): 3400.26 ---------------------------------------------------------------- torch.Size([1,388])
以上這篇使用pytorch實現論文中的unet網路就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。