1. 程式人生 > >mxnet簡單網路示例——深度學習

mxnet簡單網路示例——深度學習

#network
data = mx.sym.Variable('data')
fc1=mx.sym.FullyConnected(data,name='fc1',num_hidden=128)
act1=mx.sym.Activation(fc1,name='relu1',act_type='relu')
fc2=mx.sym.FullyConnected(act1,name='fc2',num_hidden=256)
act2=mx.sym.Activation(fc2,name='relu2',act_type='relu')
fc3=mx.sym.FullyConnected(act2,name='fc3',num_hidden=128)
act3=mx.sym.Activation(fc3,name='relu3',act_type='relu')
fc4=mx.sym.FullyConnected(act3,name='fc4',num_hidden=2)
out=mx.sym.SoftmaxOutput(fc4,name='softmax')
mod = mx.mod.Module(out)
#訓練、測試集劃分
split=-10000
x_valid=x_train[split:]
y_valid=y_train[split:]
x_train = x_train[:split]
y_train = y_train[:split]
#train
x_arr = mx.io.NDArrayIter(x_train.as_matrix(),batch_size=128,label=y_train.as_matrix())
mod.bind(data_shapes=x_arr.provide_data,label_shapes=x_arr.provide_label)
mod.init_params()
mod.fit(x_arr,num_epoch=10)
mod.score(valid,['acc','mse'],batch_end_callback=mx.callback.Speedometer(batch_size=512,frequent=10))