github代碼地址:https://github.com/aditya12agd5/convcap
論文:Convolutional Image Captioning
該網(wǎng)絡(luò)簡(jiǎn)單地說(shuō)就是使用VGG16提取特征消约,通過Attention+LSTM進(jìn)行語(yǔ)句生成的端到端網(wǎng)絡(luò)。不說(shuō)了先上網(wǎng)絡(luò)總體結(jié)構(gòu)圖抹剩。
論文是我懵逼,我還是從代碼說(shuō)吧抢野。
1.特征提取網(wǎng)絡(luò)VGG16
特征提取模塊就是一個(gè)VGG16.
vggfeats.py
import torch
import torch.nn as nn
from torchvision import models
from torch.autograd import Variable
pretrained_model = models.vgg16(pretrained=True)
class Vgg16Feats(nn.Module):
def __init__(self):
super(Vgg16Feats, self).__init__()
self.features_nopool = nn.Sequential(*list(pretrained_model.features.children())[:-1])
self.features_pool = list(pretrained_model.features.children())[-1]
self.classifier = nn.Sequential(*list(pretrained_model.classifier.children())[:-1])
def forward(self, x):
# x:[20,512,14,14]
x = self.features_nopool(x)
# y:[20,512,7,7]
x_pool = self.features_pool(x)
# x_feat:[20,25088]
x_feat = x_pool.view(x_pool.size(0), -1)
# y:[20,4096]
y = self.classifier(x_feat)
return x_pool, y
2.convcap主體網(wǎng)絡(luò)
我繪制的convcap主體網(wǎng)絡(luò)毁习,很難看。
convcap主體網(wǎng)絡(luò)流程手稿:
attention流程手稿
convcap.py
# -*- coding: utf-8 -*-
import sys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
#Layers adapted for captioning from https://arxiv.org/abs/1705.03122
def Conv1d(in_channels, out_channels, kernel_size, padding, dropout=0):
m = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding)
std = math.sqrt((4 * (1.0 - dropout)) / (kernel_size * in_channels))
m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_()
return nn.utils.weight_norm(m)
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(0, 0.1)
return m
def Linear(in_features, out_features, dropout=0.):
m = nn.Linear(in_features, out_features)
m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
m.bias.data.zero_()
return nn.utils.weight_norm(m)
# 注意力層馍悟,
class AttentionLayer(nn.Module):
def __init__(self, conv_channels, embed_dim):
super(AttentionLayer, self).__init__()
self.in_projection = Linear(conv_channels, embed_dim)
self.out_projection = Linear(embed_dim, conv_channels)
self.bmm = torch.bmm
def forward(self, x, wordemb, imgsfeats):
residual = x
x = (self.in_projection(x) + wordemb) * math.sqrt(0.5)
b, c, f_h, f_w = imgsfeats.size()
y = imgsfeats.view(b, c, f_h*f_w)
# 批二維矩陣乘法
x = self.bmm(x, y)
sz = x.size()
x = F.softmax(x.view(sz[0] * sz[1], sz[2]))
x = x.view(sz)
attn_scores = x
# 矩陣的維度換位
y = y.permute(0, 2, 1)
x = self.bmm(x, y)
s = y.size(1)
x = x * (s * math.sqrt(1.0 / s))
x = (self.out_projection(x) + residual) * math.sqrt(0.5)
return x, attn_scores
class convcap(nn.Module):
def __init__(self, num_wordclass, num_layers=1, is_attention=True, nfeats=512, dropout=.1):
super(convcap, self).__init__()
# 說(shuō)明使用的是VGG16的全鏈接層的特征
self.nimgfeats = 4096
self.is_attention = is_attention
# 每個(gè)單詞的特征維度
self.nfeats = nfeats
# 棄權(quán)率10%
self.dropout = dropout
# 初始化詞向量
self.emb_0 = Embedding(num_wordclass, nfeats, padding_idx=0)
# 初始化一個(gè)輸出輸入大小微單詞特征的全鏈接層
self.emb_1 = Linear(nfeats, nfeats, dropout=dropout)
# 初始化輸入微4906,輸出微單詞特征的全鏈接層
self.imgproj = Linear(self.nimgfeats, self.nfeats, dropout=dropout)
# 初始化輸入微單詞特征*2,輸出微單詞特征的全練級(jí)層柠衍,
self.resproj = Linear(nfeats*2, self.nfeats, dropout=dropout)
n_in = 2*self.nfeats
n_out = self.nfeats
self.n_layers = num_layers
# 生成卷積以及注意力的操作列表
self.convs = nn.ModuleList()
self.attention = nn.ModuleList()
# 核大小
self.kernel_size = 5
# 擴(kuò)邊大小
self.pad = self.kernel_size - 1
for i in range(self.n_layers):
self.convs.append(Conv1d(n_in, 2*n_out, self.kernel_size, self.pad, dropout))
if(self.is_attention):
self.attention.append(AttentionLayer(n_out, nfeats))
n_in = n_out
# 后兩層作為單詞類別識(shí)別
self.classifier_0 = Linear(self.nfeats, (nfeats // 2))
self.classifier_1 = Linear((nfeats // 2), num_wordclass, dropout=dropout)
def forward(self, imgsfeats, imgsfc7, wordclass):
attn_buffer = None
# 句子的此向量
wordemb = self.emb_0(wordclass)
# 句子向量進(jìn)行一次全鏈接
wordemb = self.emb_1(wordemb)
# wordemb洋满,第二維15個(gè)的單詞,第三位每個(gè)單詞的特征x:[100, 512, 15]
x = wordemb.transpose(2, 1)
batchsize, wordembdim, maxtokens = x.size()
# 將輸入特征從4096變?yōu)?12珍坊,在第三位復(fù)制15份牺勾,表示15個(gè)句子·y:[100, 512, 15]
y = F.relu(self.imgproj(imgsfc7))
y = y.unsqueeze(2).expand(batchsize, self.nfeats, maxtokens)
# 將特征與結(jié)果特征拼接,得到x: [100,1024, 15]
x = torch.cat([x, y], 1)
for i, conv in enumerate(self.convs):
if(i == 0):
# 將1,2維變化位置得到x:[100,15,1024]
x = x.transpose(2, 1)
# residual:[100, 512, 15] x: [100, 1024, 15]
residual = self.resproj(x)
residual = residual.transpose(2, 1)
x = x.transpose(2, 1)
else:
residual = x
# 棄權(quán)
x = F.dropout(x, p=self.dropout, training=self.training)
# 一維卷積
x = conv(x)
x = x[:,:,:-self.pad]
x = F.glu(x, dim=1)
if(self.is_attention):
attn = self.attention[i]
x = x.transpose(2, 1)
# x圖像全連接層與詞向量的組合阵漏,wordemb詞向量驻民,imgsfeats全連接層前的特征
x, attn_buffer = attn(x, wordemb, imgsfeats)
x = x.transpose(2, 1)
x = (x+residual)*math.sqrt(.5)
x = x.transpose(2, 1)
x = self.classifier_0(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.classifier_1(x)
x = x.transpose(2, 1)
return x, attn_buffer
3.訓(xùn)練
train.py
僅為部分代碼
for batch_idx, (imgs, captions, wordclass, mask, _) in \
tqdm(enumerate(train_data_loader), total=nbatches):
imgs = imgs.view(batchsize, 3, 224, 224)
wordclass = wordclass.view(batchsize_cap, max_tokens)
mask = mask.view(batchsize_cap, max_tokens)
imgs_v = Variable(imgs).cuda()
wordclass_v = Variable(wordclass).cuda()
optimizer.zero_grad()
if(img_optimizer):
img_optimizer.zero_grad()
# 提取圖像特征
imgsfeats, imgsfc7 = model_imgcnn(imgs_v)
imgsfeats, imgsfc7 = repeat_img_per_cap(imgsfeats, imgsfc7, ncap_per_img)
_, _, feat_h, feat_w = imgsfeats.size()
# 執(zhí)行concap部分網(wǎng)絡(luò)獲取輸出語(yǔ)句及attention
if(args.attention == True):
wordact, attn = model_convcap(imgsfeats, imgsfc7, wordclass_v)
attn = attn.view(batchsize_cap, max_tokens, feat_h, feat_w)
else:
wordact, _ = model_convcap(imgsfeats, imgsfc7, wordclass_v)
# 去除無(wú)異議的結(jié)束符和開始符
wordact = wordact[:,:,:-1]
wordclass_v = wordclass_v[:,1:]
mask = mask[:,1:].contiguous()
wordact_t = wordact.permute(0, 2, 1).contiguous().view(\
batchsize_cap*(max_tokens-1), -1)
wordclass_t = wordclass_v.contiguous().view(\
batchsize_cap*(max_tokens-1), 1)
# 獲取語(yǔ)句中有意義的部分
maskids = torch.nonzero(mask.view(-1)).numpy().reshape(-1)
if(args.attention == True):
#Cross-entropy損失和注意力的損失
loss = F.cross_entropy(wordact_t[maskids, ...], \
wordclass_t[maskids, ...].contiguous().view(maskids.shape[0])) \
+ (torch.sum(torch.pow(1. - torch.sum(attn, 1), 2)))\
/(batchsize_cap*feat_h*feat_w)
else:
loss = F.cross_entropy(wordact_t[maskids, ...], \
wordclass_t[maskids, ...].contiguous().view(maskids.shape[0]))