一文詳解對抗訓(xùn)練方法

對抗訓(xùn)練方法

Adversarial learning主要是用于樣本生成或者對抗攻擊領(lǐng)域雪标,主要方法是通過添加鑒別器或者根據(jù)梯度回傳生成新樣本村刨,其主要是為了提升當(dāng)前主干模型生成樣本的能力或者魯棒性

一. 對抗訓(xùn)練定義

==對抗訓(xùn)練是一種引入噪聲的訓(xùn)練方式,可以對參數(shù)進行正則化逆粹,提升模型魯棒性和泛化能力==

1.1 對抗訓(xùn)練特點

  • 相對于原始輸入阿浓,所添加的擾動是微小的
  • 添加的噪聲可以使得模型預(yù)測錯誤

1.2 對抗訓(xùn)練的基本概念

就是在原始輸入樣本x上加上一個擾動\Delta x得到對抗樣本,再用其進行訓(xùn)練退敦,這個問題可以抽象成這樣一個模型:
\max _{\theta} P(y \mid x+\Delta x ; \theta)\tag{1}
其中,yground truth,\theta是模型參數(shù)。意思就是即使在擾動的情況下求使得預(yù)測出y的概率最大的參數(shù)网梢,擾動可以被定義為:
\Delta x=ε \cdot \operatorname{sign}\left(\nabla_{x} L(x, y ; \theta)\right)\tag{2}
其中赂毯,sign為符號函數(shù)战虏,L為損失函數(shù)

最后,GoodFellow還總結(jié)了對抗訓(xùn)練的兩個作用:

  1. 提高模型應(yīng)對惡意對抗樣本時的魯棒性
  2. 作為一種regularization党涕,減少overfitting烦感,提高泛化能力

1.3 Min-Max公式

Madry在2018年的ICLR論文Towards Deep Learning Models Resistant to Adversarial Attacks中總結(jié)了之前的工作,對抗訓(xùn)練可以統(tǒng)一寫成如下格式:
\min _{\theta} \mathbb{E}_{(x, y) \sim \mathcal{D}}\left[\max _{\Delta x \in \Omega} L(x+\Delta x, y ; \theta)\right]\tag{3}
其中\mathcal{D}代表輸入樣本的分布膛堤,x代表輸入手趣,y代表標簽中符,\theta是模型參數(shù)档插,L(x+y; \theta)是單個樣本的loss,\Delta x是擾動轴咱,\Omega是擾動空間。這個式子可以分布理解如下:

  1. 內(nèi)部max是指往x中添加擾動\Delta x跳昼,\Delta x的目的是讓L(x+\Delta x, y ; \theta)越大越好,也就是說盡可能讓現(xiàn)有模型預(yù)測出錯。但是,\Delta x也是有約束的恼除,要在\Omega范圍內(nèi). 常規(guī)的約束是|| \Delta x|| \leq ε,其中ε是一個常數(shù)
  2. 外部min是指找到最魯棒的參數(shù)\theta是預(yù)測的分布符合原數(shù)據(jù)集的分布

這就解決了兩個問題:如何構(gòu)建足夠強的對抗樣本旷痕、和如何使得分布仍然盡可能接近原始分布

1.4 NLP領(lǐng)域的對抗訓(xùn)練

對于CV領(lǐng)域报强,圖像被認為是連續(xù)的哮缺,因此可以直接在原始圖像上添加擾動直撤;而對于NLP悄雅,它的輸入是文本的本質(zhì)是one-hot,而one-hot之間的歐式距離恒為\sqrt{2}习蓬,理論上不存在微小的擾動笋婿,而且遇革,在Embedding向量上加上微小擾動可能就找不到與之對應(yīng)的詞了昂勒,不是真正意義上的對抗樣本难捌,因為對抗樣本依舊能對應(yīng)一個合理的原始輸入徘公,既然不能對Embedding向量添加擾動,可以對Embedding層添加擾動竿音,使其產(chǎn)生更魯棒的Embedding向量

二. 對抗訓(xùn)練方法

2.1 FGM(Fast Gradient Method) ICLR2017

