1. 程式人生 > 其它 >torch遷移訓練inceptionv3

torch遷移訓練inceptionv3

技術標籤:cv計算機視覺深度學習pytorch遷移學習

注意事項:

1

if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = False
        else:
            original_aux_logits =
False # we are loading weights from a pretrained model kwargs['init_weights'] = False model = Inception3(**kwargs) # state_dict = load_state_dict_from_url(model_urls['inception_v3_google'], # progress=progress) state_dict =
torch.load(model_path) model.load_state_dict(state_dict) if not original_aux_logits: model.aux_logits = False del model.AuxLogits return model

將所有aux涉及到的變數設成false。刪除所有的輔助分類器。

2

調整輸入尺寸為N x 3 x 299 x 299。色彩模式為RGB。

def Inception_loader(path):

    # ANTIALIAS:high quality
return Image.open(path).resize((299, 299), Image.ANTIALIAS).convert('RGB')

3

遇到一個問題,如下程式碼:

def _transform_input(self, x):
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * \
                (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * \
                (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * \
                (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
        return x

是在計算三個色彩通道嗎?反正我關掉了。
目前推測,這個是為了將色彩通道前移的方法,但是裡面這些數字的含義仍然讓人無法理解。

(原稿)2021-1-24日夜
注意事項:
1.輸入影象 N x 3 x 299 x 299 的 尺寸必須被保證:
使用如下的自定義loader:

def Inception_loader(path):

    # ANTIALIAS:high quality
    return Image.open(path).resize((299, 299), Image.ANTIALIAS).convert('RGB')

2.關閉輔助分類器:

    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = False
        else:
            original_aux_logits = False
        # we are loading weights from a pretrained model
        kwargs['init_weights'] = False
        model = Inception3(**kwargs)
        # state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
        #                                       progress=progress)
        state_dict = torch.load(model_path)
        model.load_state_dict(state_dict)
        if not original_aux_logits:
            model.aux_logits = False
            del model.AuxLogits
        return model

把所有的aux相關屬性全設成false就好了
3.

def _transform_input(self, x):
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * \
                (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * \
                (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * \
                (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
        return x
        

這個程式碼一直比較疑惑是怎麼回事:像這種公式是怎麼推出來的?