Multi label 多標簽分類問題
阿新 • • 發佈:2018-09-19
image col ros 標簽 true 針對 nbsp 實例 object
調用函數
Pytorch使用torch.nn.BCEloss
Tensorflow使用tf.losses.sigmoid_cross_entropy
在output和target之間構建binary cross entropy,其中i為每一個類。
m = nn.Sigmoid() loss = nn.BCELoss() input = autograd.Variable(torch.randn(3), requires_grad=True) target = autograd.Variable(torch.FloatTensor(3).random_(2)) output= loss(m(input), target) output.backward()
註意target的形式,要寫成01編碼形式,eg:如果同時為第一類和第三類則,[1, 0, 1]
主要是結合sigmoid來使用,經過classifier分類過後的輸出為(batch_size,num_class)為每個數據的標簽, 標簽不是one-hot的主要體現在sigmoid(output)之後進行bceloss計算時:sigmoid輸出之後,仍然為(batch_size,num_class),但是是每個類別的分數,對於一個實例,它的各個label的分數加起來不一定要等於1,bceloss在每個類維度上求cross entropy loss然後加和求平均得到,這裏就體現了多標簽的思想。
[CVPR2015] Is object localization for free? – Weakly-supervised learning with convolutional neural networks這篇論文裏設計了針對多標簽問題的loss,傳統的類別分類不適用,作者把這個任務視為多個二分類問題,loss function和分類的分數如下:
Multi label 多標簽分類問題