[PyTorch小試牛刀]實戰三·DNN實現邏輯迴歸對FashionMNIST資料集進行分類
阿新 • • 發佈:2018-12-22
[PyTorch小試牛刀]實戰三·DNN實現邏輯迴歸對FashionMNIST資料集進行分類
內容還包括了網路模型引數的保存於載入。
資料集
下載地址
程式碼部分
import torch as t
import torchvision as tv
import numpy as np
# 超引數
EPOCH = 10
BATCH_SIZE = 100
DOWNLOAD_MNIST = True # 下過資料的話, 就可以設定成 False
N_TEST_IMG = 10 # 到時候顯示 5張圖片看效果, 如上圖一
class DNN(t.nn.Module) :
def __init__(self):
super(DNN, self).__init__()
train_data = tv.datasets.FashionMNIST(
root="./mnist/",
train=True,
transform=tv.transforms.ToTensor(),
download=DOWNLOAD_MNIST
)
test_data = tv.datasets.FashionMNIST(
root="./mnist/" ,
train=False,
transform=tv.transforms.ToTensor(),
download=DOWNLOAD_MNIST
)
print(test_data)
# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
self.train_loader = t.utils.data.DataLoader(
dataset= train_data,
batch_size=BATCH_SIZE,
shuffle=True)
self.test_loader = t.utils.data.DataLoader(
dataset=test_data,
batch_size=1000,
shuffle=True)
self.dnn = t.nn.Sequential(
t.nn.Linear(28*28,256),
t.nn.Dropout(0.5),
t.nn.ELU(),
t.nn.Linear(256,64),
t.nn.Dropout(0.5),
t.nn.ELU(),
t.nn.Linear(64,10)
)
self.lr = 0.001
self.loss = t.nn.CrossEntropyLoss()
self.opt = t.optim.Adam(self.parameters(), lr = self.lr)
def forward(self,x):
out = self.dnn(x)
return(out)
def train():
model = DNN()
print(model)
loss = model.loss
opt = model.opt
dataloader = model.train_loader
testloader = model.test_loader
for e in range(EPOCH):
step = 0
for (x, y) in (dataloader):
model.train()# train model dropout used
step += 1
b_x = x.view(-1, 28*28) # batch x, shape (batch, 28*28)
b_y = y
out = model(b_x)
losses = loss(out,b_y)
opt.zero_grad()
losses.backward()
opt.step()
if(step%100 == 0):
print(e,step,losses.data.numpy())
model.eval() # train model dropout not use
for (tx,ty) in testloader:
t_x = tx.view(-1, 28*28) # batch x, shape (batch, 28*28)
t_y = ty
t_out = model(t_x)
acc = (np.argmax(t_out.data.numpy(),axis=1) == t_y.data.numpy())
print(np.sum(acc)/1000)
break#只測試前1000個
t.save(model, './model.pkl') # 儲存整個網路
t.save(model.state_dict(), './model_params.pkl') # 只儲存網路中的引數 (速度快, 佔記憶體少)
#載入引數的方式
"""net = DNN()
net.load_state_dict(t.load('./model_params.pkl'))
net.eval()"""
#載入整個模型的方式
net = t.load('./model.pkl')
net.eval()
for (tx,ty) in testloader:
t_x = tx.view(-1, 28*28) # batch x, shape (batch, 28*28)
t_y = ty
t_out = net(t_x)
acc = (np.argmax(t_out.data.numpy(),axis=1) == t_y.data.numpy())
print(np.sum(acc)/1000)
if __name__ == "__main__":
train()
輸出結果
9 500 0.42454192
0.875
9 600 0.4553349
0.888
0.876
0.868
0.868
0.881
0.864
0.87
0.87
0.854
0.871
0.879