基于SVM的思想做CIFAR-10圖像分類

SVM

回顧一下之前的SVM丽惭,找到一個間隔最大的函數(shù),使得正負(fù)樣本離該函數(shù)是最遠(yuǎn)的,是否最遠(yuǎn)不是看哪個點離函數(shù)最遠(yuǎn)沮协,而是找到一個離函數(shù)最近的點看他是不是和該分割函數(shù)離的最近的辆飘。


使用large margin來regularization啦辐。
之前講SVM的算法:http://www.reibang.com/p/8fd28df734a0

線性分類

線性SVM就是一種線性分類的方法。輸入x蜈项,輸出y芹关,每一個樣本的權(quán)重是w,偏置項bias是b紧卒。得分函數(shù)s = wx +b
算出這么多個類別充边,哪一個類別的分?jǐn)?shù)高,那就是哪個類別。比如要做的圖像識別有三個類別[cat,ship,dog]浇冰,假設(shè)這張圖片有4個像素贬媒,拉伸成單列:

得到的結(jié)果很明顯是dog分?jǐn)?shù)最大,cat的分?jǐn)?shù)最低肘习,但是圖片很明顯是貓际乘,什么分類器是錯誤的。
一般來說習(xí)慣會把w和b合并了漂佩,x加上一個全為1的列脖含,于是有
W=[w;b];X = [x;1]

損失函數(shù)

之前的SVM是把正負(fù)樣本離分割函數(shù)有足夠的空間,雖然正確的是貓投蝉,但是貓的得分是最低的养葵,常規(guī)方法是將貓的分?jǐn)?shù)提高,這樣才可以提高貓的正確率瘩缆。但是SVM里面是要求一個間隔最大化关拒,提到這里來說,其實就是cat score不僅僅是要大于其他的分?jǐn)?shù)庸娱,而且是要有一個最低閾值着绊,cat score不能低于這個分?jǐn)?shù)。
所以正確的分類score應(yīng)該是要大于其他的分類score一個閾值:s_{y_i} >= s_j + \triangle
s_{y_i}就是正確分類的分?jǐn)?shù)熟尉,s_j就是其他分類的分?jǐn)?shù)归露。所以,這個損失函數(shù)就是:Loss_{y_i} = \sum_{j != y_i}max(0, s_j - s_{y_i}+\triangle)只有正確的分?jǐn)?shù)比其他的都大于一個閾值才為0斤儿,否則都是有損失的剧包。


只有
s_j-s_{y_i}+\triangle <= 0
損失函數(shù)才是0的。這種損失函數(shù)稱為合頁損失函數(shù)往果,用的就是SVM間隔最大化的思想解決玄捕,如果損失函數(shù)為0,那么不用求解了棚放,如果損失函數(shù)不為0枚粘,就可以用梯度下降求解。max求解梯度下降有點不現(xiàn)實飘蚯,所以自然就有了square的合頁損失函數(shù)馍迄。
Loss_{y_i} = \sum_{j != y_i}max(0, s_j - s_{y_i}+\triangle)^2

這種squared hinge loss SVM與linear hinge loss SVM相比較,特點是對違背間隔閾值要求的點加重懲罰局骤,違背的越大攀圈,懲罰越大。某些實際應(yīng)用中峦甩,squared hinge loss SVM的效果更好一些赘来。具體使用哪個现喳,可以根據(jù)實際問題,進(jìn)行交叉驗證再確定犬辰。
對于
\triangle
的設(shè)置嗦篱,之前SVM其實討論過,對于一個平面是可以隨意伸縮的幌缝,只需要增大w和b就可以隨意把
\triangle
增大灸促,所以把它定為1,也就是設(shè)置
\triangle = 1
涵卵。因為w的增長或縮小完全可以抵消
\triangle
的影響浴栽。這個時候損失函數(shù)就是:
Loss_{y_i} = \sum_{j != y_i}max(0, s_j - s_{y_i}+1)

最后還要增加的就是過擬合,regularization的限制了轿偎。L2正則化:
R(W) = \sum_{k}\sum_{l}w_{k,l}^2

