1. 程式人生 > 其它 >pytorch逆亞畫素

pytorch逆亞畫素

技術標籤:神經網路

#-*-encoding:utf-8-*-
"""
# function/功能 : 
# @File : 測試亞畫素.py 
# @Time : 2021/1/26 9:33 
# @Author : kf
# @Software: PyCharm
"""
import torch
seed=10
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True



# 逆亞畫素卷積1
def de_subpix(y): # print('索引為偶數的項:', demo_list[::2]) # print('索引為奇數的項:', demo_list[1::2]) d1=y[:, :, ::2,::2] d2=y[:, :, 1::2,::2] d3=y[:, :, ::2,1::2] d4=y[:, :, 1::2,1::2] out = torch.cat([d1, d2, d3, d4], 1) return out # 逆亞畫素卷積2 def de_subpix2(y): (b, c, h, w)
= y.shape h1 = int(h // 2) w1 = int(w // 2) d1 = torch.zeros((b, c, h1, w1)) d2 = torch.zeros((b, c, h1, w1)) d3 = torch.zeros((b, c, h1, w1)) d4 = torch.zeros((b, c, h1, w1)) for i in range(0, h1, 1): for j in range(0, w1, 1): d1[:, :, i, j] = y[:, :, 2 * i,
2 * j] d2[:, :, i, j] = y[:, :, 2 * i + 1, 2 * j] d3[:, :, i, j] = y[:, :, 2 * i, 2 * j + 1] d4[:, :, i, j] = y[:, :, 2 * i + 1, 2 * j + 1] out = torch.cat([d1, d2, d3, d4], 1) # print(out.shape) return out # 逆亞畫素卷積3 def de_pixelshuffle(input, downscale_factor): # channal batch_size, channels, in_height, in_width = input.size() out_height = in_height // downscale_factor out_width = in_width // downscale_factor input_view = input.contiguous().view(batch_size, channels, out_height, downscale_factor, out_width, downscale_factor) channels = channels *downscale_factor ** 2 shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() shuffle_out = shuffle_out.view(batch_size, channels, out_height, out_width) return shuffle_out test_x = (torch.rand(1, 16, 16, 16)) print('test_x: {}'.format(test_x.shape)) out1=de_pixelshuffle(test_x,2) print('output: {}'.format(out1.shape)) out=de_subpix(test_x) out2=de_subpix2(test_x) print(out) print('output: {}'.format(out.shape)) # 亞畫素卷積 ps = torch.nn.PixelShuffle(2) outup=ps(test_x) print('outup: {}'.format(outup.shape))

在這裡插入圖片描述
在這裡插入圖片描述

在使用3個不同逆亞畫素過程中,發現

de_subpix和de_subpix2結果相同,de_pixelshuffle結果不同,這是因為de_pixelshuffle被壓縮成一維,不能夠提取準確位置,因此不能夠使用進行逆亞畫素。

不使用for迴圈,是因為索引更快。

時間對比,單位是s:

de_subpix:0.0
de_subpix2:0.00401616096496582

通過::x進行索引,得到x的整數倍索引

demo_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
x=demo_list[::2]
demo_list[1::2]
Out[4]: [1, 3, 5, 7, 9]
demo_list[::3]
Out[5]: [0, 3, 6, 9]
demo_list[1::3]
Out[6]: [1, 4, 7]
demo_list[2::3]
Out[7]: [2, 5, 8]

最終採用方法:de_subpix

參考1:https://blog.csdn.net/aaa958099161/article/details/90230541?utm_medium=distribute.pc_relevant.none-task-blog-searchFromBaidu-2.control&depth_1-utm_source=distribute.pc_relevant.none-task-blog-searchFromBaidu-2.control

參考2:https://blog.csdn.net/qq_38818384/article/details/106904989