1. 程式人生 > 程式設計 >Pytorch 使用 nii資料做輸入資料的操作

Pytorch 使用 nii資料做輸入資料的操作

使用pix2pix-gan做醫學影象合成的時候,如果把nii資料轉成png格式會損失很多資訊,以為png格式影象的灰度值有256階,因此直接使用nii的醫學影象做輸入會更好一點。

但是Pythorch中的Dataloader是不能直接讀取nii影象的,因此加一個CreateNiiDataset的類。

先來了解一下pytorch中讀取資料的主要途徑——Dataset類。在自己構建資料層時都要基於這個類,類似於C++中的虛基類。

自己構建的資料層包含三個部分

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``,that provides the size of the dataset,and ``__getitem__``,supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self,index):
 raise NotImplementedError
def __len__(self):
 raise NotImplementedError
def __add__(self,other):
 return ConcatDataset([self,other])

根據自己的需要編寫CreateNiiDataset子類:

因為我是基於https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

做pix2pix-gan的實驗,資料包含兩個部分mr 和 ct,不需要標籤,因此上面的 def getitem(self,index):中不需要index這個引數了,類似地,根據需要,加入自己的引數,去掉不需要的引數。

class CreateNiiDataset(Dataset):
 def __init__(self,opt,transform = None,target_transform = None):
  self.path1 = opt.dataroot # parameter passing
  self.A = 'MR' 
  self.B = 'CT'
  lines = os.listdir(os.path.join(self.path1,self.A))
  lines.sort()
  imgs = []
  for line in lines:
   imgs.append(line)
  self.imgs = imgs
  self.transform = transform
  self.target_transform = target_transform

 def crop(self,image,crop_size):
  shp = image.shape
  scl = [int((shp[0] - crop_size[0]) / 2),int((shp[1] - crop_size[1]) / 2)]
  image_crop = image[scl[0]:scl[0] + crop_size[0],scl[1]:scl[1] + crop_size[1]]
  return image_crop

 def __getitem__(self,item):
  file = self.imgs[item]
  img1 = sitk.ReadImage(os.path.join(self.path1,self.A,file))
  img2 = sitk.ReadImage(os.path.join(self.path1,self.B,file))
  data1 = sitk.GetArrayFromImage(img1)
  data2 = sitk.GetArrayFromImage(img2)

  if data1.shape[0] != 256:
   data1 = self.crop(data1,[256,256])
   data2 = self.crop(data2,256])
  if self.transform is not None:
   data1 = self.transform(data1)
   data2 = self.transform(data2)

  if np.min(data1)<0:
   data1 = (data1 - np.min(data1))/(np.max(data1)-np.min(data1))

  if np.min(data2)<0:
   #data2 = data2 - np.min(data2)
   data2 = (data2 - np.min(data2))/(np.max(data2)-np.min(data2))

  data = {}
  data1 = data1[np.newaxis,np.newaxis,:,:]
  data1_tensor = torch.from_numpy(np.concatenate([data1,data1,data1],1))
  data1_tensor = data1_tensor.type(torch.FloatTensor)
  data['A'] = data1_tensor # should be a tensor in Float Tensor Type

  data2 = data2[np.newaxis,:]
  data2_tensor = torch.from_numpy(np.concatenate([data2,data2,data2],1))
  data2_tensor = data2_tensor.type(torch.FloatTensor)
  data['B'] = data2_tensor # should be a tensor in Float Tensor Type
  data['A_paths'] = [os.path.join(self.path1,file)] # should be a list,with path inside
  data['B_paths'] = [os.path.join(self.path1,file)]
  return data

 def load_data(self):
  return self

 def __len__(self):
  return len(self.imgs)

注意:最後輸出的data是一個字典,裡面有四個keys=[‘A',‘B',‘A_paths',‘B_paths'],一定要注意資料要轉成FloatTensor。

其次是data[‘A_paths'] 接收的值是一個list,一定要加[ ] 擴起來,要不然測試存圖的時候會有問題,找這個問題找了好久才發現。

然後直接在train.py的主函式裡面把資料載入那行改掉就好了

data_loader = CreateNiiDataset(opt)
dataset = data_loader.load_data()

Over!

補充知識:nii格式影象存為npy格式

我就廢話不多說了,大家還是直接看程式碼吧!

import nibabel as nib
import os
import numpy as np
 
img_path = '/home/lei/train/img/'
seg_path = '/home/lei/train/seg/'
saveimg_path = '/home/lei/train/npy_img/'
saveseg_path = '/home/lei/train/npy_seg/'
 
img_names = os.listdir(img_path)
seg_names = os.listdir(seg_path)
 
for img_name in img_names:
 print(img_name)
 img = nib.load(img_path + img_name).get_data() #載入
 img = np.array(img)
 np.save(saveimg_path + str(img_name).split('.')[0] + '.npy',img) #儲存
 
for seg_name in seg_names:
 print(seg_name)
 seg = nib.load(seg_path + seg_name).get_data()
 seg = np.array(seg)
 np.save(saveseg_path + str(seg_name).split('.')[0] + '.npy

以上這篇Pytorch 使用 nii資料做輸入資料的操作就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。