1. 程式人生 > >完整的pytorch教程(mnist 為例)

完整的pytorch教程(mnist 為例)

'''

設定不同層的學習率、更新學習率(對應更新)、輸出中間層特徵、建立網路、儲存網路、測試網路效果、初始化

A whole Pytorch tutorial : set different layer's lr , update lr (One to one correspondence)
                           output middle layer's feature and fine-tune

'''
import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-4

# mnist download,transform NHWC=>NCHW and 0,255=>0,1
train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                                      transform=torchvision.transforms.ToTensor(),download=False)
# pytorch's dataset loader
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)
# test data
test_data=torchvision.datasets.MNIST(root='./mnist',train=False)
# test Variable  need transform  gpu
test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()
# create model
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Sequential(
                nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2,stride=2))
        self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
        self.out=nn.Linear(32*7*7,10)
        
        # init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
        
    def forward(self,x):
        per_out=[]
        x=self.conv1(x)
        per_out.append(x)
        x=self.conv2(x)
        per_out.append(x)
        x=x.view(x.size(0),-1)
        output=self.out(x)
        # can output middle layer's features
        return output,per_out
cnn=CNN().cuda()# gpu

# set different layer's learning rate: [conv1 conv2] lr*10 ; [out]  lr
def get_10x_lr_params(net):
    b=[net.conv1,net.conv2]
    for i in b:
        for j in i.modules():
            for k in j.parameters():
                yield k
                # generator
# fine-tune
new_params=cnn.state_dict()
pretrain_dict=torch.load('./model/model.pth')
pretrain_dict={k:v for k,v in pretrain_dict.items() if k in new_params and v.size()==new_params[k].size()}#dict gennerator
new_params.update(pretrain_dict)
cnn.load_state_dict(new_params)

cnn.train()# if you want test ,just modify cnn.eval()
         
# update lr
def lr_poly(base_lr,iters,max_iter,power):
    return base_lr*((1-float(iters)/max_iter)**power)
def adjust_lr(optimizer,base_lr,iters,max_iter,power):
    lr=lr_poly(base_lr,iters,max_iter,power)
    optimizer.param_groups[0]['lr']=lr # first param iterator
    if len(optimizer.param_groups)>1:
        optimizer.param_groups[1]['lr']=lr*10

Params=get_10x_lr_params(cnn)
# optimizer             first params  lr=LR Internal overlapping external lr; second params
optimizer=torch.optim.Adam([{'params':cnn.out.parameters()},{'params':Params,'lr':LR*10}],lr=LR)
# loss function
loss_func=nn.CrossEntropyLoss().cuda()

iters=0
for epoch in range(EPOCH):
    i_iter=train_data.train_data.shape[0]//BATCH_SIZE
    for step,(x,y) in enumerate(train_loader):
        optimizer.zero_grad()# clear gradient
        adjust_lr(optimizer,LR,iters,EPOCH*i_iter,0.9)# update lr
        iters+=1
        b_x=Variable(x).cuda()# if channel==1 auto add c=1
        b_y=Variable(y).cuda()
#        print(cnn.state_dict()['conv1.0.weight'])
        output=cnn(b_x)[0]
        loss=loss_func(output,b_y)# Variable need to get .data
        loss.backward() # backward loss
        optimizer.step() # compute per gradient
        
        if step%50==0:
            test_output=cnn(test_x)[0]
            pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
            '''
            why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu
            and to numpy and .float compute decimal
            '''
            accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
            print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
        #                                           loss.data.cpu().numpy().item() get one value
    torch.save(cnn.state_dict(),'./model/model.pth')
# test phase
test_output=cnn(test_x[:13])[0]
pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
print(pred_y)
print(test_y[:13])

相關推薦

完整pytorch教程mnist

