1. 程式人生 > >TSN演算法的PyTorch程式碼解讀(測試部分)

TSN演算法的PyTorch程式碼解讀(測試部分)

這篇部落格介紹TSN演算法的PyTorch程式碼的測試部分,建議先看訓練部分的程式碼解讀:TSN演算法的PyTorch程式碼解讀(訓練部分),test_moels.py是測試模型的入口。

前面模組匯入和命令列引數配置方面和訓練程式碼類似,不細講。

import argparse
import time

import numpy as np
import torch.nn.parallel
import torch.optim
from sklearn.metrics import confusion_matrix

from dataset import TSNDataSet
from models import TSN
from transforms import *
from ops import ConsensusModule

# options
parser
= argparse.ArgumentParser( description="Standard video-level testing") parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics']) parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff']) parser.add_argument('test_list', type=str) parser.add_argument('weights'
, type=str) parser.add_argument('--arch', type=str, default="resnet101") parser.add_argument('--save_scores', type=str, default=None) parser.add_argument('--test_segments', type=int, default=25) parser.add_argument('--max_num', type=int, default=-1) parser.add_argument('--test_crops', type=int, default=10
) parser.add_argument('--input_size', type=int, default=224) parser.add_argument('--crop_fusion_type', type=str, default='avg', choices=['avg', 'max', 'topk']) parser.add_argument('--k', type=int, default=3) parser.add_argument('--dropout', type=float, default=0.7) parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--gpus', nargs='+', type=int, default=None) parser.add_argument('--flow_prefix', type=str, default='') args = parser.parse_args()

接下來先是根據資料集來確定類別數。然後通過models.py指令碼中的TSN類來匯入網路結構。另外如果想檢視得到的網路net的各層資訊,可以通過net.state_dict()來檢視。checkpoint = torch.load(args.weights)是匯入預訓練的模型,在PyTorch中,匯入模型都是採用torch.load()介面實現,輸入args.weights就是.pth檔案,也就是預訓練模型。base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint.state_dict().items())}就是讀取預訓練模型的層和具體引數並存到base_dict這個字典中。net.load_state_dict(base_dict)就是通過呼叫torch.nn.Module類的load_state_dict方法,達到用預訓練模型初始化net網路的過程。需要注意的是load_state_dict方法還有一個輸入:strict,如果該引數為True,就表示網路結構的層資訊要和預訓練模型的層資訊嚴格相等,反之亦然,該引數預設是True。那麼什麼時候會用到False呢?就是當你只想用預訓練網路初始化你的網路的部分層引數或者說你的預訓練網路的層資訊和你要被初始化的網路的層資訊不完全一致,那樣就只會初始化層資訊相同的層。

if args.dataset == 'ucf101':
    num_class = 101
elif args.dataset == 'hmdb51':
    num_class = 51
elif args.dataset == 'kinetics':
    num_class = 400
else:
    raise ValueError('Unknown dataset '+args.dataset)

net = TSN(num_class, 1, args.modality,
          base_model=args.arch,
          consensus_type=args.crop_fusion_type,
          dropout=args.dropout)

checkpoint = torch.load(args.weights)

base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint.state_dict().items())}
net.load_state_dict(base_dict)

接下來關於args.test_crops的條件語句是用來對資料做不同的crop操作:簡單crop操作和重複取樣的crop操作。如果args.test_crops等於1,就先resize到指定尺寸(比如從400resize到256),然後再做center crop操作,最後得到的是net.input_size的尺寸(比如224),注意這裡一張圖片做完這些crop操作後輸出還是一張圖片。如果args.test_crops等於10,那麼就呼叫該專案下的transforms.py指令碼中的GroupOverSample類進行重複取樣的crop操作,最終一張影象得到10張crop的結果,後面會詳細介紹GroupOverSample這個類。接下來的資料讀取部分和訓練時候類似,需要注意的是:1、num_segments的引數預設是25,比訓練時候要多的多。2、test_mode=True,所以在呼叫TSNDataSet類的__getitem__方法時和訓練時候有些差別。

if args.test_crops == 1:
    cropping = torchvision.transforms.Compose([
        GroupScale(net.scale_size),
        GroupCenterCrop(net.input_size),
    ])
