1. 程式人生 > >機器學習筆記(6)——C4.5決策樹中的剪枝處理和Python實現

機器學習筆記(6)——C4.5決策樹中的剪枝處理和Python實現

1. 為什麼要剪枝

還記得決策樹的構造過程嗎?為了儘可能正確分類訓練樣本,節點的劃分過程會不斷重複直到不能再分,這樣就可能對訓練樣本學習的“太好”了,把訓練樣本的一些特點當做所有資料都具有的一般性質,cong從而導致過擬合。這時就可以通過剪枝處理去掉yi一些分支來降低過擬合的風險。

剪枝的基本策略有“預剪枝”(prepruning)和“後剪枝”(post-pruning):

預剪枝是在決策樹的生成過程中,對每個結點劃分前先做評估,如果劃分不能提升決策樹的泛化效能,就停止劃分並將此節點記為葉節點;

後剪枝是在決策樹構造完成後,自底向上對非葉節點進行評估,如果將其換成葉節點能提升泛化效能,則將該子樹換成葉節點。

那麼怎麼判斷泛化效能是否提升呢?這時需要將資料集分為訓練集和驗證集,利用訓練集構造決策樹,利用驗證集來評估剪枝前後的驗證集精度(即正確分類的比例)。

下面我們把之前的西瓜資料集劃分為訓練集和驗證集,之後在分別詳細演示預剪枝和後剪枝的處理過程。

首先利用訓練集資料,構造一個未做剪枝處理的決策樹,以便於與剪枝後的決策樹做對比。

注意:這裡構造的決策樹與《機器學習》中的不一樣,因為色澤、根蒂、臍部三個屬性的資訊增益是相等的,都可以作為最優劃分屬性。

2. 預剪枝

我們先學習預剪枝的過程:

(1)根據資訊增益準則,選取“色澤”作為根節點進行劃分,會產生3個分支(青綠、烏黑、淺白)。

對根節點“色澤”,若不劃分,該節點被標記為葉節點,訓練集中正負樣本數相等,我們將其標記為“是”好瓜(當樣本最多的類不唯一時,可任選其中一類,我們預設都選正類)。那麼訓練集的7個樣本中,3個正樣本被正確分類,驗證集精度為3/7*100%=42.9%。

對根節點“色澤”劃分後,產生圖中的3個分支,訓練集中的7個樣本中,編號為{8,11,12,4}的4個樣本被正確分類,驗證集精度為4/7*100%=57.1%。

於是節點“色澤”應該進行劃分。

(2)再看“色澤”為烏黑這個分支,如果對其進行劃分,選擇“根蒂”作為劃分屬性

對“根蒂”這個分支節點,如果不劃分,驗證集精度為57.1%。如果劃分,進入此分支的兩個樣例{8,9},編號為8的樣例分類正確,編號為9的樣例分類錯誤,所以對整棵樹來說編號為{4,8,11,12}的4個樣本分類正確,驗證集精度仍為57.1%。

按預剪枝的策略,驗證集精度沒有提升的話,不再劃分。

(3)“色澤”為淺白的分支只有一個類別,無法再劃分。再評估“色澤”為青綠的分支,如果對其進行劃分,選擇“敲聲”作為劃分sh屬性,產生3個分支。

對“敲聲”這個分支節點,如果不劃分,驗證集精度為57.1%。如果劃分,進入此分支的兩個樣例{4,13},編號為4的樣例分類錯誤,編號為13的樣例也分類錯誤,所以對整棵樹來說編號為{8,11,12}的4個樣本分類正確,驗證集精度仍為42.9%。

按預剪枝的策略,驗證集精度沒有提升的話,不再劃分。

因此,通過預剪枝處理生成的樹只有一個根節點,這種樹也稱“決策樹樁”。

優缺點分析:預剪枝使得決策樹的很多分支沒有展開,可以降低過擬合的風險,減少決策樹的訓練時間和測試時間。但是,儘管有些分支的劃分不能提升泛化效能,但是後續劃分可能使效能顯著提高,由於預剪枝沒有展開這些分支,帶來了欠擬合的風險。

3. 後剪枝

後剪枝是在決策樹構造完成後,自底向上對非葉節點進行評估,為了方便分析,我們對樹中的非葉子節點進行編號,然後依次評估其是否需要剪枝。

