Swing公式
Swing公式
Ui,Uj為item1脓杉,和item2的點擊個數(shù)馋袜,下圖主要講解紅框的構(gòu)建思路,難點在于紅框的構(gòu)建思路偷拔,本文主要講解紅框的計算思路蒋院,個人試驗了很久發(fā)現(xiàn)了一種較好的解決方式
思路
思路舉例
注:圖中兩次過濾可過濾大量數(shù)據(jù)亏钩,解法比較有意思的地方在于用求根公式求解user1和user2點擊了的item的共同數(shù)目,經(jīng)過我粗略實驗欺旧,發(fā)現(xiàn)直接利用itemPair出現(xiàn)的數(shù)目效果反而更好姑丑,或許值得調(diào)整原模型alpha后再查看效果
Swing模型構(gòu)建流程
swing模型構(gòu)建流程
思路舉例
代碼直接掉用fitOnline就好,按照PvEntity給出的數(shù)據(jù)格式構(gòu)造數(shù)據(jù)辞友,param為文件中SwingParams的廣播變量
package com.sohu.mp.rec.itemBased.Swing.main
import com.sohu.mp.rec.itemBased.ItemCF.main.ItemCFManager.{computeSimilarities, loadData}
import com.sohu.mp.rec.itemBased.ItemCF.main.PathManager
import com.sohu.mp.rec.itemBased.Swing.entity.{Item, User}
//import com.sohu.mp.rec.itemBased.Swing.util.SwingUtil.SwingParams
import com.sohu.mp.rec.recall.common.entity.application.base.PvEntity
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import scala.collection.mutable.ArrayBuffer
object SwingManagerTmp {
case class SwingParams(minItemClick:Int = 100,
maxItemClick:Int = 10000000,
minUserClick:Int = 2,
maxUserClick:Int = 1000,
minPairClickNum:Int = 2,
maxPairClickNum:Int = 100000,
sessionLimitFlag: Boolean= false,
sessionLimitStamp: Long = 60L * 60 * 1000,
timeSeqFlag:Boolean= false,
alpha:Double= -15000D,
directionFlag: Boolean= false,
zeta: Double= 0.7D,
beta: Double= 0.8D,
userClickLengthFlag: Boolean = false,
itemTopN: Int = 150,
swingAlpha: Int = 5,
topNItemNum: Int = 5000,
defaultParallelism: Int = 500,
update: Boolean = false
)
def filterItemClickCnt(rawData: RDD[PvEntity],
params: Broadcast[SwingParams]):
RDD[(String, (String, Long, Int))]={
val userItemClick = rawData.map(rowDataEntity => (rowDataEntity.item_id, (rowDataEntity.user_id, rowDataEntity.timeStamp)))
.groupByKey()
.flatMap{case(itemId, itemClickArray) =>{
val itemClickSet = itemClickArray.toSet.toArray.take(params.value.maxItemClick)
val size = itemClickSet.size
val userItemResult = new ArrayBuffer[(String, (String, Long, Int))]()
println(s"model Params, $params")
if(size <= params.value.maxItemClick && size > params.value.minItemClick){
for(userIdTimeStamp <- itemClickSet){
// userId, (itemId, time, clickSize)
val userId = userIdTimeStamp._1
val time = userIdTimeStamp._2
userItemResult.append((userId, (itemId, time, size)))
}
}
userItemResult.toIterator
}}
userItemClick
}
// 獲取文章點擊總數(shù)
def getItemClickCnt(userItemClick: RDD[(String, (String, Long, Int))]): RDD[(String, Int)]={
val itemClick = userItemClick.map{
case(userId, (itemId, timeStamp, totalNlickNum)) => (itemId, totalNlickNum)
}
itemClick
}
//利用用戶點擊長度過濾數(shù)據(jù)
def genUserItemSet(rawData: RDD[(String, (String, Long, Int))], params: Broadcast[SwingParams]):
RDD[(String, Array[Item])]={
val userClickSet = rawData.map{
case(userId, (itemId, time, size)) =>
// UserId, itemId, itemTimeStamp
(userId -> Item(itemId, time))}
.groupByKey()
.map{
case((userId, clickedItems)) => {
val itemSet = clickedItems.toSet.toArray
if(itemSet.size > params.value.minUserClick && itemSet.size < params.value.maxUserClick){
(userId, itemSet)
}else{
(userId, null)
}
}}.filter{case(userId, itemSet) =>{itemSet != null}}
userClickSet
}
//genItemUserSet
def genUserPairScore(userPairItemPair: RDD[((String, String), (String, String))], params: Broadcast[SwingParams]):
RDD[((String, String), Double)]={
val userPairClickNum = userPairItemPair.map{
case ((userA, userB), (xItem, yItem)) => ((userA, userB), 1)
}
val userPairScore = userPairClickNum.reduceByKey{
case(xCoClick, yCoClick) => xCoClick + yCoClick
}.map{case ((xUser, yUser), coPairScore) =>{
// 通過x= [-b + sqrt(b^2-4ac)]/2a得出公式[1+sqrt(1+8*coClickNum)]/2
// n * (n-1)/2=coClickNum
val coClickNum = (1 + math.sqrt(1 + 8 * coPairScore)) / 2
val coScore = 1.0d / (params.value.swingAlpha + coClickNum)
// 自己的方式修改Swing score
//val coScore = 1.0d / coPairScore
((xUser, yUser), coScore)
}}
userPairScore
}
def genUserPairItemPair(itemPairUserSet: RDD[((String, String), Array[String])],
params: Broadcast[SwingParams]):
RDD[((String, String), (String, String))] ={
val userPairItemPair = itemPairUserSet.flatMap{
case((xItemId, yItemId), userIdArray) =>{
// userPair itemPair
// 共同點擊 應(yīng)該不會超出Int 21億
val uPairIPairArray = ArrayBuffer[((String, String), (String, String))]()
val userIdArraySize = userIdArray.size
for(i <- 0 until userIdArraySize; j <- i+1 until userIdArraySize){
// userPair itemPair
uPairIPairArray.append(((userIdArray(i), userIdArray(j)), (xItemId, yItemId)))
}
uPairIPairArray
}
}
println(s"userPairItemPair length and data: ${userPairItemPair.count()}")
userPairItemPair.take(10).foreach(println)
userPairItemPair
}
def getUserItemIds(userId: String,
userItemSetMap: Broadcast[scala.collection.Map[String, Array[Item]]]):
Array[String]={
userItemSetMap.value.getOrElse(userId, Array[Item]()).map{ case(item) => item.itemId}
}
//def updateItemScore(userPairScore: RDD[((String, String), Double)],
// userItemSetMap: Broadcast[scala.collection.Map[String, Array[Item]]],
// params: Broadcast[SwingParams]):
//RDD[((String, String), Double)]={
// val userPairScoreUpdate = userPairScore.map{case((xUserId, yUserId), score) =>{
// val xUserClickedItem = getUserItemIds(xUserId, userItemSetMap)
// val yUserClickedItem = getUserItemIds(yUserId, userItemSetMap)
// val coClickedItem = xUserClickedItem.intersect(yUserClickedItem)
// val scoreUpdate = computePairScore(coClickedItem, xUserClickedItem.size,
// yUserClickedItem.size, score, params)
// ((xUserId, yUserId), scoreUpdate)
// }}
// userPairScoreUpdate
//}
def computePairScore(coClickedItem: Array[String],
xUserClickedSize: Int,
yUserClickedSize: Int,
score: Double,
params: Broadcast[SwingParams]): Double ={
var scoreUpdate = score
// // 采用時間序列計算
// if(params.value.timeSeqFlag){
// val timeDistance = math.abs(yItem.timeStamp - xItem.timeStamp).toDouble
// score = score * math.exp((params.value.alpha * timeDistance))
// }
// // 計算方向
// if(params.value.directionFlag){
// val locationDistance = yItem.localtion - xItem.localtion
// // 如果序列為反方向
// var currentZeta = params.value.zeta
// if(locationDistance > 0){
// currentZeta = 1.0f
// }
// score = score * currentZeta * math.pow(params.value.beta, math.abs(locationDistance) - 1)
// }
// 考慮用戶點擊長度
if(params.value.userClickLengthFlag){
scoreUpdate = score / ( math.log(1 + xUserClickedSize) * math.log(1 + yUserClickedSize))
}
scoreUpdate
}
def genItemPairUserSetRdd(userItemSet: RDD[(String, Array[Item])], params: Broadcast[SwingParams]):
RDD[((String, String), Array[String])]={
val userItemPairs = userItemSet.flatMap{case(userId, items) => {
val itemPairsUser =new ArrayBuffer[((String, String), String)]()
for(i <- 0 until items.length; j <- i+1 until items.length){
val xItem = items(i)
val yItem = items(j)
itemPairsUser.append(((items(i).itemId, items(j).itemId), userId))
}
itemPairsUser.toIterator
}}.groupByKey()
.map{case ((xItemId, yItemId), userIdArray) =>{
val userSet = userIdArray.toSet.toArray.take(params.value.maxPairClickNum)
if(userSet.size > params.value.minPairClickNum &&
userSet.size <= params.value.maxPairClickNum){
((xItemId, yItemId), userSet)
}else{
((xItemId, yItemId), null)
}
}}
.filter{case((xItemId, yItemId), userIdArray) => userIdArray != null}
userItemPairs
}
// 構(gòu)建用戶pair栅哀,將(User1,User2)視為同一個用戶
def constructUserPair(filteredRowData: RDD[(String, (Item, User))], params: Broadcast[SwingParams]):
RDD[(Array[Item], User, User)]={
// (itemId, Item, User)
val itemRowData = filteredRowData
.map{case(userId, (item, user)) =>(item.itemId, (item, user))}
//.map(userItem =>(userItem._2._1.itemId, userItem._2))
val userPairItems = itemRowData.groupByKey()
.flatMap{case(itemId, itemUserArray)=>{
val itemUserSet = itemUserArray.toSet.toArray
// (userId1, userId2) -> (Item, User1, User2)
val pairResult = new ArrayBuffer[((String, String),(Item, User, User))]()
for(i <- 0 until itemUserSet.length; j <- i+1 until itemUserSet.length){
val xPair = itemUserSet(i)
val yPair = itemUserSet(j)
val item = xPair._1
val xUser = xPair._2
val yUser = yPair._2
pairResult.append(((xUser.userId, yUser.userId),(item, xUser, yUser)))
}
pairResult.toIterator
}}
// 表示(user1, user2)共同看過的所有item的Array
.groupByKey()
.map{case((userAId, userBId), itemUserPairsArray)=> {
var items = itemUserPairsArray.map{case(item, userA, userB) =>item}.toSet.toArray
val users = itemUserPairsArray.take(1).toArray
val userA = users(0)._2
val userB = users(0)._3
if(params.value.timeSeqFlag){
items = items.sortBy(_.timeStamp)
}
(items, userA, userB)
}}
userPairItems
}
def constructItemPair(userPairItemPair: RDD[((String, String), (String, String))],
userPairScoreRdd: RDD[((String, String), Double)],
params: Broadcast[SwingParams]):
RDD[((String, String), Double)]={
val itemPairScore = userPairItemPair.join(userPairScoreRdd).map{
case ((userA, userB), ((xItem, yItem), score)) => ((xItem, yItem), score)
}.reduceByKey(_+_)
itemPairScore
}
// def computePairScore(xItem: ItemWithLocation,
// yItem: ItemWithLocation,
// userA: User,
// userB: User,
// coClickNum: Int,
// params: Broadcast[SwingParams]): Double ={
// val userAClickLength = userA.clickLength
// val userBClickLength = userB.clickLength
// var score = 1.0 / (params.value.swingAlpha + coClickNum)
// // 采用時間序列計算
// if(params.value.timeSeqFlag){
// val timeDistance = math.abs(yItem.timeStamp - xItem.timeStamp).toDouble
// score = score * math.exp((params.value.alpha * timeDistance))
// }
// // 計算方向
// if(params.value.directionFlag){
// val locationDistance = yItem.localtion - xItem.localtion
// // 如果序列為反方向
// var currentZeta = params.value.zeta
// if(locationDistance > 0){
// currentZeta = 1.0f
// }
// score = score * currentZeta * math.pow(params.value.beta, math.abs(locationDistance) - 1)
// }
// // 考慮用戶點擊長度, 如果用戶A,B點擊越短,但是都點擊了
// // 相同的(xItem, yItem)称龙,則說明該pair相關(guān)性很強
// if(params.value.userClickLengthFlag){
// score = score /( math.log(1 + userAClickLength) * math.log(1 + userBClickLength))
// }
// score
// }
def selectItemTopN(pairScore: RDD[(String, (String, Double))], params: Broadcast[SwingParams]):
RDD[(Long, Seq[Long])]={
val itemCandidates = pairScore.groupByKey()
.map(itemCandidates =>{
val itemId = itemCandidates._1.toLong
//按照分?jǐn)?shù)降序排列
val candidates = itemCandidates._2.toArray.sortBy{
case(itemId, score) => -score
}
.take(params.value.itemTopN)
.map{
case(itemId, score) => itemId.toLong
}.toSeq
println(s"sort candidates ${candidates.mkString("")}")
(itemId, candidates)
})
itemCandidates
}
// Dates形如{"20200718", "20200720"}的格式留拾,params參數(shù)格式為SwingParamsd的格式,
// 用于控制Swing模型過濾數(shù)據(jù)鲫尊,還有模型參數(shù)的存放
def fit(spark: SparkSession, dates:List[String], params: SwingParams):
RDD[(Long, Seq[Long])]={
val modelParams = spark.sparkContext.broadcast(params)
//println(s"model Params, $modelParams")
//println("SwingManager params minItemClick, userClickLengthFlag, directionFlag")
//println(minItemClick)
//println(userClickLengthFlag)
//println(directionFlag)
//spark.sparkContext.
// HDFS上保存數(shù)據(jù)的位置
val dataPath = PathManager.getSwingDataPath()
println(s"train data path: ${dataPath}")
// 加載數(shù)據(jù)的方式
val rowDataEntityRdd = loadData(spark, dates, dataPath)
println(s"rowDataEntityRdd length and data: ${rowDataEntityRdd.count()}")
rowDataEntityRdd.take(10).foreach(println)
fitOnline(spark, rowDataEntityRdd, modelParams)
}
def fitOnline(spark: SparkSession, rowDataEntityRdd: RDD[PvEntity], params: Broadcast[SwingParams]):
RDD[(Long, Seq[Long])]={
val filteredItemRdd = filterItemClickCnt(rowDataEntityRdd, params)
println(s"filteredItemRdd length and data: ${filteredItemRdd.count()}")
filteredItemRdd.take(10).foreach(println)
val itemClickNumRdd = getItemClickCnt(filteredItemRdd).distinct()
println(s"itemClickNumRdd length and data: ${itemClickNumRdd.count()}")
itemClickNumRdd.take(10).foreach(println)
val userItemSetRdd = genUserItemSet(filteredItemRdd, params)
println(s"userItemSetRdd length and data: ${userItemSetRdd.count()}")
userItemSetRdd.take(10).foreach(println)
//filteredItemRdd.unpersist()
val itemPairUserSet = genItemPairUserSetRdd(userItemSetRdd, params)
println(s"itemPairUserSet length and data: ${itemPairUserSet.count()}")
itemPairUserSet.take(10).foreach(println)
val userPairItemPair = genUserPairItemPair(itemPairUserSet, params)
println(s"itemPairUserSet length and data: ${itemPairUserSet.count()}")
itemPairUserSet.take(10).foreach(println)
// user construct
var userPairScoreRdd = genUserPairScore(userPairItemPair, params)
//if(params.value.update){
// println(">>> entry into update userPairsScore pharse <<<")
// val userItemSetMap = spark.sparkContext.broadcast(userItemSetRdd.collectAsMap())
// userPairScoreRdd = updateItemScore(userPairScoreRdd, userItemSetMap, params)
// userItemSetMap.unpersist()
// println(">>> update userPairsScore pharse end <<<")
//}
//userItemSetRdd.unpersist()
//val userPairScoreBC = spark.sparkContext.broadcast(userPairScoreRdd.collectAsMap())
val itemPairs = constructItemPair(userPairItemPair, userPairScoreRdd, params)
println(s"itemPairs length and data: ${itemPairs.count()}")
itemPairs.take(10).foreach(println)
//itemPairUserSet.unpersist()
val itemSimilarities = computeSimilarities(itemClickNumRdd, itemPairs)
println(s"itemSimilarities length and data: ${itemSimilarities.count()}")
itemSimilarities.take(10).foreach(println)
//itemClickNumRdd.unpersist()
val itemsTopN = selectItemTopN(itemSimilarities, params)
println(s"itemsTopN length and data: ${itemsTopN.count()}")
itemsTopN.take(10).foreach(println)
itemsTopN
}
}
// 計算XItem, YItem之間的得分
def computeSimilarities(validItemRdd: RDD[(String, Int)],
itemPairs: RDD[((String, String), Double)]):
RDD[(String, (String, Double))]={
val similarities = itemPairs.map{
case ((xItemId, yItemId), pairScore) => (xItemId, (yItemId, pairScore))
}.join(validItemRdd)
.map { case (xItemId, ((yItemId, pairScore), xItemClickCnt)) =>
(yItemId, (pairScore, xItemId, xItemClickCnt))}
.join(validItemRdd)
.map { case (yItemId, ((pairScore, xItemId, xItemClickCnt), yItemClickCnt)) =>
val cosine = pairScore / math.sqrt(xItemClickCnt * yItemClickCnt)
(xItemId -> (yItemId, cosine))}
similarities
}
參考文章:
https://mp.weixin.qq.com/s?__biz=MjM5MzY4NzE3MA==&mid=2247485008&idx=1&sn=ca0549a346bc9879c48fc99628410621&chksm=a69275bd91e5fcab7a779eccbaee6d1715eb9611c7f9e4c32e1c5c814f5f9e1d49000602476e&mpshare=1&scene=1&srcid=&sharer_sharetime=1592800738855&sharer_shareid=d1a917c43153309de51a76d5d54e85ef#rd
https://zhuanlan.zhihu.com/p/67126386?from_voters_page=true