在使用PyTorch進行PointCNN的構建和實現(xiàn)中娘扩,發(fā)現(xiàn)模型在訓練過程中Loss保持穩(wěn)定下降着茸,但是在驗證過程中壮锻,出現(xiàn)完全不合理的10e9級別的Loss′汤考慮到訓練集和驗證集是完全從同一數據集中采樣出來的猜绣,不可能會在數據分布上出現(xiàn)明顯的差異,因此排除數據不一致的原因敬特。
詳細檢查了模型在訓練和驗證過程中的輸出掰邢,發(fā)現(xiàn)由于最后的一層BatchNormalization
,模型在訓練過程中的輸出是接近均值為零伟阔,方差為一的辣之。而驗證過程中,模型的輸出完全沒有遵從這個分布皱炉。因此怀估,可以認為,BatchNormalization
在驗證過程中合搅,沒有發(fā)揮它的作用多搀。
考慮到模型內部,顯式對數據分布進行調整的計算灾部,還是主要在BatchNormalization
層康铭,因此首先調查這一方面。
結果發(fā)現(xiàn)赌髓,PyTorch Forum上有人提到了相似的問題Model.eval() gives incorrect loss for model with batchnorm layers麻削。在這里,PyTorch Dev, Facebook AI Research的smth提到
it is possible that your training in general is unstable, so BatchNorm’s running_mean and running_var dont represent true batch statistics.
http://pytorch.org/docs/master/nn.html?highlight=batchnorm#torch.nn.BatchNorm1d
Try the following:
- change the momentum term in BatchNorm constructor to higher.
- before you set model.eval(), run a few inputs through model (just forward pass, you dont need to backward). This will help stabilize the running_mean / running_std values.
即春弥,BatchNormalization
層內的呛哟,隨訓練而不斷更新的擬合的數據分布,沒有能匹配真實的batch數據分布匿沛。
推薦將BatchNorm
內的momentum
項設置的比較高扫责,或是在將模型調到model.eval()
模式前,先將部分測試的數據在模型內前向傳播一下逃呼,讓BatchNorm
層可以更新一下這個估計鳖孤。
本文試驗了一下調高momentum項,但沒有明顯的效果抡笼。
本文解決這個問題通過另一位網友cakeeatingpolarbear提出的方法苏揣,將BatchNorm
函數內的track_running_stats
設置為False
,則模型會在任何模式下保持進行對數據分布的擬合推姻。