1. 程式人生 > 實用技巧 >pytorch(11)模型建立步驟與nn.Module

pytorch(11)模型建立步驟與nn.Module

模型建立與nn.Module

  1. 網路模型建立步驟
    nn.Module
graph LR 模型 --> 模型建立 模型建立 --> 構建網路層 構建網路層 --> id[卷積層,池化層,啟用函式層] 模型 --> 權值初始化 權值初始化 --> id1[Xavier,Kaiming,均勻分佈,正太分佈] 模型建立 --> 拼接網路層 拼接網路層 --> id2[LeNet,AlexNet,ResNet]

LeNet
Conv1 --> pool1 --> Conv2 --> pool2 --> fc1-->fc2 --> fc3

模型建立步驟:

graph LR A[模型構建兩要素] --> B[構建子模組] B[構建子模組] --> C["__init__()"] A[模型構建兩要素] --> E[拼接子模組] E[拼接子模組] --> F["forwar()"]
  1. nn.Module屬性
    在模型模組中,有一個非常的重要概念是nn.Module,所有的模型、網路層都會繼承nn.Module類的
graph LR A["torch.nn"] --> B["nn.Parameter"] B["nn.Parameter"] --> D["張量子類,表示可學習引數,如weights,bias"] A["torch.nn"] --> E["nn.Module"] E["nn.Module"] --> F["所有網路層基類,管理網路屬性"] A["torch.nn"] --> G["nn.functional"] G["nn.functional"] --> H["函式具體實現,如卷積,池化,啟用函式等"] A["torch.nn"] --> I["nn.init"] I["nn.init"] --> K["引數初始化方法"]

nn.Module八個重要的屬性,用於管理整個模型

  • parameters:儲存管理nn.Parameter類,例如權值、偏置等引數
  • modules:儲存管理nn.Module類。例如LeNet,它會構建它的子模組,卷積層、池化層。LeNet的modules示例,就會儲存它的卷積層、池化層。
  • buffers:儲存管理緩衝屬性,如BN層中的running_mean
  • ***_hooks:儲存管理鉤子函式
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()

net = LeNet(classes=2)

  1. 此時呼叫lenet.py中的class LeNet(nn.Module)的__init__,繼承super(LeNet,self).__init__方法,這是一個module的方法,初始化parameters,buffers等引數。因此首先將LeNet作為一個module初始化完成。
  2. self.conv1 = nn.Conv2d(3, 6, 5)。此時呼叫conv中的Conv2d類,繼承的是_ConvNde類,使用其__init__方法,初始化引數(通道數、輸出通道數、padding等),super的初始化方法,即_ConvNd類,此類繼承Module類,因此也是一個module類,進行初始化。進入它的super方法,最終也是Module.
  3. self.conv1 = nn.Conv2d(3, 6, 5)。再呼叫賦值的方法時並不是直接賦值,而是在判斷是否是引數,如果是引數,那就賦值給parameter。如果不是引數,如果是module,那就把value的值賦給modules。如果是buffers,那就把value的值賦給buffers.
  4. self.fc1 = nn.Linear(1655, 120)。呼叫Linear的類,這個類也是繼承的Module類,super(Linear,self)類,也是一個Module的類。Linear的方法中的賦值都用到了module.py中的setattr方法,用來判斷是引數還是模型

nn.Module
一個module可以包含多個子module
一個module相當於一個運算,必須實現forward()函式
每個modeule都有8個字典管理它的屬性

pytorch
資料模組,將資料轉換為張量形式輸入模型,在深度學習模型中,對輸入的張量進行復雜的數學運算,進行分類、分割、目標檢測的輸出。