寶具滑 / JS簡單實現決策樹(ID3演算法)
阿新 • • 發佈:2019-01-23
<script> // 文章: https://www.jianshu.com/p/2b50a98cd75c function DecisionTree(config) { if (typeof config == "object" && !Array.isArray(config)) this.training(config); }; DecisionTree.prototype = { _predicates: {//分割函式 '==': function (a, b) { return a == b },//針對非數字值的比較 '>=': function (a, b) { return a >= b }//針對數值的比較 }, //統計屬性值在資料集中的次數 countUniqueValues(items, attr) { var counter = {};// 獲取不同的結果值 與出現次數 for (var i of items) { if (!counter[i[attr]]) counter[i[attr]] = 0; counter[i[attr]] += 1; } return counter; }, //獲取物件中值最大的Key 假設 counter={a:9,b:2} 得到 "a" getMaxKey(counter) { var mostFrequentValue; for (var k in counter) { if (!mostFrequentValue) mostFrequentValue = k; if (counter[k] > counter[mostFrequentValue]) { mostFrequentValue = k; } }; return mostFrequentValue; }, //尋找最頻繁的特定屬性值 mostFrequentValue(items, attr) { return this.getMaxKey(this.countUniqueValues(items, attr));//計算值的出現數 }, //根據屬性切割資料集 split(items, attr, predicate, pivot) { var data = { match: [],//適合的資料集 notMatch: []//不適合的資料集 } for (var item of items) { //遍歷訓練集 if (predicate(item[attr], pivot)) {//比較是否滿足條件 data.match.push(item); } else { data.notMatch.push(item); } }; return data; }, //計算熵 entropy(items, attr) { var counter = this.countUniqueValues(items, attr);//計算值的出現數 var p, entropy = 0;//H(S)=entropy=∑(P(Xi)(log2(P(Xi)))) for (var i in counter) {//entropy+=-(P(Xi)(log2(P(Xi)))) p = counter[i] / items.length;//P(Xi)概率值 entropy += -p * Math.log2(p); } return entropy; }, buildDecisionTree(config) { var trainingSet = config.trainingSet;//訓練集 var minItemsCount = config.minItemsCount;//訓練集項數 var categoryAttr = config.categoryAttr;//用於區分的類別屬性 var entropyThrehold = config.entropyThrehold;//熵閾值 var maxTreeDepth = config.maxTreeDepth;//遞迴深度 var ignoredAttributes = config.ignoredAttributes;//忽略的屬性 // 樹最大深度為0 或訓練集的大小 小於指定項數 終止樹的構建過程 if ((maxTreeDepth == 0) || (trainingSet.length <= minItemsCount)) { return { category: this.mostFrequentValue(trainingSet, categoryAttr) }; } //初始計算 訓練集的熵 var initialEntropy = this.entropy(trainingSet, categoryAttr);//<===H(S) //訓練集熵太小 終止 if (initialEntropy <= entropyThrehold) { return { category: this.mostFrequentValue(trainingSet, categoryAttr) }; } var alreadyChecked = [];//標識已經計算過了 var bestSplit = { gain: 0 };//儲存當前最佳的分割節點資料資訊 //遍歷資料集 for (var item of trainingSet) { // 遍歷項中的所有屬性 for (var attr in item) { //跳過區分屬性與忽略屬性 if ((attr == categoryAttr) || (ignoredAttributes.indexOf(attr) >= 0)) continue; var pivot = item[attr];// 當前屬性的值 var predicateName = ((typeof pivot == 'number') ? '>=' : '=='); //根據資料型別選擇判斷條件 var attrPredPivot = attr + predicateName + pivot; if (alreadyChecked.indexOf(attrPredPivot) >= 0) continue;//已經計算過則跳過 alreadyChecked.push(attrPredPivot);//記錄 var predicate = this._predicates[predicateName];//匹配分割方式 var currSplit = this.split(trainingSet, attr, predicate, pivot); var matchEntropy = this.entropy(currSplit.match, categoryAttr);// H(match) 計算分割後合適的資料集的熵 var notMatchEntropy = this.entropy(currSplit.notMatch, categoryAttr);// H(on match) 計算分割後不合適的資料集的熵 //計算資訊增益: // IG(A,S)=H(S)-(∑P(t)H(t))) // t為分裂的子集match(匹配),on match(不匹配) // P(match)=match的長度/資料集的長度 // P(on match)=on match的長度/資料集的長度 var iGain = initialEntropy - ((matchEntropy * currSplit.match.length + notMatchEntropy * currSplit.notMatch.length) / trainingSet.length); //不斷匹配最佳增益值對應的節點資訊 if (iGain > bestSplit.gain) { bestSplit = currSplit; bestSplit.predicateName = predicateName; bestSplit.predicate = predicate; bestSplit.attribute = attr; bestSplit.pivot = pivot; bestSplit.gain = iGain; } } } // 找不到最優分割 if (!bestSplit.gain) { return { category: this.mostFrequentValue(trainingSet, categoryAttr) }; } // 遞迴繫結子樹枝 config.maxTreeDepth = maxTreeDepth - 1;//減小1深度 config.trainingSet = bestSplit.match;//將切割 match 訓練集作為下一節點的訓練集 var matchSubTree = this.buildDecisionTree(config);//遞迴匹配子樹節點 config.trainingSet = bestSplit.notMatch;//將切割 notMatch 訓練集作為下一節點的訓練集 var notMatchSubTree = this.buildDecisionTree(config);//遞迴匹配子樹節點 return { attribute: bestSplit.attribute, predicate: bestSplit.predicate, predicateName: bestSplit.predicateName, pivot: bestSplit.pivot, match: matchSubTree, notMatch: notMatchSubTree, matchedCount: bestSplit.match.length, notMatchedCount: bestSplit.notMatch.length }; }, training(config) { this.root = this.buildDecisionTree({ trainingSet: config.trainingSet,//訓練集 ignoredAttributes: config.ignoredAttributes || [],// 被忽略的屬性比如:姓名、名稱之類的 categoryAttr: config.categoryAttr || 'category',//用於區分的類別屬性 minItemsCount: config.minItemsCount || 1,//最小項數量 entropyThrehold: config.entropyThrehold || 0.01,//熵閾值 maxTreeDepth: config.maxTreeDepth || 70//遞迴的最大深度 }); }, //預測 測試 predict(data) { var attr, value, predicate, pivot; var tree = this.root; while (true) { if (tree.category) { return tree.category; } attr = tree.attribute; value = data[attr]; predicate = tree.predicate; pivot = tree.pivot; if (predicate(value, pivot)) { tree = tree.match; } else { tree = tree.notMatch; } } } }; </script> <script> var data = [ { "姓名": "餘夏", "年齡": 29, "長相": "帥", "體型": "瘦", "收入": "高", 見面: "見" }, { "姓名": "豆豆", "年齡": 25, "長相": "帥", "體型": "瘦", "收入": "高", 見面: "見" }, { "姓名": "帥常榮", "年齡": 26, "長相": "帥", "體型": "胖", "收入": "高", 見面: "見" }, { "姓名": "王濤", "年齡": 22, "長相": "帥", "體型": "瘦", "收入": "高", 見面: "見" }, { "姓名": "李東", "年齡": 23, "長相": "帥", "體型": "瘦", "收入": "高", 見面: "見" }, { "姓名": "王五五", "年齡": 23, "長相": "帥", "體型": "瘦", "收入": "低", 見面: "見" }, { "姓名": "王小濤", "年齡": 22, "長相": "帥", "體型": "瘦", "收入": "低", 見面: "見" }, { "姓名": "李繽", "年齡": 21, "長相": "帥", "體型": "胖", "收入": "高", 見面: "見" }, { "姓名": "劉明", "年齡": 21, "長相": "帥", "體型": "胖", "收入": "低", 見面: "不見" }, { "姓名": "紅鶴", "年齡": 21, "長相": "不帥", "體型": "胖", "收入": "高", 見面: "不見" }, { "姓名": "李理", "年齡": 32, "長相": "帥", "體型": "瘦", "收入": "高", 見面: "不見" }, { "姓名": "周州", "年齡": 31, "長相": "帥", "體型": "瘦", "收入": "高", 見面: "不見" }, { "姓名": "李樂", "年齡": 27, "長相": "不帥", "體型": "胖", "收入": "高", 見面: "不見" }, { "姓名": "韓明", "年齡": 24, "長相": "不帥", "體型": "瘦", "收入": "高", 見面: "不見" }, { "姓名": "小呂", "年齡": 28, "長相": "帥", "體型": "瘦", "收入": "低", 見面: "不見" }, { "姓名": "李四", "年齡": 25, "長相": "帥", "體型": "瘦", "收入": "低", 見面: "不見" }, { "姓名": "王鵬", "年齡": 30, "長相": "帥", "體型": "瘦", "收入": "低", 見面: "不見" }, ]; var decisionTree = new DecisionTree(); console.log("函式 countUniqueValues 測試:"); console.log(" 長相", decisionTree.countUniqueValues(data, "長相"));//測試 console.log(" 年齡", decisionTree.countUniqueValues(data, "年齡"));//測試 console.log(" 收入", decisionTree.countUniqueValues(data, "收入"));//測試 console.log("函式 entropy 測試:"); console.log(" 長相", decisionTree.entropy(data, "長相"));//測試 console.log(" 年齡", decisionTree.entropy(data, "年齡"));//測試 console.log(" 收入", decisionTree.entropy(data, "收入"));//測試 console.log("函式 mostFrequentValue 測試:"); console.log(" 年齡", decisionTree.mostFrequentValue(data, "年齡"));//測試 console.log(" 長相", decisionTree.mostFrequentValue(data, "長相"));//測試 console.log(" 收入", decisionTree.mostFrequentValue(data, "收入"));//測試 console.log("函式 split 測試:"); console.log(" 長相", decisionTree.split(data, "長相", (a, b) => { return a == b }, "不帥"));//測試 console.log(" 年齡", decisionTree.split(data, "年齡", (a, b) => { return a >= b }, 30));//測試 console.log(" 年齡", decisionTree.split(data, "年齡", (a, b) => { return a < b }, 25));//測試 decisionTree.training({ trainingSet: data,//訓練集 categoryAttr: '見面',//用於區分的類別屬性 ignoredAttributes: ['姓名']//被忽略的屬性 }); // 測試決策樹與隨機森林 var comic = { "姓名": "劉建1", "年齡": 21, "長相": "帥", "體型": "瘦", "收入": "高" }; console.log(comic, decisionTree.predict(comic)); comic = { "姓名": "劉建2", "年齡": 22, "長相": "不帥", "體型": "瘦", "收入": "高" }; console.log(comic, decisionTree.predict(comic)); comic = { "姓名": "劉建3", "年齡": 27, "長相": "帥", "體型": "瘦", "收入": "低" }; console.log(comic, decisionTree.predict(comic)); comic = { "姓名": "劉建4", "年齡": 30, "長相": "帥", "體型": "瘦", "收入": "高" }; console.log(comic, decisionTree.predict(comic)); comic = { "姓名": "劉建5", "年齡": 29, "長相": "帥", "體型": "胖", "收入": "高" }; console.log(comic, decisionTree.predict(comic)); comic = { "姓名": "劉建6", "年齡": 29, "長相": "帥", "體型": "胖", "收入": "低" }; console.log(comic, decisionTree.predict(comic)); comic = { "姓名": "劉建7", "年齡": 40, "長相": "帥", "體型": "瘦", "收入": "低" }; console.log(comic, decisionTree.predict(comic)); </script>