Spark實現簡單的垃圾郵件分類--JAVA原始碼
阿新 • • 發佈:2021-01-26
Spark實現簡單的垃圾郵件分類–JAVA原始碼
程式碼部分
package cn.cc.spark;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache. spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
import org.apache.spark.mllib.feature.HashingTF;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
public final class Spam {
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf().setAppName("垃圾郵件訓練").setMaster ("local[*]");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<String> spam = sc.textFile("hdfs://localhost:8020/sample/spam/spmsga1.eml");
JavaRDD<String> mail = sc.textFile("hdfs://localhost:8020/sample/mail/3-1msg1.txt");
// 例項化HashingTF,用於將郵件文字對映為x個特徵向量
final HashingTF tf = new HashingTF(10000);
//map:將RDD中的所有元素應用於Function並返回處理後的新RDD
//Function<A,B> 表示輸入型別為A,輸出型別為B
JavaRDD<LabeledPoint> positiveData = setLabeledPoint(spam, tf, 1.0);
JavaRDD<LabeledPoint> negativeData = setLabeledPoint(mail, tf, 0.0);
//將陰性陽性資料整合為訓練集
JavaRDD<LabeledPoint> trainingData = positiveData.union(negativeData);
// 快取訓練資料RDD
trainingData.cache();
//使用梯度下降法進行邏輯迴歸訓練資料
LogisticRegressionWithSGD lrLearner = new LogisticRegressionWithSGD();
LogisticRegressionModel model = lrLearner.run(trainingData.rdd());
//以垃圾郵件和正常郵件作為例子測試
//對資料同樣的HashingTF進行特性轉換
Vector positiveSample = tf.transform(filterText(sc, "hdfs://localhost:8020/sample/spam/spmsga2.eml"));
predictionResult("spmsga2.eml", model.predict(positiveSample));
Vector negativeSample = tf.transform(filterText(sc, "hdfs://localhost:8020/sample/mail/3-1msg2.txt"));
predictionResult("3-1msg2.txt", model.predict(negativeSample));
sc.stop();
sc.close();
}
public static void predictionResult(String emailName,double predictValue) {
if (predictValue==1.0) {
Logger.getGlobal().info(emailName+"郵件的預測結果為垃圾郵件"+predictValue);
}else {
Logger.getGlobal().info(emailName+"預測結果為正常郵件"+predictValue);
}
}
/**
* 設定陰性陽性標記點
* @param rdd
* @param tf
* @param lable
* @return
*/
public static JavaRDD<LabeledPoint> setLabeledPoint(JavaRDD<String> rdd,HashingTF tf,Double lable ){
return rdd.map(new Function<String, LabeledPoint>() {
private static final long serialVersionUID = 1L;
//重寫call方法,實現轉換,輸入引數為String,輸出引數為LabeledPoint
@Override public LabeledPoint call(String email) {
List<String> list = new LinkedList<>();
//匹配有效字元
Pattern pattern = Pattern.compile("[\\w]*");
Matcher matcher = pattern.matcher(email);
while(matcher.find()) {
list.add(matcher.group());
}
//設定標註點
return new LabeledPoint(lable, tf.transform(list));
}
});
}
public static List<String> filterText(JavaSparkContext sc,String path){
List<String> list = new LinkedList<>();
List<String> data = sc.textFile(path).collect();
Iterator<String> iter = data.iterator();
while(iter.hasNext()) {
String str = iter.next();
if (str.trim().isEmpty()) {
iter.remove();
}else {
list.addAll(Arrays.asList(str.split(" ")));
}
}
return list;
}
}
相關依賴
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.10</artifactId>
<version>1.3.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.10</artifactId>
<version>1.3.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.10</artifactId>
<version>1.3.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_2.10</artifactId>
<version>1.3.1</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming-kafka_2.10</artifactId>
<version>1.3.1</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.10</artifactId>
<version>1.3.1</version>
</dependency>
<dependency>
<groupId>com.datastax.spark</groupId>
<artifactId>spark-cassandra-connector_2.10</artifactId>
<version>1.0.0-rc5</version>
</dependency>
<dependency>
<groupId>com.datastax.spark</groupId>
<artifactId>spark-cassandra-connector-java_2.10</artifactId>
<version>1.0.0-rc5</version>
</dependency>
<dependency>
<groupId>org.elasticsearch</groupId>
<artifactId>elasticsearch-hadoop-mr</artifactId>
<version>2.0.0.RC1</version>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-client</artifactId>
<version>8.1.14.v20131031</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.3.3</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.0</version>
</dependency>
<dependency>
<groupId>net.sf.opencsv</groupId>
<artifactId>opencsv</artifactId>
<version>2.0</version>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>2.2.1</version>
</dependency>
</dependencies>