elif args.test_crops == 10:
    cropping = torchvision.transforms.Compose([
        GroupOverSample(net.input_size, net.scale_size)
    ])
else:
    raise ValueError("Only 1 and 10 crops are supported while we got {}".format(args.test_crops))

data_loader = torch.utils.data.DataLoader(
        TSNDataSet("", args.test_list, num_segments=args.test_segments,
                   new_length=1 if args.modality == "RGB" else 5,
                   modality=args.modality,
                   image_tmpl="img_{:05d}.jpg" if args.modality in ['RGB', 'RGBDiff'] else args.flow_prefix+"{}_{:05d}.jpg",
                   test_mode=True,
                   transform=torchvision.transforms.Compose([
                       cropping,
                       Stack(roll=args.arch == 'BNInception'),
                       ToTorchFormatTensor(div=args.arch != 'BNInception'),
                       GroupNormalize(net.input_mean, net.input_std),
                   ])),
        batch_size=1, shuffle=False,
        num_workers=args.workers * 2, pin_memory=True)

transforms.py指令碼中的GroupOverSample類。首先__init__中的GroupScale類也是在transforms.py中定義的,其實是對輸入的n張影象都做torchvision.transforms.Scale操作,也就是resize到指定尺寸。GroupMultiScaleCrop.fill_fix_offset返回的offsets是一個長度為5的列表,每個值都是一個tuple,其中前4個是四個點座標,最後一個是中心點座標,目的是以這5個點為左上角座標時可以在原圖的四個角和中心部分crop出指定尺寸的圖,後面有例子介紹。crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))是按照crop_w*crop_h的大小去crop原影象,這裡採用的是224*224。flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)是對crop得到的影象做左右翻轉。最後把未翻轉的和翻轉後的列表合併,這樣一張輸入影象就可以得到10張輸出了(5張crop,5張crop加翻轉)。舉個例子,假設image_w=340,image_h=256,crop_w=224,crop_h=224,那麼offsets就是[(0,0),(116,0),(0,32),(116,32),(58,16)],因此第一個crop的結果就是原圖上左上角座標為(0,0),右下角座標為(224,224)的圖,這也就是原圖的左上角部分圖;第二個crop的結果就是原圖上左上角座標為(116,0),右下角座標為(340,224)的圖,這也就是原圖的右上角部分圖,其他依次類推分別是原圖的左下角部分圖和右下角部分圖,最後一個是原圖正中央crop出來的224*224圖。這就是論文中說的corner crop,而且是4個corner和1個center。

class GroupOverSample(object):
    def __init__(self, crop_size, scale_size=None):
        self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)

        if scale_size is not None:
            self.scale_worker = GroupScale(scale_size)
        else:
            self.scale_worker = None

    def __call__(self, img_group):

        if self.scale_worker is not None:
            img_group = self.scale_worker(img_group)

        image_w, image_h = img_group[0].size
        crop_w, crop_h = self.crop_size

        offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h)
        oversample_group = list()
        for o_w, o_h in offsets:
            normal_group = list()
            flip_group = list()
            for i, img in enumerate(img_group):
                crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
                normal_group.append(crop)
                flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)

                if img.mode == 'L' and i % 2 == 0:
                    flip_group.append(ImageOps.invert(flip_crop))
                else:
                    flip_group.append(flip_crop)

            oversample_group.extend(normal_group)
            oversample_group.extend(flip_group)
        return oversample_group

接下來是設定GPU模式、設定模型為驗證模式、初始化資料等。

if args.gpus is not None:
    devices = [args.gpus[i] for i in range(args.workers)]
else:
    devices = list(range(args.workers))

net = torch.nn.DataParallel(net.cuda(devices[0]), device_ids=devices)
net.eval()

data_gen = enumerate(data_loader)

total_num = len(data_loader.dataset)
output = []

開始迴圈讀取資料,每執行一次迴圈表示讀取一個video的資料。在迴圈中主要是呼叫eval_video函式來測試。預測結果和真實標籤的結果都儲存在output列表中。

