動手學深度學習v2-09-03-影象分類資料集
阿新 • • 發佈:2021-11-04
1 影象分類資料集
採用的是Fashion-MNIST資料集
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms #對資料進行操作的模型
from d2l import torch as d2l
d2l.use_svg_display() #用svg顯示圖片
1.1 讀取資料集
#通過框架中的內建函式將Fashion-MNIST資料集下載並讀取到記憶體中。 # 通過ToTensor例項將影象資料從PIL型別變換成32位浮點數格式 # 併除以255使得所有畫素的數值均在0到1之間 trans = transforms.ToTensor() mnist_train = torchvision.datasets.FashionMNIST( root="/data2", train=True, transform=trans, download=True) # data2 就是自己新建立的一個目錄地址
# 獲取測試集
mnist_test = torchvision.datasets.FashionMNIST(
root="/data2", train=False, transform=trans, download=True)
存在的問題:訓練集和測試集不能一起下載,太大了,分開下載就能執行
Fashion-MNIST由10個類別的影象組成,每個類別由訓練資料集中的6000張影象和測試資料集中的1000張影象組成。測試資料集(test dataset)不會用於訓練,只用於評估模型效能。訓練集和測試集分別包含60000和10000張影象。
# 訓練集和測試集的大小 len(mnist_train), len(mnist_test)
# 每個輸入影象的高度和寬度均為28畫素。資料集由灰度影象組成,其通道數為1。
mnist_train[0][0].shape
Fashion-MNIST中包含的10個類別分別為t-shirt(T恤)、trouser(褲子)、pullover(套衫)、dress(連衣裙)、coat(外套)、sandal(涼鞋)、shirt(襯衫)、sneaker(運動鞋)、bag(包)和ankle boot(短靴)。
# 下函式用於在數字標籤索引及其文字名稱之間進行轉換。 def get_fashion_mnist_labels(labels): #@save """返回Fashion-MNIST資料集的文字標籤。""" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels]
# 建立一個函式來視覺化這些樣本
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""Plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 圖片張量
ax.imshow(img.numpy())
else:
# PIL圖片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
# 以下是訓練資料集中前幾個樣本的影象及其相應的標籤(文字形式)
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
3show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
1.2 讀取小批量
#在每次迭代中,資料載入器每次都會讀取一小批量資料,大小為batch_size。我們在訓練資料迭代器中還隨機打亂了所有樣本。
batch_size = 256
def get_dataloader_workers(): #@save
"""使用4個程序來讀取資料。"""
return 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
# 看一下讀取訓練資料所需的時間。
timer = d2l.Timer()
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec'
1.3 整合所有元件
現在我們定義load_data_fashion_mnist函式,用於獲取和讀取Fashion-MNIST資料集。它返回訓練集和驗證集的資料迭代器。此外,它還接受一個可選引數,用來將影象大小調整為另一種形狀。
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""下載Fashion-MNIST資料集,然後將其載入到記憶體中。"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="/data2", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="/data2", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
# 通過指定resize引數來測試load_data_fashion_mnist函式的影象大小調整功能。
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
2 總結
- Fashion-MNIST是一個服裝分類資料集,由10個類別的影象組成。我們將在後續章節中使用此資料集來評估各種分類演算法。
- 我們將高度 h 畫素,寬度 w 畫素影象的形狀記為 h×w 或( h , w )。
- 資料迭代器是獲得更高效能的關鍵元件。依靠實現良好的資料迭代器,利用高效能運算來避免減慢訓練過程。