pytorch學習:準備自己的圖片資料
阿新 • • 發佈:2018-12-29
圖片資料一般有兩種情況:
1. 所有圖片放在一個資料夾內,另外有一個txt檔案顯示標籤。
2. 不同類別的圖片放在不同的資料夾內,資料夾就是圖片的類別。
兩種情況,第一種可以自定義Dataset,第二種情況直接呼叫torchvision.datasets.ImageFolder處理,具體如下:
一、 所有圖片均放在一個資料夾內
以mnist資料集的10000個test為例,先將test集裡面的10000圖片儲存出來,並生著對應的txt標籤檔案。先在當前目錄建立一個空資料夾mnist_test,用於儲存10000張圖片,接著執行程式碼:
import torch
import torchvision
import matplotlib.pyplot as plt
from skimage import io
mnist_test= torchvision.datasets.MNIST(
‘./mnist‘, train=False, download=True
)
print(‘test set:‘, len(mnist_test))
f=open(‘mnist_test.txt‘,‘w‘)
for i,(img,label) in enumerate(mnist_test):
img_path="./mnist_test/"+str(i)+".jpg"
io.imsave(img_path,img)
f.write(img_path+‘ ‘+str(label)+‘\n‘)
f.close()
如此,圖片就儲存mnist_test資料夾裡面,並在當前目錄下生成了一個mnist_test.txt檔案,大致如下:
然後就正式開始處理資料:
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image
def default_loader(path):
return Image.open(path).convert(‘RGB‘)
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, ‘r‘)
imgs = []
for line in fh:
line = line.strip(‘\n‘)
line = line.rstrip()
words = line.split()
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
train_data=MyDataset(txt=‘mnist_test.txt‘, transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader))
def show_batch(imgs):
grid = utils.make_grid(imgs)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title(‘Batch from dataloader‘)
for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(),batch_y.size())
show_batch(batch_x)
plt.axis(‘off‘)
plt.show()
二、 不同類別圖片放在不同的資料夾內
首先依舊是準備資料,以flowers資料集為例,下載地址為:
http://download.tensorflow.org/example_images/flower_photos.tgz
一共五類,分別放在5個資料夾中,大致如下圖:
路徑為d:/flowers/。那麼處理資料如下:
import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt
img_data = torchvision.datasets.ImageFolder(‘D:/bnu/database/flower‘,
transform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
)
print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader))
def show_batch(imgs):
grid = utils.make_grid(imgs,nrow=5)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title(‘Batch from dataloader‘)
for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(), batch_y.size())
show_batch(batch_x)
plt.axis(‘off‘)
plt.show()
轉載連結:http://www.bubuko.com/infodetail-2304938.html