mxnet——模型載入與儲存
阿新 • • 發佈:2018-12-10
一、載入模型與pretrain模型network相同
# loading predict module
data_shape_G = 96
Batch = namedtuple('Batch',['data'])
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix=r"~/meh_cla",epoch=2)
train,val = get_iterators(batch_size=batch_size, data_shape=(3, 96, 96))
train = Multi_mnist_iterator( train)
val = Multi_mnist_iterator(val)
model = mx.mod.Module( # load pre train model
symbol=sym,
context=device,
data_names=['data'],
label_names=['softmax1_label','softmax2_label','softmax3_label'] # network structure
)
model.bind(data_shapes=train.provide_data, label_shapes= train.provide_label)
model.set_params(arg_params, aux_params, allow_missing=True)
model.fit(train, val,
optimizer_params={'learning_rate': lr, 'momentum': 0.9},
num_epoch=num_epochs,
eval_metric=MAE_zz(name="mae"),
batch_end_callback=mx.callback.Speedometer(batch_size, 2),
epoch_end_callback=checkpoint
)
二、載入模型與pretrain模型network不同
三、模型的儲存
# 使用 checkpoint callback 在每個 epoch 之後儲存一次引數。
# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)
mod = mx.mod.Module(symbol=net)
mod.fit(train_iter, num_epoch=5, epoch_end_callback=checkpoint)
# 先看下fit部分的程式碼
# sync aux params across devices
arg_params, aux_params = self.get_params()
self.set_params(arg_params, aux_params)
if epoch_end_callback is not None:
for callback in _as_list(epoch_end_callback):
callback(epoch, self.symbol, arg_params, aux_params)
參考博文
https://blog.csdn.net/u012436149/article/details/78174260?utm_source=blogxgwz7