小米實(shí)驗(yàn)室 AutoML 團(tuán)隊(duì)的NAS工作,論文題目:Fair DARTS: Eliminating Unfair Advantages in Differentiable Architecture Search唠倦。 針對(duì)現(xiàn)有DARTS框架在搜索階段訓(xùn)練過程中存在 skip-connection 富集現(xiàn)象,導(dǎo)致最終模型出現(xiàn)大幅度的性能損失問題的問題墩朦,提出了Sigmoid替代Softmax的方法拔莱,使搜索階段候選操作由競(jìng)爭(zhēng)關(guān)系轉(zhuǎn)化為合作關(guān)系鳖敷。 并提出 0-1 loss 提高了架構(gòu)參數(shù)的二值性。
動(dòng)機(jī)
skip-connection 富集現(xiàn)象
本文指出 skip connections 富集的原因主要有兩個(gè)方面:
skip connections 的不公平優(yōu)勢(shì)
在超網(wǎng)絡(luò)訓(xùn)練架構(gòu)參數(shù)過程中挽唉,兩個(gè)節(jié)點(diǎn)之間是八個(gè)操作同時(shí)作用的滤祖, skip connections 作為操作的其中一員,相較于其他的操作來講是起到了跳躍連接的作用瓶籽。在ResNet 中已經(jīng)明確指出了跳躍連接在深層網(wǎng)絡(luò)的訓(xùn)練過程中中起到了良好的梯度疏通效果匠童,進(jìn)而有效減緩了梯度消失現(xiàn)象。因此塑顺,在超網(wǎng)絡(luò)的搜索訓(xùn)練過程中汤求,skip connections可以借助其他操作的關(guān)系達(dá)到疏通效果,使得严拒,skip connections 相較于其他操作存在不公平優(yōu)勢(shì)扬绪。
softmax 的排外競(jìng)爭(zhēng)
由于 softmax 是典型的歸一化操作,是一種潛在的排外競(jìng)爭(zhēng)方式裤唠,致使一個(gè)架構(gòu)參數(shù)增大必然抑制其他參數(shù)挤牛。
部署訓(xùn)練的離散化差異(discretization discrepancy)
搜索過程結(jié)束后,在部署訓(xùn)練選取網(wǎng)絡(luò)架構(gòu)時(shí)种蘸,直接將 softmax 后最大 α 值對(duì)應(yīng)的操作保留而拋棄其它的操作墓赴,從而使得選出的網(wǎng)絡(luò)結(jié)構(gòu)和原始包含所有結(jié)構(gòu)的超網(wǎng)二者的表現(xiàn)能力存在差距。離散化差異問題主要在于兩點(diǎn)航瞭,一方面Softmax歸一化八種操作參數(shù)后诫硕,DARTS 最后選擇時(shí)的 α 值基本都在 0.1 到 0.3 之間,另一方面判定好壞的范圍比較窄刊侯,因?yàn)椴煌僮?α 值的 top1 和 top2 可能差距特別小痘括,例如 0.26 和 0.24,很難說 0.26 就一定比 0.24 好,如下圖所示:
方法
sigmoid 函數(shù)替換 softmax
class Network(nn.Module):
def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3,parse_method='darts', op_threshold=None):
pass
def forward(self, input):
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
if cell.reduction:
weights = F.sigmoid(self.alphas_reduce) # sigmoid 替換softmax
else:
weights = F.sigmoid(self.alphas_normal) # sigmoid 替換softmax
s0, s1 = s1, cell(s0, s1, weights)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0),-1))
return logits
0-1 損失函數(shù)
l2 0-1 損失函數(shù)
l1 0-1 損失函數(shù)
# l2
class ConvSeparateLoss(nn.modules.loss._Loss):
"""Separate the weight value between each operations using L2"""
def __init__(self, weight=0.1, size_average=None, ignore_index=-100,reduce=None, reduction='mean'):
super(ConvSeparateLoss, self).__init__(size_average, reduce, reduction)
self.ignore_index = ignore_index
self.weight = weight
def forward(self, input1, target1, input2):
loss1 = F.cross_entropy(input1, target1)
loss2 = -F.mse_loss(input2, torch.tensor(0.5, requires_grad=False).cuda())
return loss1 + self.weight*loss2, loss1.item(), loss2.item()
# l1
class TriSeparateLoss(nn.modules.loss._Loss):
"""Separate the weight value between each operations using L1"""
def __init__(self, weight=0.1, size_average=None, ignore_index=-100,
reduce=None, reduction='mean'):
super(TriSeparateLoss, self).__init__(size_average, reduce, reduction)
self.ignore_index = ignore_index
self.weight = weight
def forward(self, input1, target1, input2):
loss1 = F.cross_entropy(input1, target1)
loss2 = -F.l1_loss(input2, torch.tensor(0.5, requires_grad=False).cuda())
return loss1 + self.weight*loss2, loss1.item(), loss2.item()
使用上述損失函數(shù)就可以使得不同操作之間的差距增大纲菌,二者的 α 值要么逼近 0 要么逼近 1 如下圖曲線所示
實(shí)驗(yàn)
CIFAR-10
精度比較
FairDARTS 搜索 7 次均可得到魯棒性的結(jié)果:
skip connections 數(shù)量比較
DARTS 和 Fair DARTS 搜索出來的 cell 中所包含的 skip connections 數(shù)量比較:
ImageNet
精度比較
注意模型 A、B 是遷移比較疮绷,C翰舌、D 是直接搜索比較。
sigmoid 函數(shù)的共存性
熱力圖可看出使用 sigmoid 函數(shù)可讓其他操作和 skip connections 共存:
消融實(shí)驗(yàn)
去掉 Skip Connections
由于不公平的優(yōu)勢(shì)主要來自 Skip Connections冬骚,因此椅贱,搜索空間去掉 Skip Connections,那么即使在排他性競(jìng)爭(zhēng)中只冻,其他操作也應(yīng)該期待公平競(jìng)爭(zhēng)庇麦。 去掉 Skip Connections搜索得到的最佳模型(96.88±0.18%)略高于DARTS(96.76±0.32%),但低于FairDARTS(97.41±0.14%)喜德。 降低的精度表明足夠的 Skip Connections 確實(shí)對(duì)精度有益山橄,因此也不能簡(jiǎn)單去掉。
0-1 損失函數(shù)分析
- 如果去掉 0-1 損失函數(shù)會(huì)使得 α 值不再集中于兩端舍悯,不利于離散化;
- 損失靈敏度航棱,即通過超參來控制
損失函數(shù)的靈敏度
討論
對(duì)于 skip connections 使用 dropout 可以減少了不公平性;
對(duì)所有操作使用 dropout 同樣是有幫助的萌衬;
早停機(jī)制同樣關(guān)鍵(相當(dāng)于是在不公平出現(xiàn)以前及時(shí)止損)饮醇;
限制 skip connections 的數(shù)量需要極大的人為先驗(yàn),因?yàn)橹灰薅?skip connections 的數(shù)量為 2秕豫,隨機(jī)搜索也能獲得不錯(cuò)的結(jié)果朴艰;
高斯噪聲或許也能打破不公平優(yōu)勢(shì)(孕育出了后面的NoisyDARTS~)。
參考
[1] Fair DARTS: Eliminating Unfair Advantages in Differentiable Architecture Search
[2] DARTS+: Improved Differentiable Architecture Search with Early Stopping
[3] Noisy Differentiable Architecture Search
[4] Fair DARTS:公平的可微分神經(jīng)網(wǎng)絡(luò)搜索
[5] Fair darts代碼解析