proc_start_time = time.time()
max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset)

for i, (data, label) in data_gen:
    if i >= max_num:
        break
    rst = eval_video((i, data, label))
    output.append(rst[1:])
    cnt_time = time.time() - proc_start_time
    print('video {} done, total {}/{}, average {} sec/video'.format(i, i+1,
                                                                    total_num,
                                                                    float(cnt_time) / (i+1)))

eval_video函式是測試的主體,當準備好了測試資料和模型後就通過這個函式進行預測。輸入video_data是一個tuple:(i, data, label)。data.view(-1, length, data.size(2), data.size(3))是將原本輸入為(1,3*args.test_crops*args.test_segments,224,224)變換到(args.test_crops*args.test_segments,3,224,224),相當於batch size為args.test_crops*args.test_segments。然後用torch.autograd.Variable介面封裝成Variable型別資料並作為模型的輸入。net(input_var)得到的結果是Variable,如果要讀取Tensor內容,需讀取data變數,cpu()表示儲存到cpu,numpy()表示Tensor轉為numpy array,copy()表示拷貝。rst.reshape((num_crop, args.test_segments, num_class))表示將輸入維數(二維)變化到指定維數(三維),mean(axis=0)表示對num_crop維度取均值,也就是原來對某幀影象的10張crop或clip影象做預測,最後是取這10張預測結果的均值作為該幀影象的結果。最後再執行一個reshape操作。最後返回的是3個值,分別表示video的index,預測結果和video的真實標籤。

def eval_video(video_data):
    i, data, label = video_data
    num_crop = args.test_crops

    if args.modality == 'RGB':
        length = 3
    elif args.modality == 'Flow':
        length = 10
    elif args.modality == 'RGBDiff':
        length = 18
    else:
        raise ValueError("Unknown modality "+args.modality)

    input_var = torch.autograd.Variable(data.view(-1, length, data.size(2), data.size(3)),
                                        volatile=True)
    rst = net(input_var).data.cpu().numpy().copy()
    return i, rst.reshape((num_crop, args.test_segments, num_class)).mean(axis=0).reshape(
        (args.test_segments, 1, num_class)
    ), label[0]

接下來要計算video-level的預測結果,這裡從np.mean(x[0], axis=0)可以看出對args.test_segments幀影象的結果採取的也是均值方法來計算video-level的預測結果,然後通過np.argmax將概率最大的那個類別作為該video的預測類別。video_labels則是真實類別。cf = confusion_matrix(video_labels, video_pred).astype(float)是呼叫了混淆矩陣生成結果(numpy array),舉個例子,y_true=[2,0,2,2,0,1],y_pred=[0,0,2,2,0,2],那麼confusion_matrix(y_true, y_pred)的結果就是array([[2,0,0],[0,0,1],[1,0,2]]),每行表示真實類別,每列表示預測類別。因此cls_cnt = cf.sum(axis=1)表示每個真實類別有多少個video,cls_hit = np.diag(cf)就是將cf的對角線資料取出,表示每個類別的video中各預測對了多少個,因此cls_acc = cls_hit / cls_cnt就是每個類別的video預測準確率。np.mean(cls_acc)就是各類別的平均準確率。最後的if args.save_scores is not None:語句只是用來將預測結果儲存成檔案。

video_pred = [np.argmax(np.mean(x[0], axis=0)) for x in output]

video_labels = [x[1] for x in output]

cf = confusion_matrix(video_labels, video_pred).astype(float)

cls_cnt = cf.sum(axis=1)
cls_hit = np.diag(cf)

cls_acc = cls_hit / cls_cnt

print(cls_acc)

print('Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))

if args.save_scores is not None:

    # reorder before saving
    name_list = [x.strip().split()[0] for x in open(args.test_list)]

    order_dict = {e:i for i, e in enumerate(sorted(name_list))}

    reorder_output = [None] * len(output)
    reorder_label = [None] * len(output)

    for i in range(len(output)):
        idx = order_dict[name_list[i]]
        reorder_output[idx] = output[i]
        reorder_label[idx] = video_labels[i]

    np.savez(args.save_scores, scores=reorder_output, labels=reorder_label)