java寫的決策樹演算法(資料探勘演算法)
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Iterator;
//除錯過程中發現4個錯誤 ,感謝宇宙無敵的除錯工具——print
//1、selectAtrribute中的一個數組下標出錯 2、兩個字串相等的判斷
//3、輸入的資料有一個錯誤 4、selectAtrribute中最後一個迴圈忘記了i++
//決策樹的樹結點類
class TreeNode {
String element; //該值為資料的屬性名稱
String value; //上一個分裂屬性在此結點的值
LinkedHashSet<TreeNode> childs; //結點的子結點,以有順序的鏈式雜湊集儲存
public TreeNode() {
this.element = null;
this.value = null;
this.childs = null;
}
public TreeNode(String value) {
this.element = null;
this.value = value;
this.childs = null;
}
public String getElement() {
return this.element;
}
public void setElement(String e) {
this.element = e;
}
public String getValue() {
return this.value;
}
public void setValue(String v) {
this.value = v;
}
public LinkedHashSet<TreeNode> getChilds() {
return this.childs;
}
public void setChilds(LinkedHashSet<TreeNode> childs) {
this.childs = childs;
}
}
//決策樹類
class DecisionTree {
TreeNode root; //決策樹的樹根結點
public DecisionTree() {
root = new TreeNode();
}
public DecisionTree(TreeNode root) {
this.root = root;
}
public TreeNode getRoot() {
return root;
}
public void setRoot(TreeNode root) {
this.root = root;
}
public String selectAtrribute(TreeNode node,String[][] deData, boolean flags[],
LinkedHashSet<String> atrributes, HashMap<String,Integer> attrIndexMap) {
//Gain陣列存放當前結點未分類屬性的Gain值
double Gain[] = new double[atrributes.size()];
//每條資料中歸類的下標,為每條資料的最後一個值
int class_index = deData[0].length - 1;
//屬性名,該結點在該屬性上進行分類
String return_atrribute = null;
//計算每個未分類屬性的 Gain值
int count = 0; //計算到第幾個屬性
for(String atrribute:atrributes) {
//該屬性有多少個值,該屬性有多少個分類
int values_count, class_count;
//屬性值對應的下標
int index = attrIndexMap.get(atrribute);
//存放屬性的各個值和分類值
LinkedHashSet<String> values = new LinkedHashSet<String>();
LinkedHashSet<String> classes = new LinkedHashSet<String>();
for(int i = 0; i < deData.length; i++) {
if(flags[i] == true) {
values.add(deData[i][index]);
classes.add(deData[i][class_index]);
}
}
values_count = values.size();
class_count = classes.size();
int values_vector[] = new int[values_count * class_count];
int class_vector[] = new int[class_count];
for(int i = 0; i < deData.length; i++) {
if(flags[i] == true) {
int j = 0;
for(String v:values) {
if(deData[i][index].equals(v)) {
break;
} else {
j++;
}
}
int k = 0;
for(String c:classes) {
if(deData[i][class_index].equals(c)) {
break;
} else {
k++;
}
}
values_vector[j*class_count+k]++;
class_vector[k]++;
}
}
/* //輸出各項統計值
for(int i = 0; i < values_count * class_count; i++) {
System.out.print(values_vector[i] + " ");
}
System.out.println();
for(int i = 0; i < class_count; i++) {
System.out.print(class_vector[i] + " ");
}
System.out.println();
*/
//計算InforD
double InfoD = 0.0;
double class_total = 0.0;
for(int i = 0; i < class_vector.length; i++){
class_total += class_vector[i];
}
for(int i = 0; i < class_vector.length; i++){
if(class_vector[i] == 0) {
continue;
} else {
double d = Math.log(class_vector[i]/class_total) / Math.log(2.0) * class_vector[i] / class_total;
InfoD = InfoD - d;
}
}
//計算InfoA
double InfoA = 0.0;
for(int i = 0; i < values_count; i++) {
double middle = 0.0;
double attr_count = 0.0;
for(int j = 0; j < class_count; j++) {
attr_count += values_vector[i*class_count+j];
}
for(int j = 0; j < class_count; j++) {
if(values_vector[i*class_count+j] != 0) {
double k = values_vector[i*class_count+j];
middle = middle - Math.log(k/attr_count) / Math.log(2.0) * k / attr_count;
}
}
InfoA += middle * attr_count / class_total;
}
Gain[count] = InfoD - InfoA;
count++;
}
double max = 0.0;
int i = 0;
for(String atrribute:atrributes) {
if(Gain[i] > max) {
max = Gain[i];
return_atrribute = atrribute;
}
i++;
}
return return_atrribute;
}
//node:在當前結點構造決策樹
//deData:資料集
//flags:指示在當前結點構造決策樹時哪些資料是需要的
//attributes:未分類的屬性集
//attrIndexMap:屬性與對應資料下標
public void buildDecisionTree(TreeNode node, String[][] deData, boolean flags[],
LinkedHashSet<String> attributes, HashMap<String,Integer> attrIndexMap) {
//如果待分類屬性已空
if(attributes.isEmpty() == true) {
//從資料集中選擇多數類,遍歷符合條件的所有資料
HashMap<String,Integer> classMap = new HashMap<String,Integer>();
int classIndex = deData[0].length - 1;
for(int i = 0; i < deData.length; i++) {
if(flags[i] == true) {
if(classMap.containsKey(deData[i][classIndex])) {
int count = classMap.get(deData[i][classIndex]);
classMap.put(deData[i][classIndex], count+1);
} else {
classMap.put(deData[i][classIndex], 1);
}
}
}
//選擇多數類
String mostClass = null;
int mostCount = 0;
Iterator<String> it = classMap.keySet().iterator();
while(it.hasNext()) {
String strClass = (String)it.next();
if(classMap.get(strClass) > mostCount) {
mostClass = strClass;
mostCount = classMap.get(strClass);
}
}
//對結點進行賦值,該結點為葉結點
node.setElement(mostClass);
node.setChilds(null);
System.out.println("yezhi:" + node.getElement() + ":" + node.getValue());
return;
}
//如果待分類資料全都屬於一個類
int class_index = deData[0].length - 1;
String class_name = null;
HashSet<String> classSet = new HashSet<String>();
for(int i = 0; i < deData.length; i++) {
if(flags[i] == true) {
class_name = deData[i][class_index];
classSet.add(class_name);
}
}
//則該結點為葉結點,設定有關值,然後返回
if(classSet.size() == 1) {
node.setElement(class_name);
node.setChilds(null);
System.out.println("leaf:" + node.getElement() + ":" + node.getValue());
return;
}
//給定的分枝沒有元組,是不是有這種情況?
//選擇一個分類屬性
String attribute = selectAtrribute(node, deData, flags, attributes, attrIndexMap);
//設定分裂結點的值
node.setElement(attribute);
//System.out.println(attribute);
if(node == root) {
System.out.println("root:" + node.getElement() + ":" + node.getValue());
} else {
System.out.println("branch:" + node.getElement() + ":" + node.getValue());
}
//生成和設定各個子結點
int attrIndex = attrIndexMap.get(attribute);
LinkedHashSet<String> attrValues = new LinkedHashSet<String>();
for(int i = 0; i < deData.length; i++) {
if(flags[i] == true) {
attrValues.add(deData[i][attrIndex]);
}
}
LinkedHashSet<TreeNode> childs = new LinkedHashSet<TreeNode>();
for(String attrValue:attrValues) {
TreeNode tn = new TreeNode(attrValue);
childs.add(tn);
}
node.setChilds(childs);
//在候選分類屬性中刪除當前屬性
attributes.remove(attribute);
//在各個子結點上遞迴呼叫本函式
if(childs.isEmpty() != true) {
for(TreeNode child:childs) {
//設定子結點待分類的資料集
boolean newFlags[] = new boolean[deData.length] ;
for(int i = 0; i < deData.length; i++) {
newFlags[i] = flags[i];
if(deData[i][attrIndex] != child.getValue()) {
newFlags[i] = false;
}
}
//設定子結點待分類的屬性集
LinkedHashSet<String> newAttributes = new LinkedHashSet<String>();
for(String attr:attributes) {
newAttributes.add(attr);
}
//在子結點上遞迴生成決策樹
buildDecisionTree(child, deData, newFlags, newAttributes, attrIndexMap);
}
}
}
//輸出決策樹
public void printDecisionTree() {
}
}
public class DeTree {
public static void main(String[] args) {
/*
//輸入資料集1
String deData[][] = new String[12][];
deData[0] = new String[]{"Yes","No","No","Yes","Some","high","No","Yes","French","0~10","Yes"};
deData[1] = new String[]{"Yes","No","No","Yes","Full","low","No","No","Thai","30~60","No"};
deData[2] = new String[]{"No","Yes","No","No","Some","low","No","No","Burger","0~10","Yes"};
deData[3] = new String[]{"Yes","No","Yes","Yes","Full","low","Yes","No","Thai","10~30","Yes"};
deData[4] = new String[]{"Yes","No","Yes","No","Full","high","No","Yes","French",">60","No"};
deData[5] = new String[]{"No","Yes","No","Yes","Some","middle","Yes","Yes","Italian","0~10","Yes"};
deData[6] = new String[]{"No","Yes","No","No","None","low","Yes","No","Burger","0~10","No"};
deData[7] = new String[]{"No","No","No","Yes","Some","middle","Yes","Yes","Thai","0~10","Yes"};
deData[8] = new String[]{"No","Yes","Yes","No","Full","low","Yes","No","Burger",">60","No"};
deData[9] = new String[]{"Yes","Yes","Yes","Yes","Full","high","No","Yes","Italian","10~30","No"};
deData[10]= new String[]{"No","No","No","No","None","low","No","No","Thai","0~10","No"};
deData[11]= new String[]{"Yes","Yes","Yes","Yes","Full","low","No","No","Burger","30~60","Yes"};
//待分類的屬性集1
String attr[] = new String[]{"alt", "bar", "fri", "hun", "pat", "price", "rain", "res", "type", "est"};
*/
//輸入資料集2
String deData[][] = new String[14][];
deData[0] = new String[]{"youth","high","no","fair","no"};
deData[1] = new String[]{"youth","high","no","excellent","no"};
deData[2] = new String[]{"middle_aged","high","no","fair","yes"};
deData[3] = new String[]{"senior","medium","no","fair","yes"};
deData[4] = new String[]{"senior","low","yes","fair","yes"};
deData[5] = new String[]{"senior","low","yes","excellent","no"};
deData[6] = new String[]{"middle_aged","low","yes","excellent","yes"};
deData[7] = new String[]{"youth","medium","no","fair","no"};
deData[8] = new String[]{"youth","low","yes","fair","yes"};
deData[9] = new String[]{"senior","medium","yes","fair","yes"};
deData[10]= new String[]{"youth","medium","yes","excellent","yes"};
deData[11]= new String[]{"middle_aged","medium","no","excellent","yes"};
deData[12]= new String[]{"middle_aged","high","yes","fair","yes"};
deData[13]= new String[]{"senior","medium","no","excellent","no"};
//待分類的屬性集2
String attr[] = new String[]{"age", "income", "student", "credit_rating"};
LinkedHashSet<String> attributes = new LinkedHashSet<String>();
for(int i = 0; i < attr.length; i++) {
attributes.add(attr[i]);
}
//屬性與資料集中對應資料的下標
HashMap<String,Integer> attrIndexMap = new HashMap<String,Integer>();
for(int i = 0; i < attr.length; i++) {
attrIndexMap.put(attr[i], i);
}
//需要分類的資料,初始為整個資料集
boolean flags[] = new boolean[deData.length];
for(int i = 0; i < deData.length; i++) {
flags[i] = true;
}
//構造決策樹
TreeNode root = new TreeNode();
DecisionTree decisionTree = new DecisionTree(root);
decisionTree.buildDecisionTree(root, deData, flags, attributes, attrIndexMap);
}
}