文章作者:Tyan
博客:noahsnail.com ?|? CSDN ?|? 簡書
本文主要是對Caffe中mnist數(shù)據(jù)集上訓(xùn)練的LeNet模型進(jìn)行結(jié)構(gòu)分析和可視化末盔。
import caffe
import numpy as np
import matplotlib.pyplot as plt
# 定義LeNet模型信息
deploy = 'lenet.prototxt'
model = 'lenet_iter_10000.caffemodel'
# 加載模型
net = caffe.Net(deploy, model, caffe.TEST)
# 計(jì)算均值
# blob = caffe.proto.caffe_pb2.BlobProto()
# bin_mean = open(mean_file, 'rb' ).read()
# blob.ParseFromString(bin_mean)
# arr = np.array(caffe.io.blobproto_to_array(blob))
# npy_mean = arr[0]
# mu = npy_mean.mean(1).mean(1)
# init transformer
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2, 0, 1))
# transformer.set_mean('data', mu)
transformer.set_raw_scale('data', 255)
# transformer.set_channel_swap('data', (2, 1, 0))
# get certain layer feature
def init(pimg, lay_name):
global transformer
global net
image = caffe.io.load_image(pimg, color = False)
image
transformed_image = transformer.preprocess('data', image)
net.blobs['data'].data[...] = transformed_image
output = net.forward()
result = output[lay_name]
return result
# Test
result = init('test.jpg', 'prob')
print result.shape
print result
(1, 10)
[[ 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
LeNet網(wǎng)絡(luò)的所有l(wèi)ayer以及l(fā)ayer的輸出數(shù)據(jù)
data: 輸入圖片數(shù)據(jù)大小為28*28
conv1: 20個(gè)卷積核,卷積之后feature map大小24*24
pool1: pooling后feature map大小變?yōu)?2*12, 共20層
conv2: 50個(gè)卷積核, 卷積之后feature map大小為8*8
pool2: pooling后feature map大小變?yōu)?*4, 共50層
ip1: 全連接層一, 500個(gè)結(jié)點(diǎn)
ip2: 全連接層二, 10個(gè)結(jié)點(diǎn)
prob: 對ip2進(jìn)行softmax
備注: conv1之后得到20個(gè)feature map, conv2有50個(gè)卷積核, 每個(gè)卷積核在20個(gè)feature map卷積之后, 20個(gè)卷積之后的feature map對應(yīng)位置上的點(diǎn)的數(shù)據(jù)累加之后取激活函數(shù)(ReLU)得到該卷積核的對應(yīng)的feature map, 因此conv2執(zhí)行之后的feature map個(gè)數(shù)為50, 而不是50*20.
# all layer name and blob shape
# blob shape is (batch_size, channel_dim, height, width).
for layer_name, blob in net.blobs.iteritems():
print layer_name + '\t' + str(blob.data.shape)
data (1, 1, 28, 28)
conv1 (1, 20, 24, 24)
pool1 (1, 20, 12, 12)
conv2 (1, 50, 8, 8)
pool2 (1, 50, 4, 4)
ip1 (1, 500)
ip2 (1, 10)
prob (1, 10)
LeNet網(wǎng)絡(luò)的權(quán)重(weights + biases)
conv1: 20個(gè)卷積核, weights大小為5*5, 20個(gè)biases
conv2: 50個(gè)卷積核, weights大小為5*5, 50個(gè)biases
ip1: conv2之后得到50個(gè)4*4大小的feature map, 排列起來大小為800, 與ip1的500個(gè)結(jié)點(diǎn)進(jìn)行全連接, weights個(gè)數(shù)為500*800, biases個(gè)數(shù)為500
ip2: ip1的500個(gè)結(jié)點(diǎn)與ip2的10個(gè)結(jié)點(diǎn)進(jìn)行全連接, weights個(gè)數(shù)為500*10, biases個(gè)數(shù)為10
# all layer name and parameters shape
# param[0] is weights, param[1] is biases
# weights shape is (output_channels, input_channels, filter_height, filter_width)
# biases shape is (output_channels,)
for layer_name, param in net.params.iteritems():
print layer_name + '\t' + str(param[0].data.shape) + '\t' + str(param[1].data.shape)
conv1 (20, 1, 5, 5) (20,)
conv2 (50, 20, 5, 5) (50,)
ip1 (500, 800) (500,)
ip2 (10, 500) (10,)
numpy pad
padding分為四部分
第一部分: (0, n ** 2 - data.shape[0]), 補(bǔ)充方陣的缺少的部分, 0表示前面不補(bǔ), 后面補(bǔ)n ** 2 - data.shape[0]列
第二部分: (0, 1)表示每個(gè)filter的前面不補(bǔ), 后面補(bǔ)1列, filter補(bǔ)了一行
第三部分: (0, 1)表示每個(gè)filter的前面不補(bǔ), 后面補(bǔ)1列, filter補(bǔ)了一列
第四部分: (0, 0)剩下的不補(bǔ)充數(shù)據(jù)
# param(weights) visualization
def visualization(data):
# normalize data for display
data = (data - data.min()) / (data.max() - data.min())
# force the number of filters to be square
n = int(np.ceil(np.sqrt(data.shape[0])))
# add some space between filters
padding = (((0, n ** 2 - data.shape[0]), (0, 1), (0, 1)) + ((0, 0),) * (data.ndim - 3))
data = np.pad(data, padding, mode = 'constant', constant_values = 1)
# tile the filters into an image
data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
plt.imshow(data, cmap='gray')
plt.axis('off')
plt.show()
# feature map visualization
feature_map = net.blobs['conv1'].data[0]
visualization(feature_map)
# filter visualization
filters = net.params['conv1'][0].data
visualization(filters.reshape(20, 5, 5))