1. 程式人生 > >R語言 決策樹及其實現

R語言 決策樹及其實現

一顆決策樹包含一個根結點、若干個內部結點和若干個葉結點;葉結點對應於決策結果,其他每個結點則對應於一個屬性測試;每個結點包含的樣本集合根據屬性測試的結果被劃分到子結點中;根結點包含樣本全集。從根結點到葉結點的路徑對應於了一個判定測試序列。

目的:為了產生一顆泛化能力強,即處理未見示例能力強的據決策樹。

特別注意幾點:

1)通常所說的屬性是離散,若屬性是連續,則要把屬性離散化,最簡單的是是採用二分法(找劃分點)

2)缺失值處理

決策樹是一個遞迴過程,以下三種情形會導致遞迴返回:

1)當前結點包含的樣本屬於同一類別,無需劃分;

2)當前屬性集為空,或是所有樣本在所有屬性上取值相同,無法劃分;

3)當前結點包含的樣本集合為空,不能劃分。

資訊增益:一般而言,資訊增益越大,則意味著使用屬性a來劃分所獲得的“純度提升”越大

增益率:與資訊增益的原理一樣,但增益率可以校正存在偏向於選擇取值較多的特徵的問題

剪枝處理

1)預剪枝

在決策樹生成過程中,對每個結點在劃分前先進行估計,若當前結點的劃分不能帶來決策樹泛化效能提升,則停止劃分並將當前結點標記為葉結點。

2)後剪枝

先從訓練集生成一顆完整的決策樹,然後自底向上地對非葉結點進行考察,若將該結點對應的子樹替換為葉結點能帶來決策樹泛化效能提升,則將該子樹替換為葉結點。

R語言實現 

  library(C50); library(rpart); library(party); library(rpart.plot)
  library(caret)
  
  # 載入資料
  car <- read.table('./data/car.data', sep = ',')
  colnames(car) <- c('buy', 'main', 'doors', 'capacity', 'lug_boot', 'safety', 'accept')
  
  # 資料集分為測試和訓練
  ind <- createDataPartition(car$accept, times = 1, p = 0.75, list = FALSE)
  carTR <- car[ind, ]
  carTE <- car[-ind, ]
  
  # 建立模型

  # 決策樹
  # rpart包
  # 在rpart包中有函式rpart.control預剪枝,prune後剪枝
  #
  # 預剪枝:
  # rpart.control對樹進行一些設定  
  # minsplit是最小分支節點數,這裡指大於等於20,那麼該節點會繼續分劃下去,否則停止  
  # minbucket:樹中葉節點包含的最小樣本數  
  # maxdepth:決策樹最大深度 
  # xval:交叉驗證的次數
  # cp全稱為complexity parameter,指某個點的複雜度,對每一步拆分,模型的擬合優度必須提高的程度
  #
  # 後剪枝:
  # 主要是調節引數是cp
  # prune函式可以實現最小代價複雜度剪枝法,對於CART的結果,每個節點均輸出一個對應的cp
  # prune函式通過設定cp引數來對決策樹進行修剪,cp為複雜度係數
  tc <- rpart.control(minsplit = 20, minbucket = 20, maxdepth = 10, xval = 5, cp = 0.005) # 預剪枝
  rpart.model <- rpart(accept ~ ., data = carTR, control = tc)
  rpart.model <- prune(rpart.model, 
                       cp = rpart.model$cptable[which.min(rpart.model$cptable[,"xerror"]),"CP"]) # 後剪枝
  rpart.plot(rpart.model, under = TRUE, faclen = 0, cex = 0.5, main = "決策樹") # 畫圖
  
  # C5.0
  # C5.0包
  c5.0.model <- C5.0(accept ~ ., data = carTR) # C5.0
  plot(c5.0.model)
  
  # 使用ctree函式實現條件推理決策樹演算法
  # party包
  ctree.model <- ctree(accept ~ ., data = carTR)
  
  # 預測結果,並構建混淆矩陣,查看準確率
  # 構建result,存放預測結果
  result <- data.frame(arithmetic = c('C5.0', 'CART', 'ctree'), errTR = rep(0, 3),errTE = rep(0, 3))
  
  for (i in 1:3) {
    # 預測結果
    carTR_predict <- predict(switch(i, c5.0.model, rpart.model, ctree.model), newdata = carTR,
                             type = switch(i, 'class', 'class', 'response'))
    carTE_predict <- predict(switch(i, c5.0.model, rpart.model, ctree.model), newdata = carTE,
                             type = switch(i, 'class', 'class', 'response'))
    # 混淆矩陣
    tableTR <- table(actual = carTR$accept, predict = carTR_predict)
    tableTE <- table(actual = carTE$accept, predict = carTE_predict)
    
    # 計算誤差矩陣
    result[i, 2] <- paste(round((sum(tableTR) - sum(diag(tableTR)))*100/sum(tableTR), 2), '%')
    result[i, 3] <- paste(round((sum(tableTE) - sum(diag(tableTE)))*100/sum(tableTE), 2), '%')
  }
  #檢視誤差率
> result
  arithmetic  errTR  errTE
1       C5.0 1.16 % 3.25 %
2       CART 5.94 % 7.89 %
3      ctree 4.47 %  5.8 %