DARTS代碼閱讀

0x00 背景知識

先放上一篇綜述文章揩晴,對于理解NAS(網(wǎng)絡(luò)結(jié)構(gòu)搜索)的問題有很大的幫助:https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/82321884
另外,DARTS搜索墩莫,強(qiáng)烈建議先看下inception的網(wǎng)絡(luò)結(jié)構(gòu)和nasnet的論文是复,DARTS的論文基礎(chǔ)是建立在之上的懦傍,某種程度上可以看做是對nasnet的優(yōu)化只锭。

0x01 搜索思路

基于前人的經(jīng)驗(yàn)(inception/nasnet)鱼蝉,DARTS使用cell作為模型結(jié)構(gòu)搜索的基礎(chǔ)單元洒嗤,所學(xué)習(xí)的單元堆疊成卷積網(wǎng)絡(luò),也可以遞歸連接形成遞歸網(wǎng)絡(luò)魁亦。
cell內(nèi)節(jié)點(diǎn)間先默認(rèn)所有可能的操作連接渔隶,每個(gè)連接初始化權(quán)重參數(shù)值,結(jié)構(gòu)搜索也就是訓(xùn)練這些權(quán)重參數(shù)洁奈,最終兩節(jié)點(diǎn)間選取權(quán)重最大的操作作為最終結(jié)構(gòu)參數(shù)间唉。

訓(xùn)練過程中,交替訓(xùn)練網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)和網(wǎng)絡(luò)參數(shù)利术。

0x02 代碼定義

genotype結(jié)構(gòu)定義

normal=[(‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 1), (‘skip_connect’, 0), (‘skip_connect’, 0), (‘dil_conv_3x3’, 2)], normal_concat=[2, 3, 4, 5]

取了genotype里的一個(gè)normal cell的定義及其對應(yīng)的cell結(jié)構(gòu)圖首先說明下呈野,這個(gè)定義的解釋。DARTS搜索的也就是這個(gè)定義印叁。
normal定義里(‘sep_conv_3x3’, 1)的0被冒,1,2轮蜕,3昨悼,4,5對應(yīng)到圖中的紅色字體標(biāo)注的跃洛。
從normal文字定義兩個(gè)元組一組幔戏,映射到圖中一個(gè)藍(lán)色方框的節(jié)點(diǎn)(這個(gè)是作者搜索出來的結(jié)構(gòu),結(jié)構(gòu)不一樣税课,對應(yīng)關(guān)系不一定是這樣的)
sep_conv_xxxx表示操作闲延,0/1表示輸入來源
(‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 0) —-> 節(jié)點(diǎn)0
(‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1) —-> 節(jié)點(diǎn)1
(‘sep_conv_3x3’, 1), (‘skip_connect’, 0) —-> 節(jié)點(diǎn)2
(‘skip_connect’, 0), (‘dil_conv_3x3’, 2) —-> 節(jié)點(diǎn)3
normal_concat=[2, 3, 4, 5] —-> cell輸出c_{k}

DARTS搜索NOTE

首先明確痊剖,DARTS搜索實(shí)際只搜cell內(nèi)結(jié)構(gòu),整個(gè)模型的網(wǎng)絡(luò)結(jié)構(gòu)是預(yù)定好的垒玲,比如多少層陆馁,網(wǎng)絡(luò)寬度,cell內(nèi)幾個(gè)節(jié)點(diǎn)等合愈;
在構(gòu)建搜索的網(wǎng)絡(luò)結(jié)構(gòu)時(shí)叮贩,有幾個(gè)特別的地方:
1.預(yù)構(gòu)建cell時(shí),采用的一個(gè)MixedOp:包含了兩個(gè)節(jié)點(diǎn)所有可能的連接(genotype中的PRIMITIVES)佛析;
2.初始化了一個(gè)alphas矩陣益老,網(wǎng)絡(luò)做forward時(shí),參數(shù)傳入寸莫,在cell里使用捺萌,搜索過程中所有可能連接都在時(shí),計(jì)算mixedOp的輸出膘茎,采用加權(quán)的形式桃纯。
3.訓(xùn)練過程對train數(shù)據(jù)每個(gè)step又切成兩份: train和validate, train用來訓(xùn)練網(wǎng)絡(luò)參數(shù),validate用來訓(xùn)練結(jié)構(gòu)參數(shù)披坏。

