keras中卷積核維數與輸入維數的關係
阿新 • • 發佈:2019-02-04
裡面如果有錯的話,還望指正,謝謝。
在使用keras的卷積層時,需要選擇卷積核的維數,對應的函式分別為Conv1D和Conv2D等,剛開始使用時一直沒搞明白為啥輸入影象是3維的RGB影象,但是在配置卷積層時卻選擇Conv2D函式,而不是Conv3D,後來看keras的原始碼,才知道原因(我等菜鳥還是要多看原始碼才能進步。。)
在keras.layers中的convolutional.py中,基類函式_Conv中卷積核的尺寸是kernel_shape,可以看到程式碼中:
kernel_shape = self.kernel_size + (input_dim, self.filters)
self.kernel_size由conv_utils.normalize_tuple函式得到,該函式根據輸入的引數返回一個元組,引數分別是:kernel_size, rank, 'kernel_size'。
等式的第二個部分是一個元組,(input_dim, self.filters)分別表示輸入資料的維度和卷積核的個數,兩個元組相加,得到最終的卷積核的shape。例如在輸入為RGB影象時,input_dim=3(由輸入引數input_shape得出),filters=64,則:
kernel_shape = self.kernel_size + (input_dim, self.filters)=(3,3,3,64)
因此在計算引數個數時,若輸入尺寸為W*H*D,卷積核移動步長為s,同時padding,保持輸入輸出的W和H不變,考慮偏置項b,卷積核大小為(k,k),卷積核個數為n,則此時卷積層需要確定的引數個數為:n*(k*k*D+1),引數個數就與輸入影象的大小就無關啦,少了很多引數,因此上述例子中引數個數為64*(3*3*3+1),括號中前面的27為卷積核的引數,後面的加1為偏置項b。總的連線數為:[(W-k+2*padding)/s+1]*[(H-k+2*padding)/s+1]*n*(k*k*D+1)(這裡的偏置項不知道算不算是連線的一部分),(W-k+2*padding)/s+1和(H-k+2*padding)/s+1得到輸出的大小
裡面如果有錯的話,還望指正,謝謝。
在使用keras的卷積層時,需要選擇卷積核的維數,對應的函式分別為Conv1D和Conv2D等,剛開始使用時一直沒搞明白為啥輸入影象是3維的RGB影象,但是在配置卷積層時卻選擇Conv2D函式,而不是Conv3D,後來看keras的原始碼,才知道原因(我等菜鳥還是要多看原始碼才能進步。。)
在keras.layers中的convolutional.py中,基類函式_Conv中卷積核的尺寸是kernel_shape,可以看到程式碼中:
kernel_shape = self.kernel_size + (input_dim, self.filters)
self.kernel_size由conv_utils.normalize_tuple函式得到,該函式根據輸入的引數返回一個元組,引數分別是:kernel_size, rank, 'kernel_size'。
1、kernel_size是卷基層中人為指定的,可能為元組,也可能是一個整數,用於確定返回的元組kernel_size中元素的值;
2、rank為卷積核的維數(https://www.zhihu.com/question/21605094,第一個回答把秩的概念將的非常清楚,厲害),也是一個輸入時確定的引數,但是在繼承類中都已經指定了。用於確定返回的元組中kernel_size的大小;3、'kernel_size'僅僅用於報錯時指出錯誤。
例如在使用Conv2D時,假設設定的kernel_size =3,或者(3,3),rank已經設定為2,則返回的self.kernel_size=(3,3)等式的第二個部分是一個元組,(input_dim, self.filters)分別表示輸入資料的維度和卷積核的個數,兩個元組相加,得到最終的卷積核的shape。例如在輸入為RGB影象時,input_dim=3(由輸入引數input_shape得出),filters=64,則:
kernel_shape = self.kernel_size + (input_dim, self.filters)=(3,3,3,64)
因此在計算引數個數時,若輸入尺寸為W*H*D,卷積核移動步長為s,同時padding,保持輸入輸出的W和H不變,考慮偏置項b,卷積核大小為(k,k),卷積核個數為n,則此時卷積層需要確定的引數個數為:n*(k*k*D+1),引數個數就與輸入影象的大小就無關啦,少了很多引數,因此上述例子中引數個數為64*(3*3*3+1),括號中前面的27為卷積核的引數,後面的加1為偏置項b。總的連線數為:[(W-k+2*padding)/s+1]*[(H-k+2*padding)/s+1]*n*(k*k*D+1)(這裡的偏置項不知道算不算是連線的一部分),(W-k+2*padding)/s+1和(H-k+2*padding)/s+1得到輸出的大小
裡面如果有錯的話,還望指正,謝謝。