在基于PyTorch的深度學(xué)習(xí)框架中,通常需要將以下內(nèi)容傳到GPU:
- 模型參數(shù):神經(jīng)網(wǎng)絡(luò)模型的權(quán)重和偏置吭从。
- 輸入數(shù)據(jù):訓(xùn)練和測(cè)試時(shí)的輸入數(shù)據(jù)張量踊赠。
- 標(biāo)簽數(shù)據(jù):對(duì)應(yīng)的標(biāo)簽數(shù)據(jù)張量(如果有)枚钓。
- 損失函數(shù)和優(yōu)化器狀態(tài):雖然損失函數(shù)和優(yōu)化器本身不需要移動(dòng)到GPU唆姐,但它們內(nèi)部使用的數(shù)據(jù)和模型參數(shù)需要在GPU上。
步驟
- 設(shè)置設(shè)備
# 只有一塊GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 指定某個(gè)GPU
device = torch.device("cuda:1")
- 將模型移動(dòng)到GPU:
model = MyNeuralNet()
model.to(device)
- 將輸入數(shù)據(jù)和標(biāo)簽移動(dòng)到GPU:
inputs = inputs.to(device)
labels = labels.to(device)
- 損失函數(shù)和優(yōu)化器:
損失函數(shù)不需要顯式地移動(dòng)到GPU舞竿,因?yàn)樗鼤?huì)自動(dòng)處理張量的位置京景。
優(yōu)化器(optimizer)需要在模型參數(shù)移動(dòng)到GPU之后定義。
顯存管理
由于數(shù)據(jù)是一批一批地讀進(jìn)顯存的骗奖,當(dāng)每個(gè)批次的數(shù)據(jù)被用完之后确徙,GPU顯存中的這些數(shù)據(jù)會(huì)被釋放醒串,為下一個(gè)批次的數(shù)據(jù)騰出空間。
因此鄙皇,適當(dāng)降低批次大形叨摹(batch size)是一個(gè)常用的方法來節(jié)省顯存。