在實(shí)驗(yàn)室參與開(kāi)發(fā)了一個(gè)評(píng)測(cè)平臺(tái)搬男。在使用的時(shí)候有部分攻擊算法會(huì)出現(xiàn)顯存溢出的情況钞它。同時(shí)隨著樣本數(shù)增加募谎,這種顯存占用比會(huì)同比增加扶关。而不是和預(yù)先設(shè)定的一樣,僅和設(shè)置的batch_size相關(guān)数冬。如此一來(lái)节槐,對(duì)于一些占用顯存本身就較大的算法搀庶,當(dāng)樣本數(shù)增加時(shí),Docker虛環(huán)境肯定會(huì)崩潰铜异。下面將給出我的整個(gè)搜索+解決問(wèn)題的過(guò)程哥倔。
寫(xiě)在最前面的話(huà)
這個(gè)問(wèn)題目前已經(jīng)解決,最終發(fā)現(xiàn)報(bào)錯(cuò)的原因是揍庄,開(kāi)發(fā)那邊沒(méi)有正確的把batch_size傳入攻擊算法中咆蒿,導(dǎo)致出現(xiàn)了可能只有1張圖像,但是開(kāi)了一個(gè)65倍圖像尺寸的空間(實(shí)際上3就夠了)蚂子,然后這個(gè)空間作為輸入傳入模型沃测,導(dǎo)致占用顯存過(guò)多。
在排查問(wèn)題的過(guò)程中食茎,馬佬告訴我蒂破,其實(shí)Pytorch之類(lèi)的都會(huì)有自動(dòng)回收機(jī)制,需要保證的其實(shí)是
for循環(huán)中的變量董瞻,如果是顯存上的寞蚌,盡量不要讓他離開(kāi)for循環(huán)范圍!
按照GC的原理钠糊,是引用計(jì)數(shù)的挟秤,當(dāng)某個(gè)局部變量不存在引用的時(shí)候,會(huì)自動(dòng)回收抄伍。因此如果for循環(huán)內(nèi)部/外部有引用艘刚,都會(huì)導(dǎo)致某些中間變量一直被持有。
舉個(gè)例子:
losses = []
for i in range(233):
x = Variable(input).to(device) 此時(shí)x在GPU上
output = self.model(x) 此時(shí)output也在GPU上
losses.append(output) 這句話(huà)將可能導(dǎo)致存儲(chǔ)了output梯度截珍,并由于持有output對(duì)象導(dǎo)致他不會(huì)在每次for循環(huán)后釋放
y = x + ... 這句話(huà)在for循環(huán)外攀甚,等于for循環(huán)結(jié)束的時(shí)候,x仍存在未來(lái)的引用可能岗喉,此時(shí)的x不會(huì)被回收
可以修改的方式有很多秋度,比如在for循環(huán)內(nèi)部losses.append一句中,可以把output轉(zhuǎn)成cpu上資源钱床。以及將y = 這一句考慮能不能刪去荚斯。
下面是正文:首先列舉全部搜索到的問(wèn)題:
問(wèn)題一 記錄累計(jì)信息時(shí)直接使用了輸出的Variable
這個(gè)問(wèn)題的發(fā)現(xiàn),是參考了這篇知乎回答《pytorch的坑---loss沒(méi)寫(xiě)好查牌,顯存爆炸》
原貼就問(wèn)題的描述:
算是動(dòng)態(tài)圖的一個(gè)坑吧事期。記錄loss信息的時(shí)候直接使用了輸出的Variable。
for data, label in trainloader: out = model(data) loss = criterion(out, label) loss_sum += loss # <--- 這里
運(yùn)行著就發(fā)現(xiàn)顯存炸了纸颜。觀察了一下發(fā)現(xiàn)隨著每個(gè)batch顯存消耗在不斷增大..
參考了別人的代碼發(fā)現(xiàn)那句loss一般是這樣寫(xiě):loss_sum += loss.data[0]
這是因?yàn)檩敵龅膌oss的數(shù)據(jù)類(lèi)型是Variable兽泣。而PyTorch的動(dòng)態(tài)圖機(jī)制就是通過(guò)Variable來(lái)構(gòu)建圖。主要是使用Variable計(jì)算的時(shí)候胁孙,會(huì)記錄下新產(chǎn)生的Variable的運(yùn)算符號(hào)唠倦,在反向傳播求導(dǎo)的時(shí)候進(jìn)行使用称鳞。
如果這里直接將loss加起來(lái),系統(tǒng)會(huì)認(rèn)為這里也是計(jì)算圖的一部分牵敷,也就是說(shuō)網(wǎng)絡(luò)會(huì)一直延伸變大胡岔,那么消耗的顯存也就越來(lái)越大
總之使用Variable的數(shù)據(jù)時(shí)候要非常小心法希。不是必要的話(huà)盡量使用Tensor來(lái)進(jìn)行計(jì)算...
問(wèn)題二 for循環(huán)過(guò)程中的迭代變量
參考討論帖《Tensor to Variable and memory freeing best practices》
在這篇帖子中有提到枷餐,Variable和Tensor實(shí)際上共用的是一塊內(nèi)存空間。所以在使用了Variable之后苫亦,del掉相應(yīng)的Variable毛肋。不會(huì)帶來(lái)明顯的內(nèi)存釋放。唯一可能帶來(lái)一定效果的屋剑,是在for循環(huán)過(guò)程中润匙,如
for i, (x, y) in enumerate(train_loader):
x = Variable(x)
y = Variable(y)
# compute model and update
del x, y, output
x和y本身作為train_loader中內(nèi)容,會(huì)占用一塊內(nèi)存唉匾,而循環(huán)時(shí)孕讳,會(huì)產(chǎn)生一塊臨時(shí)內(nèi)存。帖子中回復(fù)認(rèn)為巍膘,此處可以節(jié)省一點(diǎn)點(diǎn)厂财。需要注意的是,還需要額外刪去引用到x和y的變量峡懈,否則仍然存在占用璃饱。
問(wèn)題三 多次訓(xùn)練,GPU未釋放
參考自討論帖《How can we release GPU memory cache?》
這個(gè)帖子中描述的解決辦法為肪康,當(dāng)GPU計(jì)算完畢后荚恶,把相應(yīng)的變量和結(jié)果轉(zhuǎn)成CPU,然后調(diào)用GC磷支,調(diào)用torch.cuda.empty_cache()
def wipe_memory(self): # DOES WORK
self._optimizer_to(torch.device('cpu'))
del self.optimizer
gc.collect()
torch.cuda.empty_cache()
def _optimizer_to(self, device):
for param in self.optimizer.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)
問(wèn)題四 torch.load的坑
參考自知乎回答《PyTorch 有哪些坑/bug谒撼? - 知乎用戶(hù)的回答》
該回答中描述,當(dāng)你使用:
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint["state_dict"])
這樣load一個(gè) pretrained model 的時(shí)候雾狈,torch.load() 會(huì)默認(rèn)把load進(jìn)來(lái)的數(shù)據(jù)放到0卡上廓潜,這樣每個(gè)進(jìn)程全部會(huì)在0卡占用一部分顯存。解決的方法也很簡(jiǎn)單箍邮,就是把load進(jìn)來(lái)的數(shù)據(jù)map到cpu上:
checkpoint = torch.load("checkpoint.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])
按照馬佬的建議茉帅,此處如果不想用到cpu的話(huà),也可以map_location=rank锭弊。具體的寫(xiě)法參考了《pytorch源碼》以及《pytorch 分布式訓(xùn)練 distributed parallel 筆記》
# 獲取GPU的rank號(hào)
gpu = torch.distributed.get_rank(group=group) # group是可選參數(shù)堪澎,返回int,執(zhí)行該腳本的進(jìn)程的rank
# 獲取了進(jìn)程號(hào)后
rank = 'cuda:{}'.format(gpu)
checkpoint = torch.load(args.resume, map_location=rank)
問(wèn)題五 pretrain weights問(wèn)題
參考自之乎回答《PyTorch 有哪些坑/bug味滞? - 鯤China的回答》
在做交叉驗(yàn)證的時(shí)候樱蛤,每折初始化模型钮呀,由于用到了pretrained weights,這時(shí)候顯存不會(huì)被釋放昨凡,幾折過(guò)后顯存就爆炸了~爽醋,這時(shí)候用三行代碼就可以解決這個(gè)問(wèn)題
del model
gc.collect()
torch.cuda.empty_cache()
問(wèn)題六 不做backward,中間變量會(huì)保存
參考自《PyTorch 有哪些坑/bug便脊? - hjy666的回答》
但是上述方法是0.4中的解決方法蚂四。pytorch0.4到pytrch1.0跨度有點(diǎn)大,variable跟tensor合并成tensor了哪痰,不能設(shè)置volatile 參數(shù)遂赠,所以在做evaluation時(shí)很容易出現(xiàn)out of memory的問(wèn)題。所以你需要在最后的loss和predict輸出設(shè)置
.cpu().detach()
比如說(shuō):
total_loss.append(loss.cpu().detach().numpy())
total_finish_loss.append(finish_loss.cpu().detach().numpy())
嘗試解決問(wèn)題
方法一:全局查找字符串
全局查找累計(jì)過(guò)程晌杰,由于主要是+=
的問(wèn)題跷睦,所以grep +=
試試:
$ grep -rn "+=" ./
得到結(jié)果
zaozhe@ /d/LABOR/SUIBUAA_AIEP (dev_aiep)
$ grep -rn "+=" ./
Binary file ./Datasets/ImageNet/images/ILSVRC2012_val_00000005.JPEG matches
Binary file ./Datasets/ImageNet/images/ILSVRC2012_val_00000006.JPEG matches
Binary file ./Datasets/ImageNet/images/ILSVRC2012_val_00000007.JPEG matches
Binary file ./Datasets/ImageNet/images/ILSVRC2012_val_00000008.JPEG matches
./EvalBox/Defense/anp.py:100: total += inputs.shape[0]
./EvalBox/Defense/anp.py:101: correct += (preds == labels).sum().item()
./EvalBox/Defense/eat.py:223: total += inputs.shape[0]
./EvalBox/Defense/eat.py:224: correct += (preds == labels).sum().item()
...
但是搜索結(jié)果中存在很多的Binary file文件,把所有搜索結(jié)果拷貝到sublime中肋演,ctrl + F
搜索全部包含"Binary file"字樣的搜索行抑诸,使用ctrl + shift + K
一鍵刪除所有匹配行。
$ grep -rn "+=" ./
./EvalBox/Analysis/grad-cam.py:36: outputs += [x]
./EvalBox/Analysis/grad-cam.py:134: cam += w * target[i, :, :]
./EvalBox/Analysis/grand_CAM.py:32: outputs += [x]
./EvalBox/Analysis/grand_CAM.py:109: cam += w * target[i, :, :]
./EvalBox/Analysis/Rebust_Defense.py:66: total += inputs.shape[0]
./EvalBox/Analysis/Rebust_Defense.py:67: correct += (preds == labels).sum().item()
./EvalBox/Attack/AdvAttack/deepfool.py:105: loop_i += 1
./EvalBox/Attack/AdvAttack/deepfool.py:137: loop_i += 1
./EvalBox/Attack/AdvAttack/ead.py:208: cnt += 1
...
然后再手動(dòng)篩選掉與該問(wèn)題無(wú)關(guān)的行爹殊,如上方示例中deepfool中的+=1蜕乡,這里并不會(huì)產(chǎn)生問(wèn)題一中,無(wú)用梯度不釋放問(wèn)題边灭。然后這里我很快就定位到了具體的py文件中异希,有這么一行
output = model(xs)
方法二:確定輸入輸出尺寸
這一步很簡(jiǎn)單,就是在你覺(jué)得不妥的變量上绒瘦,輸出一下他的尺寸看看
print("in line xxx, the var xs 's shape = ", xs.shape)
加一些提示語(yǔ)称簿,然后看看會(huì)不會(huì)是傳入的圖像太大了。
我遇到的實(shí)際問(wèn)題就是因?yàn)槎杳保羲惴▓?zhí)行過(guò)程中憨降,用于做擾動(dòng)處理的預(yù)空間維度太高。按照馬佬的測(cè)試该酗,1張3 * 244 * 244的ImageNet圖像授药,在VGG模型上執(zhí)行預(yù)測(cè),約占用顯存1.6G呜魄。而我傳入的是80 * 3 * 375 * 500的輸入悔叽,所以顯存爆炸。改為3 * 3 * 375 * 500之后爵嗅,顯存就可以正常供給了娇澎。
方法三:如何查看實(shí)時(shí)的GPU使用率
這個(gè)也是debug過(guò)程中很苦惱的東西,想知道是不是在某一步的時(shí)候睹晒,傳到顯存上的東西太多了趟庄,但是又不方便單步調(diào)試括细。
使用指令nvidia-smi
可以看到當(dāng)前的GPU使用率,大致如圖:
但是我想要的是在執(zhí)行過(guò)程中戚啥,執(zhí)行的同時(shí)奋单,獲取具體的GPU使用情況。這里我參考了這篇博客《使用python中的GPUtil庫(kù)從NVIDA GPU獲取GPU狀態(tài)》
這里面用到了一個(gè)第三方庫(kù)叫GPUtil猫十,執(zhí)行pip install gputil
即可完成下載览濒。然后我封裝了一個(gè)函數(shù):
def get_gpu_info(self, text = ""):
print("當(dāng)前行為為:", text)
GPUtil.showUtilization()
def predict(self, xs, model):
var_xs = Variable(xs.to(device))
self.get_gpu_info("將xs傳入GPU")
for i in range(100):
for j in range(200):
output = model(var_xs)
self.get_gpu_info("執(zhí)行一次預(yù)測(cè)過(guò)程")
some work there ...
self.get_gpu_info("內(nèi)層循環(huán)迭代完畢,查看是否正確釋放顯存")
而這個(gè)的輸出類(lèi)似于下圖炫彩。最好還是添加一個(gè)輸出提示匾七,因?yàn)樗绻麤](méi)有提示做分割的話(huà)絮短,其實(shí)不是很方便看到底執(zhí)行到哪里了江兢。
在這篇參考博客中,我看到有這么一段代碼
import GPUtil
import time
while True:
Gpus = GPUtil.getGPUs()
for gpu in Gpus:
print('GPU總量', gpu.memoryTotal)
print('GPU使用量', gpu.memortUsed)
time.sleep(5)
他這里的意思是不停的輸出GPU的總量和使用量丁频。但是我實(shí)際使用過(guò)程中發(fā)現(xiàn)杉允,好像并不是非常的好用,具體情況見(jiàn)下圖席里。
可以看到右邊是我的一個(gè)實(shí)測(cè)結(jié)果叔磷,雖然我中間改變了GPU的使用情況,但是輸出的值基本沒(méi)變奖磁。不知道是更新不夠快還是如何改基。我一開(kāi)始以為是更新不夠快,但是我發(fā)現(xiàn)哪怕程序一開(kāi)始咖为,他都可能會(huì)顯示已經(jīng)占用了部分的顯存資源秕狰。所以我就改用了上面的那個(gè)GPUtil.showUtilization()
方法四:找大腿問(wèn)問(wèn)
如果debug實(shí)在是太難了,也不要一門(mén)心思去找自己的問(wèn)題躁染,找個(gè)小伙伴問(wèn)問(wèn)鸣哀。描述給他人的同時(shí)你也會(huì)更了解問(wèn)題所在,而且有可能對(duì)方一語(yǔ)道破吞彤!
最后放個(gè)圖紀(jì)念一下這篇博客的誕生: