1. 程式人生 > >Naive Bayes 樸素貝葉斯的JAVA程式碼實現

Naive Bayes 樸素貝葉斯的JAVA程式碼實現

1.關於貝葉斯分類

bayes 是一種統計學分類方法,它基於貝葉斯定理,它假定一個屬性值對給定類的影響獨立於其它屬性點的值。該假定稱作類條件獨立。做次假定是為了簡化所需計算,並在此意義下稱為“樸素的”。

bayes分類的演算法大致如下:

(1)對於屬性值是離散的,並且目標label值也是離散的情況下。分別計算label不同取值的概率,以及樣本在label情況下的概率值,然後將這些概率值相乘最後得到一個概率的乘積,選擇概率乘積最大的那個值對應的label值就為預測的結果。

例如以下:是預測蘋果在給定屬性的情況是甜還是不甜的情況:

color={0,1,2,3} weight={2,3,4};是屬性序列,為離散型。sweet={yes,no}是目標值,也為離散型;

這時我們要預測在color=3,weight=3的情況下的目標值,計算過程如下:

P{y=yes}=2/5=0.4;P{color=3|yes}=1/2=0.5;P{weight=3|yes}=1/2=0.5;   故F{color=3,weight=3}取yesd的概率為 0.4*0.5*0.5=0.1;

P{y=no}=3/5=0.6;P{color=3|no}=1/3 P{weight=3|no}=1/3;  故P{color=3,weight=3}取no為 0.6*1/3*1/3=1/15;

0.1>1/15 所以認為 F{color=3,weight=3}=yes;

(2)對於屬性值是連續的情況,思想和離散是相同的,只是這時候我們計算屬性的概率用的是高斯密度:

這裡的Xk就是樣本的取值,u是樣本所在列的均值,kesi是標準差;

最後程式碼如下:

/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package auxiliary;

import java.util.ArrayList;

/**
 *
 * @author Michael Kong
 */
public class NaiveBayes extends Classifier {

