1. 程式人生 > 實用技巧 >[Paddle學習筆記][13][基於YOLOv3的昆蟲檢測-測試模型]

[Paddle學習筆記][13][基於YOLOv3的昆蟲檢測-測試模型]

說明:

本例程使用YOLOv3進行昆蟲檢測。例程分為資料處理、模型設計、損失函式、訓練模型、模型預測和測試模型六個部分。本篇為第六部分,儲存非極大值抑制輸出的結果到預測結果檔案,然後通過完整插值方法計算mAP。非極大值閾值的預測得分需要設定一個低的得分,使得計算mAP時能比較更多的平均精度。

實驗程式碼:

測試模型:

import json
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable

from source.data import multip_test_reader
from
source.model import YOLOv3 from source.infer import get_nms_infer from source.test import test num_classes = 7 # 類別數量 anchor_size = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] # 錨框大小 anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] #
錨框掩碼 downsample_ratio = 32 # 下采樣率 test_path = './dataset/val/images' # 測試目錄路徑 json_path = './output/infer.json' # 結果儲存路徑 xmls_path='./dataset/val/annotations/xmls' # 標籤目錄路徑 model_path = './output/darknet53-yolov3' # 網路權重路徑
sco_threshold = 0.01 # 預測得分閾值:設定一個小值,使得測試能夠比較更多的準確率 nms_threshold = 0.45 # 非極大值閾值:消除重疊大於該閾值的的預測邊框 iou_threshold = 0.50 # 測試交併比值:保留與真實邊框大於該閾值的預測邊框 with fluid.dygraph.guard(): # 準備資料 test_reader = multip_test_reader(test_path, batch_size=8, scale_size=(608, 608)) # 載入模型 model = YOLOv3(num_classes=num_classes, anchor_mask=anchor_mask) # 載入模型 model_dict, _ = fluid.load_dygraph(model_path) # 載入權重 model.load_dict(model_dict) # 設定權重 model.eval() # 設定驗證 # 模型預測 infer_list = [] # 預測結果列表 for test_data in test_reader(): # 讀取影象 image_name, image, image_size = test_data # 讀取資料 image = to_variable(image) # 轉換格式 image_size = to_variable(image_size) # 轉換格式 # 前向傳播 infer = model(image) # 獲取結果 infer = get_nms_infer(infer, image_size, num_classes, anchor_size, anchor_mask, downsample_ratio, sco_threshold, nms_threshold) # 新增列表 for i in range(len(infer)): # 遍歷批次 if(len(infer[i]) > 0): # 是否存在物體 infer_list.append([image_name[i], infer[i].tolist()]) print('Processed {} images...'.format(len(infer_list)), end='\r') # 儲存結果 print('Svae {} results to infer.json.'.format(len(infer_list))) json.dump(infer_list, open(json_path, 'w')) # 測試模型 test(json_path, xmls_path, num_classes, iou_threshold)

結果:

Svae 245 results to infer.json

Detection mAP(0.50) = 87.63%

測試結果

darknet53-yolov3_050 Detection mAP(0.50) = 64.87%

darknet53-yolov3_100 Detection mAP(0.50) = 81.02%

darknet53-yolov3_150 Detection mAP(0.50) = 87.63%

test.py檔案

import os
import json
import math
import numpy as np
import xml.etree.ElementTree as ET

