歡迎訪問我的GitHub
https://github.com/zq2599/blog_demos
內(nèi)容:所有原創(chuàng)文章分類匯總及配套源碼斤程,涉及Java、Docker、Kubernetes喳整、DevOPS等该园;
本篇概覽
- 作為《DL4J》實戰(zhàn)的第三篇每篷,目標是在DL4J框架下創(chuàng)建經(jīng)典的LeNet-5卷積神經(jīng)網(wǎng)絡(luò)模型篇裁,對MNIST數(shù)據(jù)集進行訓(xùn)練和測試菲饼,本篇由以下內(nèi)容構(gòu)成:
- LeNet-5簡介
- MNIST簡介
- 數(shù)據(jù)集簡介
- 關(guān)于版本和環(huán)境
- 編碼
- 驗證
LeNet-5簡介
- 是Yann LeCun于1998年設(shè)計的卷積神經(jīng)網(wǎng)絡(luò)让腹,用于手寫數(shù)字識別远剩,例如當年美國很多銀行用其識別支票上的手寫數(shù)字,LeNet-5是早期卷積神經(jīng)網(wǎng)絡(luò)最有代表性的實驗系統(tǒng)之一
- LeNet-5網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示哨鸭,一共七層:C1 -> S2 -> C3 -> S4 -> C5 -> F6 -> OUTPUT
- 這張圖更加清晰明了(原圖地址:https://cuijiahua.com/blog/2018/01/dl_3.html)民宿,能夠很好的指導(dǎo)咱們在DL4J上的編碼:
- 按照上圖簡單分析一下,用于指導(dǎo)接下來的開發(fā):
- 每張圖片都是28*28的單通道像鸡,矩陣應(yīng)該是[1, 28,28]
- C1是卷積層活鹰,所用卷積核尺寸5*5,滑動步長1只估,卷積核數(shù)目20志群,所以尺寸變化是:28-5+1=24(想象為寬度為5的窗口在寬度為28的窗口內(nèi)滑動,能滑多少次)蛔钙,輸出矩陣是[20,24,24]
- S2是池化層锌云,核尺寸2*2,步長2吁脱,類型是MAX桑涎,池化操作后尺寸減半彬向,變成了[20,12,12]
- C3是卷積層,所用卷積核尺寸5*5攻冷,滑動步長1娃胆,卷積核數(shù)目50,所以尺寸變化是:12-5+1=8等曼,輸出矩陣[50,8,8]
- S4是池化層里烦,核尺寸2*2,步長2禁谦,類型是MAX胁黑,池化操作后尺寸減半,變成了[50,4,4]
- C5是全連接層(FC)州泊,神經(jīng)元數(shù)目500丧蘸,接relu激活函數(shù)
- 最后是全連接層Output,共10個節(jié)點拥诡,代表數(shù)字0到9触趴,激活函數(shù)是softmax
MNIST簡介
- MNIST是經(jīng)典的計算機視覺數(shù)據(jù)集,來源是National Institute of Standards and Technology (NIST渴肉,美國國家標準與技術(shù)研究所)冗懦,包含各種手寫數(shù)字圖片,其中訓(xùn)練集60,000張仇祭,測試集 10,000張披蕉,
- MNIST來源于250 個不同人的手寫,其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員.,測試集(test set) 也是同樣比例的手寫數(shù)字數(shù)據(jù)
- MNIST官網(wǎng):http://yann.lecun.com/exdb/mnist/
數(shù)據(jù)集簡介
- 從MNIST官網(wǎng)下載的原始數(shù)據(jù)并非圖片文件乌奇,需要按官方給出的格式說明做解析處理才能轉(zhuǎn)為一張張圖片没讲,這些事情顯然不是本篇的主題,因此咱們可以直接使用DL4J為我們準備好的數(shù)據(jù)集(下載地址稍后給出)礁苗,該數(shù)據(jù)集中是一張張獨立的圖片爬凑,這些圖片所在目錄的名字就是該圖片具體的數(shù)字,如下圖试伙,目錄<font color="blue">0</font>里面全是數(shù)字0的圖片:
- 上述數(shù)據(jù)集的下載地址有兩個:
- 可以在CSDN下載(0積分):https://download.csdn.net/download/boling_cavalry/19846603
- github:https://raw.githubusercontent.com/zq2599/blog_download_files/master/files/mnist_png.tar.gz
- 下載之后解壓開嘁信,是個名為<font color="blue">mnist_png</font>的文件夾,稍后的實戰(zhàn)中咱們會用到它
關(guān)于DL4J版本
- 《DL4J實戰(zhàn)》系列的源碼采用了maven的父子工程結(jié)構(gòu)疏叨,DL4J的版本在父工程<font color="blue">dlfj-tutorials</font>中定義為<font color="red">1.0.0-beta7</font>
- 本篇的代碼雖然還是<font color="blue">dlfj-tutorials</font>的子工程潘靖,但是DL4J版本卻使用了更低的<font color="red">1.0.0-beta6</font>,之所以這么做蚤蔓,是因為下一篇文章卦溢,咱們會把本篇的訓(xùn)練和測試工作交給GPU來完成,而對應(yīng)的CUDA庫只有<font color="red">1.0.0-beta6</font>
- 扯了這么多,可以開始編碼了
源碼下載
- 本篇實戰(zhàn)中的完整源碼可在GitHub下載到单寂,地址和鏈接信息如下表所示(https://github.com/zq2599/blog_demos):
名稱 | 鏈接 | 備注 |
---|---|---|
項目主頁 | https://github.com/zq2599/blog_demos | 該項目在GitHub上的主頁 |
git倉庫地址(https) | https://github.com/zq2599/blog_demos.git | 該項目源碼的倉庫地址贬芥,https協(xié)議 |
git倉庫地址(ssh) | git@github.com:zq2599/blog_demos.git | 該項目源碼的倉庫地址,ssh協(xié)議 |
- 這個git項目中有多個文件夾宣决,《DL4J實戰(zhàn)》系列的源碼在<font color="blue">dl4j-tutorials</font>文件夾下誓军,如下圖紅框所示:
- <font color="blue">dl4j-tutorials</font>文件夾下有多個子工程,本次實戰(zhàn)代碼在<font color="blue">simple-convolution</font>目錄下疲扎,如下圖紅框:
編碼
- 在父工程 <font color="blue">dl4j-tutorials</font>下新建名為 <font color="red">simple-convolution</font>的子工程,其pom.xml如下捷雕,可見這里的dl4j版本被指定為<font color="red">1.0.0-beta6</font>:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>dlfj-tutorials</artifactId>
<groupId>com.bolingcavalry</groupId>
<version>1.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>simple-convolution</artifactId>
<properties>
<dl4j-master.version>1.0.0-beta6</dl4j-master.version>
</properties>
<dependencies>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
</dependencies>
</project>
- 接下來按照前面的分析實現(xiàn)代碼椒丧,已經(jīng)添加了詳細注釋,就不再贅述了:
package com.bolingcavalry.convolution;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
@Slf4j
public class LeNetMNISTReLu {
// 存放文件的地址救巷,請酌情修改
// private static final String BASE_PATH = System.getProperty("java.io.tmpdir") + "/mnist";
private static final String BASE_PATH = "E:\\temp\\202106\\26";
public static void main(String[] args) throws Exception {
// 圖片像素高
int height = 28;
// 圖片像素寬
int width = 28;
// 因為是黑白圖像壶熏,所以顏色通道只有一個
int channels = 1;
// 分類結(jié)果,0-9浦译,共十種數(shù)字
int outputNum = 10;
// 批大小
int batchSize = 54;
// 循環(huán)次數(shù)
int nEpochs = 1;
// 初始化偽隨機數(shù)的種子
int seed = 1234;
// 隨機數(shù)工具
Random randNumGen = new Random(seed);
log.info("檢查數(shù)據(jù)集文件夾是否存在:{}", BASE_PATH + "/mnist_png");
if (!new File(BASE_PATH + "/mnist_png").exists()) {
log.info("數(shù)據(jù)集文件不存在棒假,請下載壓縮包并解壓到:{}", BASE_PATH);
return;
}
// 標簽生成器,將指定文件的父目錄作為標簽
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
// 歸一化配置(像素值從0-255變?yōu)?-1)
DataNormalization imageScaler = new ImagePreProcessingScaler();
// 不論訓(xùn)練集還是測試集精盅,初始化操作都是相同套路:
// 1. 讀取圖片帽哑,數(shù)據(jù)格式為NCHW
// 2. 根據(jù)批大小創(chuàng)建的迭代器
// 3. 將歸一化器作為預(yù)處理器
log.info("訓(xùn)練集的矢量化操作...");
// 初始化訓(xùn)練集
File trainData = new File(BASE_PATH + "/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
trainRR.initialize(trainSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);
// 擬合數(shù)據(jù)(實現(xiàn)類中實際上什么也沒做)
imageScaler.fit(trainIter);
trainIter.setPreProcessor(imageScaler);
log.info("測試集的矢量化操作...");
// 初始化測試集,與前面的訓(xùn)練集操作類似
File testData = new File(BASE_PATH + "/mnist_png/testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
testIter.setPreProcessor(imageScaler); // same normalization for better results
log.info("配置神經(jīng)網(wǎng)絡(luò)");
// 在訓(xùn)練中叹俏,將學習率配置為隨著迭代階梯性下降
Map<Integer, Double> learningRateSchedule = new HashMap<>();
learningRateSchedule.put(0, 0.06);
learningRateSchedule.put(200, 0.05);
learningRateSchedule.put(600, 0.028);
learningRateSchedule.put(800, 0.0060);
learningRateSchedule.put(1000, 0.001);
// 超參數(shù)
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
// L2正則化系數(shù)
.l2(0.0005)
// 梯度下降的學習率設(shè)置
.updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))
// 權(quán)重初始化
.weightInit(WeightInit.XAVIER)
// 準備分層
.list()
// 卷積層
.layer(new ConvolutionLayer.Builder(5, 5)
.nIn(channels)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
// 下采樣妻枕,即池化
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
// 卷積層
.layer(new ConvolutionLayer.Builder(5, 5)
.stride(1, 1) // nIn need not specified in later layers
.nOut(50)
.activation(Activation.IDENTITY)
.build())
// 下采樣,即池化
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
// 稠密層粘驰,即全連接
.layer(new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500)
.build())
// 輸出
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
// 每十個迭代打印一次損失函數(shù)值
net.setListeners(new ScoreIterationListener(10));
log.info("神經(jīng)網(wǎng)絡(luò)共[{}]個參數(shù)", net.numParams());
long startTime = System.currentTimeMillis();
// 循環(huán)操作
for (int i = 0; i < nEpochs; i++) {
log.info("第[{}]個循環(huán)", i);
net.fit(trainIter);
Evaluation eval = net.evaluate(testIter);
log.info(eval.stats());
trainIter.reset();
testIter.reset();
}
log.info("完成訓(xùn)練和測試屡谐,耗時[{}]毫秒", System.currentTimeMillis()-startTime);
// 保存模型
File ministModelPath = new File(BASE_PATH + "/minist-model.zip");
ModelSerializer.writeModel(net, ministModelPath, true);
log.info("最新的MINIST模型保存在[{}]", ministModelPath.getPath());
}
}
- 執(zhí)行上述代碼,日志輸出如下蝌数,訓(xùn)練和測試都順利完成愕掏,準確率達到0.9886:
21:19:15.355 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 1110 is 0.18300625613640034
21:19:15.365 [main] DEBUG org.nd4j.linalg.dataset.AsyncDataSetIterator - Manually destroying ADSI workspace
21:19:16.632 [main] DEBUG org.nd4j.linalg.dataset.AsyncDataSetIterator - Manually destroying ADSI workspace
21:19:16.642 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu -
========================Evaluation Metrics========================
# of classes: 10
Accuracy: 0.9886
Precision: 0.9885
Recall: 0.9886
F1 Score: 0.9885
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
=========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9
---------------------------------------------------
972 0 0 0 0 0 2 2 2 2 | 0 = 0
0 1126 0 3 0 2 1 1 2 0 | 1 = 1
1 1 1019 2 0 0 0 6 3 0 | 2 = 2
0 0 1 1002 0 5 0 1 1 0 | 3 = 3
0 0 2 0 971 0 3 2 1 3 | 4 = 4
0 0 0 3 0 886 2 1 0 0 | 5 = 5
6 2 0 1 1 5 942 0 1 0 | 6 = 6
0 1 6 0 0 0 0 1015 1 5 | 7 = 7
1 0 1 1 0 2 0 2 962 5 | 8 = 8
1 2 1 3 5 3 0 2 1 991 | 9 = 9
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
21:19:16.643 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - 完成訓(xùn)練和測試,耗時[27467]毫秒
21:19:17.019 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - 最新的MINIST模型保存在[E:\temp\202106\26\minist-model.zip]
Process finished with exit code 0
關(guān)于準確率
前面的測試結(jié)果顯示準確率為<font color="blue">0.9886</font>顶伞,這是<font color="red">1.0.0-beta6</font>版本DL4J的訓(xùn)練結(jié)果饵撑,如果換成<font color="red">1.0.0-beta7</font>,準確率可以達到<font color="blue">0.99</font>以上枝哄,您可以嘗試一下肄梨;
至此,DL4J框架下的經(jīng)典卷積實戰(zhàn)就完成了挠锥,截止目前众羡,咱們的訓(xùn)練和測試工作都是CPU完成的,工作中CPU使用率的上升十分明顯蓖租,下一篇文章粱侣,咱們把今天的工作交給GPU執(zhí)行試試羊壹,看能否借助CUDA加速訓(xùn)練和測試工作;
你不孤單齐婴,欣宸原創(chuàng)一路相伴
歡迎關(guān)注公眾號:程序員欣宸
微信搜索「程序員欣宸」油猫,我是欣宸,期待與您一同暢游Java世界...
https://github.com/zq2599/blog_demos