機器學(xué)習(xí)筆記-基于梯度下降的曲線擬合

背景

7月份的時候?qū)煵贾昧藗€作業(yè),他給了一條用程序生成的曲線,然后讓我們用代碼實現(xiàn)一個梯度下降算法來擬合曲線论泛。具體要求:

data.csv文件中包含兩列用逗號分隔的數(shù)據(jù)。第一列是x蛹屿,第二列是y屁奏。完成如下工作:
(1)在data.csv中隨機選擇80%的數(shù)據(jù)作為訓(xùn)練集,剩余20%作為測試集错负。
(2)構(gòu)造模型坟瓢,采用梯度下降算法訓(xùn)練模型。
(3)用測試集對訓(xùn)練的模型進(jìn)行評估犹撒,將測試集中的x作為輸入折联,用模型計算y,計算預(yù)測值與實際值的RMSE识颊。
(4)繪制data.csv中的點诚镰,繪制x ∈ [0,1] 之間模型的對應(yīng)曲線。

數(shù)據(jù)格式如下:

0.000000000000000000,0.000045401991009684
0.010010010010010010,0.000067487908347918
0.020020020020020020,0.000099516665248245
0.030030030030030030,0.000145574221405758
0.040040040040040040,0.000211247752152538
0.050050050050050046,0.000304101936049645
0.060060060060060060,0.000434277611628926
0.070070070070070073,0.000615236631426893
0.080080080080080079,0.000864687227990188
0.090090090090090086,0.001205760122738213
0.100100100100100092,0.001668621265042236

上面的csv文件一共有1000行數(shù)據(jù)祥款,在xy平面上繪制出來的曲線如下:


思路

老師的意思是先猜這條曲線是什么函數(shù)的曲線(先確定函數(shù)的基本形式)清笨,一開始函數(shù)的具體參數(shù)是不知道的,需要猜幾個初始值镰踏,那么猜出來的曲線一定和實際曲線有較大差異函筋,再用最優(yōu)化的方法找到使差異最小化的函數(shù)參數(shù),從而實現(xiàn)曲線的擬合奠伪。這里要求實現(xiàn)梯度下降算法來求解最小值跌帐。

從曲線的圖像來看原始數(shù)據(jù)應(yīng)該是幾個均值方差不同的高斯函數(shù)疊加而成的,圖中有4個峰绊率,因此可以假設(shè)曲線的模型為:f(x)=\alpha_1e^{-\frac{(x-\mu_1)^2}{2\sigma^2_1}}+\alpha_2e^{-\frac{(x-\mu_2)^2}{2\sigma^2_2}}+\alpha_3e^{-\frac{(x-\mu_3)^2}{2\sigma^2_3}}+\alpha_4e^{-\frac{(x-\mu_4)^2}{2\sigma^2_4}}谨敛。
令誤差函數(shù)為E=\sum\limits_{i=1}^{n} (f(x_i) - y_i)^2。則理想的模型參數(shù):
(\alpha_1,\mu_1,\sigma_1,\alpha_2,...,\sigma_4)=\min\limits_{\alpha_1,...,\sigma_4}E

梯度下降算法每次求出函數(shù)(E)在某個點(當(dāng)前參數(shù))的梯度滤否,因為梯度就是函數(shù)值增長最快的那個方向脸狸,所以讓參數(shù)沿著梯度的負(fù)方向乘以一定的步長進(jìn)行更新,就一定能抵達(dá)一個局部極小點藐俺。所以只要給定了這里的誤差函數(shù)E(\alpha_1,\mu_1,\sigma_1,\alpha_2,\mu_2,\sigma_2,\alpha_3,\mu_3,\sigma_3,\alpha_4,\mu_4,\sigma_4)炊甲,就可以通過梯度下降算法來找到使誤差函數(shù)達(dá)到局部極小的12個參數(shù)。

