frustum pointnets訓練程式碼學習筆記——kitti_object.py
frustum pointnets訓練程式碼學習筆記——kitti_object.py
本文記錄了博主學習frustum pointnets過程中遇到的2D和3D資料庫顯示程式。為了畫出輸出結果,博主希望在這個程式的基礎上修改一個可以顯示結果的程式。更新於2018.09.22。
本文首先給出程式碼原文的學習筆記,隨後整理出修改後的結果顯示程式,如果公開,會在這裡放上鍊接,如果有幫助,請在程式碼頁面點一下小星星哦。附可能有用的資訊:各個集合所用的檔名在kitti/image_sets
資料夾下。
文章目錄
總結
把總結寫在前面,根據需要判斷是否需要詳細看原始碼分析。
- 這個檔案的主要功能就是將KITTI庫中的2d和3d結果畫出來(至於是畫training還是testing,檔案的kitti_object函式的初始函式和get_label_objects分別有定義和判斷)。
- 檔案個數是人為在kitti_object函式中設定的,並非自動提取。
- 畫圖是機械第從第一個圖片一直向後顯示,且如果該圖片中有多個目標,僅顯示txt檔案中排在第一的那個的資料。通過修改
objects[0].print_object()
中[]裡面的標號可以指定畫第幾個目標,不過要注意的是,每個圖片中含有的目標個數是不同的。
用到的語法規則
這一部分記錄了程式碼原文中出現的語法規則,並不影響程式碼功能的理解,但是可能方便日後的使用,因此在這裡記錄下來。
from __future__ import print_function
加上這句話以後,即使在python2.X也要像python3.X一樣的語法使用print函式(加括號)。類似地,如果有其他新的功能特性且該特性與當前版本中的使用不相容,就可以從future模組匯入。詳細說明參考這裡。
from PIL import Image
PIL已經是python平臺事實上的影象處理標準庫了,全稱為Python Imaging Library。具體的使用方法說明可以參考這裡。
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
其中,__file__
就是當前所執行的檔案,也就是kitti_object.py
。os.path.abspath
命令獲取的是當前檔案的絕對路徑,比如博主的執行結果:
>>> print os.path.abspath("/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py")
/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py
而前面的os.path.dirname
獲取的則是當前路徑所存在於的資料夾,因此,BASE_DIR指向的就是kitti_object.py
所處的檔案夾了。執行結果為:
>>> print os.path.dirname(os.path.abspath("/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py"))
/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti
sys.path.append(os.path.join(ROOT_DIR, 'mayavi'))
其中,os.path.join
用於路徑拼接。
sys.path.append
:在匯入一個模組時,預設情況下python會搜尋當前目錄、已安裝的內建模組和第三方模組,搜尋路徑存放在sys模組的path中。如果要用的模組和當前指令碼不在一個目錄下,就需要將其新增到path中。這種修改是臨時的,指令碼執行後失效。
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
opencv中提供了cvtColor函式用於實現影象格式型別的相互轉換,具體說明可以參照這裡。
程式碼原文分析
#程式碼作者資訊
''' Helper class and functions for loading KITTI objects
Author: Charles R. Qi
Date: September 2017
'''
#載入必要的庫
from __future__ import print_function
import os
import sys
import numpy as np
import cv2
from PIL import Image
#定義基礎路徑
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) #指向當前檔案所在資料夾(kitti)
ROOT_DIR = os.path.dirname(BASE_DIR) #指向frustum資料夾
sys.path.append(os.path.join(ROOT_DIR, 'mayavi'))
import kitti_util as utils #載入論文作者寫的庫
try:
raw_input # Python 2
except NameError:
raw_input = input # Python 3
#用於獲取資料庫各項路徑路徑(training或testing)
class kitti_object(object):
'''Load and parse object data into a usable format.'''
def __init__(self, root_dir, split='training'):
'''root_dir contains training and testing folders'''
self.root_dir = root_dir
self.split = split
self.split_dir = os.path.join(root_dir, split)
if split == 'training':
self.num_samples = 7481
elif split == 'testing':
self.num_samples = 7518
else:
print('Unknown split: %s' % (split))
exit(-1)
self.image_dir = os.path.join(self.split_dir, 'image_2')
self.calib_dir = os.path.join(self.split_dir, 'calib')
self.lidar_dir = os.path.join(self.split_dir, 'velodyne')
self.label_dir = os.path.join(self.split_dir, 'label_2')
# 用於後面獲取樣本庫內的樣本總數
def __len__(self):
return self.num_samples
def get_image(self, idx):
assert(idx<self.num_samples)
img_filename = os.path.join(self.image_dir, '%06d.png'%(idx))
return utils.load_image(img_filename)
def get_lidar(self, idx):
assert(idx<self.num_samples)
lidar_filename = os.path.join(self.lidar_dir, '%06d.bin'%(idx))
return utils.load_velo_scan(lidar_filename)
def get_calibration(self, idx):
assert(idx<self.num_samples)
calib_filename = os.path.join(self.calib_dir, '%06d.txt'%(idx))
return utils.Calibration(calib_filename)
# 獲取idx指示的樣本對應的label檔案路徑,並按行讀取檔案,返回檔案內容
def get_label_objects(self, idx):
assert(idx<self.num_samples and self.split=='training')
label_filename = os.path.join(self.label_dir, '%06d.txt'%(idx))
return utils.read_label(label_filename)
def get_depth_map(self, idx):
pass
def get_top_down(self, idx):
pass
class kitti_object_video(object):
''' Load data for KITTI videos '''
def __init__(self, img_dir, lidar_dir, calib_dir):
self.calib = utils.Calibration(calib_dir, from_video=True)
self.img_dir = img_dir
self.lidar_dir = lidar_dir
self.img_filenames = sorted([os.path.join(img_dir, filename) \
for filename in os.listdir(img_dir)])
self.lidar_filenames = sorted([os.path.join(lidar_dir, filename) \
for filename in os.listdir(lidar_dir)])
print(len(self.img_filenames))
print(len(self.lidar_filenames))
#assert(len(self.img_filenames) == len(self.lidar_filenames))
self.num_samples = len(self.img_filenames)
def __len__(self):
return self.num_samples
def get_image(self, idx):
assert(idx<self.num_samples)
img_filename = self.img_filenames[idx]
return utils.load_image(img_filename)
def get_lidar(self, idx):
assert(idx<self.num_samples)
lidar_filename = self.lidar_filenames[idx]
return utils.load_velo_scan(lidar_filename)
def get_calibration(self, unused):
return self.calib
def viz_kitti_video():
video_path = os.path.join(ROOT_DIR, 'dataset/2011_09_26/')
dataset = kitti_object_video(\
os.path.join(video_path, '2011_09_26_drive_0023_sync/image_02/data'),
os.path.join(video_path, '2011_09_26_drive_0023_sync/velodyne_points/data'),
video_path)
print(len(dataset))
for i in range(len(dataset)):
img = dataset.get_image(0)
pc = dataset.get_lidar(0)
Image.fromarray(img).show()
draw_lidar(pc)
raw_input()
pc[:,0:3] = dataset.get_calibration().project_velo_to_rect(pc[:,0:3])
draw_lidar(pc)
raw_input()
return
def show_image_with_boxes(img, objects, calib, show3d=True):
''' Show image with 2D bounding boxes '''
img1 = np.copy(img) # for 2d bbox
img2 = np.copy(img) # for 3d bbox
for obj in objects:
if obj.type=='DontCare':continue
cv2.rectangle(img1, (int(obj.xmin),int(obj.ymin)),
(int(obj.xmax),int(obj.ymax)), (0,255,0), 2)
box3d_pts_2d, box3d_pts_3d = utils.compute_box_3d(obj, calib.P)
img2 = utils.draw_projected_box3d(img2, box3d_pts_2d)
Image.fromarray(img1).show()
if show3d:
Image.fromarray(img2).show()
def get_lidar_in_image_fov(pc_velo, calib, xmin, ymin, xmax, ymax,
return_more=False, clip_distance=2.0):
''' Filter lidar points, keep those in image FOV '''
pts_2d = calib.project_velo_to_image(pc_velo)
fov_inds = (pts_2d[:,0]<xmax) & (pts_2d[:,0]>=xmin) & \
(pts_2d[:,1]<ymax) & (pts_2d[:,1]>=ymin)
fov_inds = fov_inds & (pc_velo[:,0]>clip_distance)
imgfov_pc_velo = pc_velo[fov_inds,:]
if return_more:
return imgfov_pc_velo, pts_2d, fov_inds
else:
return imgfov_pc_velo
def show_lidar_with_boxes(pc_velo, objects, calib,
img_fov=False, img_width=None, img_height=None):
''' Show all LiDAR points.
Draw 3d box in LiDAR point cloud (in velo coord system) '''
if 'mlab' not in sys.modules: import mayavi.mlab as mlab
from viz_util import draw_lidar_simple, draw_lidar, draw_gt_boxes3d
print(('All point num: ', pc_velo.shape[0]))
fig = mlab.figure(figure=None, bgcolor=(0,0,0),
fgcolor=None, engine=None, size=(1000, 500))
if img_fov:
pc_velo = get_lidar_in_image_fov(pc_velo, calib, 0, 0,
img_width, img_height)
print(('FOV point num: ', pc_velo.shape[0]))
draw_lidar(pc_velo, fig=fig)
for obj in objects:
if obj.type=='DontCare':continue
# Draw 3d bounding box
box3d_pts_2d, box3d_pts_3d = utils.compute_box_3d(obj, calib.P)
box3d_pts_3d_velo = calib.project_rect_to_velo(box3d_pts_3d)
# Draw heading arrow
ori3d_pts_2d, ori3d_pts_3d = utils.compute_orientation_3d(obj, calib.P)
ori3d_pts_3d_velo = calib.project_rect_to_velo(ori3d_pts_3d)
x1,y1,z1 = ori3d_pts_3d_velo[0,:]
x2,y2,z2 = ori3d_pts_3d_velo[1,:]
draw_gt_boxes3d([box3d_pts_3d_velo], fig=fig)
mlab.plot3d([x1, x2], [y1, y2], [z1,z2], color=(0.5,0.5,0.5),
tube_radius=None, line_width=1, figure=fig)
mlab.show(1)
def show_lidar_on_image(pc_velo, img, calib, img_width, img_height):
''' Project LiDAR points to image '''
imgfov_pc_velo, pts_2d, fov_inds = get_lidar_in_image_fov(pc_velo,
calib, 0, 0, img_width, img_height, True)
imgfov_pts_2d = pts_2d[fov_inds,:]
imgfov_pc_rect = calib.project_velo_to_rect(imgfov_pc_velo)
import matplotlib.pyplot as plt
cmap = plt.cm.get_cmap('hsv', 256)
cmap = np.array([cmap(i) for i in range(256)])[:,:3]*255
for i in range(imgfov_pts_2d.shape[0]):
depth = imgfov_pc_rect[i,2]
color = cmap[int(640.0/depth),:]
cv2.circle(img, (int(np.round(imgfov_pts_2d[i,0])),
int(np.round(imgfov_pts_2d[i,1]))),
2, color=tuple(color), thickness=-1)
Image.fromarray(img).show()
return img
def dataset_viz():
dataset = kitti_object(os.path.join(ROOT_DIR, 'dataset/KITTI/object')) #獲取資料庫各項路徑
for data_idx in range(len(dataset)): #從0開始到len獲取的資料庫樣本總數
# 從資料庫中載入資料
objects = dataset.get_label_objects(data_idx) #獲取data_idx對應的結果
objects[0].print_object() #在螢幕上輸出data_idx對應的第一個結果(如果有多個,修改[]內的值就可以變成對應的結果
img = dataset.get_image(data_idx) #獲取data_idx對應的圖片
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #影象格式轉換
img_height, img_width, img_channel = img.shape
print(('Image shape: ', img.shape))
pc_velo = dataset.get_lidar(data_idx)[:,0:3] #獲取data_idx對應的3D點雲
calib = dataset.get_calibration(data_idx)
# 在影象上畫出2d和3dboxes
show_image_with_boxes(img, objects, calib, False)
raw_input()
# Show all LiDAR points. Draw 3d box in LiDAR point cloud
show_lidar_with_boxes(pc_velo, objects, calib, True, img_width, img_height)
raw_input()
if __name__=='__main__':
import mayavi.mlab as mlab
from viz_util import draw_lidar_simple, draw_lidar, draw_gt_boxes3d
dataset_viz() #顯示資料