【PyTorch】:LeNet實現cifar10分類.
阿新 • • 發佈:2019-01-08
# Pytorch 0.4.0 LeNet實現cifar10分類. # @Time: 2018/6/15 # @Author: xfLi import torchvision as tv import torch.nn as nn import torch as t from torch.autograd import Variable import torch.nn.functional as F import torchvision.transforms as transforms from torch.utils.data import DataLoader from torch import optim MAX_EPOCH = 2 CLASS_NUM = 10 class Net(nn.Module): #定義網路 def __init__(self, class_num=CLASS_NUM): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16*5*5, 120) self.fc2= nn.Linear(120, 84) self.fc3 = nn.Linear(84, class_num) def forward(self, x): x = F.max_pool2d(F.relu(self.conv1(x)), kernel_size=(2, 2)) x = F.max_pool2d(F.relu(self.conv2(x)), kernel_size=(2, 2)) x = x.view(x.size()[0], -1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def getData(): #資料預處理 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #訓練集 train_set = tv.datasets.CIFAR10(root='/data/', train=True, transform=transform, download=True) train_loader = DataLoader(train_set, batch_size=4, shuffle=True) #測試集 test_set = tv.datasets.CIFAR10(root='/data/', train=False, transform=transform, download=True) test_loader = t.utils.data.DataLoader(test_set, batch_size = 4, shuffle = False) classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck') return train_loader, test_loader, classes def train(): #訓練 net = Net() train_dataloader, test_dataloader, classes = getData() #載入資料 ceterion = nn.CrossEntropyLoss() #交叉熵損失 optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) for epoch in range(MAX_EPOCH): for step, data in enumerate(train_dataloader): inputs, labels = data inputs, labels = Variable(inputs), Variable(labels) optimizer.zero_grad() outputs = net(inputs) loss = ceterion(outputs, labels) loss.backward() optimizer.step() if step % 1000 == 999: acc =test_net(net, test_dataloader) print('Epoch: ', epoch, ' |step: ', step, ' |train_loss: ', loss.item(), '|test accuracy:%.4f' % acc) print('Finished Training') return net def test_net(net, test_dataloader): # 獲取在測試集上的準確率 correct, total = .0, .0 for inputs, label in test_dataloader: output = net(inputs) _, predicted = t.max(output, 1) # 獲取分類結果 total += label.size(0) # 記錄總個數 correct += (predicted == label).sum() # 記錄分類正確的個數 return float(correct) / total if __name__ == '__main__': net = train()