1. 程式人生 > 實用技巧 >基於mmdetection的熱力圖繪製

基於mmdetection的熱力圖繪製

#coding: utf-8
import cv2
import matplotlib.pyplot as plt
import mmcv
import numpy as np
import os
import torch
import torch.nn as nn
import warnings

from mmcv.ops import RoIAlign, RoIPool
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint

from mmdet.apis import inference_detector, init_detector
from mmdet.core import get_classes from mmdet.datasets.pipelines import Compose from mmdet.models import build_detector from mmdet.models.dense_heads import * def featuremap_2_heatmap(feature_map): assert isinstance(feature_map, torch.Tensor) feature_map = feature_map.detach() heatmap = feature_map[:,0,:,:]*0
for c in range(feature_map.shape[1]): heatmap+=feature_map[:,c,:,:] heatmap = heatmap.cpu().numpy() heatmap = np.mean(heatmap, axis=0) heatmap = np.maximum(heatmap, 0) heatmap /= np.max(heatmap) return heatmap def draw_feature_map(model, img_path, save_dir): ''' :param model: 載入了引數的模型 :param img_path: 測試影象的檔案路徑 :param save_dir: 儲存生成影象的資料夾 :return:
''' img = mmcv.imread(img_path) modeltype = str(type(model)).split('.')[-1].split('\'')[0] model.eval() model.draw_heatmap = True featuremaps = inference_detector(model, img) i=0 for featuremap in featuremaps: heatmap = featuremap_2_heatmap(featuremap) heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) # 將熱力圖的大小調整為與原始影象相同 heatmap = np.uint8(255 * heatmap) # 將熱力圖轉換為RGB格式 heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # 將熱力圖應用於原始影象 superimposed_img = heatmap * 0.4 + img # 這裡的0.4是熱力圖強度因子 cv2.imwrite(os.path.join(save_dir,'featuremap_'+str(i)+'.png'), superimposed_img) # 將影象儲存到硬碟 i=i+1 from argparse import ArgumentParser def main(): parser = ArgumentParser() parser.add_argument('img', help='Image file') parser.add_argument('save_dir', help='Dir to save heatmap') parser.add_argument('config', help='Config file') parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument('--device', default='cuda:0', help='Device used for inference') args = parser.parse_args() # build the model from a config file and a checkpoint file model = init_detector(args.config, args.checkpoint, device=args.device) draw_feature_map(model,args.img,args.save_dir) if __name__ == '__main__': main()

用例: