【GiantPandaCV導語】收集自RepDistiller中的蒸餾方法哆料,盡可能簡單解釋蒸餾用到的策略慈迈,并提供了實現(xiàn)源碼山上。
1. KD: Knowledge Distillation
全稱:Distilling the Knowledge in a Neural Network
鏈接:https://arxiv.org/pdf/1503.02531.pdf
發(fā)表:NIPS14
最經(jīng)典的景图,也是明確提出知識蒸餾概念的工作固蛾,通過使用帶溫度的softmax函數(shù)來軟化教師網(wǎng)絡的邏輯層輸出作為學生網(wǎng)絡的監(jiān)督信息骗爆,
使用KL divergence來衡量學生網(wǎng)絡與教師網(wǎng)絡的差異次氨,具體流程如下圖所示(來自Knowledge Distillation A Survey)
對學生網(wǎng)絡來說,一部分監(jiān)督信息來自hard label標簽摘投,另一部分來自教師網(wǎng)絡提供的soft label煮寡。
代碼實現(xiàn):
class DistillKL(nn.Module):
"""Distilling the Knowledge in a Neural Network"""
def __init__(self, T):
super(DistillKL, self).__init__()
self.T = T
def forward(self, y_s, y_t):
p_s = F.log_softmax(y_s/self.T, dim=1)
p_t = F.softmax(y_t/self.T, dim=1)
loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
return loss
核心就是一個kl_div函數(shù),用于計算學生網(wǎng)絡和教師網(wǎng)絡的分布差異犀呼。
2. FitNet: Hints for thin deep nets
全稱:Fitnets: hints for thin deep nets
鏈接:https://arxiv.org/pdf/1412.6550.pdf
發(fā)表:ICLR 15 Poster
對中間層進行蒸餾的開山之作幸撕,通過將學生網(wǎng)絡的feature map擴展到與教師網(wǎng)絡的feature map相同尺寸以后,使用均方誤差MSE Loss來衡量兩者差異外臂。
實現(xiàn)如下:
class HintLoss(nn.Module):
"""Fitnets: hints for thin deep nets, ICLR 2015"""
def __init__(self):
super(HintLoss, self).__init__()
self.crit = nn.MSELoss()
def forward(self, f_s, f_t):
loss = self.crit(f_s, f_t)
return loss
實現(xiàn)核心就是MSELoss
3. AT: Attention Transfer
全稱:Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer
鏈接:https://arxiv.org/pdf/1612.03928.pdf
發(fā)表:ICLR16
為了提升學生模型性能提出使用注意力作為知識載體進行遷移坐儿,文中提到了兩種注意力,一種是activation-based attention transfer宋光,另一種是gradient-based attention transfer貌矿。實驗發(fā)現(xiàn)第一種方法既簡單效果又好。
[圖片上傳失敗...(image-3b8537-1639369976361)]
實現(xiàn)如下:
class Attention(nn.Module):
"""Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
via Attention Transfer
code: https://github.com/szagoruyko/attention-transfer"""
def __init__(self, p=2):
super(Attention, self).__init__()
self.p = p
def forward(self, g_s, g_t):
return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
def at_loss(self, f_s, f_t):
s_H, t_H = f_s.shape[2], f_t.shape[2]
if s_H > t_H:
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
elif s_H < t_H:
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
else:
pass
return (self.at(f_s) - self.at(f_t)).pow(2).mean()
def at(self, f):
return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))
首先使用avgpool將尺寸調(diào)整一致罪佳,然后使用MSE Loss來衡量兩者差距逛漫。
4. SP: Similarity-Preserving
全稱:Similarity-Preserving Knowledge Distillation
鏈接:https://arxiv.org/pdf/1907.09682.pdf
發(fā)表:ICCV19
SP歸屬于基于關系的知識蒸餾方法。文章思想是提出相似性保留的知識赘艳,使得教師網(wǎng)絡和學生網(wǎng)絡會對相同的樣本產(chǎn)生相似的激活酌毡。可以從下圖看出處理流程第练,教師網(wǎng)絡和學生網(wǎng)絡對應feature map通過計算內(nèi)積阔馋,得到bsxbs的相似度矩陣,然后使用均方誤差來衡量兩個相似度矩陣娇掏。
最終Loss為:
G代表的就是bsxbs的矩陣呕寝。
實現(xiàn)如下:
class Similarity(nn.Module):
"""Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
def __init__(self):
super(Similarity, self).__init__()
def forward(self, g_s, g_t):
return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
def similarity_loss(self, f_s, f_t):
bsz = f_s.shape[0]
f_s = f_s.view(bsz, -1)
f_t = f_t.view(bsz, -1)
G_s = torch.mm(f_s, torch.t(f_s))
# G_s = G_s / G_s.norm(2)
G_s = torch.nn.functional.normalize(G_s)
G_t = torch.mm(f_t, torch.t(f_t))
# G_t = G_t / G_t.norm(2)
G_t = torch.nn.functional.normalize(G_t)
G_diff = G_t - G_s
loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
return loss
5. CC: Correlation Congruence
全稱:Correlation Congruence for Knowledge Distillation
鏈接:https://arxiv.org/pdf/1904.01802.pdf
發(fā)表:ICCV19
CC也歸屬于基于關系的知識蒸餾方法。不應該僅僅引導教師網(wǎng)絡和學生網(wǎng)絡單個樣本向量之間的差異下梢,還應該學習兩個樣本之間的相關性,而這個相關性使用的是Correlation Congruence 教師網(wǎng)絡雨學生網(wǎng)絡相關性之間的歐氏距離孽江。
整體Loss如下:
實現(xiàn)如下:
class Correlation(nn.Module):
"""Similarity-preserving loss. My origianl own reimplementation
based on the paper before emailing the original authors."""
def __init__(self):
super(Correlation, self).__init__()
def forward(self, f_s, f_t):
return self.similarity_loss(f_s, f_t)
def similarity_loss(self, f_s, f_t):
bsz = f_s.shape[0]
f_s = f_s.view(bsz, -1)
f_t = f_t.view(bsz, -1)
G_s = torch.mm(f_s, torch.t(f_s))
G_s = G_s / G_s.norm(2)
G_t = torch.mm(f_t, torch.t(f_t))
G_t = G_t / G_t.norm(2)
G_diff = G_t - G_s
loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
return loss
6. VID: Variational Information Distillation
全稱:Variational Information Distillation for Knowledge Transfer
鏈接:https://arxiv.org/pdf/1904.05835.pdf
發(fā)表:CVPR19
利用互信息(Mutual Information)來衡量學生網(wǎng)絡和教師網(wǎng)絡差異讶坯。互信息可以表示出兩個變量的互相依賴程度岗屏,其值越大辆琅,表示變量之間的依賴程度越高婉烟。互信息計算如下:
互信息是教師模型的熵減去在已知學生模型條件下教師模型的熵暇屋。目標是最大化互信息昙衅,因為互信息越大說明H(t|s)越小,即學生網(wǎng)絡確定的情況下定鸟,教師網(wǎng)絡的熵會變小而涉,證明學生網(wǎng)絡已經(jīng)學習的比較充分。
整體loss如下:
由于p(t|s)很難計算联予,可以使用變分分布q(t|s)去接近真實分布婴谱。
其中q(t|s)是使用方差可學習的高斯分布模擬(公式中的log_scale):
實現(xiàn)如下:
class VIDLoss(nn.Module):
"""Variational Information Distillation for Knowledge Transfer (CVPR 2019),
code from author: https://github.com/ssahn0215/variational-information-distillation"""
def __init__(self,
num_input_channels,
num_mid_channel,
num_target_channels,
init_pred_var=5.0,
eps=1e-5):
super(VIDLoss, self).__init__()
def conv1x1(in_channels, out_channels, stride=1):
return nn.Conv2d(
in_channels, out_channels,
kernel_size=1, padding=0,
bias=False, stride=stride)
self.regressor = nn.Sequential(
conv1x1(num_input_channels, num_mid_channel),
nn.ReLU(),
conv1x1(num_mid_channel, num_mid_channel),
nn.ReLU(),
conv1x1(num_mid_channel, num_target_channels),
)
self.log_scale = torch.nn.Parameter(
np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels)
)
self.eps = eps
def forward(self, input, target):
# pool for dimentsion match
s_H, t_H = input.shape[2], target.shape[2]
if s_H > t_H:
input = F.adaptive_avg_pool2d(input, (t_H, t_H))
elif s_H < t_H:
target = F.adaptive_avg_pool2d(target, (s_H, s_H))
else:
pass
pred_mean = self.regressor(input)
pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps
pred_var = pred_var.view(1, -1, 1, 1)
neg_log_prob = 0.5*(
(pred_mean-target)**2/pred_var+torch.log(pred_var)
)
loss = torch.mean(neg_log_prob)
return loss
7. RKD: Relation Knowledge Distillation
全稱:Relational Knowledge Disitllation
鏈接:http://arxiv.org/pdf/1904.05068
發(fā)表:CVPR19
RKD也是基于關系的知識蒸餾方法,RKD提出了兩種損失函數(shù)躯泰,二階的距離損失和三階的角度損失。
- Distance-wise Loss
- Angle-wise Loss
實現(xiàn)如下:
class RKDLoss(nn.Module):
"""Relational Knowledge Disitllation, CVPR2019"""
def __init__(self, w_d=25, w_a=50):
super(RKDLoss, self).__init__()
self.w_d = w_d
self.w_a = w_a
def forward(self, f_s, f_t):
student = f_s.view(f_s.shape[0], -1)
teacher = f_t.view(f_t.shape[0], -1)
# RKD distance loss
with torch.no_grad():
t_d = self.pdist(teacher, squared=False)
mean_td = t_d[t_d > 0].mean()
t_d = t_d / mean_td
d = self.pdist(student, squared=False)
mean_d = d[d > 0].mean()
d = d / mean_d
loss_d = F.smooth_l1_loss(d, t_d)
# RKD Angle loss
with torch.no_grad():
td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
norm_td = F.normalize(td, p=2, dim=2)
t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
sd = (student.unsqueeze(0) - student.unsqueeze(1))
norm_sd = F.normalize(sd, p=2, dim=2)
s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
loss_a = F.smooth_l1_loss(s_angle, t_angle)
loss = self.w_d * loss_d + self.w_a * loss_a
return loss
@staticmethod
def pdist(e, squared=False, eps=1e-12):
e_square = e.pow(2).sum(dim=1)
prod = e @ e.t()
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
if not squared:
res = res.sqrt()
res = res.clone()
res[range(len(e)), range(len(e))] = 0
return res
8. PKT:Probabilistic Knowledge Transfer
全稱:Probabilistic Knowledge Transfer for deep representation learning
鏈接:https://arxiv.org/abs/1803.10837
發(fā)表:CoRR18
提出一種概率知識轉(zhuǎn)移方法华糖,引入了互信息來進行建模麦向。該方法具有可跨模態(tài)知識轉(zhuǎn)移、無需考慮任務類型客叉、可將手工特征融入網(wǎng)絡等有點诵竭。
實現(xiàn)如下:
class PKT(nn.Module):
"""Probabilistic Knowledge Transfer for deep representation learning
Code from author: https://github.com/passalis/probabilistic_kt"""
def __init__(self):
super(PKT, self).__init__()
def forward(self, f_s, f_t):
return self.cosine_similarity_loss(f_s, f_t)
@staticmethod
def cosine_similarity_loss(output_net, target_net, eps=0.0000001):
# Normalize each vector by its norm
output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))
output_net = output_net / (output_net_norm + eps)
output_net[output_net != output_net] = 0
target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))
target_net = target_net / (target_net_norm + eps)
target_net[target_net != target_net] = 0
# Calculate the cosine similarity
model_similarity = torch.mm(output_net, output_net.transpose(0, 1))
target_similarity = torch.mm(target_net, target_net.transpose(0, 1))
# Scale cosine similarity to 0..1
model_similarity = (model_similarity + 1.0) / 2.0
target_similarity = (target_similarity + 1.0) / 2.0
# Transform them into probabilities
model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)
target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)
# Calculate the KL-divergence
loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))
return loss
9. AB: Activation Boundaries
全稱:Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
鏈接:https://arxiv.org/pdf/1811.03233.pdf
發(fā)表:AAAI18
目標:讓教師網(wǎng)絡層的神經(jīng)元的激活邊界盡量和學生網(wǎng)絡的一樣。所謂的激活邊界指的是分離超平面(針對的是RELU這種激活函數(shù))兼搏,其決定了神經(jīng)元的激活與失活卵慰。AB提出的激活轉(zhuǎn)移損失,讓教師網(wǎng)絡與學生網(wǎng)絡之間的分離邊界盡可能一致佛呻。
實現(xiàn)如下:
class ABLoss(nn.Module):
"""Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
code: https://github.com/bhheo/AB_distillation
"""
def __init__(self, feat_num, margin=1.0):
super(ABLoss, self).__init__()
self.w = [2**(i-feat_num+1) for i in range(feat_num)]
self.margin = margin
def forward(self, g_s, g_t):
bsz = g_s[0].shape[0]
losses = [self.criterion_alternative_l2(s, t) for s, t in zip(g_s, g_t)]
losses = [w * l for w, l in zip(self.w, losses)]
# loss = sum(losses) / bsz
# loss = loss / 1000 * 3
losses = [l / bsz for l in losses]
losses = [l / 1000 * 3 for l in losses]
return losses
def criterion_alternative_l2(self, source, target):
loss = ((source + self.margin) ** 2 * ((source > -self.margin) & (target <= 0)).float() +
(source - self.margin) ** 2 * ((source <= self.margin) & (target > 0)).float())
return torch.abs(loss).sum()
10. FT: Factor Transfer
全稱:Paraphrasing Complex Network: Network Compression via Factor Transfer
鏈接:https://arxiv.org/pdf/1802.04977.pdf
發(fā)表:NIPS18
提出的是factor transfer的方法裳朋。所謂的factor,其實是對模型最后的數(shù)據(jù)結(jié)果進行一個編解碼的過程吓著,提取出的一個factor矩陣鲤嫡,用教師網(wǎng)絡的factor來指導學生網(wǎng)絡的factor送挑。
FT計算公式為:
實現(xiàn)如下:
class FactorTransfer(nn.Module):
"""Paraphrasing Complex Network: Network Compression via Factor Transfer, NeurIPS 2018"""
def __init__(self, p1=2, p2=1):
super(FactorTransfer, self).__init__()
self.p1 = p1
self.p2 = p2
def forward(self, f_s, f_t):
return self.factor_loss(f_s, f_t)
def factor_loss(self, f_s, f_t):
s_H, t_H = f_s.shape[2], f_t.shape[2]
if s_H > t_H:
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
elif s_H < t_H:
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
else:
pass
if self.p2 == 1:
return (self.factor(f_s) - self.factor(f_t)).abs().mean()
else:
return (self.factor(f_s) - self.factor(f_t)).pow(self.p2).mean()
def factor(self, f):
return F.normalize(f.pow(self.p1).mean(1).view(f.size(0), -1))
11. FSP: Flow of Solution Procedure
全稱:A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning
鏈接:https://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf
發(fā)表:CVPR17
FSP認為教學生網(wǎng)絡不同層輸出的feature之間的關系比教學生網(wǎng)絡結(jié)果好
定義了FSP矩陣來定義網(wǎng)絡內(nèi)部特征層之間的關系,是一個Gram矩陣反映老師教學生的過程暖眼。
使用的是L2 Loss進行約束FSP矩陣惕耕。
實現(xiàn)如下:
class FSP(nn.Module):
"""A Gift from Knowledge Distillation:
Fast Optimization, Network Minimization and Transfer Learning"""
def __init__(self, s_shapes, t_shapes):
super(FSP, self).__init__()
assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
s_c = [s[1] for s in s_shapes]
t_c = [t[1] for t in t_shapes]
if np.any(np.asarray(s_c) != np.asarray(t_c)):
raise ValueError('num of channels not equal (error in FSP)')
def forward(self, g_s, g_t):
s_fsp = self.compute_fsp(g_s)
t_fsp = self.compute_fsp(g_t)
loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)]
return loss_group
@staticmethod
def compute_loss(s, t):
return (s - t).pow(2).mean()
@staticmethod
def compute_fsp(g):
fsp_list = []
for i in range(len(g) - 1):
bot, top = g[i], g[i + 1]
b_H, t_H = bot.shape[2], top.shape[2]
if b_H > t_H:
bot = F.adaptive_avg_pool2d(bot, (t_H, t_H))
elif b_H < t_H:
top = F.adaptive_avg_pool2d(top, (b_H, b_H))
else:
pass
bot = bot.unsqueeze(1)
top = top.unsqueeze(2)
bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1)
top = top.view(top.shape[0], top.shape[1], top.shape[2], -1)
fsp = (bot * top).mean(-1)
fsp_list.append(fsp)
return fsp_list
12. NST: Neuron Selectivity Transfer
全稱:Like what you like: knowledge distill via neuron selectivity transfer
鏈接:https://arxiv.org/pdf/1707.01219.pdf
發(fā)表:CoRR17
使用新的損失函數(shù)最小化教師網(wǎng)絡與學生網(wǎng)絡之間的Maximum Mean Discrepancy(MMD), 文中選擇的是對其教師網(wǎng)絡與學生網(wǎng)絡之間神經(jīng)元選擇樣式的分布。
使用核技巧(對應下面poly kernel)并進一步展開以后可得:
實際上提供了Linear Kernel诫肠、Poly Kernel司澎、Gaussian Kernel三種,這里實現(xiàn)只給了Poly這種栋豫,這是因為Poly這種方法可以與KD進行互補挤安,這樣整體效果會非常好。
實現(xiàn)如下:
class NSTLoss(nn.Module):
"""like what you like: knowledge distill via neuron selectivity transfer"""
def __init__(self):
super(NSTLoss, self).__init__()
pass
def forward(self, g_s, g_t):
return [self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
def nst_loss(self, f_s, f_t):
s_H, t_H = f_s.shape[2], f_t.shape[2]
if s_H > t_H:
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
elif s_H < t_H:
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
else:
pass
f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1)
f_s = F.normalize(f_s, dim=2)
f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1)
f_t = F.normalize(f_t, dim=2)
# set full_loss as False to avoid unnecessary computation
full_loss = True
if full_loss:
return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean()
- 2 * self.poly_kernel(f_s, f_t).mean())
else:
return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean()
def poly_kernel(self, a, b):
a = a.unsqueeze(1)
b = b.unsqueeze(2)
res = (a * b).sum(-1).pow(2)
return res
13. CRD: Contrastive Representation Distillation
全稱:Contrastive Representation Distillation
鏈接:https://arxiv.org/abs/1910.10699v2
發(fā)表:ICLR20
將對比學習引入知識蒸餾中笼才,其目標修正為:學習一個表征漱受,讓正樣本對的教師網(wǎng)絡與學生網(wǎng)絡盡可能接近,負樣本對教師網(wǎng)絡與學生網(wǎng)絡盡可能遠離骡送。
構(gòu)建的對比學習問題表示如下:
整體的蒸餾Loss表示如下:
實現(xiàn)如下:https://github.com/HobbitLong/RepDistiller
class ContrastLoss(nn.Module):
"""
contrastive loss, corresponding to Eq (18)
"""
def __init__(self, n_data):
super(ContrastLoss, self).__init__()
self.n_data = n_data
def forward(self, x):
bsz = x.shape[0]
m = x.size(1) - 1
# noise distribution
Pn = 1 / float(self.n_data)
# loss for positive pair
P_pos = x.select(1, 0)
log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
# loss for K negative pair
P_neg = x.narrow(1, 1, m)
log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()
loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz
return loss
class CRDLoss(nn.Module):
"""CRD Loss function
includes two symmetric parts:
(a) using teacher as anchor, choose positive and negatives over the student side
(b) using student as anchor, choose positive and negatives over the teacher side
Args:
opt.s_dim: the dimension of student's feature
opt.t_dim: the dimension of teacher's feature
opt.feat_dim: the dimension of the projection space
opt.nce_k: number of negatives paired with each positive
opt.nce_t: the temperature
opt.nce_m: the momentum for updating the memory buffer
opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim
"""
def __init__(self, opt):
super(CRDLoss, self).__init__()
self.embed_s = Embed(opt.s_dim, opt.feat_dim)
self.embed_t = Embed(opt.t_dim, opt.feat_dim)
self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m)
self.criterion_t = ContrastLoss(opt.n_data)
self.criterion_s = ContrastLoss(opt.n_data)
def forward(self, f_s, f_t, idx, contrast_idx=None):
"""
Args:
f_s: the feature of student network, size [batch_size, s_dim]
f_t: the feature of teacher network, size [batch_size, t_dim]
idx: the indices of these positive samples in the dataset, size [batch_size]
contrast_idx: the indices of negative samples, size [batch_size, nce_k]
Returns:
The contrastive loss
"""
f_s = self.embed_s(f_s)
f_t = self.embed_t(f_t)
out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)
s_loss = self.criterion_s(out_s)
t_loss = self.criterion_t(out_t)
loss = s_loss + t_loss
return loss
14. Overhaul
全稱:A Comprehensive Overhaul of Feature Distillation
鏈接:http://openaccess.thecvf.com/content_ICCV_2019/papers/
發(fā)表:CVPR19
- teacher transform中提出使用margin RELU激活函數(shù)昂羡。
student transform中提出使用1x1卷積。
distillation feature postion選擇Pre-ReLU摔踱。
- distance function部分提出了Partial L2 損失函數(shù)虐先。
部分實現(xiàn)如下:
class OFD(nn.Module):
'''
A Comprehensive Overhaul of Feature Distillation
http://openaccess.thecvf.com/content_ICCV_2019/papers/
Heo_A_Comprehensive_Overhaul_of_Feature_Distillation_ICCV_2019_paper.pdf
'''
def __init__(self, in_channels, out_channels):
super(OFD, self).__init__()
self.connector = nn.Sequential(*[
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channels)
])
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, fm_s, fm_t):
margin = self.get_margin(fm_t)
fm_t = torch.max(fm_t, margin)
fm_s = self.connector(fm_s)
mask = 1.0 - ((fm_s <= fm_t) & (fm_t <= 0.0)).float()
loss = torch.mean((fm_s - fm_t)**2 * mask)
return loss
def get_margin(self, fm, eps=1e-6):
mask = (fm < 0.0).float()
masked_fm = fm * mask
margin = masked_fm.sum(dim=(0,2,3), keepdim=True) / (mask.sum(dim=(0,2,3), keepdim=True)+eps)
return margin
參考文獻
https://blog.csdn.net/weixin_44579633/article/details/119350631
https://blog.csdn.net/winycg/article/details/105297089
https://blog.csdn.net/weixin_46239293/article/details/120289163
https://blog.csdn.net/DD_PP_JJ/article/details/121578722
https://blog.csdn.net/DD_PP_JJ/article/details/121714957
https://zhuanlan.zhihu.com/p/344881975
https://blog.csdn.net/weixin_44633882/article/details/108927033
https://blog.csdn.net/weixin_46239293/article/details/120266111
https://blog.csdn.net/weixin_43402775/article/details/109011296