加上正則化之后就是:
Loss = \frac{1}{N}Loss + \lambda R(W)

N是訓(xùn)練樣本的個數(shù)典鸡,取平均損失函數(shù),
\lambda
就是懲罰的力度了坏晦,可以小也可以大萝玷,如果大了可能w不足以抵消正負(fù)樣本之間的間隔,可能會欠擬合英遭,因為
\triangle = 1
是在w可以自由伸縮達(dá)到的條件,如果w太小亦渗,可能就不足以增長到1了挖诸。如果小了,可能就會造成overfit法精。對于參數(shù)b就沒有這么講究了多律。

代碼實現(xiàn)

首先是對CIFAR10的數(shù)據(jù)讀取:


def load_pickle(f):
    version = platform.python_version_tuple()
    if version[0] == '2':
        return pickle.load(f)
    elif version[0] == '3':
        return pickle.load(f, encoding='latin1')
    raise ValueError("invalid python version: {}".format(version))

def loadCIFAR_batch(filename):
    with open(filename, 'rb') as f:
        datadict = load_pickle(f)
        x = datadict['data']
        y = datadict['labels']
        x = x.reshape(10000, 3, 32, 32).transpose(0, 3, 2, 1).astype('float')
        y = np.array(y)
        return x, y

def loadCIFAR10(root):
    xs = []
    ys = []
    for b in range(1, 6):
        f = os.path.join(root, 'data_batch_%d' % (b, ))
        x, y = loadCIFAR_batch(f)
        xs.append(x)
        ys.append(y)
    X = np.concatenate(xs)
    Y = np.concatenate(ys)
    x_test, y_test = loadCIFAR_batch(os.path.join(root, 'test_batch'))
    return X, Y, x_test, y_test

首先要讀入每一個文件的數(shù)據(jù)搂蜓,先用load_pickle把文件讀成字典形式狼荞,取出來。因為常規(guī)的圖片都是(數(shù)量帮碰,高相味,寬,RGB顏色)殉挽,在loadCIFAR_batch要用transpose來把維度調(diào)換一下丰涉。最后把每一個文件的數(shù)據(jù)都集合起來。
之后就是數(shù)據(jù)的格式調(diào)整了:

def data_validation(x_train, y_train, x_test, y_test):
    num_training = 49000
    num_validation = 1000
    num_test = 1000
    num_dev = 500
    mean_image = np.mean(x_train, axis=0)
    x_train -= mean_image
    mask = range(num_training, num_training + num_validation)
    X_val = x_train[mask]
    Y_val = y_train[mask]
    mask = range(num_training)
    X_train = x_train[mask]
    Y_train = y_train[mask]
    mask = np.random.choice(num_training, num_dev, replace=False)
    X_dev = x_train[mask]
    Y_dev = y_train[mask]
    mask = range(num_test)
    X_test = x_test[mask]
    Y_test = y_test[mask]
    X_train = np.reshape(X_train, (X_train.shape[0], -1))
    X_val = np.reshape(X_val, (X_val.shape[0], -1))
    X_test = np.reshape(X_test, (X_test.shape[0], -1))
    X_dev = np.reshape(X_dev, (X_dev.shape[0], -1))
    X_train = np.hstack([X_train, np.ones((X_train.shape[0], 1))])
    X_val = np.hstack([X_val, np.ones((X_val.shape[0], 1))])
    X_test = np.hstack([X_test, np.ones((X_test.shape[0], 1))])
    X_dev = np.hstack([X_dev, np.ones((X_dev.shape[0], 1))])
    return X_val, Y_val, X_train, Y_train, X_dev, Y_dev, X_test, Y_test
    pass

數(shù)據(jù)要變成一個長條斯碌。
先看看數(shù)據(jù)長啥樣:

def showPicture(x_train, y_train):
    classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    num_classes = len(classes)
    samples_per_classes = 7
    for y, cls in enumerate(classes):
        idxs = np.flatnonzero(y_train == y)
        idxs = np.random.choice(idxs, samples_per_classes, replace=False)
        for i, idx in enumerate(idxs):
            plt_index = i*num_classes +y + 1
            plt.subplot(samples_per_classes, num_classes, plt_index)
            plt.imshow(x_train[idx].astype('uint8'))
            plt.axis('off')
            if i == 0:
                plt.title(cls)
    plt.show()

