1. 程式人生 > 實用技巧 >Spark ML 之 推薦演算法專案(上)

Spark ML 之 推薦演算法專案(上)

一、整體流程

二、具體召回流程

三、程式碼實現

0、過濾已下架的/成人用品/菸酒等

package com.njbdqn.filter

import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
import org.apache.spark.sql.SparkSession

object BanGoodFilter {

  /**
   *  清洗不能推薦的商品,把留下的商品存在HDFS上儲存
   * @param spark
   */
  def ban(spark:SparkSession): Unit ={
   // 讀出原始資料
val goodsDf = MYSQLConnection.readMySql(spark, "goods") // 過濾下架商品(已經賣過),把未賣的商品存放到HDFS val gd = goodsDf.filter("is_sale=0") HDFSConnection.writeDataToHDFS("/myshops/dwd_good",gd) } }

1、根據熱點全域性召回,cross join到每個使用者(使每個使用者都有可以推薦的)

package com.njbdqn.call

import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{desc, row_number, sum}

/** * 全域性召回 */ object GlobalHotCall { def hotSell(spark:SparkSession): Unit ={ val oitab = MYSQLConnection.readMySql(spark, "orderItems").cache() // 計算全域性熱賣商品前100名 ( good_id,sellnum ) import spark.implicits._ val top30 = oitab .groupBy("good_id") .agg(sum("buy_num").alias("
sellnum")) .withColumn("rank",row_number().over(Window.orderBy(desc("sellnum")))) .limit(100) // 所有使用者id和推薦前30名cross join val wnd2 = Window.orderBy("cust_id") val custstab = MYSQLConnection.readMySql(spark,"customs") .select($"cust_id").cache() val res = custstab.crossJoin(top30) .select($"cust_id",$"good_id",$"rank") HDFSConnection.writeDataToHDFS("/myshops/dwd_hotsell",res) } }

2、分組召回

詳細見https://www.cnblogs.com/sabertobih/p/13824739.html

資料處理,歸一化:

package com.njbdqn.datahandler

import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.feature.{MinMaxScaler, StringIndexer, VectorAssembler}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, count, current_date, datediff, desc, min, row_number, sum, udf}
import org.apache.spark.sql.types.DoubleType

object KMeansHandler {
  val func_membership = udf {
    (score: Int) => {
      score match {
        case i if i < 100 => 1
        case i if i < 500 => 2
        case i if i < 1000 => 3
        case _ => 4
      }
    }
  }
  val func_bir = udf {
    (idno: String, now: String) => {
      val year = idno.substring(6, 10).toInt
      val month = idno.substring(10, 12).toInt
      val day = idno.substring(12, 14).toInt

      val dts = now.split("-")
      val nowYear = dts(0).toInt
      val nowMonth = dts(1).toInt
      val nowDay = dts(2).toInt

      if (nowMonth > month) {
        nowYear - year
      } else if (nowMonth < month) {
        nowYear - 1 - year
      } else {
        if (nowDay >= day) {
          nowYear - year
        } else {
          nowYear - 1 - year
        }
      }
    }
  }
  val func_age = udf {
    (num: Int) => {
      num match {
        case n if n < 10 => 1
        case n if n < 18 => 2
        case n if n < 23 => 3
        case n if n < 35 => 4
        case n if n < 50 => 5
        case n if n < 70 => 6
        case _ => 7
      }
    }
  }
  val func_userscore = udf {
    (sc: Int) => {
      sc match {
        case s if s < 100 => 1
        case s if s < 500 => 2
        case _ => 3
      }
    }
  }
  val func_logincount = udf {
    (sc: Int) => {
      sc match {
        case s if s < 500 => 1
        case _ => 2
      }
    }
  }

  // 整合使用者自然屬性和行為
  def user_act_info(spark:SparkSession): DataFrame ={
    val featureDataTable = MYSQLConnection.readMySql(spark,"customs").filter("active!=0")
      .select("cust_id", "company", "province_id", "city_id", "district_id"
        , "membership_level", "create_at", "last_login_time", "idno", "biz_point", "sex", "marital_status", "education_id"
        , "login_count", "vocation", "post")
    //商品表
    val goodTable=HDFSConnection.readDataToHDFS(spark,"/myshops/dwd_good").select("good_id","price")
    //訂單表
    val orderTable=MYSQLConnection.readMySql(spark,"orders").select("ord_id","cust_id")
    //訂單明細表
    val orddetailTable=MYSQLConnection.readMySql(spark,"orderItems").select("ord_id","good_id","buy_num")
    //先將公司名通過StringIndex轉為數字
    val compIndex = new StringIndexer().setInputCol("company").setOutputCol("compId")
    //使用自定義UDF函式
    import spark.implicits._
    //計算每個使用者購買的次數
    val tmp_bc=orderTable.groupBy("cust_id").agg(count($"ord_id").as("buycount"))
    //計算每個使用者在網站上花費了多少錢
    val tmp_pay=orderTable.join(orddetailTable,Seq("ord_id"),"inner").join(goodTable,Seq("good_id"),"inner").groupBy("cust_id").
      agg(sum($"buy_num"*$"price").as("pay"))

    compIndex.fit(featureDataTable).transform(featureDataTable)
      .withColumn("mslevel", func_membership($"membership_level"))
      .withColumn("min_reg_date", min($"create_at") over())
      .withColumn("reg_date", datediff($"create_at", $"min_reg_date"))
      .withColumn("min_login_time", min("last_login_time") over())
      .withColumn("lasttime", datediff($"last_login_time", $"min_login_time"))
      .withColumn("age", func_age(func_bir($"idno", current_date())))
      .withColumn("user_score", func_userscore($"biz_point"))
      .withColumn("logincount", func_logincount($"login_count"))
      // 右表:有的使用者可能沒有買/沒花錢,缺少cust_id,所以是left join,以多的為準
      .join(tmp_bc,Seq("cust_id"),"left").join(tmp_pay,Seq("cust_id"),"left")
      .na.fill(0)
      .drop("company", "membership_level", "create_at", "min_reg_date"
        , "last_login_time", "min_login_time", "idno", "biz_point", "login_count")
  }
  // 使用者分組
  def user_group(spark:SparkSession) = {
    val df = user_act_info(spark)
    //將所有列換成 Double
    val columns=df.columns.map(f=>col(f).cast(DoubleType))
    val num_fmt=df.select(columns:_*)
    //將除了第一列的所有列都組裝成一個向量列
    val va= new VectorAssembler()
      .setInputCols(Array("province_id","city_id","district_id","sex","marital_status","education_id","vocation","post","compId","mslevel","reg_date","lasttime","age","user_score","logincount","buycount","pay"))
      .setOutputCol("orign_feature")
    val ofdf=va.transform(num_fmt).select("cust_id","orign_feature")
    //將原始特徵列歸一化處理
    val mmScaler:MinMaxScaler=new MinMaxScaler().setInputCol("orign_feature").setOutputCol("feature")
    //fit產生模型 把ofdf放到模型裡使用
    mmScaler.fit(ofdf)
      .transform(ofdf)
      .select("cust_id","feature")

  }
}

kmeans計算分組召回:

package com.njbdqn.call

import com.njbdqn.datahandler.KMeansHandler
import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._


/**
 *  計算使用者分組
 */
object GroupCall {
  def calc_groups(spark:SparkSession): Unit ={
    //使用Kmeans演算法進行分組
    //計算根據不同的質心點計算所有的距離
    //記錄不同質心點距離的集合
    //    val disList:ListBuffer[Double]=ListBuffer[Double]()
    //    for (i<-2 to 40){
    //      val kms=new KMeans().setFeaturesCol("feature").setK(i)
    //      val model=kms.fit(resdf)
    //    // 為什麼不transform ??
    //      // 目的不是產生df:cust_id,feature和對應的group(prediction)
    //      // 目的是用computeCost算K數量對應的[SSD]
    //      disList.append(model.computeCost(resdf))
    //    }
    //    //呼叫繪圖工具繪圖
    //    val chart=new LineGraph("app","Kmeans質心和距離",disList)
    //    chart.pack()
    //    RefineryUtilities.centerFrameOnScreen(chart)
    //    chart.setVisible(true)

    import spark.implicits._
    val orderTable=MYSQLConnection.readMySql(spark,"orders").select("ord_id","cust_id")
    val orddetailTable=MYSQLConnection.readMySql(spark,"orderItems").select("ord_id","good_id","buy_num")
    val resdf = KMeansHandler.user_group(spark)
        //使用 Kmeans 進行分組:找一個穩定的 K 值
        val kms=new KMeans().setFeaturesCol("feature").setK(40)
        // 每個使用者所屬的組 (cust_id,groups) (1,0)
        val user_group_tab=kms.fit(resdf)
          .transform(resdf)
          .drop("feature").
          withColumnRenamed("prediction","groups").cache()

        //獲取每組使用者購買的前30名商品
        // row_number 根據組分組,買的次數desc
        // groupby 組和商品,count買的次數order_id
        val rank=30
        val wnd=Window.partitionBy("groups").orderBy(desc("group_buy_count"))

    val groups_goods = user_group_tab.join(orderTable,Seq("cust_id"),"inner")
         .join(orddetailTable,Seq("ord_id"),"inner")
          .na.fill(0)
          .groupBy("groups","good_id")
          .agg(count("ord_id").as("group_buy_count"))
          .withColumn("rank",row_number()over(wnd))
          .filter($"rank"<=rank)
        // 每個使用者所屬組推薦的商品(是為每個使用者推薦的)
    val df5 = user_group_tab.join(groups_goods,Seq("groups"),"inner")
          HDFSConnection.writeDataToHDFS("/myshops/dwd_kMeans",df5)
  }
}

3、ALS協同過濾召回

ALS資料預處理:User-Item稀疏矩陣中score需要量化成數字,每列都需要全數字,稀疏表=> Rating集合

package com.njbdqn.datahandler

import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
import org.apache.spark.mllib.recommendation.Rating
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{row_number, sum, udf}

object ALSDataHandler {
  // 為了防止使用者編號或商品編號中含有非數字情況,要對所有的商品和使用者編號給一個連續的對應的數字編號後再存到快取
  def goods_to_num(spark:SparkSession):DataFrame={
    import spark.implicits._
    val wnd1 = Window.orderBy("good_id")
    HDFSConnection.readDataToHDFS(spark,"/myshops/dwd_good").select("good_id","price")
      .select($"good_id",row_number().over(wnd1).alias("gid")).cache()
  }

  def user_to_num(spark:SparkSession):DataFrame={
    import spark.implicits._
    val wnd2 = Window.orderBy("cust_id")
    MYSQLConnection.readMySql(spark,"customs")
      .select($"cust_id",row_number().over(wnd2).alias("uid")).cache()
  }

  val actToNum=udf{
    (str:String)=>{
      str match {
        case "BROWSE"=>1
        case "COLLECT"=>2
        case "BUYCAR"=>3
        case _=>8
      }
    }
  }

  case class UserAction(act:String,act_time:String,cust_id:String,good_id:String,browse:String)

  def als_data(spark:SparkSession): RDD[Rating] ={
    val goodstab:DataFrame = goods_to_num(spark)
    val custstab:DataFrame = user_to_num(spark)
    val txt = spark.sparkContext.textFile("file:///D:/logs/virtualLogs/*.log").cache()
    import spark.implicits._
    // 計算出每個使用者對該使用者接觸過的商品的評分
    val df = txt.map(line=>{
      val arr = line.split(" ")
      UserAction(arr(0),arr(1),arr(2),arr(3),arr(4))
    }).toDF().drop("act_time","browse")
      .select($"cust_id",$"good_id",actToNum($"act").alias("score"))
      .groupBy("cust_id","good_id")
      .agg(sum($"score").alias("score"))
    // 為了防止使用者編號或商品編號中含有非數字情況,要對所有的商品和使用者編號給一個連續的對應的數字編號後再存到快取
    // 將df和goodstab、custtab join一下只保留 (gid,uid,score)
    val df2 = df.join(goodstab,Seq("good_id"),"inner")
      .join(custstab,Seq("cust_id"),"inner")
      .select("gid","uid","score")
    //.show(20)
    // 將稀疏錶轉為 Rating物件集合
    val allData:RDD[Rating] = df2.rdd.map(row=>{
      Rating(
        row.getAs("uid").toString.toInt,
        row.getAs("gid").toString.toInt,
        row.getAs("score").toString.toFloat
      )})
    allData
  }
}

ALS訓練,最後需要還原資料(數字=>非數字)

package com.njbdqn.call

import com.njbdqn.datahandler.ALSDataHandler
import com.njbdqn.datahandler.ALSDataHandler.{goods_to_num, user_to_num}
import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
import org.apache.spark.mllib.recommendation.{ALS, Rating}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}

object ALSCall {
  def als_call(spark:SparkSession): Unit ={
    val goodstab:DataFrame = goods_to_num(spark)
    val custstab:DataFrame = user_to_num(spark)
    val alldata: RDD[Rating] = ALSDataHandler.als_data(spark).cache()
    // 將獲得的Rating集合拆分按照0.2,0.8比例拆成兩個集合
   // val Array(train,test) = alldata.randomSplit(Array(0.8,0.2))
    // 使用8成的資料去訓練模型
    val model = new ALS().setCheckpointInterval(2).setRank(10).setIterations(20).setLambda(0.01).setImplicitPrefs(false)
      .run(alldata)
    // 對模型進行測試,每個使用者推薦前30名商品
    val tj = model.recommendProductsForUsers(30)
    import spark.implicits._
    // (uid,gid,rank)
    val df5 = tj.flatMap{
      case(user:Int,ratings:Array[Rating])=>
        ratings.map{case (rat:Rating)=>(user,rat.product,rat.rating)}
    }.toDF("uid","gid","rank")
      // 還原成(cust_id,good_id,score)
      .join(goodstab,Seq("gid"),"inner")
      .join(custstab,Seq("uid"),"inner")
      .select($"cust_id",$"good_id",$"rank")
   //   .show(false)
    HDFSConnection.writeDataToHDFS("/myshops/dwd_ALS_Iter20",df5)
  }
}