1. 程式人生 > >mxnet——模型載入與儲存

mxnet——模型載入與儲存

一、載入模型與pretrain模型network相同

# loading predict module
data_shape_G = 96
Batch = namedtuple('Batch',['data'])
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix=r"~/meh_cla",epoch=2)

train,val = get_iterators(batch_size=batch_size, data_shape=(3, 96, 96))
train = Multi_mnist_iterator(
train) val = Multi_mnist_iterator(val) model = mx.mod.Module( # load pre train model symbol=sym, context=device, data_names=['data'], label_names=['softmax1_label','softmax2_label','softmax3_label'] # network structure ) model.bind(data_shapes=train.provide_data, label_shapes=
train.provide_label) model.set_params(arg_params, aux_params, allow_missing=True) model.fit(train, val, optimizer_params={'learning_rate': lr, 'momentum': 0.9}, num_epoch=num_epochs, eval_metric=MAE_zz(name="mae"), batch_end_callback=mx.callback.Speedometer(batch_size,
2), epoch_end_callback=checkpoint )

二、載入模型與pretrain模型network不同
在這裡插入圖片描述
三、模型的儲存

# 使用 checkpoint callback 在每個 epoch 之後儲存一次引數。
# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)
mod = mx.mod.Module(symbol=net)
mod.fit(train_iter, num_epoch=5, epoch_end_callback=checkpoint)

# 先看下fit部分的程式碼
# sync aux params across devices
arg_params, aux_params = self.get_params()
self.set_params(arg_params, aux_params)
if epoch_end_callback is not None:
    for callback in _as_list(epoch_end_callback):
        callback(epoch, self.symbol, arg_params, aux_params)

參考博文

https://blog.csdn.net/u012436149/article/details/78174260?utm_source=blogxgwz7