符號*,torch.max 和 torch.sum, item()方法
阿新 • • 發佈:2018-11-10
*的作用可以參考https://www.cnblogs.com/jony7/p/8035376.html
torch.max可以參考https://blog.csdn.net/Z_lbj/article/details/79766690
a.size() # Out[134]: torch.Size([6, 4, 3]) torch.max(a, 0)[1].size() # Out[135]: torch.Size([4, 3]) torch.max(a, 1)[1].size() # Out[136]: torch.Size([6, 3]) torch.max(a, 2)[1].size() # Out[137]: torch.Size([6, 4])
具體怎麼比較的可以看下面
b tensor([[[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]], [[ 12., 13., 14., 15.], [ 16., 17., 18., 19.], [ 20., 21., 22., 23.]]]) torch.max(b,0)[0] tensor([[ 12., 13., 14., 15.], [ 16., 17., 18., 19.], [ 20., 21., 22., 23.]]) torch.max(b,1)[0] tensor([[ 8., 9., 10., 11.], [ 20., 21., 22., 23.]]) torch.max(b,2)[0] tensor([[ 3., 7., 11.], [ 15., 19., 23.]])
相應的下標可以得到
b tensor([[[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]], [[ 12., 13., 14., 15.], [ 16., 17., 18., 19.], [ 20., 21., 22., 23.]]]) torch.max(b,0)[1] tensor([[ 1, 1, 1, 1], [ 1, 1, 1, 1], [ 1, 1, 1, 1]]) torch.max(b,1)[1] tensor([[ 2, 2, 2, 2], [ 2, 2, 2, 2]]) torch.max(b,2)[1] tensor([[ 3, 3, 3], [ 3, 3, 3]])
torch.sum:
torch.sum(input) → Tensor
torch.sum(input, dim, out=None) → Tensor
引數:
input (Tensor) – 輸入張量
dim (int) – 縮減的維度
out (Tensor, optional) – 結果張量
函式的輸出是一個tensor
match
out:
tensor([[[ 0, 0, 2, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0]],
[[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0]]], dtype=torch.uint8)
torch.sum(match)
Out:
tensor(2)
torch.sum(match,0)
Out:
tensor([[ 0, 0, 2, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0]])
torch.sum(match,1)
Out:
tensor([[ 0, 0, 2, 0],
[ 0, 0, 0, 0]])
torch.sum(match,2)
Out:
tensor([[ 2, 0, 0],
[ 0, 0, 0]])
還要補充一點的就是item方法的使用:如果tensor只有一個元素那麼呼叫item方法的時候就是將tensor轉換成python的scalars;如果tensor不是單個元素的話那就會引發ValueError,如下面
b.item()
Traceback (most recent call last):
b.item()
ValueError: only one element tensors can be converted to Python scalars
torch.sum(b)
Out: tensor(276.)
torch.sum(b).item()
Out: 276.0
那麼在python中的item方法一般是怎麼樣的呢?可參見https://blog.csdn.net/qq_34941023/article/details/78431376