1. 程式人生 > >Pytorch 中triplet loss的寫法

Pytorch 中triplet loss的寫法

triplet loss

在Pytorch中有一個類,已經定義好了triplet loss的criterion, class TripletMarginLoss(Module):

class TripletMarginLoss(Module):
    r"""Creates a criterion that measures the triplet loss given an input
    tensors x1, x2, x3 and a margin with a value greater than 0.
    This is used for measuring a relative similarity between samples. A triplet
    is
composed by `a`, `p` and `n`: anchor, positive examples and negative example respectively. The shape of all input variables should be :math:`(N, D)`. The distance swap is described in detail in the paper `Learning shallow convolutional feature descriptors with triplet losses`_ by V. Balntas, E. Riba et al. Args: anchor: anchor input tensor positive: positive input tensor negative: negative input tensor p: the norm degree. Default: 2 Shape: - Input: :math:`(N, D)` where `D = vector dimension` - Output: :math:`(N, 1)`

使用示例:

  >>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
    >>> input1 = autograd.Variable(torch.randn(100, 128))
    >>> input2 = autograd.Variable(torch.randn(100, 128))
    >>> input3 = autograd.Variable(torch.randn(100, 128))
    >>> output = triplet_loss(input1, input2, input3)
    >>> output.backward
()

參考網址