1. 程式人生 > >torch01:torch基礎

torch01:torch基礎

MachineLP的部落格目錄:小鵬的部落格目錄

本小節介紹torch的基礎操作和流程:

(1)計算表示式的梯度值。

(2)陣列與tensor。

(3)構建輸入管道。

(4)載入預訓練的模型。

(5)儲存和載入權重。

---------------------------------我是可愛的分割線---------------------------------

程式碼部分:

(0)import

# coding=utf-8
import torch
import torchvision
import torch.nn as nn
import numpy as np 
import torchvision.transforms as transforms

print (torch.__version__)

(1)計算梯度值

# 建立tensor
x = torch.tensor(1, requires_grad=True)
w = torch.tensor(2, requires_grad=True)
b = torch.tensor(3, requires_grad=True)

# 構建模型, 建立計算圖
y = w * x + b

# 計算梯度
y.backward()

# 輸出計算後的梯度值
print ('x:grad', x.grad)
print ('w:grad', w.grad)
print ('b:grad', b.grad)

# 建立兩個tensor
x = torch.randn(10, 3)
y = torch.randn(10, 2)

# 搭建全連線層
linear = nn.Linear(3,2)

# 列印模型權重值
print ('w', linear.weight)
print ('b', linear.bias)

# 構建你需要的損失函式和優化演算法
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(linear.parameters(), lr=0.01)

# 前向計算
pred = linear(x)


# 計算loss
loss = criterion(pred, y)
print('loss: ', loss.item())

loss.backward()
# 列印輸出梯度
print ('dL/dw: ', linear.weight.grad) 
print ('dL/db: ', linear.bias.grad)

# 梯度下降
optimizer.step()

# 梯度下降後,再列印權重值就會減小。
print ('w', linear.weight)
print ('b', linear.bias)


# 梯度下降後的預測值和loss
pred = linear(x)
loss = criterion(pred, y)
print('loss after 1 step optimization: ', loss.item())

(2)陣列與tensor。

# 建立陣列, 轉陣列為tensor
x = np.array([[1, 2], [3, 4]])
y = torch.from_numpy(x)
# 轉tensor為陣列
z = y.numpy()

(3)構建輸入管道。

# 下載 CIFAR-10 資料
train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                             train=True, 
                                             transform=transforms.ToTensor(),
                                             download=True)

# 樣本和標籤
image, label = train_dataset[0]
print (image.size())
print (label)

# 通過佇列的形式載入資料
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=64, 
                                           shuffle=True)

# 建立迭代器,為每次訓練提供訓練資料
data_iter = iter(train_loader)

# Mini-batch 樣本和標籤
images, labels = data_iter.next()

# 另外一種方式
for images, labels in train_loader:
    # 訓練的程式碼
    pass
# 在你自己的資料上構建高效資料載入的方式
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        # TODO
        # 1. Initialize file paths or a list of file names. 
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0 

# 
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
                                           batch_size=64, 
                                           shuffle=True)

(4)載入預訓練的模型。

# 下載和載入預訓練的模型ResNet-18.
resnet = torchvision.models.resnet18(pretrained=True)

# 只進行fine-tune top層:
for param in resnet.parameters():
    param.requires_grad = False

# Replace the top layer for finetuning.
resnet.fc = nn.Linear(resnet.fc.in_features, 100)  # 100 is an example.

# Forward pass.
images = torch.randn(64, 3, 224, 224)
outputs = resnet(images)
print (outputs.size())     # (64, 100)

(5)儲存和載入權重。

# 儲存和載入模型
torch.save(resnet, 'model.ckpt')
model = torch.load('model.ckpt')

# 只儲存和載入模型引數
torch.save(resnet.state_dict(), 'params.ckpt')
resnet.load_state_dict(torch.load('params.ckpt'))

---------------------------------我是可愛的分割線---------------------------------

總結:

加餐:

在資料上進行載入資料:

其中,train.txt中的資料格式:

gender/0male/0(2).jpg 1
gender/0male/0(3).jpeg 1
gender/0male/0(1).jpg 0

