1. 程式人生 > >Torch-nn學習:Tabel Layer

Torch-nn學習:Tabel Layer

1.ConcatTable:對每個成員模組應用相同輸入。

如圖:

                  +-----------+
             +----> {member1, |
+-------+    |    |           |
| input +----+---->  member2, |
+-------+    |    |           |
   or        +---->  member3} |
 {input}          +-----------+
示例:
mlp = nn.ConcatTable()
mlp:add
(nn.Linear(5, 2)) mlp:add(nn.Linear(5, 3)) pred = mlp:forward(torch.randn(5)) for i, k in ipairs(pred) do print(i, k) end

2.ParallelTable:對每個成員模組應用與之對應的輸入(第i個模組應用第i個輸入)。

如圖:

+----------+         +-----------+
| {input1, +---------> {member1, |
|          |         |           |
|  input2, +--------->  member2, |
|          |         |           |
|  input3} +--------->  member3} |
+----------+         +-----------+
mlp = nn.ParallelTable()
mlp:add(nn.Linear(10, 2))
mlp:add(nn.Linear(5, 3))

x = torch.randn(10)
y = torch.rand(5)

pred = mlp:forward{x, y}
for i, k in pairs(pred) do print(i, k) end

3.MapTable:對所有輸入應用,不夠的就clone。引數共享(weightbiasgradWeight and gradBias

eg:

+----------+         +-----------+
| {input1, +---------> {member,  |
|          |         |           |
|  input2, +--------->  clone,   |
|          |         |           |
|  input3} +--------->  clone}   |
+----------+         +-----------+
map = nn.MapTable()
map:add(nn.Linear(10, 3))

x1 = torch.rand(10)
x2 = torch.rand(10)
y = map:forward{x1, x2}

for i, k in pairs(y) do print(i, k) end

4.SplitTable:這個不用解釋,index可以為負數

module = SplitTable(dimension, nInputDims)
    +----------+         +-----------+
    | input[1] +---------> {member1, |
  +----------+-+         |           |
  | input[2] +----------->  member2, |
+----------+-+           |           |
| input[3] +------------->  member3} |
+----------+             +-----------+
mlp = nn.SplitTable(2)
x = torch.randn(4, 3)
pred = mlp:forward(x)
for i, k in ipairs(pred) do print(i, k) end
mlp = nn.SplitTable(1, 2)
pred = mlp:forward(torch.randn(2, 4, 3))//will get four 2*3 matrix
for i, k in ipairs(pred) do print(i, k) end
pred = mlp:forward(torch.randn(4, 3))//will get four 1*3 matrix
for i, k in ipairs(pred) do print(i, k) end

A more complicated example

mlp = nn.Sequential()       -- Create a network that takes a Tensor as input
mlp:add(nn.SplitTable(2))
c = nn.ParallelTable()      -- The two Tensor slices go through two different Linear
c:add(nn.Linear(10, 3))     -- Layers in Parallel
c:add(nn.Linear(10, 7))
mlp:add(c)                  -- Outputing a table with 2 elements
p = nn.ParallelTable()      -- These tables go through two more linear layers separately
p:add(nn.Linear(3, 2))
p:add(nn.Linear(7, 1))
mlp:add(p)
mlp:add(nn.JoinTable(1))    -- Finally, the tables are joined together and output.

pred = mlp:forward(torch.randn(10, 2))
print(pred)

for i = 1, 100 do           -- A few steps of training such a network..
   x = torch.ones(10, 2)
   y = torch.Tensor(3)
   y:copy(x:select(2, 1):narrow(1, 1, 3))
   pred = mlp:forward(x)

   criterion = nn.MSECriterion()
   local err = criterion:forward(pred, y)
   local gradCriterion = criterion:backward(pred, y)
   mlp:zeroGradParameters()
   mlp:backward(x, gradCriterion)
   mlp:updateParameters(0.05)

   print(err)
end

5.JoinTable:對每個成員模組應用相同輸入。

module=JoinTable(dimension, nInputDims)

+----------+             +-----------+
| {input1, +-------------> output[1] |
|          |           +-----------+-+
|  input2, +-----------> output[2] |
|          |         +-----------+-+
|  input3} +---------> output[3] |
+----------+         +-----------+
x = torch.randn(5, 1)
y = torch.randn(5, 1)
z = torch.randn(2, 1)

print(nn.JoinTable(1):forward{x, y})//10*1
print(nn.JoinTable(2):forward{x, y})//5*2
module = nn.JoinTable(2, 2)

x = torch.randn(3, 1)
y = torch.randn(3, 1)

mx = torch.randn(2, 3, 1)
my = torch.randn(2, 3, 1)

print(module:forward{x, y})//3*2
print(module:forward{mx, my})//2*3*2

6.MixtureTable:

input :table{gater, experts}

output = G[1]*E[1] + G[2]*E[2] + ... + G[n]*E[n]

where dim = 1n = E:size(dim) = G:size(dim) and G:dim() == 1. Note that E:dim() >= 2, such that output:dim() = E:dim() - 1.

experts = nn.ConcatTable()
for i = 1, n do
   local expert = nn.Sequential()
   expert:add(nn.Linear(3, 4))
   expert:add(nn.Tanh())
   expert:add(nn.Linear(4, 5))
   expert:add(nn.Tanh())
   experts:add(expert)
end

gater = nn.Sequential()
gater:add(nn.Linear(3, 7))
gater:add(nn.Tanh())
gater:add(nn.Linear(7, n))
gater:add(nn.SoftMax())

trunk = nn.ConcatTable()
trunk:add(gater)
trunk:add(experts)

moe = nn.Sequential()
moe:add(trunk)
moe:add(nn.MixtureTable())

Forwarding a batch of 2 examples gives us something like this:

> =moe:forward(torch.randn(2, 3))
-0.2152  0.3141  0.3280 -0.3772  0.2284
 0.2568  0.3511  0.0973 -0.0912 -0.0599
[torch.DoubleTensor of dimension 2x5]
7.SelectTable:
> input = {torch.randn(2, 3), torch.randn(2, 1)}
> =nn.SelectTable(1):forward(input)
-0.3060  0.1398  0.2707
 0.0576  1.5455  0.0610
> input = {torch.randn(2, 3), {torch.randn(2, 1), {torch.randn(2, 2)}}}
> gradInput = nn.SelectTable(1):backward(input, torch.randn(2, 3))

> =gradInput
{
  1 : DoubleTensor - size: 2x3
  2 :
    {
      1 : DoubleTensor - size: 2x1
      2 :
        {
          1 : DoubleTensor - size: 2x2
        }
    }
}

8.NarrowTable:輸入表,輸出從offset開始的length個子表

module = NarrowTable(offset [, length])

> input = {torch.randn(2, 3), torch.randn(2, 1), torch.randn(1, 2)}
> =nn.NarrowTable(2,2):forward(input)
{
  1 : DoubleTensor - size: 2x1
  2 : DoubleTensor - size: 1x2
}

> =nn.NarrowTable(1):forward(input)
{
  1 : DoubleTensor - size: 2x3
}

9.FlattenTable:扁平後子表的序號是按照DFS後序遍歷的順序

x = {torch.rand(1), {torch.rand(2), {torch.rand(3)}}, torch.rand(4)}
print(x)
print(nn.FlattenTable():forward(x))
{
  1 : DoubleTensor - size: 1
  2 :
    {
      1 : DoubleTensor - size: 2
      2 :
        {
          1 : DoubleTensor - size: 3
        }
    }
  3 : DoubleTensor - size: 4
}
{
  1 : DoubleTensor - size: 1
  2 : DoubleTensor - size: 2
  3 : DoubleTensor - size: 3
  4 : DoubleTensor - size: 4
}

10.PariwiseDistance:

module = PairwiseDistance(p) //p-norm

11.DotProduct

mlp = nn.DotProduct()
x = torch.Tensor({1, 2, 3})
y = torch.Tensor({4, 5, 6})
print(mlp:forward({x, y}))

12.CosineDistance:

mlp = nn.CosineDistance()
x = torch.Tensor({1, 2, 3})
y = torch.Tensor({4, 5, 6})
print(mlp:forward({x, y}))

13.CriterionTable:wraps a Criterion module so that it can accept a table of inputs.

mlp = nn.CriterionTable(nn.MSECriterion())
x = torch.randn(5)
y = torch.randn(5)
print(mlp:forward{x, x})//0
print(mlp:forward{x, y})//1.9028918413199

14.CAddTable,CSubTable,CMulTable,CDivTable,CMaxTable,CMinTable