pytorch 知識點總結(持續更新)
阿新 • • 發佈:2018-12-09
1、argparse的使用 (Python指令碼時傳入引數的三種方式之一:https://blog.csdn.net/u012426298/article/details/80263507)
import argparse#必備 parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')#必備 parser.add_argument('data', metavar='DIR', help='path to dataset')# parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=90, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate') args = parser.parse_args()#必備 traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val')### optimizer = torch.optim.SGD(model.parameters(), args.lr,# momentum=args.momentum, weight_decay=args.weight_decay)
執行檔案:
python main.py -a alexnet --lr 0.01 [imagenet-folder with train and val folders]
2、限制使用哪個GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
3、讀出Tensor裡面的值
4、在分類任務中找出單個類別的準確率(每一個類別)
class_correct = list(0. for i in range(10))#10是類別的個數 class_total = list(0. for i in range(10)) with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs, 1) c = (predicted == labels).squeeze()#每一個batch的(predicted==labels) for i in range(4):#4是每一個batch的個數 label = labels[i] class_correct[label] += c[i].item() class_total[label] += 1 for i in range(10): print('Accuracy of %5s : %2d %%' % ( classes[i], 100 * class_correct[i] / class_total[i]))#每一個類別的準確率
5、