不得不看的圖片搜索系統(tǒng)實(shí)現(xiàn)

原創(chuàng):余曉龍

圖片搜索系統(tǒng)主要分為特征提取和特征匹配兩個(gè)部分,其中特征提取是深度學(xué)習(xí)模型中進(jìn)行數(shù)據(jù)處理的主要環(huán)節(jié)巡语,本文將通過一種基于無監(jiān)督方式---最大化深度互信息(DIM)方法來進(jìn)行特征提取似踱,并利用提取出來的低維特征實(shí)現(xiàn)圖片搜索系統(tǒng)。

1. DIM模型原理

DIM模型是通過計(jì)算輸入樣本與編碼器輸出的特征向量之間的互信息沐旨,利用最大化互信息來實(shí)現(xiàn)模型的訓(xùn)練。DIM模型在無監(jiān)督訓(xùn)練中使用兩種約束來表示學(xué)習(xí)榨婆。

(1)最大化輸入信息和高級特征向量之間的互信息:如果模型輸出的低維特征能夠代表輸入樣本磁携,那么該特征分布與輸入樣本分布的互信息一定是最大的。

(2)對抗匹配先驗(yàn)分布:編碼器輸出的高級特征要更接近高斯分布良风,判別器要將編碼器生成的數(shù)據(jù)分布與高斯分布進(jìn)行區(qū)分谊迄。

在實(shí)現(xiàn)的時(shí)候,DIM模型使用了3個(gè)判別器烟央,分別從局部互信息的最大化统诺、全局互信息的最大化和先驗(yàn)分布匹配的最小化3個(gè)角度來對編碼器的輸出結(jié)果進(jìn)行約束。

2. 局部互信息和全局互信息最大化約束的原理

局部特征可以理解為進(jìn)行卷積后得到的特征圖疑俭,全局特征可以理解為對特征圖進(jìn)行編碼得到的特征向量粮呢。對于圖片,它的相關(guān)性更多的體現(xiàn)在局部钞艇。圖像識別啄寡、分類是一個(gè)從局部到整體的過程、即全局特征更適用于重構(gòu)香璃,局部特征更適用于分類任務(wù)这难。DIM模型從局部和全局兩個(gè)角度對輸入和輸出計(jì)算互信息,而先驗(yàn)匹配的目的是對編碼器生成的向量形式進(jìn)行約束葡秒,使其更接近高斯分布姻乓。

3. 先驗(yàn)分布匹配最小化約束的原理

DIM模型的編碼器主要思想是對輸入數(shù)據(jù)進(jìn)行編碼成特征向量的同時(shí)嵌溢,還希望該特征向量服從于標(biāo)準(zhǔn)的高斯分布,這樣做的主要作用是使的編碼空間更加規(guī)范蹋岩,有利于解藕特征以便后續(xù)學(xué)習(xí)赖草。

4. 代碼實(shí)現(xiàn)

本文通過使用Fashion-MNIST數(shù)據(jù)集來實(shí)現(xiàn)圖片搜素器。Fashion-MNIST的單個(gè)樣本大小為28*28像素的灰度圖剪个,其中包含訓(xùn)練集60000張圖片秧骑、測試集10000張圖片。樣本的標(biāo)簽一共分為10類扣囊,包括T-shirt(T恤)乎折、Trouser(褲子??)、Pullover(套衫)侵歇、Dress(裙子??)骂澄、Coat(外套??)、Sandal(涼鞋??)惕虑、Shirt(襯衫??)坟冲、Sneaker(運(yùn)動鞋??)、Bag(包??)溃蔫、Ankle boot(踝靴??)健提。

4.1 加載并顯示Fashion-MNIST數(shù)據(jù)集

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

from torchvision.datasets.mnist import FashionMNIST
from torch.optim import Adam
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
from pathlib import Path
from torchvision.transforms import ToPILImage
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1, 2, 3'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

batch_size = 256
data_dir = r'./fashon_mnist/'
train_dataset = FashionMNIST(data_dir, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, drop_last=True,
                          pin_memory=torch.cuda.is_available())
print('train:', len(train_dataset))


def imshowrow(imgs, nrow):
    plt.figure(dpi=200)
    _img = ToPILImage()(torchvision.utils.make_grid(imgs, nrow=nrow))
    plt.axis('off')
    plt.imshow(_img)
    plt.show()



classes = ('T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle_Boot')

sample = iter(train_loader)
images, labels = sample.next()
print('sample shape:', np.shape(images))
print('sample label:', ','.join('%2d:%-5s' % (labels[j],
                                              classes[labels[j]])
                                for j in range(len(images[:10]))))
imshowrow(images[:10], nrow=10)

4.2 實(shí)現(xiàn)DIM模型

定義編碼器模型類Encoder與判別器類DeepInfoMaxLoss

Encoder:通過多個(gè)卷積層對輸入數(shù)據(jù)進(jìn)行編碼,生成64維特征向量伟叛,

