Spark機器學習(java):ALS交替最小二乘演算法
阿新 • • 發佈:2018-12-22
楔子
Spark機器學習,推薦電影,採用ALS交替最小二乘演算法
Spark中ml和mllib的區別
Spark機器學習(10):ALS交替最小二乘演算法
demo
import java.io.Serializable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark. ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.zhuzi.utils.SparkUtils;
// $example off$
/**
* @Title: JavaALS.java
* @Package com.zhuzi.mlZ
* https://www.cnblogs.com/mstk/p/7208674.html
*/
public class JavaALS {
private static Logger log = LoggerFactory.getLogger(JavaALS.class);
public static class Rating implements Serializable {
private static final long serialVersionUID = 1L;
private int userId;
private int movieId;
private float rating;
private long timestamp;
public Rating() {
}
public Rating(int userId, int movieId, float rating, long timestamp) {
this.userId = userId;
this.movieId = movieId;
this.rating = rating;
this.timestamp = timestamp;
}
public int getUserId() {
return userId;
}
public int getMovieId() {
return movieId;
}
public float getRating() {
return rating;
}
public long getTimestamp() {
return timestamp;
}
public static Rating parseRating(String str) {
String[] fields = str.split("::");
if (fields.length != 4) {
throw new IllegalArgumentException("每行必須是 4 fields");
}
int userId = Integer.parseInt(fields[0]);
int movieId = Integer.parseInt(fields[1]);
float rating = Float.parseFloat(fields[2]);
long timestamp = Long.parseLong(fields[3]);
return new Rating(userId, movieId, rating, timestamp);
}
}
public static void main(String[] args) {
SparkSession spark = SparkUtils.buildSparkSession();
JavaRDD<Rating> ratingsRDD = spark.read().textFile(SparkUtils.getFilePath("data/mllib/als/sample_movielens_ratings.txt")).javaRDD().map(Rating::parseRating);
ratingsRDD.cache();
Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);
Dataset<Row>[] splits = ratings.randomSplit(new double[] { 0.8, 0.2 });
Dataset<Row> training = splits[0];
Dataset<Row> test = splits[1];
// 利用訓練資料建立ALS推薦模型
ALS als = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("movieId").setRatingCol("rating");
ALSModel model = als.fit(training);
// Evaluate the model by computing the RMSE on the test data Note we set
// cold start strategy to 'drop' to ensure we don't get NaN evaluation
// metrics
model.setColdStartStrategy("drop");
Dataset<Row> predictions = model.transform(test);
RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating").setPredictionCol("prediction");
Double rmse = evaluator.evaluate(predictions);
System.out.println("Root-mean-square error = " + rmse);
// 為每個使用者生成前10個電影推薦
Dataset<Row> userRecs = model.recommendForAllUsers(5);
userRecs.cache();
System.out.println(userRecs.count());
for (Row dataset : userRecs.collectAsList()) {
log.warn(dataset.toString());
}
log.warn("以上是:為每個使用者生成前10個電影推薦");
// ////////////////////////////////////////////////////////////
// 為每部電影生成十大使用者推薦
Dataset<Row> movieRecs = model.recommendForAllItems(5);
movieRecs.cache();
long count = movieRecs.count();
System.out.println(count);
for (Row dataset : movieRecs.collectAsList()) {
log.warn(dataset.toString());
}
log.warn("以上是:為每部電影生成十大使用者推薦");
// 為指定的一組使用者生成前10個電影推薦
Dataset<Row> users = ratings.select(als.getUserCol()).distinct().limit(3);
Dataset<Row> userSubsetRecs = model.recommendForUserSubset(users, 10);
// 為指定的電影集生成前10個使用者推薦
Dataset<Row> movies = ratings.select(als.getItemCol()).distinct().limit(3);
Dataset<Row> movieSubSetRecs = model.recommendForItemSubset(movies, 10);
}
}
資料
資料採用spark官方下載的
data\mllib\als
裡面的資料
結果
[0,WrappedArray([28,4.2635937], [92,3.9503636], [76,3.7858698], [39,3.3879426], [2,2.999741])]
[WARN][10,WrappedArray([2,3.4163194], [53,3.241712], [25,3.0074506], [42,2.8061774], [87,2.728731])]
[WARN][20,WrappedArray([94,4.0460143], [22,3.6154926], [77,3.612016], [46,3.5636027], [88,3.4248996])]
[WARN][1,WrappedArray([8,5.209615], [55,4.414016], [39,4.104884], [68,3.8478084], [83,3.808718])]
[WARN][11,WrappedArray([23,5.2716155], [30,4.875259], [79,4.8541327], [46,4.577077], [66,4.0911093])]
[WARN][21,WrappedArray([53,5.354313], [2,4.206762], [74,3.9279747], [87,3.827361], [4,2.956087])]
[WARN][2,WrappedArray([39,5.098097], [93,5.0468936], [83,5.011127], [8,4.8800406], [63,4.520872])]
[WARN][12,WrappedArray([43,6.6539083], [46,5.720525], [35,5.1227217], [64,4.977849], [27,4.922715])]
[WARN][22,WrappedArray([30,5.1885133], [51,5.1620464], [75,4.8899145], [22,4.611838], [23,4.3775067])]
[WARN][3,WrappedArray([51,4.9186983], [77,4.49184], [80,4.014811], [18,3.9958167], [88,3.9081008])]
[WARN][13,WrappedArray([93,3.5169842], [41,3.1136718], [70,3.0926945], [92,2.9845672], [83,2.8850255])]
[WARN][23,WrappedArray([55,5.4747477], [62,5.2821107], [32,5.1058264], [49,4.683562], [48,4.589707])]
[WARN][4,WrappedArray([29,3.9669492], [83,3.8033702], [93,3.7849288], [52,3.7831538], [41,3.7319896])]
[WARN][14,WrappedArray([46,6.290769], [8,6.004507], [92,5.8273807], [52,4.9534802], [76,4.79875])]
[WARN][24,WrappedArray([30,5.1304946], [98,5.1174364], [90,5.0042753], [96,4.570305], [53,4.503151])]
[WARN][5,WrappedArray([62,4.6407185], [55,4.5774517], [32,4.542769], [49,4.09161], [68,3.8641832])]
[WARN][15,WrappedArray([46,4.871684], [90,4.0514927], [76,3.8044252], [92,3.6267335], [1,3.590583])]
[WARN][25,WrappedArray([25,4.522156], [89,4.2942066], [28,4.2211885], [39,4.097394], [44,3.761376])]
[WARN][6,WrappedArray([25,4.8586254], [58,3.9096472], [43,3.267941], [47,3.1259787], [93,3.1119614])]
[WARN][16,WrappedArray([90,4.995305], [85,4.9076405], [51,4.708634], [76,4.307944], [39,3.9640565])]
[WARN][26,WrappedArray([51,5.7545366], [22,5.566987], [94,5.0986133], [30,5.0660195], [46,4.901876])]
[WARN][7,WrappedArray([25,5.101813], [47,3.8634279], [29,3.861642], [85,3.8062382], [58,3.6149733])]
[WARN][17,WrappedArray([7,5.6873364], [77,5.4233136], [46,5.0695214], [17,4.9160004], [90,4.6762195])]
[WARN][27,WrappedArray([68,4.560275], [30,4.4277897], [49,4.311203], [88,3.8314245], [89,3.653545])]
[WARN][8,WrappedArray([29,5.076687], [53,5.0071025], [52,4.8734946], [41,4.29465], [70,4.116517])]
[WARN][18,WrappedArray([28,4.99047], [39,4.778125], [89,4.430275], [33,4.2363925], [44,3.5871475])]
[WARN][28,WrappedArray([48,5.7651014], [55,5.6089025], [18,5.2874613], [91,5.203947], [92,5.1611648])]
[WARN][9,WrappedArray([27,5.7635093], [49,4.927815], [7,4.890088], [85,4.3361597], [17,4.1737056])]
[WARN][19,WrappedArray([90,4.075626], [94,3.6695158], [98,3.4434488], [46,3.3110263], [32,3.1066573])]
[WARN][29,WrappedArray([30,5.3171983], [90,5.259403], [46,4.761197], [63,4.126344], [32,4.1161675])]
[WARN]以上是:為每個使用者生成前5個電影推薦
pox
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.4.0</version>
<scope>runtime</scope>
</dependency>