基於mmdetection的熱力圖繪製
阿新 • • 發佈:2020-10-05
#coding: utf-8 import cv2 import matplotlib.pyplot as plt import mmcv import numpy as np import os import torch import torch.nn as nn import warnings from mmcv.ops import RoIAlign, RoIPool from mmcv.parallel import collate, scatter from mmcv.runner import load_checkpoint from mmdet.apis import inference_detector, init_detectorfrom mmdet.core import get_classes from mmdet.datasets.pipelines import Compose from mmdet.models import build_detector from mmdet.models.dense_heads import * def featuremap_2_heatmap(feature_map): assert isinstance(feature_map, torch.Tensor) feature_map = feature_map.detach() heatmap = feature_map[:,0,:,:]*0for c in range(feature_map.shape[1]): heatmap+=feature_map[:,c,:,:] heatmap = heatmap.cpu().numpy() heatmap = np.mean(heatmap, axis=0) heatmap = np.maximum(heatmap, 0) heatmap /= np.max(heatmap) return heatmap def draw_feature_map(model, img_path, save_dir): ''' :param model: 載入了引數的模型 :param img_path: 測試影象的檔案路徑 :param save_dir: 儲存生成影象的資料夾 :return:''' img = mmcv.imread(img_path) modeltype = str(type(model)).split('.')[-1].split('\'')[0] model.eval() model.draw_heatmap = True featuremaps = inference_detector(model, img) i=0 for featuremap in featuremaps: heatmap = featuremap_2_heatmap(featuremap) heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) # 將熱力圖的大小調整為與原始影象相同 heatmap = np.uint8(255 * heatmap) # 將熱力圖轉換為RGB格式 heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # 將熱力圖應用於原始影象 superimposed_img = heatmap * 0.4 + img # 這裡的0.4是熱力圖強度因子 cv2.imwrite(os.path.join(save_dir,'featuremap_'+str(i)+'.png'), superimposed_img) # 將影象儲存到硬碟 i=i+1 from argparse import ArgumentParser def main(): parser = ArgumentParser() parser.add_argument('img', help='Image file') parser.add_argument('save_dir', help='Dir to save heatmap') parser.add_argument('config', help='Config file') parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument('--device', default='cuda:0', help='Device used for inference') args = parser.parse_args() # build the model from a config file and a checkpoint file model = init_detector(args.config, args.checkpoint, device=args.device) draw_feature_map(model,args.img,args.save_dir) if __name__ == '__main__': main()
用例: