1. 程式人生 > >frustum pointnets訓練程式碼學習筆記——kitti_object.py

frustum pointnets訓練程式碼學習筆記——kitti_object.py

frustum pointnets訓練程式碼學習筆記——kitti_object.py

本文記錄了博主學習frustum pointnets過程中遇到的2D和3D資料庫顯示程式。為了畫出輸出結果,博主希望在這個程式的基礎上修改一個可以顯示結果的程式。更新於2018.09.22。

本文首先給出程式碼原文的學習筆記,隨後整理出修改後的結果顯示程式,如果公開,會在這裡放上鍊接,如果有幫助,請在程式碼頁面點一下小星星哦。附可能有用的資訊:各個集合所用的檔名在kitti/image_sets資料夾下。

文章目錄

總結

把總結寫在前面,根據需要判斷是否需要詳細看原始碼分析。

  • 這個檔案的主要功能就是將KITTI庫中的2d和3d結果畫出來(至於是畫training還是testing,檔案的kitti_object函式的初始函式和get_label_objects分別有定義和判斷)。
  • 檔案個數是人為在kitti_object函式中設定的,並非自動提取。
  • 畫圖是機械第從第一個圖片一直向後顯示,且如果該圖片中有多個目標,僅顯示txt檔案中排在第一的那個的資料。通過修改objects[0].print_object()中[]裡面的標號可以指定畫第幾個目標,不過要注意的是,每個圖片中含有的目標個數是不同的。

用到的語法規則

這一部分記錄了程式碼原文中出現的語法規則,並不影響程式碼功能的理解,但是可能方便日後的使用,因此在這裡記錄下來。

from __future__ import print_function

加上這句話以後,即使在python2.X也要像python3.X一樣的語法使用print函式(加括號)。類似地,如果有其他新的功能特性且該特性與當前版本中的使用不相容,就可以從future模組匯入。詳細說明參考這裡

from PIL import Image

PIL已經是python平臺事實上的影象處理標準庫了,全稱為Python Imaging Library。具體的使用方法說明可以參考這裡

BASE_DIR = os.path.dirname(os.path.abspath(__file__))

其中,__file__就是當前所執行的檔案,也就是kitti_object.pyos.path.abspath命令獲取的是當前檔案的絕對路徑,比如博主的執行結果:

>>> print os.path.abspath("/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py")
/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py

而前面的os.path.dirname獲取的則是當前路徑所存在於的資料夾,因此,BASE_DIR指向的就是kitti_object.py所處的檔案夾了。執行結果為:

>>> print os.path.dirname(os.path.abspath("/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py"))
/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti

sys.path.append(os.path.join(ROOT_DIR, 'mayavi'))

其中,os.path.join用於路徑拼接。
sys.path.append:在匯入一個模組時,預設情況下python會搜尋當前目錄、已安裝的內建模組和第三方模組,搜尋路徑存放在sys模組的path中。如果要用的模組和當前指令碼不在一個目錄下,就需要將其新增到path中。這種修改是臨時的,指令碼執行後失效。

img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

opencv中提供了cvtColor函式用於實現影象格式型別的相互轉換,具體說明可以參照這裡


程式碼原文分析

#程式碼作者資訊
''' Helper class and functions for loading KITTI objects

Author: Charles R. Qi
Date: September 2017
'''

#載入必要的庫
from __future__ import print_function			

import os
import sys
import numpy as np
import cv2
from PIL import Image
#定義基礎路徑
BASE_DIR = os.path.dirname(os.path.abspath(__file__))			#指向當前檔案所在資料夾(kitti)
ROOT_DIR = os.path.dirname(BASE_DIR)			#指向frustum資料夾
sys.path.append(os.path.join(ROOT_DIR, 'mayavi'))
import kitti_util as utils			#載入論文作者寫的庫

try:
	raw_input          # Python 2
except NameError:
   	raw_input = input  # Python 3


