TensorFlow技術解析與實戰(zhàn) 9.2 Mnist分類問題

# -*- coding: utf-8 -*-

import sys

import importlib

importlib.reload(sys)

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

# 加載數據

mnist = input_data.read_data_sets("./", one_hot=True)

# 構建回歸模型

x = tf.placeholder(tf.float32, [None, 784])

W = tf.Variable(tf.zeros([784, 10]))

b = tf.Variable(tf.zeros([10]))

y = tf.matmul(x, W) + b? # 預測值

# 定義損失函數和優(yōu)化器

y_ = tf.placeholder(tf.float32, [None, 10])

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_))

# 采用SGD作為優(yōu)化器

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# 訓練模型

sess = tf.InteractiveSession()

tf.global_variables_initializer().run()

for _ in range(1000):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys})

# 評估訓練好的模型

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))? #計算預測值和真實值

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))? #布爾型轉化為浮點數椿猎,并取平均值迟蜜,得到準確率

print(sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels}))? #計算模型在測試集上的準確率

0.9179


卷積神經網絡實現

# -*- coding: utf-8 -*-

import sys

import importlib

importlib.reload(sys)

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

# 加載數據

mnist = input_data.read_data_sets("./", one_hot=True)

trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

trX = trX.reshape(-1, 28, 28, 1)? # 28x28x1 input img

teX = teX.reshape(-1, 28, 28, 1)? # 28x28x1 input img

X = tf.placeholder("float", [None, 28, 28, 1])

Y = tf.placeholder("float", [None, 10])

def init_weights(shape):

return tf.Variable(tf.random_normal(shape, stddev=0.01))

w = init_weights([3, 3, 1, 32])? ? ? ? # patch 大小為 3 × 3 ,輸入維度為 1 ,輸出維度為 32

w2 = init_weights([3, 3, 32, 64])? ? ? # patch 大小為 3 × 3 ,輸入維度為 32 ,輸出維度為 64

w3 = init_weights([3, 3, 64, 128])? ? # patch 大小為 3 × 3 ,輸入維度為 64 ,輸出維度為 128

w4 = init_weights([128 * 4 * 4, 625])? # 全連接層,輸入維度為 128 × 4 × 4, 是上一層的輸出數據又三維的轉變成一維, 輸出維度為 625

w_o = init_weights([625, 10])? ? ? ? ? # 輸出層,輸入維度為 625, 輸出維度為 10 ,代表 10 類 (labels)

def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):

# 第一組卷積層及池化層

l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))

l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

l1 = tf.nn.dropout(l1, p_keep_conv)? # dropout 一些神經元

# 第二組卷積層及池化層

l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))

l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

l2 = tf.nn.dropout(l2, p_keep_conv)? # dropout 一些神經元

# 第三組卷積層及池化層

l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))

l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') # l3 shape=(?, 4, 4, 128)

l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]])? # reshape to (?, 128 * 4 * 4 = 2048)

l3 = tf.nn.dropout(l3, p_keep_conv)

# 全連接層,最后dropout

l4 = tf.nn.relu(tf.matmul(l3, w4))

l4 = tf.nn.dropout(l4, p_keep_hidden)

# 輸出層

pyx = tf.matmul(l4, w_o)

return pyx? #返回預測值

p_keep_conv = tf.placeholder("float")

p_keep_hidden = tf.placeholder("float")

py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden)

# 定義損失函數

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))

train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)

predict_op = tf.argmax(py_x, 1)

#訓練模型和評估模型

batch_size = 128

test_size = 256

with tf.Session() as sess:

tf.global_variables_initializer().run()

for i in range(100):

training_batch = zip(range(0, len(trX), batch_size), range(batch_size, len(trX)+1, batch_size))

for start, end in training_batch:

sess.run(train_op, feed_dict={X:trX[start:end], Y:trY[start:end], p_keep_conv:0.8, p_keep_hidden:0.5})

test_indices = np.arange(len(teX))

np.random.shuffle(test_indices)

test_indices = test_indices[0:test_size]

print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==

