1. 程式人生 > >gluon訓練出的模型轉成mx.mode.Module可用的symbol

gluon訓練出的模型轉成mx.mode.Module可用的symbol

6. 儲存成Symbol格式的網路和引數(重點)

要注意儲存網路引數的時候,需要net.collect_params().save()這樣儲存,而不是net.save_params()儲存
最新版的mxnet已經有可以匯出到symbol格式下的介面了。需要mxnet版本在20171015以上
下面示例程式碼也已經改成新版的儲存,載入方式

#新版本的儲存方式
net.export('Gluon_FashionMNIST')

7. 使用Symbol載入網路並繫結

symnet = mx.symbol.load('Gluon_FashionMNIST-symbol.json')
mod = mx.mod.Module(symbol=symnet, context=mx.cpu())
mod.bind(data_shapes=[('data'
, (1, 1, 28, 28))]) mod.load_params('Gluon_FashionMNIST-0000.params') Batch = namedtuple('Batch', ['data'])

8. 預測試試看效果

img,label = fashion_test[random.randint(0, 60000)]
data = img.transpose([2,0,1])
data = data.reshape([1,1,28,28])
mod.forward(Batch([data]))
out = mod.get_outputs()
prob = out[0]
predicted_labels = prob.argmax(axis=1
) plt.imshow(img.reshape((28, 28)).asnumpy()) plt.axis('off') plt.show() print('predicted labels:',get_text_labels(predicted_labels.asnumpy())) print('true labels:',get_text_labels([label]))