1. 程式人生 > >使用python建立mxnet操作符(網路層)

使用python建立mxnet操作符(網路層)

  1  # -*- coding: utf-8 -*-
  2   import mxnet as mx
  3   import numpy as np
  4   import logging
  5   6   class NDArrayReorg(mx.operator.NDArrayOp):
  7       def __init__(self, stride=2):
  8           super(NDArrayReorg, self).__init__(True)
  9           self.stride = stride
 10           self.fwd_kernel = None
11 self.bwd_kernel = None 12 13 def list_arguments(self): 14 return ['data'] 15 16 def list_outputs(self): 17 return ['output'] 18 19 def infer_shape(self, in_shape): 20 data_shape = in_shape[0] 21 output_shape = [in_shape[0][0], in_shape[0][1]*4 22
, in_shape[0][2]/self.stride, in_shape[0][3]/self.stride] 23 24 return [data_shape], [output_shape] 25 26 def forward(self, in_data, out_data): 27 x = in_data[0] 28 y = out_data[0] 29 if self.fwd_kernel is
None: 30 self.fwd_kernel = mx.rtc('reorg',[('x',x)],[('y',y)],""" 31 int i = threadIdx.x + blockIdx.x*blockDim.x ; 32 int yw=y_dims[3]; 33 int yh = y_dims[2]; 34 int N = yw*yh; 35 int xw=x_dims[3]; 36 int xh = x_dims[2]; 37 int len_block = x_dims[2]*x_dims[3]; 38 for(int j =0; j<xh; j+=2) 39 for(int k=0; k<xw; k+=2) 40 { int t=j/2; 41 y[i*len_block+t*yw+k/2] = x[i*len_block+j*xw+k]; 42 y[i*len_block+t*yw+k/2+N] = x[i*len_block + j*xw+k+1]; 43 y[i*len_block+t*yw+k/2+2*N] = x[i*len_block +(j+1)*xw+k]; 44 y[i*len_block+t*yw+k/2+3*N] = x[i*len_block +(j+1)*xw+k+1]; 45 } 46 """) 47 self.fwd_kernel.push([x],[y],(x.shape[0]*x.shape[1],1,1),(1,1,1)) 48 49 def backward(self, out_grad, in_data, out_data, in_grad): 50 y = out_grad[0] 51 dx = in_grad[0] 52 if self.bwd_kernel is None: 53 self.bwd_kernel = mx.rtc('reorg_grad',[('y',y)],[('dx', dx)],""" 54 int i = threadIdx.x + blockIdx.x * blockDim.x; 55 int yh = y_dims[2]; 56 int yw = y_dims[3]; 57 int N = yw*yh; 58 int old_block = dx_dims[2]*dx_dims[3]; 59 for(int k=0;k<4;++k) 60 for(int j=0; j<yw; ++j) 61 for(int t=0; t<yh; ++t){ 62 dx[i*old_block+2*j*yw+t*2+k]=y[i*old_block+k*N+j*yw+t]; 63 } 64 """) 65 self.bwd_kernel.push([y],[dx],(y.shape[0]*y.shape[1]/4,1,1),(1,1,1)) 66 67 mnist = mx.test_utils.get_mnist() 68 batch_size = 100 69 train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True) 70 val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) 71 72 73 data = mx.sym.var('data') 74 conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20) 75 tanh1 = mx.sym.Activation(data=conv1, act_type="tanh") 76 # pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2)) 77 78 reorg = NDArrayReorg(stride=2) 79 reg = reorg(data=tanh1, name='reorg') 80 conv2 = mx.sym.Convolution(data=reg, kernel=(5,5), num_filter=20) 81 tanh2 = mx.sym.Activation(data=conv2, act_type="tanh") # 80x8x8 82 83 conv2 = mx.sym.Convolution(data=tanh2, kernel=(5,5), num_filter=50) 84 tanh2 = mx.sym.Activation(data=conv2, act_type="tanh") 85 # pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2)) 86 87 flatten = mx.sym.flatten(data=tanh2) 88 fc1 = mx.sym.FullyConnected(data=flatten,num_hidden=500) 89 tanh3 = mx.sym.Activation(data=fc1, act_type="tanh") 90 91 fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10) 92 93 mynet = mx.sym.SoftmaxOutput(data=fc2, name='softmax') 94 95 print(mynet.infer_shape(data=(100,1,28,28))) 96 mynet_model = mx.mod.Module(symbol=mynet, context=mx.gpu()) 97 98 mynet_model.fit(train_iter, 99 eval_data=val_iter, 100 optimizer='sgd', 101 optimizer_params = {'learning_rate':0.1}, 102 eval_metric='acc', 103 batch_end_callback=mx.callback.Speedometer(100,100), 104 num_epoch=10) 105 106 test_iter = mx.io.NDArrayIter(mnist['test_data'], None, batch_size) 107 prob = mynet_model.predict(test_iter) 108 test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) 109 # predict accuracy for lenet 110 acc = mx.metric.Accuracy() 111 mynet_model.score(test_iter, acc) 112 print(acc) # 網路是隨便構建的,引數也是隨便選的,所以出來的值並沒有什麼參考價值,只是為了驗證呼叫mx.rtc建立cuda的kernel