torch代碼運(yùn)行時(shí)顯存溢出問(wèn)題

在實(shí)驗(yàn)室參與開(kāi)發(fā)了一個(gè)評(píng)測(cè)平臺(tái)搬男。在使用的時(shí)候有部分攻擊算法會(huì)出現(xiàn)顯存溢出的情況钞它。同時(shí)隨著樣本數(shù)增加募谎,這種顯存占用比會(huì)同比增加扶关。而不是和預(yù)先設(shè)定的一樣,僅和設(shè)置的batch_size相關(guān)数冬。如此一來(lái)节槐,對(duì)于一些占用顯存本身就較大的算法搀庶,當(dāng)樣本數(shù)增加時(shí),Docker虛環(huán)境肯定會(huì)崩潰铜异。下面將給出我的整個(gè)搜索+解決問(wèn)題的過(guò)程哥倔。

寫(xiě)在最前面的話(huà)

這個(gè)問(wèn)題目前已經(jīng)解決,最終發(fā)現(xiàn)報(bào)錯(cuò)的原因是揍庄,開(kāi)發(fā)那邊沒(méi)有正確的把batch_size傳入攻擊算法中咆蒿,導(dǎo)致出現(xiàn)了可能只有1張圖像,但是開(kāi)了一個(gè)65倍圖像尺寸的空間(實(shí)際上3就夠了)蚂子,然后這個(gè)空間作為輸入傳入模型沃测,導(dǎo)致占用顯存過(guò)多。

在排查問(wèn)題的過(guò)程中食茎,馬佬告訴我蒂破,其實(shí)Pytorch之類(lèi)的都會(huì)有自動(dòng)回收機(jī)制,需要保證的其實(shí)是

for循環(huán)中的變量董瞻,如果是顯存上的寞蚌,盡量不要讓他離開(kāi)for循環(huán)范圍!
按照GC的原理钠糊,是引用計(jì)數(shù)的挟秤,當(dāng)某個(gè)局部變量不存在引用的時(shí)候,會(huì)自動(dòng)回收抄伍。因此如果for循環(huán)內(nèi)部/外部有引用艘刚,都會(huì)導(dǎo)致某些中間變量一直被持有。

舉個(gè)例子:

losses = []
for i in range(233):
    x = Variable(input).to(device)  此時(shí)x在GPU上
    output = self.model(x)          此時(shí)output也在GPU上
    losses.append(output)           這句話(huà)將可能導(dǎo)致存儲(chǔ)了output梯度截珍,并由于持有output對(duì)象導(dǎo)致他不會(huì)在每次for循環(huán)后釋放
y = x + ...         這句話(huà)在for循環(huán)外攀甚,等于for循環(huán)結(jié)束的時(shí)候,x仍存在未來(lái)的引用可能岗喉,此時(shí)的x不會(huì)被回收

可以修改的方式有很多秋度,比如在for循環(huán)內(nèi)部losses.append一句中,可以把output轉(zhuǎn)成cpu上資源钱床。以及將y = 這一句考慮能不能刪去荚斯。


下面是正文:首先列舉全部搜索到的問(wèn)題:

問(wèn)題一 記錄累計(jì)信息時(shí)直接使用了輸出的Variable

這個(gè)問(wèn)題的發(fā)現(xiàn),是參考了這篇知乎回答《pytorch的坑---loss沒(méi)寫(xiě)好查牌,顯存爆炸》
原貼就問(wèn)題的描述:

算是動(dòng)態(tài)圖的一個(gè)坑吧事期。記錄loss信息的時(shí)候直接使用了輸出的Variable。

for data, label in trainloader:
    out = model(data)
    loss = criterion(out, label)
    loss_sum += loss     # <--- 這里

運(yùn)行著就發(fā)現(xiàn)顯存炸了纸颜。觀察了一下發(fā)現(xiàn)隨著每個(gè)batch顯存消耗在不斷增大..
參考了別人的代碼發(fā)現(xiàn)那句loss一般是這樣寫(xiě):

loss_sum += loss.data[0]

這是因?yàn)檩敵龅膌oss的數(shù)據(jù)類(lèi)型是Variable兽泣。而PyTorch的動(dòng)態(tài)圖機(jī)制就是通過(guò)Variable來(lái)構(gòu)建圖。主要是使用Variable計(jì)算的時(shí)候胁孙,會(huì)記錄下新產(chǎn)生的Variable的運(yùn)算符號(hào)唠倦,在反向傳播求導(dǎo)的時(shí)候進(jìn)行使用称鳞。
如果這里直接將loss加起來(lái),系統(tǒng)會(huì)認(rèn)為這里也是計(jì)算圖的一部分牵敷,也就是說(shuō)網(wǎng)絡(luò)會(huì)一直延伸變大胡岔,那么消耗的顯存也就越來(lái)越大
總之使用Variable的數(shù)據(jù)時(shí)候要非常小心法希。不是必要的話(huà)盡量使用Tensor來(lái)進(jìn)行計(jì)算...