? ? ? ? ? ? sess.run(predict_op, feed_dict={X: teX[test_indices],

? ? ? ? ? ? p_keep_conv:1.0, p_keep_hidden: 1.0})))



0 0.953125

1 0.98046875

2 0.984375

3 0.9921875

4 0.98828125

5 0.9921875

6 1.0

7 0.99609375

8 0.9921875

9 0.99609375

10 0.99609375

11 0.984375

12 0.9921875

?著作權歸作者所有,轉載或內容合作請聯系作者
  • 序言:七十年代末碟嘴,一起剝皮案震驚了整個濱河市的圆,隨后出現的幾起案子鼓拧,更是在濱河造成了極大的恐慌,老刑警劉巖越妈,帶你破解...
    沈念sama閱讀 212,816評論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件季俩,死亡現場離奇詭異,居然都是意外死亡叮称,警方通過查閱死者的電腦和手機,發(fā)現死者居然都...
    沈念sama閱讀 90,729評論 3 385
  • 文/潘曉璐 我一進店門藐鹤,熙熙樓的掌柜王于貴愁眉苦臉地迎上來瓤檐,“玉大人,你說我怎么就攤上這事娱节∧域龋” “怎么了?”我有些...
    開封第一講書人閱讀 158,300評論 0 348
  • 文/不壞的土叔 我叫張陵肄满,是天一觀的道長谴古。 經常有香客問我,道長稠歉,這世上最難降的妖魔是什么掰担? 我笑而不...
    開封第一講書人閱讀 56,780評論 1 285
  • 正文 為了忘掉前任,我火速辦了婚禮怒炸,結果婚禮上带饱,老公的妹妹穿的比我還像新娘。我一直安慰自己阅羹,他們只是感情好勺疼,可當我...
    茶點故事閱讀 65,890評論 6 385
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著捏鱼,像睡著了一般执庐。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上导梆,一...
    開封第一講書人閱讀 50,084評論 1 291
  • 那天轨淌,我揣著相機與錄音迂烁,去河邊找鬼。 笑死猿诸,一個胖子當著我的面吹牛婚被,可吹牛的內容都是我干的。 我是一名探鬼主播梳虽,決...
    沈念sama閱讀 39,151評論 3 410
  • 文/蒼蘭香墨 我猛地睜開眼址芯,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了窜觉?” 一聲冷哼從身側響起谷炸,我...
    開封第一講書人閱讀 37,912評論 0 268
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎禀挫,沒想到半個月后旬陡,有當地人在樹林里發(fā)現了一具尸體,經...
    沈念sama閱讀 44,355評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡语婴,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 36,666評論 2 327
  • 正文 我和宋清朗相戀三年描孟,在試婚紗的時候發(fā)現自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片砰左。...
    茶點故事閱讀 38,809評論 1 341
  • 序言:一個原本活蹦亂跳的男人離奇死亡匿醒,死狀恐怖,靈堂內的尸體忽然破棺而出缠导,到底是詐尸還是另有隱情廉羔,我是刑警寧澤,帶...
    沈念sama閱讀 34,504評論 4 334
  • 正文 年R本政府宣布僻造,位于F島的核電站憋他,受9級特大地震影響,放射性物質發(fā)生泄漏髓削。R本人自食惡果不足惜竹挡,卻給世界環(huán)境...
    茶點故事閱讀 40,150評論 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望立膛。 院中可真熱鬧此迅,春花似錦、人聲如沸旧巾。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,882評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽鲁猩。三九已至坎怪,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間廓握,已是汗流浹背搅窿。 一陣腳步聲響...
    開封第一講書人閱讀 32,121評論 1 267
  • 我被黑心中介騙來泰國打工嘁酿, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人男应。 一個月前我還...
    沈念sama閱讀 46,628評論 2 362
  • 正文 我出身青樓闹司,卻偏偏與公主長得像,于是被迫代替她去往敵國和親沐飘。 傳聞我的和親對象是個殘疾皇子游桩,可洞房花燭夜當晚...
    茶點故事閱讀 43,724評論 2 351

推薦閱讀更多精彩內容