1. 程式人生 > 其它 >Spark實現簡單的垃圾郵件分類--JAVA原始碼

Spark實現簡單的垃圾郵件分類--JAVA原始碼

技術標籤:spark機器學習hadoopjava大資料

Spark實現簡單的垃圾郵件分類–JAVA原始碼


程式碼部分

package cn.cc.spark;

import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.
spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.LogisticRegressionModel; import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
import org.apache.spark.mllib.feature.HashingTF; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; public final class Spam { public static void main(String[] args) { SparkConf sparkConf = new SparkConf().setAppName("垃圾郵件訓練").setMaster
("local[*]"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD<String> spam = sc.textFile("hdfs://localhost:8020/sample/spam/spmsga1.eml"); JavaRDD<String> mail = sc.textFile("hdfs://localhost:8020/sample/mail/3-1msg1.txt"); // 例項化HashingTF,用於將郵件文字對映為x個特徵向量 final HashingTF tf = new HashingTF(10000); //map:將RDD中的所有元素應用於Function並返回處理後的新RDD //Function<A,B> 表示輸入型別為A,輸出型別為B JavaRDD<LabeledPoint> positiveData = setLabeledPoint(spam, tf, 1.0); JavaRDD<LabeledPoint> negativeData = setLabeledPoint(mail, tf, 0.0); //將陰性陽性資料整合為訓練集 JavaRDD<LabeledPoint> trainingData = positiveData.union(negativeData); // 快取訓練資料RDD trainingData.cache(); //使用梯度下降法進行邏輯迴歸訓練資料 LogisticRegressionWithSGD lrLearner = new LogisticRegressionWithSGD(); LogisticRegressionModel model = lrLearner.run(trainingData.rdd()); //以垃圾郵件和正常郵件作為例子測試 //對資料同樣的HashingTF進行特性轉換 Vector positiveSample = tf.transform(filterText(sc, "hdfs://localhost:8020/sample/spam/spmsga2.eml")); predictionResult("spmsga2.eml", model.predict(positiveSample)); Vector negativeSample = tf.transform(filterText(sc, "hdfs://localhost:8020/sample/mail/3-1msg2.txt")); predictionResult("3-1msg2.txt", model.predict(negativeSample)); sc.stop(); sc.close(); } public static void predictionResult(String emailName,double predictValue) { if (predictValue==1.0) { Logger.getGlobal().info(emailName+"郵件的預測結果為垃圾郵件"+predictValue); }else { Logger.getGlobal().info(emailName+"預測結果為正常郵件"+predictValue); } } /** * 設定陰性陽性標記點 * @param rdd * @param tf * @param lable * @return */ public static JavaRDD<LabeledPoint> setLabeledPoint(JavaRDD<String> rdd,HashingTF tf,Double lable ){ return rdd.map(new Function<String, LabeledPoint>() { private static final long serialVersionUID = 1L; //重寫call方法,實現轉換,輸入引數為String,輸出引數為LabeledPoint @Override public LabeledPoint call(String email) { List<String> list = new LinkedList<>(); //匹配有效字元 Pattern pattern = Pattern.compile("[\\w]*"); Matcher matcher = pattern.matcher(email); while(matcher.find()) { list.add(matcher.group()); } //設定標註點 return new LabeledPoint(lable, tf.transform(list)); } }); } public static List<String> filterText(JavaSparkContext sc,String path){ List<String> list = new LinkedList<>(); List<String> data = sc.textFile(path).collect(); Iterator<String> iter = data.iterator(); while(iter.hasNext()) { String str = iter.next(); if (str.trim().isEmpty()) { iter.remove(); }else { list.addAll(Arrays.asList(str.split(" "))); } } return list; } }

相關依賴

 <dependencies>
    <dependency> 
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-core_2.10</artifactId>
      <version>1.3.1</version>
      <scope>provided</scope>
    </dependency>
    <dependency> 
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-sql_2.10</artifactId>
      <version>1.3.1</version>
      <scope>provided</scope>
    </dependency>
    <dependency> 
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-hive_2.10</artifactId>
      <version>1.3.1</version>
      <scope>provided</scope>
    </dependency>
    <dependency> 
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-streaming_2.10</artifactId>
      <version>1.3.1</version>
    </dependency>
    <dependency> 
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-streaming-kafka_2.10</artifactId>
      <version>1.3.1</version>
    </dependency>
    <dependency> 
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-mllib_2.10</artifactId>
      <version>1.3.1</version>
    </dependency>
    <dependency> 
      <groupId>com.datastax.spark</groupId>
      <artifactId>spark-cassandra-connector_2.10</artifactId>
      <version>1.0.0-rc5</version>
    </dependency>
    <dependency> 
      <groupId>com.datastax.spark</groupId>
      <artifactId>spark-cassandra-connector-java_2.10</artifactId>
      <version>1.0.0-rc5</version>
    </dependency>
    <dependency> 
      <groupId>org.elasticsearch</groupId>
      <artifactId>elasticsearch-hadoop-mr</artifactId>
      <version>2.0.0.RC1</version>
    </dependency>
    <dependency> 
      <groupId>org.eclipse.jetty</groupId>
      <artifactId>jetty-client</artifactId>
      <version>8.1.14.v20131031</version>
    </dependency>
    <dependency>
      <groupId>com.fasterxml.jackson.core</groupId>
      <artifactId>jackson-databind</artifactId>
      <version>2.3.3</version>
    </dependency>
    <dependency>
      <groupId>org.apache.commons</groupId>
      <artifactId>commons-lang3</artifactId>
      <version>3.0</version>
    </dependency>
    <dependency>
      <groupId>net.sf.opencsv</groupId>
      <artifactId>opencsv</artifactId>
      <version>2.0</version>
    </dependency>
    <dependency>
      <groupId>org.scalatest</groupId>
      <artifactId>scalatest_${scala.binary.version}</artifactId>
      <version>2.2.1</version>
    </dependency>
  </dependencies>