1. 程式人生 > 程式設計 >Pytorch 多維陣列運算過程的索引處理方式

Pytorch 多維陣列運算過程的索引處理方式

背景:對 python 不熟悉,能看懂程式碼,也能實現一些簡單的功能,但是對 python 的核心思想和程式設計技巧不熟,所以使 Pytorch 寫 loss 的時候遇到很多麻煩,尤其是在 batch_size > 1 的時候,做矩陣乘法之類的運算會覺得特別不順手。

所幸,在邊查邊寫的過程中,理解了 python 中多維運算的實現規則。

1、python 的基本索引規則

從 0 開始

對於給定的範圍,如 b = a[m:n],那麼 b 為由 (n-m)個數據組成的新陣列,由 a[m],a[m+1],...,a[n-1] 構成。(若 n<m,得到空)

2、單個 tensor 運算,使用 dim 引數

torch 中對 tensor 的操作方法,若不加 dim 引數表示對整體的 tensr 進行操作,若增加 dim 參則表示按維操作。

例:

a = [[1,2],[3,4],[5,6]] (torch.tensor)

  torch.mean(a) => 3.5

  torch.mean(a,dim=0) => [1.5,3.5,5.5]

  torch.mean(a,dim=1) => [[3],[4]]

  torch.mean(a,dim=0) => [3,4]

  torch.mean(a,dim=1) => [1.5,5.5] 

注: torch.mean() 是一個降維的操作,所以不會出現在取均值後保持跟原 Tensor 同維的情況。 dim 引數存在時降一維,不存在時得到的是整個 Tensor 的均值。

3、兩個 tensor 運算,構造對應形狀

以乘法為例:

3.1 矩陣乘向量

a = [[1,6]]

b = [1,1]

計算乘法 c = a@b

若 a 拓展為 (N,3,2) N 為 batch_size,計算 c2 = a@b

若 a,b 同時拓展, 變成(N,2),那麼需要做一個變換 b = b.view(N,2,1),計算 c3 = a@b

3.2 矩陣乘矩陣

a = [[1,6]]

b =[ [1,1],[1,1]]

計算乘法 c = a@b

若 a 拓展為 (N,計算 c3 = a@b

以上這篇Pytorch 多維陣列運算過程的索引處理方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。