為了便于計算欲芹,可以把\sigma^2當(dāng)成一個整體卿啡,此時需要求出E在某個點的梯度的一般表示:(\frac{\partial E}{\partial \alpha_1},\frac{\partial E}{\partial \mu_1},\frac{\partial E}{\partial \sigma_1^2},\frac{\partial E}{\partial \alpha_2},\frac{\partial E}{\partial \mu_2},\frac{\partial E}{\partial \sigma_2^2},\frac{\partial E}{\partial \alpha_3},\frac{\partial E}{\partial \mu_3},\frac{\partial E}{\partial \sigma_3^2},\frac{\partial E}{\partial \alpha_4},\frac{\partial E}{\partial \mu_4},\frac{\partial E}{\partial \sigma_4^2},)。其中\frac{\partial E}{\partial \alpha_1}=2\sum\limits_{i=1}^{n}((f(x_i)-y_i)e^{-\frac{(x_i-\mu_1)^2}{2\sigma_1^2}}) \frac{\partial E}{\partial \mu_1}=2\sum\limits_{i=1}^{n}(\frac{\alpha_1(x_i-\mu_1)}{\sigma_1^2}(f(x_i)-y_i)e^{-\frac{(x_i-\mu_1)^2}{2\sigma_1^2}}) \frac{\partial E}{\partial \sigma_1^2}=2\sum\limits_{i=1}^{n}(\frac{\alpha_1(x_i-\mu_1)^2}{2\sigma_1^4}(f(x_i)-y_i)e^{-\frac{(x_i-\mu_1)^2}{2\sigma_1^2}})菱父,其余參數(shù)的偏導(dǎo)數(shù)以此類推颈娜。

設(shè)定一個迭代次數(shù),每次求出誤差函數(shù)的梯度后浙宜,設(shè)定步長\eta官辽,讓參數(shù)沿梯度的負(fù)方向更新,如:\alpha_1=\alpha_1-\eta\frac{\partial E}{\partial \alpha_1}粟瞬,\mu_1=\mu_1-\eta\frac{\partial E}{\partial \mu_1}同仆,然后重復(fù)這個步驟,直到達(dá)到一定迭代次數(shù)或者總誤差小于一定閾值停止迭代裙品。

程序

程序使用Java實現(xiàn)乓梨。(C++寫起來麻煩而且沒有合適的圖表顯示庫,Python太慢清酥,Java寫起來最順手)

一開始我面臨的問題就是選擇一個圖表顯示庫扶镀,簡單地調(diào)研了一下選了XChart,但是去了該項目的Github主頁發(fā)現(xiàn)居然沒有打包好的 jar 包焰轻,于是需要 clone 下來然后用 mvn package 命令把 jar 包打出來臭觉。

然后我定義了一個模型類 Model,這個模型類的成員變量是 double數(shù)組辱志,用來放待調(diào)的參數(shù)蝠筑,比如上文中的f(x)對應(yīng)的參數(shù)數(shù)組長度就為12。Model類有一些待實現(xiàn)的方法如函數(shù)的求值(val)揩懒、梯度的求值(grad)等什乙,其派生類GaussianModel就是上文中的模型。另外已球,因為梯度下降會抵達(dá)最近的極小點而不是全局最小點臣镣,最終的收斂點極大依賴于參數(shù)的初始值辅愿,我每次隨機選取了一部分?jǐn)?shù)據(jù)點來求梯度以跳出局部極小。

Java代碼如下:

package com.company;

import org.knowm.xchart.QuickChart;
import org.knowm.xchart.SwingWrapper;
import org.knowm.xchart.XYChart;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import static java.lang.Math.E;
import static java.lang.Math.pow;
import static java.lang.Math.sqrt;
import static java.lang.System.exit;


public class Solver {

    private List<Point> rawData = new ArrayList<>();
    private List<Point> trainData = new ArrayList<>();
    private List<Point> testData = new ArrayList<>();
    private Model model = null;
    private Function<Model, Double> loss = null;

