hourglass pytorch 實現
阿新 • • 發佈:2018-11-05
主要分為幾塊
1、資料集讀取
2、hg-model
3、training
4、程式碼主要來自於github上幾個 大佬的 程式碼的結合 @bearpaw 以及 @roytseng-tw 的訓練程式碼和 @anibali 的evaluation程式碼, 主要這兩位 大佬的程式碼 基本上和使用lua在torch7上的作者原始碼沒有什麼出入,是很好的復現
5、同時採用了hourglass原作者的 訓練集 驗證集 測試集 @umich-vl
7、同時我也會在github上放出caffe版本的hourglass實現,這個主要來自於RMPE這個論文的github。
8、我目前訓練結果在 MPII驗證集上 只能達到 89.3 閾值0.5
一、資料讀取
1、資料增廣
這裡主要涉及到了 crop 、scale 、flip 、rotate這幾個操作
""" Random """ def randn(): return random.gauss(0, 1) def rand(): return random.random() def rnd(x): '''umich hourglass mpii random function''' return max(-2 * x, min(2 * x, randn() * x)) """ Visualization """ def show_sample(img, label): # FIXME: color blending is not right, diff color for each joint nJoints = label.shape[0] white = np.ones((4,) + img.shape[1:3]) new_img = white.copy() new_img[:3] = img * 0.5 for i in range(nJoints): new_img += 0.5 * white * sktf.resize(label[i], img.shape[1:3], preserve_range=True) # print(label[i].max()) # plt.subplot(121) # plt.imshow(np.transpose(new_img, [1, 2, 0])) # plt.subplot(122) # plt.imshow(label[i]) # plt.show() return np.transpose(new_img, [1, 2, 0]) """ Label """ def create_label(imsize, pt, sigma, distro_type='Gaussian'): label = np.zeros(imsize) # Check that any part of the distro is in-bounds ul = np.math.floor(pt[0] - 3 * sigma), np.math.floor(pt[1] - 3 * sigma) br = np.math.floor(pt[0] + 3 * sigma), np.math.floor(pt[1] + 3 * sigma) # If not, return the blank label if ul[0] >= imsize[1] or ul[1] >= imsize[0] or br[0] < 0 or br[1] < 0: return label # Generate distro size = 6 * sigma + 1 x = np.arange(0, size, 1, float) y = x[:, np.newaxis] x0 = y0 = size // 2 '''Note: original torch impl: `local g = image.gaussian(size)` equals to `gaussian(size, sigma=0.25*size)` here ''' if distro_type == 'Gaussian': distro = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) elif distro_type == 'Cauchy': # IS THIS CORRECT ??? distro = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5) # distro = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) * np.pi) # Usable distro range distro_x = max(0, -ul[0]), min(br[0], imsize[1]) - ul[0] distro_y = max(0, -ul[1]), min(br[1], imsize[0]) - ul[1] assert (distro_x[0] >= 0 and distro_y[0] >= 0), '{}, {}'.format(distro_x, distro_y) # label range label_x = max(0, ul[0]), min(br[0], imsize[1]) label_y = max(0, ul[1]), min(br[1], imsize[0]) label[label_y[0]:label_y[1], label_x[0]:label_x[1]] = \ distro[distro_y[0]:distro_y[1], distro_x[0]:distro_x[1]] return label """ Flip """ def fliplr_labels(labels, matchedParts, joint_dim=1, width_dim=3): """fliplr the joint labels, defaults (B, C, H, W) """ # flip horizontally labels = np.flip(labels, axis=width_dim) # Change left-right parts perm = np.arange(labels.shape[joint_dim]) for i, j in matchedParts: perm[i] = j perm[j] = i labels = np.take(labels, perm, axis=joint_dim) return labels def fliplr_coords(pts, width, matchedParts): # Flip horizontally (only flip valid points) pts = np.array([(width - x, y) if x > 0 else (x, y) for x, y in pts]) # Change left-right parts perm = np.arange(pts.shape[0]) for i, j in matchedParts: perm[i] = j perm[j] = i pts = pts[perm] return pts """ Transform, Crop """ def get_transform(center, scale, rot, res, invert=False): '''Prepare transformation matrix (scale, rot). ''' h = 200 * scale t = np.eye(3) # transformation matrix # scale t[0, 0] = res[1] / h t[1, 1] = res[0] / h # translation t[0, 2] = res[1] * (-center[0] / h + .5) t[1, 2] = res[0] * (-center[1] / h + .5) # rotation if rot != 0: rot = -rot # To match direction of rotation from cropping rot_mat = np.zeros((3, 3)) rot_rad = rot * np.pi / 180 sn, cs = np.sin(rot_rad), np.cos(rot_rad) rot_mat[:2, :2] = [[cs, -sn], [sn, cs]] rot_mat[2, 2] = 1 # Need to make sure rotation is around center t_mat = np.eye(3) t_mat[0, 2] = -res[1] / 2 t_mat[1, 2] = -res[0] / 2 t_inv = t_mat.copy() t_inv[:2, 2] *= -1 t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) if invert: t = np.linalg.inv(t) return t def transform(pts, center, scale, rot, res, invert=False): """ Transform points from original coord to new coord pts: 2 * n array """ t = get_transform(center, scale, rot, [res, res], invert) pts = np.array(pts) assert pts.shape[0] == 2, pts.shape if pts.ndim == 1: pts = np.array([pts[0], pts[1], 1]) else: pts = np.concatenate([pts, np.ones((1, pts.shape[1]))], axis=0) new_pt = np.dot(t, pts) return new_pt[:2].astype(int) def crop(img, center, scale, rot, res): ''' res: single value of targeted output image resolution rot: in degrees ''' # Preprocessing for efficient cropping ht, wd = img.shape[0], img.shape[1] # print(center, scale, rot, ht, wd) sf = scale * 200.0 / res # print(sf) if sf < 2: sf = 1 else: new_size = int(np.math.floor(max(ht, wd) / sf)) new_ht = int(np.math.floor(ht / sf)) new_wd = int(np.math.floor(wd / sf)) if new_size < 2: # Zoomed out so much that the image is now a single pixel or less return np.zeros(res, res) if img.ndim == 2 \ else np.zeros(res, res, img.shape[2]) else: img = sktf.resize(img, [new_ht, new_wd], preserve_range=True) ht, wd = img.shape[0], img.shape[1] # print(ht, wd) # Calculate upper left and bottom right coordinates defining crop region center = center / sf scale = scale / sf # print(center, scale) ul = transform([0, 0], center, scale, 0, res, invert=True) br = transform([res, res], center, scale, 0, res, invert=True) if sf >= 2: br += - (br - ul - res) # print(ul, br) # Padding so that when rotated proper amount of context is included pad = np.math.ceil(np.linalg.norm(br - ul) / 2 - (br[0] - ul[0]) / 2) # print(pad) if rot != 0: ul -= pad br += pad # print(ul, br) # Define the range of pixels to take from the old image old_x = max(0, ul[0]), min(br[0], wd) old_y = max(0, ul[1]), min(br[1], ht) # print(old_x, old_y) # And where to put them in the new image new_x = max(0, -ul[0]), min(br[0], wd) - ul[0] new_y = max(0, -ul[1]), min(br[1], ht) - ul[1] # print(new_x, new_y) # Initialize new image and copy pixels over new_shape = [br[1] - ul[1], br[0] - ul[0]] # print(new_shape) if len(img.shape) > 2: new_shape += [img.shape[2]] new_img = np.zeros(new_shape) new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] if rot != 0: # Rotate the image and remove padded area new_img = sktf.rotate(new_img, rot, preserve_range=True) new_img = new_img[pad:-pad, pad:-pad] if sf < 2: new_img = sktf.resize(new_img, [res, res], preserve_range=True) return new_img
2、針對資料集去讀取資料batch
結合這個 指令碼以及上面的資料增廣指令碼兩個指令碼基本上完成了全部的 資料操作。
class MPII_Dataset(torch.utils.data.Dataset): def __init__(self, data_root, split, inp_res=256, out_res=64, sigma=1, scale_factor=0.25, rot_factor=30, return_meta=False, small_image=True): self.data_root = data_root self.split = split self.inp_res = inp_res self.out_res = out_res self.sigma = sigma self.scale_factor = scale_factor self.rot_factor = rot_factor self.return_meta = return_meta self.small_image = small_image self.nJoints = 16 self.accIdxs = [0, 1, 2, 3, 4, 5, 10, 11, 14, 15] # joint idxs for accuracy calculation self.flipRef = [[0, 5], [1, 4], [2, 3], # noqa [10, 15], [11, 14], [12, 13]] self.annot = {} tags = ['imgname', 'part', 'center', 'scale'] f = h5py.File('{}/mpii/{}.h5'.format(data_root, split), 'r') for tag in tags: self.annot[tag] = np.asarray(f[tag]).copy() f.close() def _getPartInfo(self, index): # get a COPY pts = self.annot['part'][index].copy() c = self.annot['center'][index].copy() s = self.annot['scale'][index].copy() # Small adjustment so cropping is less likely to take feet out c[1] = c[1] + 15 * s s = s * 1.25 return pts, c, s def _loadImage(self, index): impath = os.path.join(self.data_root, 'mpii/images', self.annot['imgname'][index].decode('utf-8')) im = skim.img_as_float(skio.imread(impath)) return im def __getitem__(self, index): im = self._loadImage(index) pts, c, s = self._getPartInfo(index) r = 0 if self.split == 'train': # scale and rotation s = s * (2 ** rnd(self.scale_factor)) r = 0 if rand() < 0.6 else rnd(self.rot_factor) # flip LR if rand() < 0.5: im = im[:, ::-1, :] pts = fliplr_coords(pts, width=im.shape[1], matchedParts=self.flipRef) c[0] = im.shape[1] - c[0] # flip center point also # Color jitter im = np.clip(im * np.random.uniform(0.6, 1.4, size=3), 0, 1) # Prepare image im = crop(im, c, s, r, self.inp_res) if im.ndim == 2: im = np.tile(im, [1, 1, 3]) if self.small_image: # small size image im_s = sktf.resize(im, [self.out_res, self.out_res], preserve_range=True) # (h, w, c) to (c, h, w) im = np.transpose(im, [2, 0, 1]) if self.small_image: im_s = np.transpose(im_s, [2, 0, 1]) # Prepare label labels = np.zeros((self.nJoints, self.out_res, self.out_res)) new_pts = transform(pts.T, c, s, r, self.out_res).T for i in range(self.nJoints): if pts[i, 0] > 0: labels[i] = create_label( labels.shape[1:], new_pts[i], self.sigma) ret_list = [im.astype(np.float32), labels.astype(np.float32)] if self.small_image: ret_list.append(im_s) if self.return_meta: meta = [pts, c, s, r] ret_list.append(meta) return tuple(ret_list) def __len__(self): return len(self.annot['imgname'])
二、模型程式碼
1、首先我們先去把 殘差網路的基本模組定義一下
class HgResBlock(nn.Module):
''' Hourglass residual block '''
def __init__(self, inplanes, outplanes, stride=1):
super().__init__()
self.inplanes = inplanes
self.outplanes = outplanes
midplanes = outplanes // 2
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, midplanes, 1, stride) # bias=False
self.bn2 = nn.BatchNorm2d(midplanes)
self.conv2 = nn.Conv2d(midplanes, midplanes, 3, stride, 1)
self.bn3 = nn.BatchNorm2d(midplanes)
self.conv3 = nn.Conv2d(midplanes, outplanes, 1, stride) # bias=False
self.relu = nn.ReLU(inplace=True)
if inplanes != outplanes:
self.conv_skip = nn.Conv2d(inplanes, outplanes, 1, 1)
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
if self.inplanes != self.outplanes:
residual = self.conv_skip(residual)
out += residual
return out
2、定義hourglass基本結構
class Hourglass(nn.Module):
def __init__(self, depth, nFeat, nModules, resBlock):
super().__init__()
self.depth = depth
self.nFeat = nFeat
self.nModules = nModules # num residual modules per location
self.resBlock = resBlock
self.hg = self._make_hour_glass()
self.downsample = nn.MaxPool2d(2, 2)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
def _make_hour_glass(self):
hg = []
for i in range(self.depth):
res = [self._make_residual(self.nModules) for _ in range(3)] # skip(upper branch); down_path, up_path(lower branch)
if i == (self.depth - 1):
res.append(self._make_residual(self.nModules)) # extra one for the middle
hg.append(nn.ModuleList(res))
return nn.ModuleList(hg)
def _make_residual(self, n):
return nn.Sequential(*[self.resBlock(self.nFeat, self.nFeat) for _ in range(n)])
def forward(self, x):
return self._hour_glass_forward(0, x)
def _hour_glass_forward(self, depth_id, x):
up1 = self.hg[depth_id][0](x)
low1 = self.downsample(x)
low1 = self.hg[depth_id][1](low1)
if depth_id == (self.depth - 1):
low2 = self.hg[depth_id][3](low1)
else:
low2 = self._hour_glass_forward(depth_id + 1, low1)
low3 = self.hg[depth_id][2](low2)
up2 = self.upsample(low3)
return up1 + up2
class HourglassNet(nn.Module):
'''Hourglass model from Newell et al ECCV 2016'''
def __init__(self, nStacks, nModules, nFeat, nClasses, resBlock=HgResBlock, inplanes=3):
super().__init__()
self.nStacks = nStacks
self.nModules = nModules
self.nFeat = nFeat
self.nClasses = nClasses
self.resBlock = resBlock
self.inplanes = inplanes
self._make_head()
hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
for i in range(nStacks):
hg.append(Hourglass(4, nFeat, nModules, resBlock))
res.append(self._make_residual(nModules))
fc.append(self._make_fc(nFeat, nFeat))
score.append(nn.Conv2d(nFeat, nClasses, 1))
if i < (nStacks - 1):
fc_.append(nn.Conv2d(nFeat, nFeat, 1))
score_.append(nn.Conv2d(nClasses, nFeat, 1))
self.hg = nn.ModuleList(hg)
self.res = nn.ModuleList(res)
self.fc = nn.ModuleList(fc)
self.score = nn.ModuleList(score)
self.fc_ = nn.ModuleList(fc_)
self.score_ = nn.ModuleList(score_)
def _make_head(self):
self.conv1 = nn.Conv2d(self.inplanes, 64, 7, 2, 3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.res1 = self.resBlock(64, 128)
self.pool = nn.MaxPool2d(2, 2)
self.res2 = self.resBlock(128, 128)
self.res3 = self.resBlock(128, self.nFeat)
def _make_residual(self, n):
return nn.Sequential(*[self.resBlock(self.nFeat, self.nFeat) for _ in range(n)])
def _make_fc(self, inplanes, outplanes):
return nn.Sequential(
nn.Conv2d(inplanes, outplanes, 1),
nn.BatchNorm2d(outplanes),
nn.ReLU(True))
def forward(self, x):
# head
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.res1(x)
x = self.pool(x)
x = self.res2(x)
x = self.res3(x)
out = []
for i in range(self.nStacks):
y = self.hg[i](x)
y = self.res[i](y)
y = self.fc[i](y)
score = self.score[i](y)
out.append(score)
if i < (self.nStacks - 1):
fc_ = self.fc_[i](y)
score_ = self.score_[i](score)
x = x + fc_ + score_
return out
三、訓練
初始化資料 和 網路
train_set = MPII_Dataset(
FLAGS.dataDir, split='train',
inp_res=FLAGS.inputRes, out_res=FLAGS.outputRes,
scale_factor=FLAGS.scale, rot_factor=FLAGS.rotate, sigma=FLAGS.hmSigma)
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=FLAGS.trainBatch, shuffle=True,
num_workers=FLAGS.nThreads, pin_memory=True)
netHg = nn.DataParallel(HourglassNet(
nStacks=FLAGS.nStacks, nModules=FLAGS.nModules, nFeat=FLAGS.nFeats,
nClasses=train_set.nJoints)) # ref `nClasses` from dataset
criterion = nn.MSELoss()
if cuda:
torch.backends.cudnn.benchmark = True
netHg.cuda()
criterion.cuda()
optimHg = torch.optim.RMSprop(
netHg.parameters(),
lr=FLAGS.lr,
alpha=FLAGS.alpha, eps=FLAGS.eps)
呼叫網路進行訓練
def run(epoch, iter_start=0):
netHg.train()
global global_step
pbar = tqdm.tqdm(train_loader, desc='Epoch %02d' % epoch, dynamic_ncols=True)
pbar_info = tqdm.tqdm(bar_format='{bar}{postfix}')
avg_acc = 0
for it, sample in enumerate(pbar, start=iter_start):
global_step += 1
image, label, image_s = sample
image = Variable(image)
label = Variable(label)
image_s = Variable(image_s)
if FLAGS.cuda:
image = image.cuda(async=True) # TODO: check the affect of async
label = label.cuda(async=True)
image_s = image_s.cuda(async=True)
# generator
outputs = netHg(image)
loss_hg_content = 0
for out in outputs: # TODO: speed up with multiprocessing map?
loss_hg_content += criterion(out, label)
loss_hg = loss_hg_content
optimHg.zero_grad()
loss_hg.backward()
optimHg.step()
accs = accuracy(outputs[-1].data.cpu(), label.data.cpu(), train_set.accIdxs)
sumWriter.add_scalar('loss_hg', loss_hg, global_step)
sumWriter.add_scalar('acc', accs[0], global_step)
# TODO: learning rate scheduling
# sumWriter.add_scalar('lr', lr, global_step)
pbar_info.set_postfix({
'loss_hg': getValue(loss_hg),
'acc': accs[0]
})
pbar_info.update()
avg_acc += accs[0] / len(train_loader)
pbar_info.set_postfix_str('avg_acc: {}'.format(avg_acc))
pbar.close()
pbar_info.close()