問題描述:
最近一直在使用pytorch, 由于深度學(xué)習(xí)的網(wǎng)絡(luò)往往需要設(shè)置驗(yàn)證集來驗(yàn)證模型是否穩(wěn)定.
我一直再做一個(gè)關(guān)于醫(yī)學(xué)影像分割的課題,為了查看自己的模型是否穩(wěn)定,于是設(shè)置了驗(yàn)證集.
但是在運(yùn)行的過程中,當(dāng)程序執(zhí)行到 validatioon時(shí),顯存立即上升,我可憐的顯卡只有8GB顯存,瞬間爆炸.
怎么辦呢?實(shí)驗(yàn)得做呀.于是找了不少方法,比如設(shè)置各個(gè)網(wǎng)絡(luò)變量requires_grad=False,但是并不管用,顯存依然爆炸.
后來百度了一番,終于解決了顯存爆炸的問題.
解決方案:
假設(shè)訓(xùn)練程序是這樣的:
for train_data, train_label in ?train_dataloader:
? ? do?
? ? ? ? ? ?trainning
then
for valid_data,valid_label in valid_dataloader:
? ? do?
? ? ? ? ? ? validtion
當(dāng)程序執(zhí)行到validation時(shí),顯存忽然上升,幾乎是之前的兩倍.
只需要這樣改:
for train_data, train_label in?train_dataloader:
????????do
????????????trainning
then
with torch.no_grad():
????for valid_data,valid_label in valid_dataloader:
????????????do
? ? ? ? ? ????? validtion
當(dāng)程序執(zhí)行到validation時(shí),顯存將不再上升.問題得到解決.真的是非常簡單.