1. 程式人生 > >keras SegNet使用池化索引(pooling indices)

keras SegNet使用池化索引(pooling indices)

keras中不能直接使用池化索引。最近學習到SegNet(網上許多錯的,沒有用池化索引),其中下采樣上取樣用到此部分。此處用到自定義層。
在這裡插入圖片描述
完整測試程式碼如下。

"""
@author: LiShiHang
@software: PyCharm
@file: utils.py
@time: 2018/12/18 14:58
"""
from keras.engine import Layer
import keras.backend as K


class MaxPoolingWithArgmax2D(Layer):

    def __init__(
            self,
pool_size=(2, 2), strides=(2, 2), padding='same', **kwargs): super(MaxPoolingWithArgmax2D, self).__init__(**kwargs) self.padding = padding self.pool_size = pool_size self.strides = strides def call(self, inputs, **
kwargs): padding = self.padding pool_size = self.pool_size strides = self.strides if K.backend() == 'tensorflow': ksize = [1, pool_size[0], pool_size[1], 1] padding = padding.upper() strides = [1, strides[0], strides[1], 1] output,
argmax = K.tf.nn.max_pool_with_argmax( inputs, ksize=ksize, strides=strides, padding=padding) else: errmsg = '{} backend is not supported for layer {}'.format( K.backend(), type(self).__name__) raise NotImplementedError(errmsg) argmax = K.cast(argmax, K.floatx()) return [output, argmax] def compute_output_shape(self, input_shape): ratio = (1, 2, 2, 1) output_shape = [ dim // ratio[idx] if dim is not None else None for idx, dim in enumerate(input_shape)] output_shape = tuple(output_shape) return [output_shape, output_shape] def compute_mask(self, inputs, mask=None): return 2 * [None] class MaxUnpooling2D(Layer): def __init__(self, up_size=(2, 2), **kwargs): super(MaxUnpooling2D, self).__init__(**kwargs) self.up_size = up_size def call(self, inputs, output_shape=None): updates, mask = inputs[0], inputs[1] with K.tf.variable_scope(self.name): mask = K.cast(mask, 'int32') input_shape = K.tf.shape(updates, out_type='int32') # calculation new shape if output_shape is None: output_shape = ( input_shape[0], input_shape[1] * self.up_size[0], input_shape[2] * self.up_size[1], input_shape[3]) # calculation indices for batch, height, width and feature maps one_like_mask = K.ones_like(mask, dtype='int32') batch_shape = K.concatenate( [[input_shape[0]], [1], [1], [1]], axis=0) batch_range = K.reshape( K.tf.range(output_shape[0], dtype='int32'), shape=batch_shape) b = one_like_mask * batch_range y = mask // (output_shape[2] * output_shape[3]) x = (mask // output_shape[3]) % output_shape[2] feature_range = K.tf.range(output_shape[3], dtype='int32') f = one_like_mask * feature_range # transpose indices & reshape update values to one dimension updates_size = K.tf.size(updates) indices = K.transpose(K.reshape( K.stack([b, y, x, f]), [4, updates_size])) values = K.reshape(updates, [updates_size]) ret = K.tf.scatter_nd(indices, values, output_shape) return ret def compute_output_shape(self, input_shape): mask_shape = input_shape[1] return ( mask_shape[0], mask_shape[1] * self.up_size[0], mask_shape[2] * self.up_size[1], mask_shape[3] ) if __name__ == '__main__': import keras import numpy as np # input = keras.layers.Input((4, 4, 3)) # o = MaxPoolingWithArgmax2D()(input) # model = keras.Model(inputs=input, outputs=o) # outputs=o # model.compile(optimizer="adam", loss='categorical_crossentropy') # x = np.random.randint(0, 100, (3, 4, 4, 3)) # 除錯此處 # m = model.predict(x) # 除錯此處 # print(m) input = keras.layers.Input((4, 4, 3)) o = MaxPoolingWithArgmax2D()(input) o2 = MaxUnpooling2D()(o) model = keras.Model(inputs=input, outputs=o2) # outputs=o model.compile(optimizer="adam", loss='categorical_crossentropy') x = np.random.randint(0, 100, (3, 4, 4, 3)) # 除錯此處 m = model.predict(x) # 除錯此處 print(m)

感興趣的可除錯註釋處。
在這裡插入圖片描述