影象上取樣
阿新 • • 發佈:2018-11-25
記錄常使用的函式避免遺忘
def upsample(x,scale=2,features=64,activation=tf.nn.relu): assert scale in [2,3,4] x = slim.conv2d(x,features,[3,3],activation_fn=activation) if scale == 2: ps_features = 3*(scale**2) #filter個數,[3,3]卷積核維度 x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation) #x = slim.conv2d_transpose(x,ps_features,6,stride=1,activation_fn=activation) x = PS(x,2,color=True) elif scale == 3: ps_features =3*(scale**2) #特徵圖個數發生改變 64變成12 x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation) #x = slim.conv2d_transpose(x,ps_features,9,stride=1,activation_fn=activation) x = PS(x,3,color=True) elif scale == 4: ps_features = 3*(2**2) for i in range(2): x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation) #x = slim.conv2d_transpose(x,ps_features,6,stride=1,activation_fn=activation) x = PS(x,2,color=True) return x def PS(X, r, color=False): if color: Xc = tf.split(X, 3, 3) #將x在第3個維度切成3份 10*50*50*12切割成 10*50*50*4 #value:準備切分的張量; num_or_size_splits:準備切成幾份; axis : 準備在第幾個維度上進行切割 X = tf.concat([_phase_shift(x, r) for x in Xc],3) #對每一個通道填充畫素 else: X = _phase_shift(X, r) return X def _phase_shift(I, r): bsize, a, b, c = I.get_shape().as_list()# bsize = 10, a=50, b=50, c=4 bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim X = tf.reshape(I, (bsize, a, b, r, r)) X = tf.transpose(X, (0, 1, 2, 4, 3)) # bsize, a, b, 1, 1 X = tf.split(X, a, 1) # a * [bsize, b, r, r] #tf.squeeze函式 #從tensor中刪除所有大小是1的維度,axis可以用來指定要刪掉的為1的維度,但指定的維度必須確保其是1,否則會報錯 X = tf.concat([tf.squeeze(x, axis=1) for x in X],2) # bsize, b, a*r, r X = tf.split(X, b, 1) # b * [bsize, a*r, r] X = tf.concat([tf.squeeze(x, axis=1) for x in X],2) # bsize, a*r, b*r return tf.reshape(X, (bsize, a*r, b*r, 1)) def my_anti_shuffle(input_image, ratio): shape = input_image.shape ori_height = int(shape[0]) ori_width = int(shape[1]) ori_channels = int(shape[2]) if ori_height % ratio != 0 or ori_width % ratio != 0: print("Error! Height and width must be divided by ratio!") return height = ori_height // ratio width = ori_width // ratio channels = ori_channels * ratio * ratio anti_shuffle = np.zeros((height, width, channels), dtype=np.uint8) for c in range(0, ori_channels): for x in range(0, ratio): for y in range(0, ratio): anti_shuffle[:,:,c * ratio * ratio + x * ratio + y] = input_image[x::ratio, y::ratio, c]#每ratio取樣一次 return anti_shuffle def shuffle(input_image, ratio): shape = input_image.shape height = int(shape[0]) * ratio width = int(shape[1]) * ratio channels = int(shape[2]) / ratio / ratio shuffled = np.zeros((height, width, channels), dtype=np.uint8) for i in range(0, height): for j in range(0, width): for k in range(0, channels): shuffled[i,j,k] = input_image[i / ratio, j / ratio, k * ratio * ratio + (i % ratio) * ratio + (j % ratio)] return shuffled