基於NaiveBayes的文字分類之Spark實現
阿新 • • 發佈:2019-01-03
在嘗試了python下面用sklearn進行文字分類(http://blog.csdn.net/a_step_further/article/details/50189727)後,我們再來看下用spark如何實現文字分類的工作,採用的演算法同樣是樸素貝葉斯。
此前,我們已經實現了hadoop叢集環境下使用mapreduce進行中文分詞(http://blog.csdn.net/a_step_further/article/details/50333961),那麼文字分類的過程也使用叢集環境操作,相對於python的單機版本實現,無疑更方便一些。
上程式碼:
import org.apache.spark.mllib.classification.NaiveBayes import org.apache.spark.mllib.feature.{IDFModel, HashingTF, IDF} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.{SparkContext, SparkConf} object textClassify { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("text_classify").set("spark.akka.frameSize","20") val sc = new SparkContext(conf) if(args.length != 2){ println("Usage: textClassify <inputLoc> <idfSaveLoc> <modelSaveLoc> ") System.exit(-1) } val inputLoc = args(0) val inputData = sc.textFile(inputLoc).map(line => line.split("\t")).filter(_.length == 2).cache() val features = inputData.map(x => x(1).split(" ").toSeq).cache() val hashingTF = new HashingTF() val tf = hashingTF.transform(features) val idf: IDFModel = new IDF(minDocFreq = 2).fit(tf) val tfIdf = idf.transform(tf) val zippedData = inputData.map(x => x(0)).zip(tfIdf).map{case (label, tfIdf) => LabeledPoint(label.toDouble, tfIdf) }.cache() val randomSplitData = zippedData.randomSplit(Array(0.6, 0.4), seed=10L) zippedData.unpersist() val trainData = randomSplitData(0).cache() val testData = randomSplitData(1) val model = NaiveBayes.train(trainData, lambda = 0.1) trainData.unpersist() //預測 val predictTestData = testData.map{case x => (model.predict(x.features), x.label)} val totalTrueNum = predictTestData.filter(x => x._2 == 1.0).count() val predictTrueNum = predictTestData.filter(x => x._1 == 1.0).count() val predictRealTrue = predictTestData.filter(x => x._1 == x._2 && x._2 == 1.0).count() println("results------------------------------------------------") println("準確率:", 1.0*predictRealTrue/predictTrueNum) println("召回率:",1.0*predictRealTrue/totalTrueNum) println("------------------------------------------------") val modelSaveLoc = args(1) model.save(sc,modelSaveLoc) sc.stop() } }