Weka演算法Classifier-trees-RandomTree原始碼分析
阿新 • • 發佈:2019-01-12
一、RandomTree演算法
在網上搜了一下,並沒有找到RandomTree的嚴格意義上的演算法描述,因此我覺得RandomTree充其量只是一種構建樹的思路,和普通決策樹相比,RandomTree會隨機的選擇若干屬性來進行構建而不是選取所有的屬性。
Weka在實現上,對於隨機屬性的選取、生成分裂點的過程是這樣的:
1、設定一個要選取的屬性的數量K
2、在全域屬性中無放回的對屬性進行抽樣
3、算出該屬性的資訊增益(注意不是資訊增益率)
4、重複K次,選出資訊增益最大的當分裂節點。
5、構建該節點的孩子子樹。
二、具體程式碼實現
(1)buildClassifier
- publicvoid buildClassifier(Instances data) throws Exception {
- // 如果傳入的K不合理,把K放到一個合理的範圍裡
- if (m_KValue > data.numAttributes() - 1)
- m_KValue = data.numAttributes() - 1;
- if (m_KValue < 1)
- m_KValue = (int) Utils.log2(data.numAttributes()) + 1;//這個是K的預設值
-
// 判斷一下該分類器是否有能力處理這個資料集,如果沒能力直接就在testWithFail裡拋異常退出了
- getCapabilities().testWithFail(data);
- // 刪除掉missClass
- data = new Instances(data);
- data.deleteWithMissingClass();
- // 如果只有一列,就build一個ZeroR模型,之後就結束了。ZeroR模型分類是這樣的:如果是連續型,總是返回期望,如果離散型,總是返回訓練集中出現最多的那個
- if (data.numAttributes() == 1) {
- System.err
-
.println("Cannot build model (only class attribute present in data!), "
- + "using ZeroR model instead!");
- m_zeroR = new weka.classifiers.rules.ZeroR();
- m_zeroR.buildClassifier(data);
- return;
- } else {
- m_zeroR = null;
- }
- // 如果m_NumFlods大於0,則會把資料集分為兩部分,一部分用於train,一部分用於test,也就是backfit
- //分的方式和多折交叉驗證是一樣的,例如m_NumFlods是10的話,則train佔90%,backfit佔10%
- Instances train = null;
- Instances backfit = null;
- Random rand = data.getRandomNumberGenerator(m_randomSeed);
- if (m_NumFolds <= 0) {
- train = data;
- } else {
- data.randomize(rand);
- data.stratify(m_NumFolds);
- train = data.trainCV(m_NumFolds, 1, rand);
- backfit = data.testCV(m_NumFolds, 1);
- }
- // 生成所有的可選屬性
- int[] attIndicesWindow = newint[data.numAttributes() - 1];
- int j = 0;
- for (int i = 0; i < attIndicesWindow.length; i++) {
- if (j == data.classIndex())
- j++; // 忽略掉classIndex
- attIndicesWindow[i] = j++;//這段程式碼有點奇怪,i和j是相等的,為啥不用attIndicesWindow=i?
- }
- // 算出每個class的頻率,也就是每個分類出現的次數(更正確的說法應該是權重,但權重預設都是1)
- double[] classProbs = newdouble[train.numClasses()];
- for (int i = 0; i < train.numInstances(); i++) {
- Instance inst = train.instance(i);
- classProbs[(int) inst.classValue()] += inst.weight();
- }
- // Build tree
- m_Tree = new Tree();
- m_Info = new Instances(data, 0);
- m_Tree.buildTree(train, classProbs, attIndicesWindow, rand, 0);//呼叫tree的build方法,在後面單獨分析
- // Backfit if required
- if (backfit != null) {
- m_Tree.backfitData(backfit);//在後面單獨分析
- }
- }
這個Tree物件是RandomTree的一個子類,之前我還以為會複用其餘的決策樹模型(比如J48),但weka沒這麼做,很驚奇的是RandomTree和J48的作者還是同一個,不知道為啥這麼設計。
(2)tree.buildTree
- protectedvoid buildTree(Instances data, double[] classProbs,
- int[] attIndicesWindow, Random random, int depth) throws Exception {
- //首先判斷一下是否有instance,如果沒有的話直接就返回
- if (data.numInstances() == 0) {
- m_Attribute = -1;
- m_ClassDistribution = null;
- m_Prop = null;
- return;
- }
- m_ClassDistribution = classProbs.clone();
- if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum
- || Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)],
- Utils.sum(m_ClassDistribution))
- || ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) {
- // 遞迴結束的條件有3個 1、instance數量小於2*m_Minnum 2、instance都已經在同一個類中 3、達到最大的深度
- //前兩個條件和j48的遞迴結束條件很相似,相關內容可參考我之前的幾篇部落格。
- m_Attribute = -1;
- m_Prop = null;
- return;
- }
- double val = -Double.MAX_VALUE;
- double split = -Double.MAX_VALUE;
- double[][] bestDists = null;
- double[] bestProps = null;
- int bestIndex = 0;
- double[][] props = newdouble[1][0];
- double[][][] dists = newdouble[1][0][0];//這個陣列第一列只有下標為0的被用到,不知道為啥這麼設計
- int attIndex = 0;//儲存被選擇到的屬性
- int windowSize = attIndicesWindow.length;//儲存目前可選擇的屬性的數量
- int k = m_KValue;//k代表還能選擇的屬性的數量
- boolean gainFound = false;//是否發現了一個有資訊增益的節點
- while ((windowSize > 0) && (k-- > 0 || !gainFound)) {//此迴圈退出條件有2個 1、沒有節點可以選了 2、已經選了k個屬性了並且找到了一個有用的屬性 換句話說,如果K次迭代沒有找到可以分裂的隨機節點,迴圈也會繼續下去