最近碰到pytorch分布式訓(xùn)練時候问慎,memory幾乎線性增加萍摊,撐炸機器的問題。
pytorch中內(nèi)存泄漏常見的原因大概是以下幾點:
- 不恰當(dāng)?shù)膌oss累加
- 有些人累加梯度會直接把梯度拿過來加如叼,但是由于每個梯度都存儲了很多東西冰木,導(dǎo)致隨著step增加,梯度累計的越來越多笼恰,造成了內(nèi)存泄漏
- 做法就是把loss的數(shù)值取出來踊沸,而不是累計整個梯度
- 直接把list轉(zhuǎn)化成tensor
- 常出現(xiàn)在dataloader的dataset類實現(xiàn)上,data是list社证,在get_item的時候直接轉(zhuǎn)化成tensor了逼龟,這樣好像每次都會造成數(shù)據(jù)多覆蓋。不過這個好像是python天然的原因追葡,算不上pytorch的鍋
- 標(biāo)準做法就是dataset類中存儲的是np類型腺律,然后再轉(zhuǎn)tensor(list->np->tensor)
- dataloader中num_worker的設(shè)置
- 有些版本中,num_worker設(shè)置大于0宜肉,就會有內(nèi)存泄漏匀钧,改成0就沒問題了
參考鏈接:https://github.com/pytorch/pytorch/issues/13246
有時候可能會遇到不同的問題,具體問題可以通過python的內(nèi)存分析工具做分析(不過講道理不是太管用)比如:https://www.pythonf.cn/read/108519谬返,https://zhuanlan.zhihu.com/p/121003986
我的心情隨著第一個github的issue答案起起伏伏榴捡,試了幾遍都不行,然后忽然想到朱浴,這些bug官方都回復(fù)修了吊圾,怎么還能有問題呢…然后轉(zhuǎn)頭把sagemaker上pytorch的版本從1.6降到了1.5,世界安靜了…
最近一天一個bug翰蠢,踩坑美滋滋