Naive Bayes 樸素貝葉斯的JAVA程式碼實現
阿新 • • 發佈:2019-01-25
1.關於貝葉斯分類
bayes 是一種統計學分類方法,它基於貝葉斯定理,它假定一個屬性值對給定類的影響獨立於其它屬性點的值。該假定稱作類條件獨立。做次假定是為了簡化所需計算,並在此意義下稱為“樸素的”。
bayes分類的演算法大致如下:
(1)對於屬性值是離散的,並且目標label值也是離散的情況下。分別計算label不同取值的概率,以及樣本在label情況下的概率值,然後將這些概率值相乘最後得到一個概率的乘積,選擇概率乘積最大的那個值對應的label值就為預測的結果。
例如以下:是預測蘋果在給定屬性的情況是甜還是不甜的情況:
color={0,1,2,3} weight={2,3,4};是屬性序列,為離散型。sweet={yes,no}是目標值,也為離散型;
這時我們要預測在color=3,weight=3的情況下的目標值,計算過程如下:
P{y=yes}=2/5=0.4;P{color=3|yes}=1/2=0.5;P{weight=3|yes}=1/2=0.5; 故F{color=3,weight=3}取yesd的概率為 0.4*0.5*0.5=0.1;
P{y=no}=3/5=0.6;P{color=3|no}=1/3 P{weight=3|no}=1/3; 故P{color=3,weight=3}取no為 0.6*1/3*1/3=1/15;
0.1>1/15 所以認為 F{color=3,weight=3}=yes;
(2)對於屬性值是連續的情況,思想和離散是相同的,只是這時候我們計算屬性的概率用的是高斯密度:
這裡的Xk就是樣本的取值,u是樣本所在列的均值,kesi是標準差;
最後程式碼如下:
/* * To change this template, choose Tools | Templates * and open the template in the editor. */ package auxiliary; import java.util.ArrayList; /** * * @author Michael Kong */ public class NaiveBayes extends Classifier { boolean isClassfication[]; ArrayList lblClass=new ArrayList(); //儲存目標值的種類 ArrayListlblCount=new ArrayList();//儲存目標值的個數 ArrayListlblProba=new ArrayList();//儲存對應的label的概率 CountProbility countlblPro; /*@ClassListBasedLabel是將訓練陣列按照 label的順序來分類儲存*/ ArrayList>> ClassListBasedLabel=new ArrayList>> (); public NaiveBayes() { } @Override /** * @train主要完成求一些概率 * 1.labels中的不同取值的概率f(Yi); 對應28,29行兩段程式碼 * 2.將訓練陣列按目標值分類儲存 第37行程式碼 * */ public void train(boolean[] isCategory, double[][] features, double[] labels){ isClassfication=isCategory; countlblPro=new CountProbility(isCategory,features,labels); countlblPro.getlblClass(lblClass, lblCount, lblProba); ArrayList> trainingList=countlblPro.UnionFeaLbl(features, labels); //union the features[][] and labels[] ClassListBasedLabel=countlblPro.getClassListBasedLabel(lblClass, trainingList); } @Override /**3.在Y的條件下,計算Xi的概率 f(Xi/Y); * 4.返回使得Yi*Xi*...概率最大的那個label的取值 * */ public double predict(double[] features) { int max_index; //用於記錄使概率取得最大的那個索引 int index=0; //這個索引是 標識不同的labels 所對應的概率 ArrayList pro_=new ArrayList(); //這個概率陣列是儲存features[] 在不同labels下對應的概率 for(ArrayList> elements: ClassListBasedLabel) //依次取不同的label值對應的元祖集合 { ArrayList pro=new ArrayList();//存同一個label對應的所有概率,之後其中的元素自乘 double probility=1.0; //計算概率的乘積 for(int i=0;i element:elements) //依次取labels中的所有元祖 { if(element.get(i).equals(features[i])) //如果這個元祖的第index資料和b相等,那麼就count就加1 count++; } if(count==0) { pro.add(1/(double)(elements.size()+1)); } else pro.add(count/(double)elements.size()); //統計完所有之後 計算概率值 並加入 } else { double Sdev; double Mean; double probi=1.0; Mean=countlblPro.getMean(elements, i); Sdev=countlblPro.getSdev(elements, i); if(Sdev!=0) { probi*=((1/(Math.sqrt(2*Math.PI)*Sdev))*(Math.exp(-(features[i]-Mean)*(features[i]-Mean)/(2*Sdev*Sdev)))); pro.add(probi); } else pro.add(1.5); } } for(double pi:pro) probility*=pi; //將所有概率相乘 probility*=lblProba.get(index);//最後再乘以一個 Yi pro_.add(probility);// 放入pro_ 至此 一個迴圈結束, index++; } double max_pro=pro_.get(0); max_index=0; for(int i=1;i=max_pro) { max_pro=pro_.get(i); max_index=i; } } return lblClass.get(max_index); } public class CountProbility { boolean []isCatory; double[][]features; private double[]labels; public CountProbility(boolean[] isCategory, double[][] features, double[] labels) { this.isCatory=isCategory; this.features=features; this.labels=labels; } //獲取label中取值情況 public void getlblClass( ArrayList lblClass,ArrayListlblCount,ArrayListlblProba) { int j=0; for(double i:labels) { //如果當前的label不存在於lblClass則加入 if(!lblClass.contains(i)) { lblClass.add(j,i); lblCount.add(j++,1); } else //如果label中已經存在,就將其計數加1 { int index=lblClass.indexOf(i); int count=lblCount.get(index); lblCount.set(index,++count); } } for(int i=0;i> UnionFeaLbl(double[][] features, double[] labels) { ArrayList>traingList=new ArrayList>(); for(int i=0;ielements=new ArrayList(); for(int j=0;j>> getClassListBasedLabel (ArrayList lblClass,ArrayList>trainingList) { ArrayList>> ClassListBasedLabel=new ArrayList>> () ; for(double num:lblClass) { ArrayList> elements=new ArrayList>(); for(ArrayListelement:trainingList) { if(element.get(element.size()-1).equals(num)) elements.add(element); } ClassListBasedLabel.add(elements); } return ClassListBasedLabel; } public double getMean(ArrayList> elements,int index) { double sum=0.0; double Mean; for(ArrayList element:elements) { sum+=element.get(index); } Mean=sum/(double)elements.size(); return Mean; } public double getSdev(ArrayList> elements,int index) { double dev=0.0; double Mean; Mean=getMean(elements,index); for(ArrayList element:elements) { dev+=Math.pow((element.get(index)-Mean),2); } dev=Math.sqrt(dev/elements.size()); return dev; } } }