自定義開發Spark ML機器學習類
初窺門徑
Spark的MLlib元件內建實現了很多常見的機器學習演算法,包括資料抽取,分類,聚類,關聯分析,協同過濾等等.
然鵝,內建的演算法並不能滿足我們所有的需求,所以我們還是經常需要自定義ML演算法.
MLlib提供的API分為兩類:
- 1.基於DataFrame的API,屬於spark.ml包.
- 2.基於RDD的API, 屬於spark.mllib包.
從Spark 2.0開始,Spark的API全面從RDD轉向DataFrame,MLlib也是如此,官網原話如下:
Announcement: DataFrame-based API is primary API
The MLlib RDD-based API is now in maintenance mode.
所以本文將介紹基於DataFrame的自定義ml類編寫方法.不涉及具體演算法,只講擴充套件ml類的方法.
略知一二
官方文件並沒有介紹如何自定義ml類,所以只有從原始碼入手,看看原始碼裡面是怎麼實現的.
找一個最簡單的內建演算法入手,這個演算法就是內建的分詞器,Tokenizer.
Tokenizer只是簡單的將文字以空白部分進行分割,只適合給英文進行分詞,所以它的實現及其簡短,原始碼如下:
package org.apache.spark.ml.feature
import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
/**
* A tokenizer that converts the input string to lowercase and then splits it by white spaces.
*
* @see [[RegexTokenizer]]
*/
@Since("1.2.0")
class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable {
@Since("1.2.0")
def this() = this(Identifiable.randomUID("tok"))
override protected def createTransformFunc: String => Seq[String] = {
_.toLowerCase.split("\\s")
}
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType, s"Input type must be string type but got $inputType.")
}
override protected def outputDataType: DataType = new ArrayType(StringType, true)
@Since("1.4.1")
override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
}
@Since("1.6.0")
object Tokenizer extends DefaultParamsReadable[Tokenizer] {
@Since("1.6.0")
override def load(path: String): Tokenizer = super.load(path)
}
簡單分析下原始碼:
- Tokenizer繼承了UnaryTransformer類.unary是’一元’的意思,也是說這個類實現的是類似一元函式的功能,一個輸入變數,一個輸出.直接看UnaryTransformer的原始碼註釋:
/**
* :: DeveloperApi ::
* Abstract class for transformers that take one input column, apply transformation, and output the
* result as a new column.
*/
DeveloperApi表明這是一個開發級API,開發者可以用,不會有許可權問題(原始碼中有很多private[spark]的類,是不允許外部呼叫的).
註釋的大意就是:這是一個為實現transformers準備的抽象類,以一個欄位(列)為輸入,輸出一個新欄位(列).
所以實際上就是實現一個Transformer,只是這個Transformer有指定的輸入欄位和輸出欄位.
- UnaryTransformer類中只有兩個抽象方法.
一個是createTransformFunc,是最核心的方法,這個方法需要返回一個函式,這個函式的引數即Transformer的輸入欄位的值,返回值為Transformer的輸出欄位的值.看看Tokenizer中的實現,就明白了.
另一個是outputDataType,這個方法用來返回輸出欄位的型別.
validateInputType方法是用來檢查輸入欄位型別的,看需要實現.
Tokenizer混入了DefaultParamsWritable特質,使得自己可以被儲存.
對應的object Tokenizer伴生物件,用來讀取已儲存的Tokenizer.值得注意的是,Transformer類是PipelineStage類的子類,所以Transformer的子類,包括我們自定義的,是可以直接用在ML Pipelines中的.這就厲害了,說明自定義的演算法類,可以無縫與內建機器學習演算法打配合,還能利用Pipeline的調優工具(model selection,Cross-Validation等).
初出茅廬
看完原始碼,基本套路已經明瞭,不如動手抄一個,不,敲一個.
依葫蘆畫瓢,實現一個正則提取的Transformer.
import util.matching.Regex
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types._
/**
* 正則提取器
* 將匹配指定正則表示式的全部子字串,提取到array[string]中.
*/
class RegexExtractor(override val uid: String)
extends UnaryTransformer[String, Seq[String], RegexExtractor] {
def this() = this(Identifiable.randomUID("RegexExtractor"))
/**
* 引數:正則表示式
*
* @group param
*/
final val regex = new Param[Regex](this, "RegexExpr", "正則表示式")
/** @group setParam */
def setRegexExpr(value: String): this.type = set(regex, new Regex(value))
override protected def outputDataType: DataType = new ArrayType(StringType, true)
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == DataTypes.StringType,
s"Input type must be string type but got $inputType."
)
}
override protected def createTransformFunc: String => Seq[String] = {
parseContent
}
/**
* 資料處理
*/
private def parseContent(text: String): Seq[String] = {
if (text == null || text.isEmpty) {
return Seq.empty[String]
}
$(regex).findAllIn(text).toSeq
}
}
這個類結構與Tokenizer原始碼基本差不多,多用到的Param類,是一個引數的包裝類.
作用是self-contained documentation and optionally default value.
其實就是把引數的值,文件,預設值等屬性組合成一個類,方便呼叫.
比如上面定義的regex引數,就可以用$(regex)這樣的方式直接呼叫.
另外在org.apache.spark.ml.param中有很多內建的Param類,可以直接使用.
同時org.apache.spark.ml.param.shared中有很多輔助引入引數的特質,比如HasInputCols特質,你的自定義Transformer只要混入這個特質就擁有了inputCols引數.不過目前shared中特質的作用域是private[ml],也就是說不能直接引用,而是要copy一份程式碼到自己的專案,並修改作用域才行.
關於這個作用域的問題,有人在spark的jira上提到,提議將其作為DeveloperApi開放出來,我也投了一票表示支援.後來在2017年11月終於resolved,該問題將在Spark2.3.0中解決.詳情戳我
粗懂皮毛
自定義的類寫好了,該怎麼用呢? 當然是跟內建的一樣啦.上栗子:
val regex="nidezhengze"
val tranTitle = new RegexExtractor()
.setInputCol("title")
.setOutputCol("title_price_texts")
.setRegexExpr(regex)
val pipeline = new Pipeline().setStages(Array(
tranTitle
))
val matched = pipeline.fit(data).transform(data)
打完收功
到這裡,開發簡單Transform的套路已經清楚了,不過這裡實現的功能比較類似於一個UDF,只能對dataset的一個欄位進行處理,而且是逐行處理,並不能根據多行資料進行處理,實現視窗函式類似的功能,而且也沒有涉及模型的輸出.如果要開發更復雜的演算法,甚至進行模型訓練,就需要更深入的瞭解MLlib了,閱讀原始碼是個好途徑.
下回再說.