# coding = utf-8
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
from PIL import Image


def default_loader(path):
    # 注意要保證每個batch的tensor大小時候一樣的。
    return Image.open(path).convert('RGB')


class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\n')
            # line = line.rstrip()
            words = line.split(' ')
            imgs.append((words[0],int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
    
    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img,label
    
    def __len__(self):
        return len(self.imgs)

def get_loader(dataset='train.txt', crop_size=178, image_size=128, batch_size=2, mode='train', num_workers=1):
    """Build and return a data loader."""
    transform = []
    if mode == 'train':
        transform.append(transforms.RandomHorizontalFlip())
    transform.append(transforms.CenterCrop(crop_size))
    transform.append(transforms.Resize(image_size))
    transform.append(transforms.ToTensor())
    transform.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = transforms.Compose(transform)
    train_data=MyDataset(txt=dataset, transform=transform)
    data_loader = DataLoader(dataset=train_data,
                                  batch_size=batch_size,
                                  shuffle=(mode=='train'),
                                  num_workers=num_workers)
    return data_loader
# 注意要保證每個batch的tensor大小時候一樣的。
# data_loader = DataLoader(train_data, batch_size=2,shuffle=True)
data_loader = get_loader('train.txt')
print(len(data_loader))

def show_batch(imgs):
    grid = utils.make_grid(imgs)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.title('Batch from dataloader')


for i, (batch_x, batch_y) in enumerate(data_loader):
    if(i<4):
        print(i, batch_x.size(),batch_y.size())
        show_batch(batch_x)
        plt.axis('off')
        plt.show()

總結:

以上是torch的基礎部分,總體的流程已經有了,上手就很快了。

相關推薦

torch01torch基礎

MachineLP的部落格目錄:小鵬的部落格目錄本小節介紹torch的基礎操作和流程:(1)計算表示式的梯度值。(2)陣列與tensor。(3)構建輸入管道。(4)載入預訓練的模型。(5)儲存和載入權重。---------------------------------我是可

Day1計算機基礎

允許 ade 模式 劃分 width 指針 bsp 方法 是我 今天是正式上課的第一天,聽瞎驢老師講課還是很容易聽懂的。雖然接觸計算機比較早,大學也學過一點相關內容,今天的課也是很有收獲的,需要一定的時間來整理記錄一下今天所學的東西。 一、編程語言的作用

資源整合java基礎課程第一天

nds his java 不能 1.7 public hello 單行 bst jdk的安裝配置 (1)JAVA_HOME E:\Java\jdk1.7.0 (2) path %JAVA_HOME%\bin; CMD

nodejs零基礎詳細教程1安裝+基礎概念

img res 安裝過程 pkg 實時 linkedin 圖標 過程 好的 第一章 建議學習時間2小時 課程共10章 學習方式:詳細閱讀,並手動實現相關代碼 學習目標:此教程將教會大家 安裝Node、搭建服務器、express、mysql、mongodb、編寫後臺業務邏輯

Python開發【第六篇】Python基礎條件和循環

ora back strong als 重復執行 操作 enume 條件表達式 服務 目錄 一、if語句 1、功能 2、語法 單分支,單重條件判斷 多分支,多重條件判斷 if + else 多分支if + elif + else 語句小結 + 案例 三元表達式 二、whil

Python開發【第五篇】Python基礎之2

對齊方式 dex 字符串 後退 ring lag nic 有效 func 字符串格式化 Python的字符串格式化有兩種方式: 百分號方式、format方式 百分號的方式相對來說比較老,而format方式則是比較先進的方式,企圖替換古老的方式,目前兩者並存。[PEP-310

Python開發【第四篇】Python基礎之函數

nco pos *args 更強 三元 sequence hunk ins att 三元運算 三元運算(三目運算),是對簡單的條件語句的縮寫。 # 書寫格式 result = 值1 if 條件 else 值2 # 如果條件成立,那麽將 “值1” 賦值給result

css3動畫效果1基礎

prop 包含 rop lin tex color 變換 百分比 css屬性 css動畫分兩種:過渡效果transition 、關鍵幀動畫keyframes 一、過渡效果transition 需觸發一個事件(如hover、click)時,才改變其css屬性。 過渡效果通常在

1python基礎

python基礎 spa 靜態 clas 編程 編譯器 mar 編譯 gin python基礎 參考:https://www.cnblogs.com/alex3714/articles/5465198.html 編程語言主要從以下幾個角度為進行分類,編譯型和解釋型、靜態語

httpd學習http基礎

http html.js http:hyper text transfer protocol,80tcphtml: 編程語言,超文本標記語言;CSS: Cascading Style Sheet,層疊樣式表js:javascript,客戶端腳本MIME: Multipurpose Internet M

python數據處理pandas基礎

log eat ges 處理 保留 sed lang sce rop 本文資料來源:   Python for Data Anylysis: Chapter 5   10 mintues to pandas: http://pandas.pydata.org/pandas-

openstack項目【day23】glance基礎

/var/ 默認 write 強調 alt p s 版本 星期 .cn 本節內容 一 什麽是glance 二 為何要有glance 三 glance的功能 四 glance的兩個版本 五 鏡像的數據存放 六 鏡像的訪問權限 七 鏡像及任務的各種狀態 八 glance包含的

linux常用命令整理(五)shell基礎

程序猿 逆向 多條 希望 正則表達 group 運行 ls命令 交互式 大家好,我是會唱歌的程序猿~~~~~~ 最近在學習linux,閑暇之余就把這些基本的命令進行了整理,希望大家能用的上,整理的的目的是在忘了的時候翻出來看看^?_?^,前後一共分為五個部分

第五篇python基礎_5

執行過程 ini 間接 ray 復雜 func 時間 基於 time 本篇內容 協程函數 遞歸 二分法 import語句 from...import語句 模塊搜索路徑 包的導入 軟件開發規範 logging模塊的使用 一、 協程函數 1.定義 協程函數就是使用了y

python之路python基礎3

bar 匿名函數 發送 函數式 edit 系統 概念 作用域 opened ---恢復內容開始--- 本節內容 1. 函數基本語法及特性 2. 參數與局部變量 3. 返回值 嵌套函數 4.遞歸 5.匿名函數 6.函數式編程介紹 7.高階函數 8.內置函數 溫故知新 1.

前端知識學習一 CSS基礎

分隔 color html元素 http 方式 瀏覽器 單位 工作 分離 一.CSS概述     css指的是層疊樣式表,樣式定義如何顯示HTML元素,樣式通常存儲在樣式表中,   把樣式添加到HTML4.0中,是為了解決內容和表現分離的問題。外部樣式表通常存儲在css文件

JMeter基礎之一 一個簡單的性能測試

cat 自動生成 html enc 兩個 導致 自己的 線程數 網絡 QPS 解釋   QPS : Query Per Second 每秒查詢率。是一臺查詢服務器每秒能夠處理的查詢次數。在因特網上,作為域名系統服務器的機器的性能經常用每秒查詢率來衡量。   為了達成預期

第八篇python基礎_8 面向對象與網絡編程

pro size 賬單 socket 基礎 發生 多態 proc client 本篇內容 接口與歸一化設計 多態與多態性 封裝 面向對象高級 異常處理 網絡編程 一、 接口與歸一化設計 1.定義 (1)歸一化讓使用者無需關心對象的類是什麽,只需要知道這些對象都具備某

課堂筆記Python基礎-字典

更新 tabs with numeric ide rfi form pda [] Python字典的兩大特點:無序、鍵唯一 #字典創建dic={‘name‘:‘alex‘} #第一種形式 dic2=dict(((‘name‘,‘alex‘),)) #

課堂筆記Python基礎-文件操作

def cas elf 擴展 中文 tell enum new span 對文件操作流程 打開文件,得到文件句柄並賦值給一個變量 通過句柄對文件進行操作 關閉文件   現有文件如下:    昨夜寒蛩不住鳴。 驚回千裏夢,已三更。 起來獨自繞階行。 人悄悄,簾外月朧明。