DL4J實戰(zhàn)之三:經(jīng)典卷積實例(LeNet-5)

歡迎訪問我的GitHub

https://github.com/zq2599/blog_demos

內(nèi)容:所有原創(chuàng)文章分類匯總及配套源碼斤程,涉及Java、Docker、Kubernetes喳整、DevOPS等该园;

本篇概覽

  • 作為《DL4J》實戰(zhàn)的第三篇每篷,目標是在DL4J框架下創(chuàng)建經(jīng)典的LeNet-5卷積神經(jīng)網(wǎng)絡(luò)模型篇裁,對MNIST數(shù)據(jù)集進行訓(xùn)練和測試菲饼,本篇由以下內(nèi)容構(gòu)成:
  1. LeNet-5簡介
  2. MNIST簡介
  3. 數(shù)據(jù)集簡介
  4. 關(guān)于版本和環(huán)境
  5. 編碼
  6. 驗證

LeNet-5簡介

  • 是Yann LeCun于1998年設(shè)計的卷積神經(jīng)網(wǎng)絡(luò)让腹,用于手寫數(shù)字識別远剩,例如當年美國很多銀行用其識別支票上的手寫數(shù)字,LeNet-5是早期卷積神經(jīng)網(wǎng)絡(luò)最有代表性的實驗系統(tǒng)之一
  • LeNet-5網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示哨鸭,一共七層:C1 -> S2 -> C3 -> S4 -> C5 -> F6 -> OUTPUT
在這里插入圖片描述
在這里插入圖片描述
  • 按照上圖簡單分析一下,用于指導(dǎo)接下來的開發(fā):
  1. 每張圖片都是28*28的單通道像鸡,矩陣應(yīng)該是[1, 28,28]
  2. C1是卷積層活鹰,所用卷積核尺寸5*5,滑動步長1只估,卷積核數(shù)目20志群,所以尺寸變化是:28-5+1=24(想象為寬度為5的窗口在寬度為28的窗口內(nèi)滑動,能滑多少次)蛔钙,輸出矩陣是[20,24,24]
  3. S2是池化層锌云,核尺寸2*2,步長2吁脱,類型是MAX桑涎,池化操作后尺寸減半彬向,變成了[20,12,12]
  4. C3是卷積層,所用卷積核尺寸5*5攻冷,滑動步長1娃胆,卷積核數(shù)目50,所以尺寸變化是:12-5+1=8等曼,輸出矩陣[50,8,8]
  5. S4是池化層里烦,核尺寸2*2,步長2禁谦,類型是MAX胁黑,池化操作后尺寸減半,變成了[50,4,4]
  6. C5是全連接層(FC)州泊,神經(jīng)元數(shù)目500丧蘸,接relu激活函數(shù)
  7. 最后是全連接層Output,共10個節(jié)點拥诡,代表數(shù)字0到9触趴,激活函數(shù)是softmax

MNIST簡介

  • MNIST是經(jīng)典的計算機視覺數(shù)據(jù)集,來源是National Institute of Standards and Technology (NIST渴肉,美國國家標準與技術(shù)研究所)冗懦,包含各種手寫數(shù)字圖片,其中訓(xùn)練集60,000張仇祭,測試集 10,000張披蕉,
  • MNIST來源于250 個不同人的手寫,其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員.,測試集(test set) 也是同樣比例的手寫數(shù)字數(shù)據(jù)
  • MNIST官網(wǎng):http://yann.lecun.com/exdb/mnist/

數(shù)據(jù)集簡介

  • 從MNIST官網(wǎng)下載的原始數(shù)據(jù)并非圖片文件乌奇,需要按官方給出的格式說明做解析處理才能轉(zhuǎn)為一張張圖片没讲,這些事情顯然不是本篇的主題,因此咱們可以直接使用DL4J為我們準備好的數(shù)據(jù)集(下載地址稍后給出)礁苗,該數(shù)據(jù)集中是一張張獨立的圖片爬凑,這些圖片所在目錄的名字就是該圖片具體的數(shù)字,如下圖试伙,目錄<font color="blue">0</font>里面全是數(shù)字0的圖片:
在這里插入圖片描述
  • 上述數(shù)據(jù)集的下載地址有兩個:
  1. 可以在CSDN下載(0積分):https://download.csdn.net/download/boling_cavalry/19846603
  2. github:https://raw.githubusercontent.com/zq2599/blog_download_files/master/files/mnist_png.tar.gz
  • 下載之后解壓開嘁信,是個名為<font color="blue">mnist_png</font>的文件夾,稍后的實戰(zhàn)中咱們會用到它