    public Solver(String csvPath) throws FileNotFoundException {
        Scanner scanner = new Scanner(new File(csvPath));
        while (scanner.hasNextLine()) {
            String[] xy = scanner.nextLine().split(",");
            rawData.add(new Point(Double.valueOf(xy[0]), Double.valueOf(xy[1])));
        }
    }

    private Function<Model, Double> mse = (m) -> {
        double lossSum = 0.0;
        for (Point p : trainData) {
            double diff = m.val(p.x) - p.y;
            lossSum += (diff * diff);
        }
        return lossSum / 2.0;
    };

    private void divide(float ratio4Train) {
        trainData.clear();
        testData.clear();
        if (ratio4Train <= 0) throw new IllegalArgumentException("Ratio <= 0");
        int testCount = (int) (rawData.size() * (1 - ratio4Train));
        Random rand = new Random(System.currentTimeMillis());
        Set<Integer> exclusiveIndices4Test = new HashSet<>();
        while (exclusiveIndices4Test.size() < testCount) {
            int index = rand.nextInt(rawData.size());
            if (! exclusiveIndices4Test.contains(index)) {
                testData.add(rawData.get(index));
                exclusiveIndices4Test.add(index);
            }
        }
        for (int i = 0; i < rawData.size(); i ++) {
            if (! exclusiveIndices4Test.contains(i)) {
                trainData.add(rawData.get(i));
            }
        }
    }

    private void train() {
        System.out.println("Train data size: " + trainData.size());
        System.out.println("Test data size: " + testData.size());
//        model = new PolyModel(4);
        model = new GaussianModel(5);
        loss = mse;
        // ==========================================================
        for (int i = 0; i < 10000; i ++) {
            double lossVal = loss.apply(model);
            double[] gradVal = model.grad(trainData);
            System.out.println(String.format("Iter: %d, loss: %f ", i, lossVal));
            System.out.println(String.format("Theta: %f, %f, %f", model.theta[0], model.theta[1], model.theta[2]));
            System.out.println(String.format("Grad: %f, %f, %f\n", gradVal[0], gradVal[1], gradVal[2]));
            if (Double.isNaN(lossVal)) {
                model.randomize(); i = 0;
                continue;
            }
            for (int j = 0; j < gradVal.length; j ++) {
                double delta = model.rate(j) * gradVal[j];
                model.theta[j] -= delta;
            }
//            if (lossVal < 1.06) break;
        }
        System.out.println(String.format("Theta: %f, %f, %f", model.theta[0], model.theta[1], model.theta[2]));
    }

    private void validate() {
        double RMSE = 0.0;
        for (Point p : testData) {
            double diff = model.val(p.x) - p.y;
            RMSE += (diff * diff);
        }
        RMSE /= testData.size();
        RMSE = sqrt(RMSE);
        System.out.println("RMSE: " + RMSE);
    }

    private void plot() {
        XYChart chart = QuickChart.getChart(
                "Result", "X", "Y", "y(x)",
                trainData.stream().map(point -> point.x).collect(Collectors.toList()),
                trainData.stream().map(point -> point.y).collect(Collectors.toList()));

        double[] xPoints = new double[150];
        double[] yPoints = new double[150];
        for (int i = 0; i < 150; i ++) {
            xPoints[i] = i * 10.0 / 150;
            yPoints[i] = model.val(xPoints[i]);
        }
        chart.addSeries("model", xPoints, yPoints);

        new SwingWrapper<XYChart>(chart).displayChart();
    }

    public void solve() {
        divide(0.8f);
        train();
        validate();
        plot();
    }

    public static void main(String[] args) throws FileNotFoundException {
    // write your code here
        if (args.length < 1) {
            System.out.println("Usage: java -jar GradientDesent.jar data.csv");
            exit(0);
        }
        new Solver(args[0]).solve();
    }

    private static class Point {
        double x;
        double y;
        public Point(double x, double y) {this.x = x; this.y = y;}

    }

