背景
7月份的時候?qū)煵贾昧藗€作業(yè),他給了一條用程序生成的曲線,然后讓我們用代碼實現(xiàn)一個梯度下降算法來擬合曲線论泛。具體要求:
data.csv文件中包含兩列用逗號分隔的數(shù)據(jù)。第一列是x蛹屿,第二列是y屁奏。完成如下工作:
(1)在data.csv中隨機選擇80%的數(shù)據(jù)作為訓(xùn)練集,剩余20%作為測試集错负。
(2)構(gòu)造模型坟瓢,采用梯度下降算法訓(xùn)練模型。
(3)用測試集對訓(xùn)練的模型進(jìn)行評估犹撒,將測試集中的x作為輸入折联,用模型計算y,計算預(yù)測值與實際值的RMSE识颊。
(4)繪制data.csv中的點诚镰,繪制x ∈ [0,1] 之間模型的對應(yīng)曲線。
數(shù)據(jù)格式如下:
0.000000000000000000,0.000045401991009684
0.010010010010010010,0.000067487908347918
0.020020020020020020,0.000099516665248245
0.030030030030030030,0.000145574221405758
0.040040040040040040,0.000211247752152538
0.050050050050050046,0.000304101936049645
0.060060060060060060,0.000434277611628926
0.070070070070070073,0.000615236631426893
0.080080080080080079,0.000864687227990188
0.090090090090090086,0.001205760122738213
0.100100100100100092,0.001668621265042236
上面的csv文件一共有1000行數(shù)據(jù)祥款,在xy平面上繪制出來的曲線如下:
思路
老師的意思是先猜這條曲線是什么函數(shù)的曲線(先確定函數(shù)的基本形式)清笨,一開始函數(shù)的具體參數(shù)是不知道的,需要猜幾個初始值镰踏,那么猜出來的曲線一定和實際曲線有較大差異函筋,再用最優(yōu)化的方法找到使差異最小化的函數(shù)參數(shù),從而實現(xiàn)曲線的擬合奠伪。這里要求實現(xiàn)梯度下降算法來求解最小值跌帐。
從曲線的圖像來看原始數(shù)據(jù)應(yīng)該是幾個均值方差不同的高斯函數(shù)疊加而成的,圖中有4個峰绊率,因此可以假設(shè)曲線的模型為:谨敛。
令誤差函數(shù)為。則理想的模型參數(shù):
梯度下降算法每次求出函數(shù)(E)在某個點(當(dāng)前參數(shù))的梯度滤否,因為梯度就是函數(shù)值增長最快的那個方向脸狸,所以讓參數(shù)沿著梯度的負(fù)方向乘以一定的步長進(jìn)行更新,就一定能抵達(dá)一個局部極小點藐俺。所以只要給定了這里的誤差函數(shù)炊甲,就可以通過梯度下降算法來找到使誤差函數(shù)達(dá)到局部極小的12個參數(shù)。
為了便于計算欲芹,可以把當(dāng)成一個整體卿啡,此時需要求出E在某個點的梯度的一般表示:。其中 菱父,其余參數(shù)的偏導(dǎo)數(shù)以此類推颈娜。
設(shè)定一個迭代次數(shù),每次求出誤差函數(shù)的梯度后浙宜,設(shè)定步長官辽,讓參數(shù)沿梯度的負(fù)方向更新,如:粟瞬,同仆,然后重復(fù)這個步驟,直到達(dá)到一定迭代次數(shù)或者總誤差小于一定閾值停止迭代裙品。
程序
程序使用Java實現(xiàn)乓梨。(C++寫起來麻煩而且沒有合適的圖表顯示庫,Python太慢清酥,Java寫起來最順手)
一開始我面臨的問題就是選擇一個圖表顯示庫扶镀,簡單地調(diào)研了一下選了XChart,但是去了該項目的Github主頁發(fā)現(xiàn)居然沒有打包好的 jar 包焰轻,于是需要 clone 下來然后用 mvn package 命令把 jar 包打出來臭觉。
然后我定義了一個模型類 Model,這個模型類的成員變量是 double數(shù)組辱志,用來放待調(diào)的參數(shù)蝠筑,比如上文中的對應(yīng)的參數(shù)數(shù)組長度就為12。Model類有一些待實現(xiàn)的方法如函數(shù)的求值(val)揩懒、梯度的求值(grad)等什乙,其派生類GaussianModel就是上文中的模型。另外已球,因為梯度下降會抵達(dá)最近的極小點而不是全局最小點臣镣,最終的收斂點極大依賴于參數(shù)的初始值辅愿,我每次隨機選取了一部分?jǐn)?shù)據(jù)點來求梯度以跳出局部極小。
Java代碼如下:
package com.company;
import org.knowm.xchart.QuickChart;
import org.knowm.xchart.SwingWrapper;
import org.knowm.xchart.XYChart;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import static java.lang.Math.E;
import static java.lang.Math.pow;
import static java.lang.Math.sqrt;
import static java.lang.System.exit;
public class Solver {
private List<Point> rawData = new ArrayList<>();
private List<Point> trainData = new ArrayList<>();
private List<Point> testData = new ArrayList<>();
private Model model = null;
private Function<Model, Double> loss = null;
public Solver(String csvPath) throws FileNotFoundException {
Scanner scanner = new Scanner(new File(csvPath));
while (scanner.hasNextLine()) {
String[] xy = scanner.nextLine().split(",");
rawData.add(new Point(Double.valueOf(xy[0]), Double.valueOf(xy[1])));
}
}
private Function<Model, Double> mse = (m) -> {
double lossSum = 0.0;
for (Point p : trainData) {
double diff = m.val(p.x) - p.y;
lossSum += (diff * diff);
}
return lossSum / 2.0;
};
private void divide(float ratio4Train) {
trainData.clear();
testData.clear();
if (ratio4Train <= 0) throw new IllegalArgumentException("Ratio <= 0");
int testCount = (int) (rawData.size() * (1 - ratio4Train));
Random rand = new Random(System.currentTimeMillis());
Set<Integer> exclusiveIndices4Test = new HashSet<>();
while (exclusiveIndices4Test.size() < testCount) {
int index = rand.nextInt(rawData.size());
if (! exclusiveIndices4Test.contains(index)) {
testData.add(rawData.get(index));
exclusiveIndices4Test.add(index);
}
}
for (int i = 0; i < rawData.size(); i ++) {
if (! exclusiveIndices4Test.contains(i)) {
trainData.add(rawData.get(i));
}
}
}
private void train() {
System.out.println("Train data size: " + trainData.size());
System.out.println("Test data size: " + testData.size());
// model = new PolyModel(4);
model = new GaussianModel(5);
loss = mse;
// ==========================================================
for (int i = 0; i < 10000; i ++) {
double lossVal = loss.apply(model);
double[] gradVal = model.grad(trainData);
System.out.println(String.format("Iter: %d, loss: %f ", i, lossVal));
System.out.println(String.format("Theta: %f, %f, %f", model.theta[0], model.theta[1], model.theta[2]));
System.out.println(String.format("Grad: %f, %f, %f\n", gradVal[0], gradVal[1], gradVal[2]));
if (Double.isNaN(lossVal)) {
model.randomize(); i = 0;
continue;
}
for (int j = 0; j < gradVal.length; j ++) {
double delta = model.rate(j) * gradVal[j];
model.theta[j] -= delta;
}
// if (lossVal < 1.06) break;
}
System.out.println(String.format("Theta: %f, %f, %f", model.theta[0], model.theta[1], model.theta[2]));
}
private void validate() {
double RMSE = 0.0;
for (Point p : testData) {
double diff = model.val(p.x) - p.y;
RMSE += (diff * diff);
}
RMSE /= testData.size();
RMSE = sqrt(RMSE);
System.out.println("RMSE: " + RMSE);
}
private void plot() {
XYChart chart = QuickChart.getChart(
"Result", "X", "Y", "y(x)",
trainData.stream().map(point -> point.x).collect(Collectors.toList()),
trainData.stream().map(point -> point.y).collect(Collectors.toList()));
double[] xPoints = new double[150];
double[] yPoints = new double[150];
for (int i = 0; i < 150; i ++) {
xPoints[i] = i * 10.0 / 150;
yPoints[i] = model.val(xPoints[i]);
}
chart.addSeries("model", xPoints, yPoints);
new SwingWrapper<XYChart>(chart).displayChart();
}
public void solve() {
divide(0.8f);
train();
validate();
plot();
}
public static void main(String[] args) throws FileNotFoundException {
// write your code here
if (args.length < 1) {
System.out.println("Usage: java -jar GradientDesent.jar data.csv");
exit(0);
}
new Solver(args[0]).solve();
}
private static class Point {
double x;
double y;
public Point(double x, double y) {this.x = x; this.y = y;}
}
private static abstract class Model {
double theta[] = null;
abstract double val(double x);
abstract double[] grad(List<Point> trainData);
abstract void randomize();
abstract double rate(int i);
}
private static class PolyModel extends Model{
public PolyModel(int n) {
if (n < 2) throw new IllegalArgumentException("n MUST be larger than 2.");
theta = new double[n];
randomize();
}
double val(double x) {
double result = 0.0;
for (int i = 0; i < theta.length; i ++) {
result += theta[i] * pow(x, i);
}
return result;
}
@Override
double[] grad(List<Point> trainData) {
double []gradVec = new double[theta.length];
for (int i = 0; i < gradVec.length; i ++) {
gradVec[i] = 0.0;
Random r = new Random();
List<Point> data = new ArrayList<>();
for (int k = 0; k < 50; k ++)
data.add(trainData.get(r.nextInt(trainData.size())));
for (Point p : data) {
double diff = val(p.x) - p.y;
gradVec[i] += (diff * pow(p.x, i));
}
}
return gradVec;
}
@Override
void randomize() {
Random rand = new Random(System.currentTimeMillis());
for (int i = 0; i < theta.length; i ++) {
theta[i] = rand.nextDouble() ;
}
}
@Override
double rate(int i) {
return 0.00000002;
}
}
private static class GaussianModel extends Model{
/**
* f(x) = a * e ^ (- (x - μ)^2 / σ^2)
* (a, μ, σ2) <<----
* @param n number of gaussian function
*/
public GaussianModel(int n) {
if (n < 1) throw new IllegalArgumentException("n MUST be larger than 1.");
theta = new double[n * 3];
randomize();
}
@Override
double val(double x) {
double result = 0.0;
for (int i = 0; i < theta.length / 3; i ++) {
double alpha = theta[i * 3 + 0];
double miu = theta[i * 3 + 1];
double sigma2 = theta[i * 3 + 2];
result += (alpha * pow(E, - pow((x - miu), 2) / sigma2 / 2));
}
return result;
}
@Override
double[] grad(List<Point> trainData) {
double[] gradVec = new double[theta.length];
for (int i = 0; i < theta.length / 3; i ++) {
gradVec[i * 3 + 0] = 0;
gradVec[i * 3 + 1] = 0;
gradVec[i * 3 + 2] = 0;
double alpha = theta[i * 3 + 0];
double miu = theta[i * 3 + 1];
double sigma2 = theta[i * 3 + 2];
Random r = new Random();
List<Point> stochasticData = new ArrayList<>();
for (int k = 0; k < 30; k ++)
stochasticData.add(trainData.get(r.nextInt(trainData.size())));
for (Point p : stochasticData) {
double val = val(p.x);
gradVec[i * 3 + 0] += 2
* (val - p.y)
* (pow(E, - pow((p.x - miu), 2) / sigma2 / 2));
gradVec[i * 3 + 1] += (2
* alpha
* (val - p.y)
* pow(E, - pow((p.x - miu), 2) / sigma2 / 2)
* ((p.x - miu) / sigma2));
gradVec[i * 3 + 2] += (2
* alpha
* (val - p.y)
* pow(E, - pow((p.x - miu), 2) / sigma2 / 2)
* (pow((p.x - miu), 2) / pow(sigma2, 2) / 2)); //把sigma平方當(dāng)成了一個整體
}
}
return gradVec;
}
@Override
void randomize() {
Random rand = new Random(System.currentTimeMillis());
for (int i = 0; i < theta.length / 3; i ++) {
theta[i * 3 + 0] = rand.nextDouble();
theta[i * 3 + 1] = rand.nextDouble() * 5;
theta[i * 3 + 2] = rand.nextDouble();
}
}
@Override
double rate(int i) {
if (i % 3 == 0) {
return 0.0005;
} else if (i % 3 == 1) { // miu
return 0.0005;
} else {
return 0.00005;
}
}
public String toString() {
StringBuilder builder = new StringBuilder("Theta: ");
for (double t : theta) {
builder.append(t);
builder.append(", ");
}
builder.append("\nGrad: ");
return builder.toString();
}
}
}
最后的結(jié)果還是比較看人品的忆某,并不是每次都能擬合地比較好点待,貼一個結(jié)果的圖:
數(shù)據(jù)和代碼我放到了我的Github:https://github.com/Jimmie00x0000/gradient_desent_demo。