問(wèn)題二 for循環(huán)過(guò)程中的迭代變量

參考討論帖《Tensor to Variable and memory freeing best practices》
在這篇帖子中有提到枷餐,Variable和Tensor實(shí)際上共用的是一塊內(nèi)存空間。所以在使用了Variable之后苫亦,del掉相應(yīng)的Variable毛肋。不會(huì)帶來(lái)明顯的內(nèi)存釋放。唯一可能帶來(lái)一定效果的屋剑,是在for循環(huán)過(guò)程中润匙,如

for i, (x, y) in enumerate(train_loader):
    x = Variable(x)
    y = Variable(y)
    # compute model and update
    del x, y, output 

x和y本身作為train_loader中內(nèi)容,會(huì)占用一塊內(nèi)存唉匾,而循環(huán)時(shí)孕讳,會(huì)產(chǎn)生一塊臨時(shí)內(nèi)存。帖子中回復(fù)認(rèn)為巍膘,此處可以節(jié)省一點(diǎn)點(diǎn)厂财。需要注意的是,還需要額外刪去引用到x和y的變量峡懈,否則仍然存在占用璃饱。

問(wèn)題三 多次訓(xùn)練,GPU未釋放

參考自討論帖《How can we release GPU memory cache?》
這個(gè)帖子中描述的解決辦法為肪康,當(dāng)GPU計(jì)算完畢后荚恶,把相應(yīng)的變量和結(jié)果轉(zhuǎn)成CPU,然后調(diào)用GC磷支,調(diào)用torch.cuda.empty_cache()

def wipe_memory(self): # DOES WORK
    self._optimizer_to(torch.device('cpu'))
    del self.optimizer
    gc.collect()
    torch.cuda.empty_cache()

def _optimizer_to(self, device):
    for param in self.optimizer.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

問(wèn)題四 torch.load的坑

參考自知乎回答《PyTorch 有哪些坑/bug谒撼? - 知乎用戶(hù)的回答》
該回答中描述,當(dāng)你使用:

checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint["state_dict"])

這樣load一個(gè) pretrained model 的時(shí)候雾狈,torch.load() 會(huì)默認(rèn)把load進(jìn)來(lái)的數(shù)據(jù)放到0卡上廓潜,這樣每個(gè)進(jìn)程全部會(huì)在0卡占用一部分顯存。解決的方法也很簡(jiǎn)單箍邮,就是把load進(jìn)來(lái)的數(shù)據(jù)map到cpu上:

checkpoint = torch.load("checkpoint.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])

按照馬佬的建議茉帅,此處如果不想用到cpu的話(huà),也可以map_location=rank锭弊。具體的寫(xiě)法參考了《pytorch源碼》以及《pytorch 分布式訓(xùn)練 distributed parallel 筆記》

    # 獲取GPU的rank號(hào)
    gpu = torch.distributed.get_rank(group=group)  # group是可選參數(shù)堪澎,返回int,執(zhí)行該腳本的進(jìn)程的rank
    # 獲取了進(jìn)程號(hào)后
    rank = 'cuda:{}'.format(gpu)
    checkpoint = torch.load(args.resume, map_location=rank)

問(wèn)題五 pretrain weights問(wèn)題

參考自之乎回答《PyTorch 有哪些坑/bug味滞? - 鯤China的回答》
在做交叉驗(yàn)證的時(shí)候樱蛤,每折初始化模型钮呀,由于用到了pretrained weights,這時(shí)候顯存不會(huì)被釋放昨凡,幾折過(guò)后顯存就爆炸了~爽醋,這時(shí)候用三行代碼就可以解決這個(gè)問(wèn)題

del model
gc.collect()
torch.cuda.empty_cache()

問(wèn)題六 不做backward,中間變量會(huì)保存

參考自《PyTorch 有哪些坑/bug便脊? - hjy666的回答》

但是上述方法是0.4中的解決方法蚂四。pytorch0.4到pytrch1.0跨度有點(diǎn)大,variable跟tensor合并成tensor了哪痰,不能設(shè)置volatile 參數(shù)遂赠,所以在做evaluation時(shí)很容易出現(xiàn)out of memory的問(wèn)題。所以你需要在最后的loss和predict輸出設(shè)置

.cpu().detach()

比如說(shuō):

total_loss.append(loss.cpu().detach().numpy())
total_finish_loss.append(finish_loss.cpu().detach().numpy())

嘗試解決問(wèn)題

方法一:全局查找字符串

全局查找累計(jì)過(guò)程晌杰,由于主要是+=的問(wèn)題跷睦,所以grep +=試試:

$ grep -rn "+=" ./

得到結(jié)果

zaozhe@ /d/LABOR/SUIBUAA_AIEP (dev_aiep)
$ grep -rn "+=" ./
Binary file ./Datasets/ImageNet/images/ILSVRC2012_val_00000005.JPEG matches
Binary file ./Datasets/ImageNet/images/ILSVRC2012_val_00000006.JPEG matches
Binary file ./Datasets/ImageNet/images/ILSVRC2012_val_00000007.JPEG matches
Binary file ./Datasets/ImageNet/images/ILSVRC2012_val_00000008.JPEG matches
./EvalBox/Defense/anp.py:100:                total += inputs.shape[0]
./EvalBox/Defense/anp.py:101:                correct += (preds == labels).sum().item()
./EvalBox/Defense/eat.py:223:                total += inputs.shape[0]
./EvalBox/Defense/eat.py:224:                correct += (preds == labels).sum().item()
...

但是搜索結(jié)果中存在很多的Binary file文件,把所有搜索結(jié)果拷貝到sublime中肋演,ctrl + F搜索全部包含"Binary file"字樣的搜索行抑诸,使用ctrl + shift + K一鍵刪除所有匹配行。

$ grep -rn "+=" ./
./EvalBox/Analysis/grad-cam.py:36:                outputs += [x]
./EvalBox/Analysis/grad-cam.py:134:            cam += w * target[i, :, :]
./EvalBox/Analysis/grand_CAM.py:32:                outputs += [x]
./EvalBox/Analysis/grand_CAM.py:109:            cam += w * target[i, :, :]
./EvalBox/Analysis/Rebust_Defense.py:66:                total += inputs.shape[0]
./EvalBox/Analysis/Rebust_Defense.py:67:                correct += (preds == labels).sum().item()
./EvalBox/Attack/AdvAttack/deepfool.py:105:                loop_i += 1
./EvalBox/Attack/AdvAttack/deepfool.py:137:            loop_i += 1
./EvalBox/Attack/AdvAttack/ead.py:208:                cnt += 1
...

然后再手動(dòng)篩選掉與該問(wèn)題無(wú)關(guān)的行爹殊,如上方示例中deepfool中的+=1蜕乡,這里并不會(huì)產(chǎn)生問(wèn)題一中,無(wú)用梯度不釋放問(wèn)題边灭。然后這里我很快就定位到了具體的py文件中异希,有這么一行

    output = model(xs)

方法二:確定輸入輸出尺寸

這一步很簡(jiǎn)單,就是在你覺(jué)得不妥的變量上绒瘦,輸出一下他的尺寸看看

print("in line xxx, the var xs 's shape = ", xs.shape)

加一些提示語(yǔ)称簿,然后看看會(huì)不會(huì)是傳入的圖像太大了。

我遇到的實(shí)際問(wèn)題就是因?yàn)槎杳保羲惴▓?zhí)行過(guò)程中憨降,用于做擾動(dòng)處理的預(yù)空間維度太高。按照馬佬的測(cè)試该酗,1張3 * 244 * 244的ImageNet圖像授药,在VGG模型上執(zhí)行預(yù)測(cè),約占用顯存1.6G呜魄。而我傳入的是80 * 3 * 375 * 500的輸入悔叽,所以顯存爆炸。改為3 * 3 * 375 * 500之后爵嗅,顯存就可以正常供給了娇澎。

方法三:如何查看實(shí)時(shí)的GPU使用率

這個(gè)也是debug過(guò)程中很苦惱的東西,想知道是不是在某一步的時(shí)候睹晒,傳到顯存上的東西太多了趟庄,但是又不方便單步調(diào)試括细。

使用指令nvidia-smi可以看到當(dāng)前的GPU使用率,大致如圖:

但是我想要的是在執(zhí)行過(guò)程中戚啥,執(zhí)行的同時(shí)奋单,獲取具體的GPU使用情況。這里我參考了這篇博客《使用python中的GPUtil庫(kù)從NVIDA GPU獲取GPU狀態(tài)》
這里面用到了一個(gè)第三方庫(kù)叫GPUtil猫十,執(zhí)行pip install gputil即可完成下載览濒。然后我封裝了一個(gè)函數(shù):

   def get_gpu_info(self, text = ""):
      print("當(dāng)前行為為:", text)
      GPUtil.showUtilization()
   def predict(self, xs, model):
       var_xs = Variable(xs.to(device))
       self.get_gpu_info("將xs傳入GPU")
       for i in range(100):
           for j in range(200):
               output = model(var_xs)
               self.get_gpu_info("執(zhí)行一次預(yù)測(cè)過(guò)程")
               some work there ...
           self.get_gpu_info("內(nèi)層循環(huán)迭代完畢,查看是否正確釋放顯存")

