Torch-nn學習:Tabel Layer
1.ConcatTable:對每個成員模組應用相同輸入。
如圖:
+-----------+
+----> {member1, |
+-------+ | | |
| input +----+----> member2, |
+-------+ | | |
or +----> member3} |
{input} +-----------+
示例:
mlp = nn.ConcatTable() mlp:add(nn.Linear(5, 2)) mlp:add(nn.Linear(5, 3)) pred = mlp:forward(torch.randn(5)) for i, k in ipairs(pred) do print(i, k) end
2.ParallelTable:對每個成員模組應用與之對應的輸入(第i個模組應用第i個輸入)。
如圖:
+----------+ +-----------+ | {input1, +---------> {member1, | | | | | | input2, +---------> member2, | | | | | | input3} +---------> member3} | +----------+ +-----------+
mlp = nn.ParallelTable() mlp:add(nn.Linear(10, 2)) mlp:add(nn.Linear(5, 3)) x = torch.randn(10) y = torch.rand(5) pred = mlp:forward{x, y} for i, k in pairs(pred) do print(i, k) end
3.MapTable:對所有輸入應用,不夠的就clone。引數共享(weight
, bias
, gradWeight
and gradBias
)
eg:
+----------+ +-----------+ | {input1, +---------> {member, | | | | | | input2, +---------> clone, | | | | | | input3} +---------> clone} | +----------+ +-----------+
map = nn.MapTable() map:add(nn.Linear(10, 3)) x1 = torch.rand(10) x2 = torch.rand(10) y = map:forward{x1, x2} for i, k in pairs(y) do print(i, k) end
4.SplitTable:這個不用解釋,index可以為負數
module = SplitTable(dimension, nInputDims)
+----------+ +-----------+
| input[1] +---------> {member1, |
+----------+-+ | |
| input[2] +-----------> member2, |
+----------+-+ | |
| input[3] +-------------> member3} |
+----------+ +-----------+
mlp = nn.SplitTable(2) x = torch.randn(4, 3) pred = mlp:forward(x) for i, k in ipairs(pred) do print(i, k) end
mlp = nn.SplitTable(1, 2) pred = mlp:forward(torch.randn(2, 4, 3))//will get four 2*3 matrix for i, k in ipairs(pred) do print(i, k) end pred = mlp:forward(torch.randn(4, 3))//will get four 1*3 matrix for i, k in ipairs(pred) do print(i, k) end
A more complicated example
mlp = nn.Sequential() -- Create a network that takes a Tensor as input mlp:add(nn.SplitTable(2)) c = nn.ParallelTable() -- The two Tensor slices go through two different Linear c:add(nn.Linear(10, 3)) -- Layers in Parallel c:add(nn.Linear(10, 7)) mlp:add(c) -- Outputing a table with 2 elements p = nn.ParallelTable() -- These tables go through two more linear layers separately p:add(nn.Linear(3, 2)) p:add(nn.Linear(7, 1)) mlp:add(p) mlp:add(nn.JoinTable(1)) -- Finally, the tables are joined together and output. pred = mlp:forward(torch.randn(10, 2)) print(pred) for i = 1, 100 do -- A few steps of training such a network.. x = torch.ones(10, 2) y = torch.Tensor(3) y:copy(x:select(2, 1):narrow(1, 1, 3)) pred = mlp:forward(x) criterion = nn.MSECriterion() local err = criterion:forward(pred, y) local gradCriterion = criterion:backward(pred, y) mlp:zeroGradParameters() mlp:backward(x, gradCriterion) mlp:updateParameters(0.05) print(err) end
5.JoinTable:對每個成員模組應用相同輸入。
module=JoinTable(dimension, nInputDims)
+----------+ +-----------+
| {input1, +-------------> output[1] |
| | +-----------+-+
| input2, +-----------> output[2] |
| | +-----------+-+
| input3} +---------> output[3] |
+----------+ +-----------+
x = torch.randn(5, 1) y = torch.randn(5, 1) z = torch.randn(2, 1) print(nn.JoinTable(1):forward{x, y})//10*1 print(nn.JoinTable(2):forward{x, y})//5*2
module = nn.JoinTable(2, 2) x = torch.randn(3, 1) y = torch.randn(3, 1) mx = torch.randn(2, 3, 1) my = torch.randn(2, 3, 1) print(module:forward{x, y})//3*2 print(module:forward{mx, my})//2*3*2
6.MixtureTable:
input :table
{gater,
experts}
output = G[1]*E[1] + G[2]*E[2] + ... + G[n]*E[n]
where dim = 1
, n
= E:size(dim) = G:size(dim)
and G:dim() == 1
. Note that E:dim()
>= 2
, such that output:dim() = E:dim() - 1
.
experts = nn.ConcatTable() for i = 1, n do local expert = nn.Sequential() expert:add(nn.Linear(3, 4)) expert:add(nn.Tanh()) expert:add(nn.Linear(4, 5)) expert:add(nn.Tanh()) experts:add(expert) end gater = nn.Sequential() gater:add(nn.Linear(3, 7)) gater:add(nn.Tanh()) gater:add(nn.Linear(7, n)) gater:add(nn.SoftMax()) trunk = nn.ConcatTable() trunk:add(gater) trunk:add(experts) moe = nn.Sequential() moe:add(trunk) moe:add(nn.MixtureTable())
Forwarding a batch of 2 examples gives us something like this:
> =moe:forward(torch.randn(2, 3)) -0.2152 0.3141 0.3280 -0.3772 0.2284 0.2568 0.3511 0.0973 -0.0912 -0.0599 [torch.DoubleTensor of dimension 2x5]7.SelectTable:
> input = {torch.randn(2, 3), torch.randn(2, 1)} > =nn.SelectTable(1):forward(input) -0.3060 0.1398 0.2707 0.0576 1.5455 0.0610
> input = {torch.randn(2, 3), {torch.randn(2, 1), {torch.randn(2, 2)}}}
> gradInput = nn.SelectTable(1):backward(input, torch.randn(2, 3)) > =gradInput { 1 : DoubleTensor - size: 2x3 2 : { 1 : DoubleTensor - size: 2x1 2 : { 1 : DoubleTensor - size: 2x2 } } }
8.NarrowTable:輸入表,輸出從offset開始的length個子表
module
= NarrowTable(offset
[, length])
> input = {torch.randn(2, 3), torch.randn(2, 1), torch.randn(1, 2)} > =nn.NarrowTable(2,2):forward(input) { 1 : DoubleTensor - size: 2x1 2 : DoubleTensor - size: 1x2 } > =nn.NarrowTable(1):forward(input) { 1 : DoubleTensor - size: 2x3 }
9.FlattenTable:扁平後子表的序號是按照DFS後序遍歷的順序
x = {torch.rand(1), {torch.rand(2), {torch.rand(3)}}, torch.rand(4)} print(x) print(nn.FlattenTable():forward(x))
{ 1 : DoubleTensor - size: 1 2 : { 1 : DoubleTensor - size: 2 2 : { 1 : DoubleTensor - size: 3 } } 3 : DoubleTensor - size: 4 } { 1 : DoubleTensor - size: 1 2 : DoubleTensor - size: 2 3 : DoubleTensor - size: 3 4 : DoubleTensor - size: 4 }
10.PariwiseDistance:
module
= PairwiseDistance(p)
//p-norm
11.DotProduct
mlp = nn.DotProduct() x = torch.Tensor({1, 2, 3}) y = torch.Tensor({4, 5, 6}) print(mlp:forward({x, y}))
12.CosineDistance:
mlp = nn.CosineDistance() x = torch.Tensor({1, 2, 3}) y = torch.Tensor({4, 5, 6}) print(mlp:forward({x, y}))
13.CriterionTable:wraps a Criterion module so that it can accept a table
of
inputs.
mlp = nn.CriterionTable(nn.MSECriterion()) x = torch.randn(5) y = torch.randn(5) print(mlp:forward{x, x})//0 print(mlp:forward{x, y})//1.9028918413199
14.CAddTable,CSubTable,CMulTable,CDivTable,CMaxTable,CMinTable