Pytorch 中triplet loss的寫法
阿新 • • 發佈:2019-02-06
triplet loss
在Pytorch中有一個類,已經定義好了triplet loss的criterion, class TripletMarginLoss(Module):
class TripletMarginLoss(Module):
r"""Creates a criterion that measures the triplet loss given an input
tensors x1, x2, x3 and a margin with a value greater than 0.
This is used for measuring a relative similarity between samples. A triplet
is composed by `a`, `p` and `n`: anchor, positive examples and negative
example respectively. The shape of all input variables should be
:math:`(N, D)`.
The distance swap is described in detail in the paper `Learning shallow
convolutional feature descriptors with triplet losses`_ by
V. Balntas, E. Riba et al.
Args:
anchor: anchor input tensor
positive: positive input tensor
negative: negative input tensor
p: the norm degree. Default: 2
Shape:
- Input: :math:`(N, D)` where `D = vector dimension`
- Output: :math:`(N, 1)`
使用示例:
>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
>>> input1 = autograd.Variable(torch.randn(100, 128))
>>> input2 = autograd.Variable(torch.randn(100, 128))
>>> input3 = autograd.Variable(torch.randn(100, 128))
>>> output = triplet_loss(input1, input2, input3)
>>> output.backward ()