focal loss
阿新 • • 發佈:2018-11-13
Focal Loss 就是一個解決分類問題中類別不平衡、分類難度差異的一個 loss.
Kaiming 大神的 Focal Loss ,二分類形式,是:
如果落實到 ŷ =σ(x) 這個預測,那麼就有:
通過一系列調參,得到 α=0.25, γ=2(在他的模型上)的效果最好。
多分類:
Focal Loss 在多分類中的形式也很容易得到,其實就是:
ŷt 是目標的預測值,一般就是經過 softmax 後的結果。那我自己構思的 L∗∗ 怎麼推廣到多分類?也很簡單:
這裡 xt 也是目標的預測值,但它是 softmax 前的結果。
tensorlfow實現的multi-class, multi-label 如下:
def focal_loss(self, labels, logits, gamma=2.0, alpha=0.25, normalize=True): labels = tf.where(labels > 0, tf.ones_like(labels), tf.zeros_like(labels)) labels = tf.cast(labels, tf.float32) probs = tf.sigmoid(logits) ce_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) alpha_t = tf.ones_like(logits) * alpha alpha_t = tf.where(labels > 0, alpha_t, 1.0 - alpha_t) probs_t = tf.where(labels > 0, probs, 1.0 - probs) # tf.where(input, a,b),其中a,b均為尺寸一致的tensor,作用是將a中對應input中true的位置的元素值不變,其餘元素進行替換,替換成b中對應位置的元素值 focal_matrix = alpha_t * tf.pow((1.0 - probs_t), gamma) loss = focal_matrix * ce_loss loss = tf.reduce_sum(loss) if normalize: n_pos = tf.reduce_sum(labels) # total_weights = tf.stop_gradient(tf.reduce_sum(focal_matrix)) # total_weights = tf.Print(total_weights, [n_pos, total_weights]) # loss = loss / total_weights def has_pos(): return loss / tf.cast(n_pos, tf.float32) def no_pos(): #total_weights = tf.stop_gradient(tf.reduce_sum(focal_matrix)) #return loss / total_weights return loss loss = tf.cond(n_pos > 0, has_pos, no_pos) return loss