概述
tf.nn.conv2d是TensorFlow里面實(shí)現(xiàn)卷積的函數(shù),參考文檔對(duì)它的介紹并不是很詳細(xì),實(shí)際上這是搭建卷積神經(jīng)網(wǎng)絡(luò)比較核心的一個(gè)方法烙如,非常重要么抗。
說明
tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)
參數(shù)
- input:指需要做卷積的輸入圖像,它要求是一個(gè)Tensor亚铁,具有[batch, in_height, in_width, in_channels]這樣的shape蝇刀,具體含義是[訓(xùn)練時(shí)一個(gè)batch的圖片數(shù)量, 圖片高度, 圖片寬度, 圖像通道數(shù)],注意這是一個(gè)4維的Tensor徘溢,要求類型為float32和float64其中之一
- filter:相當(dāng)于CNN中的卷積核吞琐,它要求是一個(gè)Tensor,具有[filter_height, filter_width, in_channels, out_channels]這樣的shape然爆,具體含義是[卷積核的高度站粟,卷積核的寬度,圖像通道數(shù)曾雕,卷積核個(gè)數(shù)]奴烙,要求類型與參數(shù)input相同,有一個(gè)地方需要注意剖张,第三維in_channels缸沃,就是參數(shù)input的第四維
- strides:卷積時(shí)在圖像每一維的步長,這是一個(gè)一維的向量修械,長度4
- padding:string類型的量,只能是"SAME","VALID"其中之一检盼,這個(gè)值決定了不同的卷積方法肯污,當(dāng)其為‘SAME’時(shí),表示卷積核可以停留在圖像邊緣吨枉。
- use_cudnn_on_gpu:bool類型蹦渣,是否使用cudnn加速,默認(rèn)為true
- name:指定該操作的name
返回
結(jié)果返回一個(gè)Tensor貌亭,這個(gè)輸出柬唯,就是我們常說的feature map
實(shí)例
1.考慮一種最簡單的情況,現(xiàn)在有一張3×3單通道的圖像(對(duì)應(yīng)的shape:[1圃庭,3锄奢,3,1])剧腻,用一個(gè)1×1的卷積核(對(duì)應(yīng)的shape:[1拘央,1,1书在,1])去做卷積灰伟,最后會(huì)得到一張3×3的feature map。輸出:[1,3, 3, 1]
input_arg = tf.Variable(tf.ones([1, 3, 3, 1]))
filter_arg = tf.Variable(tf.ones([1, 1, 1, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 3, 3, 1],
use_cudnn_on_gpu=False, padding='VALID')
--------------case1--------------
[[[[ 1.]
[ 1.]
[ 1.]]
[[ 1.]
[ 1.]
[ 1.]]
[[ 1.]
[ 1.]
[ 1.]]]]
2.增加圖片的通道數(shù)儒旬,使用一張3×3五通道的圖像(對(duì)應(yīng)的shape:[1栏账,3帖族,3,5])挡爵,用一個(gè)1×1的卷積核(對(duì)應(yīng)的shape:[1竖般,1,1了讨,1])去做卷積捻激,仍然是一張3×3的feature map,這就相當(dāng)于每一個(gè)像素點(diǎn)前计,卷積核都與該像素點(diǎn)的每一個(gè)通道做點(diǎn)積胞谭。輸出:[1, 3, 3, 1]
input_arg = tf.Variable(tf.ones([1, 3, 3, 5]))
filter_arg = tf.Variable(tf.ones([1, 1, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1],
use_cudnn_on_gpu=False, padding='VALID')
--------------case2--------------
[[[[ 5.]
[ 5.]
[ 5.]]
[[ 5.]
[ 5.]
[ 5.]]
[[ 5.]
[ 5.]
[ 5.]]]]
3.把卷積核擴(kuò)大,現(xiàn)在用3×3的卷積核做卷積男杈,最后的輸出是一個(gè)值丈屹,相當(dāng)于情況2的feature map所有像素點(diǎn)的值求和。輸出:[1, 1, 1, 1]
input_arg = tf.Variable(tf.ones([1, 3, 3, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1],
use_cudnn_on_gpu=False, padding='VALID')
--------------case3--------------
[[[[ 45.]]]]
4.使用更大的圖片將情況2的圖片擴(kuò)大到5×5伶棒,仍然是3×3的卷積核旺垒,令步長為1,輸出3×3的feature map肤无。
注意我們可以把這種情況看成情況2和情況3的中間狀態(tài)先蒋,卷積核以步長1滑動(dòng)遍歷全圖,以下x表示的位置宛渐,表示卷積核停留的位置竞漾,每停留一個(gè),輸出feature map的一個(gè)像素窥翩。輸出:[1, 3, 3, 1]
.....
.xxx.
.xxx.
.xxx.
.....
input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1],
use_cudnn_on_gpu=False, padding='VALID')
--------------case4--------------
[[[[ 45.]
[ 45.]
[ 45.]]
[[ 45.]
[ 45.]
[ 45.]]
[[ 45.]
[ 45.]
[ 45.]]]]
5.上面我們一直令參數(shù)padding的值為‘VALID’业岁,當(dāng)其為‘SAME’時(shí),表示卷積核可以停留在圖像邊緣寇蚊,輸出:[1, 5, 5, 1]
xxxxx
xxxxx
xxxxx
xxxxx
xxxxx
input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1],
use_cudnn_on_gpu=False, padding='SAME')
--------------case5--------------
[[[[ 20.]
[ 30.]
[ 30.]
[ 30.]
[ 20.]]
[[ 30.]
[ 45.]
[ 45.]
[ 45.]
[ 30.]]
[[ 30.]
[ 45.]
[ 45.]
[ 45.]
[ 30.]]
[[ 30.]
[ 45.]
[ 45.]
[ 45.]
[ 30.]]
[[ 20.]
[ 30.]
[ 30.]
[ 30.]
[ 20.]]]]
6.如果卷積核有多個(gè)笔时,此時(shí)輸出7張5×5的feature map。輸出:[1, 5, 5, 7]
input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 7]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1],
use_cudnn_on_gpu=False, padding='SAME')
oplist.append([op, 'case6'])
--------------case6--------------
[[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]]
7.步長不為1的情況仗岸,文檔里說了對(duì)于圖片允耿,因?yàn)橹挥袃删S,通常strides取[1扒怖,stride右犹,stride,1]姚垃。輸出:[1, 3, 3, 7]
x.x.x
.....
x.x.x
.....
x.x.x
input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 7]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 2, 2, 1],
use_cudnn_on_gpu=False, padding='SAME')
oplist.append([op, 'case7'])
--------------case7--------------
[[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]]
8.如果batch值不為1念链,同時(shí)輸入4張圖,輸出的每張圖,都有7張3×3的feature map掂墓。輸出:[4, 3, 3, 7]
input_arg = tf.Variable(tf.ones([4, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 7]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 2, 2, 1],
use_cudnn_on_gpu=False, padding='SAME')
--------------case8--------------
[[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]
[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]
[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]
[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]]
代碼
import tensorflow as tf
oplist = []
# input_arg = [batch, in_height, in_width, in_channels]
# filter_arg = [filter_height, filter_width, in_channels, out_channels]
# case 1
input_arg = tf.Variable(tf.ones([1, 3, 3, 1]))
filter_arg = tf.Variable(tf.ones([1, 1, 1, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1],
use_cudnn_on_gpu=False, padding='VALID')
oplist.append([op, 'case1'])
# case 2
input_arg = tf.Variable(tf.ones([1, 3, 3, 5]))
filter_arg = tf.Variable(tf.ones([1, 1, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1],
use_cudnn_on_gpu=False, padding='VALID')
oplist.append([op, 'case2'])
# case 3
input_arg = tf.Variable(tf.ones([1, 3, 3, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1],
use_cudnn_on_gpu=False, padding='VALID')
oplist.append([op, 'case3'])
# case 4
input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1],
use_cudnn_on_gpu=False, padding='VALID')
oplist.append([op, 'case4'])
# case 5
input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1],
use_cudnn_on_gpu=False, padding='SAME')
oplist.append([op, 'case5'])
# case 6
input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 7]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1],
use_cudnn_on_gpu=False, padding='SAME')
oplist.append([op, 'case6'])
# case 7
input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 7]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 2, 2, 1],
use_cudnn_on_gpu=False, padding='SAME')
oplist.append([op, 'case7'])
# case 8
input_arg = tf.Variable(tf.ones([4, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 7]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 2, 2, 1],
use_cudnn_on_gpu=False, padding='SAME')
oplist.append([op, 'case8'])
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
for aop in oplist:
print('--------------{}--------------'.format(aop[1]))
print(sess.run(aop[0]))
print('\n')
--------------case1--------------
[[[[ 1.]
[ 1.]
[ 1.]]
[[ 1.]
[ 1.]
[ 1.]]
[[ 1.]
[ 1.]
[ 1.]]]]
--------------case2--------------
[[[[ 5.]
[ 5.]
[ 5.]]
[[ 5.]
[ 5.]
[ 5.]]
[[ 5.]
[ 5.]
[ 5.]]]]
--------------case3--------------
[[[[ 45.]]]]
--------------case4--------------
[[[[ 45.]
[ 45.]
[ 45.]]
[[ 45.]
[ 45.]
[ 45.]]
[[ 45.]
[ 45.]
[ 45.]]]]
--------------case5--------------
[[[[ 20.]
[ 30.]
[ 30.]
[ 30.]
[ 20.]]
[[ 30.]
[ 45.]
[ 45.]
[ 45.]
[ 30.]]
[[ 30.]
[ 45.]
[ 45.]
[ 45.]
[ 30.]]
[[ 30.]
[ 45.]
[ 45.]
[ 45.]
[ 30.]]
[[ 20.]
[ 30.]
[ 30.]
[ 30.]
[ 20.]]]]
--------------case6--------------
[[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]]
--------------case7--------------
[[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]]
--------------case8--------------
[[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]
[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]
[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]
[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]
[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]
[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]]