# 計算平均精度
class DetectionMAP(object):
    def __init__(self, num_classes, iou_threshold=0.5):
        """
        功能: 
            初始化計算平均精度方法
        輸入: 
            num_classes   - 預測類別數量
            iou_threshold - 測試交併比值
        輸出:
        """
        self.num_classes = num_classes                     # 預測類別數量
        self.iou_threshold = iou_threshold                 # 測試交併比值
        self.count = [0] * self.num_classes                # 數量統計列表
        self.score = [[] for _ in range(self.num_classes)] # 得分統計列表
        
    def update(self, infer, gtbox, gtcls):
        """
        功能: 
            統計各類數量和得分
        輸入: 
            infer - 預測結果
            gtbox - 物體邊框
            gtcls - 物體類別
        輸出:
        """
        # 統計各類數量
        for gtcls_item in gtcls:
            self.count[int(np.array(gtcls_item))] += 1
        
        # 統計各類得分
        visited = [False] * len(gtcls) # 各類訪問標識
        for infer_item in infer:
            # 獲取預測資料
            pdcls, pdsco, xmin, ymin, xmax, ymax = infer_item.tolist() # 獲取預測資料
            pdbox = [xmin, ymin, xmax, ymax]                           # 獲取預測邊框
            
            # 計算最大邊框
            max_index = -1 # 最大交併索引
            max_iou = -1.0 # 最大交併比值
            for i, gtcls_item in enumerate(gtcls): # 遍歷真實類別列表
                if int(gtcls_item) == int(pdcls): # 如果真實類別等於預測類別,則計算交併比值
                    iou = self.get_box_iou_xyxy(pdbox, gtbox[i])
                    if iou > max_iou: # 如果交併比值大於最大交併比值,則更新最大交併比值和索引
                        max_index = i
                        max_iou = iou
            
            # 統計各類得分
            if max_iou > self.iou_threshold: # 如果最大交併比值大於測試交併比值
                if not visited[max_index]: # 如果該物體沒有被統計,則新增到列表,並設定訪問標識
                    self.score[int(pdcls)].append([pdsco, 1.0]) # 新增各類正確正例
                    visited[max_index] = True                   # 設定訪問標識為真
                else: # 如果該物體已經被統計,則新增到列表,並設定為成錯誤正例
                    self.score[int(pdcls)].append([pdsco, 0.0]) # 新增各類錯誤正例
            else: # 如果最大交併比值不大於測試交併比值,則新增到列表,並設定成錯誤正例
                self.score[int(pdcls)].append([pdsco, 0.0])     # 新增各類錯誤正例
        
    def get_box_iou_xyxy(self, box1, box2):
        """
        功能: 
            計算邊框交併比值
        輸入: 
            box1 - 邊界框1
            box2 - 邊界框2
        輸出:
            iou  - 交併比值
        """
        # 計算交集面積
        x1_min, y1_min, x1_max, y1_max = box1[0], box1[1], box1[2], box1[3]
        x2_min, y2_min, x2_max, y2_max = box2[0], box2[1], box2[2], box2[3]

        x_min = np.maximum(x1_min, x2_min)
        y_min = np.maximum(y1_min, y2_min)
        x_max = np.minimum(x1_max, x2_max)
        y_max = np.minimum(y1_max, y2_max)

        w = np.maximum(x_max - x_min + 1.0, 0)
        h = np.maximum(y_max - y_min + 1.0, 0)

        intersection = w * h # 交集面積

        # 計算並集面積
        s1 = (y1_max - y1_min + 1.0) * (x1_max - x1_min + 1.0)
        s2 = (y2_max - y2_min + 1.0) * (x2_max - x2_min + 1.0)

        union = s1 + s2 - intersection # 並集面積

        # 計算交併比
        iou = intersection / union

        return iou
    
    def get_mAP(self):
        """
        功能:
            計算各類平均精度
        輸入:
        輸出:
            mAP - 各類平均精度
        """
        # 計算每類精度
        mAP = 0 # 各類平均精度
        cnt = 0 # 各類類別計數
        for score, count in zip(self.score, self.count): # 遍歷每類物體
            # 統計正誤正例
            if count == 0 or len(score) == 0: # 如果該類數量為0,或得分列表為空,則繼續下一個類別
                continue
            tp_list, fp_list = self.get_tp_fp_list(score) # 統計正誤正例
            
            # 計算預測的準確率和召回率
            precision = [] # 準確率列表
            recall = []    # 召回率列表
            for tp, fp in zip(tp_list, fp_list):
                precision.append(float(tp) / (tp + fp)) # 新增準確率
                recall.append(float(tp) / count)        # 新增召回率
            
            # 計算平均精度
            AP = 0.0         # 平均精度
            pre_recall = 0.0 # 前召回率
            for i in range(len(precision)): # 遍歷正確率列表
                recall_gap = math.fabs(recall[i] - pre_recall) # 計算召回率差值
                if recall_gap > 1e-6: # 如果召回率改變,則計算平均精度,更新前召回率
                    AP += precision[i] * recall_gap # 累加平均精度
                    pre_recall = recall[i]          # 更新前召回率
            
            # 更新各類精度
            mAP += AP # 累加各類精度
            cnt += 1  # 增加類別計數
            
        # 計算平均精度
        mAP = (mAP / float(cnt)) if cnt > 0 else mAP
        
        return mAP

    def get_tp_fp_list(self, score):
        """
        功能:
            對得分列表進行從大到小排序,按排序統計正確正例和錯誤正例數量
        輸入:
            score   - 得分列表
        輸出:
            tp_list - 正確正例列表
            fp_list - 錯誤正例列表
        """
        tp = 0       # 正確正例數量
        fp = 0       # 錯誤正例數量
        tp_list = [] # 正確正例列表
        fp_list = [] # 錯誤正例列表
        
        score_list = sorted(score, key=lambda s: s[0], reverse=True) # 對得分列表按從大到小排序
        for (score, label) in score_list:
            tp += int(label)     # 統計正確正例
            tp_list.append(tp)   # 新增正確正例
            fp += 1 - int(label) # 統計錯誤正例
            fp_list.append(fp)   # 新增錯誤正例
        
        return tp_list, fp_list
    