	boolean isClassfication[];
   ArrayList lblClass=new ArrayList();  //儲存目標值的種類
   ArrayListlblCount=new ArrayList();//儲存目標值的個數
   ArrayListlblProba=new ArrayList();//儲存對應的label的概率
   CountProbility countlblPro;
   /*@ClassListBasedLabel是將訓練陣列按照 label的順序來分類儲存*/
   ArrayList>> ClassListBasedLabel=new  ArrayList>> ();
    public NaiveBayes() {
    }
    @Override
    /**
     * @train主要完成求一些概率
     * 1.labels中的不同取值的概率f(Yi);  對應28,29行兩段程式碼
     * 2.將訓練陣列按目標值分類儲存   第37行程式碼

     * */
    public void train(boolean[] isCategory, double[][] features, double[] labels){
    	isClassfication=isCategory;
    	countlblPro=new CountProbility(isCategory,features,labels);
    	countlblPro.getlblClass(lblClass, lblCount, lblProba);  	
    	ArrayList> trainingList=countlblPro.UnionFeaLbl(features, labels); //union the features[][] and labels[]
    	ClassListBasedLabel=countlblPro.getClassListBasedLabel(lblClass, trainingList);
    }
    @Override
    /**3.在Y的條件下,計算Xi的概率 f(Xi/Y);
     * 4.返回使得Yi*Xi*...概率最大的那個label的取值
     * */
    public double predict(double[] features) {
    	
    	int max_index; //用於記錄使概率取得最大的那個索引
    	int index=0; //這個索引是 標識不同的labels 所對應的概率
    	ArrayList pro_=new ArrayList(); //這個概率陣列是儲存features[] 在不同labels下對應的概率
		for(ArrayList> elements: ClassListBasedLabel)  //依次取不同的label值對應的元祖集合
		{
		 	ArrayList pro=new ArrayList();//存同一個label對應的所有概率,之後其中的元素自乘
			double probility=1.0; //計算概率的乘積
			
			for(int i=0;i element:elements) //依次取labels中的所有元祖
					{
						if(element.get(i).equals(features[i])) //如果這個元祖的第index資料和b相等,那麼就count就加1
							count++;
					}
					if(count==0)
					{
						pro.add(1/(double)(elements.size()+1));
					}
					else
						pro.add(count/(double)elements.size()); //統計完所有之後  計算概率值 並加入
				}
				else
				{

		  			double Sdev;
		    		double Mean;
	    			double probi=1.0;
    				Mean=countlblPro.getMean(elements, i);
    				Sdev=countlblPro.getSdev(elements, i);
    				if(Sdev!=0)
    				{
    					probi*=((1/(Math.sqrt(2*Math.PI)*Sdev))*(Math.exp(-(features[i]-Mean)*(features[i]-Mean)/(2*Sdev*Sdev))));
    	    			pro.add(probi);
    				}
    				else
    					pro.add(1.5);
    				
	        	}
			}
			for(double pi:pro)
				probility*=pi; //將所有概率相乘
			probility*=lblProba.get(index);//最後再乘以一個 Yi
			pro_.add(probility);// 放入pro_ 至此 一個迴圈結束,	
			index++;
		}
		double max_pro=pro_.get(0);
		max_index=0;
		
		
		for(int i=1;i=max_pro)
			{
				max_pro=pro_.get(i);
				max_index=i;
			}	
		}  
		return  lblClass.get(max_index);
    }
    
    
    
    
    public class CountProbility
    {
    	boolean []isCatory;
    	double[][]features;
    	private double[]labels;
    	public CountProbility(boolean[] isCategory, double[][] features, double[] labels)
    	{
    		this.isCatory=isCategory;
    		this.features=features;
    		this.labels=labels;
    	}
    	//獲取label中取值情況
    	public void getlblClass(  ArrayList lblClass,ArrayListlblCount,ArrayListlblProba)
    	{
    		int j=0;
            for(double i:labels)
            {
            	//如果當前的label不存在於lblClass則加入
            	if(!lblClass.contains(i))
            	{
            		lblClass.add(j,i);
            		lblCount.add(j++,1);
            	}
            	else //如果label中已經存在,就將其計數加1
            	{
            		int index=lblClass.indexOf(i); 
            		int count=lblCount.get(index);
            		lblCount.set(index,++count);
            	}
            		
            }
            for(int i=0;i>  UnionFeaLbl(double[][] features, double[] labels)
    	{
    		ArrayList>traingList=new  ArrayList>();
    		for(int i=0;ielements=new ArrayList();
    			for(int j=0;j>> getClassListBasedLabel (ArrayList lblClass,ArrayList>trainingList)
    	{
    		ArrayList>> ClassListBasedLabel=new ArrayList>> () ;
        		for(double num:lblClass)
        		{
    				ArrayList> elements=new ArrayList>();
	    			for(ArrayListelement:trainingList)
	    			{
	    				if(element.get(element.size()-1).equals(num))
	    					elements.add(element);
	    			}
    			ClassListBasedLabel.add(elements);
        		}
    			return ClassListBasedLabel;
    	}
    	public double getMean(ArrayList> elements,int index)
    	{
    		double sum=0.0;
    		double Mean;
    		
    		for(ArrayList element:elements)
    		{
    			sum+=element.get(index);
    			
    		}
    		Mean=sum/(double)elements.size();
    		return  Mean;
    	}
    	public double getSdev(ArrayList> elements,int index)
    	{
    		double dev=0.0;
    		double Mean;
    		Mean=getMean(elements,index);
    		for(ArrayList element:elements)
    		{
    			dev+=Math.pow((element.get(index)-Mean),2);
    		}
    		dev=Math.sqrt(dev/elements.size());
    		return  dev;
    	}
    	
    	
    }
}