引入依賴
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.12</artifactId>
<version>2.4.4</version>
<exclusions>
<exclusion>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>14.0.1</version>
</dependency>
數(shù)據(jù)準(zhǔn)備
門店數(shù)據(jù)
- 通過 dml.sql 導(dǎo)入了 400 條數(shù)據(jù)日川;
行為數(shù)據(jù)
- 保存在文件 behavior.csv 中里逆,總共 3 列刚照,第一列 userId,第二列 shopId搂蜓,第三列用戶對這個(gè)門店的鐘愛度打分;
- behavior.csv 中大概有 2 萬多條數(shù)據(jù)辽装;
離線 ALS 召回模型的訓(xùn)練
離線 ALS 召回模型的訓(xùn)練 | 過程
- 讀行為數(shù)據(jù) behavior.csv 到內(nèi)存中帮碰;
- 轉(zhuǎn)換數(shù)據(jù)結(jié)構(gòu):JavaRDD<String> -> JavaRDD<Rating> -> Dataset<Row>;
- 按 8-2 分拾积,將行為數(shù)據(jù)集分成 2 份殉挽,一份訓(xùn)練用,一份測試用拓巧;
- 設(shè)置 ALS 模型的參數(shù):
.setMaxIter(10)
斯碌,.setRank(5)
,.setRegParam(0.01)
肛度; - 生成模型傻唾;
- 生成模型測評器;
- 用測試行為數(shù)據(jù)承耿,測試生成的模型策吠,得到 rmse 得分;
- 生成的模型可以保存在磁盤瘩绒;
模型生成的結(jié)果
- alsmodel
- itemFactor - 存儲門店訓(xùn)練出來的特征值猴抹;
- metadata
- userFactors - 存儲用戶訓(xùn)練出來的特征值,二進(jìn)制的锁荔;
離線 ALS 召回模型的訓(xùn)練 | 代碼
package tech.lixinlei.dianping.recommand;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import java.io.IOException;
import java.io.Serializable;
/**
* ALS 召回算法的訓(xùn)練
* 實(shí)現(xiàn) Serializable 是因?yàn)轶案琒park 的程序可以運(yùn)行在不同的機(jī)器上;
*/
public class AlsRecallTrain implements Serializable {
public static void main(String[] args) throws IOException {
//初始化spark運(yùn)行環(huán)境
SparkSession spark = SparkSession.builder().master("local").appName("DianpingApp").getOrCreate();
JavaRDD<String> csvFile = spark.read().textFile("file:///home/lixinlei/project/gitee/dianping/src/main/resources/behavior.csv").toJavaRDD();
JavaRDD<Rating> ratingJavaRDD = csvFile.map(new Function<String, Rating>() {
/**
* 將 behavior.csv 中的一行阳堕,從 String 轉(zhuǎn)成 Rating跋理;
* @param v1 behavior.csv 中數(shù)據(jù)的一行
* @return
* @throws Exception
*/
@Override
public Rating call(String v1) throws Exception {
return Rating.parseRating(v1);
}
});
// Dataset 可以理解為 MySQL 中的一張表,row 中 column 的定義遵從 Rating 的定義恬总;
Dataset<Row> rating = spark.createDataFrame(ratingJavaRDD, Rating.class);
// 將所有的 rating 數(shù)據(jù)分成 8-2 分前普,80% 的數(shù)據(jù)用來做訓(xùn)練,20% 的訓(xùn)練用來做測試
Dataset<Row>[] splits = rating.randomSplit(new double[]{0.8, 0.2});
Dataset<Row> trainingData = splits[0];
Dataset<Row> testingData = splits[1];
// .setMaxIter(10) 設(shè)置最大擬合次數(shù)
// .setRank(5) 分解矩陣后 feature 的數(shù)量
// .setRegParam(0.01) 正則化系數(shù)壹堰,增大正則化的值拭卿,可以防止過擬合的情況
// 過擬合:指得是模型訓(xùn)練出來的內(nèi)容骡湖,過分的逼近真實(shí)數(shù)據(jù),導(dǎo)致一旦真實(shí)數(shù)據(jù)出現(xiàn)一些誤差峻厚,預(yù)測的結(jié)果反而不盡如人意响蕴;
// 欠擬合:模型訓(xùn)練出來的內(nèi)容,沒有達(dá)到收斂于真是數(shù)據(jù)惠桃,使得預(yù)測結(jié)果的偏差距離真實(shí)結(jié)果太大浦夷;
// 過擬合的解決方案:1)增大數(shù)據(jù)規(guī)模 2)減少 RANK,即特征的數(shù)量辜王,使得模型預(yù)測的能力更加松散 3)增大正則化的系數(shù)
// 欠擬合的解決方案:1)增加 RANK 2)減少正則化系數(shù)
ALS als = new ALS().setMaxIter(10).setRank(5).setRegParam(0.01).
setUserCol("userId").setItemCol("shopId").setRatingCol("rating");
// 模型訓(xùn)練
ALSModel alsModel = als.fit(trainingData);
// 模型評測:測評的時(shí)候劈狐,用到了 testingData 中的 userId 和 shopId 字段的值,沒有用 rating 字段的值呐馆,而且計(jì)算出了一個(gè)新字段懈息,叫 prediction
Dataset<Row> predictions = alsModel.transform(testingData);
// rmse 均方根誤差,預(yù)測值與真實(shí)值的偏差的平方除以觀測次數(shù)(testingData 的條數(shù))摹恰,開個(gè)根號
// rmse 的值越小辫继,標(biāo)識模型在測試數(shù)據(jù)集上的表現(xiàn)越好;
RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse")
.setLabelCol("rating").setPredictionCol("prediction");
double rmse = evaluator.evaluate(predictions);
System.out.println("rmse = " + rmse);
alsModel.save("file:///home/lixinlei/project/gitee/dianping/src/main/resources/alsmodel");
}
/**
* 自定義數(shù)據(jù)結(jié)構(gòu)俗慈,用來承接 behavior.csv 中的一行數(shù)據(jù)姑宽;
*/
public static class Rating implements Serializable{
private int userId;
private int shopId;
private int rating;
/**
* 將 hebavior.csv 中的一行數(shù)據(jù),組裝成 Rating 對象返回闺阱;
* @param str behavior.csv 文件的一行輸入
* @return
*/
public static Rating parseRating(String str){
str = str.replace("\"","");
String[] strArr = str.split(",");
int userId = Integer.parseInt(strArr[0]);
int shopId = Integer.parseInt(strArr[1]);
int rating = Integer.parseInt(strArr[2]);
return new Rating(userId,shopId,rating);
}
public Rating(int userId, int shopId, int rating) {
this.userId = userId;
this.shopId = shopId;
this.rating = rating;
}
public int getUserId() {
return userId;
}
public int getShopId() {
return shopId;
}
public int getRating() {
return rating;
}
}
}
使用離線 ALS 召回模型為活躍的 5 個(gè)用戶召回(粗排)門店信息
召回 | 步驟
- 先加載訓(xùn)練出的離線模型 ALSModel炮车;
- 再加載行為數(shù)據(jù) behavior.csv;
- 再選 5 個(gè)用戶做預(yù)測酣溃;
- 解析預(yù)測結(jié)果存入數(shù)據(jù)庫瘦穆;
召回 | 代碼實(shí)現(xiàn)
package tech.lixinlei.dianping.recommand;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.ForeachPartitionFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import java.io.Serializable;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.*;
/**
* 加載生成的模型,預(yù)測比較活躍的用戶赊豌,并且生成離線數(shù)據(jù)的候選集扛或;
*/
public class AlsRecallPredict {
public static void main(String[] args) {
// 初始化spark運(yùn)行環(huán)境
SparkSession spark = SparkSession.builder().master("local").appName("DianpingApp").getOrCreate();
// 加載模型進(jìn)內(nèi)存
ALSModel alsModel = ALSModel.load("file:///home/lixinlei/project/gitee/dianping/src/main/resources/alsmodel");
JavaRDD<String> csvFile = spark.read().textFile("file:///home/lixinlei/project/gitee/dianping/src/main/resources/behavior.csv").toJavaRDD();
JavaRDD<Rating> ratingJavaRDD = csvFile.map(new Function<String, Rating>() {
@Override
public Rating call(String v1) throws Exception {
return Rating.parseRating(v1);
}
});
Dataset<Row> rating = spark.createDataFrame(ratingJavaRDD, Rating.class);
// 給 5 個(gè)用戶做離線的召回結(jié)果預(yù)測
Dataset<Row> users = rating.select(alsModel.getUserCol()).distinct().limit(5);
// userRecs 就是預(yù)測的結(jié)果
Dataset<Row> userRecs = alsModel.recommendForUserSubset(users,20);
userRecs.foreachPartition(new ForeachPartitionFunction<Row>() {
@Override
public void call(Iterator<Row> t) throws Exception {
Connection connection = DriverManager.
getConnection("jdbc:mysql://127.0.0.1:3306/dianping?" +
"user=root&password=Jiangdi_2018&useUnicode=true&characterEncoding=UTF-8");
PreparedStatement preparedStatement = connection.
prepareStatement("insert into recommend(id, recommend) values (?, ?)");
List<Map<String,Object>> data = new ArrayList<Map<String, Object>>();
t.forEachRemaining(action -> {
int userId = action.getInt(0);
List<GenericRowWithSchema> recommendationList = action.getList(1);
List<Integer> shopIdList = new ArrayList<Integer>();
recommendationList.forEach(row->{
Integer shopId = row.getInt(0);
shopIdList.add(shopId);
});
String recommendData = StringUtils.join(shopIdList,",");
Map<String,Object> map = new HashMap<String, Object>();
map.put("userId",userId);
map.put("recommend",recommendData);
data.add(map);
});
data.forEach(stringObjectMap -> {
try {
preparedStatement.setInt(1, (Integer) stringObjectMap.get("userId"));
preparedStatement.setString(2, (String) stringObjectMap.get("recommend"));
preparedStatement.addBatch();
} catch (SQLException e) {
e.printStackTrace();
}
});
preparedStatement.executeBatch();
connection.close();
}
});
}
public static class Rating implements Serializable {
private int userId;
private int shopId;
private int rating;
public static Rating parseRating(String str){
str = str.replace("\"","");
String[] strArr = str.split(",");
int userId = Integer.parseInt(strArr[0]);
int shopId = Integer.parseInt(strArr[1]);
int rating = Integer.parseInt(strArr[2]);
return new Rating(userId,shopId,rating);
}
public Rating(int userId, int shopId, int rating) {
this.userId = userId;
this.shopId = shopId;
this.rating = rating;
}
public int getUserId() {
return userId;
}
public int getShopId() {
return shopId;
}
public int getRating() {
return rating;
}
}
}
召回的結(jié)果
SELECT * FROM dianping.recommend;
# id, recommend
'148', '400,216,145,131,421,464,128,257,332,479,283,248,447,138,494,292,228,186,231,378'
'463', '202,323,479,420,255,154,484,318,405,135,206,345,324,382,262,199,123,494,201,388'
'471', '216,479,127,191,464,172,202,125,389,494,411,303,455,226,249,369,291,105,211,434'
'1088', '324,465,402,135,294,199,163,203,255,185,147,323,130,430,388,313,112,145,219,481'
'1238', '268,438,130,383,313,324,465,203,180,148,222,353,252,402,481,368,142,428,448,198'