背景介紹
我們的系統(tǒng)有一小部分機(jī)器學(xué)習(xí)模型識(shí)別需求享怀,因?yàn)榉N種原因,最終選用了Spark MLlib來進(jìn)行訓(xùn)練和預(yù)測趟咆。MLlib的Pipeline設(shè)計(jì)很好地契合了一個(gè)機(jī)器學(xué)習(xí)流水線添瓷,在模型訓(xùn)練和效果驗(yàn)證階段,pipeline可以簡化開發(fā)流程值纱,然而在預(yù)測階段鳞贷,MLlib pipeline的表現(xiàn)有點(diǎn)差強(qiáng)人意。
問題描述
某個(gè)模型的輸入為一個(gè)字符串虐唠,假設(shè)長度為N搀愧,在我們的場景里面這個(gè)N一般不會(huì)大于10。特征也很簡單疆偿,對(duì)于每一個(gè)輸入咱筛,可以在O(N)的時(shí)間計(jì)算出特征向量,分類器選用的是隨機(jī)森林杆故。
對(duì)于這樣的預(yù)測任務(wù)迅箩,直觀上感覺應(yīng)該非常快处铛,初步估計(jì)10ms以內(nèi)出結(jié)果饲趋。但是通MLlib pipeline的transform預(yù)測結(jié)果預(yù)測時(shí)拐揭,性能在86ms左右(2000條query平均響應(yīng)時(shí)間)。而且奕塑,query和query之間在輸入相同的情況下堂污,也存在響應(yīng)時(shí)間波動(dòng)的問題。
預(yù)測性能優(yōu)化
先說說響應(yīng)時(shí)間波動(dòng)的問題龄砰,每一條query的輸入都是一樣的盟猖,也就排除了特征加工時(shí)的計(jì)算量波動(dòng)的問題,因?yàn)檎麄€(gè)計(jì)算中消耗內(nèi)存極少换棚,且測試時(shí)內(nèi)存足夠扒披,因?yàn)橐才懦齡c導(dǎo)致預(yù)測性能抖動(dòng)的問題。那么剩下的只有Spark了圃泡,Spark可能在做某些事情導(dǎo)致了預(yù)測性能抖動(dòng)。通過查看log信息愿险,可以印證這個(gè)觀點(diǎn)颇蜡。
從日志中截取了一小段,里面有大量的清理broadcast變量信息辆亏。這也為后續(xù)性能優(yōu)化提供了一個(gè)方向风秤。(下面會(huì)有部分MLlib源碼,源碼基于Spark2.3)
在MLlib中扮叨,是調(diào)用PipelineModel的transform方法進(jìn)行預(yù)測缤弦,該方法會(huì)調(diào)用pipeline的每一個(gè)stage內(nèi)的Transformer的transform方法來對(duì)輸入的DataFrame/DataSet進(jìn)行轉(zhuǎn)換。
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur))
}
下面彻磁,我們先看看訓(xùn)練好的隨機(jī)森林模型(RandomForestClassificationModel)在預(yù)測時(shí)做了些什么吧
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
重點(diǎn)來了碍沐,終于找到前面說的broadcast的'罪魁禍'了,每次預(yù)測時(shí)衷蜓,MLlib都會(huì)把模型廣播到集群累提。這樣做的好處是方便批處理,但對(duì)于小計(jì)算量磁浇,壓根不需要集群的預(yù)測場景這樣的做法就有點(diǎn)浪費(fèi)資源了:
- 每次預(yù)測都廣播顯然太多余斋陪。
- 因?yàn)槊看味紡V播,所以之前的廣播變量也會(huì)逐漸回收置吓,在回收時(shí)无虚,又反過來影響預(yù)測的性能。
解決辦法
從上述代碼中可以看到衍锚,RandomForestClassificationModel 預(yù)測最根本的地方是在于調(diào)用predict方法友题,輸入是一個(gè)Vector」谷看看predict干了什么
override protected def predict(features: FeaturesType): Double = {
raw2prediction(predictRaw(features))
}
predict分為兩步走:
override protected def predictRaw(features: Vector): Vector = {
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
// Ignore the tree weights since all are 1.0 for now.
val votes = Array.fill[Double](numClasses)(0.0)
_trees.view.foreach { tree =>
val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats
val total = classCounts.sum
if (total != 0) {
var i = 0
while (i < numClasses) {
votes(i) += classCounts(i) / total
i += 1
}
}
}
Vectors.dense(votes)
}
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
case dv: DenseVector =>
ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv)
dv
case sv: SparseVector =>
throw new RuntimeException("Unexpected error in RandomForestClassificationModel:" +
" raw2probabilityInPlace encountered SparseVector")
}
}
這兩個(gè)方法的輸入和輸出均為vector咆爽,那么我們?nèi)绻堰@兩個(gè)方法反射出來直接用在預(yù)測的特征向量上是不是就可以了梁棠?答案是肯定的。
注意其中的raw2probability
在Spark2.3中的RandomForestClassificationModel中斗埂,簽名變?yōu)榱?code>raw2probabilityInPlace
全面繞開pipeline
前面解決了分類器預(yù)測的性能問題符糊,另外一個(gè)問題就來了。輸入的特征向量怎么來呢呛凶?在一個(gè)MLlib Pipeline流程中男娄,分類器預(yù)測只是最后一步,前面還有多種多樣的特征加工節(jié)點(diǎn)漾稀。我嘗試了將一個(gè)pipeline拆解成兩個(gè)模闲,一個(gè)用于特征加工,一個(gè)用于分類預(yù)測崭捍。用第一個(gè)pipeline加工特征尸折,只繞開第二個(gè),性能顯然是提升了殷蛇,但還沒達(dá)到預(yù)期效果实夹。于是,我有了另外一個(gè)想法:全面繞開pipeline粒梦,對(duì)pipeline的每一步亮航,都避免調(diào)用原生transform接口。這樣的弊端就是匀们,必須重寫pipeline的每一步預(yù)測方法缴淋,然后人肉還原pipeline的預(yù)測流程。流程大致跟上面類似泄朴。
例如:OneHot(說句題外話重抖,這東西在Spark2.3之前的版本是有bug的,詳情參考官方文檔)叼旋。
OneHotEncoderModel的transform方法如下:
@Since("2.3.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema, logging = true)
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
val encodedColumns = $(inputCols).indices.map { idx =>
val inputColName = $(inputCols)(idx)
val outputColName = $(outputCols)(idx)
val outputAttrGroupFromSchema =
AttributeGroup.fromStructField(transformedSchema(outputColName))
val metadata = if (outputAttrGroupFromSchema.size < 0) {
OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName,
categorySizes(idx), $(dropLast), keepInvalid).toMetadata()
} else {
outputAttrGroupFromSchema.toMetadata()
}
encoder(col(inputColName).cast(DoubleType), lit(idx))
.as(outputColName, metadata)
}
dataset.withColumns($(outputCols), encodedColumns)
}
里面對(duì)feature進(jìn)行轉(zhuǎn)換的關(guān)鍵代碼行是 encoder...
private def encoder: UserDefinedFunction = {
val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID
val configedSizes = getConfigedCategorySizes
val localCategorySizes = categorySizes
// The udf performed on input data. The first parameter is the input value. The second
// parameter is the index in inputCols of the column being encoded.
udf { (label: Double, colIdx: Int) =>
val origCategorySize = localCategorySizes(colIdx)
// idx: index in vector of the single 1-valued element
val idx = if (label >= 0 && label < origCategorySize) {
label
} else {
if (keepInvalid) {
origCategorySize
} else {
if (label < 0) {
throw new SparkException(s"Negative value: $label. Input can't be negative. " +
s"To handle invalid values, set Param handleInvalid to " +
s"${OneHotEncoderEstimator.KEEP_INVALID}")
} else {
throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
}
}
}
val size = configedSizes(colIdx)
if (idx < size) {
Vectors.sparse(size, Array(idx.toInt), Array(1.0))
} else {
Vectors.sparse(size, Array.empty[Int], Array.empty[Double])
}
}
}
encoder里面關(guān)鍵的是這個(gè)udf仇哆,將其摳出重寫之后直接作用于特征向量。
效果
經(jīng)過測試夫植,全面繞開pipeline之后讹剔,響應(yīng)時(shí)間下降到16ms左右。(2000條query平均響應(yīng)時(shí)間)详民,且不再有抖動(dòng)延欠。