1. 程式人生 > >多輸入通道和多輸出通道

多輸入通道和多輸出通道

多輸入通道

多輸出通道:

設卷積核輸入通道,輸出通道是 ci , co,高寬為 kh , kw

為每個輸出通道分別建立 ci * kh * kw 的核陣列。在輸出通道維上連結,卷積核的形狀即為co * ci *kh*kw

from mxnet import autograd,nd
from mxnet import gluon,init
from mxnet.gluon import nn,loss as gloss
from mxnet.gluon import data as gdata

# 二維卷積層
def corr2d(X,K):
    h, w 
= K.shape Y = nd.zeros((X.shape[0] - h + 1,X.shape[1] - w + 1)) for i in range(Y.shape[0]): for j in range(Y.shape[1]): Y[i,j] = (X[i: i+h,j:j+w]*K).sum() return Y # 多通道輸入 def corr2d_multi_in(X,K): return nd.add_n(*[corr2d(x,k) for x,k in zip(X,K)]) X = nd.array([[[0,1,2],[3,4,5],[6,7,8]], [[
1,2,3],[4,5,6],[7,8,9]]]) print(X) K = nd.array([[[0,1],[2,3]],[[1,2],[3,4]]]) print(corr2d_multi_in(X,K)) # 多通道輸出 # 為每個輸出通道分別建立 ci * kh * kw 的核陣列 # 將他們在輸出通道維上連結,卷積核的形狀即為 co * ci * kh * kw 的核陣列 def corr2d_multi_in_out(X,K): return nd.stack(*[corr2d_multi_in(X,k) for k in K]) # 3 通道核 3 * 2 * 2 * 2
K = nd.stack(K,K+1,K+2) print(K) print(corr2d_multi_in_out(X,K))