DeepInfoMaxLoss:實(shí)現(xiàn)全局私痹、局部、先驗(yàn)判別器三個(gè)模型結(jié)構(gòu)痪伦,合并損失函數(shù)得到總損失侄榴。

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.c0 = nn.Conv2d(1, 64, kernel_size=4, stride=1)
        self.c1 = nn.Conv2d(64, 128, kernel_size=4, stride=1)
        self.c2 = nn.Conv2d(128, 256, kernel_size=4, stride=1)
        self.c3 = nn.Conv2d(256, 512, kernel_size=4, stride=1)
    
        self.l1 = nn.Linear(512*16*16, 64)

        self.b1 = nn.BatchNorm2d(128)
        self.b2 = nn.BatchNorm2d(256)
        self.b3 = nn.BatchNorm2d(512)

    def forward(self, x):
        # print('x', x.shape)  # torch.Size([256, 1, 28, 28])
        h = F.relu(self.c0(x))
        # print('h1', h.size())  # torch.Size([256, 64, 25, 25])
        features = F.relu(self.b1(self.c1(h)))
        # print('features', features.size())  # torch.Size([256, 128, 22, 22])
        h = F.relu(self.b2(self.c2(features)))
        # print('h2', h.size())  # torch.Size([256, 256, 19, 19])
        h = F.relu(self.b3(self.c3(h)))
        # print('h3', h.size())  # torch.Size([256, 512, 16, 16])
        encoder = self.l1(h.view(x.shape[0], -1))
        return encoder, features


class DeepInfoMaxLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=1.0, gamma=0.1):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        self.local_d = nn.Sequential(
            nn.Conv2d(192, 512, kernel_size=1),
            nn.ReLU(True),
            nn.Conv2d(512, 512, kernel_size=1),
            nn.ReLU(True),
            nn.Conv2d(512, 1, kernel_size=1)
        )

        self.prior_d = nn.Sequential(
            nn.Linear(64, 1000),
            nn.ReLU(True),
            nn.Linear(1000, 200),
            nn.ReLU(True),
            nn.Linear(200, 1),
            nn.Sigmoid()
        )

        self.global_d_M = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3),
            nn.ReLU(True),
            nn.Conv2d(64, 32, kernel_size=3),
            nn.Flatten()
        )

        self.global_d_fc = nn.Sequential(
            nn.Linear(32 * 18 * 18 + 64, 512),
            nn.ReLU(True),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 1)
        )

    def GlobalD(self, y, M):
        h = self.global_d_M(M)
        h = torch.cat((y, h), dim=1)
        return self.global_d_fc(h)

    def forward(self, y, M, M_prime):
        y_exp = y.unsqueeze(-1).unsqueeze(-1)
        # print('y_exp', y_exp.shape)
        # y_exp torch.Size([256, 64, 1, 1])
        y_exp = y_exp.expand(-1, -1, 22, 22)
        # print('y_exp', y_exp.shape)
        # y_exp torch.Size([256, 64, 22, 22])
        y_M = torch.cat((M, y_exp), dim=1)
        # print('y_M', y_M.shape)
        # y_M torch.Size([256, 192, 22, 22])
        y_M_prime = torch.cat((M_prime, y_exp), dim=1)
        # print('y_M_prime', y_M_prime.shape)
        # y_M_prime torch.Size([256, 192, 22, 22])

        Ej = -F.softplus(-self.local_d(y_M)).mean()
        Em = F.softplus(self.local_d(y_M_prime)).mean()
        Local = (Em - Ej) * self.beta

        Ej = -F.softplus(-self.GlobalD(y, M)).mean()
        Em = F.softplus(self.GlobalD(y, M_prime)).mean()
        Global = (Em - Ej) * self.alpha

        prior = torch.rand_like(y)
        term_a = torch.log(self.prior_d(prior)).mean()
        term_b = torch.log(1.0 - self.prior_d(y)).mean()
        Prior = -(term_a + term_b) * self.gamma

        return Local + Global + Prior

4.3 實(shí)例化模型并進(jìn)行訓(xùn)練

totalepoch = 100
if __name__ == '__main__':
    encoder = Encoder().to(device)
    loss_fn = DeepInfoMaxLoss().to(device)
    optim = Adam(encoder.parameters(), lr=1e-4)
    loss_optim = Adam(loss_fn.parameters(), lr=1e-4)

    epoch_loss = []
    for epoch in range(totalepoch + 1):
        batch = tqdm(train_loader, total=len(train_dataset) // batch_size)
        train_loss = []
        for x, target in batch:
            x = x.to(device)
            optim.zero_grad()
            loss_optim.zero_grad()
            y, M = encoder(x)

            M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0)
            loss = loss_fn(y, M, M_prime)
            train_loss.append(loss.item())
            batch.set_description(
                str(epoch) + ' Loss:%.4f' % np.mean(train_loss[-20:]
            ))
            loss.backward()
            optim.step()
            loss_optim.step()

        if epoch % 10 == 0:
            root = Path(r'./DIMmodel2/')
            enc_file = root / Path('encoder' + str(epoch) + '.pth')
            loss_file = root / Path('loss' + str(epoch) + '.pth')
            enc_file.parent.mkdir(parents=True, exist_ok=True)
            torch.save(encoder.state_dict(), str(enc_file))
            torch.save(loss_fn.state_dict(), str(loss_file))
       
        epoch_loss.append(np.mean(train_loss[-20:]))
    plt.plot(np.arange(len(epoch_loss)), epoch_loss, 'r')
    plt.show()

