1. 程式人生 > 其它 >CNN每層卷積結果視覺展示(3Dircadb肝臟資料為例)

CNN每層卷積結果視覺展示(3Dircadb肝臟資料為例)

試著展示了肝臟每層卷積之後的結果。程式碼如下:

import torch
import torch.nn as nn
import SimpleITK as sitk
import numpy as np

def change_indenty(ct):
    ct[ct < 40] = 40
    ct[ct > 400] = 400
    return ct

class Dialte(nn.Module):
    def __init__(self):
        super(Dialte, self).__init__()
        self.act = nn.ReLU(inplace=False)
        self.norm = nn.BatchNorm3d
        self.conv1 = nn.Sequential(
            nn.Conv3d(1,1, kernel_size=3, stride=1, padding=1),
            self.act,
            self.norm(1)
        )
        self.conv2 = nn.Sequential(
            nn.Conv3d(1, 1, kernel_size=3, stride=1, padding=1),
            self.act,
            self.norm(1)
        )
        self.conv3 = nn.Sequential(
            nn.Conv3d(1, 1, kernel_size=3, stride=1, padding=1),
            self.act,
            self.norm(1)
        )
        self.conv4 = nn.Sequential(
            nn.Conv3d(1, 1, kernel_size=3, stride=1, padding=1),
            self.act,
            self.norm(1)
        )
        self.conv5 = nn.Sequential(
            nn.Conv3d(1, 1, kernel_size=3, stride=1, padding=1),
            self.act,
            self.norm(1)
        )
        self.conv6 = nn.Sequential(
            nn.Conv3d(1, 1, kernel_size=3, stride=1, padding=1),
            self.act,
            self.norm(1)
        )
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        return x

if __name__ == "__main__":
    path = r"D:\myProject\HDC_vessel_seg\datasets\nii\image_2.nii"
    image = sitk.ReadImage(path)
    img_num = sitk.GetArrayFromImage(image)
    img_num = change_indenty(img_num)
    img_num = np.expand_dims(np.expand_dims(img_num, axis=0), axis=0).astype(np.float32)
    img_num = torch.from_numpy(img_num)
    print(img_num.shape)

    # image = torch.randn(1*3*8*8*8).reshape(1,3,8,8,8)
    model = Dialte()
    x = model(img_num)
    x = x[0,0,...]
    x = x.cpu().data.numpy()
    predict_seg = sitk.GetImageFromArray(x)
    predict_seg.SetSpacing(image.GetSpacing())
    predict_seg.SetOrigin(image.GetOrigin())
    predict_seg.SetDirection(image.GetDirection())
    # sitk.WriteImage(predict_seg, path.replace("vessel", "dialte"))
    sitk.WriteImage(predict_seg, path.replace("image", "pre_image"))

  

  結果:

原始

第一層:

第二層:

第三層:

第四層:

第五層:

第六層: