1. 程式人生 > >Weakly Supervised Instance Segmentation using Class Peak Response論文復現以及遇到的問題

Weakly Supervised Instance Segmentation using Class Peak Response論文復現以及遇到的問題

摘要:使用影象級別標籤的弱監督例項分割,而不是昂貴的畫素級掩碼,仍然未被探索。在本文中,我們通過利用類峰值響應來啟用分類網路(例如掩碼提取)來解決這一具有挑戰性的問題。僅使用影象標籤監控,完全卷積方式的CNN分類器可以生成類響應圖,其指定每個影象位置處的分類置信度。我們觀察到類響應圖中的區域性最大值(即峰值)通常對應於駐留在每個例項內的強視覺提示。受此啟發,我們首先設計一個過程,以激發從類響應圖中出現的峰值。然後,出現的峰值被反向傳播並有效地對映到每個物件例項的高資訊區域,例如例項邊界。我們將從類峰值響應生成的上述對映稱為峰值響應對映(PRM)。 PRM提供了精細的詳細例項級表示,即使使用一些現成的方法也可以提取例項掩碼。據我們所知,我們首次報告了具有挑戰性的影象級監督例項分割任務的結果。大量實驗表明,我們的方法還可以提高弱監督的逐點定位以及語義分割效能,並在流行的基準測試中報告最先進的結果,包括PASCAL VOC 2012和MS COCO。

復現的時候就按照官網上說的那樣:

win10或者≥ubantu14.04

GPU版本:

NVIDIA GPU + CUDA CuDNN

或者CPU可以

還有就是

我是在win10上實現的,沒有GPU加速。

安裝:

1.Install Nest

pip install git+https://github.com/ZhouYanzhao/Nest.git

2.Install PRM via Nest's CLI tool:

nest module install github@ZhouYanzhao/PRM:pytorch prm

測試是否成功:

nest module list --filter prm

輸出:

# 3 Nest modules found.
# [0] prm.fc_resnet50 (1.0.0)
# [1] prm.peak_response_mapping (1.0.0)
# [2] prm.prm_visualize (1.0.0)

接下來就是運行了。

1.Install Nest's build-in Pytorch modules

nest module install github@ZhouYanzhao/Nest:pytorch pytorch

2.Download the PASCAL-VOC2012 dataset:在瀏覽器輸入就行解壓到demo裡

http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

3.Run the demo experiment via demo/main.ipynb

用jupyter notebook開啟

得到像這樣的:

這裡就是訓練用的,訓練生成的是.pt權重檔案,這是pytorch框架的權重,還有.pkl型別的,其實都一樣。

這裡的config.yml是設定引數用的。

_name: network_trainer
data_loaders:
  _name: fetch_data
  dataset: 
    _name: pascal_voc_classification
    data_dir: ./datasets/VOCdevkit
    year: 2012
  batch_size: 16
  num_workers: 4
  transform:
    _name: image_transform
    image_size: [448, 448]
    mean: [0.485, 0.456, 0.406]
    std: [0.229, 0.224, 0.225]
  train_augmentation:
    horizontal_flip: 0.5
  train_splits:
    - trainval
model:
  _name: peak_response_mapping
  backbone:
    _name: fc_resnet50
  win_size: 3
  sub_pixel_locating_factor: 8
  enable_peak_stimulation: true
criterion:
  _name: multilabel_soft_margin_loss
  difficult_samples: yes
optimizer:
  _name: sgd_optimizer
  lr: 0.01
  momentum: 0.9
  weight_decay: 1.0e-4
parameter:
  _name: finetune
  base_lr: 0.01
  groups:
    'features': 0.01
meters:
  loss:
    _name: loss_meter
max_epoch: 20
device: cuda
hooks:
  on_start:
    -
      _name: print_state
      formats:
        - '@CONFIG'
        - 'Model: {model}'
      join_str: '\n'
  on_end_epoch: 
    - 
      _name: print_state
      formats:
        - 'epoch: {epoch_idx}'
        - 'classification_loss: {metrics[trainval_loss]:.4f}'
    -
      _name: checkpoint
      save_dir: './snapshots'
      save_latest: yes

因為用的是CPU,所以batch_size設定成的是4,設定成16,發現卡的不行。然後就開始訓練了。

訓練了三天多,要不是因為無故斷電了,還能訓練,最後的loss下降到0.03

