樸素貝葉斯文字分類java實現
阿新 • • 發佈:2019-01-23
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import com.data.util.IoUtil;
public class NativeBayes {
/**
* 預設頻率
*/
private double defaultFreq = 0.1;
/**
* 訓練資料的比例
*/
private Double trainingPercent = 0.8;
private Map<String, List<String>> files_all = new HashMap<String, List<String>>();
private Map<String, List<String>> files_train = new HashMap<String, List<String>>();
private Map<String, List<String>> files_test = new HashMap<String, List<String>>();
public NativeBayes() {
}
/**
* 每個分類的頻率
*/
private Map<String, Integer> classFreq = new HashMap<String, Integer>();
private Map<String, Double> ClassProb = new HashMap<String, Double>();
/**
* 特徵總數
*/
private Set<String> WordDict = new HashSet<String>();
private Map<String, Map<String, Integer>> classFeaFreq = new HashMap<String, Map<String, Integer>>();
private Map<String, Map<String, Double>> ClassFeaProb = new HashMap<String, Map<String, Double>>();
private Map<String, Double> ClassDefaultProb = new HashMap<String, Double>();
/**
* 計算準確率
* @param reallist 真實類別
* @param pridlist 預測類別
*/
public void Evaluate(List<String> reallist, List<String> pridlist){
double correctNum = 0.0;
for (int i = 0; i < reallist.size(); i++) {
if(reallist.get(i) == pridlist.get(i)){
correctNum += 1;
}
}
double accuracy = correctNum / reallist.size();
System.out.println("準確率為:" + accuracy);
}
/**
* 計算精確率和召回率
* @param reallist
* @param pridlist
* @param classname
*/
public void CalPreRec(List<String> reallist, List<String> pridlist, String classname){
double correctNum = 0.0;
double allNum = 0.0;//測試資料中,某個分類的文章總數
double preNum = 0.0;//測試資料中,預測為該分類的文章總數
for (int i = 0; i < reallist.size(); i++) {
if(reallist.get(i) == classname){
allNum += 1;
if(reallist.get(i) == pridlist.get(i)){
correctNum += 1;
}
}
if(pridlist.get(i) == classname){
preNum += 1;
}
}
System.out.println(classname + " 精確率(跟預測分類比較):" + correctNum / preNum + " 召回率(跟真實分類比較):" + correctNum / allNum);
}
/**
* 用模型進行預測
*/
public void PredictTestData() {
List<String> reallist=new ArrayList<String>();
List<String> pridlist=new ArrayList<String>();
for (Entry<String, List<String>> entry : files_test.entrySet()) {
String realclassname = entry.getKey();
List<String> files = entry.getValue();
for (String file : files) {
reallist.add(realclassname);
List<String> classnamelist=new ArrayList<String>();
List<Double> scorelist=new ArrayList<Double>();
for (Entry<String, Double> entry_1 : ClassProb.entrySet()) {
String classname = entry_1.getKey();
//先驗概率
Double score = Math.log(entry_1.getValue());
String[] words = IoUtil.readFromFile(new File(file)).split(" ");
for (String word : words) {
if(!WordDict.contains(word)){
continue;
}
if(ClassFeaProb.get(classname).containsKey(word)){
score += Math.log(ClassFeaProb.get(classname).get(word));
}else{
score += Math.log(ClassDefaultProb.get(classname));
}
}
classnamelist.add(classname);
scorelist.add(score);
}
Double maxProb = Collections.max(scorelist);
int idx = scorelist.indexOf(maxProb);
pridlist.add(classnamelist.get(idx));
}
}
Evaluate(reallist, pridlist);
for (String cname : files_test.keySet()) {
CalPreRec(reallist, pridlist, cname);
}
}
/**
* 模型訓練
*/
public void createModel() {
double sum = 0.0;
for (Entry<String, Integer> entry : classFreq.entrySet()) {
sum+=entry.getValue();
}
for (Entry<String, Integer> entry : classFreq.entrySet()) {
ClassProb.put(entry.getKey(), entry.getValue()/sum);
}
for (Entry<String, Map<String, Integer>> entry : classFeaFreq.entrySet()) {
sum = 0.0;
String classname = entry.getKey();
for (Entry<String, Integer> entry_1 : entry.getValue().entrySet()){
sum += entry_1.getValue();
}
double newsum = sum + WordDict.size()*defaultFreq;
Map<String, Double> feaProb = new HashMap<String, Double>();
ClassFeaProb.put(classname, feaProb);
for (Entry<String, Integer> entry_1 : entry.getValue().entrySet()){
String word = entry_1.getKey();
feaProb.put(word, (entry_1.getValue() +defaultFreq) /newsum);
}
ClassDefaultProb.put(classname, defaultFreq/newsum);
}
}
/**
* 載入訓練資料
*/
public void loadTrainData(){
for (Entry<String, List<String>> entry : files_train.entrySet()) {
String classname = entry.getKey();
List<String> docs = entry.getValue();
classFreq.put(classname, docs.size());
Map<String, Integer> feaFreq = new HashMap<String, Integer>();
classFeaFreq.put(classname, feaFreq);
for (String doc : docs) {
String[] words = IoUtil.readFromFile(new File(doc)).split(" ");
for (String word : words) {
WordDict.add(word);
if(feaFreq.containsKey(word)){
int num = feaFreq.get(word) + 1;
feaFreq.put(word, num);
}else{
feaFreq.put(word, 1);
}
}
}
}
System.out.println(classFreq.size()+" 分類, " + WordDict.size()+" 特徵詞");
}
/**
* 將資料分為訓練資料和測試資料
*
* @param dataDir
*/
public void splitData(String dataDir) {
// 用檔名區分類別
Pattern pat = Pattern.compile("\\d+([a-z]+?)\\.");
dataDir = "testdata/allfiles";
File f = new File(dataDir);
File[] files = f.listFiles();
for (File file : files) {
String fname = file.getName();
Matcher m = pat.matcher(fname);
if (m.find()) {
String cname = m.group(1);
if (files_all.containsKey(cname)) {
files_all.get(cname).add(file.toString());
} else {
List<String> tmp = new ArrayList<String>();
tmp.add(file.toString());
files_all.put(cname, tmp);
}
} else {
System.out.println("err: " + file);
}
}
System.out.println("統計資料:");
for (Entry<String, List<String>> entry : files_all.entrySet()) {
String cname = entry.getKey();
List<String> value = entry.getValue();
// System.out.println(cname + " : " + value.size());
List<String> train = new ArrayList<String>();
List<String> test = new ArrayList<String>();
for (String str : value) {
if (Math.random() <= trainingPercent) {// 80%用來訓練 , 20%測試
train.add(str);
} else {
test.add(str);
}
}
files_train.put(cname, train);
files_test.put(cname, test);
}
System.out.println("所有檔案數:");
printStatistics(files_all);
System.out.println("訓練檔案數:");
printStatistics(files_train);
System.out.println("測試檔案數:");
printStatistics(files_test);
}
/**
* 列印統計資訊
*
* @param m
*/
public void printStatistics(Map<String, List<String>> m) {
for (Entry<String, List<String>> entry : m.entrySet()) {
String cname = entry.getKey();
List<String> value = entry.getValue();
System.out.println(cname + " : " + value.size());
}
System.out.println("--------------------------------");
}
public static void main(String[] args) {
NativeBayes bayes = new NativeBayes();
bayes.splitData(null);
bayes.loadTrainData();
bayes.createModel();
bayes.PredictTestData();
}
}
所有檔案數:
sports : 1018
auto : 1020
business : 1028
--------------------------------
訓練檔案數:
sports : 791
auto : 812
business : 808
--------------------------------
測試檔案數:
sports : 227
auto : 208
business : 220
--------------------------------
分類, 39613 特徵詞
準確率為:0.9801526717557252
sports 精確率(跟預測分類比較):0.9956140350877193 召回率(跟真實分類比較):1.0
auto 精確率(跟預測分類比較):0.9579439252336449 召回率(跟真實分類比較):0.9855769230769231
business 精確率(跟預測分類比較):0.9859154929577465 召回率(跟真實分類比較):0.9545454545454546
統計資料:
所有檔案數:
sports : 1018
auto : 1020
business : 1028
--------------------------------
訓練檔案數:
sports : 827
auto : 833
business : 825
--------------------------------
測試檔案數:
sports : 191
auto : 187
business : 203
--------------------------------
分類, 39907 特徵詞
準確率為:0.9759036144578314
sports 精確率(跟預測分類比較):0.9894736842105263 召回率(跟真實分類比較):0.9842931937172775
auto 精確率(跟預測分類比較):0.9836956521739131 召回率(跟真實分類比較):0.9679144385026738
business 精確率(跟預測分類比較):0.9565217391304348 召回率(跟真實分類比較):0.9753694581280788