FGM是根據(jù)具體的梯度進行scale,得到更好的對抗樣本:
r_{adv}=εg/\|g\|_2\tag{4}
整個對抗訓(xùn)練的過程如下,偽代碼如下:

  1. 計算x的前向loss、反向傳播得到梯度
  2. 根據(jù)embedding矩陣的梯度計算出r猿涨,并加到當(dāng)前embedding上,相當(dāng)于x+r
  3. 計算x+r的前向loss藐石,反向傳播得到對抗的梯度,累加到(1)的梯度上
  4. 將embedding恢復(fù)為(1)時的值
  5. 根據(jù)(3)的梯度對參數(shù)進行更新
class FGM:
    def __init__(self, model: nn.Module, eps=1.):
        self.model = (model.module if hasattr(model, "module") else model)
        self.eps = eps
        self.backup = {}
    # only attack word embedding
    def attack(self, emb_name='word_embeddings'):
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name:
                self.backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm and not torch.isnan(norm):
                    r_at = self.eps * param.grad / norm
                    param.data.add_(r_at)
    def restore(self, emb_name='word_embeddings'):
        for name, para in self.model.named_parameters():
            if para.requires_grad and emb_name in name:
                assert name in self.backup
                para.data = self.backup[name]
        self.backup = {}

2.2 FGSM (Fast Gradient Sign Method) ICLR2015

FGSM的全稱是Fast Gradient Sign Method. FGSM和FGM的核心區(qū)別在計算擾動的方式不一樣于微,F(xiàn)GSM擾動的計算方式如下:
r_{adv}=ε \cdot \operatorname{sign}\left(\nabla_{x} L(x, y ; \theta)\right)\tag{5}

def FGSM(image, epsilon, data_grad):
    """
    :param image: 需要攻擊的圖像
    :param epsilon: 擾動值的范圍
    :param data_grad: 圖像的梯度
    :return: 擾動后的圖像
    """
    # 收集數(shù)據(jù)梯度的元素符號
    sign_data_grad = data_grad.sign()
    # 通過調(diào)整輸入圖像的每個像素來創(chuàng)建擾動圖像
    perturbed_image = image + epsilon*sign_data_grad
    # 添加剪切以維持[0,1]范圍
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    # 返回被擾動的圖像
    return perturbed_image

2.3 PGD(Projected Gradient Descent)

FGM直接通過epsilon參數(shù)算出了對抗擾動逗嫡,這樣得到的可能不是最優(yōu)的。因此PGD進行了改進株依,通過迭代慢慢找到最優(yōu)的擾動

r_{adv|t+1}=\alpha g_t/\|g_t\|_2\tag{6}

并且\|r\|_2≤ε

PGD整個對抗訓(xùn)練的過程如下

  1. 計算x的前向loss驱证、反向傳播得到梯度并備份

  2. 對于每步t:

    1. 根據(jù)embedding矩陣的梯度計算出r,并加到當(dāng)前embedding上勺三,相當(dāng)于x+r(超出范圍則投影回epsilon內(nèi))
    2. if t不是最后一步: 將梯度歸0雷滚,根據(jù)(1)x+r計算前后向并得到梯度
    3. if t是最后一步: 恢復(fù)(1)的梯度,計算最后的x+r并將梯度累加到(1)
  3. 將embedding恢復(fù)為(1)時的值

  4. 根據(jù)(5)的梯度對參數(shù)進行更新

在循環(huán)中r是逐漸累加的吗坚,要注意的是最后更新參數(shù)只使用最后一個x+r算出來的梯度