然后就是使用谷歌的公式了:

    def loss(self, x, y, reg):
        loss = 0.0
        dw = np.zeros(self.W.shape)
        num_train = x.shape[0]
        scores = x.dot(self.W)
        correct_class_score = scores[range(num_train), list(y)].reshape(-1, 1)
        margin = np.maximum(0, scores - correct_class_score + 1)
        margin[range(num_train), list(y)] = 0
        loss = np.sum(margin)/num_train + 0.5 * reg * np.sum(self.W*self.W)

        num_classes = self.W.shape[1]
        inter_mat = np.zeros((num_train, num_classes))
        inter_mat[margin > 0] = 1
        inter_mat[range(num_train), list(y)] = 0
        inter_mat[range(num_train), list(y)] = -np.sum(inter_mat, axis=1)

        dW = (x.T).dot(inter_mat)
        dW = dW/num_train + reg*self.W
        return loss, dW
        pass

操作都是常規(guī)操作一死,算出score然后求loss最后SGD求梯度更新W。

    def train(self, X, y, learning_rate=1e-3, reg=1e-5, num_iters=100,batch_size=200, verbose=False):
        num_train, dim = X.shape
        num_classes = np.max(y) + 1
        if self.W is None:
            self.W = 0.001 * np.random.randn(dim, num_classes)
        # Run stochastic gradient descent to optimize W
        loss_history = []
        for it in range(num_iters):
            X_batch = None
            y_batch = None
            idx_batch = np.random.choice(num_train, batch_size, replace = True)
            X_batch = X[idx_batch]
            y_batch = y[idx_batch]
            # evaluate loss and gradient
            loss, grad = self.loss(X_batch, y_batch, reg)
            loss_history.append(loss)
            self.W -=  learning_rate * grad
            if verbose and it % 100 == 0:
                print('iteration %d / %d: loss %f' % (it, num_iters, loss))
        return loss_history
        pass

預(yù)測:

    def predict(self, X):
        y_pred = np.zeros(X.shape[0])
        scores = X.dot(self.W)
        y_pred = np.argmax(scores, axis = 1)
        return y_pred

最后運行函數(shù):

 svm = LinearSVM()
    tic = time.time()
    cifar10_name = '../Data/cifar-10-batches-py'
    x_train, y_train, x_test, y_test = loadCIFAR10(cifar10_name)
    X_val, Y_val, X_train, Y_train, X_dev, Y_dev, X_test, Y_test = data_validation(x_train, y_train, x_test, y_test)
    loss_hist = svm.train(X_train, Y_train, learning_rate=1e-7, reg=2.5e4,
                          num_iters=3000, verbose=True)
    toc = time.time()
    print('That took %fs' % (toc - tic))
    plt.plot(loss_hist)
    plt.xlabel('Iteration number')
    plt.ylabel('Loss value')
    plt.show()
    y_test_pred = svm.predict(X_test)
    test_accuracy = np.mean(Y_test == y_test_pred)
    print('accuracy: %f' % test_accuracy)
    w = svm.W[:-1, :]  # strip out the bias
    w = w.reshape(32, 32, 3, 10)
    w_min, w_max = np.min(w), np.max(w)
    classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    for i in range(10):
        plt.subplot(2, 5, i + 1)
        wimg = 255.0 * (w[:, :, :, i].squeeze() - w_min) / (w_max - w_min)
        plt.imshow(wimg.astype('uint8'))
        plt.axis('off')
        plt.title(classes[i])
    plt.show()

首先是畫出整個loss函數(shù)趨勢:



最后再可視化一下w權(quán)值傻唾,看看每一個種類提取處理的特征是什么樣子的:


