pytorch實現unet
阿新 • • 發佈:2021-02-10
unet是非常經典的影象分割的網路,因網路結構形似字母U而著稱
實現起來不是很複雜,程式碼如下:
# -*- coding: utf-8 -*- import torch import torch.nn as nn class unet(nn.Module): def __init__(self): super().__init__() #conv1 self.conv1=nn.Sequential( nn.Conv2d(1,64,3), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64,64,3), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) #conv2 self.conv2=nn.Sequential( nn.Conv2d(64,128,3), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128,128,3), nn.BatchNorm2d(128), nn.ReLU(inplace=True)) #conv3 self.conv3=nn.Sequential( nn.Conv2d(128,256,3), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256,256,3), nn.BatchNorm2d(256), nn.ReLU(inplace=True)) #conv4 self.conv4=nn.Sequential( nn.Conv2d(256,512,3), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512,512,3), nn.BatchNorm2d(512), nn.ReLU(inplace=True)) #conv5 self.conv5=nn.Sequential( nn.Conv2d(1024,512,3), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512,256,3), nn.BatchNorm2d(256), nn.ReLU(inplace=True)) #conv6 self.conv6=nn.Sequential( nn.Conv2d(512,256,3), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256,128,3), nn.BatchNorm2d(128), nn.ReLU(inplace=True)) #conv7 self.conv7=nn.Sequential( nn.Conv2d(256,128,3), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128,64,3), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) self.trans=nn.Sequential( nn.Conv2d(512,1024,kernel_size=3), nn.BatchNorm2d(1024), nn.ReLU(inplace=True), nn.Conv2d(1024,512,3), nn.BatchNorm2d(512), nn.ReLU(inplace=True)) self.end=nn.Sequential( nn.Conv2d(128,64,kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64,64,3), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64,2,3), nn.BatchNorm2d(2), nn.ReLU(inplace=True) ) self.unSample=nn.Upsample(mode='bilinear',scale_factor=2) self.pool=nn.MaxPool2d(kernel_size=2,stride=2) def forward(self,x): out_conv1=self.conv1(x) out=self.pool(out_conv1) out_conv2=self.conv2(out) out=self.pool(out_conv2) out_conv3=self.conv3(out) out=self.pool(out_conv3) out_conv4=self.conv4(out) out=self.pool(out_conv4) out=self.trans(out) out=self.unSample(out) out=torch.cat((out,out_conv4[:,:,:out.shape[2],:out.shape[3]]),1) out=self.conv5(out) out=self.unSample(out) out=torch.cat((out,out_conv3[:,:,:out.shape[2],:out.shape[3]]),1) out=self.conv6(out) out=self.unSample(out) out=torch.cat((out,out_conv2[:,:,:out.shape[2],:out.shape[3]]),1) out=self.conv7(out) out=self.unSample(out) out=torch.cat((out,out_conv1[:,:,:out.shape[2],:out.shape[3]]),1) return self.end(out) if __name__=='__main__': input = torch.randn(1,1,572,572) net=unet() output = net(input) print(output) torch.save(net,'unet.pth') torch.onnx.export(net, input, "unet.onnx", export_params=True, opset_version=10, do_constant_folding=True, input_names = ['input'], output_names = ['output'], )
執行結果如下:
此程式碼沒有進行訓練,是直接使用初始化權重的結果