模型
四種模式
- CNN-rand: 單詞向量是隨機初始化,向量隨著模型學(xué)習(xí)而改變
- CNN-static: 使用預(yù)訓(xùn)練的靜態(tài)詞向量档址,向量不會隨著模型學(xué)習(xí)而改變
- CNN-non-static: 使用預(yù)訓(xùn)練的靜態(tài)詞向量祠斧,預(yù)訓(xùn)練的向量可以微調(diào)(fine-tuned)
- CNN-multichannel: 靜態(tài)+微調(diào) 兩個channel都使用預(yù)訓(xùn)練的靜態(tài)詞向量闻察,卷積核用在兩個channel上,反向傳播只改變一個channel
代碼
if args.static: #使用預(yù)訓(xùn)練的靜態(tài)詞向量
args.embedding_dim = text_field.vocab.vectors.size()[-1]
args.vectors = text_field.vocab.vectors
if args.multichannel:
args.static = True
args.non_static = True
# args.class_num = len(label_field.vocab)
args.class_num = len(label_field.vocab) - 1
import torch
import torch.nn as nn
import torch.nn.functional as F
class TextCNN(nn.Module):
def __init__(self, args):
super(TextCNN, self).__init__()
self.args = args
class_num = args.class_num
channel_num = 1
filter_num = args.filter_num
filter_sizes = args.filter_sizes
vocabulary_size = args.vocabulary_size
embedding_dimension = args.embedding_dim
self.embedding = nn.Embedding(vocabulary_size, embedding_dimension)
if args.static:
self.embedding = self.embedding.from_pretrained(args.vectors, freeze=not args.non_static)
if args.multichannel:
# multichannel:non_static=True and static=True
# channel1 fine-tuned
# channel2 static
self.embedding2 = nn.Embedding(vocabulary_size, embedding_dimension).from_pretrained(args.vectors)
channel_num += 1
else:
self.embedding2 = None
self.convs = nn.ModuleList(
# ModuleList是一個特殊的module琢锋,可以包含幾個子module辕漂,
# 可以像用list一樣使用它,但不能直接把輸入傳給 ModuleList吴超。
# (N, C_in, H, W) => (N, C_out, H, W)
[nn.Conv2d(channel_num, filter_num, (size, embedding_dimension)) for size in filter_sizes])
self.dropout = nn.Dropout(args.dropout)
self.fc = nn.Linear(len(filter_sizes) * filter_num, class_num)
def forward(self, x):
if self.embedding2:
x = torch.stack([self.embedding(x), self.embedding2(x)], dim=1)
else:
x = self.embedding(x)
# torch.unsqueeze()這個函數(shù)主要是對數(shù)據(jù)維度進行擴充钉嘹。給指定位置加上維數(shù)為一的維度
# 升維 (N, size, embedding_dimension) =>
# (N, channel_num, size, embedding_dimension)
x = x.unsqueeze(1)
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs] #卷積后降維
x = [F.max_pool1d(item, item.size(2)).squeeze(2) for item in x] #最大值池化后降維
#torch.squeeze() 這個函數(shù)主要對數(shù)據(jù)的維度進行壓縮,去掉維數(shù)為1的的維度
x = torch.cat(x, 1) # 拼接 3個卷集核鲸阻,一個卷集核100(filter_num)個值
x = self.dropout(x)
logits = self.fc(x)
return logits
問題
- target = target.data.sub(1)
- len(label_field.vocab) == 3 跋涣?