1. 程式人生 > >Spark機器學習(java):ALS交替最小二乘演算法

Spark機器學習(java):ALS交替最小二乘演算法

楔子

Spark機器學習,推薦電影,採用ALS交替最小二乘演算法

Spark中ml和mllib的區別
Spark機器學習(10):ALS交替最小二乘演算法

demo

import java.io.Serializable;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.
ml.recommendation.ALSModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.zhuzi.utils.SparkUtils; // $example off$ /** * @Title: JavaALS.java * @Package com.zhuzi.mlZ * https://www.cnblogs.com/mstk/p/7208674.html */
public class JavaALS { private static Logger log = LoggerFactory.getLogger(JavaALS.class); public static class Rating implements Serializable { private static final long serialVersionUID = 1L; private int userId; private int movieId; private float rating; private long timestamp; public
Rating() { } public Rating(int userId, int movieId, float rating, long timestamp) { this.userId = userId; this.movieId = movieId; this.rating = rating; this.timestamp = timestamp; } public int getUserId() { return userId; } public int getMovieId() { return movieId; } public float getRating() { return rating; } public long getTimestamp() { return timestamp; } public static Rating parseRating(String str) { String[] fields = str.split("::"); if (fields.length != 4) { throw new IllegalArgumentException("每行必須是 4 fields"); } int userId = Integer.parseInt(fields[0]); int movieId = Integer.parseInt(fields[1]); float rating = Float.parseFloat(fields[2]); long timestamp = Long.parseLong(fields[3]); return new Rating(userId, movieId, rating, timestamp); } } public static void main(String[] args) { SparkSession spark = SparkUtils.buildSparkSession(); JavaRDD<Rating> ratingsRDD = spark.read().textFile(SparkUtils.getFilePath("data/mllib/als/sample_movielens_ratings.txt")).javaRDD().map(Rating::parseRating); ratingsRDD.cache(); Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class); Dataset<Row>[] splits = ratings.randomSplit(new double[] { 0.8, 0.2 }); Dataset<Row> training = splits[0]; Dataset<Row> test = splits[1]; // 利用訓練資料建立ALS推薦模型 ALS als = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("movieId").setRatingCol("rating"); ALSModel model = als.fit(training); // Evaluate the model by computing the RMSE on the test data Note we set // cold start strategy to 'drop' to ensure we don't get NaN evaluation // metrics model.setColdStartStrategy("drop"); Dataset<Row> predictions = model.transform(test); RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating").setPredictionCol("prediction"); Double rmse = evaluator.evaluate(predictions); System.out.println("Root-mean-square error = " + rmse); // 為每個使用者生成前10個電影推薦 Dataset<Row> userRecs = model.recommendForAllUsers(5); userRecs.cache(); System.out.println(userRecs.count()); for (Row dataset : userRecs.collectAsList()) { log.warn(dataset.toString()); } log.warn("以上是:為每個使用者生成前10個電影推薦"); // //////////////////////////////////////////////////////////// // 為每部電影生成十大使用者推薦 Dataset<Row> movieRecs = model.recommendForAllItems(5); movieRecs.cache(); long count = movieRecs.count(); System.out.println(count); for (Row dataset : movieRecs.collectAsList()) { log.warn(dataset.toString()); } log.warn("以上是:為每部電影生成十大使用者推薦"); // 為指定的一組使用者生成前10個電影推薦 Dataset<Row> users = ratings.select(als.getUserCol()).distinct().limit(3); Dataset<Row> userSubsetRecs = model.recommendForUserSubset(users, 10); // 為指定的電影集生成前10個使用者推薦 Dataset<Row> movies = ratings.select(als.getItemCol()).distinct().limit(3); Dataset<Row> movieSubSetRecs = model.recommendForItemSubset(movies, 10); } }

資料

資料採用spark官方下載的 data\mllib\als 裡面的資料

結果

[0,WrappedArray([28,4.2635937], [92,3.9503636], [76,3.7858698], [39,3.3879426], [2,2.999741])]
[WARN][10,WrappedArray([2,3.4163194], [53,3.241712], [25,3.0074506], [42,2.8061774], [87,2.728731])]
[WARN][20,WrappedArray([94,4.0460143], [22,3.6154926], [77,3.612016], [46,3.5636027], [88,3.4248996])]
[WARN][1,WrappedArray([8,5.209615], [55,4.414016], [39,4.104884], [68,3.8478084], [83,3.808718])]
[WARN][11,WrappedArray([23,5.2716155], [30,4.875259], [79,4.8541327], [46,4.577077], [66,4.0911093])]
[WARN][21,WrappedArray([53,5.354313], [2,4.206762], [74,3.9279747], [87,3.827361], [4,2.956087])]
[WARN][2,WrappedArray([39,5.098097], [93,5.0468936], [83,5.011127], [8,4.8800406], [63,4.520872])]
[WARN][12,WrappedArray([43,6.6539083], [46,5.720525], [35,5.1227217], [64,4.977849], [27,4.922715])]
[WARN][22,WrappedArray([30,5.1885133], [51,5.1620464], [75,4.8899145], [22,4.611838], [23,4.3775067])]
[WARN][3,WrappedArray([51,4.9186983], [77,4.49184], [80,4.014811], [18,3.9958167], [88,3.9081008])]
[WARN][13,WrappedArray([93,3.5169842], [41,3.1136718], [70,3.0926945], [92,2.9845672], [83,2.8850255])]
[WARN][23,WrappedArray([55,5.4747477], [62,5.2821107], [32,5.1058264], [49,4.683562], [48,4.589707])]
[WARN][4,WrappedArray([29,3.9669492], [83,3.8033702], [93,3.7849288], [52,3.7831538], [41,3.7319896])]
[WARN][14,WrappedArray([46,6.290769], [8,6.004507], [92,5.8273807], [52,4.9534802], [76,4.79875])]
[WARN][24,WrappedArray([30,5.1304946], [98,5.1174364], [90,5.0042753], [96,4.570305], [53,4.503151])]
[WARN][5,WrappedArray([62,4.6407185], [55,4.5774517], [32,4.542769], [49,4.09161], [68,3.8641832])]
[WARN][15,WrappedArray([46,4.871684], [90,4.0514927], [76,3.8044252], [92,3.6267335], [1,3.590583])]
[WARN][25,WrappedArray([25,4.522156], [89,4.2942066], [28,4.2211885], [39,4.097394], [44,3.761376])]
[WARN][6,WrappedArray([25,4.8586254], [58,3.9096472], [43,3.267941], [47,3.1259787], [93,3.1119614])]
[WARN][16,WrappedArray([90,4.995305], [85,4.9076405], [51,4.708634], [76,4.307944], [39,3.9640565])]
[WARN][26,WrappedArray([51,5.7545366], [22,5.566987], [94,5.0986133], [30,5.0660195], [46,4.901876])]
[WARN][7,WrappedArray([25,5.101813], [47,3.8634279], [29,3.861642], [85,3.8062382], [58,3.6149733])]
[WARN][17,WrappedArray([7,5.6873364], [77,5.4233136], [46,5.0695214], [17,4.9160004], [90,4.6762195])]
[WARN][27,WrappedArray([68,4.560275], [30,4.4277897], [49,4.311203], [88,3.8314245], [89,3.653545])]
[WARN][8,WrappedArray([29,5.076687], [53,5.0071025], [52,4.8734946], [41,4.29465], [70,4.116517])]
[WARN][18,WrappedArray([28,4.99047], [39,4.778125], [89,4.430275], [33,4.2363925], [44,3.5871475])]
[WARN][28,WrappedArray([48,5.7651014], [55,5.6089025], [18,5.2874613], [91,5.203947], [92,5.1611648])]
[WARN][9,WrappedArray([27,5.7635093], [49,4.927815], [7,4.890088], [85,4.3361597], [17,4.1737056])]
[WARN][19,WrappedArray([90,4.075626], [94,3.6695158], [98,3.4434488], [46,3.3110263], [32,3.1066573])]
[WARN][29,WrappedArray([30,5.3171983], [90,5.259403], [46,4.761197], [63,4.126344], [32,4.1161675])]
[WARN]以上是:為每個使用者生成前5個電影推薦

pox

<dependency>
	<groupId>org.apache.spark</groupId>
	<artifactId>spark-mllib_2.11</artifactId>
	<version>2.4.0</version>
	<scope>runtime</scope>
</dependency>