1. 程式人生 > 程式設計 >pytorch梯度剪裁方式

pytorch梯度剪裁方式

我就廢話不多說,看例子吧!

import torch.nn as nn

outputs = model(data)
loss= loss_fn(outputs,target)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(),max_norm=20,norm_type=2)
optimizer.step()

nn.utils.clip_grad_norm_ 的引數:

parameters – 一個基於變數的迭代器,會進行梯度歸一化

max_norm – 梯度的最大範數

norm_type – 規定範數的型別,預設為L2

以上這篇pytorch梯度剪裁方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。