gluon訓練出的模型轉成mx.mode.Module可用的symbol
阿新 • • 發佈:2018-12-13
6. 儲存成Symbol格式的網路和引數(重點)
要注意儲存網路引數的時候,需要net.collect_params().save()
這樣儲存,而不是net.save_params()
儲存
最新版的mxnet已經有可以匯出到symbol格式下的介面了。需要mxnet版本在20171015以上
下面示例程式碼也已經改成新版的儲存,載入方式
#新版本的儲存方式
net.export('Gluon_FashionMNIST')
7. 使用Symbol載入網路並繫結
symnet = mx.symbol.load('Gluon_FashionMNIST-symbol.json')
mod = mx.mod.Module(symbol=symnet, context=mx.cpu())
mod.bind(data_shapes=[('data' , (1, 1, 28, 28))])
mod.load_params('Gluon_FashionMNIST-0000.params')
Batch = namedtuple('Batch', ['data'])
8. 預測試試看效果
img,label = fashion_test[random.randint(0, 60000)]
data = img.transpose([2,0,1])
data = data.reshape([1,1,28,28])
mod.forward(Batch([data]))
out = mod.get_outputs()
prob = out[0]
predicted_labels = prob.argmax(axis=1 )
plt.imshow(img.reshape((28, 28)).asnumpy())
plt.axis('off')
plt.show()
print('predicted labels:',get_text_labels(predicted_labels.asnumpy()))
print('true labels:',get_text_labels([label]))