1. 程式人生 > 程式設計 >PyTorch實現更新部分網路,其他不更新

PyTorch實現更新部分網路,其他不更新

torch.Tensor.detach()的使用

detach()的官方說明如下:

Returns a new Tensor,detached from the current graph.
The result will never require gradient.

假設有模型A和模型B,我們需要將A的輸出作為B的輸入,但訓練時我們只訓練模型B. 那麼可以這樣做:

input_B = output_A.detach()

它可以使兩個計算圖的梯度傳遞斷開,從而實現我們所需的功能。

以上這篇PyTorch實現更新部分網路,其他不更新就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。