1. 程式人生 > 其它 >詳解計算miou的程式碼以及混淆矩陣的意義

詳解計算miou的程式碼以及混淆矩陣的意義

詳解計算miou的程式碼以及混淆矩陣的意義

miou的定義

'''
    Mean Intersection over Union(MIoU,均交併比):為語義分割的標準度量。其計算兩個集合的交集和並集之比.
    在語義分割的問題中,這兩個集合為真實值(ground truth)和預測值(predicted segmentation)。
    這個比例可以變形為正真數(intersection)比上真正、假負、假正(並集)之和。在每個類上計算IoU,之後平均。
    
    對於21個類別,分別求IOU:
        例如,對於類別1的IOU定義如下:
            (1)統計在ground truth中屬於類別1的畫素數
            (2)統計在預測結果中每個類別1的畫素數
                (1) + (2)就是二者的並集畫素數(類比於兩塊區域的面積加和, 注:二者交集部分的面積加重複了)
                再減去二者的交集(既在ground truth集合中又在預測結果集合中的畫素),得到的就是二者的並集(所有跟類別1有關係的畫素:包括TP,FP,FN)
        擴充套件提示:
            TP(真正): 預測正確, 預測結果是正類, 真實是正類  
            FP(假正): 預測錯誤, 預測結果是正類, 真實是負類
            FN(假負): 預測錯誤, 預測結果是負類, 真實是正類
            
            TN(真負): 預測正確, 預測結果是負類, 真實是負類   #跟類別1無關,所以不包含在並集中
            (本例中, 正類:是類別1, 負類:不是類別1)
                
    mIoU:
        對於每個類別計算出的IoU求和取平均
    
    '''
————————————————
版權宣告:本文為CSDN博主「你吃過滷汁牛肉嗎」的原創文章,遵循CC 4.0 BY-SA版權協議,轉載請附上原文出處連結及本宣告。
原文連結:https://blog.csdn.net/u012370185/article/details/94409933
class mIOU:
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.hist = np.zeros((num_classes, num_classes))


    # 返回的是混淆矩陣
    def _fast_hist(self, label_pred, label_true):
        # 去除背景
        # ground truth中所有正確(值在[0, classe_num])的畫素label的mask
        mask = (label_true >= 0) & (label_true < self.num_classes)
        # 計算出每一類(0-n**2-1)中對應的數(0-n**2-1)出現的次數,返回值為(n,n)
        # confusion_matrix是一個[num_classes, num_classes]的矩陣,
        # confusion_matrix矩陣中(x, y)位置的元素代表該張圖片中真實類別為x, 被預測為y的畫素個數
        '''
        關於下面的混淆矩陣如何計算出來的可能會有些初學者不大理解,筆者根據自己的想法對下面的程式碼有一定的見解,
        可能有一定錯誤,歡迎指出
        我們之前得到的是兩張由0-num_class-1的數字組成的label,分別對應我們的類別總數
        self.num_classes * label_true[mask].astype(int),這段程式碼通過將label_true[mask]乘以num_class
        第0類還是0,第一類的數字變成num_class(注意這是在原來的圖上操作),以此類推,
        +label_pred[mask],對於這一步我舉個栗子,比如groundtrue是第一類,num_class=21,在之前操作已經將該畫素塊
        變成21了,如果我預測的還是第一類,則這一畫素塊變成了22,在bincount函式中,使得數字22的次數增加了1,在後面的reshape中
        數字22對於的就是第二行第第二列,也就是對角線上的(因為混淆矩陣的定義就是對角線上的就是預測正確的,即TP),所以得到了
        hist就是混淆矩陣
        '''
        hist = np.bincount(
            self.num_classes * label_true[mask].astype(int) +
            label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
        return hist

    def add_batch(self, predictions, gts):
        for lp, lt in zip(predictions, gts):
            self.hist += self._fast_hist(lp.flatten(), lt.flatten())

    def evaluate(self):
        '''
        miou = TP / (TP+FN+TN)
        因此下面的式子顯然是計算miou的
        :return:
        '''
        iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist))
        return np.nanmean(iu[1:])