一種人臉68特徵點檢測的深度學習方法
該人臉68特徵點檢測的深度學習方法採用VGG作為原型進行改造(以下簡稱mini VGG),從資料集的準備,網路模型的構造以及最終的訓練過程三個方面進行介紹,工程原始碼詳見:Github連結
一、資料集的準備
1、資料集的採集
第二類是自己標註的資料集:
這部分主要是用標註工具對自己收集到的圖片進行標註,我採用自己的標註工具進行標註後,生成的是一個包含68點座標位置的txt文件,之後要需要通過以下指令碼將其轉換成公共資料集中類似的pts檔案的形式:
from __future__ import division
import os
import cv2
from compiler.ast import flatten
txt_dir = "/Users/camlin_z/Data/68landmark/txt/"
txt_new_dir = "/Users/camlin_z/Data/68landmark/landmark/"
def trans_label():
files = os.listdir(txt_dir)
for file in files:
flag = file.find(".")
if flag > 0:
txt_name = file[:flag] + ".pts"
print txt_name
line = open(txt_dir + file, 'r' )
for label in line:
label = label.strip().split()
label = map(float, label)
file_new = open(txt_new_dir + txt_name, 'w+')
file_new.write("version: 1" + "\n")
file_new.write("n_points: 68" + "\n")
file_new.write("{" + "\n")
for i in range(0, 135, 2):
file_new.write(str(label[i]) + " " + str(label[i+1]) + "\n")
file_new.write("}")
else:
print file, " not exist!"
if __name__ == '__main__':
trans_label()
通過以上的整理過程,就可以將資料集整理成以下形式:
同時需要將以上資料集分成訓練集和測試集兩個部分。
2、資料集的預處理
準備好上面的五個資料集後,接下來就是對於資料集的一系列處理了,由於特徵點的檢測是基於檢測框檢測出來之後,將影象crop出只有人臉的部分,然後再進行特徵點的檢測過程(因為這樣可以大量的減少影象中其他因素的干擾,將神經網路的功能聚焦到特徵點檢測的任務上面來),所以需要根據以上資料集中標註的特徵點位置來裁剪出一個只有人臉的區域,用於神經網路的訓練。
處理過程主要參考:
https://yinguobing.com/facial-landmark-localization-by-deep-learning-data-collate/
但是在影象進行預處理之後,特徵點的位置同樣也會發生變化,上面作者分享的程式碼在對影象進行處理之後沒有將對應的特徵點座標進行處理,所以我將原始的程式碼進行改進,同時對特徵點座標和影象進行處理,並生成最終我們網路訓練需要的label形式,程式碼如下:
# -*- coding: utf-8 -*-
"""
This script shows how to read iBUG pts file and draw all the landmark points on image.
"""
from __future__ import division
import os
import cv2
from compiler.ast import flatten
import face_detector_image as fd
from lxml import etree, objectify
from compiler.ast import flatten
import shutil
# 0: test the pts of crop image
# 1: output the crop image
test_flag = 0
# List all the files
filelist_train = ["300W/trainset", "afw", "data2", "data3", "data4/trainset",
"helen/trainset", "landmark/trainset", "lfpw/trainset"]
filelist_test = ["300W/testset", "data4/testset", "helen/testset",
"landmark/testset", "lfpw/testset"]
filelist = filelist_train
def mkr(dr):
if not os.path.exists(dr):
os.mkdir(dr)
def read_points(file_name=None):
"""
Read points from .pts file.
"""
points = []
with open(file_name) as file:
line_count = 0
for line in file:
if "version" in line or "points" in line or "{" in line or "}" in line:
continue
else:
loc_x, loc_y = line.strip().split()
points.append([float(loc_x), float(loc_y)])
line_count += 1
return points
def draw_landmark_point(image, points):
"""
Draw landmark point on image.
"""
for point in points:
cv2.circle(image, (int(point[0]), int(
point[1])), 2, (0, 255, 0), -1, cv2.LINE_AA)
return image
def points_are_valid(points, image):
"""Check if all points are in image"""
min_box = get_minimal_box(points)
if box_in_image(min_box, image):
return True
return False
def get_square_box(box):
"""Get the square boxes which are ready for CNN from the boxes"""
left_x = box[0]
top_y = box[1]
right_x = box[2]
bottom_y = box[3]
box_width = right_x - left_x
box_height = bottom_y - top_y
# Check if box is already a square. If not, make it a square.
diff = box_height - box_width
delta = int(abs(diff) / 2)
if diff == 0: # Already a square.
return box
elif diff > 0: # Height > width, a slim box.
left_x -= delta
right_x += delta
if diff % 2 == 1:
right_x += 1
else: # Width > height, a short box.
top_y -= delta
bottom_y += delta
if diff % 2 == 1:
bottom_y += 1
# Make sure box is always square.
assert ((right_x - left_x) == (bottom_y - top_y)), 'Box is not square.'
return [left_x, top_y, right_x, bottom_y]
def get_minimal_box(points):
"""
Get the minimal bounding box of a group of points.
The coordinates are also converted to int numbers.
"""
min_x = int(min([point[0] for point in points]))
max_x = int(max([point[0] for point in points]))
min_y = int(min([point[1] for point in points]))
max_y = int(max([point[1] for point in points]))
return [min_x, min_y, max_x, max_y]
def move_box(box, offset):
"""Move the box to direction specified by offset"""
left_x = box[0] + offset[0]
top_y = box[1] + offset[1]
right_x = box[2] + offset[0]
bottom_y = box[3] + offset[1]
return [left_x, top_y, right_x, bottom_y]
def expand_box(square_box, scale_ratio=1.2):
"""Scale up the box"""
assert (scale_ratio >= 1), "Scale ratio should be greater than 1."
delta = int((square_box[2] - square_box[0]) * (scale_ratio - 1) / 2)
left_x = square_box[0] - delta
left_y = square_box[1] - delta
right_x = square_box[2] + delta
right_y = square_box[3] + delta
return [left_x, left_y, right_x, right_y]
def points_in_box(points, box):
"""Check if box contains all the points"""
minimal_box = get_minimal_box(points)
return box[0] <= minimal_box[0] and \
box[1] <= minimal_box[1] and \
box[2] >= minimal_box[2] and \
box[3] >= minimal_box[3]
def box_in_image(box, image):
"""Check if the box is in image"""
rows = image.shape[0]
cols = image.shape[1]
return box[0] >= 0 and box[1] >= 0 and box[2] <= cols and box[3] <= rows
def box_is_valid(image, points, box):
"""Check if box is valid."""
# Box contains all the points.
points_is_in_box = points_in_box(points, box)
# Box is in image.
box_is_in_image = box_in_image(box, image)
# Box is square.
w_equal_h = (box[2] - box[0]) == (box[3] - box[1])
# Return the result.
return box_is_in_image and points_is_in_box and w_equal_h
def fit_by_shifting(box, rows, cols):
"""Method 1: Try to move the box."""
# Face box points.
left_x = box[0]
top_y = box[1]
right_x = box[2]
bottom_y = box[3]
# Check if moving is possible.
if right_x - left_x <= cols and bottom_y - top_y <= rows:
if left_x < 0: # left edge crossed, move right.
right_x += abs(left_x)
left_x = 0
if right_x > cols: # right edge crossed, move left.
left_x -= (right_x - cols)
right_x = cols
if top_y < 0: # top edge crossed, move down.
bottom_y += abs(top_y)
top_y = 0
if bottom_y > rows: # bottom edge crossed, move up.
top_y -= (bottom_y - rows)
bottom_y = rows
return [left_x, top_y, right_x, bottom_y]
def fit_by_shrinking(box, rows, cols):
"""Method 2: Try to shrink the box."""
# Face box points.
left_x = box[0]
top_y = box[1]
right_x = box[2]
bottom_y = box[3]
# The first step would be get the interlaced area.
if left_x < 0: # left edge crossed, set zero.
left_x = 0
if right_x > cols: # right edge crossed, set max.
right_x = cols
if top_y < 0: # top edge crossed, set zero.
top_y = 0
if bottom_y > rows: # bottom edge crossed, set max.
bottom_y = rows
# Then found out which is larger: the width or height. This will
# be used to decide in which dimention the size would be shrinked.
width = right_x - left_x
height = bottom_y - top_y
delta = abs(width - height)
# Find out which dimention should be altered.
if width > height: # x should be altered.
if left_x != 0 and right_x != cols: # shrink from center.
left_x += int(delta / 2)
right_x -= int(delta / 2) + delta % 2
elif left_x == 0: # shrink from right.
right_x -= delta
else: # shrink from left.
left_x += delta
else: # y should be altered.
if top_y != 0 and bottom_y != rows: # shrink from center.
top_y += int(delta / 2) + delta % 2
bottom_y -= int(delta / 2)
elif top_y == 0: # shrink from bottom.
bottom_y -= delta
else: # shrink from top.
top_y += delta
return [left_x, top_y, right_x, bottom_y]
def fit_box(box, image, points):
"""
Try to fit the box, make sure it satisfy following conditions:
- A square.
- Inside the image.
- Contains all the points.
If all above failed, return None.
"""
rows = image.shape[0]
cols = image.shape[1]
# First try to move the box.
box_moved = fit_by_shifting(box, rows, cols)
# If moving faild ,try to shrink.
if box_is_valid(image, points, box_moved):
return box_moved
else:
box_shrinked = fit_by_shrinking(box, rows, cols)
# If shrink failed, return None
if box_is_valid(image, points, box_shrinked):
return box_shrinked
# Finally, Worst situation.
print("Fitting failed!")
return None
def get_valid_box(image, points):
"""
Try to get a valid face box which meets the requirments.
The function follows these steps:
1. Try method 1, if failed:
2. Try method 0, if failed:
3. Return None
"""
# Try method 1 first.
def _get_postive_box(raw_boxes, points):
for box in raw_boxes:
# Move box down.
diff_height_width = (box[3] - box[1]) - (box[2] - box[0])
offset_y = int(abs(diff_height_width / 2))
box_moved = move_box(box, [0, offset_y])
# Make box square.
square_box = get_square_box(box_moved)
# Remove false positive boxes.
if points_in_box(points, square_box):
return square_box
return None
# Try to get a positive box from face detection results.
_, raw_boxes = fd.get_facebox(image, threshold=0.5)
positive_box = _get_postive_box(raw_boxes, points)
if positive_box is not None:
if box_in_image(positive_box, image) is True:
return positive_box
return fit_box(positive_box, image, points)
# Method 1 failed, Method 0
min_box = get_minimal_box(points)
sqr_box = get_square_box(min_box)
epd_box = expand_box(sqr_box)
if box_in_image(epd_box, image) is True:
return epd_box
return fit_box(epd_box, image, points)
def get_new_pts(facebox, raw_points, label_txt, image_file, flag, ratio_w, ratio_h):
"""
generate a new pts file according to face box
"""
x = facebox[0]
y = facebox[1]
# print x, y
new_point = []
label_pts = flatten(raw_points)
# print label_pts
label_txt.write(flag + image_file + ".jpg ")
for i in range(0, 135, 2):
if i != 134:
x_temp = int((label_pts[i] - x) * ratio_w )
y_temp = int((label_pts[i + 1] - y) * ratio_h)
new_point.append([x_temp, y_temp])
label_txt.write(str(x_temp) + " " + str(y_temp) + " ")
else:
x_temp = int((label_pts[i] - x) * ratio_w)
y_temp = int((label_pts[i + 1] - y) * ratio_h)
new_point.append([x_temp, y_temp])
label_txt.write(str(x_temp) + " " + str(y_temp))
label_txt.write("\n")
# print new_point
return new_point
def preview(point_file, test_flag, bbox_new_file):
"""
Preview points on image.
"""
# Read the points from file.
raw_points = read_points(point_file)
# Safe guard, make sure point importing goes well.
assert len(raw_points) == 68, "The landmarks should contain 68 points."
# Read the image.
head, tail = os.path.split(point_file)
image_file = tail.split('.')[-2]
img_jpeg = os.path.join(head, image_file + ".jpeg")
img_jpg = os.path.join(head, image_file + ".jpg")
img_png = os.path.join(head, image_file + ".png")
if os.path.exists(img_jpg):
img = cv2.imread(img_jpg)
img_file = img_jpg
elif os.path.exists(img_jpeg):
img = cv2.imread(img_jpeg)
img_file = img_jpeg
else:
img = cv2.imread(img_png)
img_file = img_png
print image_file
# Fast check: all points are in image.
if points_are_valid(raw_points, img) is False:
return None
# Get the valid facebox.
facebox = get_valid_box(img, raw_points)
if facebox is None:
print("Using minimal box.")
facebox = get_minimal_box(raw_points)
# Extract valid image area.
face_area = img[facebox[1]:facebox[3],
facebox[0]: facebox[2]]
rw = 1
rh = 1
# Check if resize is needed.
width = facebox[2] - facebox[0]
height = facebox[3] - facebox[1]
print width,height
if width != height:
print('opps!', width, height)
if (width != 224) or (height != 224):
face_area = cv2.resize(face_area, (224, 224))
rw = 224 / width
rh = 224 / height
# generate a new pts file according to facebox
new_point = get_new_pts(facebox, raw_points, label_txt,
image_file, flag, rw, rh)
if test_flag == 0:
# verify the crop image whether match to 68 point or not
face_area = draw_landmark_point(face_area, new_point)
cv2.imwrite(DATA_TEST_DST + image_file + ".jpg", face_area)
else:
cv2.imwrite(DATA_DST + image_file + ".jpg", face_area)
# Show the result.
cv2.imshow("Crop face", face_area)
if cv2.waitKey(10) == 27:
cv2.waitKey()
# # Show whole image in window.
# width, height = img.shape[:2]
# max_height = 640
# if height > max_height:
# img = cv2.resize(
# img, (max_height, int(width * max_height / height)))
# cv2.imshow("preview", img)
# cv2.waitKey()
def main():
"""
The main entrance
"""
for file_string in filelist:
root = "/Users/camlin_z/Data/data/"
# 影象儲存的路勁
DATA_DIR = root + file_string + "/"
# crop之後影象儲存的路勁
DATA_DST = root + file_string + "_crop/"
# 儲存將轉換後的座標畫在crop之後的影象的路徑,用於驗證座標的轉換是否出現錯誤
DATA_TEST_DST = root + file_string + "_pts/"
# 最終生成網路訓練需要的label的txt檔案的路徑
point_new_file = root + file_string + ".txt"
flag = file_string + "/"
pts_file_list = []
for file_path, _, file_names in os.walk(DATA_DIR):
for file_name in file_names:
if file_name.split(".")[-1] in ["pts"]:
pts_file_list.append(os.path.join(file_path, file_name))
label_txt = open(point_new_file, 'w')
mkr(DATA_DST)
mkr(DATA_TEST_DST)
# Show the image one by one.
for file_name in pts_file_list:
preview(file_name, test_flag, bbox_new_file)
if __name__ == "__main__":
main()
3、資料增強
由於以上資料集總共加起來只有五千張左右,對於需要大資料訓練的神經網路顯然是不夠的,所以這裡考慮對上面的資料集進行資料增強的操作,由於專案的需要,所以主要是對原來的資料集進行旋轉的資料增強。
以上的旋轉主要可以分為兩種策略:
1、將原始影象直接保持原始大小進行旋轉
2、將原始影象旋轉後將生成的影象的四個邊向外擴充,使得生成的影象不會切掉原始影象的四個邊。
主要分為±15°,±30°,±45°,±60°四種旋轉型別,在進行資料增強的過程中,主要由三個問題需要解決:
(1)旋轉後產生的黑色區域可能影響卷積學習特徵
(2)利用以上產生的只有人臉的影象進行旋轉可能將之前標註的特徵點旋轉到影象的外面,導致某些特徵點損失
(3)旋轉後特徵點座標的生成
針對第一個問題:
可以參考:https://blog.csdn.net/guyuealian/article/details/77993410中的方法,對影象的黑色區域利用其邊緣值的二次插值來進行填充,但是上面的處理過程可能會產生一些奇怪的邊緣效果,如下圖所示:
有擔心上面這些奇怪的特徵是不是會影響最終卷積網路的學習結果,但是暫時還沒有找到合適的解決方法,有大牛知道,感謝留言。
針對第二個問題:
可以參考:https://www.oschina.net/translate/opencv-rotation中的程式碼,將影象旋轉後根據其旋轉後產生的新的長寬來儲存圖片,保證最終生成的旋轉後的圖片不會去掉原始圖片的四個角,上面展示的圖片就是利用這種方法進行旋轉-60°之後的結果。
將以上問題一一解決之後,由於採用策略二進行旋轉時會產生上面圖片所示的大塊奇怪的特徵,但是策略一則不會產生那麼大塊的奇怪的特徵,所以我對於旋轉的資料增強的整體邏輯如下:
按照上面的處理過程,可以寫出下面的程式碼:
#-*- coding: UTF-8 -*-
from __future__ import division
import cv2
import os
import numpy as np
import math
filelist = ["300W/trainset", "afw/trainset", "data2/trainset", "data3/trainset",
"data4/trainset", "helen/trainset", "landmark/trainset", "lfpw/trainset",
"300W/testset", "afw/testset", "data2/testset", "data3/testset",
"data4/testset", "helen/testset", "landmark/testset", "lfpw/testset"]
img_dir = "/Users/camlin_z/Data/data_output/"
angles = [15, 30, 45, 60]
def mkr(dr):
if not os.path.exists(dr):
os.mkdir(dr)
def read_points(file_name=None):
"""
Read points from .pts file.
"""
points = []
with open(file_name) as file:
line_count = 0
for line in file:
if "version" in line or "points" in line or "{" in line or "}" in line:
continue
else:
loc_x, loc_y = line.strip().split()
points.append([float(loc_x), float(loc_y)])
line_count += 1
return points
def draw_save_landmark(image, points, dst):
"""
Draw landmark point on image.
"""
for point in points:
cv2.circle(image, (int(point[0]), int(
point[1])), 2, (0, 255, 0), -1, cv2.LINE_AA)
cv2.imwrite(dst, image)
def trans_label(txt, label):
file_new = open(txt, 'w+')
file_new.write("version: 1" + "\n")
file_new.write("n_points: 68" + "\n")
file_new.write("{" + "\n")
for point in label:
file_new.write(str(point[0]) + " " + str(point[1]) + "\n")
file_new.write("}")
def rotate_with_adjust_size(img, theta):
img_raw = cv2.imread(img)
height, width = img_raw.shape[:2]
center = (width / 2, height / 2)
scale = 1
rangle = np.deg2rad(theta) # angle in radians
# now calculate new image width and height
nw = (abs(np.sin(rangle) * height) + abs(np.cos(rangle) * width)) * scale
nh = (abs(np.cos(rangle) * height) + abs(np.sin(rangle) * width)) * scale
# ask OpenCV for the rotation matrix
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), theta, scale)
# calculate the move from the old center to the new center combined
# with the rotation
rot_move = np.dot(rot_mat, np.array([(nw - width) * 0.5, (nh - height) * 0.5, 0]))
# the move only affects the translation, so update the translation
# part of the transform
rot_mat[0, 2] += rot_move[0]
rot_mat[1, 2] += rot_move[1]
img_rotate = cv2.warpAffine(img_raw, rot_mat, (int(np.math.ceil(nw)), int(np.math.ceil(nh))), cv2.INTER_LANCZOS4,
cv2.BORDER_REFLECT, 1)
offset_w = (nw - width) / 2
offset_h = (nh - height) / 2
img_rotate = cv2.resize(img_rotate, (224, 224))
rw = 224 / nw
rh = 224 / nh
return img_rotate, center, offset_w, offset_h, rw, rh
def rotate_with_original_size(img, theta):
img_raw = cv2.imread(img)
height, width = img_raw.shape[:2]
center = (width / 2, height / 2)
rot_mat = cv2.getRotationMatrix2D(center, theta, 1)
img_rotate = cv2.warpAffine(img_raw, rot_mat, (width, height), cv2.INTER_LANCZOS4,
cv2.BORDER_REFLECT, 1)
return img_rotate, center
def rotate_pts_original_size(img, points, center, angle):
flag = 0
new_points = []
height, width = img.shape[:2]
theta = np.deg2rad(angle)
for i in range(len(points)):
[x_raw, y_raw] = points[i]
y_raw = height - y_raw
(center_x, center_y) = center
center_y = height - center_y
x = round((x_raw - center_x) * math.cos(theta) - (y_raw - center_y) * math.sin(theta) + center_x)
y = round((x_raw - center_x) * math.sin(theta) + (y_raw - center_y) * math.cos(theta) + center_y)
x = int(x)
y = int(height - y)
if x <= 0 or y <= 0:
flag = 1
break
new_points.append([x, y])
return new_points, flag
def rotate_pts_adjust_size(img, points, center, angle, offset_w, offset_h, rate_w, rate_h):
new_points = []
height, width = img.shape[:2]
theta = np.deg2rad(angle)
for i in range(len(points)):
[x_raw, y_raw] = points[i]
y_raw = height - y_raw
(center_x, center_y) = center
center_y = height - center_y
x = round((x_raw - center_x) * math.cos(theta) - (y_raw - center_y) * math.sin(theta) + center_x)
y = round((x_raw - center_x) * math.sin(theta) + (y_raw - center_y) * math.cos(theta) + center_y)
x = int((x + offset_w) * rate_w)
y = int((height - y + offset_h) * rate_h)
new_points.append([x, y])
return new_points
def main():
for angle in angles:
for file_string in filelist:
out_dir = os.path.join(img_dir, file_string + "_" + str(abs(angle)))
out_verify_dir = out_dir + "/out/"
mkr(out_dir)
mkr(out_verify_dir)
for file_path, _, file_names in os.walk(os.path.join(img_dir, file_string)):
for file_name in file_names:
if file_name.split(".")[-1] in ["jpg", "png", "jpeg"]:
print file_name
# 讀取影象路徑
img_file_path = os.path.join(img_dir, file_string, file_name)
# 讀取pts檔案路徑
pts_file_name = file_name.split(".")[0] + ".pts"
pts_file_path = os.path.join(img_dir, file_string, pts_file_name)
# 寫入pts檔案路徑
pts_new_dir = os.path.join(out_dir, pts_file_name)
############ 原始大小旋轉影象和點 ############
# 隨機生成指定的旋轉角度
if angle == 15:
theta_pos = np.random.randint(0, 15)
theta_neg = np.random.randint(-15, 0)
elif angle == 30:
theta_pos = np.random.randint(15, 30)
theta_neg = np.random.randint(-30, -15)
elif angle == 45:
theta_pos = np.random.randint(30, 45)
theta_neg = np.random.randint(-45, -30)
else:
theta_pos = np.random.randint(45, 60)
theta_neg = np.random.randint(-60, -45)
arr = np.random.randint(0, 2)
if arr == 0:
theta = theta_pos
else:
theta = theta_neg
print theta
# 旋轉影象
img, center = rotate_with_original_size(img_file_path, theta)
# 調整影象對應座標點
points = read_points(pts_file_path)
new_points, flag = rotate_pts_original_size(img, points, center, theta)
# 根據以上flag判斷產生的點是否超出影象位置
# 如果超出,則使用調整大小的方式旋轉
if flag == 1:
print img_file_path, "warning!!!"
img, center, offset_w, offset_h, rw, rh = rotate_with_adjust_size(img_file_path, theta)
new_points = rotate_pts_adjust_size(img, points, center, theta, offset_w, offset_h,rw, rh)
# 將影象寫入輸出資料夾
cv2.imwrite(os.path.join(out_dir, file_name), img)
# 將pts重新寫入輸出資料夾
trans_label(pts_new_dir, new_points)
# 將座標點畫到影象上驗證位置是否正確
out_img_path = os.path.join(out_verify_dir, file_name)
draw_save_landmark(img, new_points, out_img_path)
if __name__ == '__main__':
main()
以上過程處理完成後,就完成了所有的資料預處理過程了。
二、網路模型的構造
由於caffe的圖片輸入層只是支援一個標籤的輸入,所以本文中的caffe的iamge data layer經過了一定程度的修改,使其可以接受136個label值的輸入:
#ifdef USE_OPENCV
#include <opencv2/core/core.hpp>
#include <fstream> // NOLINT(readability/streams)
#include <iostream> // NOLINT(readability/streams)
#include <string>
#include <utility>
#include <vector>
#include "caffe/data_transformer.hpp"
#include "caffe/layers/base_data_layer.hpp"
#include "caffe/layers/image_data_layer.hpp"
#include "caffe/util/benchmark.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"
namespace caffe {
template <typename Dtype>
ImageDataLayer<Dtype>::~ImageDataLayer<Dtype>() {
this->StopInternalThread();
}
template <typename Dtype>
int ImageDataLayer<Dtype>::Rand(int n) {
if (n < 1) return 1;
caffe::rng_t* rng =
static_cast<caffe::rng_t*>(prefetch_rng_->generator());
return ((*rng)() % n);
}
template <typename Dtype>
void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const int new_height = this->layer_param_.image_data_param().new_height();
const int new_width = this->layer_param_.image_data_param().new_width();
const bool is_color = this->layer_param_.image_data_param().is_color();
const bool shuffleflag = this->layer_param_.image_data_param().shuffle();
string root_folder = this->layer_param_.image_data_param().root_folder();
CHECK((new_height == 0 && new_width == 0) ||
(new_height > 0 && new_width > 0)) << "Current implementation requires "
"new_height and new_width to be set at the same time.";
// Read the file with filenames and labels
const string& source = this->layer_param_.image_data_param().source();
LOG(INFO) << "Opening file " << source;
std::ifstream infile(source.c_str());
string line;
int pos;
int label_dim = 0;
bool gfirst = true;
int rd = shuffleflag?4:0;
while (std::getline(infile, line)) {
if(line.find_last_of(' ')==line.size()-2) line.erase(line.find_last_not_of(' ')-1);
pos = line.find_first_of(' ');
string str = line.substr(0, pos);
int p0 = pos + 1;
vector<float> vl;
while (pos != -1){
pos = line.find_first_of(' ', p0);
vl.push_back(atof(line.substr(p0, pos).c_str()));
p0 = pos + 1;
}
if (shuffleflag) {
float minx = vl[0];
float maxx = minx;
float miny = vl[1];
float maxy = miny;
for (int i = 2; i < vl.size(); i += 2){
if (vl[i] < minx) minx = vl[i];
else if (vl[i] > maxx) maxx = vl[i];
if (vl[i + 1] < miny) miny = vl[i + 1];
else if (vl[i + 1] > maxy) maxy = vl[i + 1];
}
vl.push_back(minx);
vl.push_back(maxx + 1);
vl.push_back(miny);
vl.push_back(maxy + 1);
}
if (gfirst){
label_dim = vl.size();
gfirst = false;
LOG(INFO) << "label dim: " << label_dim - rd;
//LOG(INFO) << line;
}
CHECK_EQ(vl.size(), label_dim) << "label dim not match in: " << lines_.size()<<", "<<lines_[lines_.size()-1].first;
lines_.push_back(std::make_pair(str, vl));
}
CHECK(!lines_.empty()) << "File is empty";
if (shuffleflag) {
// randomly shuffle data
LOG(INFO) << "Shuffling data & randomly crop image";
const unsigned int prefetch_rng_seed = caffe_rng_rand();
prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));
ShuffleImages();
} else {
if (this->phase_ == TRAIN && Caffe::solver_rank() > 0 &&
this->layer_param_.image_data_param().rand_skip() == 0) {
LOG(WARNING) << "Shuffling or skipping recommended for multi-GPU";
}
}
LOG(INFO) << "A total of " << lines_.size() << " images.";
lines_id_ = 0;
// Check if we would need to randomly skip a few data points
if (this->layer_param_.image_data_param().rand_skip()) {
unsigned int skip = caffe_rng_rand() %
this->layer_param_.image_data_param().rand_skip();
LOG(INFO) << "Skipping first " << skip << " data points.";
CHECK_GT(lines_.size(), skip) << "Not enough points to skip";
lines_id_ = skip;
}
// Read an image, and use it to initialize the top blob.
cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,
0, 0, is_color);
CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;
// Use data_transformer to infer the expected blob shape from a cv_image.
vector<int> top_shape(4);
top_shape[0] = 1;
top_shape[1] = cv_img.channels();
top_shape[2] = shuffleflag ? ne