名稱(chēng):Sub-Image Anomaly Detection with Deep Pyramid Correspondences
SPADE是一種通過(guò)特征對(duì)比的方法進(jìn)行異常檢測(cè)的算法,主要核心是通過(guò)K近鄰進(jìn)行檢索(kNN)÷椋基于KNN的異常檢測(cè)一般只能區(qū)分整體特性久脯,無(wú)法精確得到缺陷的位置。本文提出了一種利用KNN和多尺度特征的方法來(lái)進(jìn)行異常的缺陷檢測(cè)與定位歧胁。
- 優(yōu)點(diǎn)
不需要訓(xùn)練滋饲,只要有正樣本圖像就行。 - 缺點(diǎn)
需要存儲(chǔ)所有訓(xùn)練集的特征喊巍,對(duì)于內(nèi)存的需求很高屠缭。
SPADE整個(gè)過(guò)程分為3部分:圖像深度特征提取、K近鄰正常圖像檢索和特征金字塔像素對(duì)齊崭参。
1.圖像深度特征提取
就是使用一個(gè)在imagenet上預(yù)訓(xùn)練過(guò)的模型進(jìn)行特征提取呵曹,論文中使用的是pytorch框架自帶的wide_resnet50_2,對(duì)layer1何暮,layer2奄喂,layer3,avepool層的結(jié)果進(jìn)行了輸出海洼。
這個(gè)步驟會(huì)把訓(xùn)練集中的所有圖像都做一遍特征提取跨新,然后分別把各個(gè)層提取出來(lái)的特征存儲(chǔ)起來(lái),等需要用的時(shí)候再全部載入內(nèi)存(所以如果數(shù)據(jù)集很大的話(huà)坏逢,對(duì)內(nèi)存的需求就很高)域帐。
2.K近鄰正常圖像檢索
這個(gè)步驟是在整圖層面上判定這個(gè)圖像有沒(méi)有異常,但不會(huì)告訴你異常具體在那個(gè)位置是整。主要是使用上一步avepool層的輸出特征肖揣,分別測(cè)試圖像的avepool層的輸出特征分別和訓(xùn)練集中avepool層的輸出特征計(jì)算歐式近距離,然后再取距離最近的K個(gè)圖像浮入,作為訓(xùn)練集中與測(cè)試圖像最接近的K個(gè)圖像龙优。代碼中使用topk進(jìn)行選取,參數(shù)largest=False代表降序排列事秀。
# calculate distance matrix
dist_matrix = calc_dist_matrix(torch.flatten(test_outputs['avgpool'], 1),
torch.flatten(train_outputs['avgpool'], 1))
# select K nearest neighbor and take average
topk_values, topk_indexes = torch.topk(dist_matrix, k=args.top_k, dim=1, largest=False)
3.特征金字塔像素對(duì)齊
這個(gè)步驟是用來(lái)確定異常在圖像的具體哪個(gè)位置彤断,以測(cè)試圖像的layer1特征為例野舶,layer1上的各個(gè)像素都會(huì)與步驟2中篩選出來(lái)的K個(gè)圖像的layer1上的像素做歐式距離計(jì)算,然后輸出2者之間最短的距離瓦糟,遍歷整張?zhí)卣鲌D就能得到測(cè)試圖像的layer1上的特征與篩選出來(lái)的K個(gè)圖像的layer1上的特征在像素層面的最短距離筒愚,然后layer2,layer3特分別做相同的計(jì)算菩浙,由于layer1巢掺,layer2,layer3他們的特征圖尺寸不一樣劲蜻,所以會(huì)將他們r(jià)eszie到一樣的尺寸再通道拼接在一起陆淀,之后在通道層面求平均,就能得到mask圖先嬉,作者還對(duì)這個(gè)mas圖做了一個(gè)高斯濾波用于平滑圖像轧苫。代碼里面除以100,主要是用來(lái)分段計(jì)算疫蔓,所有像素的特征一起計(jì)算歐氏距離含懊,內(nèi)存容易溢出。
# construct a gallery of features at all pixel locations of the K nearest neighbors
topk_feat_map = train_outputs[layer_name][topk_indexes[t_idx]]
test_feat_map = test_outputs[layer_name][t_idx:t_idx + 1]
feat_gallery = topk_feat_map.transpose(3, 1).flatten(0, 2).unsqueeze(-1).unsqueeze(-1)
# calculate distance matrix
dist_matrix_list = []
for d_idx in range(feat_gallery.shape[0] // 100):
dist_matrix = torch.pairwise_distance(feat_gallery[d_idx * 100:d_idx * 100 + 100], test_feat_map)
dist_matrix_list.append(dist_matrix)
dist_matrix = torch.cat(dist_matrix_list, 0)
# k nearest features from the gallery (k=1)
score_map = torch.min(dist_matrix, dim=0)[0]
score_map = F.interpolate(score_map.unsqueeze(0).unsqueeze(0), size=224,
mode='bilinear', align_corners=False)
score_maps.append(score_map)
# average distance between the features
score_map = torch.mean(torch.cat(score_maps, 0), dim=0)
# apply gaussian smoothing on the score map
score_map = gaussian_filter(score_map.squeeze().cpu().detach().numpy(), sigma=4)