基于tensorflow搭建一個(gè)簡(jiǎn)單的CNN模型(code)


我們將要搭建一個(gè)簡(jiǎn)單的卷積神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)去提高手寫(xiě)數(shù)字的預(yù)測(cè)結(jié)果精度。

# Introductory CNN Model: MNIST Digits

# In this example, we will download the MNIST?handwritten

# digits and create a simple CNN network to predict the

# digit category (0-9)

主要分為以下幾個(gè)步驟:導(dǎo)入數(shù)據(jù)潦匈;創(chuàng)建模型的變量玩裙;搭建模型惫叛;采用批量化訓(xùn)練網(wǎng)絡(luò)昌阿;可視化loss殖告,accuracy等結(jié)果烘挫。


1.導(dǎo)入必要的庫(kù)和開(kāi)始一個(gè)圖譜會(huì)話

import tensorflow as tf

import numpy as np

import matplotlib as plt

from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

sess = tf.Session()

2.導(dǎo)入數(shù)據(jù)集和將圖片裝換為28 * 28大小的矩陣

data_dir = 'temp'?? #數(shù)據(jù)集存放的文件夾

mnist = read_data_sets(data_dir)????? #讀取數(shù)據(jù)集

#將訓(xùn)練和測(cè)試數(shù)據(jù)集圖片歸一化為28*28大小

train_xdata = np.array([np.reshape(x, (28,28)) for x in mnist.train.images])

test_xdata = np.array([np.reshape(x, (28,28)) for x in mnist.test.images])

train_labels = mnist.train.labels??? #訓(xùn)練數(shù)據(jù)集標(biāo)簽

test_labels = mnist.test.labels?????? #測(cè)試數(shù)據(jù)集標(biāo)簽

3.定義模型參數(shù)

batch_size = 100?????? #一個(gè)批量的圖片數(shù)量

learning_rate = 0.005?????????? #學(xué)習(xí)率

evaluation_size = 500????????? #模型驗(yàn)證數(shù)據(jù)集一個(gè)批量的數(shù)量

image_width = train_xdata[0].shape[0]????? #圖片的長(zhǎng) 28

image_height = train_xdata[0].shape[1]???? #圖片的寬 28

target_size = max(train_labels)+1??? #輸出類(lèi)別的個(gè)數(shù) 10

num_channels = 1???????????????????????????? # 通道數(shù)為1

generations = 500????????????????????????????? #迭代代數(shù)

eval_every = 5 ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?#每次5個(gè)generation

conv1_features = 25????????????????????????? #卷積核的個(gè)數(shù)

conv2_features = 50????????????????????????? #卷積核的個(gè)數(shù)

max_pool_size1 = 2????????????????????????? #池化層窗口大小

max_pool_size2 = 2????????????????????????? #池化層窗口大小

fully_connected_size1 = 100??????????? #全連接層大小

4.定義數(shù)據(jù)集的占位符

#輸入數(shù)據(jù)的張量大小

x_input_shape = (batch_size, image_width, image_height, num_channels)

#創(chuàng)建輸入訓(xùn)練數(shù)據(jù)的占位符

x_input = tf.placeholder(tf.float32, shape=x_input_shape)? ? ? ? ?

#創(chuàng)建一個(gè)批量訓(xùn)練結(jié)果的占位符

y_target = tf.placeholder(tf.int32, shape=batch_size)? ? ? ? ?

#驗(yàn)證圖片輸入張量

eval_input_shape = (evaluation_size, image_width, image_height,num_channels)

#創(chuàng)建輸入驗(yàn)證數(shù)據(jù)的占位符

eval_input = tf.placeholder(tf.float32, shape=eval_input_shape)? ??

#創(chuàng)建一個(gè)批量驗(yàn)證結(jié)果的占位符

eval_target = tf.placeholder(tf.int32, shape= evaluation_size )??

5.定義訓(xùn)練權(quán)重和偏置的變量

#定義第一個(gè)卷積核的參數(shù)诀艰,其中用tf.truncated_normal生成正太分布的數(shù)據(jù),#stddev(正態(tài)分布標(biāo)準(zhǔn)差)為0.1

conv1_weight = tf.Variable(tf.truncated_normal([4, 4, num_channels, conv1_features], stddev=0.1, dtype = tf.float32))

#定義第一個(gè)卷積核對(duì)應(yīng)的偏置

conv1_bias = tf.Variable(tf.zeros([conv1_features], dtype=tf.float32))

#定義第二個(gè)卷積核的參數(shù)饮六,其中用tf.truncated_normal生成正太分布的數(shù)據(jù)其垄,#stddev(正態(tài)分布標(biāo)準(zhǔn)差)為0.1

conv2_weight = tf.Variable(tf.truncated_normal([4, 4, num_channels, conv2_features], stddev=0.1, dtype = tf.float32))

#定義第二個(gè)卷積核對(duì)應(yīng)的偏置

conv2_bias = tf.Variable(tf.zeros([conv2_features], dtype=tf.float32))

6.定義全連接層的權(quán)重和偏置

