1. 程式人生 > 其它 >PyTorch基礎知識學習

PyTorch基礎知識學習

學習了張量、自動求導和平行計算。

深度學習流程

1.張量

  • 0維張量=標量 是一個數字
  • 1維張量=向量
  • 2維張量=矩陣
  • 3維張量=時間序列/股價/文字資料/彩色圖片(RGB)
  • 4維張量=影象
  • 5維張量=視訊

建立tensor

import torch
# 建立tensor,用dtype指定型別。注意型別要匹配
a = torch.tensor(1.0, dtype=torch.float)
b = torch.tensor(1, dtype=torch.long)
# 使用指定型別函式隨機初始化指定大小的tensor
c = torch.FloatTensor(2,3)
d = torch.IntTensor(2)

tensor和numpy array之間的相互轉換

import numpy as np
x = np.array([[1,2,3],[4,5,6]])
y = torch.tensor(x)
print(y)
z = torch.from_numpy(x)
print(z)
w = y.numpy()
print(w)

常見的構造Tensor的函式

函式 功能
Tensor(**sizes*) 基礎建構函式
tensor(data) 類似於np.array
ones(**sizes*) 全1
zeros(**sizes*) 全0
eye(**sizes*) 對角為1,其餘為0
arange(s,e,step) 從s到e,步長為step
linspace(s,e,steps) 從s到e,均勻分成step份
rand/randn(**sizes*)
normal(mean,std)/uniform(from,to) 正態分佈/均勻分佈
randperm(m) 隨機排列
i = torch.rand(2, 3) 
j = torch.ones(2, 3)

操作

# 檢視tensor的維度資訊(兩種方式)
print(i.shape)
print(i.size())
# tensor的運算
e = torch.add(i,j)
print(e)
# tensor的索引方式與numpy類似
print(e[:,1])
print(e[0,:])
# 改變tensor形狀的神器:view
print(e.view((3,2)))
print(e.view(-1,2))
# tensor的廣播機制(使用時要注意這個特性)
f = torch.arange(1, 3).view(1, 2)
print(f)
g = torch.arange(1, 4).view(3, 1)
print(g)
print(f + g)
# 擴充套件&壓縮tensor的維度:squeeze
print(e)
h = e.unsqueeze(1)
print(h)
print(h.shape)

2.自動求導

PyTorch實現模型訓練

  • 輸入資料,正向傳播(同時建立計算圖DCG)
  • 計算損失函式
  • 損失函式反向傳播
  • 更新模型引數

Tensor資料結構是實現自動求導的基礎

PyTorch 中,所有神經網路的核心是 autograd 包。autograd包為張量上的所有操作提供了自動求導機制。它是一個在執行時定義 ( define-by-run )的框架,這意味著反向傳播是根據程式碼如何執行來決定的,並且每次迭代可以是不同的。

torch.Tensor 是這個包的核心類。如果設定它的屬性 .requires_gradTrue,那麼它將會追蹤對於該張量的所有操作。當完成計算後可以通過呼叫 .backward(),來自動計算所有的梯度。這個張量的所有梯度將會自動累加到.grad屬性。

3.平行計算

為什麼需要平行計算?

  • 能計算——視訊記憶體佔用
  • 算得快——計算速度
  • 效果好——大batch提升訓練效果

怎麼並行?

CUDA

  • GPU廠商NVIDIA提供的GPU計算框架
  • GPU本身的程式設計基於CUDA語言實現
  • 在PyTorch中,CUDA指的是模型或者資料開始使用GPU(而不是CPU)

並行的方法

  • 網路結構分佈到不同裝置中(Network partitioning)

  • 同一層的任務分佈到不同資料中(Layer-wise partitioning)

  • 不同資料分佈到不同的裝置中(Data parallelism)

cuDNN與CUDA

  • cuDNN是用於深度神經網路的加速庫
  • cuDNN是基於CUDA完成深度學習的加速

參考連結:

  1. 深入淺出Pytorch視訊:https://www.bilibili.com/video/BV1e341127Lt?p=2
  2. Datawhale專案:https://github.com/datawhalechina/thorough-pytorch/tree/main/%E7%AC%AC%E4%BA%8C%E7%AB%A0%20PyTorch%E5%9F%BA%E7%A1%80%E7%9F%A5%E8%AF%86
  3. 李巨集毅2021春機器學習課程:https://www.bilibili.com/video/BV1Wv411h7kN?p=5