0x03 關(guān)鍵代碼片段

以下把代碼中一些關(guān)鍵的态坦,影響到理解DARTS的地方說明一下:

  • file: train_search.py 第149行
    architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
  logits = model(input)
  loss = criterion(logits, target)
  loss.backward()
  nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
  optimizer.step()

這里就是論文里近似后的交叉梯度下降,其中architect.step()是結(jié)構(gòu)參數(shù)weights的梯度下降棒拂,optimizer.step()是網(wǎng)絡(luò)參數(shù)的梯度下降伞梯。

  • file: model_search.py
class MixedOp(nn.Module):
  def __init__(self, C, stride):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    for primitive in PRIMITIVES:
      op = OPS[primitive](C, stride, False)
      if 'pool' in primitive:
        op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
      self._ops.append(op)
  def forward(self, x, weights):
    return sum(w * op(x) for w, op in zip(weights, self._ops)) # weighted op

這個(gè)是MixedOp,兩節(jié)點(diǎn)間操作把PRIMITIVES里定義的所有操作都連接上帚屉,計(jì)算輸出時(shí)利用傳入的weights進(jìn)行加權(quán)壮锻。

  • file: model_search.py第47行
def forward(self, s0, s1, weights):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)
    states = [s0, s1]
    offset = 0
    for i in range(self._steps):
      s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states)) # all nodes before can be input, mixop.
      offset += len(states) #0, 2, 5, 9
      states.append(s)
    return torch.cat(states[-self._multiplier:], dim=1)

self.ops[], 實(shí)際是14(2+3+4+5)個(gè)MixedOp,2+3+4+5的解釋涮阔,對于第一個(gè)內(nèi)部節(jié)點(diǎn)猜绣,有兩個(gè)可能的輸入(c{k-1}, c_{k-2}),對于第二個(gè)內(nèi)部節(jié)點(diǎn)敬特,有三個(gè)可能的輸入(兩個(gè)同節(jié)點(diǎn)1掰邢,另加上第一個(gè)節(jié)點(diǎn))……
代碼里,weights[]伟阔,也是一個(gè)長度14的list辣之,前2個(gè)對應(yīng)到第一個(gè)節(jié)點(diǎn)的兩個(gè)輸入的權(quán)重,第3~5這3個(gè)元素對應(yīng)到第二個(gè)節(jié)點(diǎn)的三個(gè)輸入的權(quán)重……這就是上面代碼里offset的作用

  • file: architect.py 第11行
class Architect(object):
  def __init__(self, model, args):
    self.network_momentum = args.momentum
    self.network_weight_decay = args.weight_decay
    self.model = model
    self.optimizer = torch.optim.Adam(self.model.arch_parameters(),   #arch_parameters, 
        lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay) 

需要注意的是Architect里optimizer優(yōu)化器的參數(shù)是model.arch_parameters(), 這個(gè)對應(yīng)到的是model_search.py里定義的._arch_parameters皱炉,及初始化的各節(jié)點(diǎn)連接的權(quán)重怀估。
def _initialize_alphas(self):
k = sum(1 for i in range(self._steps) for n in range(2+i)) # 2+i, 2 for two inputs, i=0,1,2,3, nodes before this. 2+3+4+5
num_ops = len(PRIMITIVES)

self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
    self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
    self._arch_parameters = [
      self.alphas_normal,
      self.alphas_reduce,
    ]

  • file: model_search.py 第133行
