1. 程式人生 > >【MXNet Gluon】使用預訓練好的模型fine-tune

【MXNet Gluon】使用預訓練好的模型fine-tune

finetune關鍵程式碼

prenet=ResNet(466)
net=ResNet(3400)
ctx = [mx.gpu(i) for i in range(3)]
if finetune ==1:
    prenet.load_params('params/net-%d.params' % (start_iter),ctx)
    #features為需要保留的模型引數,output為修改為新資料集類別數的Dense層
    net.features=prenet.features
    net.output.initialize(mx.init.Xavier(), ctx)
else
: net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx)

ResNet程式碼參考
class ResNetV2(HybridBlock):

def __init__(self, block, layers, channels, classes=1000, thumbnail=False, **kwargs):
    super(ResNetV2, self).__init__(**kwargs)
    assert len(layers) == len(channels) - 1
    with self.name_scope():
        self.features = nn.HybridSequential(prefix='')
        self.features.add(nn.BatchNorm(scale=False, center=False))
        if thumbnail:
            self.features.add(_conv3x3(channels[0], 1, 0))
        else:
            self.features.add(nn.Conv2D(channels[0], 7, 2, 3, use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.MaxPool2D(3, 2, 1))

        in_channels = channels[0]
        for i, num_layer in enumerate(layers):
            stride = 1 if i == 0 else 2
            self.features.add(self._make_layer(block, num_layer, channels[i+1],
                                               stride, i+1, in_channels=in_channels))
            in_channels = channels[i+1]
        self.features.add(nn.BatchNorm())
        self.features.add(nn.Activation('relu'))
        self.features.add(nn.GlobalAvgPool2D())
        self.features.add(nn.Flatten())

        self.output = nn.Dense(classes, in_units=in_channels)

def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0):
    layer = nn.HybridSequential(prefix='stage%d_'%stage_index)
    with layer.name_scope():
        layer.add(block(channels, stride, channels != in_channels, in_channels=in_channels,
                        prefix=''))
        for _ in range(layers-1):
            layer.add(block(channels, 1, False, in_channels=channels, prefix=''))
    return layer

def hybrid_forward(self, F, x):
    x = self.features(x)
    x = self.output(x)
    return x