class PGD():
    def __init__(self, model):
        self.model = model
        self.emb_backup = {}
        self.grad_backup = {}
    def attack(self, epsilon=1., alpha=0.3, emb_name='emb.', is_first_attack=False):
        # emb_name這個參數(shù)要換成你模型中embedding的參數(shù)名
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name:
                if is_first_attack:
                    self.emb_backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm != 0 and not torch.isnan(norm):
                    r_at = alpha * param.grad / norm
                    param.data.add_(r_at)
                    param.data = self.project(name, param.data, epsilon)
    def restore(self, emb_name='emb.'):
        # emb_name這個參數(shù)要換成你模型中embedding的參數(shù)名
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name: 
                assert name in self.emb_backup
                param.data = self.emb_backup[name]
        self.emb_backup = {}
    def project(self, param_name, param_data, epsilon):
        r = param_data - self.emb_backup[param_name]
        if torch.norm(r) > epsilon:
            r = epsilon * r / torch.norm(r)
        return self.emb_backup[param_name] + r
    def backup_grad(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.grad_backup[name] = param.grad.clone()
    def restore_grad(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.grad = self.grad_backup[name]

2.4 FreeAT(Free Adversarial Training)

從FGSM到PGD祈远,主要是優(yōu)化對抗擾動的計算呆万,雖然取得了更好的效果,但計算量也一步步增加车份。對于每個樣本谋减,F(xiàn)GSM和FGM都只用計算兩次,一次是計算x的前后向扫沼,一次是計算x+r的前后向出爹。而PGD則計算了K+1次,消耗了更多的計算資源缎除。因此FreeAT被提了出來严就,在PGD的基礎(chǔ)上進行訓(xùn)練速度的優(yōu)化

FreeAT的思想是在對每個樣本x連續(xù)重復(fù)m次訓(xùn)練,計算r時復(fù)用上一步的梯度器罐,為了保證速度梢为,整體epoch會除以mr的更新公式為:
r_{t+1}=r_t+ε \cdot sign(g)\tag{7}
FreeAT的訓(xùn)練過程如下:

  1. 初始化r=0
  2. 對于epoch=1...N/m:
    1. 對于每個x:
      1. 對于每步m:
        1. 利用上一步的r轰坊,計算x+r的前后向铸董,得到梯度
        2. 根據(jù)梯度更新參數(shù)
        3. 根據(jù)梯度更新r

FreeAT的問題在于每次的r對于當(dāng)前的參數(shù)都是次優(yōu)的(無法最大化loss),因為當(dāng)前r是由r_{t-1}\theta_{t-1}計算出來的肴沫,是對于\theta_{t-1}的最優(yōu)

2.5 YOPO(You Only Propagate Once)

YOPO的出發(fā)點是利用神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)來降低梯度計算的計算量粟害。從極大值原理PMP(Pontryagin’s maximum principle)出發(fā),對抗擾動只和網(wǎng)絡(luò)的第0層有關(guān)颤芬,即在embedding層上添加擾動悲幅。再加之層之間是解耦合的,那就不需要每次都計算完整的前后向傳播

基于這個想法站蝠,復(fù)用后面幾層的梯度夺艰,減少非必要的完整傳播〕烈拢可以將PGDr次攻擊拆成m\times n次:
p=\nabla_{g_\tilde \theta}(l(g_\tilde \theta(f_0(x_i+r_i^{j,0},\theta_0)),y_i))\cdot \nabla_{f_0}(g_\tilde \theta(f_0(x_i+r_i^{j,0},\theta_0)))\tag{8}
則對r的更新就可以變?yōu)?
r_i^{j,s+1}=r_i^{j,s}+\alpha_1\cdot\nabla_{r_i}f_0(x_i+r_i^{j,s},\theta_0)\tag{9}

其算法流程為:

對于每個樣本x,初始化r(1,0)减牺,對于j=1,2,…,m:

  1. 根據(jù)r(j,0),計算p對于s=0,1,…,n-1:
  2. 計算r(j,s+1)
  3. r(j+1,0)=r(j,n)

2.6 FreeLB (Free Large-Batch)

YOPO的假設(shè)對于ReLU-based網(wǎng)絡(luò)來說是不成立的豌习,因為YOPO要求損失是兩次可微的,于是拔疚,F(xiàn)reeLB在FreeAT的基礎(chǔ)上將每次inner-max中更新模型參數(shù)這一操作換掉肥隆,利用K步之后累積的參數(shù)梯度進行更新,于是總體任務(wù)的目標函數(shù)就記為:
\underset{\theta}{min}\mathbb E_{(Z,y)\sim \mathcal D}\left[\frac{1}{K}\sum_{t=0}^{K-1}\underset{\delta_t\in\mathcal I_t}{max}\ L(f_\theta(X+\delta_t),y)\right]\\ \mathcal I_t=\mathcal B_{X+\delta_0}(\alpha t)\cap\mathcal B_X(\epsilon)\tag{10}
X+\delta_t可以看成兩個球形鄰域的交上局部最大的近似稚失。同時栋艳,通過累積參數(shù)梯度的操作,可以看作是輸入了[X+\delta_0,\cdots,X+\delta_{K-1}]這樣一個虛擬的K倍大小的batch句各。其中input subwords的one-hot representations記為Z吸占,embedding matrix記為V晴叨,subwords embedding記為X = V Z