訓(xùn)練完成后得到模型文件,在DIMmodel2文件夾下生成encoder100.pth和loss.pth网沾。

4.4 加載模型實(shí)現(xiàn)圖像搜索

import random

model_path = r'./DIMmodel2/encoder%d.pth' % (totalepoch)
encoder = Encoder().to(device)
encoder.load_state_dict(torch.load(model_path, map_location=device))

batchesimg, batchesenc = [], []
batch = tqdm(train_loader, total=len(train_dataset) // batch_size)

for images, target in batch:
    images = images.to(device)
    with torch.no_grad():
        encoded, features = encoder(images)
    batchesimg.append(images)
    batchesenc.append(encoded)

batchesenc = torch.cat(batchesenc, axis=0)
batchesimg = torch.cat(batchesimg, axis=0)

index = random.randrange(0, len(batchesenc))
batchesenc[index].repeat(len(batchesenc), 1)

l2_dis = F.mse_loss(batchesenc[index].repeat(len(batchesenc), 1),
                    batchesenc, reduction='none').sum(1)

findnum = 5   # 設(shè)置需要查找圖片的個(gè)數(shù)
_, indices = l2_dis.topk(findnum, largest=False) # 查找出5個(gè)最相似的圖片

indices = torch.cat([torch.tensor([index]).to(device), indices])

rel = batchesimg[indices]
imshowrow(rel.cpu(), nrow=len(indices))

從結(jié)果圖像可以看出癞蚕,查找出的最相似的5張圖片與查找的圖像是一樣的。通過最大化深度互信息模型實(shí)現(xiàn)的圖像搜索是有效的辉哥。大家可以修改數(shù)據(jù)集桦山,實(shí)現(xiàn)自己的圖片搜素系統(tǒng)。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末醋旦,一起剝皮案震驚了整個(gè)濱河市恒水,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌饲齐,老刑警劉巖钉凌,帶你破解...
    沈念sama閱讀 221,820評論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異捂人,居然都是意外死亡御雕,警方通過查閱死者的電腦和手機(jī)矢沿,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,648評論 3 399
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來酸纲,“玉大人捣鲸,你說我怎么就攤上這事∶銎拢” “怎么了栽惶?”我有些...
    開封第一講書人閱讀 168,324評論 0 360
  • 文/不壞的土叔 我叫張陵,是天一觀的道長疾嗅。 經(jīng)常有香客問我外厂,道長,這世上最難降的妖魔是什么宪迟? 我笑而不...
    開封第一講書人閱讀 59,714評論 1 297
  • 正文 為了忘掉前任酣衷,我火速辦了婚禮,結(jié)果婚禮上次泽,老公的妹妹穿的比我還像新娘。我一直安慰自己席爽,他們只是感情好意荤,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,724評論 6 397
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著只锻,像睡著了一般玖像。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上齐饮,一...
    開封第一講書人閱讀 52,328評論 1 310
  • 那天捐寥,我揣著相機(jī)與錄音,去河邊找鬼祖驱。 笑死握恳,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的捺僻。 我是一名探鬼主播乡洼,決...
    沈念sama閱讀 40,897評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼匕坯!你這毒婦竟也來了束昵?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,804評論 0 276
  • 序言:老撾萬榮一對情侶失蹤葛峻,失蹤者是張志新(化名)和其女友劉穎锹雏,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體术奖,經(jīng)...
    沈念sama閱讀 46,345評論 1 318
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡礁遵,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,431評論 3 340
  • 正文 我和宋清朗相戀三年匿辩,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片榛丢。...
    茶點(diǎn)故事閱讀 40,561評論 1 352
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡铲球,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出晰赞,到底是詐尸還是另有隱情稼病,我是刑警寧澤,帶...
    沈念sama閱讀 36,238評論 5 350
  • 正文 年R本政府宣布掖鱼,位于F島的核電站然走,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏戏挡。R本人自食惡果不足惜芍瑞,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,928評論 3 334
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望褐墅。 院中可真熱鬧拆檬,春花似錦、人聲如沸妥凳。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,417評論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽逝钥。三九已至屑那,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間艘款,已是汗流浹背持际。 一陣腳步聲響...
    開封第一講書人閱讀 33,528評論 1 272
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留哗咆,地道東北人蜘欲。 一個(gè)月前我還...
    沈念sama閱讀 48,983評論 3 376
  • 正文 我出身青樓,卻偏偏與公主長得像岳枷,于是被迫代替她去往敵國和親芒填。 傳聞我的和親對象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,573評論 2 359

推薦閱讀更多精彩內(nèi)容