centernet的資料增強操作--仿射變換
其實在這裡也分析過。奈何當初寫的程式碼不知道哪裡去了;
先看下效果圖:
從圖上可以看到,在原圖隨機確定的三個點都對映到變換之後的圖,然後這三點包圍的外接矩形區域在仿射變換之後都是肯定在的。整體呈現出平移縮放放大的效果。
畫圖改動程式碼如下:
在CenterNet-master/src/lib/utils/image.py複製函式get_affine_transform,返回src和dst三對點。
def get_affine_transform_point_src_dst(center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0): if not isinstance(scale, np.ndarray) and not isinstance(scale, list): scale = np.array([scale, scale], dtype=np.float32) scale_tmp = scale src_w = scale_tmp[0] dst_w = output_size[0] dst_h = output_size[1] rot_rad = np.pi * rot / 180 src_dir = get_dir([0, src_w * -0.5], rot_rad) dst_dir = np.array([0, dst_w * -0.5], np.float32) src = np.zeros((3, 2), dtype=np.float32) dst = np.zeros((3, 2), dtype=np.float32) src[0, :] = center + scale_tmp * shift src[1, :] = center + src_dir + scale_tmp * shift dst[0, :] = [dst_w * 0.5, dst_h * 0.5] dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir src[2:, :] = get_3rd_point(src[0, :], src[1, :]) dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) if inv: trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) else: trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) return trans, src, dst
在/CenterNet-master/src/lib/datasets/sample/ctdet.py中,畫圖, 新增show_3pt函式
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch.utils.data as data
import numpy as np
import torch
import json
import cv2
import os
from utils.image import flip, color_aug
from utils.image import get_affine_transform, affine_transform, get_affine_transform_point_src_dst
from utils.image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian
from utils.image import draw_dense_reg
import math
def show_3pt(src_img, inp, src_3pt, dst_3pt):
h,w,c = src_img.shape
x = src_3pt[:, 0]
y = src_3pt[:, 1]
min_x = np.min(x)
min_y = np.min(y)
width_new = w
height_new = h
if min_x < 0:
width_new += (-min_x)
src_3pt[:, 0] = src_3pt[:, 0] + (-min_x)
if min_y < 0:
height_new += (-min_y)
src_3pt[:, 1] = src_3pt[:, 1] + (-min_y)
start_x, start_y = 0, 0
if min_x < 0:
start_x = -min_x
if min_y < 0:
start_y = -min_y
new_img = np.zeros([int(height_new + 2), int(width_new + 2), int(c)], dtype=np.uint8)
new_img[int(start_y): int(start_y+h), int(start_x):int(start_x+w), :] = src_img.astype(np.uint8)
for cnt in range(3):
pt = (src_3pt[cnt][0], src_3pt[cnt][1])
# print("pt=", pt)
cv2.circle(new_img, pt, 14, (0, 0, 255), -1)
for cnt in range(3):
pt = (dst_3pt[cnt][0], dst_3pt[cnt][1])
# print("pt=", pt)
cv2.circle(inp, pt, 14, (0, 255, 255), -1)
cv2.imshow("new_img", new_img)
cv2.imshow("inp", inp)
cv2.imshow("src_img", src_img)
cv2.waitKey(0)
在這裡呼叫:
def __getitem__(self, index):
img_id = self.images[index]
file_name = self.coco.loadImgs(ids=[img_id])[0]['file_name']
img_path = os.path.join(self.img_dir, file_name)
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
anns = self.coco.loadAnns(ids=ann_ids)
num_objs = min(len(anns), self.max_objs)
img = cv2.imread(img_path)
height, width = img.shape[0], img.shape[1]
c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32)
if self.opt.keep_res:#False
input_h = (height | self.opt.pad) + 1
input_w = (width | self.opt.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
else:
s = max(img.shape[0], img.shape[1]) * 1.0
input_h, input_w = self.opt.input_h, self.opt.input_w
flipped = False
if self.split == 'train':
if not self.opt.not_rand_crop:#yes
s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
w_border = self._get_border(128, img.shape[1])
h_border = self._get_border(128, img.shape[0])
c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)
c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)
else:
sf = self.opt.scale
cf = self.opt.shift
c[0] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
c[1] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
if np.random.random() < self.opt.flip:
flipped = True
img = img[:, ::-1, :]
c[0] = width - c[0] - 1
trans_input, src_3pt, dst_3pt = get_affine_transform_point_src_dst(
c, s, 0, [input_w, input_h])
inp = cv2.warpAffine(img, trans_input,
(input_w, input_h),
flags=cv2.INTER_LINEAR)
show_3pt(img, inp, src_3pt, dst_3pt)
這裡其實關鍵的是確定三對點。兩個關鍵引數c和s
s = max(img.shape[0], img.shape[1]) * 1.0
if self.split == 'train':
if not self.opt.not_rand_crop:#yes
s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
w_border = self._get_border(128, img.shape[1])
h_border = self._get_border(128, img.shape[0])
c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border) #w_border = 128
c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border) #h_border = 128
這裡的c就是代表center的意思,影象周圍去掉128內圈就是c的範圍, s是影象最長邊然後隨機的乘以[0.6,1.4,0.1]
這三對點第一個點就是c為中心點
def get_affine_transform_point_src_dst(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
scale = np.array([scale, scale], dtype=np.float32)
scale_tmp = scale
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans, src, dst
然後第二個點是沿著c向上src_w * -0.5
src_dir = get_dir([0, src_w * -0.5], rot_rad)
src[1, :] = center + src_dir + scale_tmp * shift
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
這裡看到圖片有黑邊就是因為這裡的src_w * -0.5大於c的y,就導致y-0.5×src_w為負數。
第三對點
def get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
這裡是根據前面2個點來計算得到的。這裡其實很簡單,比如src_pt[0]=[500,500], src_pt[1]=[500,250]
那麼direct=[0, 250]
return( [500,250] + [-250, 0])
即[250,250]
有沒有發現!這裡其實就是之前0-->1的時候向上偏移了比如h,然後這裡在1的基礎上又向左偏移h。
所以,以上就是三對點產生的過程!
產生黑邊就是因為向上的偏移量,src_w * -0.5大於c的y,!左邊的黑邊就是因為src_w * -0.5 大於c的x。
s = max(img.shape[0], img.shape[1]) * 1.0
s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
src_dir = get_dir([0, src_w * -0.5], rot_rad) #這裡src_w就是s