完整的pytorch教程(mnist 為例)
'''
設定不同層的學習率、更新學習率(對應更新)、輸出中間層特徵、建立網路、儲存網路、測試網路效果、初始化
A whole Pytorch tutorial : set different layer's lr , update lr (One to one correspondence)
output middle layer's feature and fine-tune
'''
import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data
EPOCH=20
BATCH_SIZE=64
LR=1e-4
# mnist download,transform NHWC=>NCHW and 0,255=>0,1
train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
transform=torchvision.transforms.ToTensor(),download=False)
# pytorch's dataset loader
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)
# test data
test_data=torchvision.datasets.MNIST(root='./mnist',train=False)
# test Variable need transform gpu
test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()
# create model
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=2))
self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
self.out=nn.Linear(32*7*7,10)
# init
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self,x):
per_out=[]
x=self.conv1(x)
per_out.append(x)
x=self.conv2(x)
per_out.append(x)
x=x.view(x.size(0),-1)
output=self.out(x)
# can output middle layer's features
return output,per_out
cnn=CNN().cuda()# gpu
# set different layer's learning rate: [conv1 conv2] lr*10 ; [out] lr
def get_10x_lr_params(net):
b=[net.conv1,net.conv2]
for i in b:
for j in i.modules():
for k in j.parameters():
yield k
# generator
# fine-tune
new_params=cnn.state_dict()
pretrain_dict=torch.load('./model/model.pth')
pretrain_dict={k:v for k,v in pretrain_dict.items() if k in new_params and v.size()==new_params[k].size()}#dict gennerator
new_params.update(pretrain_dict)
cnn.load_state_dict(new_params)
cnn.train()# if you want test ,just modify cnn.eval()
# update lr
def lr_poly(base_lr,iters,max_iter,power):
return base_lr*((1-float(iters)/max_iter)**power)
def adjust_lr(optimizer,base_lr,iters,max_iter,power):
lr=lr_poly(base_lr,iters,max_iter,power)
optimizer.param_groups[0]['lr']=lr # first param iterator
if len(optimizer.param_groups)>1:
optimizer.param_groups[1]['lr']=lr*10
Params=get_10x_lr_params(cnn)
# optimizer first params lr=LR Internal overlapping external lr; second params
optimizer=torch.optim.Adam([{'params':cnn.out.parameters()},{'params':Params,'lr':LR*10}],lr=LR)
# loss function
loss_func=nn.CrossEntropyLoss().cuda()
iters=0
for epoch in range(EPOCH):
i_iter=train_data.train_data.shape[0]//BATCH_SIZE
for step,(x,y) in enumerate(train_loader):
optimizer.zero_grad()# clear gradient
adjust_lr(optimizer,LR,iters,EPOCH*i_iter,0.9)# update lr
iters+=1
b_x=Variable(x).cuda()# if channel==1 auto add c=1
b_y=Variable(y).cuda()
# print(cnn.state_dict()['conv1.0.weight'])
output=cnn(b_x)[0]
loss=loss_func(output,b_y)# Variable need to get .data
loss.backward() # backward loss
optimizer.step() # compute per gradient
if step%50==0:
test_output=cnn(test_x)[0]
pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
'''
why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu
and to numpy and .float compute decimal
'''
accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
# loss.data.cpu().numpy().item() get one value
torch.save(cnn.state_dict(),'./model/model.pth')
# test phase
test_output=cnn(test_x[:13])[0]
pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
print(pred_y)
print(test_y[:13])
相關推薦
完整的pytorch教程(mnist 為例)
''' 設定不同層的學習率、更新學習率(對應更新)、輸出中間層特徵、建立網路、儲存網路、測試網路效果、初始化 A whole Pytorch tutorial : set different layer's lr , update lr (One to one cor
Tensorflow建立資料集(mnist為例)
網上的mnist的demo大部分都是按照實戰google那本來的,但是那個在資料集的處理上用的是TensorFlow的官方api,我們在正常做標籤的時候並不一定要那樣做,本文講解了兩種標籤方式區別於實戰google的demo。 folder方式: ROOT_FOLDER |--------
pytorch建立自己的資料集(以mnist為例)
本文將原始的numpy array資料在pytorch下封裝為Dataset類的資料集,為後續深度網路訓練提供資料。 載入並儲存影象資訊 首先匯入需要的庫,定義各種路徑。 import os import matplotlib from keras.datase
Keras中實現模型載入與測試(以mnist為例)
需要安裝cv2 安裝h5py的命令如下(模型載入模組): sudo pip install cython sudo apt-get install libhdf5-dev sudo pip
多數據源動態關聯報表的制作(birt為例)
處理 center 關閉 主表 等價 兩個 數據 fonts img 使用Jasper或BIRT等報表工具時,常會碰到一些很規的統計,用報表工具本身或SQL都難以處理,比方與主表相關的子表分布在多個數據庫中,報表要展現這些數據源動態關聯的結果。集算器具
linux mount掛載設備(u盤,光盤,iso等 )使用說明(Centos為例)
centos mount 掛載方法集錦 linux mount 對於新手學習,mount 命令,一定會有很多疑問。其實我想疑問來源更多的是對linux系統本身特殊性了解問題。 linux是基於文件系統,所有的設備都會對應於:/dev/下面的設備。如:[chengmo@centos5 dev]$ l
對於單個模型(長方體為例)進行面投影時的消隱
return .com www ++ 類的繼承 投影 逆時針 所有 順序 作者:feiquan 出處:http://www.cnblogs.com/feiquan/ 版權聲明:本文版權歸作者和博客園共有,歡迎轉載,但未經作者同意必須保留此段聲明,且在文章頁面明顯位置給
【總結】遊戲框架與架構設計(Unity為例)
單機 業務 github 事件 概念 lec 集合 架構模式 wid 使用框架開發遊戲 優點:耦合性低,重用性高,部署快,可維護性高,方便管理。提高開發效率,降低開發難度 缺點:增加了系統結構和實現的復雜性,需要額外花費精力維護,不適合小型程序,易影響運行效率 常見
KEGG下載某物種最新的版本信息(斑馬魚為例)
ref nbsp wid 結構 egg 解析 image 版本 ima 步驟一:打開鏈接並選擇物種 http://www.genome.jp/kegg-bin/get_htext?hsa00001+3101 步驟二:對文件進行解析 步驟三:統計信息 一級結構(6大類):
oracle資料庫的徹底解除安裝(11g為例)
1. 關閉oracle所有的服務。可以在windows的服務管理器中關閉;2. 開啟登錄檔:regedit 開啟路徑: HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\ 刪除該路徑下的所有以oracle開始的服務名稱,這個鍵是標
LINUX安裝軟體(FFmpeg為例)
https://trac.ffmpeg.org/wiki/CompilationGuide/Generic This page provides some generic instructions for compiling a project starting from the source
Java-Spring框架實現簡單的檔案上傳(圖片為例)
一、開發環境搭建 下載相應的jar包: 1.commons-fileupload 2.commons-io web.xml 檔案配置: <servlet> &
VLC 模組構造巨集的展開(access_output_http為例)
巨集的定義: vlc_module_begin () set_description( N_("HTTP stream output") ) set_capability( "sout access", 0 )
簡單明瞭的nftables防火牆配置(arch為例)
Arch Linux的核心已經包含了netfilter包過濾框架。 在/etc/nftables.conf預設包含著一個簡單的防火牆設定,但過於簡單, 現在重新編寫nft的設定(這裡列舉的規則適合個人電腦,伺服器或是其它的機器可以參考其它資料配置更加適合的規則)。 # nft list r
簡單明了的nftables防火墻配置(arch為例)
設置 包含 hab 資料 監聽 php style lis 一個 Arch Linux的內核已經包含了netfilter包過濾框架。 在/etc/nftables.conf默認包含著一個簡單的防火墻設置,但過於簡單, 現在重新編寫nft的設置(這裏列舉的規則適合個人電腦,服
RHCS實現高可用中的共享儲存iscisi(mysql為例)
1、實驗環境 server2 172.25.66.2(配置Nginx、ricci、luci) server3 172.25.66.3(Apache) server4 172.25.66.4 (Apache) server5 172.25.66.5(配置Nginx
git中的ssh和https方式的使用(gitee為例)
在使用git管理程式碼,或者使用github,國內的碼雲(gitee)的時候,有兩種方式可以使用,分別是https和ssh,以下均使用gitee為例。 ssh方式 配置ssh,如果不配置ssh的話,clone專案的時候會
構建有向無環圖(DAG)模型解決矩形巢狀問題 以(nyoj16為例)
DAG(Directed Acyclic Graph):在圖論中,如果一個有向圖無法從某個頂點出發經過若干條邊回到該點,則這個圖是一個有向無環圖(DAG圖)。有向無環圖上的動態規劃是學習動態規劃的基礎。很多問題都可以轉化為DAG上的最長路和最短路或計數問題。 分析
Python爬蟲 抓取大資料崗位招聘資訊(51job為例)
簡單介紹一下爬蟲原理。並給出 51job網站完整的爬蟲方案。 爬蟲基礎知識 資料來源 網路爬蟲的資料一般都來自伺服器的響應結果,通常有html和json資料等,這兩種資料也是網路爬蟲的主要資料來源。 其中html資料是網頁的原始碼,通過瀏覽器-檢視原始碼可
兩個及以上塊級元素(div為例)排列一行
1.浮動解決(float)-------> _ >萬能 demo程式碼: <!DOCTYPE html> <html> <head lang="en"> <meta charset="UTF-8">