pytorch逆亞畫素
阿新 • • 發佈:2021-01-28
技術標籤:神經網路
#-*-encoding:utf-8-*-
"""
# function/功能 :
# @File : 測試亞畫素.py
# @Time : 2021/1/26 9:33
# @Author : kf
# @Software: PyCharm
"""
import torch
seed=10
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# 逆亞畫素卷積1
def de_subpix(y):
# print('索引為偶數的項:', demo_list[::2])
# print('索引為奇數的項:', demo_list[1::2])
d1=y[:, :, ::2,::2]
d2=y[:, :, 1::2,::2]
d3=y[:, :, ::2,1::2]
d4=y[:, :, 1::2,1::2]
out = torch.cat([d1, d2, d3, d4], 1)
return out
# 逆亞畫素卷積2
def de_subpix2(y):
(b, c, h, w) = y.shape
h1 = int(h // 2)
w1 = int(w // 2)
d1 = torch.zeros((b, c, h1, w1))
d2 = torch.zeros((b, c, h1, w1))
d3 = torch.zeros((b, c, h1, w1))
d4 = torch.zeros((b, c, h1, w1))
for i in range(0, h1, 1):
for j in range(0, w1, 1):
d1[:, :, i, j] = y[:, :, 2 * i, 2 * j]
d2[:, :, i, j] = y[:, :, 2 * i + 1, 2 * j]
d3[:, :, i, j] = y[:, :, 2 * i, 2 * j + 1]
d4[:, :, i, j] = y[:, :, 2 * i + 1, 2 * j + 1]
out = torch.cat([d1, d2, d3, d4], 1)
# print(out.shape)
return out
# 逆亞畫素卷積3
def de_pixelshuffle(input, downscale_factor): # channal
batch_size, channels, in_height, in_width = input.size()
out_height = in_height // downscale_factor
out_width = in_width // downscale_factor
input_view = input.contiguous().view(batch_size, channels, out_height, downscale_factor, out_width, downscale_factor)
channels = channels *downscale_factor ** 2
shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()
shuffle_out = shuffle_out.view(batch_size, channels, out_height, out_width)
return shuffle_out
test_x = (torch.rand(1, 16, 16, 16))
print('test_x: {}'.format(test_x.shape))
out1=de_pixelshuffle(test_x,2)
print('output: {}'.format(out1.shape))
out=de_subpix(test_x)
out2=de_subpix2(test_x)
print(out)
print('output: {}'.format(out.shape))
# 亞畫素卷積
ps = torch.nn.PixelShuffle(2)
outup=ps(test_x)
print('outup: {}'.format(outup.shape))
在使用3個不同逆亞畫素過程中,發現
de_subpix和de_subpix2結果相同,de_pixelshuffle結果不同,這是因為de_pixelshuffle被壓縮成一維,不能夠提取準確位置,因此不能夠使用進行逆亞畫素。
不使用for迴圈,是因為索引更快。
時間對比,單位是s:
de_subpix:0.0
de_subpix2:0.00401616096496582
通過::x進行索引,得到x的整數倍索引
demo_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
x=demo_list[::2]
demo_list[1::2]
Out[4]: [1, 3, 5, 7, 9]
demo_list[::3]
Out[5]: [0, 3, 6, 9]
demo_list[1::3]
Out[6]: [1, 4, 7]
demo_list[2::3]
Out[7]: [2, 5, 8]
最終採用方法:de_subpix
參考1:https://blog.csdn.net/aaa958099161/article/details/90230541?utm_medium=distribute.pc_relevant.none-task-blog-searchFromBaidu-2.control&depth_1-utm_source=distribute.pc_relevant.none-task-blog-searchFromBaidu-2.control
參考2:https://blog.csdn.net/qq_38818384/article/details/106904989