tensorflow實現2D小波變化dwt和小波逆變換idwt,梯度可以反向傳播
阿新 • • 發佈:2019-01-05
使用tensorflow實現小波變化和小波逆變換,並且梯度可以反向傳播。因此可以方便的將小波變化嵌入到網路結構中去。
本程式碼參考pytorch實現的小波變化移植至tensorflow。pytorch實現連結:https://github.com/fbcotter/pytorch_wavelets。
實現中存在的一個的問題是tensorflow不能實現分組卷積,因此這裡只能採用迴圈一個2D卷積來實現,所以會增加時間複雜度。關於分組卷積,在tensorflow的issue中有討論,連結:https://github.com/tensorflow/tensorflow/issues/3332。
但是目前本人在tensorflow上還沒有找到很好的解決方法,即使後來實現了用3D卷積來實現,但是過多的tf.reshape、tf.slice和tf.concat操作,所以依然沒有解決問題。希望有更好的解決分組卷積的小夥伴們教教我。
下面的程式碼,包括二維的小標變換和小波逆變換以及測試程式碼。注意的是,這裡的函式間傳遞的都是4-Dtensor。這裡必須安裝pywt才能使用
# -*- coding: utf-8 -*- # @Author : Chen Meiya # @time : 2018/12/9 21:46 # @File : tf_dwt_release.py # @Software : PyCharm import numpy as np import tensorflow as tf from PIL import Image import pywt import time import matplotlib.pyplot as plt # C is channel # just suit for J=1 def tf_dwt(yl, in_size, wave='db3'): w = pywt.Wavelet(wave) ll = np.outer(w.dec_lo, w.dec_lo) lh = np.outer(w.dec_hi, w.dec_lo) hl = np.outer(w.dec_lo, w.dec_hi) hh = np.outer(w.dec_hi, w.dec_hi) d_temp = np.zeros((np.shape(ll)[0], np.shape(ll)[1], 1, 4)) d_temp[::-1, ::-1, 0, 0] = ll d_temp[::-1, ::-1, 0, 1] = lh d_temp[::-1, ::-1, 0, 2] = hl d_temp[::-1, ::-1, 0, 3] = hh filts = d_temp.astype('float32') filts = np.copy(filts) filter = tf.convert_to_tensor(filts) sz = 2 * (len(w.dec_lo) // 2 - 1) with tf.variable_scope('DWT'): # Pad odd length images if in_size[0] % 2 == 1 and tf.shape(yl)[1] % 2 == 1: yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz + 1], [sz, sz + 1], [0, 0]]), mode='reflect') elif in_size[0] % 2 == 1: yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz + 1], [sz, sz], [0, 0]]), mode='reflect') elif in_size[1] % 2 == 1: yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz], [sz, sz + 1], [0, 0]]), mode='reflect') else: yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz], [sz, sz], [0, 0]]), mode='reflect') # group convolution outputs = tf.nn.conv2d(yl[:, :, :, 0:1], filter, padding='VALID', strides=[1, 2, 2, 1]) for channel in range(1, int(yl.shape.dims[3])): temp = tf.nn.conv2d(yl[:, :, :, channel:channel+1], filter, padding='VALID', strides=[1, 2, 2, 1]) outputs = tf.concat([outputs, temp], axis=3) return outputs def tf_idwt(y, wave='db3'): w = pywt.Wavelet(wave) ll = np.outer(w.rec_lo, w.rec_lo) lh = np.outer(w.rec_hi, w.rec_lo) hl = np.outer(w.rec_lo, w.rec_hi) hh = np.outer(w.rec_hi, w.rec_hi) d_temp = np.zeros((np.shape(ll)[0], np.shape(ll)[1], 1, 4)) d_temp[:, :, 0, 0] = ll d_temp[:, :, 0, 1] = lh d_temp[:, :, 0, 2] = hl d_temp[:, :, 0, 3] = hh filts = d_temp.astype('float32') filter = tf.convert_to_tensor(filts) s = 2 * (len(w.dec_lo) // 2 - 1) with tf.variable_scope('IWT'): out_size = tf.shape(y)[1] in_t = tf.slice(y, (0, 0, 0, 0), (tf.shape(y)[0], out_size, out_size, 4)) outputs = tf.nn.conv2d_transpose(in_t, filter, output_shape=[tf.shape(y)[0], 2*(out_size-1)+np.shape(ll)[0], 2*(tf.shape(y)[1]-1)+np.shape(ll)[0], 1], padding='VALID', strides=[1, 2, 2, 1]) for channels in range(4, int(y.shape.dims[-1]), 4): y_batch = tf.slice(y, (0, 0, 0, channels), (tf.shape(y)[0], out_size, out_size, 4)) out_t = tf.nn.conv2d_transpose(y_batch, filter, output_shape=[tf.shape(y)[0], 2*(out_size-1)+np.shape(ll)[0], 2*(out_size-1)+np.shape(ll)[0], 1], padding='VALID', strides=[1, 2, 2, 1]) outputs = tf.concat((outputs, out_t), axis=3) outputs = outputs[:, s: 2*(out_size-1)+np.shape(ll)[0]-s, s: 2*(out_size-1)+np.shape(ll)[0]-s, :] return outputs if __name__ == '__main__': # load images a = Image.open('22090.jpg') # change the image path X_n = np.array(a).astype('float32') X_n = X_n / 255 X_n = X_n[0:256, 0:256, :] X_t = np.zeros((1, 256, 256, 3), dtype='float32') X_t[0, :, :, :] = X_n[:, :, :] # test code sess = tf.Session() inputs = tf.placeholder(tf.float32, [None, None, None, 3], name='inputs') outputs_in = tf.placeholder(tf.float32, [None, None, None, 12], name='outputs') outputs = tf_dwt(inputs, in_size=[256, 256]) outputs_mex = tf_idwt(outputs_in) sess.run(tf.global_variables_initializer()) time_start = time.time() outputs_dwt = sess.run(outputs, feed_dict={inputs: X_t}) outputs_mex = sess.run(outputs_mex, feed_dict={outputs_in: outputs_dwt}) time_end = time.time() print('totally cost', time_end - time_start) # show the decomposition images plt.figure() plt.imshow(outputs_dwt[0, :, :, 0], cmap='gray') # pywt is the python library to dwt. If you are not install pywt, please annotate the code cA, (cH, cV, cD) = pywt.dwt2(X_n[:, :, 0], 'db3') # compare to the groundtruth plt.figure() plt.imshow(np.abs(cA-outputs_dwt[0, :, :, 0]), cmap='gray') plt.show()
- 可參見後面的第二篇部落格檢視優化後的版本