1.目標(biāo)檢測(cè)基礎(chǔ)
生成邊界框
# bbox是bounding box的縮寫
dog_bbox, cat_bbox = [60, 45, 378, 516], [400, 112, 655, 493]
def bbox_to_rect(bbox, color):
# 將邊界框(左上x, 左上y, 右下x, 右下y)格式轉(zhuǎn)換成matplotlib格式:
# ((左上x, 左上y), 寬, 高)
return d2l.plt.Rectangle(
xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],
fill=False, edgecolor=color, linewidth=2)
什么是錨框
對(duì)于一個(gè)檢測(cè)圖像舆蝴,會(huì)有大量的采樣區(qū)域诊霹,并調(diào)整邊緣使得能夠更準(zhǔn)確預(yù)測(cè)目標(biāo)的真實(shí)邊界框。錨框便是這樣一種方法渡嚣,它以每個(gè)像素為中心生成多個(gè)寬高比和大小不同的邊界框梢睛,這些邊界框便是錨框(anchor box),基于錨框可以進(jìn)行目標(biāo)檢測(cè)等等识椰。錨框一般有三種選取方式:人為經(jīng)驗(yàn)绝葡、K-Means聚類,作為超參數(shù)進(jìn)行學(xué)習(xí)
一次生成多個(gè)錨框
對(duì)于高寬分別為h腹鹉、w的圖像藏畅。分別以每個(gè)像素為中心生成形狀不同的錨框。其他參數(shù)有大小功咒,寬高比r愉阎,于是我們有了錨框?qū)捀?img class="math-inline" src="https://math.jianshu.com/math?formula=ws%5Csqrt%7Br%7D" alt="ws\sqrt{r}" mathimg="1">,
。
對(duì)于一組大小和一組寬高
如果我們對(duì)每一個(gè)像素都用寬高比和大小組合生成多個(gè)錨框力奋,最后會(huì)得到個(gè)框榜旦,復(fù)雜度會(huì)過高,所以一般我們只會(huì)對(duì)或者的全部組合感興趣景殷,即生成:
所以對(duì)這個(gè)圖像最終會(huì)生成個(gè)框
生成多個(gè)錨框的算法溅呢,其中feature_map是圖像的矩陣澡屡,一共有四維:N,C,H,W,分別是數(shù)目,通道數(shù)藕届,高挪蹭,寬,sizes是錨框相對(duì)圖像大小的列表休偶,ratios是寬高比的列表梁厉。首先先生成size和sqrt(r)的列表,總共n+m-1個(gè)踏兜,然后再生成相對(duì)坐標(biāo)中心的位置ss1/2词顾,ss2/2。然后將坐標(biāo)點(diǎn)組合起來碱妆。
返回值是一個(gè)三維的tensor肉盹,第一維是1,即圖像數(shù)目疹尾,第二維是框的數(shù)目上忍,第三維是的長度為3,分別是對(duì)角線點(diǎn)的坐標(biāo)
def MultiBoxPrior(feature_map, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5]):
"""
# 按照「9.4.1. 生成多個(gè)錨框」所講的實(shí)現(xiàn), anchor表示成(xmin, ymin, xmax, ymax).
https://zh.d2l.ai/chapter_computer-vision/anchor.html
Args:
feature_map: torch tensor, Shape: [N, C, H, W].
sizes: List of sizes (0~1) of generated MultiBoxPriores.
ratios: List of aspect ratios (non-negative) of generated MultiBoxPriores.
Returns:
anchors of shape (1, num_anchors, 4). 由于batch里每個(gè)都一樣, 所以第一維為1
"""
pairs = [] # pair of (size, sqrt(ration))
# 生成n + m -1個(gè)框
for r in ratios:
pairs.append([sizes[0], math.sqrt(r)])
for s in sizes[1:]:
pairs.append([s, math.sqrt(ratios[0])])
pairs = np.array(pairs)
# 生成相對(duì)于坐標(biāo)中心點(diǎn)的框(x,y,x,y)
ss1 = pairs[:, 0] * pairs[:, 1] # size * sqrt(ration)
ss2 = pairs[:, 0] / pairs[:, 1] # size / sqrt(ration)
base_anchors = np.stack([-ss1, -ss2, ss1, ss2], axis=1) / 2
#將坐標(biāo)點(diǎn)和anchor組合起來生成hw(n+m-1)個(gè)框輸出
h, w = feature_map.shape[-2:]
shifts_x = np.arange(0, w) / w
shifts_y = np.arange(0, h) / h
shift_x, shift_y = np.meshgrid(shifts_x, shifts_y)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = np.stack((shift_x, shift_y, shift_x, shift_y), axis=1)
anchors = shifts.reshape((-1, 1, 4)) + base_anchors.reshape((1, -1, 4))
return torch.tensor(anchors, dtype=torch.float32).view(1, -1, 4)
用于將框可視化:這里我們定義了一個(gè)_make_list私有函數(shù)纳本,用于將不是list的變量變?yōu)閘ist
def show_bboxes(axes, bboxes, labels=None, colors=None):
def _make_list(obj, default_values=None):
if obj is None:
obj = default_values
elif not isinstance(obj, (list, tuple)):
obj = [obj]
return obj
labels = _make_list(labels)
colors = _make_list(colors, ['b', 'g', 'r', 'm', 'c'])
for i, bbox in enumerate(bboxes):
color = colors[i % len(colors)]
rect = bbox_to_rect(bbox.detach().cpu().numpy(), color)
axes.add_patch(rect)
if labels and len(labels) > i:
text_color = 'k' if color == 'w' else 'w'
axes.text(rect.xy[0], rect.xy[1], labels[i],
va='center', ha='center', fontsize=6, color=text_color,
bbox=dict(facecolor=color, lw=0))
交并比
假設(shè)一個(gè)錨框能較好的覆蓋目標(biāo)窍蓝,那么怎么衡量多好呢?于是我們引入了交并比用于衡量?jī)蓚€(gè)集合的相似度(假設(shè)我們知道真實(shí)邊框)繁成。給定兩個(gè)集合吓笙,則它們的Jaccard系數(shù)即為兩者交集大小除以并集大小。
我們將邊框內(nèi)的區(qū)域看作是像素集合巾腕,這樣的話交集即重合的像素面睛,并集即所有像素的并。
torch.clamp:四個(gè)參數(shù)尊搬,input叁鉴、min、max毁嗦、out=None
將輸入的input張量的每個(gè)元素歸一化到min亲茅,max區(qū)間內(nèi)
tensor.unsqueeze(k)用于增加維度例如tensor為(2,3),那么unsqueeze(1)后為(1,2,3)
計(jì)算交集:即每個(gè)框的最大點(diǎn)減去最小點(diǎn)
def compute_intersection(set_1, set_2):
"""
計(jì)算anchor之間的交集
Args:
set_1: a tensor of dimensions (n1, 4), anchor表示成(xmin, ymin, xmax, ymax)
set_2: a tensor of dimensions (n2, 4), anchor表示成(xmin, ymin, xmax, ymax)
Returns:
intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, shape: (n1, n2)
"""
# PyTorch auto-broadcasts singleton dimensions
lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0)) # (n1, n2, 2)
upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0)) # (n1, n2, 2)
intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 2)
return intersection_dims[:, :, 0] * intersection_dims[:, :, 1] # (n1, n2)
計(jì)算兩個(gè)集合的jaccard系數(shù)狗准,即相似度。相并面積即兩集合面積之和減去相交面積
def compute_jaccard(set_1, set_2):
"""
計(jì)算anchor之間的Jaccard系數(shù)(IoU)
Args:
set_1: a tensor of dimensions (n1, 4), anchor表示成(xmin, ymin, xmax, ymax)
set_2: a tensor of dimensions (n2, 4), anchor表示成(xmin, ymin, xmax, ymax)
Returns:
Jaccard Overlap of each of the boxes in set 1 with respect to each of the boxes in set 2, shape: (n1, n2)
"""
# Find intersections
intersection = compute_intersection(set_1, set_2) # (n1, n2)
# Find areas of each box in both sets
areas_set_1 = (set_1[:, 2] - set_1[:, 0]) * (set_1[:, 3] - set_1[:, 1]) # (n1)
areas_set_2 = (set_2[:, 2] - set_2[:, 0]) * (set_2[:, 3] - set_2[:, 1]) # (n2)
# Find the union
# PyTorch auto-broadcasts singleton dimensions
union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection # (n1, n2)
return intersection / union # (n1, n2)
為錨框標(biāo)注標(biāo)簽
我們將每一個(gè)錨框視為一個(gè)訓(xùn)練樣本茵肃,所以我們需要給錨框標(biāo)注標(biāo)簽:1.真實(shí)標(biāo)簽 2.錨框與真實(shí)邊框的偏移量腔长,對(duì)于測(cè)試集,即目標(biāo)檢測(cè)時(shí)验残,我們分別預(yù)測(cè)出錨框的預(yù)測(cè)類別和偏移量(offset)捞附,接著通過預(yù)測(cè)的偏移量調(diào)整錨框位置從而得到預(yù)測(cè)邊界框。
那么如何給錨框分配真實(shí)邊框呢?
假設(shè)錨框分別為鸟召,真實(shí)邊框分別為
胆绊,最終我們可以得到一個(gè)交并比矩陣。對(duì)于這個(gè)矩陣欧募,我們每次找其中最大的元素压状,然后清空這個(gè)元素所在的行和列,這樣遞歸下去跟继。直到矩陣為空种冬。
那么又如何設(shè)定偏移量呢?
由于各個(gè)框和位置各異舔糖,我們需要一種方法能夠?qū)⑺鼈儦w一化娱两,一般是采用這種辦法:
特別的,如果一個(gè)錨框沒有被分配真實(shí)邊框金吗,便將其真實(shí)邊框設(shè)為背景十兢,這類錨框稱為負(fù)類錨框,其余稱為正類錨框摇庙。
舉個(gè)栗子:
其中g(shù)round_truth為真實(shí)邊框旱物,第一位是類別,其他四個(gè)元素是左上角和右下角的的坐標(biāo)跟匆。
bbox_scale = torch.tensor((w, h, w, h), dtype=torch.float32)
ground_truth = torch.tensor([[0, 0.1, 0.08, 0.52, 0.92],
[1, 0.55, 0.2, 0.9, 0.88]])
anchors = torch.tensor([[0, 0.1, 0.2, 0.3], [0.15, 0.2, 0.4, 0.4],
[0.63, 0.05, 0.88, 0.98], [0.66, 0.45, 0.8, 0.8],
[0.57, 0.3, 0.92, 0.9]])
fig = d2l.plt.imshow(img)
show_bboxes(fig.axes, ground_truth[:, 1:] * bbox_scale, ['dog', 'cat'], 'k')
show_bboxes(fig.axes, anchors * bbox_scale, ['0', '1', '2', '3', '4']);
下面使用之前實(shí)現(xiàn)的MultiBoxTarget函數(shù)來為錨框標(biāo)注類別和偏移量异袄。將函數(shù)背景類別設(shè)為0。
下面為分配錨框和標(biāo)注標(biāo)簽玛臂、偏移量的實(shí)現(xiàn)烤蜕。
def assign_anchor(bb, anchor, jaccard_threshold=0.5):
"""
# 按照「9.4.1. 生成多個(gè)錨框」圖9.3所講為每個(gè)anchor分配真實(shí)的bb, anchor表示成歸一化(xmin, ymin, xmax, ymax).
https://zh.d2l.ai/chapter_computer-vision/anchor.html
Args:
bb: 真實(shí)邊界框(bounding box), shape:(nb, 4)
anchor: 待分配的anchor, shape:(na, 4)
jaccard_threshold: 預(yù)先設(shè)定的閾值
Returns:
assigned_idx: shape: (na, ), 每個(gè)anchor分配的真實(shí)bb對(duì)應(yīng)的索引, 若未分配任何bb則為-1
"""
na = anchor.shape[0]
nb = bb.shape[0]
jaccard = compute_jaccard(anchor, bb).detach().cpu().numpy() # shape: (na, nb)
assigned_idx = np.ones(na) * -1 # 存放標(biāo)簽初始全為-1
# 先為每個(gè)bb分配一個(gè)anchor(不要求滿足jaccard_threshold)
jaccard_cp = jaccard.copy()
for j in range(nb):
i = np.argmax(jaccard_cp[:, j])
assigned_idx[i] = j
jaccard_cp[i, :] = float("-inf") # 賦值為負(fù)無窮, 相當(dāng)于去掉這一行
# 處理還未被分配的anchor, 要求滿足jaccard_threshold
for i in range(na):
if assigned_idx[i] == -1:
j = np.argmax(jaccard[i, :])
if jaccard[i, j] >= jaccard_threshold:
assigned_idx[i] = j
return torch.tensor(assigned_idx, dtype=torch.long)
def xy_to_cxcy(xy):
"""
將(x_min, y_min, x_max, y_max)形式的anchor轉(zhuǎn)換成(center_x, center_y, w, h)形式的.
https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection/blob/master/utils.py
Args:
xy: bounding boxes in boundary coordinates, a tensor of size (n_boxes, 4)
Returns:
bounding boxes in center-size coordinates, a tensor of size (n_boxes, 4)
"""
return torch.cat([(xy[:, 2:] + xy[:, :2]) / 2, # c_x, c_y
xy[:, 2:] - xy[:, :2]], 1) # w, h
def MultiBoxTarget(anchor, label):
"""
# 按照「9.4.1. 生成多個(gè)錨框」所講的實(shí)現(xiàn), anchor表示成歸一化(xmin, ymin, xmax, ymax).
https://zh.d2l.ai/chapter_computer-vision/anchor.html
Args:
anchor: torch tensor, 輸入的錨框, 一般是通過MultiBoxPrior生成, shape:(1,錨框總數(shù)迹冤,4)
label: 真實(shí)標(biāo)簽, shape為(bn, 每張圖片最多的真實(shí)錨框數(shù), 5)
第二維中讽营,如果給定圖片沒有這么多錨框, 可以先用-1填充空白, 最后一維中的元素為[類別標(biāo)簽, 四個(gè)坐標(biāo)值]
Returns:
列表, [bbox_offset, bbox_mask, cls_labels]
bbox_offset: 每個(gè)錨框的標(biāo)注偏移量,形狀為(bn泡徙,錨框總數(shù)*4)
bbox_mask: 形狀同bbox_offset, 每個(gè)錨框的掩碼, 一一對(duì)應(yīng)上面的偏移量, 負(fù)類錨框(背景)對(duì)應(yīng)的掩碼均為0, 正類錨框的掩碼均為1
cls_labels: 每個(gè)錨框的標(biāo)注類別, 其中0表示為背景, 形狀為(bn橱鹏,錨框總數(shù))
"""
assert len(anchor.shape) == 3 and len(label.shape) == 3
bn = label.shape[0]
def MultiBoxTarget_one(anc, lab, eps=1e-6):
"""
MultiBoxTarget函數(shù)的輔助函數(shù), 處理batch中的一個(gè)
Args:
anc: shape of (錨框總數(shù), 4)
lab: shape of (真實(shí)錨框數(shù), 5), 5代表[類別標(biāo)簽, 四個(gè)坐標(biāo)值]
eps: 一個(gè)極小值, 防止log0
Returns:
offset: (錨框總數(shù)*4, )
bbox_mask: (錨框總數(shù)*4, ), 0代表背景, 1代表非背景
cls_labels: (錨框總數(shù), 4), 0代表背景
"""
an = anc.shape[0]
# 變量的意義
assigned_idx = assign_anchor(lab[:, 1:], anc) # (錨框總數(shù), )
print("a: ", assigned_idx.shape)
print(assigned_idx)
bbox_mask = ((assigned_idx >= 0).float().unsqueeze(-1)).repeat(1, 4) # (錨框總數(shù), 4)
print("b: " , bbox_mask.shape)
print(bbox_mask)
cls_labels = torch.zeros(an, dtype=torch.long) # 0表示背景
assigned_bb = torch.zeros((an, 4), dtype=torch.float32) # 所有anchor對(duì)應(yīng)的bb坐標(biāo)
for i in range(an):
bb_idx = assigned_idx[i]
if bb_idx >= 0: # 即非背景
cls_labels[i] = lab[bb_idx, 0].long().item() + 1 # 注意要加一
assigned_bb[i, :] = lab[bb_idx, 1:]
# 如何計(jì)算偏移量
center_anc = xy_to_cxcy(anc) # (center_x, center_y, w, h)
center_assigned_bb = xy_to_cxcy(assigned_bb)
offset_xy = 10.0 * (center_assigned_bb[:, :2] - center_anc[:, :2]) / center_anc[:, 2:]
offset_wh = 5.0 * torch.log(eps + center_assigned_bb[:, 2:] / center_anc[:, 2:])
offset = torch.cat([offset_xy, offset_wh], dim = 1) * bbox_mask # (錨框總數(shù), 4)
return offset.view(-1), bbox_mask.view(-1), cls_labels
# 組合輸出
batch_offset = []
batch_mask = []
batch_cls_labels = []
for b in range(bn):
offset, bbox_mask, cls_labels = MultiBoxTarget_one(anchor[0, :, :], label[b, :, :])
batch_offset.append(offset)
batch_mask.append(bbox_mask)
batch_cls_labels.append(cls_labels)
bbox_offset = torch.stack(batch_offset)
bbox_mask = torch.stack(batch_mask)
cls_labels = torch.stack(batch_cls_labels)
return [bbox_offset, bbox_mask, cls_labels]
測(cè)試:
由于第一維要求為圖片數(shù)目維,所以增加一維堪藐。
labels = MultiBoxTarget(anchors.unsqueeze(dim=0),
ground_truth.unsqueeze(dim=0))
返回一個(gè)列表莉兰,第一項(xiàng)是偏差,第二項(xiàng)為掩碼變量(形狀:批量大小礁竞、錨框個(gè)數(shù)的四倍糖荒,對(duì)應(yīng)錨框的四個(gè)偏移量),第三項(xiàng)為錨框的類別
輸出預(yù)測(cè)邊界框
非極大值抑制
當(dāng)錨框數(shù)量較多時(shí)模捂,同一個(gè)目標(biāo)上可能會(huì)有許多相似的預(yù)測(cè)邊界框捶朵,我們移除相似的預(yù)測(cè)邊界框蜘矢。
原理:對(duì)于一個(gè)預(yù)測(cè)邊界框B,模型計(jì)算其各個(gè)類別的預(yù)測(cè)概率综看,假設(shè)最大概率為p品腹,該概率所對(duì)應(yīng)的類別即為B的預(yù)測(cè)類別,稱p為預(yù)測(cè)邊界框B的置信度红碑。對(duì)所有預(yù)測(cè)邊界框(非背景)按置信度由高向低排序形成列表L舞吭。對(duì)L,設(shè)高到底最高置信度的預(yù)測(cè)邊界框?yàn)榛鶞?zhǔn)句喷,將與它的交并比大于某一閾值的非基準(zhǔn)邊框移除镣典,并重復(fù)這一過程,直到所有邊界框都是基準(zhǔn)邊界框唾琼。
下面是MultiBoxDetection的實(shí)現(xiàn)
from collections import namedtuple
Pred_BB_Info = namedtuple("Pred_BB_Info", ["index", "class_id", "confidence", "xyxy"])
def non_max_suppression(bb_info_list, nms_threshold = 0.5):
"""
非極大抑制處理預(yù)測(cè)的邊界框
Args:
bb_info_list: Pred_BB_Info的列表, 包含預(yù)測(cè)類別兄春、置信度等信息
nms_threshold: 閾值
Returns:
output: Pred_BB_Info的列表, 只保留過濾后的邊界框信息
"""
output = []
# 先根據(jù)置信度從高到低排序
sorted_bb_info_list = sorted(bb_info_list, key = lambda x: x.confidence, reverse=True)
# 循環(huán)遍歷刪除冗余輸出
while len(sorted_bb_info_list) != 0:
best = sorted_bb_info_list.pop(0)
output.append(best)
if len(sorted_bb_info_list) == 0:
break
bb_xyxy = []
for bb in sorted_bb_info_list:
bb_xyxy.append(bb.xyxy)
iou = compute_jaccard(torch.tensor([best.xyxy]),
torch.tensor(bb_xyxy))[0] # shape: (len(sorted_bb_info_list), )
n = len(sorted_bb_info_list)
sorted_bb_info_list = [sorted_bb_info_list[i] for i in range(n) if iou[i] <= nms_threshold]
return output
def MultiBoxDetection(cls_prob, loc_pred, anchor, nms_threshold = 0.5):
"""
# 按照「9.4.1. 生成多個(gè)錨框」所講的實(shí)現(xiàn), anchor表示成歸一化(xmin, ymin, xmax, ymax).
https://zh.d2l.ai/chapter_computer-vision/anchor.html
Args:
cls_prob: 經(jīng)過softmax后得到的各個(gè)錨框的預(yù)測(cè)概率, shape:(bn, 預(yù)測(cè)總類別數(shù)+1, 錨框個(gè)數(shù))
loc_pred: 預(yù)測(cè)的各個(gè)錨框的偏移量, shape:(bn, 錨框個(gè)數(shù)*4)
anchor: MultiBoxPrior輸出的默認(rèn)錨框, shape: (1, 錨框個(gè)數(shù), 4)
nms_threshold: 非極大抑制中的閾值
Returns:
所有錨框的信息, shape: (bn, 錨框個(gè)數(shù), 6)
每個(gè)錨框信息由[class_id, confidence, xmin, ymin, xmax, ymax]表示
class_id=-1 表示背景或在非極大值抑制中被移除了
"""
assert len(cls_prob.shape) == 3 and len(loc_pred.shape) == 2 and len(anchor.shape) == 3
bn = cls_prob.shape[0]
def MultiBoxDetection_one(c_p, l_p, anc, nms_threshold = 0.5):
"""
MultiBoxDetection的輔助函數(shù), 處理batch中的一個(gè)
Args:
c_p: (預(yù)測(cè)總類別數(shù)+1, 錨框個(gè)數(shù))
l_p: (錨框個(gè)數(shù)*4, )
anc: (錨框個(gè)數(shù), 4)
nms_threshold: 非極大抑制中的閾值
Return:
output: (錨框個(gè)數(shù), 6)
"""
pred_bb_num = c_p.shape[1]
anc = (anc + l_p.view(pred_bb_num, 4)).detach().cpu().numpy() # 加上偏移量
confidence, class_id = torch.max(c_p, 0)
confidence = confidence.detach().cpu().numpy()
class_id = class_id.detach().cpu().numpy()
pred_bb_info = [Pred_BB_Info(
index = i,
class_id = class_id[i] - 1, # 正類label從0開始
confidence = confidence[i],
xyxy=[*anc[i]]) # xyxy是個(gè)列表
for i in range(pred_bb_num)]
# 正類的index
obj_bb_idx = [bb.index for bb in non_max_suppression(pred_bb_info, nms_threshold)]
output = []
for bb in pred_bb_info:
output.append([
(bb.class_id if bb.index in obj_bb_idx else -1.0),
bb.confidence,
*bb.xyxy
])
return torch.tensor(output) # shape: (錨框個(gè)數(shù), 6)
batch_output = []
for b in range(bn):
batch_output.append(MultiBoxDetection_one(cls_prob[b], loc_pred[b], anchor[0], nms_threshold))
return torch.stack(batch_output)
測(cè)試:
output = MultiBoxDetection(
cls_probs.unsqueeze(dim=0), offset_preds.unsqueeze(dim=0),
anchors.unsqueeze(dim=0), nms_threshold=0.5)
output
輸出:其中,輸入都增加了樣本維(0維)锡溯,輸出的第一維是類別赶舆,-1表示被移除,第二個(gè)元素是預(yù)測(cè)邊界框的置信度祭饭,后四個(gè)元素是左上角和右下角的坐標(biāo)芜茵。
tensor([[[ 0.0000, 0.9000, 0.1000, 0.0800, 0.5200, 0.9200],
[-1.0000, 0.8000, 0.0800, 0.2000, 0.5600, 0.9500],
[-1.0000, 0.7000, 0.1500, 0.3000, 0.6200, 0.9100],
[ 1.0000, 0.9000, 0.5500, 0.2000, 0.9000, 0.8800]]])
實(shí)際操作中,在非極大抑制前還可以刪除一些置信度較低的預(yù)測(cè)邊界框來減少后續(xù)計(jì)算量倡蝙。
多尺度目標(biāo)檢測(cè)
現(xiàn)實(shí)中九串,對(duì)每個(gè)像素都生成錨框很容易計(jì)算量過大。
減少錨框個(gè)數(shù)的方法:
1.在輸入圖像中均勻采樣一小部分像素
2.在不同尺度下生成數(shù)量和大小不同的錨框寺鸥。
例子:使用較小錨框來檢測(cè)較小的目標(biāo)時(shí)可以采樣較多的區(qū)域猪钮,而用較大錨框來檢測(cè)較大目標(biāo)時(shí),可以采樣較少區(qū)域胆建。
2.風(fēng)格圖像遷移
風(fēng)格圖像遷移即將兩張圖象的內(nèi)容和樣式進(jìn)行合成烤低。其中一張是內(nèi)容圖像,一張是樣式圖像
具體原理如下:
首先初始化一個(gè)合成圖像笆载,然后選擇一個(gè)預(yù)訓(xùn)練過的卷積神經(jīng)網(wǎng)絡(luò)來抽取特征扑馁。一般來說靠近輸入層的特征多包括圖像的細(xì)節(jié)特征,靠近輸出層的特征包含圖像的整體特征凉驻。
以下圖為例腻要,預(yù)訓(xùn)練的神經(jīng)網(wǎng)絡(luò)有三個(gè)卷積層,其中第二層輸出的圖像是內(nèi)容特征涝登,第一層和第三層輸出的圖像是細(xì)節(jié)特征闯第,通過如圖實(shí)線的正向傳播,并計(jì)算損失函數(shù)(內(nèi)容損失缀拭、樣式損失咳短、總差變損失),其中總差變損失有助于減少圖像的噪點(diǎn)蛛淋。訓(xùn)練結(jié)束時(shí)咙好,輸出樣式遷移的模型參數(shù)得到合成圖像。
引入一些庫:
%matplotlib inline
import time
import torch
import torch.nn.functional as F
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import sys
sys.path.append("/home/kesci/input")
import d2len9900 as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 均已測(cè)試
print(device, torch.__version__)
讀取圖像
content_img = Image.open('/home/kesci/input/NeuralStyle5603/rainier.jpg')
plt.imshow(content_img);
style_img = Image.open('/home/kesci/input/NeuralStyle5603/autumn_oak.jpg')
plt.imshow(style_img);
接著還要對(duì)圖像進(jìn)行預(yù)處理弟蚀,首先對(duì)圖像在rgb三個(gè)通道做標(biāo)準(zhǔn)化淋样,然后將維度變成神經(jīng)網(wǎng)絡(luò)可以接受的格式各吨。然后我們還需要后處理函數(shù),對(duì)輸出的圖像的像素值還原回標(biāo)準(zhǔn)化之前的值层宫,并映射到0-1之間
rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])
def preprocess(PIL_img, image_shape):
process = torchvision.transforms.Compose([
torchvision.transforms.Resize(image_shape),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
return process(PIL_img).unsqueeze(dim = 0) # (batch_size, 3, H, W)
def postprocess(img_tensor):
inv_normalize = torchvision.transforms.Normalize(
mean= -rgb_mean / rgb_std,
std= 1/rgb_std)
to_PIL_image = torchvision.transforms.ToPILImage()
return to_PIL_image(inv_normalize(img_tensor[0].cpu()).clamp(0, 1))
載入VGG-19模型來抽取特征
VGG-19含有5個(gè)VGG卷積塊,選擇第四個(gè)卷積塊的最后一層作為內(nèi)容層其监,選擇五個(gè)卷積塊的第一層作為樣式層萌腿。最后將要用到的VGG網(wǎng)絡(luò)的層抽取出來,使用nn.Sequential()構(gòu)建一個(gè)新的網(wǎng)絡(luò)
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
net_list = []
for i in range(max(content_layers + style_layers) + 1):
net_list.append(pretrained_net.features[i])
net = torch.nn.Sequential(*net_list)
由于正常訓(xùn)練我們只能獲得最后一層的輸出抖苦,所以我們需要逐層前向計(jì)算net
def extract_features(X, content_layers, style_layers):
contents = []
styles = []
for i in range(len(net)):
X = net[i](X)
if i in style_layers:
styles.append(X)
if i in content_layers:
contents.append(X)
return contents, styles
定義兩個(gè)用于獲得內(nèi)容特征和樣式特征的函數(shù)
def get_contents(image_shape, device):
content_X = preprocess(content_img, image_shape).to(device)
contents_Y, _ = extract_features(content_X, content_layers, style_layers)
return content_X, contents_Y
def get_styles(image_shape, device):
style_X = preprocess(style_img, image_shape).to(device)
_, styles_Y = extract_features(style_X, content_layers, style_layers)
return style_X, styles_Y
接下來我們要計(jì)算三部分的損失函數(shù):
內(nèi)容損失采用MSE作為損失函數(shù)
def content_loss(Y_hat, Y):
return F.mse_loss(Y_hat, Y)
樣式損失也同樣用均方誤差來計(jì)算
但是需要對(duì)樣式層輸出做一些處理:輸出的樣本數(shù)為1毁菱,通道數(shù)為c,高和寬分別為h和w锌历,將輸出變?yōu)閏行h*w列的矩陣X贮庞,代表了通道i上的樣式特征,計(jì)算這個(gè)矩陣的格拉姆矩陣
究西,即計(jì)算了向量
和
的內(nèi)積窗慎,它表達(dá)了這兩個(gè)通道上特征的相關(guān)性。由于計(jì)算內(nèi)積之后元素容易出現(xiàn)較大的值卤材,所以最后要除以矩陣中元素的個(gè)數(shù)
def gram(X):
num_channels, n = X.shape[1], X.shape[2] * X.shape[3]
X = X.view(num_channels, n)
return torch.matmul(X, X.t()) / (num_channels * n)
def style_loss(Y_hat, gram_Y):
return F.mse_loss(gram(Y_hat), gram_Y)
總變差損失用于減少合成圖像中的噪點(diǎn)(特別亮或特別暗的元素)遮斥。我們常用總變差降噪(total variation denoising),假設(shè)為坐標(biāo)(i,j)的像素值商膊,則總變差損失為:
它的目的就是使相鄰的像素值盡可能相近
def tv_loss(Y_hat):
return 0.5 * (F.l1_loss(Y_hat[:, :, 1:, :], Y_hat[:, :, :-1, :]) +
F.l1_loss(Y_hat[:, :, :, 1:], Y_hat[:, :, :, :-1]))
最后伏伐,我們還有一個(gè)總的樣式遷移損失,它是三種損失函數(shù)的加權(quán)和晕拆,通過超參數(shù)我們可以調(diào)節(jié)內(nèi)容藐翎、樣式、噪點(diǎn)三個(gè)方面的重要性实幕。
content_weight, style_weight, tv_weight = 1, 1e3, 10
def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
# 分別計(jì)算內(nèi)容損失吝镣、樣式損失和總變差損失
contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
contents_Y_hat, contents_Y)]
styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
styles_Y_hat, styles_Y_gram)]
tv_l = tv_loss(X) * tv_weight
# 對(duì)所有損失求和
l = sum(styles_l) + sum(contents_l) + tv_l
return contents_l, styles_l, tv_l, l
最后我們需要?jiǎng)?chuàng)建和初始化合成圖像,合成圖像是唯一需要更新的變量
class GeneratedImage(torch.nn.Module):
def __init__(self, img_shape):
super(GeneratedImage, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(*img_shape))
def forward(self):
return self.weight
def get_inits(X, device, lr, styles_Y):
gen_img = GeneratedImage(X.shape).to(device)
gen_img.weight.data = X.data
optimizer = torch.optim.Adam(gen_img.parameters(), lr=lr)
styles_Y_gram = [gram(Y) for Y in styles_Y]
return gen_img(), styles_Y_gram, optimizer
訓(xùn)練:
def train(X, contents_Y, styles_Y, device, lr, max_epochs, lr_decay_epoch):
print("training on ", device)
X, styles_Y_gram, optimizer = get_inits(X, device, lr, styles_Y)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_decay_epoch, gamma=0.1)
for i in range(max_epochs):
start = time.time()
contents_Y_hat, styles_Y_hat = extract_features(
X, content_layers, style_layers)
contents_l, styles_l, tv_l, l = compute_loss(
X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
optimizer.zero_grad()
l.backward(retain_graph = True)
optimizer.step()
scheduler.step()
if i % 50 == 0 and i != 0:
print('epoch %3d, content loss %.2f, style loss %.2f, '
'TV loss %.2f, %.2f sec'
% (i, sum(contents_l).item(), sum(styles_l).item(), tv_l.item(),
time.time() - start))
return X.detach()
image_shape = (150, 225)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
style_X, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.01, 500, 200)
plt.imshow(postprocess(output));
3.批量歸一化和殘差網(wǎng)絡(luò)
批量歸一化(BatchNormalization)
利用小批量的均值和標(biāo)準(zhǔn)差昆庇,不斷調(diào)整神經(jīng)網(wǎng)絡(luò)的中間輸出末贾,從而使神經(jīng)網(wǎng)絡(luò)在各層的中間輸出的數(shù)值更穩(wěn)定。
對(duì)全連接層做批量歸一化的位置是仿射變換和激活函數(shù)之間
其中BN函數(shù)為
其中拉伸參數(shù)和偏移參數(shù)是可學(xué)習(xí)參數(shù)整吆。
如果卷積層有多個(gè)通道拱撵,要對(duì)每個(gè)通道都單獨(dú)歸一化辉川,每個(gè)通道的拉伸參數(shù)和偏移參數(shù)需要相同。
總的來說就是:
訓(xùn)練時(shí)以batch為單位進(jìn)行計(jì)算均值和方差拴测。預(yù)測(cè)時(shí)用移動(dòng)平均法估算整個(gè)訓(xùn)練樣本的均值和方差乓旗。
具體實(shí)現(xiàn):
def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):
# 判斷當(dāng)前模式是訓(xùn)練模式還是預(yù)測(cè)模式
if not is_training:
# 如果是在預(yù)測(cè)模式下,直接使用傳入的移動(dòng)平均所得的均值和方差
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# 使用全連接層的情況集索,計(jì)算特征維上的均值和方差
mean = X.mean(dim=0)
var = ((X - mean) ** 2).mean(dim=0)
else:
# 使用二維卷積層的情況屿愚,計(jì)算通道維上(axis=1)的均值和方差。這里我們需要保持
# X的形狀以便后面可以做廣播運(yùn)算
mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
# 訓(xùn)練模式下用當(dāng)前的均值和方差做標(biāo)準(zhǔn)化
X_hat = (X - mean) / torch.sqrt(var + eps)
# 更新移動(dòng)平均的均值和方差
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
Y = gamma * X_hat + beta # 拉伸和偏移
return Y, moving_mean, moving_var
class BatchNorm(nn.Module):
def __init__(self, num_features, num_dims):
super(BatchNorm, self).__init__()
if num_dims == 2:
shape = (1, num_features) #全連接層輸出神經(jīng)元
else:
shape = (1, num_features, 1, 1) #通道數(shù)
# 參與求梯度和迭代的拉伸和偏移參數(shù)务荆,分別初始化成0和1
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
# 不參與求梯度和迭代的變量妆距,全在內(nèi)存上初始化成0
self.moving_mean = torch.zeros(shape)
self.moving_var = torch.zeros(shape)
def forward(self, X):
# 如果X不在內(nèi)存上,將moving_mean和moving_var復(fù)制到X所在顯存上
if self.moving_mean.device != X.device:
self.moving_mean = self.moving_mean.to(X.device)
self.moving_var = self.moving_var.to(X.device)
# 保存更新過的moving_mean和moving_var, Module實(shí)例的traning屬性默認(rèn)為true, 調(diào)用.eval()后設(shè)成false
Y, self.moving_mean, self.moving_var = batch_norm(self.training,
X, self.gamma, self.beta, self.moving_mean,
self.moving_var, eps=1e-5, momentum=0.9)
return Y
傳入LeNet網(wǎng)絡(luò)中
net = nn.Sequential(
nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
BatchNorm(6, num_dims=4),
nn.Sigmoid(),
nn.MaxPool2d(2, 2), # kernel_size, stride
nn.Conv2d(6, 16, 5),
BatchNorm(16, num_dims=4),
nn.Sigmoid(),
nn.MaxPool2d(2, 2),
d2l.FlattenLayer(),
nn.Linear(16*4*4, 120),
BatchNorm(120, num_dims=2),
nn.Sigmoid(),
nn.Linear(120, 84),
BatchNorm(84, num_dims=2),
nn.Sigmoid(),
nn.Linear(84, 10)
)
print(net)
當(dāng)然也有簡(jiǎn)潔實(shí)現(xiàn):
nn庫中有集成BatchNorm函匕,分別是BatchNorm2d和BatchNorm1d
net = nn.Sequential(
nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
nn.BatchNorm2d(6),
nn.Sigmoid(),
nn.MaxPool2d(2, 2), # kernel_size, stride
nn.Conv2d(6, 16, 5),
nn.BatchNorm2d(16),
nn.Sigmoid(),
nn.MaxPool2d(2, 2),
d2l.FlattenLayer(),
nn.Linear(16*4*4, 120),
nn.BatchNorm1d(120),
nn.Sigmoid(),
nn.Linear(120, 84),
nn.BatchNorm1d(84),
nn.Sigmoid(),
nn.Linear(84, 10)
)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
殘差網(wǎng)絡(luò)(ResNet)
CNN的缺陷是當(dāng)網(wǎng)絡(luò)過深時(shí)娱据,網(wǎng)絡(luò)的收斂性和準(zhǔn)確性都會(huì)變差。
右邊便是殘差網(wǎng)絡(luò)浦箱,將輸入的值直接加到輸出上吸耿,這樣給訓(xùn)練的負(fù)擔(dān)更小,傳播的也更快酷窥。
殘差網(wǎng)絡(luò)的意思是神經(jīng)網(wǎng)絡(luò)的輸出是輸入X的殘差
class Residual(nn.Module): # 本類已保存在d2lzh_pytorch包中方便以后使用
#可以設(shè)定輸出通道數(shù)咽安、是否使用額外的1x1卷積層來修改通道數(shù)以及卷積層的步幅。
def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
super(Residual, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
return F.relu(Y + X)
對(duì)于ResNet模型蓬推,它是由以下網(wǎng)絡(luò)結(jié)構(gòu)構(gòu)成:
net = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
if first_block:
assert in_channels == out_channels # 第一個(gè)模塊的通道數(shù)同輸入通道數(shù)一致
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
else:
blk.append(Residual(out_channels, out_channels))
return nn.Sequential(*blk)
net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
net.add_module("resnet_block2", resnet_block(64, 128, 2))
net.add_module("resnet_block3", resnet_block(128, 256, 2))
net.add_module("resnet_block4", resnet_block(256, 512, 2))
net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的輸出: (Batch, 512, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(512, 10)))
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
稠密連接網(wǎng)絡(luò)(DenseNet)
稠密連接網(wǎng)絡(luò)是殘差網(wǎng)絡(luò)的變種妆棒,它使用列表將輸出連接。
它主要由稠密塊和過渡塊構(gòu)成沸伏。稠密塊決定輸入和輸出的連接方式糕珊,過渡層用來控制通道數(shù)。
稠密塊python實(shí)現(xiàn):
def conv_block(in_channels, out_channels):
blk = nn.Sequential(nn.BatchNorm2d(in_channels),
nn.ReLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
return blk
class DenseBlock(nn.Module):
def __init__(self, num_convs, in_channels, out_channels):
super(DenseBlock, self).__init__()
net = []
for i in range(num_convs):
in_c = in_channels + i * out_channels
net.append(conv_block(in_c, out_channels))
self.net = nn.ModuleList(net)
self.out_channels = in_channels + num_convs * out_channels # 計(jì)算輸出通道數(shù)
def forward(self, X):
for blk in self.net:
Y = blk(X)
X = torch.cat((X, Y), dim=1) # 在通道維上將輸入和輸出連結(jié)
return X
對(duì)于過渡層毅糟,采用1*1卷積來減少通道數(shù)红选,用步幅為2的平均池化層來減半高和寬
def transition_block(in_channels, out_channels):
blk = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.ReLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.AvgPool2d(kernel_size=2, stride=2))
return blk
所以DenseNet模型實(shí)現(xiàn)如下:
net = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
num_channels, growth_rate = 64, 32 # num_channels為當(dāng)前的通道數(shù)
num_convs_in_dense_blocks = [4, 4, 4, 4]
for i, num_convs in enumerate(num_convs_in_dense_blocks):
DB = DenseBlock(num_convs, num_channels, growth_rate)
net.add_module("DenseBlosk_%d" % i, DB)
# 上一個(gè)稠密塊的輸出通道數(shù)
num_channels = DB.out_channels
# 在稠密塊之間加入通道數(shù)減半的過渡層
if i != len(num_convs_in_dense_blocks) - 1:
net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))
num_channels = num_channels // 2
net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的輸出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10)))
X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():
X = layer(X)
print(name, ' output shape:\t', X.shape)
batch_size=16
# 如出現(xiàn)“out of memory”的報(bào)錯(cuò)信息,可減小batch_size或resize
train_iter, test_iter =load_data_fashion_mnist(batch_size, resize=96)
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
4.GAN(生成對(duì)抗網(wǎng)絡(luò))
生成對(duì)抗網(wǎng)絡(luò)含有兩個(gè)神經(jīng)網(wǎng)絡(luò)姆另,一個(gè)是Discrimination分類器喇肋,一個(gè)是Generator生成器,生成器通過訓(xùn)練產(chǎn)生與真實(shí)數(shù)據(jù)同分布的數(shù)據(jù)來騙過分類器迹辐,分類器通過分辨
分類器的損失函數(shù)是交叉熵:
生成器的損失函數(shù)是:
引入一些包
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch import nn
import numpy as np
from torch.autograd import Variable
import torch
生成器的實(shí)現(xiàn):
class net_G(nn.Module):
def __init__(self):
super(net_G,self).__init__()
self.model=nn.Sequential(
nn.Linear(2,2),
)
self._initialize_weights()
def forward(self,x):
x=self.model(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m,nn.Linear):
m.weight.data.normal_(0,0.02)
m.bias.data.zero_()
分類器實(shí)現(xiàn):
class net_D(nn.Module):
def __init__(self):
super(net_D,self).__init__()
self.model=nn.Sequential(
nn.Linear(2,5),
nn.Tanh(),
nn.Linear(5,3),
nn.Tanh(),
nn.Linear(3,1),
nn.Sigmoid()
)
self._initialize_weights()
def forward(self,x):
x=self.model(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m,nn.Linear):
m.weight.data.normal_(0,0.02)
m.bias.data.zero_()
def update_D(X,Z,net_D,net_G,loss,trainer_D):
batch_size=X.shape[0]
Tensor=torch.FloatTensor
ones=Variable(Tensor(np.ones(batch_size))).view(batch_size,1)
zeros = Variable(Tensor(np.zeros(batch_size))).view(batch_size,1)
real_Y=net_D(X.float())
fake_X=net_G(Z)
fake_Y=net_D(fake_X)
loss_D=(loss(real_Y,ones)+loss(fake_Y,zeros))/2
loss_D.backward()
trainer_D.step()
return float(loss_D.sum())
def update_G(Z,net_D,net_G,loss,trainer_G):
batch_size=Z.shape[0]
Tensor=torch.FloatTensor
ones=Variable(Tensor(np.ones((batch_size,)))).view(batch_size,1)
fake_X=net_G(Z)
fake_Y=net_D(fake_X)
loss_G=loss(fake_Y,ones)
loss_G.backward()
trainer_G.step()
return float(loss_G.sum())
最終進(jìn)行訓(xùn)練:
def train(net_D,net_G,data_iter,num_epochs,lr_D,lr_G,latent_dim,data):
loss=nn.BCELoss()
Tensor=torch.FloatTensor
trainer_D=torch.optim.Adam(net_D.parameters(),lr=lr_D)
trainer_G=torch.optim.Adam(net_G.parameters(),lr=lr_G)
plt.figure(figsize=(7,4))
d_loss_point=[]
g_loss_point=[]
d_loss=0
g_loss=0
for epoch in range(1,num_epochs+1):
d_loss_sum=0
g_loss_sum=0
batch=0
for X in data_iter:
batch+=1
X=Variable(X)
batch_size=X.shape[0]
Z=Variable(Tensor(np.random.normal(0,1,(batch_size,latent_dim))))
trainer_D.zero_grad()
d_loss = update_D(X, Z, net_D, net_G, loss, trainer_D)
d_loss_sum+=d_loss
trainer_G.zero_grad()
g_loss = update_G(Z, net_D, net_G, loss, trainer_G)
g_loss_sum+=g_loss
d_loss_point.append(d_loss_sum/batch)
g_loss_point.append(g_loss_sum/batch)
plt.ylabel('Loss', fontdict={'size': 14})
plt.xlabel('epoch', fontdict={'size': 14})
plt.xticks(range(0,num_epochs+1,3))
plt.plot(range(1,num_epochs+1),d_loss_point,color='orange',label='discriminator')
plt.plot(range(1,num_epochs+1),g_loss_point,color='blue',label='generator')
plt.legend()
plt.show()
print(d_loss,g_loss)
Z =Variable(Tensor( np.random.normal(0, 1, size=(100, latent_dim))))
fake_X=net_G(Z).detach().numpy()
plt.figure(figsize=(3.5,2.5))
plt.scatter(data[:,0],data[:,1],color='blue',label='real')
plt.scatter(fake_X[:,0],fake_X[:,1],color='orange',label='generated')
plt.legend()
plt.show()
訓(xùn)練:
lr_D,lr_G,latent_dim,num_epochs=0.05,0.005,2,20
generator=net_G()
discriminator=net_D()
train(discriminator,generator,data_iter,num_epochs,lr_D,lr_G,latent_dim,data)
5.數(shù)據(jù)增強(qiáng)
對(duì)圖像來說蝶防,主要包括圖像增廣、翻轉(zhuǎn)明吩、裁剪间学、變化顏色、改變對(duì)比度、亮度等等低葫。進(jìn)行圖像增強(qiáng)可以減少模型對(duì)某種屬性的依賴详羡,增強(qiáng)泛化能力。