pytorch 讀取和儲存模型引數
阿新 • • 發佈:2020-08-22
只儲存引數資訊
載入
checkpoint = torch.load(opt.resume)
model.load_state_dict(checkpoint)
儲存
torch.save(self.state_dict(),file_path)
這而只儲存了引數資訊,讀取時也只有引數資訊,模型結構需要手動編寫
儲存整個模型
儲存
torch.save(the_model, PATH)
載入:
the_model = torch.load(PATH)
有時候會看到載入時
model.load_state_dict(checkpoint['state_dic'])
這是因為checkpoint是一個字典,儲存的key可以自己定義。
儲存
torch.save({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, 'checkpoint.tar' )
載入
if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.evaluate, checkpoint['epoch']))
state_dict參考連結:
https://www.cnblogs.com/tingtin/p/13544489.html