1. 程式人生 > >《機器學習實戰》中用matplotlib繪製決策樹, python3

《機器學習實戰》中用matplotlib繪製決策樹, python3

  人笨, 繪製樹形圖那裡的程式碼看了幾次也沒看懂(很多莫名其妙的(全域性?)變數), 然後就自己想辦法寫了個

import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties


def getTreeDB(mytree):
	"""
	利用遞迴獲取字典最大深度, 子葉數目
	:param mytree:一個字典樹, 或者樹的子葉節點(字元型)
	:return:返回 樹的深度, 子葉數目
	"""
	if not isinstance(mytree, dict):  # 如果是子葉節點, 返回1
		return 1, 1
	depth = []  # 儲存每條樹枝的深度
	leafs = 0  # 結點當前的子葉數目
	keys = list(mytree.keys())  # 獲取字典的鍵
	if len(keys) == 1:  # 如果鍵只有一個(說明是個結點而不是樹枝)
		mytree = mytree[keys[0]]  # 結點的value一定是樹枝(判斷的是每條支路的深度而不是結點)
	for key in mytree.keys():  # 遍歷每條樹枝
		res = getTreeDB(mytree[key])  # 獲取子樹的深度, 子葉數目
		depth.append(1 + res[0])  # 把每條樹枝的深度(加上自身)放在節點的深度集合中
		leafs += res[1]  # 累積子葉數目
	return max(depth), leafs  # 返回最大的深度值, 子葉數目


def plotArrow(what, xy1, xy2, which):
	"""
	畫一個帶文字描述的箭頭, 文字在箭頭中間
	:param what: 文字內容
	:param xy1: 箭頭起始座標
	:param xy2: 箭頭終點座標
	:param which: 箭頭所在的圖物件
	:return: suprise
	"""

	# 畫箭頭
	which.arrow(
		xy1[0], xy1[1], xy2[0] - xy1[0], xy2[1] - xy1[1],
		length_includes_head = True,  # 增加的長度包含箭頭部分
		head_width = 0.15, head_length = 0.5, fc = 'r', ec = 'brown')

	tx = (xy1[0] + xy2[0]) / 2
	ty = (xy1[1] + xy2[1]) / 2

	zhfont = FontProperties(fname = 'msyh.ttc')  # 顯示中文的方法

	# 畫文字
	which.annotate(
		what,
		size = 10,
		xy = (tx, ty),
		xytext = (-5, 5),  # 偏移量
		textcoords = 'offset points',
		bbox = dict(boxstyle = "square", ec = (1., 0.5, 0.5), fc = (1., 0.8, 0.8)),  # 外框, fc 內部顏色, ec 邊框顏色
		fontproperties = zhfont)  # 字型


def plotNode(what, xy, which, mod = 'any'):
	"""
	畫樹的節點
	:param what: 節點的內容
	:param xy: 節點的座標
	:param which: 節點所在的圖物件
	:param mod: 判斷節點是子葉還是非子葉(顏色不同)
	:return: suprise
	"""
	zhfont = FontProperties(fname = 'msyh.ttc')  # 顯示中文的方法, msyh.ttc是微軟雅黑的字型檔案
	if mod == 'leaf':
		color = 'yellow'
	else:
		color = 'greenyellow'
	which.text(
		xy[0], xy[1],
		what, size = 18,
		ha = "center", va = "center",
		bbox = dict(boxstyle = "round", ec = (1., 0.5, 0.5), fc = color),
		fontproperties = zhfont)


def plotInfo(what, which):
	"""
	提示圖中內容
	:param what: 子葉標籤
	:param which: 所在的圖物件
	:return: suprise
	"""
	what = '綠色: 特徵,  粉紅: 特徵值,  黃色: ' + what
	zhfont = FontProperties(fname = 'msyh.ttc')  # 顯示中文的方法
	which.text(
		2, 2,
		what, size = 18,
		ha = "center", va = "center",
		bbox = dict(boxstyle = "round", ec = (1., 0.5, 0.5), fc = '#BB91A6'),
		fontproperties = zhfont)