    private static abstract class Model {
        double theta[] = null;
        abstract double val(double x);
        abstract double[] grad(List<Point> trainData);
        abstract void randomize();
        abstract double rate(int i);
    }

    private static class PolyModel extends Model{

        public PolyModel(int n) {
            if (n < 2) throw new IllegalArgumentException("n MUST be larger than 2.");
            theta = new double[n];
            randomize();
        }

        double val(double x) {
            double result = 0.0;
            for (int i = 0; i < theta.length; i ++) {
                result += theta[i] * pow(x, i);
            }
            return result;
        }

        @Override
        double[] grad(List<Point> trainData) {
            double []gradVec = new double[theta.length];
            for (int i = 0; i < gradVec.length; i ++) {
                gradVec[i] = 0.0;
                Random r = new Random();
                List<Point> data = new ArrayList<>();
                for (int k = 0; k < 50; k ++)
                    data.add(trainData.get(r.nextInt(trainData.size())));
                for (Point p : data) {
                    double diff = val(p.x) - p.y;
                    gradVec[i] += (diff * pow(p.x, i));
                }
            }
            return gradVec;
        }

        @Override
        void randomize() {
            Random rand = new Random(System.currentTimeMillis());
            for (int i = 0; i < theta.length; i ++) {
                theta[i] = rand.nextDouble() ;
            }
        }

        @Override
        double rate(int i) {
            return 0.00000002;
        }
    }

    private static class GaussianModel extends Model{

        /**
         * f(x) = a * e ^ (- (x - μ)^2 / σ^2)
         * (a, μ, σ2) <<----
         * @param n number of gaussian function
         */
        public GaussianModel(int n) {
            if (n < 1) throw new IllegalArgumentException("n MUST be larger than 1.");
            theta = new double[n * 3];
            randomize();
        }

        @Override
        double val(double x) {
            double result = 0.0;
            for (int i = 0; i < theta.length / 3; i ++) {
                double alpha = theta[i * 3 + 0];
                double miu = theta[i * 3 + 1];
                double sigma2 = theta[i * 3 + 2];
                result += (alpha * pow(E, - pow((x - miu), 2) / sigma2 / 2));
            }
            return result;
        }

        @Override
        double[] grad(List<Point> trainData) {
            double[] gradVec = new double[theta.length];
            for (int i = 0; i < theta.length / 3; i ++) {
                gradVec[i * 3 + 0] = 0;
                gradVec[i * 3 + 1] = 0;
                gradVec[i * 3 + 2] = 0;
                double alpha = theta[i * 3 + 0];
                double miu = theta[i * 3 + 1];
                double sigma2 = theta[i * 3 + 2];
                Random r = new Random();
                List<Point> stochasticData = new ArrayList<>();
                for (int k = 0; k < 30; k ++)
                    stochasticData.add(trainData.get(r.nextInt(trainData.size())));
                for (Point p : stochasticData) {
                    double val = val(p.x);
                    gradVec[i * 3 + 0] += 2
                            * (val - p.y)
                            * (pow(E, - pow((p.x - miu), 2) / sigma2 / 2));
                    gradVec[i * 3 + 1] += (2
                            * alpha
                            * (val - p.y)
                            * pow(E, - pow((p.x - miu), 2) / sigma2 / 2)
                            * ((p.x - miu) / sigma2));
                    gradVec[i * 3 + 2] += (2
                            * alpha
                            * (val - p.y)
                            * pow(E, - pow((p.x - miu), 2) / sigma2 / 2)
                            * (pow((p.x - miu), 2) / pow(sigma2, 2) / 2)); //把sigma平方當(dāng)成了一個整體
                }
            }
            return gradVec;
        }

        @Override
        void randomize() {
            Random rand = new Random(System.currentTimeMillis());
            for (int i = 0; i < theta.length / 3; i ++) {
                theta[i * 3 + 0] = rand.nextDouble();
                theta[i * 3 + 1] = rand.nextDouble() * 5;
                theta[i * 3 + 2] = rand.nextDouble();
            }
        }

