torch01:torch基礎
MachineLP的部落格目錄:小鵬的部落格目錄
本小節介紹torch的基礎操作和流程:
(1)計算表示式的梯度值。
(2)陣列與tensor。
(3)構建輸入管道。
(4)載入預訓練的模型。
(5)儲存和載入權重。
---------------------------------我是可愛的分割線---------------------------------
程式碼部分:
(0)import
# coding=utf-8 import torch import torchvision import torch.nn as nn import numpy as np import torchvision.transforms as transforms print (torch.__version__)
(1)計算梯度值
# 建立tensor x = torch.tensor(1, requires_grad=True) w = torch.tensor(2, requires_grad=True) b = torch.tensor(3, requires_grad=True) # 構建模型, 建立計算圖 y = w * x + b # 計算梯度 y.backward() # 輸出計算後的梯度值 print ('x:grad', x.grad) print ('w:grad', w.grad) print ('b:grad', b.grad) # 建立兩個tensor x = torch.randn(10, 3) y = torch.randn(10, 2) # 搭建全連線層 linear = nn.Linear(3,2) # 列印模型權重值 print ('w', linear.weight) print ('b', linear.bias) # 構建你需要的損失函式和優化演算法 criterion = nn.MSELoss() optimizer = torch.optim.SGD(linear.parameters(), lr=0.01) # 前向計算 pred = linear(x) # 計算loss loss = criterion(pred, y) print('loss: ', loss.item()) loss.backward() # 列印輸出梯度 print ('dL/dw: ', linear.weight.grad) print ('dL/db: ', linear.bias.grad) # 梯度下降 optimizer.step() # 梯度下降後,再列印權重值就會減小。 print ('w', linear.weight) print ('b', linear.bias) # 梯度下降後的預測值和loss pred = linear(x) loss = criterion(pred, y) print('loss after 1 step optimization: ', loss.item())
(2)陣列與tensor。
# 建立陣列, 轉陣列為tensor
x = np.array([[1, 2], [3, 4]])
y = torch.from_numpy(x)
# 轉tensor為陣列
z = y.numpy()
(3)構建輸入管道。
# 下載 CIFAR-10 資料 train_dataset = torchvision.datasets.CIFAR10(root='../../data/', train=True, transform=transforms.ToTensor(), download=True) # 樣本和標籤 image, label = train_dataset[0] print (image.size()) print (label) # 通過佇列的形式載入資料 train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True) # 建立迭代器,為每次訓練提供訓練資料 data_iter = iter(train_loader) # Mini-batch 樣本和標籤 images, labels = data_iter.next() # 另外一種方式 for images, labels in train_loader: # 訓練的程式碼 pass
# 在你自己的資料上構建高效資料載入的方式
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
# TODO
# 1. Initialize file paths or a list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
#
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
batch_size=64,
shuffle=True)
(4)載入預訓練的模型。
# 下載和載入預訓練的模型ResNet-18.
resnet = torchvision.models.resnet18(pretrained=True)
# 只進行fine-tune top層:
for param in resnet.parameters():
param.requires_grad = False
# Replace the top layer for finetuning.
resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is an example.
# Forward pass.
images = torch.randn(64, 3, 224, 224)
outputs = resnet(images)
print (outputs.size()) # (64, 100)
(5)儲存和載入權重。
# 儲存和載入模型
torch.save(resnet, 'model.ckpt')
model = torch.load('model.ckpt')
# 只儲存和載入模型引數
torch.save(resnet.state_dict(), 'params.ckpt')
resnet.load_state_dict(torch.load('params.ckpt'))
---------------------------------我是可愛的分割線---------------------------------
總結:
加餐:
在資料上進行載入資料:
其中,train.txt中的資料格式:
gender/0male/0(2).jpg 1
gender/0male/0(3).jpeg 1
gender/0male/0(1).jpg 0
# coding = utf-8
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
from PIL import Image
def default_loader(path):
# 注意要保證每個batch的tensor大小時候一樣的。
return Image.open(path).convert('RGB')
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
# line = line.rstrip()
words = line.split(' ')
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
def get_loader(dataset='train.txt', crop_size=178, image_size=128, batch_size=2, mode='train', num_workers=1):
"""Build and return a data loader."""
transform = []
if mode == 'train':
transform.append(transforms.RandomHorizontalFlip())
transform.append(transforms.CenterCrop(crop_size))
transform.append(transforms.Resize(image_size))
transform.append(transforms.ToTensor())
transform.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
transform = transforms.Compose(transform)
train_data=MyDataset(txt=dataset, transform=transform)
data_loader = DataLoader(dataset=train_data,
batch_size=batch_size,
shuffle=(mode=='train'),
num_workers=num_workers)
return data_loader
# 注意要保證每個batch的tensor大小時候一樣的。
# data_loader = DataLoader(train_data, batch_size=2,shuffle=True)
data_loader = get_loader('train.txt')
print(len(data_loader))
def show_batch(imgs):
grid = utils.make_grid(imgs)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader')
for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(),batch_y.size())
show_batch(batch_x)
plt.axis('off')
plt.show()
總結:
以上是torch的基礎部分,總體的流程已經有了,上手就很快了。
相關推薦
torch01:torch基礎
MachineLP的部落格目錄:小鵬的部落格目錄本小節介紹torch的基礎操作和流程:(1)計算表示式的梯度值。(2)陣列與tensor。(3)構建輸入管道。(4)載入預訓練的模型。(5)儲存和載入權重。---------------------------------我是可
Day1:計算機基礎
允許 ade 模式 劃分 width 指針 bsp 方法 是我 今天是正式上課的第一天,聽瞎驢老師講課還是很容易聽懂的。雖然接觸計算機比較早,大學也學過一點相關內容,今天的課也是很有收獲的,需要一定的時間來整理記錄一下今天所學的東西。 一、編程語言的作用
資源整合:java基礎課程第一天
nds his java 不能 1.7 public hello 單行 bst jdk的安裝配置 (1)JAVA_HOME E:\Java\jdk1.7.0 (2) path %JAVA_HOME%\bin; CMD
nodejs零基礎詳細教程1:安裝+基礎概念
img res 安裝過程 pkg 實時 linkedin 圖標 過程 好的 第一章 建議學習時間2小時 課程共10章 學習方式:詳細閱讀,並手動實現相關代碼 學習目標:此教程將教會大家 安裝Node、搭建服務器、express、mysql、mongodb、編寫後臺業務邏輯
Python開發【第六篇】:Python基礎條件和循環
ora back strong als 重復執行 操作 enume 條件表達式 服務 目錄 一、if語句 1、功能 2、語法 單分支,單重條件判斷 多分支,多重條件判斷 if + else 多分支if + elif + else 語句小結 + 案例 三元表達式 二、whil
Python開發【第五篇】:Python基礎之2
對齊方式 dex 字符串 後退 ring lag nic 有效 func 字符串格式化 Python的字符串格式化有兩種方式: 百分號方式、format方式 百分號的方式相對來說比較老,而format方式則是比較先進的方式,企圖替換古老的方式,目前兩者並存。[PEP-310
Python開發【第四篇】:Python基礎之函數
nco pos *args 更強 三元 sequence hunk ins att 三元運算 三元運算(三目運算),是對簡單的條件語句的縮寫。 # 書寫格式 result = 值1 if 條件 else 值2 # 如果條件成立,那麽將 “值1” 賦值給result
css3動畫效果:1基礎
prop 包含 rop lin tex color 變換 百分比 css屬性 css動畫分兩種:過渡效果transition 、關鍵幀動畫keyframes 一、過渡效果transition 需觸發一個事件(如hover、click)時,才改變其css屬性。 過渡效果通常在
1:python基礎
python基礎 spa 靜態 clas 編程 編譯器 mar 編譯 gin python基礎 參考:https://www.cnblogs.com/alex3714/articles/5465198.html 編程語言主要從以下幾個角度為進行分類,編譯型和解釋型、靜態語
httpd學習:http基礎
http html.js http:hyper text transfer protocol,80tcphtml: 編程語言,超文本標記語言;CSS: Cascading Style Sheet,層疊樣式表js:javascript,客戶端腳本MIME: Multipurpose Internet M
python數據處理:pandas基礎
log eat ges 處理 保留 sed lang sce rop 本文資料來源: Python for Data Anylysis: Chapter 5 10 mintues to pandas: http://pandas.pydata.org/pandas-
openstack項目【day23】:glance基礎
/var/ 默認 write 強調 alt p s 版本 星期 .cn 本節內容 一 什麽是glance 二 為何要有glance 三 glance的功能 四 glance的兩個版本 五 鏡像的數據存放 六 鏡像的訪問權限 七 鏡像及任務的各種狀態 八 glance包含的
linux常用命令整理(五):shell基礎
程序猿 逆向 多條 希望 正則表達 group 運行 ls命令 交互式 大家好,我是會唱歌的程序猿~~~~~~ 最近在學習linux,閑暇之余就把這些基本的命令進行了整理,希望大家能用的上,整理的的目的是在忘了的時候翻出來看看^?_?^,前後一共分為五個部分
第五篇:python基礎_5
執行過程 ini 間接 ray 復雜 func 時間 基於 time 本篇內容 協程函數 遞歸 二分法 import語句 from...import語句 模塊搜索路徑 包的導入 軟件開發規範 logging模塊的使用 一、 協程函數 1.定義 協程函數就是使用了y
python之路:python基礎3
bar 匿名函數 發送 函數式 edit 系統 概念 作用域 opened ---恢復內容開始--- 本節內容 1. 函數基本語法及特性 2. 參數與局部變量 3. 返回值 嵌套函數 4.遞歸 5.匿名函數 6.函數式編程介紹 7.高階函數 8.內置函數 溫故知新 1.
前端知識學習一 :CSS基礎
分隔 color html元素 http 方式 瀏覽器 單位 工作 分離 一.CSS概述 css指的是層疊樣式表,樣式定義如何顯示HTML元素,樣式通常存儲在樣式表中, 把樣式添加到HTML4.0中,是為了解決內容和表現分離的問題。外部樣式表通常存儲在css文件
轉:JMeter基礎之一 一個簡單的性能測試
cat 自動生成 html enc 兩個 導致 自己的 線程數 網絡 QPS 解釋 QPS : Query Per Second 每秒查詢率。是一臺查詢服務器每秒能夠處理的查詢次數。在因特網上,作為域名系統服務器的機器的性能經常用每秒查詢率來衡量。 為了達成預期
第八篇:python基礎_8 面向對象與網絡編程
pro size 賬單 socket 基礎 發生 多態 proc client 本篇內容 接口與歸一化設計 多態與多態性 封裝 面向對象高級 異常處理 網絡編程 一、 接口與歸一化設計 1.定義 (1)歸一化讓使用者無需關心對象的類是什麽,只需要知道這些對象都具備某
課堂筆記:Python基礎-字典
更新 tabs with numeric ide rfi form pda [] Python字典的兩大特點:無序、鍵唯一 #字典創建dic={‘name‘:‘alex‘} #第一種形式 dic2=dict(((‘name‘,‘alex‘),)) #
課堂筆記:Python基礎-文件操作
def cas elf 擴展 中文 tell enum new span 對文件操作流程 打開文件,得到文件句柄並賦值給一個變量 通過句柄對文件進行操作 關閉文件 現有文件如下: 昨夜寒蛩不住鳴。 驚回千裏夢,已三更。 起來獨自繞階行。 人悄悄,簾外月朧明。