正在學(xué)習(xí)斯坦福的cs231n課程,該課程使用的是CIFAR-10數(shù)據(jù)集
該數(shù)據(jù)集可在管網(wǎng)下載
http://www.cs.toronto.edu/~kriz/cifar.html
下載并解壓涧郊,得到
如何導(dǎo)入數(shù)據(jù)
CIFAR-10數(shù)據(jù)集由pickle產(chǎn)生飒泻,因此也由pickle導(dǎo)入
import pickle
def load_file(filename):
with open(filename, 'rb') as fo:
data = pickle.load(fo, encoding='latin1')
return data
filename = 'D:/Download/cifar-10-batches-py/data_batch_1'
data = load_file(filename)
print(data.keys())//得到當(dāng)前文件的一些基本信息
當(dāng)前文件的一些基本信息
dict_keys(['batch_label', 'labels', 'data', 'filenames'])
NN分類(lèi)的思想
NN分類(lèi)并不需要訓(xùn)練,只需要將要判斷的圖和已有數(shù)據(jù)進(jìn)行比較即可
比較時(shí)計(jì)算目標(biāo)圖與每一個(gè)數(shù)據(jù)圖的范數(shù)一笋庄,范數(shù)一最小的數(shù)據(jù)圖所屬類(lèi)別即為目標(biāo)圖類(lèi)別
關(guān)于范數(shù)一與范數(shù)二代碼如下
import numpy as np
import pickle
filename = 'xxx'
filename_test = 'xxx'
class NearestNeighbor:
"""docstring for NearestNeighbor"""
def __init__(self):
pass
# 導(dǎo)入數(shù)據(jù)
def load_file(self, filename):
with open(filename, 'rb') as fo:
data = pickle.load(fo, encoding='latin1')
return data
# 訓(xùn)練模型,NN只是簡(jiǎn)單的導(dǎo)入即可效览,X是數(shù)據(jù)无切,n*3072丐枉,Y是數(shù)據(jù)標(biāo)簽瘦锹,n*1
def train(self, X, y):
self.Xtr = X
self.ytr = y
# 使用模型進(jìn)行預(yù)測(cè)籍嘹,X是test集的數(shù)據(jù)
def predict(self, X):
num_test = X.shape[0]# test數(shù)據(jù)個(gè)數(shù)
Ypred = np.zeros(num_test)# 初始化預(yù)測(cè)結(jié)果
for i in range(num_test):
distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)# 計(jì)算范數(shù)一
min_index = np.argmin(distances)# 尋范數(shù)一最小的數(shù)據(jù)
Ypred[i] = self.ytr[min_index]# 得到預(yù)測(cè)結(jié)果
return Ypred
net = NearestNeighbor()
data = net.load_file(filename)
test_batch = net.load_file(filename_test)
net.train(data['data'], data['labels'])
result = net.predict(test_batch['data'])
print(result)