?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末投慈,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌伪煤,老刑警劉巖加袋,帶你破解...
    沈念sama閱讀 210,978評論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異带族,居然都是意外死亡锁荔,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 89,954評論 2 384
  • 文/潘曉璐 我一進(jìn)店門蝙砌,熙熙樓的掌柜王于貴愁眉苦臉地迎上來阳堕,“玉大人,你說我怎么就攤上這事择克√褡埽” “怎么了?”我有些...
    開封第一講書人閱讀 156,623評論 0 345
  • 文/不壞的土叔 我叫張陵肚邢,是天一觀的道長壹堰。 經(jīng)常有香客問我,道長骡湖,這世上最難降的妖魔是什么贱纠? 我笑而不...
    開封第一講書人閱讀 56,324評論 1 282
  • 正文 為了忘掉前任,我火速辦了婚禮响蕴,結(jié)果婚禮上谆焊,老公的妹妹穿的比我還像新娘。我一直安慰自己浦夷,他們只是感情好辖试,可當(dāng)我...
    茶點故事閱讀 65,390評論 5 384
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著劈狐,像睡著了一般罐孝。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上肥缔,一...
    開封第一講書人閱讀 49,741評論 1 289
  • 那天莲兢,我揣著相機與錄音,去河邊找鬼续膳。 笑死怒见,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的姑宽。 我是一名探鬼主播遣耍,決...
    沈念sama閱讀 38,892評論 3 405
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼炮车!你這毒婦竟也來了舵变?” 一聲冷哼從身側(cè)響起酣溃,我...
    開封第一講書人閱讀 37,655評論 0 266
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎纪隙,沒想到半個月后赊豌,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,104評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡绵咱,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,451評論 2 325
  • 正文 我和宋清朗相戀三年碘饼,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片悲伶。...
    茶點故事閱讀 38,569評論 1 340
  • 序言:一個原本活蹦亂跳的男人離奇死亡艾恼,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出麸锉,到底是詐尸還是另有隱情钠绍,我是刑警寧澤,帶...
    沈念sama閱讀 34,254評論 4 328
  • 正文 年R本政府宣布花沉,位于F島的核電站柳爽,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏碱屁。R本人自食惡果不足惜磷脯,卻給世界環(huán)境...
    茶點故事閱讀 39,834評論 3 312
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望娩脾。 院中可真熱鬧赵誓,春花似錦、人聲如沸晦雨。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,725評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽闹瞧。三九已至,卻和暖如春展辞,著一層夾襖步出監(jiān)牢的瞬間奥邮,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,950評論 1 264
  • 我被黑心中介騙來泰國打工罗珍, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留洽腺,地道東北人。 一個月前我還...
    沈念sama閱讀 46,260評論 2 360
  • 正文 我出身青樓覆旱,卻偏偏與公主長得像蘸朋,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子扣唱,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 43,446評論 2 348

推薦閱讀更多精彩內(nèi)容

  • 以西瓜書為主線藕坯,以其他書籍作為參考進(jìn)行補充团南,例如《統(tǒng)計學(xué)習(xí)方法》,《PRML》等 第一章 緒論 1.2 基本術(shù)語 ...
    danielAck閱讀 4,500評論 0 6
  • 機器學(xué)習(xí)是做NLP和計算機視覺這類應(yīng)用算法的基礎(chǔ)炼彪,雖然現(xiàn)在深度學(xué)習(xí)模型大行其道吐根,但是懂一些傳統(tǒng)算法的原理和它們之間...
    在河之簡閱讀 20,487評論 4 65
  • 不止筆記的更新停了一周拷橘,并非不止君偷懶不愿寫,最初的計劃是每天寫一篇1000字左右的文章喜爷,嘗試之后發(fā)現(xiàn)不現(xiàn)實冗疮,一方...
    張軒銘閱讀 2,612評論 9 122
  • 昨晚沒睡好赌厅,那時候,疲憊著的身體帶著倦怠的雙眼轿塔,靜靜躺著等待夢之遇見特愿,我想就這樣睡去:夢見田里野花開放,竹...
    四火哥哥閱讀 198評論 1 1
  • 這話得從頭說起勾缭,卻不知從何說起揍障。 不知道是不是年紀(jì)越大,越喜歡對晚輩(反正就是比之年紀(jì)小的人)說這說那俩由。像最常見的...
    剪佳閱讀 392評論 0 0