[2018-07-29 15:41:05,699] epoch: 0 | classification_loss: 0.1172
[2018-07-29 15:41:05,918] checkpoint created at C:\Users\Mars\PRM-pytorch\demo\snapshots\model_latest.pt
[2018-07-29 20:29:30,672] epoch: 1 | classification_loss: 0.0758
[2018-07-29 20:29:30,812] checkpoint created at C:\Users\Mars\PRM-pytorch\demo\snapshots\model_latest.pt
[2018-07-30 01:17:34,042] epoch: 2 | classification_loss: 0.0645
[2018-07-30 01:17:34,214] checkpoint created at C:\Users\Mars\PRM-pytorch\demo\snapshots\model_latest.pt
[2018-07-30 06:04:50,313] epoch: 3 | classification_loss: 0.0566
[2018-07-30 06:04:50,485] checkpoint created at C:\Users\Mars\PRM-pytorch\demo\snapshots\model_latest.pt
[2018-07-30 11:03:36,152] epoch: 4 | classification_loss: 0.0498
[2018-07-30 11:03:36,293] checkpoint created at C:\Users\Mars\PRM-pytorch\demo\snapshots\model_latest.pt
[2018-07-30 16:09:31,468] epoch: 5 | classification_loss: 0.0444
[2018-07-30 16:09:31,683] checkpoint created at C:\Users\Mars\PRM-pytorch\demo\snapshots\model_latest.pt
[2018-07-30 21:07:03,611] epoch: 6 | classification_loss: 0.0409
[2018-07-30 21:07:03,751] checkpoint created at C:\Users\Mars\PRM-pytorch\demo\snapshots\model_latest.pt
[2018-07-31 01:55:32,188] epoch: 7 | classification_loss: 0.0414
[2018-07-31 01:55:32,422] checkpoint created at C:\Users\Mars\PRM-pytorch\demo\snapshots\model_latest.pt
[2018-07-31 06:43:39,614] epoch: 8 | classification_loss: 0.0388
[2018-07-31 06:43:39,761] checkpoint created at C:\Users\Mars\PRM-pytorch\demo\snapshots\model_latest.pt
[2018-07-31 11:32:24,822] epoch: 9 | classification_loss: 0.0351
[2018-07-31 11:32:25,025] checkpoint created at C:\Users\Mars\PRM-pytorch\demo\snapshots\model_latest.pt
[2018-07-31 16:20:31,651] epoch: 10 | classification_loss: 0.0351
[2018-07-31 16:20:31,858] checkpoint created at C:\Users\Mars\PRM-pytorch\demo\snapshots\model_latest.pt
[2018-07-31 21:08:55,310] epoch: 11 | classification_loss: 0.0330
[2018-07-31 21:08:55,508] checkpoint created at C:\Users\Mars\PRM-pytorch\demo\snapshots\model_latest.pt

然後是匯入模型:

這回問題很大了!因為儲存的是模型的引數,需要再建立一個網路,再把引數賦值進去,但匯入的時候發現,有問題,新建的模型需要的是‘module.0.features.0.weight’類似命名的,而生成的權重檔案都是'0.features.0.weight'這樣的命名的,所以我得到的權重檔案缺少了‘module.’這幾個字元,先想著用notepad++直接改,但是讀取的時候根本識別不了。

KeyError: ‘unexpected key “module.encoder.embedding.weight”搜到了一個類似的問題,這個問題說的正是我遇到的反例,生成的模型沒有‘module.’,而生成的模型裡多了‘module.’這幾個字元,所以大家把權重檔案中的前面那幾個字元去掉了。

# 匯入已經儲存了的權重檔案
state_dict = torch.load('myfile.pth.tar')
# 建立一個沒有`module.`的權重
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load
model.load_state_dict(new_state_dict)

而我的情況恰好相反,所以我需要加上這幾個字元:

先把字串變成list,然後在前面新增`module.`

n_state = state['model']
new_state_dict = OrderedDict()
for k, v in n_state.items():
    temp_list = list(k)
    temp_list.insert(0,'module.')
    name = ''.join(temp_list)
    new_state_dict[name] = v
# load params
# print(new_state_dict)
model.load_state_dict(new_state_dict)

問題解決!

然後就都是一些小問題了。這是全部的程式碼:

%matplotlib inline
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from nest import modules
import numpy as np
import os
import PIL.Image
import json
import scipy.misc
import torch.nn.functional as F
import PIL.Image
from collections import OrderedDict
import warnings
warnings.filterwarnings("ignore")
class_names = modules.pascal_voc_object_categories()
image_size = 448
# image pre-processor
transformer = modules.image_transform(
    image_size = [image_size, image_size],
    augmentation = dict(),
    mean = [0.485, 0.456, 0.406],
    std = [0.229, 0.224, 0.225])
backbone = modules.fc_resnet50(num_classes=20, pretrained=False)
model = modules.peak_response_mapping(backbone)
model = nn.DataParallel(model)
state = torch.load('11.pt')
# model.load_state_dict(state['model'])
# model = model.module.cuda()
# a = state['model']
# b = np.array(a)
# print(state.items())
n_state = state['model']
new_state_dict = OrderedDict()
for k, v in n_state.items():
    temp_list = list(k)
    temp_list.insert(0,'module.')
    name = ''.join(temp_list)
    new_state_dict[name] = v
# load params
# print(new_state_dict)
model.load_state_dict(new_state_dict)
model = model.module.cpu()
idx = 4
raw_img = PIL.Image.open('F://python//PRM-pytorch//demo//data//sample%d.jpg' % idx).convert('RGB')
input_var = transformer(raw_img).unsqueeze(0).cpu().requires_grad_()
with open('F://python//PRM-pytorch//demo//data//sample%d.json' % idx, 'r') as f:
    proposals = list(map(modules.rle_decode, json.load(f)))
