優化版本: tensorflow實現2D小波變化dwt和小波逆變換idwt
阿新 • • 發佈:2019-01-05
由於上上篇部落格寫了使用tensorflow實現2D小波變化dwt和小波逆變換idwt,但是實現的方法在速度上和資源佔用上實在堪憂,特別是在channel比較大的情況下。因此本人對於上次的程式碼進行了優化。
優化主要表現在兩個方面:
- 去掉原來用於調整尺寸的for迴圈結構,使用tf.slice等命令代替;
- 去掉原來的迴圈卷積結構,使用tensorflow3D卷積代替
分析
上述的兩種操作之所以能夠節省計算資源,提升速度。原因在於,tensorflow會在反向傳播的時候儲存下來每一個tensor操作的結果。例如,for迴圈64個tf.concat,那麼tensorflow就會儲存64個concat的反向梯度圖,分別為tf.concat_1…tf.concat_64(表述可能不嚴謹),儲存的這些結果都會佔用大量的計算資源,而這些對於計算並不是必要的。因此要節省計算資源,就是要使用盡量少的tensor操作來實現功能。tensorflow提供的tf.slice命令就可以完全替代原來迴圈的tf.concat結構,而反向傳播中只佔用了原來迴圈一次的資源。同樣的道理迴圈的卷積也是如此,雖然3D卷積也是消耗資源的,但是,相比之下還是優於迴圈結構的。
另外:此次的程式碼和上次還有一個小的區別,調整了卷積核的尺寸,實現DWT的同時加速。原來預設的基為db3,卷積核的尺寸為6,調整後的預設基為haar,卷積核尺寸為2。讀者可以根據自己的需要給定基。
程式碼
# -*- coding: utf-8 -*- # @Author : Cmy # @time : 2018/12/5 20:37 # @File : tf_dwt_3d_v2.py # @Software : PyCharm import numpy as np import tensorflow as tf from PIL import Image import pywt import time import matplotlib.pyplot as plt # C is channel # just suit for J=1 def tf_dwt(yl, wave='haar'): w = pywt.Wavelet(wave) ll = np.outer(w.dec_lo, w.dec_lo) lh = np.outer(w.dec_hi, w.dec_lo) hl = np.outer(w.dec_lo, w.dec_hi) hh = np.outer(w.dec_hi, w.dec_hi) d_temp = np.zeros((np.shape(ll)[0], np.shape(ll)[1], 1, 4)) d_temp[::-1, ::-1, 0, 0] = ll d_temp[::-1, ::-1, 0, 1] = lh d_temp[::-1, ::-1, 0, 2] = hl d_temp[::-1, ::-1, 0, 3] = hh filts = d_temp.astype('float32') filts = filts[None, :, :, :, :] filter = tf.convert_to_tensor(filts) sz = 2 * (len(w.dec_lo) // 2 - 1) with tf.variable_scope('DWT'): ### Pad odd length images # if in_size[0] % 2 == 1 and tf.shape(yl)[1] % 2 == 1: # yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz + 1], [sz, sz + 1], [0, 0]]), mode='reflect') # elif in_size[0] % 2 == 1: # yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz + 1], [sz, sz], [0, 0]]), mode='reflect') # elif in_size[1] % 2 == 1: # yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz], [sz, sz + 1], [0, 0]]), mode='reflect') # else: yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz], [sz, sz], [0, 0]]), mode='reflect') y = tf.expand_dims(yl, 1) inputs = tf.split(y, [1]*int(y.shape.dims[4]), 4) inputs = tf.concat([x for x in inputs], 1) outputs_3d = tf.nn.conv3d(inputs, filter, padding='VALID', strides=[1, 1, 2, 2, 1]) outputs = tf.split(outputs_3d, [1] * int(outputs_3d.shape.dims[1]), 1) outputs = tf.concat([x for x in outputs], 4) outputs = tf.reshape(outputs, (tf.shape(outputs)[0], tf.shape(outputs)[2], tf.shape(outputs)[3], tf.shape(outputs)[4])) return outputs def tf_idwt(y, wave='haar'): w = pywt.Wavelet(wave) ll = np.outer(w.rec_lo, w.rec_lo) lh = np.outer(w.rec_hi, w.rec_lo) hl = np.outer(w.rec_lo, w.rec_hi) hh = np.outer(w.rec_hi, w.rec_hi) d_temp = np.zeros((np.shape(ll)[0], np.shape(ll)[1], 1, 4)) d_temp[:, :, 0, 0] = ll d_temp[:, :, 0, 1] = lh d_temp[:, :, 0, 2] = hl d_temp[:, :, 0, 3] = hh filts = d_temp.astype('float32') filts = filts[None, :, :, :, :] filter = tf.convert_to_tensor(filts) s = 2 * (len(w.dec_lo) // 2 - 1) out_size = tf.shape(y)[1] with tf.variable_scope('IWT'): y = tf.expand_dims(y, 1) inputs = tf.split(y, [4] * int(int(y.shape.dims[4])/4), 4) inputs = tf.concat([x for x in inputs], 1) outputs_3d = tf.nn.conv3d_transpose(inputs, filter, output_shape=[tf.shape(y)[0], tf.shape(inputs)[1], 2*(out_size-1)+np.shape(ll)[0], 2*(out_size-1)+np.shape(ll)[0], 1], padding='VALID', strides=[1, 1, 2, 2, 1]) outputs = tf.split(outputs_3d, [1] * int(int(y.shape.dims[4])/4), 1) outputs = tf.concat([x for x in outputs], 4) outputs = tf.reshape(outputs, (tf.shape(outputs)[0], tf.shape(outputs)[2], tf.shape(outputs)[3], tf.shape(outputs)[4])) outputs = outputs[:, s: 2 * (out_size - 1) + np.shape(ll)[0] - s, s: 2 * (out_size - 1) + np.shape(ll)[0] - s, :] return outputs if __name__ == '__dwt__': # load images a = Image.open('12074.jpg') X_n = np.array(a).astype('float32') X_n = X_n / 255 X_n = X_n[0:256, 0:256, :] X_t = np.zeros((1, 256, 256, 3), dtype='float32') X_t[0, :, :, :] = X_n[:, :, :] X_tf = tf.convert_to_tensor(X_t) # convert to tensor sess = tf.Session() inputs = tf.placeholder(tf.float32, [None, None, None, 3], name='inputs') outputs_in = tf.placeholder(tf.float32, [None, None, None, 12], name='outputs') outputs = tf_dwt(inputs) outputs_mex = tf_idwt(outputs_in) sess.run(tf.global_variables_initializer()) time_start = time.time() outputs_dwt = sess.run(outputs, feed_dict={inputs: X_t}) outputs_mex = sess.run(outputs_mex, feed_dict={outputs_in: outputs_dwt}) time_end = time.time() print('totally cost', time_end - time_start) # show the decomposition images plt.figure() plt.imshow(outputs_dwt[0, :, :, 0], cmap='gray') plt.figure() plt.imshow(outputs_mex[0, :, :, 0], cmap='gray') # # # pywt cA, (cH, cV, cD) = pywt.dwt2(X_n[:, :, 0], 'haar') # show the pywt plt.figure() plt.imshow(np.abs(cH-outputs_dwt[0, :, :, 1]), cmap='gray') plt.figure() plt.imshow(np.abs(X_n[:, :, 1] - outputs_mex[0, :, :, 1]), cmap='gray') plt.show()