關(guān)于DL4J版本

  • 《DL4J實戰(zhàn)》系列的源碼采用了maven的父子工程結(jié)構(gòu)疏叨,DL4J的版本在父工程<font color="blue">dlfj-tutorials</font>中定義為<font color="red">1.0.0-beta7</font>
  • 本篇的代碼雖然還是<font color="blue">dlfj-tutorials</font>的子工程潘靖,但是DL4J版本卻使用了更低的<font color="red">1.0.0-beta6</font>,之所以這么做蚤蔓,是因為下一篇文章卦溢,咱們會把本篇的訓(xùn)練和測試工作交給GPU來完成,而對應(yīng)的CUDA庫只有<font color="red">1.0.0-beta6</font>
  • 扯了這么多,可以開始編碼了

源碼下載

名稱 鏈接 備注
項目主頁 https://github.com/zq2599/blog_demos 該項目在GitHub上的主頁
git倉庫地址(https) https://github.com/zq2599/blog_demos.git 該項目源碼的倉庫地址贬芥,https協(xié)議
git倉庫地址(ssh) git@github.com:zq2599/blog_demos.git 該項目源碼的倉庫地址,ssh協(xié)議
  • 這個git項目中有多個文件夾宣决,《DL4J實戰(zhàn)》系列的源碼在<font color="blue">dl4j-tutorials</font>文件夾下誓军,如下圖紅框所示:
在這里插入圖片描述
  • <font color="blue">dl4j-tutorials</font>文件夾下有多個子工程,本次實戰(zhàn)代碼在<font color="blue">simple-convolution</font>目錄下疲扎,如下圖紅框:
在這里插入圖片描述

編碼

  • 在父工程 <font color="blue">dl4j-tutorials</font>下新建名為 <font color="red">simple-convolution</font>的子工程,其pom.xml如下捷雕,可見這里的dl4j版本被指定為<font color="red">1.0.0-beta6</font>:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <artifactId>dlfj-tutorials</artifactId>
        <groupId>com.bolingcavalry</groupId>
        <version>1.0-SNAPSHOT</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>simple-convolution</artifactId>

    <properties>
        <dl4j-master.version>1.0.0-beta6</dl4j-master.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
        </dependency>

        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
        </dependency>

        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>

        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>${nd4j.backend}</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
    </dependencies>
</project>
  • 接下來按照前面的分析實現(xiàn)代碼椒丧,已經(jīng)添加了詳細注釋,就不再贅述了:
package com.bolingcavalry.convolution;

import lombok.extern.slf4j.Slf4j;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

@Slf4j
public class LeNetMNISTReLu {

    // 存放文件的地址救巷,請酌情修改
//    private static final String BASE_PATH = System.getProperty("java.io.tmpdir") + "/mnist";

    private static final String BASE_PATH = "E:\\temp\\202106\\26";

    public static void main(String[] args) throws Exception {
        // 圖片像素高
        int height = 28;
        // 圖片像素寬
        int width = 28;
        // 因為是黑白圖像壶熏,所以顏色通道只有一個
        int channels = 1;
        // 分類結(jié)果,0-9浦译,共十種數(shù)字
        int outputNum = 10;
        // 批大小
        int batchSize = 54;
        // 循環(huán)次數(shù)
        int nEpochs = 1;
        // 初始化偽隨機數(shù)的種子
        int seed = 1234;

        // 隨機數(shù)工具
        Random randNumGen = new Random(seed);
        
        log.info("檢查數(shù)據(jù)集文件夾是否存在:{}", BASE_PATH + "/mnist_png");

        if (!new File(BASE_PATH + "/mnist_png").exists()) {
            log.info("數(shù)據(jù)集文件不存在棒假,請下載壓縮包并解壓到:{}", BASE_PATH);
            return;
        }

        // 標簽生成器,將指定文件的父目錄作為標簽
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        // 歸一化配置(像素值從0-255變?yōu)?-1)
        DataNormalization imageScaler = new ImagePreProcessingScaler();

        // 不論訓(xùn)練集還是測試集精盅,初始化操作都是相同套路:
        // 1. 讀取圖片帽哑,數(shù)據(jù)格式為NCHW
        // 2. 根據(jù)批大小創(chuàng)建的迭代器
        // 3. 將歸一化器作為預(yù)處理器

        log.info("訓(xùn)練集的矢量化操作...");
        // 初始化訓(xùn)練集
        File trainData = new File(BASE_PATH + "/mnist_png/training");
        FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
        ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
        trainRR.initialize(trainSplit);
        DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);
        // 擬合數(shù)據(jù)(實現(xiàn)類中實際上什么也沒做)
        imageScaler.fit(trainIter);
        trainIter.setPreProcessor(imageScaler);