def _parse(weights):
      #  weights: [2 + 3 + 4 + 5][len(PRIMITIVES)]
      gene = []
      n = 2
      start = 0
      for i in range(self._steps): #ch: steps = 4
        end = start + n 
        print('start=', start, 'end=', end, 'n=', n)
        W = weights[start:end].copy()
        print(W) # ch: add
        # chenhua: for x, -max(W[x][...]), W[][] is the parameters for architect. lambda elect out the OP weights most.
        edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
        print(edges)
        for j in edges: #ch: j, edges mean op, all possible ops between two node
          print(j)
          k_best = None
          for k in range(len(W[j])):  #ch: k, the weights for possible connection?
            if k != PRIMITIVES.index('none'):
              if k_best is None or W[j][k] > W[j][k_best]:
                print('W[j][k]=', W[j][k], 'W[j][k_best]=', W[j][k_best])
                k_best = k
          gene.append((PRIMITIVES[k_best], j))  #ch: find ????
        start = end
        n += 1
      return gene
    # ch: alphas_xxx, parameters for architect??
    gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
    gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())
    concat = range(2+self._steps-self._multiplier, self._steps+2) #ch: step=4, mltiplier=3
    print('concat', concat)
    genotype = Genotype(
      normal=gene_normal, normal_concat=concat,
      reduce=gene_reduce, reduce_concat=concat
    )
    print('genotype=', genotype)
    return genotype

搜索過程中搜索出的結(jié)果(節(jié)點(diǎn)間的op)的打印,就是靠這個(gè)函數(shù)。
核心是找出兩個(gè)節(jié)點(diǎn)間不為none的所有ops中權(quán)重最大的多搀,就是最終的結(jié)果歧蕉。
注意:weights[][]的size是[2 + 3 + 4 + 5][len(PRIMITIVES)]

參考鏈接

  1. https://cloud.tencent.com/developer/article/1348049
  2. https://blog.csdn.net/srdlaplace/article/details/80863346
  3. https://www.jiqizhixin.com/articles/2018-06-27-6
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市康铭,隨后出現(xiàn)的幾起案子惯退,更是在濱河造成了極大的恐慌,老刑警劉巖从藤,帶你破解...
    沈念sama閱讀 216,651評論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件催跪,死亡現(xiàn)場離奇詭異,居然都是意外死亡夷野,警方通過查閱死者的電腦和手機(jī)懊蒸,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,468評論 3 392
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來悯搔,“玉大人骑丸,你說我怎么就攤上這事”罟拢” “怎么了者娱?”我有些...
    開封第一講書人閱讀 162,931評論 0 353
  • 文/不壞的土叔 我叫張陵抡笼,是天一觀的道長苏揣。 經(jīng)常有香客問我,道長推姻,這世上最難降的妖魔是什么平匈? 我笑而不...
    開封第一講書人閱讀 58,218評論 1 292
  • 正文 為了忘掉前任,我火速辦了婚禮藏古,結(jié)果婚禮上增炭,老公的妹妹穿的比我還像新娘。我一直安慰自己拧晕,他們只是感情好隙姿,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,234評論 6 388
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著厂捞,像睡著了一般输玷。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上靡馁,一...
    開封第一講書人閱讀 51,198評論 1 299
  • 那天欲鹏,我揣著相機(jī)與錄音,去河邊找鬼臭墨。 笑死赔嚎,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播尤误,決...
    沈念sama閱讀 40,084評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼侠畔,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了袄膏?” 一聲冷哼從身側(cè)響起践图,我...
    開封第一講書人閱讀 38,926評論 0 274
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎沉馆,沒想到半個(gè)月后码党,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,341評論 1 311
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡斥黑,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,563評論 2 333
  • 正文 我和宋清朗相戀三年揖盘,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片锌奴。...
    茶點(diǎn)故事閱讀 39,731評論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡兽狭,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出鹿蜀,到底是詐尸還是另有隱情箕慧,我是刑警寧澤,帶...
    沈念sama閱讀 35,430評論 5 343
  • 正文 年R本政府宣布茴恰,位于F島的核電站颠焦,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏往枣。R本人自食惡果不足惜伐庭,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,036評論 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望分冈。 院中可真熱鬧圾另,春花似錦、人聲如沸雕沉。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,676評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽坡椒。三九已至扰路,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間肠牲,已是汗流浹背幼衰。 一陣腳步聲響...
    開封第一講書人閱讀 32,829評論 1 269
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留缀雳,地道東北人渡嚣。 一個(gè)月前我還...
    沈念sama閱讀 47,743評論 2 368
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親识椰。 傳聞我的和親對象是個(gè)殘疾皇子绝葡,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,629評論 2 354

推薦閱讀更多精彩內(nèi)容