# plot raw image
# plt.figure(figsize=(5,5))
# plt.imshow(raw_img)

model = model.eval()
print('Object categories in the image:')
confidence = model(input_var)
for idx in range(len(class_names)):
    if confidence.data[0, idx] > 0:
        print('    [class_idx: %d] %s (%.2f)' % (idx, class_names[idx], confidence[0, idx]))
        
        
        
model = model.inference()
visual_cues = model(input_var)
if visual_cues is None:
    print('No class peak response detected')
else:
    confidence, class_response_maps, class_peak_responses, peak_response_maps = visual_cues
    _, class_idx = torch.max(confidence, dim=1)
    class_idx = class_idx.item()
    num_plots = 2 + len(peak_response_maps)
    f, axarr = plt.subplots(1, num_plots, figsize=(num_plots * 4, 4))
    axarr[0].imshow(scipy.misc.imresize(raw_img, (image_size, image_size), interp='bicubic'))
    axarr[0].set_title('Image')
    axarr[0].axis('off')
    axarr[1].imshow(class_response_maps[0, class_idx].cpu(), interpolation='bicubic')
    axarr[1].set_title('Class Response Map ("%s")' % class_names[class_idx])
    axarr[1].axis('off')
    for idx, (prm, peak) in enumerate(sorted(zip(peak_response_maps, class_peak_responses), key=lambda v: v[-1][-1])):
        axarr[idx + 2].imshow(prm.cpu(), cmap=plt.cm.jet)
        axarr[idx + 2].set_title('Peak Response Map ("%s")' % (class_names[peak[1].item()]))
        axarr[idx + 2].axis('off')        
# predict instance masks via proposal retrieval
instance_list = model(input_var, retrieval_cfg=dict(proposals=proposals, param=(0.95, 1e-5, 0.8)))

# visualization
if instance_list is None:
    print('No object detected')
else:
    # peak response maps are merged if they select similar proposals
    vis = modules.prm_visualize(instance_list, class_names=class_names)
    f, axarr = plt.subplots(1, 3, figsize=(12, 5))
    axarr[0].imshow(scipy.misc.imresize(raw_img, (image_size, image_size), interp='bicubic'))
    axarr[0].set_title('Image')
    axarr[0].axis('off')
    axarr[1].imshow(vis[0])
    axarr[1].set_title('Prediction')
    axarr[1].axis('off')
    axarr[2].imshow(vis[1])
    axarr[2].set_title('Peak Response Maps')
    axarr[2].axis('off')
    plt.show()

結果:

其他圖片的測試:

其實我在這裡就有了問題,就是作者在進行畫mask的時候用到了提前計算好的proposals,儲存成了.json檔案,

.json檔案我遇到過,就是用labelme進行標註的時候生成的標籤,但是我比較這兩個檔案不一樣。

標籤檔案是這樣的:

{
  "flags": {},
  "shapes": [
    {
      "label": "car",
      "line_color": null,
      "fill_color": null,
      "points": [
        [
          64,
          77
        ],
        [
          50,
          91
        ],
        [
          41,
          102
        ],
        [
          37,
          115
        ],
        [
          36,
          130
        ],
        [
          36,
          138
        ],
        [
          37,
          153
        ],
        [
          40,
          163
        ],
        [
          47,
          167
        ],
        [
          55,
          167
        ],
        [
          60,
          165
        ],
        [
          63,
          157
        ],
        [
          64,
          154
        ],
        [
          110,
          160
        ],
        [
          113,
          167
        ],
        [
          120,
          174
        ],
        [
          127,
          178
        ],
        [
          146,
          178
        ],
        [
          155,
          174
        ],
        [
          159,
          167
        ],
        [
          160,
          165
        ],
        [
          218,
          164
        ],
        [
          220,
          170
        ],
        [
          232,
          175
        ],
        [
          247,
          174
        ],
        [
          255,
          169
        ],
        [
          258,
          163
        ],
        [
          258,
          161
        ],
        [
          267,
          160
        ],
        [
          268,
          155
        ],
        [
          267,
          154
        ],
        [
          269,
          141
        ],
        [
          268,
          131
        ],
        [
          264,
          121
        ],
        [
          260,
          113
        ],
        [
          252,
          109
        ],
        [
          238,
          103
        ],
        [
          224,
          99
        ],
        [
          217,
          98
        ],
        [
          219,
          96
        ],
        [
          213,
          93
        ],
        [
          208,
          91
        ],
        [
          194,
          80
        ],
        [
          185,
          74
        ],
        [
          175,
          71
        ],
        [
          162,
          69
        ],
        [
          140,
          68
        ],
        [
          122,
          67
        ],
        [
          106,
          67
        ],
        [
          93,
          67
        ],
        [
          80,
          70
        ]
      ]
    }
  ],
  "lineColor": [
    0,
    255,
    0,
    128
  ],
  "fillColor": [
    255,
    0,
    0,
    128
  ],
  "imagePath": "..\\..\\..\\get_img\\China_Vehicle_Raw\\0001_6.jpg",
  "imageData": "