原創(chuàng):余曉龍
圖片搜索系統(tǒng)主要分為特征提取和特征匹配兩個(gè)部分,其中特征提取是深度學(xué)習(xí)模型中進(jìn)行數(shù)據(jù)處理的主要環(huán)節(jié)巡语,本文將通過一種基于無監(jiān)督方式---最大化深度互信息(DIM)方法來進(jìn)行特征提取似踱,并利用提取出來的低維特征實(shí)現(xiàn)圖片搜索系統(tǒng)。
1. DIM模型原理
DIM模型是通過計(jì)算輸入樣本與編碼器輸出的特征向量之間的互信息沐旨,利用最大化互信息來實(shí)現(xiàn)模型的訓(xùn)練。DIM模型在無監(jiān)督訓(xùn)練中使用兩種約束來表示學(xué)習(xí)榨婆。
(1)最大化輸入信息和高級特征向量之間的互信息:如果模型輸出的低維特征能夠代表輸入樣本磁携,那么該特征分布與輸入樣本分布的互信息一定是最大的。
(2)對抗匹配先驗(yàn)分布:編碼器輸出的高級特征要更接近高斯分布良风,判別器要將編碼器生成的數(shù)據(jù)分布與高斯分布進(jìn)行區(qū)分谊迄。
在實(shí)現(xiàn)的時(shí)候,DIM模型使用了3個(gè)判別器烟央,分別從局部互信息的最大化统诺、全局互信息的最大化和先驗(yàn)分布匹配的最小化3個(gè)角度來對編碼器的輸出結(jié)果進(jìn)行約束。
2. 局部互信息和全局互信息最大化約束的原理
局部特征可以理解為進(jìn)行卷積后得到的特征圖疑俭,全局特征可以理解為對特征圖進(jìn)行編碼得到的特征向量粮呢。對于圖片,它的相關(guān)性更多的體現(xiàn)在局部钞艇。圖像識別啄寡、分類是一個(gè)從局部到整體的過程、即全局特征更適用于重構(gòu)香璃,局部特征更適用于分類任務(wù)这难。DIM模型從局部和全局兩個(gè)角度對輸入和輸出計(jì)算互信息,而先驗(yàn)匹配的目的是對編碼器生成的向量形式進(jìn)行約束葡秒,使其更接近高斯分布姻乓。
3. 先驗(yàn)分布匹配最小化約束的原理
DIM模型的編碼器主要思想是對輸入數(shù)據(jù)進(jìn)行編碼成特征向量的同時(shí)嵌溢,還希望該特征向量服從于標(biāo)準(zhǔn)的高斯分布,這樣做的主要作用是使的編碼空間更加規(guī)范蹋岩,有利于解藕特征以便后續(xù)學(xué)習(xí)赖草。
4. 代碼實(shí)現(xiàn)
本文通過使用Fashion-MNIST數(shù)據(jù)集來實(shí)現(xiàn)圖片搜素器。Fashion-MNIST的單個(gè)樣本大小為28*28像素的灰度圖剪个,其中包含訓(xùn)練集60000張圖片秧骑、測試集10000張圖片。樣本的標(biāo)簽一共分為10類扣囊,包括T-shirt(T恤)乎折、Trouser(褲子??)、Pullover(套衫)侵歇、Dress(裙子??)骂澄、Coat(外套??)、Sandal(涼鞋??)惕虑、Shirt(襯衫??)坟冲、Sneaker(運(yùn)動鞋??)、Bag(包??)溃蔫、Ankle boot(踝靴??)健提。
4.1 加載并顯示Fashion-MNIST數(shù)據(jù)集
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision.datasets.mnist import FashionMNIST
from torch.optim import Adam
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
from pathlib import Path
from torchvision.transforms import ToPILImage
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1, 2, 3'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
batch_size = 256
data_dir = r'./fashon_mnist/'
train_dataset = FashionMNIST(data_dir, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, drop_last=True,
pin_memory=torch.cuda.is_available())
print('train:', len(train_dataset))
def imshowrow(imgs, nrow):
plt.figure(dpi=200)
_img = ToPILImage()(torchvision.utils.make_grid(imgs, nrow=nrow))
plt.axis('off')
plt.imshow(_img)
plt.show()
classes = ('T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle_Boot')
sample = iter(train_loader)
images, labels = sample.next()
print('sample shape:', np.shape(images))
print('sample label:', ','.join('%2d:%-5s' % (labels[j],
classes[labels[j]])
for j in range(len(images[:10]))))
imshowrow(images[:10], nrow=10)
4.2 實(shí)現(xiàn)DIM模型
定義編碼器模型類Encoder與判別器類DeepInfoMaxLoss
Encoder:通過多個(gè)卷積層對輸入數(shù)據(jù)進(jìn)行編碼,生成64維特征向量伟叛,
DeepInfoMaxLoss:實(shí)現(xiàn)全局私痹、局部、先驗(yàn)判別器三個(gè)模型結(jié)構(gòu)痪伦,合并損失函數(shù)得到總損失侄榴。
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.c0 = nn.Conv2d(1, 64, kernel_size=4, stride=1)
self.c1 = nn.Conv2d(64, 128, kernel_size=4, stride=1)
self.c2 = nn.Conv2d(128, 256, kernel_size=4, stride=1)
self.c3 = nn.Conv2d(256, 512, kernel_size=4, stride=1)
self.l1 = nn.Linear(512*16*16, 64)
self.b1 = nn.BatchNorm2d(128)
self.b2 = nn.BatchNorm2d(256)
self.b3 = nn.BatchNorm2d(512)
def forward(self, x):
# print('x', x.shape) # torch.Size([256, 1, 28, 28])
h = F.relu(self.c0(x))
# print('h1', h.size()) # torch.Size([256, 64, 25, 25])
features = F.relu(self.b1(self.c1(h)))
# print('features', features.size()) # torch.Size([256, 128, 22, 22])
h = F.relu(self.b2(self.c2(features)))
# print('h2', h.size()) # torch.Size([256, 256, 19, 19])
h = F.relu(self.b3(self.c3(h)))
# print('h3', h.size()) # torch.Size([256, 512, 16, 16])
encoder = self.l1(h.view(x.shape[0], -1))
return encoder, features
class DeepInfoMaxLoss(nn.Module):
def __init__(self, alpha=0.5, beta=1.0, gamma=0.1):
super().__init__()
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.local_d = nn.Sequential(
nn.Conv2d(192, 512, kernel_size=1),
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=1),
nn.ReLU(True),
nn.Conv2d(512, 1, kernel_size=1)
)
self.prior_d = nn.Sequential(
nn.Linear(64, 1000),
nn.ReLU(True),
nn.Linear(1000, 200),
nn.ReLU(True),
nn.Linear(200, 1),
nn.Sigmoid()
)
self.global_d_M = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3),
nn.ReLU(True),
nn.Conv2d(64, 32, kernel_size=3),
nn.Flatten()
)
self.global_d_fc = nn.Sequential(
nn.Linear(32 * 18 * 18 + 64, 512),
nn.ReLU(True),
nn.Linear(512, 512),
nn.ReLU(True),
nn.Linear(512, 1)
)
def GlobalD(self, y, M):
h = self.global_d_M(M)
h = torch.cat((y, h), dim=1)
return self.global_d_fc(h)
def forward(self, y, M, M_prime):
y_exp = y.unsqueeze(-1).unsqueeze(-1)
# print('y_exp', y_exp.shape)
# y_exp torch.Size([256, 64, 1, 1])
y_exp = y_exp.expand(-1, -1, 22, 22)
# print('y_exp', y_exp.shape)
# y_exp torch.Size([256, 64, 22, 22])
y_M = torch.cat((M, y_exp), dim=1)
# print('y_M', y_M.shape)
# y_M torch.Size([256, 192, 22, 22])
y_M_prime = torch.cat((M_prime, y_exp), dim=1)
# print('y_M_prime', y_M_prime.shape)
# y_M_prime torch.Size([256, 192, 22, 22])
Ej = -F.softplus(-self.local_d(y_M)).mean()
Em = F.softplus(self.local_d(y_M_prime)).mean()
Local = (Em - Ej) * self.beta
Ej = -F.softplus(-self.GlobalD(y, M)).mean()
Em = F.softplus(self.GlobalD(y, M_prime)).mean()
Global = (Em - Ej) * self.alpha
prior = torch.rand_like(y)
term_a = torch.log(self.prior_d(prior)).mean()
term_b = torch.log(1.0 - self.prior_d(y)).mean()
Prior = -(term_a + term_b) * self.gamma
return Local + Global + Prior
4.3 實(shí)例化模型并進(jìn)行訓(xùn)練
totalepoch = 100
if __name__ == '__main__':
encoder = Encoder().to(device)
loss_fn = DeepInfoMaxLoss().to(device)
optim = Adam(encoder.parameters(), lr=1e-4)
loss_optim = Adam(loss_fn.parameters(), lr=1e-4)
epoch_loss = []
for epoch in range(totalepoch + 1):
batch = tqdm(train_loader, total=len(train_dataset) // batch_size)
train_loss = []
for x, target in batch:
x = x.to(device)
optim.zero_grad()
loss_optim.zero_grad()
y, M = encoder(x)
M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0)
loss = loss_fn(y, M, M_prime)
train_loss.append(loss.item())
batch.set_description(
str(epoch) + ' Loss:%.4f' % np.mean(train_loss[-20:]
))
loss.backward()
optim.step()
loss_optim.step()
if epoch % 10 == 0:
root = Path(r'./DIMmodel2/')
enc_file = root / Path('encoder' + str(epoch) + '.pth')
loss_file = root / Path('loss' + str(epoch) + '.pth')
enc_file.parent.mkdir(parents=True, exist_ok=True)
torch.save(encoder.state_dict(), str(enc_file))
torch.save(loss_fn.state_dict(), str(loss_file))
epoch_loss.append(np.mean(train_loss[-20:]))
plt.plot(np.arange(len(epoch_loss)), epoch_loss, 'r')
plt.show()
訓(xùn)練完成后得到模型文件,在DIMmodel2文件夾下生成encoder100.pth和loss.pth网沾。
4.4 加載模型實(shí)現(xiàn)圖像搜索
import random
model_path = r'./DIMmodel2/encoder%d.pth' % (totalepoch)
encoder = Encoder().to(device)
encoder.load_state_dict(torch.load(model_path, map_location=device))
batchesimg, batchesenc = [], []
batch = tqdm(train_loader, total=len(train_dataset) // batch_size)
for images, target in batch:
images = images.to(device)
with torch.no_grad():
encoded, features = encoder(images)
batchesimg.append(images)
batchesenc.append(encoded)
batchesenc = torch.cat(batchesenc, axis=0)
batchesimg = torch.cat(batchesimg, axis=0)
index = random.randrange(0, len(batchesenc))
batchesenc[index].repeat(len(batchesenc), 1)
l2_dis = F.mse_loss(batchesenc[index].repeat(len(batchesenc), 1),
batchesenc, reduction='none').sum(1)
findnum = 5 # 設(shè)置需要查找圖片的個(gè)數(shù)
_, indices = l2_dis.topk(findnum, largest=False) # 查找出5個(gè)最相似的圖片
indices = torch.cat([torch.tensor([index]).to(device), indices])
rel = batchesimg[indices]
imshowrow(rel.cpu(), nrow=len(indices))
從結(jié)果圖像可以看出癞蚕,查找出的最相似的5張圖片與查找的圖像是一樣的。通過最大化深度互信息模型實(shí)現(xiàn)的圖像搜索是有效的辉哥。大家可以修改數(shù)據(jù)集桦山,實(shí)現(xiàn)自己的圖片搜素系統(tǒng)。