##############################################################################################################

object_names = ['Boerner', 'Leconte', 'Linnaeus', 'acuminatus', 'armandi', 'coleoptera', 'linnaeus'] # 物體名稱
def get_object_gtcls():
    """
    功能:
        將物體名稱對映成物體類別
    輸入:
    輸出:
        object_gtcls - 物體類別
    """
    object_gtcls = {} # 物體類別字典
    for key, value in enumerate(object_names):
        object_gtcls[value] = key # 將物體名稱對映成物體類別
    return object_gtcls

def test(json_path, xmls_path, num_classes, iou_threshold):
    """
    功能:
        測試模型平均精度
    輸入:
        json_path     - 預測結果路徑
        xmls_path     - 標籤目錄路徑
        num_classes   - 預測類別數量
        iou_threshold - 測試交併比值
    輸出:
    """
    # 宣告計算方法
    mAP = DetectionMAP(num_classes, iou_threshold)
    
    # 統計預測得分
    json_list = json.load(open(json_path))               # 讀取預測結果
    for json_item in json_list: # 遍歷預測結果
        # 讀取預測檔案
        image_name = str(json_item[0])                   # 讀取檔名稱
        infer = np.array(json_item[1]).astype('float32') # 讀取預測結果
        
        # 讀取標籤檔案
        tree = ET.parse(os.path.join(xmls_path, image_name + '.xml')) # 解析檔案
        image_w = float(tree.find('size').find('width').text)         # 影象寬度
        image_h = float(tree.find('size').find('height').text)        # 影象高度
        
        object_list = tree.findall('object')                     # 物體列表
        gtbox = np.zeros((len(object_list), 4), dtype='float32') # 物體邊框
        gtcls = np.zeros((len(object_list),  ), dtype='int32')   # 物體類別
        
        for i, object_item in enumerate(object_list):
            # 讀取物體邊框
            x_min = float(object_item.find('bndbox').find('xmin').text) # 物體邊框x1
            y_min = float(object_item.find('bndbox').find('ymin').text) # 物體邊框y1
            x_max = float(object_item.find('bndbox').find('xmax').text) # 物體邊框x2
            y_max = float(object_item.find('bndbox').find('ymax').text) # 物體邊框y2
            
            x_min = max(0.0, x_min)
            y_min = max(0.0, y_min)
            x_max = min(x_max, image_w - 1.0)
            y_max = min(y_max, image_h - 1.0)
            
            gtbox[i] = [x_min, y_min, x_max, y_max] # 設定物體邊框
            
            # 讀取物體類別
            object_name = object_item.find('name').text # 讀取物體名稱
            gtcls[i] = get_object_gtcls()[object_name]  # 將物體名稱對映成物體類別
        
        # 統計預測得分
        mAP.update(infer, gtbox, gtcls)
        
    # 計算平均精度
    mAP_value = mAP.get_mAP() * 100 # 計算平均精度
    print("Detection mAP({:.2f}) = {:.2f}%".format(iou_threshold, mAP_value))

參考資料:

https://blog.csdn.net/qq_31511955/article/details/89022037

https://blog.csdn.net/weixin_41278720/article/details/88774411

https://blog.csdn.net/wc996789331/article/details/83785993

https://blog.csdn.net/litt1e/article/details/88814417

https://blog.csdn.net/litt1e/article/details/88852745

https://blog.csdn.net/litt1e/article/details/88907542

https://aistudio.baidu.com/aistudio/projectdetail/742781

https://aistudio.baidu.com/aistudio/projectdetail/672017

https://aistudio.baidu.com/aistudio/projectdetail/868589

https://aistudio.baidu.com/aistudio/projectdetail/122277