''' 設定不同層的學習率、更新學習率(對應更新)、輸出中間層特徵、建立網路、儲存網路、測試網路效果、初始化 A whole Pytorch tutorial : set different layer's lr , update lr (One to one cor

Tensorflow建立資料集mnist

網上的mnist的demo大部分都是按照實戰google那本來的,但是那個在資料集的處理上用的是TensorFlow的官方api,我們在正常做標籤的時候並不一定要那樣做,本文講解了兩種標籤方式區別於實戰google的demo。 folder方式: ROOT_FOLDER |--------

pytorch建立自己的資料集mnist

本文將原始的numpy array資料在pytorch下封裝為Dataset類的資料集,為後續深度網路訓練提供資料。 載入並儲存影象資訊 首先匯入需要的庫,定義各種路徑。 import os import matplotlib from keras.datase

Keras中實現模型載入與測試mnist

 需要安裝cv2 安裝h5py的命令如下(模型載入模組): sudo pip install cython sudo apt-get install libhdf5-dev sudo pip 

多數據源動態關聯報表的制作birt

處理 center 關閉 主表 等價 兩個 數據 fonts img 使用Jasper或BIRT等報表工具時,常會碰到一些很規的統計,用報表工具本身或SQL都難以處理,比方與主表相關的子表分布在多個數據庫中,報表要展現這些數據源動態關聯的結果。集算器具

linux mount掛載設備u盤,光盤,iso等 使用說明Centos

centos mount 掛載方法集錦 linux mount 對於新手學習,mount 命令,一定會有很多疑問。其實我想疑問來源更多的是對linux系統本身特殊性了解問題。 linux是基於文件系統,所有的設備都會對應於:/dev/下面的設備。如:[chengmo@centos5 dev]$ l

對於單個模型長方體進行面投影時的消隱

return .com www ++ 類的繼承 投影 逆時針 所有 順序 作者:feiquan 出處:http://www.cnblogs.com/feiquan/ 版權聲明:本文版權歸作者和博客園共有,歡迎轉載,但未經作者同意必須保留此段聲明,且在文章頁面明顯位置給

【總結】遊戲框架與架構設計Unity

單機 業務 github 事件 概念 lec 集合 架構模式 wid 使用框架開發遊戲 優點:耦合性低,重用性高,部署快,可維護性高,方便管理。提高開發效率,降低開發難度 缺點:增加了系統結構和實現的復雜性,需要額外花費精力維護,不適合小型程序,易影響運行效率 常見

KEGG下載某物種最新的版本信息斑馬魚

ref nbsp wid 結構 egg 解析 image 版本 ima 步驟一:打開鏈接並選擇物種 http://www.genome.jp/kegg-bin/get_htext?hsa00001+3101 步驟二:對文件進行解析 步驟三:統計信息 一級結構(6大類):

oracle資料庫的徹底解除安裝11g

1. 關閉oracle所有的服務。可以在windows的服務管理器中關閉;2. 開啟登錄檔:regedit 開啟路徑: HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\ 刪除該路徑下的所有以oracle開始的服務名稱,這個鍵是標

LINUX安裝軟體FFmpeg

https://trac.ffmpeg.org/wiki/CompilationGuide/Generic This page provides some generic instructions for compiling a project starting from the source

Java-Spring框架實現簡單的檔案上傳圖片

一、開發環境搭建 下載相應的jar包: 1.commons-fileupload    2.commons-io web.xml 檔案配置: <servlet>       &

VLC 模組構造巨集的展開access_output_http

巨集的定義: vlc_module_begin ()     set_description( N_("HTTP stream output") )     set_capability( "sout access", 0 )    

簡單明瞭的nftables防火牆配置arch

Arch Linux的核心已經包含了netfilter包過濾框架。 在/etc/nftables.conf預設包含著一個簡單的防火牆設定,但過於簡單, 現在重新編寫nft的設定(這裡列舉的規則適合個人電腦,伺服器或是其它的機器可以參考其它資料配置更加適合的規則)。   # nft list r

簡單明了的nftables防火墻配置arch

設置 包含 hab 資料 監聽 php style lis 一個 Arch Linux的內核已經包含了netfilter包過濾框架。 在/etc/nftables.conf默認包含著一個簡單的防火墻設置,但過於簡單, 現在重新編寫nft的設置(這裏列舉的規則適合個人電腦,服

RHCS實現高可用中的共享儲存iscisimysql

1、實驗環境 server2 172.25.66.2(配置Nginx、ricci、luci) server3 172.25.66.3(Apache) server4 172.25.66.4 (Apache) server5 172.25.66.5(配置Nginx

git中的ssh和https方式的使用gitee

      在使用git管理程式碼,或者使用github,國內的碼雲(gitee)的時候,有兩種方式可以使用,分別是https和ssh,以下均使用gitee為例。 ssh方式    配置ssh,如果不配置ssh的話,clone專案的時候會

構建有向無環圖DAG模型解決矩形巢狀問題 以nyoj16

 DAG(Directed Acyclic Graph):在圖論中,如果一個有向圖無法從某個頂點出發經過若干條邊回到該點,則這個圖是一個有向無環圖(DAG圖)。有向無環圖上的動態規劃是學習動態規劃的基礎。很多問題都可以轉化為DAG上的最長路和最短路或計數問題。 分析

Python爬蟲 抓取大資料崗位招聘資訊51job

簡單介紹一下爬蟲原理。並給出 51job網站完整的爬蟲方案。 爬蟲基礎知識 資料來源 網路爬蟲的資料一般都來自伺服器的響應結果,通常有html和json資料等,這兩種資料也是網路爬蟲的主要資料來源。 其中html資料是網頁的原始碼,通過瀏覽器-檢視原始碼可

兩個及以上塊級元素div排列一行

1.浮動解決(float)-------> _ >萬能 demo程式碼: <!DOCTYPE html> <html> <head lang="en"> <meta charset="UTF-8">