對於完整的決策樹,在剪枝前,編號為{11,12}的兩個樣本被正確分類,因此其驗證集精度為2/7*100%=28.6%。

(1)第一步先考察編號為4的結點,如果剪掉該分支,該結點應被標記為“是”。進入該分支的驗證集樣本有{8,9},樣本8被正確分類,對整個驗證集,編號為{8,11,12}的樣本正確分類,因此驗證集精度提升為42.9%,決定剪掉該分支。

(2)再來考察編號為3的結點,如果剪掉該分支,該結點應標記為“是”,進入該分支的樣本有{4,13},其中樣本4被正確分類,對整個驗證集,編號為{4,8,11,12}的樣本正確分類,因此驗證集精度提升為57.1%,決定剪掉該分支。

(3)再看編號為2的結點,如果剪掉該分支,該結點應標記為“是”,進入該分支的樣本有{8,9},其中樣本8被正確分類,樣本9被錯誤分類,對整個驗證集,編號為{4,8,11,12}的樣本正確分類,驗證集精度仍為57.1%,沒有提升,因此不做剪枝。

(4)對編號為1的結點,如果對其剪枝,其驗證集精度為42.9%(同預剪枝的第一步),因此也不剪枝。

後剪枝得到的決策樹就是第(3)步的樣子。

優缺點分析:後剪枝通常比預剪枝保留更多的分支,欠擬合風險小。但是後剪枝shi是在決策樹構造完成後進行的,其訓練時間的開銷會大於預剪枝。

4. 後剪枝的Python實現

由於後剪枝的泛化能力高於預剪枝,這裡只對後剪枝程式設計。為了方便評估,上述過程並沒有包含連續屬性,但是C4.5決策樹是可以處理連續sh屬性的,因此我們在程式設計中把連續屬性也一併考慮進去。

def postPruningTree(inputTree, dataSet, data_test, labels, labelProperties):
    """ 
    type: (dict, list, list, list, list) -> dict
    inputTree: 已構造的樹
    dataSet: 訓練集
    data_test: 驗證集
    labels: 屬性標籤
    labelProperties: 屬性類別
    """
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    classList = [example[-1] for example in dataSet]
    featkey = copy.deepcopy(firstStr)
    if '<' in firstStr:  # 對連續的特徵值,使用正則表示式獲得特徵標籤和value
        featkey = re.compile("(.+<)").search(firstStr).group()[:-1]
        featvalue = float(re.compile("(<.+)").search(firstStr).group()[1:])
    labelIndex = labels.index(featkey)
    temp_labels = copy.deepcopy(labels)
    temp_labelProperties = copy.deepcopy(labelProperties)
    if labelProperties[labelIndex] == 0:  # 離散特徵
        del (labels[labelIndex])
        del (labelProperties[labelIndex])
    for key in secondDict.keys():  # 對每個分支
        if type(secondDict[key]).__name__ == 'dict':  # 如果不是葉子節點
            if temp_labelProperties[labelIndex] == 0:  # 離散的
                subDataSet = splitDataSet(dataSet, labelIndex, key)
                subDataTest = splitDataSet(data_test, labelIndex, key)
            else:
                if key == 'yes':
                    subDataSet = splitDataSet_c(dataSet, labelIndex, featvalue,
                                               'L')
                    subDataTest = splitDataSet_c(data_test, labelIndex,
                                                featvalue, 'L')
                else:
                    subDataSet = splitDataSet_c(dataSet, labelIndex, featvalue,
                                               'R')
                    subDataTest = splitDataSet_c(data_test, labelIndex,
                                                featvalue, 'R')
            inputTree[firstStr][key] = postPruningTree(secondDict[key],
                                                       subDataSet, subDataTest,
                                                       copy.deepcopy(labels),
                                                       copy.deepcopy(
                                                           labelProperties))
    if testing(inputTree, data_test, temp_labels,
               temp_labelProperties) <= testingMajor(majorityCnt(classList),
                                                     data_test):
        return inputTree
    return majorityCnt(classList)

執行程式繪製出剪枝後的決策樹,與上面人工繪製的一致。

參考:

周志華《機器學習》