MNIST手寫數字資料集
阿新 • • 發佈:2018-12-27
from torchvision import datasets, transforms # training parameters batch_size = 128 lr = 0.0002 train_epoch = 20 # data_loader img_size = 64 transform = transforms.Compose([ transforms.Scale(img_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) train_loader = torch.utils.data.DataLoader( datasets.MNIST('data', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True) # network G = generator(128) D = discriminator(128) G.weight_init(mean=0.0, std=0.02) D.weight_init(mean=0.0, std=0.02) G.cuda() D.cuda() # Binary Cross Entropy loss BCE_loss = nn.BCELoss() # Adam optimizer G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999)) D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
MNIST資料集已經包含在了torchvision裡面,從網上搜索到的介紹:圖片大小為28x28,訓練樣本有6000個,測試樣本10000個
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
對於上面這個標準化,normal公式是
初始化:
# weight_init def weight_init(self, mean, std): for m in self._modules: normal_init(self._modules[m], mean, std) def normal_init(m, mean, std): if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): m.weight.data.normal_(mean, std) m.bias.data.zero_()