參考CS231n,將KNN 跑起來了懈万,成功將系統(tǒng)搞死,钞速,內(nèi)存和計(jì)算能力開銷太大。
以下代碼 切記不用輕易跑苹威。驾凶。
http://www.cs.toronto.edu/~kriz/cifar.html
code:
import os
import sys
import numpy as np
import pickle
def load_CIFAR_batch(filename):
"""
cifar-10數(shù)據(jù)集是分batch存儲(chǔ)的,這是載入單個(gè)batch
@參數(shù) filename: cifar文件名
@r返回值: X, Y: cifar batch中的 data 和 labels
"""
with open(filename,"rb") as f :
datadict = pickle.load(f,encoding='iso-8859-1')
print(filename)
X=datadict['data']
Y=datadict['labels']
X=X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
Y=np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
"""
讀取載入整個(gè) CIFAR-10 數(shù)據(jù)集
@參數(shù) ROOT: 根目錄名
@return: X_train, Y_train: 訓(xùn)練集 data 和 labels
X_test, Y_test: 測(cè)試集 data 和 labels
"""
xs=[]
ys=[]
for b in range(1,6):
f=os.path.join(ROOT, "data_batch_%d" % (b, ))
X, Y=load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
X_train=np.concatenate(xs)
Y_train=np.concatenate(ys)
del X, Y
X_test, Y_test=load_CIFAR_batch(os.path.join(ROOT, "test_batch"))
return X_train, Y_train, X_test, Y_test
# 載入訓(xùn)練和測(cè)試數(shù)據(jù)集
X_train, Y_train, X_test, Y_test = load_CIFAR10('data/cifar/')
# 把32*32*3的多維數(shù)組展平
Xtr_rows = X_train.reshape(X_train.shape[0], 32 * 32 * 3) # Xtr_rows : 50000 x 3072
Xte_rows = X_test.reshape(X_test.shape[0], 32 * 32 * 3) # Xte_rows : 10000 x 3072
class NearestNeighbor:
def __init__(self):
pass
def train(self, X, y):
"""
這個(gè)地方的訓(xùn)練其實(shí)就是把所有的已有圖片讀取進(jìn)來 -_-||
"""
# the nearest neighbor classifier simply remembers all the training data
self.Xtr = X
self.ytr = y
def predict(self, X):
"""
所謂的預(yù)測(cè)過程其實(shí)就是掃描所有訓(xùn)練集中的圖片窟哺,計(jì)算距離技肩,取最小的距離對(duì)應(yīng)圖片的類目
"""
num_test = X.shape[0]
# 要保證維度一致哦
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
# 把訓(xùn)練集掃一遍 -_-||
for i in range(num_test):
# 計(jì)算l1距離,并找到最近的圖片
distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
min_index = np.argmin(distances) # 取最近圖片的下標(biāo)
Ypred[i] = self.ytr[min_index] # 記錄下label
return Ypred
nn = NearestNeighbor() # 初始化一個(gè)最近鄰對(duì)象
nn.train(Xtr_rows, Y_train) # 訓(xùn)練...其實(shí)就是讀取訓(xùn)練集
Yte_predict = nn.predict(Xte_rows) # 預(yù)測(cè)
# 比對(duì)標(biāo)準(zhǔn)答案旋奢,計(jì)算準(zhǔn)確率
print ('accuracy: %f' % ( np.mean(Yte_predict == Y_test)))