機器學習_8.決策樹演算法
1.ID3演算法
預備知識
1.資訊熵:
2.資訊增益
演算法內容
引入了資訊理論中的互資訊(資訊增益)作為選擇判別因素的度量,即:以資訊增益的下降速度作為選取分類屬性的標準,所選的測試屬性是從根節點到當前節點的路徑上從沒有被考慮過的具有最高的資訊增益的屬性。這就需要計算各個屬性的資訊增益的值,找出最大的作為判別的屬性:
1. 計算先驗熵,沒有接收到其他的屬性值時的平均不確定性,
2. 計算後驗墒,在接收到輸出符號yi時關於信源的不確定性,
3. 條件熵,對後驗熵在輸出符號集Y中求期望,接收到全部的付好後對信源的不確定性,
4. 互資訊,先驗熵和條件熵的差,
例項
是否適合打壘球的決策表如下
天氣 | 溫度 | 溼度 | 風速 | 活動 |
---|---|---|---|---|
晴 | 炎熱 | 高 | 弱 | 取消 |
晴 | 炎熱 | 高 | 強 | 取消 |
陰 | 炎熱 | 高 | 弱 | 進行 |
雨 | 適中 | 高 | 弱 | 進行 |
雨 | 寒冷 | 正常 | 弱 | 進行 |
雨 | 寒冷 | 正常 | 強 | 取消 |
陰 | 寒冷 | 正常 | 強 | 進行 |
晴 | 適中 | 高 | 弱 | 取消 |
晴 | 寒冷 | 正常 | 弱 | 進行 |
雨 | 適中 | 正常 | 弱 | 進行 |
晴 | 適中 | 正常 | 強 | 進行 |
陰 | 適中 | 高 | 強 | 進行 |
陰 | 炎熱 | 正常 | 弱 | 進行 |
雨 | 適中 | 高 | 強 | 取消 |
1.計算先驗熵:在沒有接收到其他的任何的屬性值時候,活動進行與否的熵根據下表進行計算。
2.分別將各個屬性作為決策屬性時的條件熵(先計算後驗墒,在計算條件熵)
(1) 計算已知天氣情況下活動是否進行的條件熵(已知天氣情況下對於活動的不確定性)
先計算後驗墒:
再計算條件熵:(知道了Y之後,對X的不確定性:知道了天氣之後,對活動的不確定性,越小是越好的)
(2)計算已知溫度情況時對活動的條件熵(不確定性)
(3)已知溼度情況下對於活動是否進行的條件熵(不確定性)
(4)已知風速情況下對於活動是否進行的條件熵(不確定性)
3.計算資訊增益
所以選擇天氣作為第一個判別因素
在選擇了天氣作為第一個判別因素之後,我們很容易看出(計算的方法和上面提到的一樣),針對上圖的中間的三張子表來說,第一張子表在選擇溼度作為劃分資料的feature的時候,分類問題可以完全解決:溼度正常的情況下進行活動,溼度高的時候取消(在天氣狀態為晴的條件下);第二個子表不需要劃分,即,天氣晴的情況下不管其他的因素是什麼,活動都要進行;第三張子表當選擇風速作為劃分的feature時,分類問題也完全解決:風速弱的時候進行,風速強的時候取消(在天氣狀況為雨的條件下)。
Python實現
import math
import operator
def calcShannonEnt(dataset):
numEntries = len(dataset)
labelCounts = {}
for featVec in dataset:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] +=1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob*math.log(prob, 2)
return shannonEnt
def CreateDataSet():
dataset = [[1, 1, 'yes' ],
[1, 1, 'yes' ],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataset, labels
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numberFeatures = len(dataSet[0])-1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0;
bestFeature = -1;
for i in range(numberFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy =0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
classCount ={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]=1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0])==len(classList):
return classList[0]
if len(dataSet[0])==1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
myDat,labels = CreateDataSet()
createTree(myDat,labels)
Java實現
1.計算給定資料集的夏農熵
ID3演算法實現中,訓練資料和測試資料都是用ArrayList<ArrayList<String>> 存放,每一個子ArrayList是一個sample(feature+label)。即,data中的一列是一個屬性,一行是一個樣本。
uniqueLabels用來統計不同的label出現的個數。
public double calculateShannonEntropy(ArrayList<ArrayList<String>> data) {
double shannon = 0.0;
int length = data.get(0).size(); // length-1就是label的index
HashMap<String, Integer> uniqueLabels = new HashMap<>();
for (int i = 0; i < data.size(); i++) {
if (uniqueLabels.containsKey(data.get(i).get(length - 1))) {
uniqueLabels.replace(data.get(i).get(length - 1), uniqueLabels.get(data.get(i).get(length - 1)) + 1);
} else {
uniqueLabels.put(data.get(i).get(length - 1), 1);
}
}
for (String one : uniqueLabels.keySet()) {
shannon += -(((double) (uniqueLabels.get(one)) / (data.size()))
* Math.log((double) (uniqueLabels.get(one)) / (data.size())) / Math.log(2));
}
return shannon;
2 按照給定的feature的取值劃分資料集
三個引數(data, index, value)的含義: 將data中第index列上值為value的樣本返回,並且在返回的結果中樣本不包括index列的特徵
public ArrayList<ArrayList<String>> splitDataSetByFeature(ArrayList<ArrayList<String>> data, int index,
String value) {
ArrayList<ArrayList<String>> subData = new ArrayList<>();
for (int i = 0; i < data.size(); i++) {
ArrayList<String> newSample = new ArrayList<>();
if (data.get(i).get(index).equals(value)) {
for (int j = 0; j < data.get(i).size(); j++) {
if (j != index) {
newSample.add(data.get(i).get(j));
}
}
subData.add(newSample);
}
}
return subData;
}
3.選擇最好的資料集劃分方式
對於一個數據集data,要選擇其中的最好的feature來劃分資料, 所以需要一列一列(data中的一列是一個屬性,一行是一個樣本)的比較(比較使用哪個特徵來劃分得到的資訊增益最大)。對於每一列來說,計算該列中的屬性值有多少種,然後計算每種屬性值的熵的大小,然後按照比例求和。最後比較每一列的熵值的總和,資訊增益最大的屬性就是我們想要找的最好的屬性。
featureStatistic用來統計某一個特徵可能的取值以及這些取值的個數
public int chooseBestFeature(ArrayList<ArrayList<String>> data, ArrayList<String> featureName) {
int featureSize = data.get(0).size();
int dataSize = data.size();
int bestFuatrue = -1;
double bestInfoGain = 0.0;
double infoGain = 0.0;
double baseShannon = this.calculateShannonEntropy(data);
double shannon = 0.0;
HashMap<String, Integer> featureStatistic = new HashMap<>();
for (int i = 0; i < featureSize - 1; i++) {
for (int j = 0; j < data.size(); j++) {
if (featureStatistic.containsKey(data.get(j).get(i))) {
featureStatistic.replace(data.get(j).get(i), featureStatistic.get(data.get(j).get(i)) + 1);
} else {
featureStatistic.put(data.get(j).get(i), 1);
}
}
ArrayList<ArrayList<String>> subdata;
for (String featureValue : featureStatistic.keySet()) {
subdata = this.splitDataSetByFeature(data, i, featureValue);
shannon += this.calculateShannonEntropy(subdata)
* ((double) featureStatistic.get(featureValue) / dataSize);
}
infoGain = baseShannon - shannon;
if (infoGain > bestInfoGain) {
bestInfoGain = infoGain;
bestFuatrue = i;
}
shannon = 0.0;
featureStatistic.clear();
}
return bestFuatrue;
}
4.構造決策樹
遞迴的構造決策樹,注意函式的返回型別是object,而不是DecisionTree(該類的定義下面給出),這是因為當我們構造到葉子結點的時候,我們可能返回的是String(正例還是反例,yes or no,而不再是棵子樹),所以使用Object
public Object createDecisionTree(ArrayList<ArrayList<String>> data, ArrayList<String> featureName) {
int dataSize = data.size();
int featureSize = data.get(0).size();
// 如果沒有特徵了,data.get(0).size = 1 說明只剩下標籤了, 開始投票。
if (data.get(0).size() == 1) {
return vote(data);
}
// 判斷是不是所有的sample的label都一致了, 如果是,返回這個統一的類別標籤。
HashSet<String> labels = new HashSet<>();
for (int i = 0; i < dataSize; i++) {
if (!labels.contains(data.get(i).get(featureSize - 1))) {
labels.add(data.get(i).get(featureSize - 1));
}
}
if (labels.size() == 1) {
return data.get(0).get(featureSize - 1);
}
// 選擇最好的feature來進行決策樹(子決策樹)的構建
int bestFeatureIndex = this.chooseBestFeature(data, featureName);
String bestFeature = featureName.get(bestFeatureIndex);
featureName.remove(bestFeatureIndex);
// 統計上一步選出的最好的屬性,都有那些取值。
HashSet<String> bestFeatureValuesSet = new HashSet<>();
for (int i = 0; i < data.size(); i++) {
if (!bestFeatureValuesSet.contains(data.get(i).get(bestFeatureIndex))) {
bestFeatureValuesSet.add(data.get(i).get(bestFeatureIndex));
}
}
DecisionTree tree = new DecisionTree();
tree.setAttributeName(bestFeature);
// 最好的屬性的每一個取值,都形成一個子樹的root, 開始遞迴。
Iterator<String> iterator = bestFeatureValuesSet.iterator();
while (iterator.hasNext()) {
ArrayList<String> subFeatureName = new ArrayList<>();
for (int i = 0; i < featureName.size(); i++) {
subFeatureName.add(featureName.get(i));
} // 遞迴的一個關鍵問題。
String featureValue = iterator.next();
tree.children.put(featureValue,
createDecisionTree(splitDataSetByFeature(data, bestFeatureIndex, featureValue), subFeatureName));
}
return tree;
}
5.投票函式
當已經沒有屬性可以作為劃分的依據了, 但是這些樣本的類的標籤依然不同, 那麼這個時候就要投票決定了。這個時候data的形式應該是隻有一列標籤了。那麼我們就找這一列標籤中最多的,作為類別返回。
public String vote(ArrayList<ArrayList<String>> data) {
String voteResult = null;
int dataSize = data.size();
int length = data.get(0).size();
HashMap<String, Integer> sta = new HashMap<>();
for (int i = 0; i < dataSize; i++) {
if (!sta.keySet().contains(data.get(i).get(length - 1))) {
sta.put(data.get(i).get(length - 1), 1);
} else {
sta.replace(data.get(i).get(length - 1), sta.get(data.get(i).get(length - 1)) + 1);
}
}
int maxValue = Collections.max(sta.values());
for (String key : sta.keySet()) {
if (maxValue == sta.get(key)) {
voteResult = key;
}
}
return voteResult;
}
6.決策樹的資料結構
不像python中有一個功能比較強大的字典,所以這裡自定義了一個決策樹的資料結構(類DecisionTree),兩個域:
(1)String:用來表示該樹(子樹)的屬性(feature)。
(2) HashMap<String, Object> : key的值表示feature的取值,Object是子樹(DecisionTree)或者是最終的label。
典型的一個遞迴的定義。並且在該類中提供了:
(1)遍歷樹的方法。
(2)將構造的樹輸出到指定的檔案中。
public class DecisionTree implements Serializable{
private static final long serialVersionUID = 1L;
private String attributeName;
public HashMap<String, Object> children;
private String decisionTree = "./outputTree/decisionTree.data";
public void printTree(Object tree, ArrayList<String> record, BufferedWriter bufferedWriter) {
if (tree instanceof String) {
record.add((String) tree);
System.out.println(record);
try {
bufferedWriter.write(record.toString());
bufferedWriter.newLine();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
record.remove(record.size() - 1);
record.remove(record.size() - 1);
return;
}
record.add(((DecisionTree) tree).getAttributeName());
for (String key : ((DecisionTree) tree).children.keySet()) {
record.add(key);
printTree(((DecisionTree) tree).children.get(key), record, bufferedWriter);
}
int count = 1;
while( record.size() > 0 && count <= 2){
record.remove(record.size() - 1);
count++;
}
}
public void saveDecisionTree(Object tree)
{
try {
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(this.decisionTree));
this.printTree(tree, new ArrayList<>(), bufferedWriter);
bufferedWriter.close();
System.out.println("\r\nthe decision tree has saved in the file: './outputTree/decisionTree.data'");
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
演算法測試和完整程式碼見:https://blog.csdn.net/robin_Xu_shuai/article/details/74011205
2.C4.5
C4.5是Ross Quinlan在1993年在ID3的基礎上改進而提出的。.ID3採用的資訊增益度量存在一個缺點,它一般會優先選擇有較多屬性值的Feature,因為屬性值多的Feature會有相對較大的資訊增益?(資訊增益反映的給定一個條件以後不確定性減少的程度,必然是分得越細的資料集確定性更高,也就是條件熵越小,資訊增益越大).為了避免這個不足C4.5中是用資訊增益比率(gain ratio)來作為選擇分支的準則。資訊增益比率通過引入一個被稱作分裂資訊(Split information)的項來懲罰取值較多的Feature。除此之外,C4.5還彌補了ID3中不能處理特徵屬性值連續的問題。但是,對連續屬性值需要掃描排序,會使C4.5效能下降。
C4.5並不一個演算法,而是一組演算法—C4.5,非剪枝C4.5和C4.5規則。下圖中的演算法將給出C4.5的基本工作流程:
判斷物件的屬性是有順序的,屬性選擇度量又稱分裂規則,因為它們決定給定節點上的元組如何分裂。屬性選擇度量提供了每個屬性描述給定訓練元組的秩評定,具有最好度量得分的屬性被選作給定元組的分裂屬性。目前比較流行的屬性選擇度量有--資訊增益、增益率和Gini指標。
在ID3已介紹的關於資訊理論部分的基礎上,介紹資訊增益率。
資訊增益率使用“分裂資訊”值將資訊增益規範化。分類資訊類似於Info(D),定義如下:
這個值表示通過將訓練資料集D劃分成對應於屬性A測試的v個輸出的v個劃分產生的資訊。資訊增益率定義:
選擇具有最大增益率的屬性作為分裂屬性。
建立樹類:
package C45Test;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class DecisionTree {
public TreeNode createDT(List<ArrayList<String>> data,List<String> attributeList){
System.out.println("當前的DATA為");
for(int i=0;i<data.size();i++){
ArrayList<String> temp = data.get(i);
for(int j=0;j<temp.size();j++){
System.out.print(temp.get(j)+ " ");
}
System.out.println();
}
System.out.println("---------------------------------");
System.out.println("當前的ATTR為");
for(int i=0;i<attributeList.size();i++){
System.out.print(attributeList.get(i)+ " ");
}
System.out.println();
System.out.println("---------------------------------");
TreeNode node = new TreeNode();
String result = InfoGain.IsPure(InfoGain.getTarget(data));
if(result != null){
node.setNodeName("leafNode");
node.setTargetFunValue(result);
return node;
}
if(attributeList.size() == 0){
node.setTargetFunValue(result);
return node;
}else{
InfoGain gain = new InfoGain(data,attributeList);
double maxGain = 0.0;
int attrIndex = -1;
for(int i=0;i<attributeList.size();i++){
double tempGain = gain.getGainRatio(i);
if(maxGain < tempGain){
maxGain = tempGain;
attrIndex = i;
}
}
System.out.println("選擇出的最大增益率屬性為: " + attributeList.get(attrIndex));
node.setAttributeValue(attributeList.get(attrIndex));
List<ArrayList<String>> resultData = null;
Map<String,Long> attrvalueMap = gain.getAttributeValue(attrIndex);
for(Map.Entry<String, Long> entry : attrvalueMap.entrySet()){
resultData = gain.getData4Value(entry.getKey(), attrIndex);
TreeNode leafNode = null;
System.out.println("當前為"+attributeList.get(attrIndex)+"的"+entry.getKey()+"分支。");
if(resultData.size() == 0){
leafNode = new TreeNode();
leafNode.setNodeName(attributeList.get(attrIndex));
leafNode.setTargetFunValue(result);
leafNode.setAttributeValue(entry.getKey());
}else{
for (int j = 0; j < resultData.size(); j++) {
resultData.get(j).remove(attrIndex);
}
ArrayList<String> resultAttr = new ArrayList<String>(attributeList);
resultAttr.remove(attrIndex);
leafNode = createDT(resultData,resultAttr);
}
node.getChildTreeNode().add(leafNode);
node.getPathName().add(entry.getKey());
}
}
return node;
}
class TreeNode{
private String attributeValue;
private List<TreeNode> childTreeNode;
private List<String> pathName;
private String targetFunValue;
private String nodeName;
public TreeNode(String nodeName){
this.nodeName = nodeName;
this.childTreeNode = new ArrayList<TreeNode>();
this.pathName = new ArrayList<String>();
}
public TreeNode(){
this.childTreeNode = new ArrayList<TreeNode>();
this.pathName = new ArrayList<String>();
}
public String getAttributeValue() {
return attributeValue;
}
public void setAttributeValue(String attributeValue) {
this.attributeValue = attributeValue;
}
public List<TreeNode> getChildTreeNode() {
return childTreeNode;
}
public void setChildTreeNode(List<TreeNode> childTreeNode) {
this.childTreeNode = childTreeNode;
}
public String getTargetFunValue() {
return targetFunValue;
}
public void setTargetFunValue(String targetFunValue) {
this.targetFunValue = targetFunValue;
}
public String getNodeName() {
return nodeName;
}
public void setNodeName(String nodeName) {
this.nodeName = nodeName;
}
public List<String> getPathName() {
return pathName;
}
public void setPathName(List<String> pathName) {
this.pathName = pathName;
}
}
}
增益率計算類
package C45Test;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
//C 4.5 實現
public class InfoGain {
private List<ArrayList<String>> data;
private List<String> attribute;
public InfoGain(List<ArrayList<String>> data,List<String> attribute){
this.data = new ArrayList<ArrayList<String>>();
for(int i=0;i<data.size();i++){
List<String> temp = data.get(i);
ArrayList<String> t = new ArrayList<String>();
for(int j=0;j<temp.size();j++){
t.add(temp.get(j));
}
this.data.add(t);
}
this.attribute = new ArrayList<String>();
for(int k=0;k<attribute.size();k++){
this.attribute.add(attribute.get(k));
}
/*this.data = data;
this.attribute = attribute;*/
}
//獲得熵
public double getEntropy(){
Map<String,Long> targetValueMap = getTargetValue();
Set<String> targetkey = targetValueMap.keySet();
double entropy = 0.0;
for(String key : targetkey){
double p = MathUtils.div((double)targetValueMap.get(key), (double)data.size());
entropy += (-1) * p * Math.log(p);
}
return entropy;
}
//獲得InfoA
public double getInfoAttribute(int attributeIndex){
Map<String,Long> attributeValueMap = getAttributeValue(attributeIndex);
double infoA = 0.0;
for(Map.Entry<String, Long> entry : attributeValueMap.entrySet()){
int size = data.size();
double attributeP = MathUtils.div((double)entry.getValue() , (double) size);
Map<String,Long> targetValueMap = getAttributeValueTargetValue(entry.getKey(),attributeIndex);
long totalCount = 0L;
for(Map.Entry<String, Long> entryValue :targetValueMap.entrySet()){
totalCount += entryValue.getValue();
}
double valueSum = 0.0;
for(Map.Entry<String, Long> entryTargetValue : targetValueMap.entrySet()){
double p = MathUtils.div((double)entryTargetValue.getValue(), (double)totalCount);
valueSum += Math.log(p) * p;
}
infoA += (-1) * attributeP * valueSum;
}
return infoA;
}
//得到屬性值在決策空間的比例
public Map<String,Long> getAttributeValueTargetValue(String attributeName,int attributeIndex){
Map<String,Long> targetValueMap = new HashMap<String,Long>();
Iterator<ArrayList<String>> iterator = data.iterator();
while(iterator.hasNext()){
List<String> tempList = iterator.next();
if(attributeName.equalsIgnoreCase(tempList.get(attributeIndex))){
int size = tempList.size();
String key = tempList.get(size - 1);
Long value = targetValueMap.get(key);
targetValueMap.put(key, value != null ? ++value :1L);
}
}
return targetValueMap;
}
//得到屬性在決策空間上的數量
public Map<String,Long> getAttributeValue(int attributeIndex){
Map<String,Long> attributeValueMap = new HashMap<String,Long>();
for(ArrayList<String> note : data){
String key = note.get(attributeIndex);
Long value = attributeValueMap.get(key);
attributeValueMap.put(key, value != null ? ++value :1L);
}
return attributeValueMap;
}
public List<ArrayList<String>> getData4Value(String attrValue,int attrIndex){
List<ArrayList<String>> resultData = new ArrayList<ArrayList<String>>();
Iterator<ArrayList<String>> iterator = data.iterator();
for(;iterator.hasNext();){
ArrayList<String> templist = iterator.next();
if(templist.get(attrIndex).equalsIgnoreCase(attrValue)){
ArrayList<String> temp = (ArrayList<String>) templist.clone();
resultData.add(temp);
}
}
return resultData;
}
//獲得增益率
public double getGainRatio(int attributeIndex){
return MathUtils.div(getGain(attributeIndex), getSplitInfo(attributeIndex));
}
//獲得增益量
public double getGain(int attributeIndex){
return getEntropy() - getInfoAttribute(attributeIndex);
}
//得到懲罰因子
public double getSplitInfo(int attributeIndex){
Map<String,Long> attributeValueMap = getAttributeValue(attributeIndex);
double splitA = 0.0;
for(Map.Entry<String, Long> entry : attributeValueMap.entrySet()){
int size = data.size();
double attributeP = MathUtils.div((double)entry.getValue() , (double) size);
splitA += attributeP * Math.log(attributeP) * (-1);
}
return splitA;
}
//得到目標函式在當前集合範圍內的離散的值
public Map<String,Long> getTargetValue(){
Map<String,Long> targetValueMap = new HashMap<String,Long>();
Iterator<ArrayList<String>> iterator = data.iterator();
while(iterator.hasNext()){
List<String> tempList = iterator.next();
String key = tempList.get(tempList.size() - 1);
Long value = targetValueMap.get(key);
targetValueMap.put(key, value != null ? ++value : 1L);
}
return targetValueMap;
}
//獲得TARGET值
public static List<String> getTarget(List<ArrayList<String>> data){
List<String> list = new ArrayList<String>();
for(ArrayList<String> temp : data){
int index = temp.size() -1;
String value = temp.get(index);
list.add(value);
}
return list;
}
//判斷當前純度是否100%
public static String IsPure(List<String> list){
Set<String> set = new HashSet<String>();
for(String name :list){
set.add(name);
}
if(set.size() > 1) return null;
Iterator<String> iterator = set.iterator();
return iterator.next();
}
}
測試類,資料集讀取以上的分別放到2個List中。
package C45Test;
import java.util.ArrayList;
import java.util.List;
import C45Test.DecisionTree.TreeNode;
public class MainC45 {
private static final List<ArrayList<String>> dataList = new ArrayList<ArrayList<String>>();
private static final List<String> attributeList = new ArrayList<String>();
public static void main(String args[]){
DecisionTree dt = new DecisionTree();
TreeNode node = dt.createDT(configData(),configAttribute());
System.out.println();
}
}
大數運算工具類
package C45Test;
import java.math.BigDecimal;
public abstract class MathUtils {
//預設餘數長度
private static final int DIV_SCALE = 10;
//受限於DOUBLE長度
public static double add(double value1,double value2){
BigDecimal big1 = new BigDecimal(String.valueOf(value1));
BigDecimal big2 = new BigDecimal(String.valueOf(value2));
return big1.add(big2).doubleValue();
}
//大數加法
public static double add(String value1,String value2){
BigDecimal big1 = new BigDecimal(value1);
BigDecimal big2 = new BigDecimal(value2);
return big1.add(big2).doubleValue();
}
public static double div(double value1,double value2){
BigDecimal big1 = new BigDecimal(String.valueOf(value1));
BigDecimal big2 = new BigDecimal(String.valueOf(value2));
return big1.divide(big2,DIV_SCALE,BigDecimal.ROUND_HALF_UP).doubleValue();
}
public static double mul(double value1,double value2){
BigDecimal big1 = new BigDecimal(String.valueOf(value1));
BigDecimal big2 = new BigDecimal(String.valueOf(value2));
return big1.multiply(big2).doubleValue();
}
public static double sub(double value1,double value2){
BigDecimal big1 = new BigDecimal(String.valueOf(value1));
BigDecimal big2 = new BigDecimal(String.valueOf(value2));
return big1.subtract(big2).doubleValue();
}
public static double returnMax(double value1, double value2) {
BigDecimal big1 = new BigDecimal(value1);
BigDecimal big2 = new BigDecimal(value2);
return big1.max(big2).doubleValue();
}
}
3.CART演算法
原理:
分類迴歸樹演算法:CART(Classification And Regression Tree)演算法採用一種二分遞迴分割的技術,將當前的樣本集分為兩個子樣本集,使得生成的的每個非葉子節點都有兩個分支。因此,CART演算法生成的決策樹是結構簡潔的二叉樹。
分類樹兩個基本思想:第一個是將訓練樣本進行遞迴地劃分自變數空間進行建樹的想法,第二個想法是用驗證資料進行剪枝。
建樹:在分類迴歸樹中,我們把類別集Result表示因變數,選取的屬性集attributelist表示自變數,通過遞迴的方式把attributelist把p維空間劃分為不重疊的矩形,具體建樹的基本步驟參見:http://baike.baidu.com/view/3075445.htm。
CART演算法是怎樣進行樣本劃分的呢?它檢查每個變數和該變數所有可能的劃分值來發現最好的劃分,對離散值如{x,y,x},則在該屬性上的劃分有三種情況({{x,y},{z}},{{x,z},y},{{y,z},x}),空集和全集的劃分除外;對於連續值處理引進“分裂點”的思想,假設樣本集中某個屬性共n個連續值,則有n-1個分裂點,每個“分裂點”為相鄰兩個連續值的均值 (a[i] + a[i+1]) / 2。將每個屬性的所有劃分按照他們能減少的雜質(合成物中的異質,不同成分)量來進行排序,雜質的減少被定義為劃分前的雜質減去劃分之後每個節點的雜質量*劃分所佔樣本比率之和,目前最流行的雜質度量方法是:GINI指標,如果我們用k,k=1,2,3……C表示類,其中C是類別集Result的因變數數目,一個節點A的GINI不純度定義為:
其中,Pk表示觀測點中屬於k類得概率,當Gini(A)=0時所有樣本屬於同一類,當所有類在節點中以相同的概率出現時,Gini(A)最大化,此時值為(C-1)C/2。
對於分類迴歸樹,A如果它不滿足“T都屬於同一類別or T中只剩下一個樣本”,則此節點為非葉節點,所以嘗試根據樣本的每一個屬性及可能的屬性值,對樣本的進行二元劃分,假設分類後A分為B和C,其中B佔A中樣本的比例為p,C為q(顯然p+q=1)。則雜質改變數:Gini(A) -p*Gini(B)-q*Gini(C),每次劃分該值應為非負,只有這樣劃分才有意義,對每個屬性值嘗試劃分的目的就是找到雜質該變數最大的一個劃分,該屬性值劃分子樹即為最優分支。
剪枝:在CART過程中第二個關鍵的思想是用獨立的驗證資料集對訓練集生長的樹進行剪枝。
分析分類迴歸樹的遞迴建樹過程,不難發現它實質上存在著一個數據過度擬合問題。在決策樹構造時,由於訓練資料中的噪音或孤立點,許多分枝反映的是訓練資料中的異常,使用這樣的判定樹對類別未知的資料進行分類,分類的準確性不高。因此試圖檢測和減去這樣的分支,檢測和減去這些分支的過程被稱為樹剪枝。樹剪枝方法用於處理過分適應資料問題。通常,這種方法使用統計度量,減去最不可靠的分支,這將導致較快的分類,提高樹獨立於訓練資料正確分類的能力。
決策樹常用的剪枝常用的簡直方法有兩種:事前剪枝和事後剪枝,CART演算法經常採用事後剪枝方法:該方法是通過在完全生長的樹上剪去分枝實現的,通過刪除節點的分支來剪去樹節點。最下面未被剪枝的節點成為樹葉。
CART用的成本複雜性標準是分類樹的簡單誤分(基於驗證資料的)加上一個對樹的大小的懲罰因素。懲罰因素是有引數的,我們用a表示,每個節點的懲罰。成本複雜性標準對於一個數來說是Err(T)+a|L(T)|,其中Err(T)是驗證資料被樹誤分部分,L(T)是樹T的葉節點樹,a是每個節點的懲罰成本:一個從0向上變動的數字。當a=0對樹有太多的節點沒有懲罰,用的成本複雜性標準是完全生長的沒有剪枝的樹。在剪枝形成的一系列樹中,從其中選擇一個在驗證資料集上具有最小誤分的樹是很自然的,我們把這個樹成為最小誤分樹。
演算法實現:
本文根據一個樣本集,進行了CART演算法的簡單實現。該樣本集中每個樣本有十六個特徵屬性和一個結果屬性,為了降低劃分的難度,每個特徵屬性取兩個不同的離散值,結果屬性有兩個離散值:Yes和No。
資料結構定義:在該演算法中定義了三種資料結構:儲存樣本屬性名稱及取值的Node屬性,儲存單個樣本的EXampleSet屬性,樹的節點屬性dataNode;存放在DataStructure.h中,程式碼如下:
typedef struct tagNode
{//儲存屬性
string name;//屬性的名稱
string value;//屬性取值
}Node;
typedef struct tagExampleSet
{//樣本儲存
string example[16];//樣本的每個屬性上的屬性值
string decision;//樣本的結果類
}ExampleSet;
typedef struct Data_Node{
//節點的資料結構,結果分為兩類yes類和No類
int Yesnum;//類yes得樣本數目
int Nonum;//類no得樣本數
vector<ExampleSet> myVector;//儲存樣本
Data_Node *LeftNode;//左子樹
Data_Node *RightNode;//右子樹
int Property;//劃分選取的屬性
string Proper_value;//所選的屬性的值
int nodenum;//標示節點
bool leavenode;//標示葉節點
}dataNode;
樣本讀取及處理:用兩個檔案分別儲存樣本的屬性及所有樣本。檔案t儲存樣本的十六個自變數屬性、類別屬性的名稱和離散值集合,檔案t1是所有樣本的集合,用ReadFile類讀取檔案,並把它們分別儲存在兩個向量中。建樹的過程在MySufan類中,該類地方法列表如下:
MySuanfa();
~MySuanfa();
void Method();//呼叫建樹、剪枝方法
void BuildTree(Data_Node*thisNode);//建樹方法,每次呼叫DeviceTree對非葉節點進行劃分
void DeviceTree(Data_Node*thisNode,int i);//對非葉結點進行劃分,分出左節點,有節點
int Choose_Property(Data_Node* thisNode);//返回選擇的屬性值
double pure(int i1,int i2,int i3);//純度計算函式,每次計算最優劃分時用
void Deal(Data_Node* d);//剪枝函式,此函式對建好的樹用測試樣本進行剪枝
void levelorder(Data_Node * p);//層次遍歷,此方法按曾給決策點分配序號,用於剪枝
void inorder(Data_Node *p);//中序遍歷,和建樹的前序遍歷用於確定樹的結構
void BuildTest(Data_Node *d,int t);//此方法用於計算當取不同決策點時,建樹樣本的錯誤樣本數,t為決策點數目
void CutTree(Data_Node *d,int k,int t);//k為單個樣本,t為決策點數,根據決策點對測試樣本集進行測試
void ClassOfNode(vector<ExampleSet>);//本方法用於切割原始樣本集,將樣本分為測試樣本和建樹樣本
遞迴建樹:建樹按照遞迴方式進行建樹,採用全部樣本的2/3進行建樹,首先找到一個劃分值,如果不存在返回-1,然後判斷一個樹是否為葉子節點,不為葉子節點按照劃分值進行劃分,關鍵程式碼如下:
void MySuanfa::BuildTree(Data_Node* thisNode)
{
if(thisNode!=NULL){// //節點不為空
nodenum++;
thisNode->nodenum=nodenum;
int getProperty=Choose_Property(thisNode);//找到劃分
thisNode->Property=getProperty;
if((thisNode->Yesnum*thisNode->Nonum==0)||getProperty==-1)
{//如果劃分為-1,則無法再次劃分
thisNode->Property=-1;
thisNode->leavenode=true;
}
else
{//遞迴建樹
thisNode->leavenode=false;
DeviceTree(thisNode,getProperty);//將父節點按照劃分屬性進行劃分
BuildTree(thisNode->LeftNode);//遞迴建立左子樹
BuildTree(thisNode->RightNode);//遞迴建立右子樹
}
}
}
分析上面程式碼,Choose_Property(thisNode);函式的作用是將thisNode中的樣本嘗試進行最優劃分,劃分的依據就是雜質最大該變數,如果劃分成功返回屬性下標,否則返回-1,我們在樣本中每個屬性預設取兩個離散值。注意到方法中對書中定義的leavenode和nodenum兩個變數的操作,他們的用途是什麼呢?nodenum的第一個作用是樹的遍歷,將每一個節點賦予一個唯一的值,建樹的過程是前序建樹,建樹結束後根據樹的中序遍歷可以唯一確定樹的結構,nodenum的第二個作用和leavenode的作用將會在剪枝過程中用到,後面將會提到。
當建樹結束後,樹的前序即為nodenum從小到大的排序,然後通過呼叫中序遍歷函式輸出樹的中序序列,確定樹的結構。該樹含有17個決策點(非葉子節點),18個葉子節點。
樹中決策點的劃分程式碼對應的屬性名稱:
0————handicapped-infants ; 1————water-project-cost-sharing
2————adoption-of-the-budget-resolution ; 3————physician-fee-freeze
4————el-salvador-aid ; 5————religious-groups-in-schools
6————anti-satellite-test-ban; 7————aid-to-nicaraguan-contras
8————mx-missile ; 9————immigration
10————synfuels-corporation-cutback ; 11————education-spending
12————superfund-right-to-sue ; 13————crime
14————duty-free-exports ; 15—export-administration-act-south-africa
按照遞迴分類的演算法,最終生成的樹的葉子節點中或者同屬一類或者只有一個樣本,分析樹的結構我們可以發現,有兩個葉子節點8和23不符合這種情況,卻成了葉子節點。這與所選樣本有關,在這兩個葉節點中兩個樣本的十六個特徵屬性值都相同,只有所屬類別不同,所以無法根據遞迴演算法進行分類。另當選取physician-fee-freeze 和adoption-of-the-budget-resolution兩種屬性進行決策時,樣本所屬的類別已經基本判定,造成這種情況我們可認為這兩種屬性在樣本中所佔的權重很大,只要確定這兩種情況,樹的大部分樣本的分類就已確定。
剪枝:用訓練樣本建樹結束後,就是進行樹的剪枝階段,本演算法採用樣本集的後1/3作為測試進行剪枝。
樹的決策點:如果一個節點為非葉節點,則稱該節點為一個樹的決策點。樹的剪枝就是減去過分擬合給樹帶來的的冗餘,用盡可能少的決策點、儘可能低的樹高獲取儘可能大的正確率。
如何獲取樹的決策點?逐層確定樹的決策點,並根據決策點數目進行剪枝是剪枝的關鍵。
根據二叉樹的特性可知樹的非葉節點=葉節點-1;所以可以從樹的節點數中得知樹種非葉結點的數量。本程式根據這一特性將樹的決策點逐層賦值,根節點賦值1,根節點的左節點賦值2……,這一過程通過層次遍歷實現。並將該值賦給nodenum,對於葉子節點nodenum為0關鍵程式碼如下:
void MySuanfa::levelorder(Data_Node* p)
{
int node=1;
list<Data_Node *>q;
if(p)q.push_back(p);
p->nodenum=node;
while(!q.empty())
{
p=q.front();
q.pop_front();
if(p->LeftNode)
{
if(p->LeftNode->leavenode)
{//如果該節點的左節點是子節點,則將nodenum賦0
p->LeftNode->nodenum=0;
}
else
{//否則將該節點賦一個node值,該值表示此決策點的順序
node++;
p->LeftNode->nodenum=node;
q.push_back(p->LeftNode);
}
}
if(p->RightNode)
{
if(p->RightNode->leavenode)//
{//如果該節點的右節點是子節點,則將nodenum賦0
p->RightNode->nodenum=0;
}
else
{//否則將該節點賦一個node值,該值表示此決策點的順序
node++;
p->RightNode->nodenum=node;
q.push_back(p->RightNode);
}
}
}
}
遍歷結束後,每一個決策點數目可以確定一個樹,我們就可以根據樹的決策點數對訓練樣本和測試樣本的誤差進行統計,怎樣根據決策點數確定樹的結構?可以將樹的前序遍歷進行改進,對於t個決策點,節點為0或大於t的都是葉子節點,一旦確定葉子節點,樹的結構就清楚了,下圖為重新賦值後的樹,在該圖中,如當有3個決策點時,2的子節點和3的子節點都是葉子節點,當用改進的前序遍歷便立時會輸出有3個決策點:(1,2,3);4個葉子節點(4,5,0,6)的子樹:
不同決策點可對應不同子樹,通過前序遍歷可以將葉子節點中的錯誤樣本統計出來計算該樹情況下錯誤樣本的個數,然後再用測試樣本遍歷樹,統計測試樣本再改樹下錯誤樣本個數最後得出結果集如下:
通過比較可知當樹有8和9個決策點時,測試誤差最小,我們取8,因為此時樹比9個決策點簡單,我們取含有8個決策點為最小誤分樹。最小誤分樹結構如下:
上圖中最小誤分樹非葉節點中的兩個值,第一個表示決策點表示,第二個表示選擇的屬性的程式碼,葉子節點中兩數表示每一類的數目。
我們定義最優剪枝的方法是在剪枝序列中含有誤差在最小誤差樹的一個標準差之內的最小樹,算出的最小誤差率被砍做一個帶有標準差等於的隨機變數的觀測值,其中Emin對最小誤差樹的錯誤率,Nval是驗證集的個數:Emin=5.41%,Nval=148,所以到當樹有4個決策點時,為最優剪枝。
參考:
1.https://blog.csdn.net/robin_Xu_shuai/article/details/74011205
2.https://blog.csdn.net/qq_36330643/article/details/77415451
3.https://www.cnblogs.com/yjd_hycf_space/p/6940068.html
4.https://www.cnblogs.com/sumuncle/p/5610877.html
5.https://blog.csdn.net/jbfsdzpp/article/details/44036349
6.https://blog.csdn.net/happyblogs/article/details/6843520#