torch遷移訓練inceptionv3
阿新 • • 發佈:2021-01-26
注意事項:
1
if pretrained:
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True
if 'aux_logits' in kwargs:
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = False
else:
original_aux_logits = False
# we are loading weights from a pretrained model
kwargs['init_weights'] = False
model = Inception3(**kwargs)
# state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
# progress=progress)
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
return model
將所有aux涉及到的變數設成false。刪除所有的輔助分類器。
2
調整輸入尺寸為N x 3 x 299 x 299。色彩模式為RGB。
def Inception_loader(path):
# ANTIALIAS:high quality
return Image.open(path).resize((299, 299), Image.ANTIALIAS).convert('RGB')
3
遇到一個問題,如下程式碼:
def _transform_input(self, x):
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * \
(0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * \
(0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * \
(0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
return x
是在計算三個色彩通道嗎?反正我關掉了。
目前推測,這個是為了將色彩通道前移的方法,但是裡面這些數字的含義仍然讓人無法理解。
(原稿)2021-1-24日夜
注意事項:
1.輸入影象 N x 3 x 299 x 299 的 尺寸必須被保證:
使用如下的自定義loader:
def Inception_loader(path):
# ANTIALIAS:high quality
return Image.open(path).resize((299, 299), Image.ANTIALIAS).convert('RGB')
2.關閉輔助分類器:
if pretrained:
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True
if 'aux_logits' in kwargs:
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = False
else:
original_aux_logits = False
# we are loading weights from a pretrained model
kwargs['init_weights'] = False
model = Inception3(**kwargs)
# state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
# progress=progress)
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
return model
把所有的aux相關屬性全設成false就好了
3.
def _transform_input(self, x):
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * \
(0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * \
(0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * \
(0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
return x
這個程式碼一直比較疑惑是怎麼回事:像這種公式是怎麼推出來的?