torch.nn與torch.nn.functional之間的區別和聯絡
阿新 • • 發佈:2018-12-18
原文地址:https://blog.csdn.net/GZHermit/article/details/78730856
迷惑的地方是在於forward的函式的定義方法。為什麼要把網路中的一部分層在__init__()函式裡定義出來,而另一部分層則是在__forward()__函式裡定義?並且一個用的是nn,另一個用的是nn.functional。同一種層的API定義有兩種,這樣看似冗餘的設計是為了什麼呢?
nn.Conv2d是一個類,而F.conv2d()是一個函式,而nn.Conv2d的forward()函式實現是用F.conv2d()實現的(在Module類裡的__call__實現了forward()函式的呼叫,所以當例項化nn.Conv2d類時,forward()函式也被執行了,詳細可閱讀torch原始碼),所以兩者功能並無區別,那麼為什麼要有這樣的兩種實現方式同時存在呢?
原因其實在於,為了兼顧靈活性和便利性。
在建圖過程中,往往有兩種層,一種如全連線層,卷積層等,當中有Variable,另一種如Pooling層,Relu層等,當中沒有Variable。
如果所有的層都用nn.functional來定義,那麼所有的Variable,如weights,bias等,都需要使用者來手動定義,非常不方便。
而如果所有的層都換成nn來定義,那麼即便是簡單的計算都需要建類來做,而這些可以用更為簡單的函式來代替的。