1. 程式人生 > >focal loss

focal loss

Focal Loss 就是一個解決分類問題中類別不平衡、分類難度差異的一個 loss.

Kaiming 大神的 Focal Loss ,二分類形式,是:

VBcD02jFhglbdajMCsZiameIjv6vJgibJl9gRk1yFSQeU66nlwqC856HBGqibtsoyXCKtPeOumoRmdg3PAGLl5vWA

如果落實到 ŷ =σ(x) 這個預測,那麼就有:

VBcD02jFhglbdajMCsZiameIjv6vJgibJl1CuI26775Cyp4CibIjKDuPzOOabGwicggdIUCWj3P5y9aeDhA5cAVkCw

通過一系列調參,得到 α=0.25, γ=2(在他的模型上)的效果最好。

多分類:

Focal Loss 在多分類中的形式也很容易得到,其實就是:

VBcD02jFhglbdajMCsZiameIjv6vJgibJlgichcUBg0FibMjoZe7eTaEC11Cj0HVvHicak38mr25ud0SzpMfALtWAwg

 

ŷt 是目標的預測值,一般就是經過 softmax 後的結果。那我自己構思的 L∗∗ 怎麼推廣到多分類?也很簡單:

VBcD02jFhglbdajMCsZiameIjv6vJgibJlq8dAb8xUUDZxsicConHLjdzxQ37vBoCEtoZEJpjTVXkNLZRSlQSVCQg

 

這裡 xt 也是目標的預測值,但它是 softmax 前的結果。

tensorlfow實現的multi-class, multi-label 如下:

def focal_loss(self, labels, logits, gamma=2.0, alpha=0.25, normalize=True):
        labels = tf.where(labels > 0, tf.ones_like(labels), tf.zeros_like(labels))
        labels = tf.cast(labels, tf.float32)
        probs = tf.sigmoid(logits)
        ce_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)

        alpha_t = tf.ones_like(logits) * alpha
        alpha_t = tf.where(labels > 0, alpha_t, 1.0 - alpha_t)
        probs_t = tf.where(labels > 0, probs, 1.0 - probs)
        # tf.where(input, a,b),其中a,b均為尺寸一致的tensor,作用是將a中對應input中true的位置的元素值不變,其餘元素進行替換,替換成b中對應位置的元素值
        focal_matrix = alpha_t * tf.pow((1.0 - probs_t), gamma)
        loss = focal_matrix * ce_loss

        loss = tf.reduce_sum(loss)
        if normalize:
            n_pos = tf.reduce_sum(labels)
            # total_weights = tf.stop_gradient(tf.reduce_sum(focal_matrix))
            # total_weights = tf.Print(total_weights, [n_pos, total_weights])
            #         loss = loss / total_weights
            def has_pos():
                return loss / tf.cast(n_pos, tf.float32)
            def no_pos():
                #total_weights = tf.stop_gradient(tf.reduce_sum(focal_matrix))
                #return loss / total_weights
                return loss
            loss = tf.cond(n_pos > 0, has_pos, no_pos)
        return loss