1. 程式人生 > 程式設計 >pytorch中torch.max和Tensor.view函式用法詳解

pytorch中torch.max和Tensor.view函式用法詳解

torch.max()

1.

torch.max()簡單來說是返回一個tensor中的最大值。

例如:

>>> si=torch.randn(4,5)
>>> print(si)
tensor([[ 1.1659,-1.5195,0.0455,1.7610,-0.2064],[-0.3443,2.0483,0.6303,0.9475,0.4364],[-1.5268,-1.0833,1.6847,0.0145,-0.2088],[-0.8681,0.1516,-0.7764,0.8244,-1.2194]])

>>> print(torch.max(si))
tensor(2.0483)

2.

這個函式的引數中還有一個dim引數,使用方法為re = torch.max(Tensor,dim),返回的re為一個二維向量,其中re[0]為最大值的Tensor,re[1]為最大值對應的index的Tensor。

例如:

>>> print(torch.max(si,0)[0])
tensor([1.1659,0.4364])

注意,Tensor的維度從0開始算起。在torch.max()中指定了dim之後,比如對於一個3x4x5的Tensor,指定dim為0後,得到的結果是維度為0的“每一行”對應位置求最大的那個值,此時輸出的Tensor的維度是4x5.

對於簡單的二維Tensor,如上面例子的這個4x5的Tensor。指定dim為0,則給出的結果是4行做比較之後的最大值;如果指定dim為1,則給出的結果是5列做比較之後的最大值,且此處做比較時是按照位置分別做比較,得到一個新的Tensor。

Tensor.view()

簡單說就是一個把tensor 進行reshape的操作。

>>> a=torch.randn(3,4,5,7)
>>> b = a.view(1,-1)
>>> print(b.size())
torch.Size([1,420])

其中引數-1表示剩下的值的個數一起構成一個維度。如上例中,第一個引數1將第一個維度的大小設定成1,後一個-1就是說第二個維度的大小=元素總數目/第一個維度的大小,此例中為3*4*5*7/1=420.

>>> d = a.view(a.size(0),a.size(1),-1)
>>> print(d.size())
torch.Size([3,35])

 

>>> e=a.view(4,-1,5)
>>> print(e.size())
torch.Size([4,21,5])

以上這篇pytorch中torch.max和Tensor.view函式用法詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。