0x00 背景知識
先放上一篇綜述文章揩晴,對于理解NAS(網(wǎng)絡(luò)結(jié)構(gòu)搜索)的問題有很大的幫助:https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/82321884
另外,DARTS搜索墩莫,強(qiáng)烈建議先看下inception的網(wǎng)絡(luò)結(jié)構(gòu)和nasnet的論文是复,DARTS的論文基礎(chǔ)是建立在之上的懦傍,某種程度上可以看做是對nasnet的優(yōu)化只锭。
0x01 搜索思路
基于前人的經(jīng)驗(yàn)(inception/nasnet)鱼蝉,DARTS使用cell作為模型結(jié)構(gòu)搜索的基礎(chǔ)單元洒嗤,所學(xué)習(xí)的單元堆疊成卷積網(wǎng)絡(luò),也可以遞歸連接形成遞歸網(wǎng)絡(luò)魁亦。
cell內(nèi)節(jié)點(diǎn)間先默認(rèn)所有可能的操作連接渔隶,每個(gè)連接初始化權(quán)重參數(shù)值,結(jié)構(gòu)搜索也就是訓(xùn)練這些權(quán)重參數(shù)洁奈,最終兩節(jié)點(diǎn)間選取權(quán)重最大的操作作為最終結(jié)構(gòu)參數(shù)间唉。
訓(xùn)練過程中,交替訓(xùn)練網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)和網(wǎng)絡(luò)參數(shù)利术。
0x02 代碼定義
genotype結(jié)構(gòu)定義
normal=[(‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 1), (‘skip_connect’, 0), (‘skip_connect’, 0), (‘dil_conv_3x3’, 2)], normal_concat=[2, 3, 4, 5]
取了genotype里的一個(gè)normal cell的定義及其對應(yīng)的cell結(jié)構(gòu)圖首先說明下呈野,這個(gè)定義的解釋。DARTS搜索的也就是這個(gè)定義印叁。
normal定義里(‘sep_conv_3x3’, 1)的0被冒,1,2轮蜕,3昨悼,4,5對應(yīng)到圖中的紅色字體標(biāo)注的跃洛。
從normal文字定義兩個(gè)元組一組幔戏,映射到圖中一個(gè)藍(lán)色方框的節(jié)點(diǎn)(這個(gè)是作者搜索出來的結(jié)構(gòu),結(jié)構(gòu)不一樣税课,對應(yīng)關(guān)系不一定是這樣的)
sep_conv_xxxx表示操作闲延,0/1表示輸入來源
(‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 0) —-> 節(jié)點(diǎn)0
(‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1) —-> 節(jié)點(diǎn)1
(‘sep_conv_3x3’, 1), (‘skip_connect’, 0) —-> 節(jié)點(diǎn)2
(‘skip_connect’, 0), (‘dil_conv_3x3’, 2) —-> 節(jié)點(diǎn)3
normal_concat=[2, 3, 4, 5] —-> cell輸出c_{k}
DARTS搜索NOTE
首先明確痊剖,DARTS搜索實(shí)際只搜cell內(nèi)結(jié)構(gòu),整個(gè)模型的網(wǎng)絡(luò)結(jié)構(gòu)是預(yù)定好的垒玲,比如多少層陆馁,網(wǎng)絡(luò)寬度,cell內(nèi)幾個(gè)節(jié)點(diǎn)等合愈;
在構(gòu)建搜索的網(wǎng)絡(luò)結(jié)構(gòu)時(shí)叮贩,有幾個(gè)特別的地方:
1.預(yù)構(gòu)建cell時(shí),采用的一個(gè)MixedOp:包含了兩個(gè)節(jié)點(diǎn)所有可能的連接(genotype中的PRIMITIVES)佛析;
2.初始化了一個(gè)alphas矩陣益老,網(wǎng)絡(luò)做forward時(shí),參數(shù)傳入寸莫,在cell里使用捺萌,搜索過程中所有可能連接都在時(shí),計(jì)算mixedOp的輸出膘茎,采用加權(quán)的形式桃纯。
3.訓(xùn)練過程對train數(shù)據(jù)每個(gè)step又切成兩份: train和validate, train用來訓(xùn)練網(wǎng)絡(luò)參數(shù),validate用來訓(xùn)練結(jié)構(gòu)參數(shù)披坏。
0x03 關(guān)鍵代碼片段
以下把代碼中一些關(guān)鍵的态坦,影響到理解DARTS的地方說明一下:
- file: train_search.py 第149行
architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
logits = model(input)
loss = criterion(logits, target)
loss.backward()
nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
optimizer.step()
這里就是論文里近似后的交叉梯度下降,其中architect.step()是結(jié)構(gòu)參數(shù)weights的梯度下降棒拂,optimizer.step()是網(wǎng)絡(luò)參數(shù)的梯度下降伞梯。
- file: model_search.py
class MixedOp(nn.Module):
def __init__(self, C, stride):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in PRIMITIVES:
op = OPS[primitive](C, stride, False)
if 'pool' in primitive:
op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
self._ops.append(op)
def forward(self, x, weights):
return sum(w * op(x) for w, op in zip(weights, self._ops)) # weighted op
這個(gè)是MixedOp,兩節(jié)點(diǎn)間操作把PRIMITIVES里定義的所有操作都連接上帚屉,計(jì)算輸出時(shí)利用傳入的weights進(jìn)行加權(quán)壮锻。
- file: model_search.py第47行
def forward(self, s0, s1, weights):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
for i in range(self._steps):
s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states)) # all nodes before can be input, mixop.
offset += len(states) #0, 2, 5, 9
states.append(s)
return torch.cat(states[-self._multiplier:], dim=1)
self.ops[], 實(shí)際是14(2+3+4+5)個(gè)MixedOp,2+3+4+5的解釋涮阔,對于第一個(gè)內(nèi)部節(jié)點(diǎn)猜绣,有兩個(gè)可能的輸入(c{k-1}, c_{k-2}),對于第二個(gè)內(nèi)部節(jié)點(diǎn)敬特,有三個(gè)可能的輸入(兩個(gè)同節(jié)點(diǎn)1掰邢,另加上第一個(gè)節(jié)點(diǎn))……
代碼里,weights[]伟阔,也是一個(gè)長度14的list辣之,前2個(gè)對應(yīng)到第一個(gè)節(jié)點(diǎn)的兩個(gè)輸入的權(quán)重,第3~5這3個(gè)元素對應(yīng)到第二個(gè)節(jié)點(diǎn)的三個(gè)輸入的權(quán)重……這就是上面代碼里offset的作用
- file: architect.py 第11行
class Architect(object):
def __init__(self, model, args):
self.network_momentum = args.momentum
self.network_weight_decay = args.weight_decay
self.model = model
self.optimizer = torch.optim.Adam(self.model.arch_parameters(), #arch_parameters,
lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
需要注意的是Architect里optimizer優(yōu)化器的參數(shù)是model.arch_parameters(), 這個(gè)對應(yīng)到的是model_search.py里定義的._arch_parameters皱炉,及初始化的各節(jié)點(diǎn)連接的權(quán)重怀估。
def _initialize_alphas(self):
k = sum(1 for i in range(self._steps) for n in range(2+i)) # 2+i, 2 for two inputs, i=0,1,2,3, nodes before this. 2+3+4+5
num_ops = len(PRIMITIVES)
self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
self._arch_parameters = [
self.alphas_normal,
self.alphas_reduce,
]
- file: model_search.py 第133行
def _parse(weights):
# weights: [2 + 3 + 4 + 5][len(PRIMITIVES)]
gene = []
n = 2
start = 0
for i in range(self._steps): #ch: steps = 4
end = start + n
print('start=', start, 'end=', end, 'n=', n)
W = weights[start:end].copy()
print(W) # ch: add
# chenhua: for x, -max(W[x][...]), W[][] is the parameters for architect. lambda elect out the OP weights most.
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
print(edges)
for j in edges: #ch: j, edges mean op, all possible ops between two node
print(j)
k_best = None
for k in range(len(W[j])): #ch: k, the weights for possible connection?
if k != PRIMITIVES.index('none'):
if k_best is None or W[j][k] > W[j][k_best]:
print('W[j][k]=', W[j][k], 'W[j][k_best]=', W[j][k_best])
k_best = k
gene.append((PRIMITIVES[k_best], j)) #ch: find ????
start = end
n += 1
return gene
# ch: alphas_xxx, parameters for architect??
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())
concat = range(2+self._steps-self._multiplier, self._steps+2) #ch: step=4, mltiplier=3
print('concat', concat)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
print('genotype=', genotype)
return genotype
搜索過程中搜索出的結(jié)果(節(jié)點(diǎn)間的op)的打印,就是靠這個(gè)函數(shù)。
核心是找出兩個(gè)節(jié)點(diǎn)間不為none的所有ops中權(quán)重最大的多搀,就是最終的結(jié)果歧蕉。
注意:weights[][]的size是[2 + 3 + 4 + 5][len(PRIMITIVES)]