#用於獲取資料庫各項路徑路徑(training或testing)
class kitti_object(object):
	'''Load and parse object data into a usable format.'''

	def __init__(self, root_dir, split='training'):
    	'''root_dir contains training and testing folders'''
    	self.root_dir = root_dir
    	self.split = split
    	self.split_dir = os.path.join(root_dir, split)

    	if split == 'training':
        	self.num_samples = 7481
    	elif split == 'testing':
        	self.num_samples = 7518
    	else:
        	print('Unknown split: %s' % (split))
        	exit(-1)

    	self.image_dir = os.path.join(self.split_dir, 'image_2')
    	self.calib_dir = os.path.join(self.split_dir, 'calib')
    	self.lidar_dir = os.path.join(self.split_dir, 'velodyne')
    	self.label_dir = os.path.join(self.split_dir, 'label_2')

	# 用於後面獲取樣本庫內的樣本總數
	def __len__(self):
    	return self.num_samples

	def get_image(self, idx):
    	assert(idx<self.num_samples) 
    	img_filename = os.path.join(self.image_dir, '%06d.png'%(idx))
    	return utils.load_image(img_filename)

	def get_lidar(self, idx): 
    	assert(idx<self.num_samples) 
    	lidar_filename = os.path.join(self.lidar_dir, '%06d.bin'%(idx))
    	return utils.load_velo_scan(lidar_filename)

	def get_calibration(self, idx):
    	assert(idx<self.num_samples) 
    	calib_filename = os.path.join(self.calib_dir, '%06d.txt'%(idx))
    	return utils.Calibration(calib_filename)

	# 獲取idx指示的樣本對應的label檔案路徑,並按行讀取檔案,返回檔案內容
	def get_label_objects(self, idx):
    	assert(idx<self.num_samples and self.split=='training') 
    	label_filename = os.path.join(self.label_dir, '%06d.txt'%(idx))
    	return utils.read_label(label_filename)
    
	def get_depth_map(self, idx):
    	pass

	def get_top_down(self, idx):
    	pass

class kitti_object_video(object):
    ''' Load data for KITTI videos '''
	def __init__(self, img_dir, lidar_dir, calib_dir):
    	self.calib = utils.Calibration(calib_dir, from_video=True)
    	self.img_dir = img_dir
    	self.lidar_dir = lidar_dir
    	self.img_filenames = sorted([os.path.join(img_dir, filename) \
        	for filename in os.listdir(img_dir)])
   		self.lidar_filenames = sorted([os.path.join(lidar_dir, filename) \
        	for filename in os.listdir(lidar_dir)])
    	print(len(self.img_filenames))
    	print(len(self.lidar_filenames))
    	#assert(len(self.img_filenames) == len(self.lidar_filenames))
    	self.num_samples = len(self.img_filenames)

	def __len__(self):
    	return self.num_samples

	def get_image(self, idx):
    	assert(idx<self.num_samples) 
    	img_filename = self.img_filenames[idx]
    	return utils.load_image(img_filename)

	def get_lidar(self, idx): 
    	assert(idx<self.num_samples) 
    	lidar_filename = self.lidar_filenames[idx]
    	return utils.load_velo_scan(lidar_filename)

	def get_calibration(self, unused):
    	return self.calib

def viz_kitti_video():
	video_path = os.path.join(ROOT_DIR, 'dataset/2011_09_26/')
	dataset = kitti_object_video(\
    	os.path.join(video_path, '2011_09_26_drive_0023_sync/image_02/data'),
    	os.path.join(video_path, '2011_09_26_drive_0023_sync/velodyne_points/data'),
    	video_path)
	print(len(dataset))
	for i in range(len(dataset)):
    	img = dataset.get_image(0)
    	pc = dataset.get_lidar(0)
    	Image.fromarray(img).show()
    	draw_lidar(pc)
    	raw_input()
    	pc[:,0:3] = dataset.get_calibration().project_velo_to_rect(pc[:,0:3])
    	draw_lidar(pc)
    	raw_input()
	return

def show_image_with_boxes(img, objects, calib, show3d=True):
	''' Show image with 2D bounding boxes '''
	img1 = np.copy(img) # for 2d bbox
	img2 = np.copy(img) # for 3d bbox
	for obj in objects:
    	if obj.type=='DontCare':continue
    	cv2.rectangle(img1, (int(obj.xmin),int(obj.ymin)),
        	(int(obj.xmax),int(obj.ymax)), (0,255,0), 2)
    	box3d_pts_2d, box3d_pts_3d = utils.compute_box_3d(obj, calib.P)
    	img2 = utils.draw_projected_box3d(img2, box3d_pts_2d)
	Image.fromarray(img1).show()
	if show3d:
    	Image.fromarray(img2).show()

def get_lidar_in_image_fov(pc_velo, calib, xmin, ymin, xmax, ymax,
                       return_more=False, clip_distance=2.0):
	''' Filter lidar points, keep those in image FOV '''
	pts_2d = calib.project_velo_to_image(pc_velo)
	fov_inds = (pts_2d[:,0]<xmax) & (pts_2d[:,0]>=xmin) & \
    	(pts_2d[:,1]<ymax) & (pts_2d[:,1]>=ymin)
	fov_inds = fov_inds & (pc_velo[:,0]>clip_distance)
	imgfov_pc_velo = pc_velo[fov_inds,:]
	if return_more:
    	return imgfov_pc_velo, pts_2d, fov_inds
	else:
    	return imgfov_pc_velo

