1. 程式人生 > 實用技巧 >torch.gt(a,b)----->pointwise元素比較

torch.gt(a,b)----->pointwise元素比較

"""
兩個同shape的tensor,point-wise級別元素對應比較,前者大於後者,返回位置元素為1;前者小於等於後者,返回位置元素為0
"""
>>> import torch
>>> a=torch.randn(2,4)
>>> a
tensor([[-0.5466,  0.9203, -1.3220, -0.7948],
        [ 2.0300,  1.3090, -0.5527, -0.1326]])
>>> b=torch.randn(2,4)
>>> b
tensor([[-0.0160, -0.3129, -1.0287,  0.5962],
        [ 0.3191,  0.7988,  1.4888, -0.3341]])
>>> torch.gt(a,b)            #得到a中比b中元素大的位置,由於廣播作用, b也可以為1個數!!!
tensor([[0, 1, 0, 0],
        [1, 1, 0, 1]], dtype=torch.uint8)
>>> torch.gt(b,a)           #b中比a中大
tensor([[1, 0, 1, 1],
        [0, 0, 1, 0]], dtype=torch.uint8)
>>> torch.gt(a,1)
tensor([[0, 0, 0, 0],
        [1, 1, 0, 0]], dtype=torch.uint8)
>>> c=torch.Tensor([[1,2,3],[4,5,6]])
>>> d=torch.Tensor([[1,1,3],[5,5,5]])
>>> torch.gt(c,d)                   #必須是嚴格大於才為1
tensor([[0, 1, 0],
        [0, 0, 1]], dtype=torch.uint8)