前言
一般機(jī)器學(xué)習(xí)框架都使用MNIST作為入門(mén)闺魏。就像"Hello World"對(duì)于任何一門(mén)編程語(yǔ)言一樣堡妒,要想入門(mén)機(jī)器學(xué)習(xí)搀庶,就先要掌握MNIST锯仪。
筆者在學(xué)習(xí)的時(shí)候Tensorflow已經(jīng)成為十分流行的機(jī)器學(xué)習(xí)框架泵督,網(wǎng)上有大量的“資源”,但是大多都限于皮毛卵酪。
很多教程就是給你一段代碼然后隨便講兩句幌蚊,這樣對(duì)新手并不友好谤碳。
因此我萌生了寫(xiě)一個(gè)詳解的想法溃卡。
筆者是一名網(wǎng)絡(luò)工程在讀大學(xué)生,知識(shí)水平有限蜒简,未必能做到面面俱到且處處正確瘸羡,如有錯(cuò)誤請(qǐng)指出。
源代碼
- 訓(xùn)練集
請(qǐng)點(diǎn)擊此處下載搓茬。
提取碼:xgpy - 源代碼
在源代碼同一目錄下新建文件夾“訓(xùn)練集”犹赖,把百度云連接里面的.gz文件放入該文件夾。
# -*- coding: utf-8 -*-
import tensorflow as tf
import input_data
mnist = input_data.read_data_sets('./訓(xùn)練集', one_hot=True)
'''
#構(gòu)建運(yùn)算圖
'''
# X Y 都是占位符 占位而已 不表示具體的數(shù)據(jù)
x = tf.placeholder("float",[None,784]) # 圖像的大小為784;None表示第一個(gè)維度可以是任意長(zhǎng)度
# 一個(gè)Variable代表一個(gè)可修改的張量,它們可以用于計(jì)算輸入值卷仑,也可以在計(jì)算中被修改
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
# 計(jì)算交叉熵
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
# 梯度下降算法(gradient descent algorithm)
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
# 在運(yùn)行計(jì)算之前峻村,我們需要添加一個(gè)操作來(lái)初始化我們創(chuàng)建的變量:
init = tf.global_variables_initializer()
# 在一個(gè)Session里面啟動(dòng)我們的模型,并且初始化變量:
sess = tf.Session()
sess.run(init)
# 訓(xùn)練模型1000次
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
#print('-**-',accuracy,type(accuracy))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
詳解
這一大段代碼實(shí)現(xiàn)的功能是:
建立 y = w*x+b 的模型锡凝,其中x是輸入的
可以直觀的看到粘昨,以上代碼分為三部分:構(gòu)建圖、定義會(huì)話、啟動(dòng)圖张肾。
構(gòu)建圖
構(gòu)建圖也分為定義變量芭析、定義交叉熵、定義優(yōu)化方法吞瞪。
- 定義變量
由定義方法分類馁启,本實(shí)例中主要有兩種變量。
第一類是由tf.Variable()定義的w芍秆、b
第二類是由tf.placeholder()定義的y_惯疙、x
順帶提一句y = tf.nn.softmax(tf.matmul(x,W) + b)是這兩者結(jié)合起來(lái)的。
那么這兩類有什么區(qū)別呢妖啥?
一般而言螟碎,Varibale主要用來(lái)保存tensorflow圖中的一些結(jié)構(gòu)中的參數(shù),如本例中的w權(quán)重迹栓,b偏置掉分。需要初始化。
plceholder主要用來(lái)把要訓(xùn)練/測(cè)試的數(shù)據(jù)輸入模型克伊,每次訓(xùn)練plceholder都有不一樣的值酥郭。在Session.run(feed_dict={})中的參數(shù)確定實(shí)際的值。
可視化網(wǎng)頁(yè)
https://www.cs.ryerson.ca/~aharley/vis/fc/
https://www.cs.ryerson.ca/~aharley/vis/conv/flat.html
https://poloclub.github.io/cnn-explainer/