1. 程式人生 > 其它 >【leetcode LCP 34】二叉樹染色

【leetcode LCP 34】二叉樹染色

# -*- coding: utf-8 -*-
"""
Created on Wed Sep 15 17:15:50 2021

@author: 11651
"""
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt


# 通過ToTensor例項將影象資料從PIL型別變換成32位浮點數格式
# 併除以255使得所有畫素的數值均在0到1之間 trans = transforms.ToTensor() mnist_train = torchvision.datasets.FashionMNIST( root="./data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST( root="./data", train=False, transform=trans, download=True) len(mnist_train), len(mnist_test) mnist_train[0][0].shape
#生成批次資料 data_loader_train = torch.utils.data.DataLoader(dataset=mnist_train, batch_size = 18, shuffle = True) data_loader_test = torch.utils.data.DataLoader(dataset=mnist_test, batch_size
= 18, shuffle = True) def get_fashion_mnist_labels(labels): #@save """返回Fashion-MNIST資料集的文字標籤。""" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels] def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save """Plot a list of images.""" figsize = (num_cols * scale, num_rows * scale) _, axes = plt.subplots(num_rows, num_cols, figsize=figsize) axes = axes.flatten() for i, (ax, img) in enumerate(zip(axes, imgs)): if torch.is_tensor(img): # 圖片張量 ax.imshow(img.numpy(), cmap='gray') else: # PIL圖片 ax.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) if titles: ax.set_title(titles[i]) return axes X, y = next(iter(data_loader_train)) show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))