1. 程式人生 > 其它 >Pytorch訓練模型常用操作

Pytorch訓練模型常用操作

One-hot編碼

將標籤轉換為one-hot編碼形式

def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    new_y = torch.eye(num_classes)[y.cpu().data.numpy(), ]
    if (y.is_cuda):
        return new_y.cuda()
    return new_y
  • 示例
>>> y = np.array([1,2,3])
>>> y
array([1, 2, 3])
>>> torch.eye(4)[y,]
tensor([[0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

>>> y
array([[1, 2, 2],
       [1, 2, 3]])
>>> torch.eye(4)[y,]
tensor([[[0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 1., 0.]],

        [[0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]]])
>>> torch.eye(4)[y]
tensor([1., 1., 0.])

分別初始化

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
      torch.nn.init.xavier_normal_(m.weight.data)
      torch.nn.init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
      torch.nn.init.xavier_normal_(m.weight.data)
      torch.nn.init.constant_(m.bias.data, 0.0)

classifier = classifier.apply(weights_init)

checkpoint檢查是否接著訓練

try:
    checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
    start_epoch = checkpoint['epoch']
    classifier.load_state_dict(checkpoint['model_state_dict'])
    log_string('Use pretrain model')
except:
    log_string('No existing model, starting training from scratch...')
    start_epoch = 0

根據迭代次數調整學習率


def bn_momentum_adjust(m, momentum):
    if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
        m.momentum = momentum

lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
log_string('Learning rate:%f' % lr)
for param_group in optimizer.param_groups:
    param_group['lr'] = lr
momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
if momentum < 0.01:
    momentum = 0.01
print('BN momentum updated to: %f' % momentum)
classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum))
classifier = classifier.train()

批量資料維度不一致

自定義torch.utils.data.Dataloader(dataset, collate_fn=collate_fn)中的collate_fn

def my_collate_fn(batch_data):
    """
    descriptions: 對齊批量資料維度, [(data, label),(data, label)...]轉化成([data, data...],[label,label...])
    :param batch_data:  list,[(data, label),(data, label)...]
    :return: tuple, ([data, data...],[label,label...])
    """
    batch_data.sort(key=lambda x: len(x[0][0]), reverse=False)  # 按照資料長度升序排序
    data_list = []
    label_list = []
    max_len = len(batch_data[0][0][0])
    for batch in range(0, len(batch_data)):
        data = batch_data[batch][0][0]
        label = batch_data[batch][0][1]
        diff = max_len - len(data)
        for i in range(diff):
            data.append([0, 0, 0])
            label.append(0)
        data_list.append(data)
        label_list.append(label)

    data_tensor = torch.tensor(data_list, dtype=torch.float32)
    label_tensor = torch.tensor(label_list, dtype=torch.float32)
    data_copy = (data_tensor, label_tensor)
    return data_copy