而這個(gè)的輸出類(lèi)似于下圖炫彩。最好還是添加一個(gè)輸出提示匾七,因?yàn)樗绻麤](méi)有提示做分割的話(huà)絮短,其實(shí)不是很方便看到底執(zhí)行到哪里了江兢。


在這篇參考博客中,我看到有這么一段代碼

import GPUtil
import time
while True:
    Gpus = GPUtil.getGPUs()
    for gpu in Gpus:
        print('GPU總量', gpu.memoryTotal)
        print('GPU使用量', gpu.memortUsed)
    time.sleep(5)

他這里的意思是不停的輸出GPU的總量和使用量丁频。但是我實(shí)際使用過(guò)程中發(fā)現(xiàn)杉允,好像并不是非常的好用,具體情況見(jiàn)下圖席里。


可以看到右邊是我的一個(gè)實(shí)測(cè)結(jié)果叔磷,雖然我中間改變了GPU的使用情況,但是輸出的值基本沒(méi)變奖磁。不知道是更新不夠快還是如何改基。我一開(kāi)始以為是更新不夠快,但是我發(fā)現(xiàn)哪怕程序一開(kāi)始咖为,他都可能會(huì)顯示已經(jīng)占用了部分的顯存資源秕狰。所以我就改用了上面的那個(gè)GPUtil.showUtilization()

方法四:找大腿問(wèn)問(wèn)

如果debug實(shí)在是太難了,也不要一門(mén)心思去找自己的問(wèn)題躁染,找個(gè)小伙伴問(wèn)問(wèn)鸣哀。描述給他人的同時(shí)你也會(huì)更了解問(wèn)題所在,而且有可能對(duì)方一語(yǔ)道破吞彤!


最后放個(gè)圖紀(jì)念一下這篇博客的誕生:


最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末我衬,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子饰恕,更是在濱河造成了極大的恐慌挠羔,老刑警劉巖,帶你破解...
    沈念sama閱讀 222,252評(píng)論 6 516
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件埋嵌,死亡現(xiàn)場(chǎng)離奇詭異破加,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)莉恼,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,886評(píng)論 3 399
  • 文/潘曉璐 我一進(jìn)店門(mén)拌喉,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)速那,“玉大人,你說(shuō)我怎么就攤上這事尿背《搜觯” “怎么了?”我有些...
    開(kāi)封第一講書(shū)人閱讀 168,814評(píng)論 0 361
  • 文/不壞的土叔 我叫張陵田藐,是天一觀的道長(zhǎng)荔烧。 經(jīng)常有香客問(wèn)我,道長(zhǎng)汽久,這世上最難降的妖魔是什么鹤竭? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 59,869評(píng)論 1 299
  • 正文 為了忘掉前任,我火速辦了婚禮景醇,結(jié)果婚禮上臀稚,老公的妹妹穿的比我還像新娘。我一直安慰自己三痰,他們只是感情好吧寺,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,888評(píng)論 6 398
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著散劫,像睡著了一般稚机。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上获搏,一...
    開(kāi)封第一講書(shū)人閱讀 52,475評(píng)論 1 312
  • 那天赖条,我揣著相機(jī)與錄音,去河邊找鬼常熙。 笑死纬乍,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的症概。 我是一名探鬼主播蕾额,決...
    沈念sama閱讀 41,010評(píng)論 3 422
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼彼城!你這毒婦竟也來(lái)了诅蝶?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書(shū)人閱讀 39,924評(píng)論 0 277
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤募壕,失蹤者是張志新(化名)和其女友劉穎调炬,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體舱馅,經(jīng)...
    沈念sama閱讀 46,469評(píng)論 1 319
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡缰泡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,552評(píng)論 3 342
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片棘钞。...
    茶點(diǎn)故事閱讀 40,680評(píng)論 1 353
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡缠借,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出宜猜,到底是詐尸還是另有隱情泼返,我是刑警寧澤,帶...
    沈念sama閱讀 36,362評(píng)論 5 351
  • 正文 年R本政府宣布姨拥,位于F島的核電站绅喉,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏叫乌。R本人自食惡果不足惜柴罐,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 42,037評(píng)論 3 335
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望憨奸。 院中可真熱鬧革屠,春花似錦、人聲如沸膀藐。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 32,519評(píng)論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)额各。三九已至,卻和暖如春吧恃,著一層夾襖步出監(jiān)牢的瞬間虾啦,已是汗流浹背。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 33,621評(píng)論 1 274
  • 我被黑心中介騙來(lái)泰國(guó)打工痕寓, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留傲醉,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 49,099評(píng)論 3 378
  • 正文 我出身青樓呻率,卻偏偏與公主長(zhǎng)得像硬毕,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子礼仗,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,691評(píng)論 2 361