Pytorch訓練模型常用操作
阿新 • • 發佈:2021-10-21
One-hot編碼
將標籤轉換為one-hot編碼形式
def to_categorical(y, num_classes):
""" 1-hot encodes a tensor """
new_y = torch.eye(num_classes)[y.cpu().data.numpy(), ]
if (y.is_cuda):
return new_y.cuda()
return new_y
- 示例
>>> y = np.array([1,2,3]) >>> y array([1, 2, 3]) >>> torch.eye(4)[y,] tensor([[0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]]) >>> y array([[1, 2, 2], [1, 2, 3]]) >>> torch.eye(4)[y,] tensor([[[0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 1., 0.]], [[0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]]]) >>> torch.eye(4)[y] tensor([1., 1., 0.])
分別初始化
def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: torch.nn.init.xavier_normal_(m.weight.data) torch.nn.init.constant_(m.bias.data, 0.0) elif classname.find('Linear') != -1: torch.nn.init.xavier_normal_(m.weight.data) torch.nn.init.constant_(m.bias.data, 0.0) classifier = classifier.apply(weights_init)
checkpoint檢查是否接著訓練
try: checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth') start_epoch = checkpoint['epoch'] classifier.load_state_dict(checkpoint['model_state_dict']) log_string('Use pretrain model') except: log_string('No existing model, starting training from scratch...') start_epoch = 0
根據迭代次數調整學習率
def bn_momentum_adjust(m, momentum):
if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
m.momentum = momentum
lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
log_string('Learning rate:%f' % lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
if momentum < 0.01:
momentum = 0.01
print('BN momentum updated to: %f' % momentum)
classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum))
classifier = classifier.train()
批量資料維度不一致
自定義torch.utils.data.Dataloader(dataset, collate_fn=collate_fn)
中的collate_fn
def my_collate_fn(batch_data):
"""
descriptions: 對齊批量資料維度, [(data, label),(data, label)...]轉化成([data, data...],[label,label...])
:param batch_data: list,[(data, label),(data, label)...]
:return: tuple, ([data, data...],[label,label...])
"""
batch_data.sort(key=lambda x: len(x[0][0]), reverse=False) # 按照資料長度升序排序
data_list = []
label_list = []
max_len = len(batch_data[0][0][0])
for batch in range(0, len(batch_data)):
data = batch_data[batch][0][0]
label = batch_data[batch][0][1]
diff = max_len - len(data)
for i in range(diff):
data.append([0, 0, 0])
label.append(0)
data_list.append(data)
label_list.append(label)
data_tensor = torch.tensor(data_list, dtype=torch.float32)
label_tensor = torch.tensor(label_list, dtype=torch.float32)
data_copy = (data_tensor, label_tensor)
return data_copy