mxnet簡單網路示例——深度學習
阿新 • • 發佈:2018-11-07
#network data = mx.sym.Variable('data') fc1=mx.sym.FullyConnected(data,name='fc1',num_hidden=128) act1=mx.sym.Activation(fc1,name='relu1',act_type='relu') fc2=mx.sym.FullyConnected(act1,name='fc2',num_hidden=256) act2=mx.sym.Activation(fc2,name='relu2',act_type='relu') fc3=mx.sym.FullyConnected(act2,name='fc3',num_hidden=128) act3=mx.sym.Activation(fc3,name='relu3',act_type='relu') fc4=mx.sym.FullyConnected(act3,name='fc4',num_hidden=2) out=mx.sym.SoftmaxOutput(fc4,name='softmax') mod = mx.mod.Module(out) #訓練、測試集劃分 split=-10000 x_valid=x_train[split:] y_valid=y_train[split:] x_train = x_train[:split] y_train = y_train[:split] #train x_arr = mx.io.NDArrayIter(x_train.as_matrix(),batch_size=128,label=y_train.as_matrix()) mod.bind(data_shapes=x_arr.provide_data,label_shapes=x_arr.provide_label) mod.init_params() mod.fit(x_arr,num_epoch=10) mod.score(valid,['acc','mse'],batch_end_callback=mx.callback.Speedometer(batch_size=512,frequent=10))