torch.gt(a,b)----->pointwise元素比較
阿新 • • 發佈:2020-10-22
""" 兩個同shape的tensor,point-wise級別元素對應比較,前者大於後者,返回位置元素為1;前者小於等於後者,返回位置元素為0 """ >>> import torch >>> a=torch.randn(2,4) >>> a tensor([[-0.5466, 0.9203, -1.3220, -0.7948], [ 2.0300, 1.3090, -0.5527, -0.1326]]) >>> b=torch.randn(2,4) >>> b tensor([[-0.0160, -0.3129, -1.0287, 0.5962], [ 0.3191, 0.7988, 1.4888, -0.3341]]) >>> torch.gt(a,b) #得到a中比b中元素大的位置,由於廣播作用, b也可以為1個數!!! tensor([[0, 1, 0, 0], [1, 1, 0, 1]], dtype=torch.uint8) >>> torch.gt(b,a) #b中比a中大 tensor([[1, 0, 1, 1], [0, 0, 1, 0]], dtype=torch.uint8) >>> torch.gt(a,1) tensor([[0, 0, 0, 0], [1, 1, 0, 0]], dtype=torch.uint8) >>> c=torch.Tensor([[1,2,3],[4,5,6]]) >>> d=torch.Tensor([[1,1,3],[5,5,5]]) >>> torch.gt(c,d) #必須是嚴格大於才為1 tensor([[0, 1, 0], [0, 0, 1]], dtype=torch.uint8)