        log.info("測試集的矢量化操作...");
        // 初始化測試集,與前面的訓(xùn)練集操作類似
        File testData = new File(BASE_PATH + "/mnist_png/testing");
        FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
        ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
        testRR.initialize(testSplit);
        DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
        testIter.setPreProcessor(imageScaler); // same normalization for better results

        log.info("配置神經(jīng)網(wǎng)絡(luò)");

        // 在訓(xùn)練中叹俏,將學習率配置為隨著迭代階梯性下降
        Map<Integer, Double> learningRateSchedule = new HashMap<>();
        learningRateSchedule.put(0, 0.06);
        learningRateSchedule.put(200, 0.05);
        learningRateSchedule.put(600, 0.028);
        learningRateSchedule.put(800, 0.0060);
        learningRateSchedule.put(1000, 0.001);

        // 超參數(shù)
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            // L2正則化系數(shù)
            .l2(0.0005)
            // 梯度下降的學習率設(shè)置
            .updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))
            // 權(quán)重初始化
            .weightInit(WeightInit.XAVIER)
            // 準備分層
            .list()
            // 卷積層
            .layer(new ConvolutionLayer.Builder(5, 5)
                .nIn(channels)
                .stride(1, 1)
                .nOut(20)
                .activation(Activation.IDENTITY)
                .build())
            // 下采樣妻枕,即池化
            .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build())
            // 卷積層
            .layer(new ConvolutionLayer.Builder(5, 5)
                .stride(1, 1) // nIn need not specified in later layers
                .nOut(50)
                .activation(Activation.IDENTITY)
                .build())
            // 下采樣,即池化
            .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build())
            // 稠密層粘驰,即全連接
            .layer(new DenseLayer.Builder().activation(Activation.RELU)
                .nOut(500)
                .build())
            // 輸出
            .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nOut(outputNum)
                .activation(Activation.SOFTMAX)
                .build())
            .setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
            .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        // 每十個迭代打印一次損失函數(shù)值
        net.setListeners(new ScoreIterationListener(10));

        log.info("神經(jīng)網(wǎng)絡(luò)共[{}]個參數(shù)", net.numParams());

        long startTime = System.currentTimeMillis();
        // 循環(huán)操作
        for (int i = 0; i < nEpochs; i++) {
            log.info("第[{}]個循環(huán)", i);
            net.fit(trainIter);
            Evaluation eval = net.evaluate(testIter);
            log.info(eval.stats());
            trainIter.reset();
            testIter.reset();
        }
        log.info("完成訓(xùn)練和測試屡谐,耗時[{}]毫秒", System.currentTimeMillis()-startTime);

        // 保存模型
        File ministModelPath = new File(BASE_PATH + "/minist-model.zip");
        ModelSerializer.writeModel(net, ministModelPath, true);
        log.info("最新的MINIST模型保存在[{}]", ministModelPath.getPath());
    }
}
  • 執(zhí)行上述代碼,日志輸出如下蝌数,訓(xùn)練和測試都順利完成愕掏,準確率達到0.9886:
21:19:15.355 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 1110 is 0.18300625613640034
21:19:15.365 [main] DEBUG org.nd4j.linalg.dataset.AsyncDataSetIterator - Manually destroying ADSI workspace
21:19:16.632 [main] DEBUG org.nd4j.linalg.dataset.AsyncDataSetIterator - Manually destroying ADSI workspace
21:19:16.642 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - 

========================Evaluation Metrics========================
 # of classes:    10
 Accuracy:        0.9886
 Precision:       0.9885
 Recall:          0.9886
 F1 Score:        0.9885
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)


=========================Confusion Matrix=========================
    0    1    2    3    4    5    6    7    8    9
---------------------------------------------------
  972    0    0    0    0    0    2    2    2    2 | 0 = 0
    0 1126    0    3    0    2    1    1    2    0 | 1 = 1
    1    1 1019    2    0    0    0    6    3    0 | 2 = 2
    0    0    1 1002    0    5    0    1    1    0 | 3 = 3
    0    0    2    0  971    0    3    2    1    3 | 4 = 4
    0    0    0    3    0  886    2    1    0    0 | 5 = 5
    6    2    0    1    1    5  942    0    1    0 | 6 = 6
    0    1    6    0    0    0    0 1015    1    5 | 7 = 7
    1    0    1    1    0    2    0    2  962    5 | 8 = 8
    1    2    1    3    5    3    0    2    1  991 | 9 = 9

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
21:19:16.643 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - 完成訓(xùn)練和測試,耗時[27467]毫秒
21:19:17.019 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - 最新的MINIST模型保存在[E:\temp\202106\26\minist-model.zip]