        @Override
        double rate(int i) {
            if (i % 3 == 0) {
                return 0.0005;
            } else if (i % 3 == 1) { // miu
                return 0.0005;
            } else {
                return 0.00005;
            }
        }

        public String toString() {
            StringBuilder builder = new StringBuilder("Theta: ");
            for (double t : theta) {
                builder.append(t);
                builder.append(", ");
            }
            builder.append("\nGrad: ");
            return builder.toString();
        }
    }
}

最后的結(jié)果還是比較看人品的忆某,并不是每次都能擬合地比較好点待,貼一個結(jié)果的圖:


結(jié)果

數(shù)據(jù)和代碼我放到了我的Github:https://github.com/Jimmie00x0000/gradient_desent_demo

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末弃舒,一起剝皮案震驚了整個濱河市癞埠,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌聋呢,老刑警劉巖苗踪,帶你破解...
    沈念sama閱讀 216,843評論 6 502
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異削锰,居然都是意外死亡通铲,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,538評論 3 392
  • 文/潘曉璐 我一進(jìn)店門喂窟,熙熙樓的掌柜王于貴愁眉苦臉地迎上來测暗,“玉大人,你說我怎么就攤上這事磨澡⊥胱模” “怎么了?”我有些...
    開封第一講書人閱讀 163,187評論 0 353
  • 文/不壞的土叔 我叫張陵稳摄,是天一觀的道長稚字。 經(jīng)常有香客問我,道長厦酬,這世上最難降的妖魔是什么胆描? 我笑而不...
    開封第一講書人閱讀 58,264評論 1 292
  • 正文 為了忘掉前任,我火速辦了婚禮仗阅,結(jié)果婚禮上昌讲,老公的妹妹穿的比我還像新娘。我一直安慰自己减噪,他們只是感情好短绸,可當(dāng)我...
    茶點故事閱讀 67,289評論 6 390
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著筹裕,像睡著了一般醋闭。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上朝卒,一...
    開封第一講書人閱讀 51,231評論 1 299
  • 那天证逻,我揣著相機與錄音,去河邊找鬼抗斤。 笑死囚企,一個胖子當(dāng)著我的面吹牛丈咐,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播洞拨,決...
    沈念sama閱讀 40,116評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼扯罐,長吁一口氣:“原來是場噩夢啊……” “哼负拟!你這毒婦竟也來了烦衣?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 38,945評論 0 275
  • 序言:老撾萬榮一對情侶失蹤掩浙,失蹤者是張志新(化名)和其女友劉穎花吟,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體厨姚,經(jīng)...
    沈念sama閱讀 45,367評論 1 313
  • 正文 獨居荒郊野嶺守林人離奇死亡衅澈,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,581評論 2 333
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了谬墙。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片今布。...
    茶點故事閱讀 39,754評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖拭抬,靈堂內(nèi)的尸體忽然破棺而出部默,到底是詐尸還是另有隱情,我是刑警寧澤造虎,帶...
    沈念sama閱讀 35,458評論 5 344
  • 正文 年R本政府宣布傅蹂,位于F島的核電站,受9級特大地震影響算凿,放射性物質(zhì)發(fā)生泄漏份蝴。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,068評論 3 327
  • 文/蒙蒙 一氓轰、第九天 我趴在偏房一處隱蔽的房頂上張望婚夫。 院中可真熱鬧,春花似錦署鸡、人聲如沸案糙。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,692評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽侍筛。三九已至,卻和暖如春撒穷,著一層夾襖步出監(jiān)牢的瞬間匣椰,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,842評論 1 269
  • 我被黑心中介騙來泰國打工端礼, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留禽笑,地道東北人入录。 一個月前我還...
    沈念sama閱讀 47,797評論 2 369
  • 正文 我出身青樓,卻偏偏與公主長得像佳镜,于是被迫代替她去往敵國和親僚稿。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,654評論 2 354

推薦閱讀更多精彩內(nèi)容