《機器學習實戰》中用matplotlib繪製決策樹, python3
阿新 • • 發佈:2018-12-12
人笨, 繪製樹形圖那裡的程式碼看了幾次也沒看懂(很多莫名其妙的(全域性?)變數), 然後就自己想辦法寫了個
import matplotlib.pyplot as plt from matplotlib.font_manager import FontProperties def getTreeDB(mytree): """ 利用遞迴獲取字典最大深度, 子葉數目 :param mytree:一個字典樹, 或者樹的子葉節點(字元型) :return:返回 樹的深度, 子葉數目 """ if not isinstance(mytree, dict): # 如果是子葉節點, 返回1 return 1, 1 depth = [] # 儲存每條樹枝的深度 leafs = 0 # 結點當前的子葉數目 keys = list(mytree.keys()) # 獲取字典的鍵 if len(keys) == 1: # 如果鍵只有一個(說明是個結點而不是樹枝) mytree = mytree[keys[0]] # 結點的value一定是樹枝(判斷的是每條支路的深度而不是結點) for key in mytree.keys(): # 遍歷每條樹枝 res = getTreeDB(mytree[key]) # 獲取子樹的深度, 子葉數目 depth.append(1 + res[0]) # 把每條樹枝的深度(加上自身)放在節點的深度集合中 leafs += res[1] # 累積子葉數目 return max(depth), leafs # 返回最大的深度值, 子葉數目 def plotArrow(what, xy1, xy2, which): """ 畫一個帶文字描述的箭頭, 文字在箭頭中間 :param what: 文字內容 :param xy1: 箭頭起始座標 :param xy2: 箭頭終點座標 :param which: 箭頭所在的圖物件 :return: suprise """ # 畫箭頭 which.arrow( xy1[0], xy1[1], xy2[0] - xy1[0], xy2[1] - xy1[1], length_includes_head = True, # 增加的長度包含箭頭部分 head_width = 0.15, head_length = 0.5, fc = 'r', ec = 'brown') tx = (xy1[0] + xy2[0]) / 2 ty = (xy1[1] + xy2[1]) / 2 zhfont = FontProperties(fname = 'msyh.ttc') # 顯示中文的方法 # 畫文字 which.annotate( what, size = 10, xy = (tx, ty), xytext = (-5, 5), # 偏移量 textcoords = 'offset points', bbox = dict(boxstyle = "square", ec = (1., 0.5, 0.5), fc = (1., 0.8, 0.8)), # 外框, fc 內部顏色, ec 邊框顏色 fontproperties = zhfont) # 字型 def plotNode(what, xy, which, mod = 'any'): """ 畫樹的節點 :param what: 節點的內容 :param xy: 節點的座標 :param which: 節點所在的圖物件 :param mod: 判斷節點是子葉還是非子葉(顏色不同) :return: suprise """ zhfont = FontProperties(fname = 'msyh.ttc') # 顯示中文的方法, msyh.ttc是微軟雅黑的字型檔案 if mod == 'leaf': color = 'yellow' else: color = 'greenyellow' which.text( xy[0], xy[1], what, size = 18, ha = "center", va = "center", bbox = dict(boxstyle = "round", ec = (1., 0.5, 0.5), fc = color), fontproperties = zhfont) def plotInfo(what, which): """ 提示圖中內容 :param what: 子葉標籤 :param which: 所在的圖物件 :return: suprise """ what = '綠色: 特徵, 粉紅: 特徵值, 黃色: ' + what zhfont = FontProperties(fname = 'msyh.ttc') # 顯示中文的方法 which.text( 2, 2, what, size = 18, ha = "center", va = "center", bbox = dict(boxstyle = "round", ec = (1., 0.5, 0.5), fc = '#BB91A6'), fontproperties = zhfont) def plotTree(mytree, figxsize, figysize, what): """ 利用遞迴畫決策樹 所有子葉節點兩兩之間的間距都是xsize 每一層節點之間的間距都是ysize 子葉節點的數目都是確定的, 所以橫座標也是確定的, 從左往右第leafnum個子葉節點的橫座標x = leafs * xsize 非子葉節點的橫座標由該節點孩子的橫座標確定, x = 孩子橫座標平均值 每一層節點的縱座標由層數deep確定, y = ylen - deep * ysize, 其中ylen為畫板高度 :param mytree: 要畫的字典樹 :param figxsize: 畫布的x長度 (兩者會影響顯示效果) :param figysize: 畫布的y長度 (這兩個值很影響樹的分佈,(不宜過大)(?) )) :param what: 子葉的標籤(用於提示圖的結果是什麼) :return: suprise """ def plotAll(subtree, deep, leafnum): """ 內部函式, 遞迴畫圖, 會使用外部的變數 :param subtree: 要畫的子樹 :param deep: 子樹根節點所在的深度 :param leafnum: 下一個子葉節點從左到右的排號(用來決定下一個子葉節點的橫座標) :return:suprise """ if not isinstance(subtree, dict): # 如果是子葉節點(非字典) x = leafnum * xsize # 計算橫座標 y = ylen - deep * ysize # 計算縱座標 plotNode(subtree, (x, y), ax, 'leaf') # 畫節點 return x, y, leafnum + 1 # 返回子葉節點的座標, 已畫子葉數目+1 key = list(subtree.keys()) # 獲取子樹的根節點的鍵(節點的名稱) if len(key) != 1: # 傳進來的子樹應該只有一個根節點 raise TypeError("非字典樹") # 不滿足就報錯 xlist = [] # 儲存根節點孩子的橫座標 ylist = [] # 儲存根節點孩子的縱座標 keyvalue = subtree[key[0]] # 根節點的孩子(子字典, 子字典的key為權值, value為子樹) for k in keyvalue: # k為每一格權值(每一個選擇) res = plotAll(keyvalue[k], deep + 1, leafnum) # 獲取這個孩子的座標 leafnum = res[2] # 更新已畫的子葉樹 xlist.append(res[0]) # 儲存孩子的座標 ylist.append(res[1]) x = sum(xlist) / len(xlist) # 求平均得出該根節點的橫座標 y = ylen - deep * 3 # 計算該根節點的縱座標 plotNode(key[0], (x, y), ax) # 畫該節點 i = 0 for k in keyvalue: # 依次畫出根節點與孩子之間的箭頭 plotArrow(k, (x, y), (xlist[i], ylist[i]), ax) i += 1 return x, y, leafnum # 返回該節點的座標 xsize, ysize = 4, 3 # 預設子葉間距為4, 每層的間距為3 (設定為這兩個值的原因...我覺得這樣好看些...可以試試別的值) fig = plt.figure(figsize = (figxsize, figysize)) # 一張畫布 axprops = dict(xticks = [], yticks = []) # 橫縱座標顯示的數字(設定為空, 不顯示) ax = fig.add_subplot(111, frameon = False, **axprops) # 隱藏座標軸 depth, leaf = getTreeDB(mytree) # 獲取深度, 子葉節點數目 xlen, ylen = 4 * (leaf + 1), 3 * (depth + 1) # 計算橫縱間距 ax.set_xlim(0, xlen) # 設定座標系x, y的範圍 ax.set_ylim(0, ylen) plotAll(mytree, 1, 1) # 畫樹 plotInfo(what, ax) # 提示標籤 plt.show() # show show show show show testtree = {'有自己的房子': {0: {'有工作': {0: 'no', 1: 'yes'}}, 1: 'yes'}} # 一個樹 testlabel = ['年齡', '有工作', '有自己的房子', '信貸情況'] #訓練資料的標籤 plotTree(testtree, 10, 6, testlabel[-1])
看起來還是不錯
程式碼的註釋可能有(fei)點(chang)令人費解... 有問題的地方很多...