多輸入通道和多輸出通道
阿新 • • 發佈:2018-11-29
多輸入通道
多輸出通道:
設卷積核輸入通道,輸出通道是 ci , co,高寬為 kh , kw
為每個輸出通道分別建立 ci * kh * kw 的核陣列。在輸出通道維上連結,卷積核的形狀即為co * ci *kh*kw
from mxnet import autograd,nd from mxnet import gluon,init from mxnet.gluon import nn,loss as gloss from mxnet.gluon import data as gdata # 二維卷積層 def corr2d(X,K): h, w= K.shape Y = nd.zeros((X.shape[0] - h + 1,X.shape[1] - w + 1)) for i in range(Y.shape[0]): for j in range(Y.shape[1]): Y[i,j] = (X[i: i+h,j:j+w]*K).sum() return Y # 多通道輸入 def corr2d_multi_in(X,K): return nd.add_n(*[corr2d(x,k) for x,k in zip(X,K)]) X = nd.array([[[0,1,2],[3,4,5],[6,7,8]], [[1,2,3],[4,5,6],[7,8,9]]]) print(X) K = nd.array([[[0,1],[2,3]],[[1,2],[3,4]]]) print(corr2d_multi_in(X,K)) # 多通道輸出 # 為每個輸出通道分別建立 ci * kh * kw 的核陣列 # 將他們在輸出通道維上連結,卷積核的形狀即為 co * ci * kh * kw 的核陣列 def corr2d_multi_in_out(X,K): return nd.stack(*[corr2d_multi_in(X,k) for k in K]) # 3 通道核 3 * 2 * 2 * 2K = nd.stack(K,K+1,K+2) print(K) print(corr2d_multi_in_out(X,K))