def plotTree(mytree, figxsize, figysize, what):
	"""
	利用遞迴畫決策樹
	所有子葉節點兩兩之間的間距都是xsize
	每一層節點之間的間距都是ysize
	子葉節點的數目都是確定的, 所以橫座標也是確定的, 從左往右第leafnum個子葉節點的橫座標x = leafs * xsize
	非子葉節點的橫座標由該節點孩子的橫座標確定, x = 孩子橫座標平均值
	每一層節點的縱座標由層數deep確定, y = ylen - deep * ysize, 其中ylen為畫板高度
	:param mytree: 要畫的字典樹
	:param figxsize: 畫布的x長度    (兩者會影響顯示效果)
	:param figysize: 畫布的y長度    (這兩個值很影響樹的分佈,(不宜過大)(?) ))
	:param what: 子葉的標籤(用於提示圖的結果是什麼)
	:return: suprise
	"""

	def plotAll(subtree, deep, leafnum):
		"""
		內部函式, 遞迴畫圖, 會使用外部的變數
		:param subtree: 要畫的子樹
		:param deep: 子樹根節點所在的深度
		:param leafnum: 下一個子葉節點從左到右的排號(用來決定下一個子葉節點的橫座標)
		:return:suprise
		"""
		if not isinstance(subtree, dict):  # 如果是子葉節點(非字典)
			x = leafnum * xsize  # 計算橫座標
			y = ylen - deep * ysize  # 計算縱座標
			plotNode(subtree, (x, y), ax, 'leaf')  # 畫節點
			return x, y, leafnum + 1  # 返回子葉節點的座標, 已畫子葉數目+1

		key = list(subtree.keys())  # 獲取子樹的根節點的鍵(節點的名稱)
		if len(key) != 1:  # 傳進來的子樹應該只有一個根節點
			raise TypeError("非字典樹")  # 不滿足就報錯
		xlist = []  # 儲存根節點孩子的橫座標
		ylist = []  # 儲存根節點孩子的縱座標
		keyvalue = subtree[key[0]]  # 根節點的孩子(子字典, 子字典的key為權值, value為子樹)
		for k in keyvalue:  # k為每一格權值(每一個選擇)
			res = plotAll(keyvalue[k], deep + 1, leafnum)  # 獲取這個孩子的座標
			leafnum = res[2]  # 更新已畫的子葉樹
			xlist.append(res[0])  # 儲存孩子的座標
			ylist.append(res[1])
		x = sum(xlist) / len(xlist)  # 求平均得出該根節點的橫座標
		y = ylen - deep * 3  # 計算該根節點的縱座標
		plotNode(key[0], (x, y), ax)  # 畫該節點

		i = 0
		for k in keyvalue:  # 依次畫出根節點與孩子之間的箭頭
			plotArrow(k, (x, y), (xlist[i], ylist[i]), ax)
			i += 1

		return x, y, leafnum  # 返回該節點的座標

	xsize, ysize = 4, 3  # 預設子葉間距為4, 每層的間距為3 (設定為這兩個值的原因...我覺得這樣好看些...可以試試別的值)
	fig = plt.figure(figsize = (figxsize, figysize))  # 一張畫布
	axprops = dict(xticks = [], yticks = [])  # 橫縱座標顯示的數字(設定為空, 不顯示)
	ax = fig.add_subplot(111, frameon = False, **axprops)  # 隱藏座標軸
	depth, leaf = getTreeDB(mytree)  # 獲取深度, 子葉節點數目
	xlen, ylen = 4 * (leaf + 1), 3 * (depth + 1)  # 計算橫縱間距
	ax.set_xlim(0, xlen)  # 設定座標系x, y的範圍
	ax.set_ylim(0, ylen)
	plotAll(mytree, 1, 1)  # 畫樹
	plotInfo(what, ax)  # 提示標籤
	plt.show()  # show show show show show


testtree = {'有自己的房子': {0: {'有工作': {0: 'no', 1: 'yes'}}, 1: 'yes'}}  # 一個樹
testlabel = ['年齡', '有工作', '有自己的房子', '信貸情況']  #訓練資料的標籤
plotTree(testtree, 10, 6, testlabel[-1])

看起來還是不錯

程式碼的註釋可能有(fei)點(chang)令人費解... 有問題的地方很多...