def show_lidar_with_boxes(pc_velo, objects, calib,
                      img_fov=False, img_width=None, img_height=None): 
	''' Show all LiDAR points.
    	Draw 3d box in LiDAR point cloud (in velo coord system) '''
	if 'mlab' not in sys.modules: import mayavi.mlab as mlab
	from viz_util import draw_lidar_simple, draw_lidar, draw_gt_boxes3d

	print(('All point num: ', pc_velo.shape[0]))
	fig = mlab.figure(figure=None, bgcolor=(0,0,0),
    	fgcolor=None, engine=None, size=(1000, 500))
	if img_fov:
    	pc_velo = get_lidar_in_image_fov(pc_velo, calib, 0, 0,
        	img_width, img_height)
    	print(('FOV point num: ', pc_velo.shape[0]))
	draw_lidar(pc_velo, fig=fig)

	for obj in objects:
    	if obj.type=='DontCare':continue
    	# Draw 3d bounding box
    	box3d_pts_2d, box3d_pts_3d = utils.compute_box_3d(obj, calib.P) 
    	box3d_pts_3d_velo = calib.project_rect_to_velo(box3d_pts_3d)
    	# Draw heading arrow
    	ori3d_pts_2d, ori3d_pts_3d = utils.compute_orientation_3d(obj, calib.P)
    	ori3d_pts_3d_velo = calib.project_rect_to_velo(ori3d_pts_3d)
    	x1,y1,z1 = ori3d_pts_3d_velo[0,:]
    	x2,y2,z2 = ori3d_pts_3d_velo[1,:]
    	draw_gt_boxes3d([box3d_pts_3d_velo], fig=fig)
   	 	mlab.plot3d([x1, x2], [y1, y2], [z1,z2], color=(0.5,0.5,0.5),
        	tube_radius=None, line_width=1, figure=fig)
	mlab.show(1)

def show_lidar_on_image(pc_velo, img, calib, img_width, img_height):
	''' Project LiDAR points to image '''
	imgfov_pc_velo, pts_2d, fov_inds = get_lidar_in_image_fov(pc_velo,
    	calib, 0, 0, img_width, img_height, True)
	imgfov_pts_2d = pts_2d[fov_inds,:]
	imgfov_pc_rect = calib.project_velo_to_rect(imgfov_pc_velo)

	import matplotlib.pyplot as plt
	cmap = plt.cm.get_cmap('hsv', 256)
	cmap = np.array([cmap(i) for i in range(256)])[:,:3]*255

	for i in range(imgfov_pts_2d.shape[0]):
    	depth = imgfov_pc_rect[i,2]
    	color = cmap[int(640.0/depth),:]
    	cv2.circle(img, (int(np.round(imgfov_pts_2d[i,0])),
        	int(np.round(imgfov_pts_2d[i,1]))),
        	2, color=tuple(color), thickness=-1)
	Image.fromarray(img).show() 
	return img

def dataset_viz():
	dataset = kitti_object(os.path.join(ROOT_DIR, 'dataset/KITTI/object'))			#獲取資料庫各項路徑

	for data_idx in range(len(dataset)):			#從0開始到len獲取的資料庫樣本總數
    	# 從資料庫中載入資料
    	objects = dataset.get_label_objects(data_idx)			#獲取data_idx對應的結果
    	objects[0].print_object()			#在螢幕上輸出data_idx對應的第一個結果(如果有多個,修改[]內的值就可以變成對應的結果
    	img = dataset.get_image(data_idx)			#獲取data_idx對應的圖片
    	img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 			#影象格式轉換
    	img_height, img_width, img_channel = img.shape
    	print(('Image shape: ', img.shape))
    	pc_velo = dataset.get_lidar(data_idx)[:,0:3]			#獲取data_idx對應的3D點雲
    	calib = dataset.get_calibration(data_idx)

    	# 在影象上畫出2d和3dboxes
    	show_image_with_boxes(img, objects, calib, False)
    	raw_input()
    	# Show all LiDAR points. Draw 3d box in LiDAR point cloud
    	show_lidar_with_boxes(pc_velo, objects, calib, True, img_width, img_height)
    	raw_input()

if __name__=='__main__':
	import mayavi.mlab as mlab
	from viz_util import draw_lidar_simple, draw_lidar, draw_gt_boxes3d
	dataset_viz()			#顯示資料