1. 程式人生 > 其它 >pytorch實現unet

pytorch實現unet

技術標籤:深度學習unet

unet是非常經典的影象分割的網路,因網路結構形似字母U而著稱

實現起來不是很複雜,程式碼如下:

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn

class unet(nn.Module):
    def __init__(self):
        super().__init__()
        #conv1
        self.conv1=nn.Sequential(
        nn.Conv2d(1,64,3),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.Conv2d(64,64,3),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True))
        #conv2
        self.conv2=nn.Sequential(
        nn.Conv2d(64,128,3),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.Conv2d(128,128,3),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True))
        #conv3
        self.conv3=nn.Sequential(
        nn.Conv2d(128,256,3),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.Conv2d(256,256,3),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True))
        #conv4
        self.conv4=nn.Sequential(
        nn.Conv2d(256,512,3),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True),
        nn.Conv2d(512,512,3),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True))
        #conv5
        self.conv5=nn.Sequential(
        nn.Conv2d(1024,512,3),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True),
        nn.Conv2d(512,256,3),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True))
        #conv6
        self.conv6=nn.Sequential(
        nn.Conv2d(512,256,3),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.Conv2d(256,128,3),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True))
        #conv7
        self.conv7=nn.Sequential(
        nn.Conv2d(256,128,3),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.Conv2d(128,64,3),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True))
        
        self.trans=nn.Sequential(
        nn.Conv2d(512,1024,kernel_size=3),
        nn.BatchNorm2d(1024),
        nn.ReLU(inplace=True),
        nn.Conv2d(1024,512,3),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True))
        
        self.end=nn.Sequential(
        nn.Conv2d(128,64,kernel_size=3),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.Conv2d(64,64,3),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.Conv2d(64,2,3),
        nn.BatchNorm2d(2),
        nn.ReLU(inplace=True)
        )
        
        self.unSample=nn.Upsample(mode='bilinear',scale_factor=2)
        self.pool=nn.MaxPool2d(kernel_size=2,stride=2)
        
    def forward(self,x):
        out_conv1=self.conv1(x)
        out=self.pool(out_conv1)
        out_conv2=self.conv2(out)
        out=self.pool(out_conv2)
        out_conv3=self.conv3(out)
        out=self.pool(out_conv3)
        out_conv4=self.conv4(out)
        out=self.pool(out_conv4)
        out=self.trans(out)
        out=self.unSample(out)
        out=torch.cat((out,out_conv4[:,:,:out.shape[2],:out.shape[3]]),1)
        out=self.conv5(out)
        out=self.unSample(out)
        out=torch.cat((out,out_conv3[:,:,:out.shape[2],:out.shape[3]]),1)
        out=self.conv6(out)
        out=self.unSample(out)
        out=torch.cat((out,out_conv2[:,:,:out.shape[2],:out.shape[3]]),1)        
        out=self.conv7(out)
        out=self.unSample(out)
        out=torch.cat((out,out_conv1[:,:,:out.shape[2],:out.shape[3]]),1)
        return self.end(out)
    
if __name__=='__main__':
    input = torch.randn(1,1,572,572)
    net=unet()
    output = net(input)
    print(output)
    torch.save(net,'unet.pth')    
    torch.onnx.export(net, input,  "unet.onnx", export_params=True, opset_version=10,          
                  do_constant_folding=True,   input_names = ['input'],   output_names = ['output'], 
)
    

執行結果如下:

此程式碼沒有進行訓練,是直接使用初始化權重的結果