#輸出卷積特征圖的大小

resulting_width = image_width // (max_pool_size1 * max_pool_size2)

resulting_height = image_height // (max_pool_size1 * max_pool_size2)

#將卷積層特征圖拉成一維向量

full1_input_size = resulting_width * resulting_height * conv2_features

#創(chuàng)建第一個(gè)全連接層權(quán)重和偏置

full1_weight =tf.Variable(tf.truncated_normal([full1_input_size,fully_connected_size1],

?????????????????????????????????????????????? stddev=0.1, dtype=tf.float32))

full1_bias = tf.Variable(tf.truncated_normal([fully_connected_size1], stddev=0.1,

?????????????????????????????????????????????? dtype=tf.float32))

#創(chuàng)建第二個(gè)全連接層權(quán)重和偏置

full2_weight = tf.Variable(tf.truncated_normal([fully_connected_size1,target_size],

?????????????????????????????????????????????? stddev=0.1, dtype=tf.float32))

full2_bias = tf.Variable(tf.truncated_normal([target_size], stddev=0.1,

?????????????????????????????????????????????? dtype=tf.float32))

7.定義網(wǎng)絡(luò)模型

def my_conv_net(input_data):

??? #First Conv-relu-maxpool layer

??? conv1 = tf.nn.conv2d(input_data, conv1_weight, strides=[1, 1, 1, 1], padding='SAME')

??? relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_bias))

??? max_pool1 = tf.nn.max_pool(relu1, ksize=[1, max_pool_size1, max_pool_size1, 1],? strides=[1, max_pool_size1, max_pool_size1, 1], padding='SAME')

??? # Second Conv-relu-maxpool layer

??? conv2 = tf.nn.conv2d(max_pool1, conv2_weight, strides=[1, 1, 1, 1], padding='SAME')

??? relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_bias))

??? max_pool2 = tf.nn.max_pool(relu2, ksize=[1, max_pool_size1, max_pool_size1, 1],? strides=[1, max_pool_size2, max_pool_size2, 1], padding='SAME')

??? #將輸出轉(zhuǎn)換為一個(gè)[1xN],為下一個(gè)全連接層輸入做準(zhǔn)備

??? final_conv_shape = max_pool2.get_shape().as_list()

??? final_shape = final_conv_shape[1] * final_conv_shape[2] * final_conv_shape[3]

??? flat_output = tf.reshape(max_pool2, [final_conv_shape[0], final_shape])

??? #First fully-connected layer

??? fully_connected1 = tf.nn.relu(tf.add(tf.add(tf.matmul(flat_output,full1_weight), full1_bias)))

?? ?# Second fully-connected layer

??? final_model_output = tf.add(tf.matmul(fully_connected1, full2_weight), full2_bias)

??? return (final_model_output)

8.定義網(wǎng)絡(luò)的訓(xùn)練數(shù)據(jù)和測(cè)數(shù)據(jù)

model_output = my_conv_net(x_input)

test_model_output = my_conv_net(eval_input)

9.使用Softmax函數(shù)作為loss function

loss = loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=model_output, labels=y_target))

10.接下來(lái)創(chuàng)建一個(gè)訓(xùn)練和測(cè)試的函數(shù)

prediction = tf.nn.softmax(model_output)

test_prediction = tf.nn.softmax(test_model_output)

# Create accuracy function

def get_accuracy(logits, targets):

? ? batch_predictions = np.argmax(logits, axis=1)

? ? num_correct = np.sum(np.equal(batch_predictions, targets))

return(100. * num_correct/batch_predictions.shape[0])

11.創(chuàng)建一個(gè)optimizer function

my_optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)

train_step = my_optimizer.minimize(loss)

# Initialize Variables

init = tf.initialize_all_variables()

sess.run(init)

12.開(kāi)始訓(xùn)練模型

train_loss = [ ]

train_acc = [ ]

test_acc = [ ]

for i in range(generations):

? ? rand_index = np.random.choice(len(train_xdata), size=batch_size)

? ? rand_x = train_xdata[rand_index]

? ? rand_x = np.expand_dims(rand_x, 3)

? ? rand_y = train_labels[rand_index]

? ? train_dict = {x_input: rand_x, y_target: rand_y}

? ? sess.run(train_step, feed_dict=train_dict)

? ? temp_train_loss, temp_train_preds = sess.run([loss,?prediction], feed_dict=train_dict)

? ? temp_train_acc = get_accuracy(temp_train_preds, rand_y)

? ? if (i+1) % eval_every == 0:

? ? ? ? eval_index = np.random.choice(len(test_xdata),?size=evaluation_size)

? ? ? ? eval_x = test_xdata[eval_index]

? ? ? ? eval_x = np.expand_dims(eval_x, 3)

? ? ? ? eval_y = test_labels[eval_index]

? ? ? ? test_dict = {eval_input: eval_x, eval_target: eval_y}

? ? ? ? test_preds = sess.run(test_prediction, feed_dict=test_dict)

? ? ? ? temp_test_acc = get_accuracy(test_preds, eval_y)

