本章節(jié)介紹了分類和回歸的算法夏跷。它還包括討論特定類別的算法部分哼转,如:線性方法,樹和集成槽华。
下面是整個API Doc中的內容框架壹蔓,這里不會每個都詳細介紹,主要會把用到的介紹出來猫态,后續(xù)用到的再陸續(xù)添加雁刷。(下面的鏈接都是指向官網文檔而不是本筆記中的對應內容所在位置织阳,而且有些內容沒有出現在本筆記中)
Classification 分類
邏輯回歸
邏輯回歸是預測分類問題的流行算法青瀑。它是 廣義線性模型的一個特例來預測結果的可能性衣撬。 在spark.ml邏輯回歸中可以使用二項式Logistic回歸來預測二分類問題酿联,也可以通過使用多項Logistic回歸來預測多分類問題终息。 使用family參數在這兩種算法之間進行選擇,或者不設置它贞让,讓Spark自己推斷出正確的值周崭。
通過將family參數設置為“多項式”,也可以將多項Logistic回歸用于二分類問題喳张。它將產生兩個系數的集合和兩個intercept续镇。
當在沒有intercept的常量非零列的數據集上對LogisticRegressionModel進行擬合時,Spark MLlib為常數非零列輸出零系數销部。此行為與R glmnet相同摸航,但與LIBSVM不同制跟。
二分類邏輯回歸
有關二項式邏輯回歸實現的更多背景和更多細節(jié),請參閱spark.mllib中邏輯回歸的文檔酱虎。
代碼示例:
以下示例顯示了如何用elastic net regularization來訓練的二項式和多項Logistic的回歸模型用于二分類問題雨膨。 elasticNetParam對應于α,regParam對應于λ读串。(這兩個參數的定義參見Linear methods)
Java版代碼
public class JavaLogisticRegressionWithElasticNetExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaLogisticRegressionWithElasticNetExample")
.getOrCreate();
// $example on$
// Load training data
Dataset<Row> training = spark.read().format("libsvm")
.load("/home/paul/spark/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_libsvm_data.txt");
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8);
// Fit the model
LogisticRegressionModel lrModel = lr.fit(training);
// Print the coefficients and intercept for logistic regression
System.out.println("\n---------- Binomial logistic regression's Coefficients: "
+ lrModel.coefficients() + "\nBinomial Intercept: " + lrModel.intercept());
// We can also use the multinomial family for binary classification
LogisticRegression mlr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8)
.setFamily("multinomial");
// Fit the model
LogisticRegressionModel mlrModel = mlr.fit(training);
// Print the coefficients and intercepts for logistic regression with multinomial family
System.out.println("\n+++++++++ Multinomial coefficients: " + mlrModel.coefficientMatrix()
+ "\nMultinomial intercepts: " + mlrModel.interceptVector());
// $example off$
spark.stop();
}
}
spark.ml實現的邏輯回歸算法也支持提取出訓練集上訓練后模型的摘要(這有助于分析模型在訓練集上的性能)恢暖。 需要注意的是預測結果和權值在BinaryLogisticRegressionSummary中被存儲為DataFrame類型并且被標注為@transient排监,所以只能在driver上可用。
LogisticRegressionTrainingSummary
是提供給LogisticRegressionModel
的摘要杰捂。目前只有二分類模型有這個功能舆床,而且必須被顯式的強轉成類型BinaryLogisticRegressionTrainingSummary
。對于多分類模型的摘要的支持將在后續(xù)版本中實現琼娘。
Java版代碼:
public class JavaLogisticRegressionSummaryExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaLogisticRegressionSummaryExample")
.getOrCreate();
// Load training data
Dataset<Row> training = spark.read().format("libsvm")
.load("/home/paul/spark/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_libsvm_data.txt");
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8);
// Fit the model
LogisticRegressionModel lrModel = lr.fit(training);
// $example on$
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier
// example
LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
// Obtain the loss per iteration.
double[] objectiveHistory = trainingSummary.objectiveHistory();
for (double lossPerIteration : objectiveHistory) {
System.out.println(lossPerIteration);
}
// Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary
// classification problem.
BinaryLogisticRegressionSummary binarySummary =
(BinaryLogisticRegressionSummary) trainingSummary;
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
Dataset<Row> roc = binarySummary.roc();
roc.show();
roc.select("FPR").show();
System.out.println(binarySummary.areaUnderROC());
// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
Dataset<Row> fMeasure = binarySummary.fMeasureByThreshold();
double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0);
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
.select("threshold").head().getDouble(0);
lrModel.setThreshold(bestThreshold);
// $example off$
spark.stop();
}
}
運行結果為:
0.6833149135741672
0.6662875751473734
0.6217068546034618
0.6127265245887887
0.6060347986802873
0.6031750687571562
0.5969621534836274
0.5940743031983118
0.5906089243339022
0.5894724576491042
0.5882187775729587
17/05/02 22:46:21 WARN Executor: 1 block locks were not released by TID = 25:
[rdd_39_0]
+---+--------------------+
|FPR| TPR|
+---+--------------------+
|0.0| 0.0|
|0.0|0.017543859649122806|
|0.0| 0.03508771929824561|
|0.0| 0.05263157894736842|
|0.0| 0.07017543859649122|
|0.0| 0.08771929824561403|
|0.0| 0.10526315789473684|
|0.0| 0.12280701754385964|
|0.0| 0.14035087719298245|
|0.0| 0.15789473684210525|
|0.0| 0.17543859649122806|
|0.0| 0.19298245614035087|
|0.0| 0.21052631578947367|
|0.0| 0.22807017543859648|
|0.0| 0.24561403508771928|
|0.0| 0.2631578947368421|
|0.0| 0.2807017543859649|
|0.0| 0.2982456140350877|
|0.0| 0.3157894736842105|
|0.0| 0.3333333333333333|
+---+--------------------+
only showing top 20 rows
17/05/02 22:46:22 WARN Executor: 1 block locks were not released by TID = 27:
[rdd_39_0]
+---+
|FPR|
+---+
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
|0.0|
+---+
only showing top 20 rows
1.0