原理
在神經(jīng)網(wǎng)絡(luò)中,參數(shù)默認(rèn)是進(jìn)行隨機(jī)初始化的食零。不同的初始化參數(shù)往往會(huì)導(dǎo)致不同的結(jié)果滑废。
當(dāng)?shù)玫奖容^好的結(jié)果時(shí)我們通常希望這個(gè)結(jié)果是可以復(fù)現(xiàn)的蝗肪,在pytorch中,通過(guò)設(shè)置全局隨機(jī)數(shù)種子可以實(shí)現(xiàn)這個(gè)目的蠕趁。
具體操作
對(duì)隨機(jī)數(shù)生成器設(shè)置固定種子的操作可以分為四部分薛闪。
1. cuDNN
cudnn中對(duì)卷積操作進(jìn)行了優(yōu)化,犧牲了精度來(lái)?yè)Q取計(jì)算效率俺陋。
如果需要保證可重復(fù)性豁延,可以使用如下設(shè)置:
from torch.backends import cudnn
cudnn.benchmark = False # if benchmark=True, deterministic will be False
cudnn.deterministic = True
不過(guò)實(shí)際上這個(gè)設(shè)置對(duì)精度影響不大,僅僅是小數(shù)點(diǎn)后幾位的差別腊状。所以如果不是對(duì)精度要求極高诱咏,其實(shí)不太建議修改,因?yàn)闀?huì)使計(jì)算效率降低缴挖。
2. PyTorch
seed = 0
torch.manual_seed(seed) # 為CPU設(shè)置隨機(jī)種子
torch.cuda.manual_seed(seed) # 為當(dāng)前GPU設(shè)置隨機(jī)種子
torch.cuda.manual_seed_all(seed) # 為所有GPU設(shè)置隨機(jī)種子
3. Python & NumPy
如果讀取數(shù)據(jù)的過(guò)程采用了隨機(jī)預(yù)處理(如RandomCrop袋狞、RandomHorizontalFlip等),那么對(duì)python映屋、numpy的隨機(jī)數(shù)生成器也需要設(shè)置種子苟鸯。
import os
import random
import numpy as np
seed = 0
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed) # 為了禁止hash隨機(jī)化,使得實(shí)驗(yàn)可復(fù)現(xiàn)秧荆。
4. dataloader
如果dataloader采用了多線(xiàn)程(num_workers > 1), 那么由于讀取數(shù)據(jù)的順序不同倔毙,最終運(yùn)行結(jié)果也會(huì)有差異。也就是說(shuō)乙濒,改變num_workers參數(shù)陕赃,也會(huì)對(duì)實(shí)驗(yàn)結(jié)果產(chǎn)生影響。
目前暫時(shí)沒(méi)有發(fā)現(xiàn)解決這個(gè)問(wèn)題的方法颁股,但是只要固定num_workers數(shù)目(線(xiàn)程數(shù))不變么库,基本上也能夠重復(fù)實(shí)驗(yàn)結(jié)果。
對(duì)于不同線(xiàn)程的隨機(jī)數(shù)種子設(shè)置甘有,主要通過(guò)DataLoader的worker_init_fn參數(shù)來(lái)實(shí)現(xiàn)诉儒。默認(rèn)情況下使用線(xiàn)程ID作為隨機(jī)數(shù)種子。如果需要自己設(shè)定亏掀,可以參考以下代碼:
GLOBAL_SEED = 1
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
global GLOBAL_WORKER_ID
GLOBAL_WORKER_ID = worker_id
set_seed(GLOBAL_SEED + worker_id)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)
參考