依據(jù)下面算法中的數(shù)學(xué)符號,PGD需要進行N_{ep}\cdot(K+1)次梯度計算矾屯,F(xiàn)reeAT需要進行N_{ep}次兼蕊,F(xiàn)reeLB需要N_{ep}\cdot K次。雖然FreeLB在效率上并沒有特別大的優(yōu)勢件蚕,但是其效果十分不錯

另外孙技,論文中指出對抗訓(xùn)練和dropout不能同時使用,加上dropout相當(dāng)于改變了網(wǎng)絡(luò)的結(jié)果排作,影響擾動的計算牵啦。如果一定要加入dropout操作,需要在K步中都使用同一個mask

2.7 SMART(SMoothness-inducing Adversarial Regularization)

SMART放棄了Min-Max公式妄痪,選擇通過正則項Smoothness-inducing Adversarial Regularization完成對抗學(xué)習(xí)哈雏。為了解決這個新的目標函數(shù)作者又提出了優(yōu)化算法Bregman Proximal Point Optimization,這就是SMART的兩個主要內(nèi)容

SMART的主要想法是強制模型在neighboring data points上作出相似的預(yù)測拌夏,加入正則項后的目標函數(shù)如下所示:
\underset{\theta}{min}\ \mathcal F(\theta)=\mathcal L(\theta)+\lambda_s\mathcal R_s(\theta))\\ \mathcal L(\theta)=\frac{1}{n}\sum_{i=1}^{n}\ell\left(f(x_i;\theta),y_i\right)\\ \mathcal R_s(\theta)=\frac{1}{n}\sum_{i=1}^{n}\underset{||\tilde x_i-x_i||_p\leq\epsilon}{max}\ \ell_s\left[f(\tilde x_i;\theta),f(x_i;\theta)\right]\tag{11}

\ell是具體任務(wù)的損失函數(shù)僧著,\tilde x_i是generated neighbors of training points,\ell_s在分類任務(wù)中使用對稱的KL散度障簿,即\ell_s(P,Q)=\mathcal D_{KL}(P||Q)+D_{KL}(Q||L)盹愚;在回歸任務(wù)中使用平方損失,\ell_s(p,q)=(p-q)^2此時可以看到對抗發(fā)生在正則化項上站故,對抗的目標是最大擾動前后的輸出

Bregman Proximal Point Optimization也可以看作是一個正則項皆怕,防止更新的時候\theta_{t+1}和前面的\theta_t變化過大讨惩。在第t+1次迭代時对竣,采用vanilla Bregman proximal point (VBPP) method
\theta_{t+1}=argmin_{\theta}\mathcal F(\theta)+\mu\mathcal D_{Breg}(\theta,\theta_t)\tag{12}
其中\mathcal D_{Breg}表示Bregman divergence定義為:
\mathcal D_{Breg}(\theta,\theta_t)=\frac{1}{n}\sum_{i=1}^n\ell_s\left(f(x_i;\theta),f(x_i;\theta_t)\right)\tag{13}
\ell_s是上面給出的對稱KL散度

使用動量來加速VBPP,此時定義\beta為動量堕虹,記\tilde\theta=(1-\beta)\theta_t+\beta\tilde\theta_{t-1}表示指數(shù)移動平均岂津,那么momentum Bregman proximal point (MBPP) method就可以表示為:
\theta_{t+1}=argmin_{\theta}\mathcal F(\theta)+\mu\mathcal D_{Breg}(\theta,\tilde\theta_t)\tag{14}
下面是SMART的完整算法流程:

  1. 對于t輪迭代:
    1. 備份theta虱黄,作為Bregman divergence計算的\theta_t
    2. 對于每一個batch
      1. 使用正態(tài)分布隨機初始化擾動,結(jié)合x得到x\_tilde
      2. 循環(huán)m小步:計
        1. 算擾動下的梯度g\_tilde
        2. 基于g\_tilde和學(xué)習(xí)率更新x\_tilde
      3. 基于x\_tilde重新計算梯度吮成,更新參數(shù)\theta
    3. 更新\theta_t