? ? ? ? # Record and print results

? ? ? ? train_loss.append(temp_train_loss)

? ? ? ? train_acc.append(temp_train_acc)

? ? ? ? test_acc.append(temp_test_acc)

? ? ? ? acc_and_loss = [(i+1), temp_train_loss, temp_train_acc,?temp_test_acc]

? ? ? ? acc_and_loss = [np.round(x,2) for x in acc_and_loss]

13.輸出結(jié)果

print('Generation # {}. Train Loss: {:.2f}. Train Acc (Test Acc):?{:.2f} ({:.2f})'.format(*acc_and_loss))

14.使用matplotlib顯示loss-accuracies曲線

eval_indices = range(0, generations, eval_every)

# Plot loss over time

plt.plot(eval_indices, train_loss, 'k-')

plt.title('Softmax Loss per Generation')

plt.xlabel('Generation')

plt.ylabel('Softmax Loss')

plt.show()

# Plot train and test accuracy

plt.plot(eval_indices, train_acc, 'k-', label='Train Set Accuracy')

plt.plot(eval_indices, test_acc, 'r--', label='Test Set Accuracy')

plt.title('Train and Test Accuracy')

plt.xlabel('Generation')

plt.ylabel('Accuracy')

plt.legend(loc='lower right')

plt.show()

圖1. 左圖是我們500 generations時(shí)的訓(xùn)練精度曲線。右圖是在500 generations時(shí)的softmax loss值

15.顯示最新一個(gè)批量的預(yù)測(cè)結(jié)果

# Plot the 6 of the last batch results:

actuals = rand_y[0:6]

predictions = np.argmax(temp_train_preds,axis=1)[0:6]

images = np.squeeze(rand_x[0:6])

Nrows = 2

Ncols = 3

for i in range(6):

? ? plt.subplot(Nrows, Ncols, i+1)

? ? plt.imshow(np.reshape(images[i], [28,28]), cmap='Greys_r')

? ? plt.title('Actual: ' + str(actuals[i]) + ' Pred: ' + str(predi

? ? ctions[i]),fontsize=10)

? ? frame = plt.gca()

? ? frame.axes.get_xaxis().set_visible(False)

? ? frame.axes.get_yaxis().set_visible(False)

圖2.?顯示出6張隨機(jī)圖片的真實(shí)結(jié)果和預(yù)測(cè)結(jié)果
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末卤橄,一起剝皮案震驚了整個(gè)濱河市绿满,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌窟扑,老刑警劉巖喇颁,帶你破解...
    沈念sama閱讀 216,470評(píng)論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異辜膝,居然都是意外死亡无牵,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,393評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門(mén)厂抖,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)茎毁,“玉大人,你說(shuō)我怎么就攤上這事忱辅∑咧” “怎么了?”我有些...
    開(kāi)封第一講書(shū)人閱讀 162,577評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵墙懂,是天一觀的道長(zhǎng)橡卤。 經(jīng)常有香客問(wèn)我,道長(zhǎng)损搬,這世上最難降的妖魔是什么碧库? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 58,176評(píng)論 1 292
  • 正文 為了忘掉前任,我火速辦了婚禮巧勤,結(jié)果婚禮上嵌灰,老公的妹妹穿的比我還像新娘。我一直安慰自己颅悉,他們只是感情好沽瞭,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,189評(píng)論 6 388
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著剩瓶,像睡著了一般驹溃。 火紅的嫁衣襯著肌膚如雪城丧。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書(shū)人閱讀 51,155評(píng)論 1 299
  • 那天豌鹤,我揣著相機(jī)與錄音亡哄,去河邊找鬼。 笑死傍药,一個(gè)胖子當(dāng)著我的面吹牛磺平,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播拐辽,決...
    沈念sama閱讀 40,041評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼拣挪,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了俱诸?” 一聲冷哼從身側(cè)響起菠劝,我...
    開(kāi)封第一講書(shū)人閱讀 38,903評(píng)論 0 274
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎睁搭,沒(méi)想到半個(gè)月后赶诊,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,319評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡园骆,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,539評(píng)論 2 332
  • 正文 我和宋清朗相戀三年舔痪,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片锌唾。...
    茶點(diǎn)故事閱讀 39,703評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡锄码,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出晌涕,到底是詐尸還是另有隱情滋捶,我是刑警寧澤,帶...
    沈念sama閱讀 35,417評(píng)論 5 343
  • 正文 年R本政府宣布余黎,位于F島的核電站重窟,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏惧财。R本人自食惡果不足惜巡扇,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,013評(píng)論 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望垮衷。 院中可真熱鬧厅翔,春花似錦、人聲如沸帘靡。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 31,664評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)描姚。三九已至涩赢,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間轩勘,已是汗流浹背筒扒。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 32,818評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留绊寻,地道東北人花墩。 一個(gè)月前我還...
    沈念sama閱讀 47,711評(píng)論 2 368
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像澄步,于是被迫代替她去往敵國(guó)和親冰蘑。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,601評(píng)論 2 353

推薦閱讀更多精彩內(nèi)容