# 形狀 [307]
x = tf.placeholder(tf.float32, [None, 3072])
# [None]
y = tf.placeholder(tf.int64, [None])
# get_variable 表示如果已經(jīng)定義了 w 就使用定義好的 w杈湾,如果沒(méi)有定義新建一個(gè)變量 w
# (3072,1)
w = tf.get_variable('w', [x.get_shape()[-1], 1], initializer=tf.random_normal)
# (1,)
b = tf.get_variable('b', [1], initializer=tf.constant_initializer(0.0))
# y_ 將 w * x + b 現(xiàn)在
# [None,3072],[3072,1] = [None,1] (None,1)
y_ = tf.matmul(x, w) + b
# y_ 還只是一個(gè)內(nèi)積值,我們可以將其變成一個(gè)概率值骄瓣,變成概率值的方法是使用函數(shù) sigmoid 中對(duì)其進(jìn)行壓縮
# [None,1]
p_y_1 = tf.nn.sigmoid(y_)
# 得到概率為 1 的值就可以和真正 y 進(jìn)行差別分析,以為 y 的形狀(None) 和 p_y_1(None,1)不一樣需要進(jìn)行形狀修改
y_reshaped = tf.reshape(y, (-1, 1))
# 以為在 tensorFlow 對(duì)數(shù)據(jù)類型比較敏感派近,我們需要將 y_resphapded 類型從 int64 修改Wie float32
y_reshaped_float = tf.cast(y_reshaped, tf.float32)
# reduce_mean 是就均值而 square 是求平方
loss = tf.reduce_mean(tf.square(y_reshaped_float - p_y_1))
# 預(yù)測(cè)值通過(guò)將 p_y_1 和 0.5 進(jìn)行比較得到 true 或 false 來(lái)表預(yù)測(cè)值
predict = p_y_1 > 0.5
# [1,0,1,1,0,0,1]
correct_prediction = tf.equal(tf.cast(predict, tf.int64), y_reshaped)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))
class CifarData:
def __init__(self,filenames, need_shuffle):
all_data = []
all_labels = []
for filename in filenames:
data, labels = load_data(filename)
for item, label in zip(data, labels):
if label in [0,1]:
all_data.append(item)
all_labels.append(label)
self._data = np.vstack(all_data)
self._labels = np.hstack(all_labels)
self._num_examples = self._data.shape[0]
print(self._num_examples)
self._need_shuffle = need_shuffle
self._indicator = 0
if self._need_shuffle:
self._shuffle_data()
定義一個(gè)類 CifarData 來(lái)控制數(shù)據(jù)涧郊,need_shuffle 作為一個(gè)控制是否對(duì)數(shù)據(jù)進(jìn)行重新排序(洗牌)的標(biāo)識(shí)飒泻,當(dāng)我們處理訓(xùn)練數(shù)據(jù)集時(shí)候可以通過(guò)開(kāi)啟 need_shuffle 來(lái)得到更多隨機(jī)樣本,而對(duì)于測(cè)試數(shù)據(jù)集則不會(huì)開(kāi)啟該開(kāi)關(guān)噪窘。
all_data = []
all_labels = []
for filename in filenames:
data, labels = load_data(filename)
使用之間的 data_load 方法將文件中數(shù)據(jù)加載進(jìn)來(lái)笋庄。
for item, label in zip(data, labels):
if label in [0,1]:
all_data.append(item)
all_labels.append(label)
因?yàn)槲覀兲幚矶诸惖臄?shù)據(jù)集所有通過(guò)過(guò)濾得到標(biāo)簽為 0 或 1 的數(shù)據(jù)集
self._data = np.vstack(all_data)
self._labels = np.hstack(all_labels)
通過(guò)合并后轉(zhuǎn)換為 numpy 中的矩陣,vstack 將按縱向進(jìn)行合并形成一個(gè)矩陣倔监,而 hstack 則是按橫向進(jìn)行合并形成矩陣直砂。
self._num_examples = self._data.shape[0]
定義有多少個(gè)向量,然后就是定義 suffle 函數(shù)浩习,
def _shuffle_data(self):
p = np.random.permutation(self._num_examples)
self._data = self._data[p]
self._labels = self._labels[p]
首先通過(guò) random.permutation 得到一個(gè)排列哆键,這個(gè)函數(shù)從 0 到 _num_examples 進(jìn)行一個(gè)混排,然后用 p 對(duì) data 和 p 集合進(jìn)行洗牌瘦锹。
def next_batch(self,batch_size):
end_indicator = self._indicator + batch_size
if end_indicator > self._num_examples:
if self._need_shuffle:
self._shuffle_data()
else:
raise Exception("have no more examples")
if end_indicator > self._num_examples:
raise Exception("batch size is larger than all examples")
batch_data = self._data[self._indicator:end_indicator]
batch_labels = self._labels[self._indicator:end_indicator]
self._indicator = end_indicator
return batch_data,batch_labels
定義 next_batch 樣本籍嘹,會(huì)返回 batch_size 個(gè)樣本,
end_indicator = self._indicator + batch_size
if end_indicator > self._num_examples:
if self._need_shuffle:
self._shuffle_data()
self._indicator = 0
end_indicator = batch_size
else:
raise Exception("have no more examples")
如果 end_indicator 大于 self._num_examples 表示我們?nèi)≈禈颖镜慕刂刮恢贸隽藰颖緮?shù)弯院,這時(shí)如果是訓(xùn)練數(shù)據(jù)集就需要重新洗牌然后繼續(xù)獲取數(shù)據(jù)辱士,但如果不是訓(xùn)練數(shù)據(jù)集則拋出一個(gè)異常。
batch_data = self._data[self._indicator:end_indicator]
batch_labels = self._labels[self._indicator:end_indicator]
self._indicator = end_indicator
return batch_data, batch_labels
將這batch_size 間數(shù)據(jù)返回去听绳。
self._data = np.vstack(all_data)
self._labels = np.hstack(all_labels)
# 測(cè)試
print(self._data.shape)
print(self._labels.shape)
self._num_examples = self._data.shape[0]
在這個(gè)位置輸出一下颂碘,測(cè)試一下我們創(chuàng)建好的類是否正常工作
train_filenames = [os.path.join(CIFAR_DIR, 'data_batch_%d' % i)
for i in range(1, 6)]
test_filenames = [os.path.join(CIFAR_DIR, 'test_batch')]
train_data = CifarData(train_filenames, True)
我們知道訓(xùn)練數(shù)據(jù)集應(yīng)該有 50000 樣本以為每一個(gè) data_batch 有一個(gè) 10000 樣本,而又 0 - 9 十個(gè)類別(也就是圖片的類別)而因?yàn)檫^(guò)濾為 0椅挣,1 所以只要 10000 個(gè)數(shù)據(jù)
(10000, 3072)
(10000,)
batch_data, batch_labels = train_data.next_batch(10)
print(batch_data, batch_labels)
[[208 186 128 ... 100 97 97]
[ 55 59 65 ... 55 52 52]
[223 223 226 ... 61 58 52]
...
[160 111 71 ... 48 48 51]
[105 105 105 ... 50 50 49]
[252 248 248 ... 93 98 97]] [1 0 1 1 0 0 1 1 0 1]