Process finished with exit code 0

關(guān)于準確率

  • 前面的測試結(jié)果顯示準確率為<font color="blue">0.9886</font>顶伞,這是<font color="red">1.0.0-beta6</font>版本DL4J的訓(xùn)練結(jié)果饵撑,如果換成<font color="red">1.0.0-beta7</font>,準確率可以達到<font color="blue">0.99</font>以上枝哄,您可以嘗試一下肄梨;

  • 至此,DL4J框架下的經(jīng)典卷積實戰(zhàn)就完成了挠锥,截止目前众羡,咱們的訓(xùn)練和測試工作都是CPU完成的,工作中CPU使用率的上升十分明顯蓖租,下一篇文章粱侣,咱們把今天的工作交給GPU執(zhí)行試試羊壹,看能否借助CUDA加速訓(xùn)練和測試工作;

你不孤單齐婴,欣宸原創(chuàng)一路相伴

  1. Java系列
  2. Spring系列
  3. Docker系列
  4. kubernetes系列
  5. 數(shù)據(jù)庫+中間件系列
  6. DevOps系列

歡迎關(guān)注公眾號:程序員欣宸

微信搜索「程序員欣宸」油猫,我是欣宸,期待與您一同暢游Java世界...
https://github.com/zq2599/blog_demos

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末柠偶,一起剝皮案震驚了整個濱河市情妖,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌诱担,老刑警劉巖毡证,帶你破解...
    沈念sama閱讀 217,277評論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異蔫仙,居然都是意外死亡料睛,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,689評論 3 393
  • 文/潘曉璐 我一進店門摇邦,熙熙樓的掌柜王于貴愁眉苦臉地迎上來恤煞,“玉大人,你說我怎么就攤上這事施籍【影牵” “怎么了?”我有些...
    開封第一講書人閱讀 163,624評論 0 353
  • 文/不壞的土叔 我叫張陵法梯,是天一觀的道長苔货。 經(jīng)常有香客問我,道長立哑,這世上最難降的妖魔是什么夜惭? 我笑而不...
    開封第一講書人閱讀 58,356評論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮铛绰,結(jié)果婚禮上诈茧,老公的妹妹穿的比我還像新娘。我一直安慰自己捂掰,他們只是感情好敢会,可當我...
    茶點故事閱讀 67,402評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著这嚣,像睡著了一般鸥昏。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上姐帚,一...
    開封第一講書人閱讀 51,292評論 1 301
  • 那天吏垮,我揣著相機與錄音,去河邊找鬼。 笑死膳汪,一個胖子當著我的面吹牛唯蝶,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播遗嗽,決...
    沈念sama閱讀 40,135評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼粘我,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了痹换?” 一聲冷哼從身側(cè)響起征字,我...
    開封第一講書人閱讀 38,992評論 0 275
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎娇豫,沒想到半個月后柔纵,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,429評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡锤躁,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,636評論 3 334
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了或详。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片系羞。...
    茶點故事閱讀 39,785評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖霸琴,靈堂內(nèi)的尸體忽然破棺而出椒振,到底是詐尸還是另有隱情,我是刑警寧澤梧乘,帶...
    沈念sama閱讀 35,492評論 5 345
  • 正文 年R本政府宣布澎迎,位于F島的核電站,受9級特大地震影響选调,放射性物質(zhì)發(fā)生泄漏夹供。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,092評論 3 328
  • 文/蒙蒙 一仁堪、第九天 我趴在偏房一處隱蔽的房頂上張望哮洽。 院中可真熱鬧,春花似錦弦聂、人聲如沸鸟辅。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,723評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽匪凉。三九已至,卻和暖如春捺檬,著一層夾襖步出監(jiān)牢的瞬間再层,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,858評論 1 269
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留树绩,地道東北人萨脑。 一個月前我還...
    沈念sama閱讀 47,891評論 2 370
  • 正文 我出身青樓,卻偏偏與公主長得像饺饭,于是被迫代替她去往敵國和親渤早。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 44,713評論 2 354

推薦閱讀更多精彩內(nèi)容