三. Reference

  1. Madry A, Makelov A, Schmidt L, et al. Towards deep learning models resistant to adversarial attacks[J]. arXiv preprint arXiv:1706.06083, 2017.

  2. Goodfellow I J, Shlens J, Szegedy C. Explaining and harnessing adversarial examples[J]. arXiv preprint arXiv:1412.6572, 2014.

  3. Miyato T, Dai A M, Goodfellow I. Adversarial training methods for semi-supervised text classification[J]. arXiv preprint arXiv:1605.07725, 2016.

  4. Shafahi A, Najibi M, Ghiasi A, et al. Adversarial training for free![J]. arXiv preprint arXiv:1904.12843, 2019.

  5. Zhang D, Zhang T, Lu Y, et al. You only propagate once: Accelerating adversarial training via maximal principle[J]. arXiv preprint arXiv:1905.00877, 2019.

  6. Zhu C, Cheng Y, Gan Z, et al. Freelb: Enhanced adversarial training for natural language understanding[J]. arXiv preprint arXiv:1909.11764, 2019.

  7. Jiang H, He P, Chen W, et al. Smart: Robust and efficient fine-tuning for pre-trained natural language models through principled regularized optimization[J]. arXiv preprint arXiv:1911.03437, 2019.

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末橱乱,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子粱甫,更是在濱河造成了極大的恐慌泳叠,老刑警劉巖,帶你破解...
    沈念sama閱讀 222,183評論 6 516
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件茶宵,死亡現(xiàn)場離奇詭異危纫,居然都是意外死亡,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,850評論 3 399
  • 文/潘曉璐 我一進店門种蝶,熙熙樓的掌柜王于貴愁眉苦臉地迎上來契耿,“玉大人,你說我怎么就攤上這事蛤吓∠梗” “怎么了?”我有些...
    開封第一講書人閱讀 168,766評論 0 361
  • 文/不壞的土叔 我叫張陵会傲,是天一觀的道長锅棕。 經(jīng)常有香客問我淌山,道長,這世上最難降的妖魔是什么泼疑? 我笑而不...
    開封第一講書人閱讀 59,854評論 1 299
  • 正文 為了忘掉前任,我火速辦了婚禮退渗,結(jié)果婚禮上移稳,老公的妹妹穿的比我還像新娘。我一直安慰自己会油,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 68,871評論 6 398
  • 文/花漫 我一把揭開白布翻翩。 她就那樣靜靜地躺著,像睡著了一般嫂冻。 火紅的嫁衣襯著肌膚如雪胶征。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 52,457評論 1 311
  • 那天桨仿,我揣著相機與錄音睛低,去河邊找鬼服傍。 笑死,一個胖子當(dāng)著我的面吹牛伴嗡,可吹牛的內(nèi)容都是我干的从铲。 我是一名探鬼主播,決...
    沈念sama閱讀 40,999評論 3 422
  • 文/蒼蘭香墨 我猛地睜開眼阱扬,長吁一口氣:“原來是場噩夢啊……” “哼泣懊!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起馍刮,我...
    開封第一講書人閱讀 39,914評論 0 277
  • 序言:老撾萬榮一對情侶失蹤窃蹋,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后警没,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,465評論 1 319
  • 正文 獨居荒郊野嶺守林人離奇死亡亡脸,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,543評論 3 342
  • 正文 我和宋清朗相戀三年树酪,在試婚紗的時候發(fā)現(xiàn)自己被綠了浅碾。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片续语。...
    茶點故事閱讀 40,675評論 1 353
  • 序言:一個原本活蹦亂跳的男人離奇死亡垂谢,死狀恐怖绵载,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情娃豹,我是刑警寧澤,帶...
    沈念sama閱讀 36,354評論 5 351
  • 正文 年R本政府宣布鹃栽,位于F島的核電站躯畴,受9級特大地震影響民鼓,放射性物質(zhì)發(fā)生泄漏蓬抄。R本人自食惡果不足惜丰嘉,卻給世界環(huán)境...
    茶點故事閱讀 42,029評論 3 335
  • 文/蒙蒙 一嚷缭、第九天 我趴在偏房一處隱蔽的房頂上張望耍贾。 院中可真熱鬧,春花似錦荐开、人聲如沸简肴。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,514評論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至仍翰,卻和暖如春赫粥,著一層夾襖步出監(jiān)牢的瞬間予借,已是汗流浹背越平。 一陣腳步聲響...
    開封第一講書人閱讀 33,616評論 1 274
  • 我被黑心中介騙來泰國打工灵迫, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人挣跋。 一個月前我還...
    沈念sama閱讀 49,091評論 3 378
  • 正文 我出身青樓狞换,卻偏偏與公主長得像避咆,于是被迫代替她去往敵國和親修噪。 傳聞我的和親對象是個殘疾皇子查库,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,685評論 2 360

推薦閱讀更多精彩內(nèi)容