1. 程式人生 > 其它 >動手學深度學習v2-09-03-影象分類資料集

動手學深度學習v2-09-03-影象分類資料集

1 影象分類資料集

採用的是Fashion-MNIST資料集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms #對資料進行操作的模型
from d2l import torch as d2l

d2l.use_svg_display()  #用svg顯示圖片

1.1 讀取資料集

#通過框架中的內建函式將Fashion-MNIST資料集下載並讀取到記憶體中。
# 通過ToTensor例項將影象資料從PIL型別變換成32位浮點數格式
# 併除以255使得所有畫素的數值均在0到1之間
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="/data2", train=True, transform=trans, download=True)
# data2 就是自己新建立的一個目錄地址
# 獲取測試集
mnist_test = torchvision.datasets.FashionMNIST(
    root="/data2", train=False, transform=trans, download=True)

存在的問題:訓練集和測試集不能一起下載,太大了,分開下載就能執行

Fashion-MNIST由10個類別的影象組成,每個類別由訓練資料集中的6000張影象和測試資料集中的1000張影象組成。測試資料集(test dataset)不會用於訓練,只用於評估模型效能。訓練集和測試集分別包含60000和10000張影象。

# 訓練集和測試集的大小
len(mnist_train), len(mnist_test)
# 每個輸入影象的高度和寬度均為28畫素。資料集由灰度影象組成,其通道數為1。
mnist_train[0][0].shape

Fashion-MNIST中包含的10個類別分別為t-shirt(T恤)、trouser(褲子)、pullover(套衫)、dress(連衣裙)、coat(外套)、sandal(涼鞋)、shirt(襯衫)、sneaker(運動鞋)、bag(包)和ankle boot(短靴)。

# 下函式用於在數字標籤索引及其文字名稱之間進行轉換。
def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST資料集的文字標籤。"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
# 建立一個函式來視覺化這些樣本
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 圖片張量
            ax.imshow(img.numpy())
        else:
            # PIL圖片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes
# 以下是訓練資料集中前幾個樣本的影象及其相應的標籤(文字形式)
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
3show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

1.2 讀取小批量

#在每次迭代中,資料載入器每次都會讀取一小批量資料,大小為batch_size。我們在訓練資料迭代器中還隨機打亂了所有樣本。
batch_size = 256

def get_dataloader_workers():  #@save
    """使用4個程序來讀取資料。"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())
# 看一下讀取訓練資料所需的時間。
timer = d2l.Timer()
for X, y in train_iter:
    continue
f'{timer.stop():.2f} sec'

1.3 整合所有元件

現在我們定義load_data_fashion_mnist函式,用於獲取和讀取Fashion-MNIST資料集。它返回訓練集和驗證集的資料迭代器。此外,它還接受一個可選引數,用來將影象大小調整為另一種形狀。

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下載Fashion-MNIST資料集,然後將其載入到記憶體中。"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="/data2", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="/data2", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
# 通過指定resize引數來測試load_data_fashion_mnist函式的影象大小調整功能。
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break

2 總結

  • Fashion-MNIST是一個服裝分類資料集,由10個類別的影象組成。我們將在後續章節中使用此資料集來評估各種分類演算法。
  • 我們將高度 h 畫素,寬度 w 畫素影象的形狀記為 h×w 或( h , w )。
  • 資料迭代器是獲得更高效能的關鍵元件。依靠實現良好的資料迭代器,利用高效能運算來避免減慢訓練過程。