代碼結(jié)構(gòu):
1.master文件夾
(1) dice_loss.py
(2) eval.py
(3) predict.py: ** 完全沒涉及pruning后的網(wǎng)絡(luò)**
(4) pruning.py:
(5) submit.py:** 完全沒涉及pruning后的網(wǎng)絡(luò)**
(6) train.py:** 完全沒涉及pruning后的網(wǎng)絡(luò)**
(7) 文件夾unet
<1> prune_layers.py
<2> prune_unet_model.py : class PruneUNet
<3> prune_unet_parts.py : 對p_double_conv, p_inconv, p_outconv, p_down, p_up 進(jìn)行了定義
<4> unet_model.py
<5> unet_parts.py
(8) 文件夾util
<1> load.py
<2> util.py
<3> data_vis.py
<4> crf.py
閱讀目的:
能用,能跑,放自己的數(shù)據(jù)能跑
閱讀筆記
A. pruning.py
閱讀
- pruning.py中的line 91的net.train()的理解
net的定義:
net = PruneUNet(n_channels=3, n_classes=1)
- 中心for循環(huán)的理解
for循環(huán)做了以下四件事情:
對每個epoch:
(1) reset the generator for training data and validation data.
這里實際上對每一個epoch瘦赫,在開始時對數(shù)據(jù)都做了traditional augmentation,但是我們的數(shù)據(jù)量足夠多耐薯,不需要這么做。待改進(jìn)
(2) 取validation dataset的前四項隘世,進(jìn)行prediction和計算accuracy可柿。
這里因為validation dataset是自動生成的,所以雖然都是前4項丙者,但是validation dataset是不一樣的。
(3) 用PruneUNet訓(xùn)練training dataset的前兩個batch:
PruneUNet位于unet文件夾的prune_unet_model.py
中
-
model.eval()
:Pytorch會自動把BN和Dropout固定住营密,不會取平均械媒,而是用訓(xùn)練好的值 -
model.train()
:讓model變成訓(xùn)練模式,此時 dropout和batch normalization的操作在訓(xùn)練時起到防止網(wǎng)絡(luò)過擬合的作用评汰。
訓(xùn)練纷捞,算loss,反向傳播
然后進(jìn)行prune被去,
對每個epoch主儡,都要循環(huán)num_prune_iterations次,每一次運(yùn)行一遍net.prune惨缆。net.prune的具體內(nèi)容見B糜值,總結(jié)下來是去掉一個channel丰捷。
疑問:如果一直執(zhí)行net.prune,都是去掉最小值寂汇,但是去掉最小值之后如果不刪掉對應(yīng)prune_feature_map 里的值棘脐,那么每次刪掉的module里的filter不是一樣的嗎昏滴?
回答:每一次找到對應(yīng)layer_idx和filter_idx之后,對conv2d層執(zhí)行prune,都需要運(yùn)行位于prune_layer.py中的函數(shù)prune_feature_map煞抬,在這個函數(shù)中,執(zhí)行了下面兩步:
indices = Variable(torch.LongTensor([i for i in range(self.out_channels) if i != map_index]))
self.weight = nn.Parameter(self.weight.index_select(0, indices).data)
對bias和對weight有一樣的操作霹崎。
最后將輸出channel減一胯陋。
這里重點(diǎn)理解這個index_select函數(shù):
函數(shù)格式:
index_select(
dim,
index)
參數(shù)含義:
dim:表示從第幾維挑選數(shù)據(jù),類型為int值扒磁;index:表示從第一個參數(shù)維度中的哪個位置挑選數(shù)據(jù)硼被,類型為torch.Tensor類的實例;
(4) 繼續(xù)對第一次循環(huán)里的validation的數(shù)據(jù)用pruned的代碼進(jìn)行預(yù)測和計算loss渗磅。
(5) if save_cp時嚷硫,保存net.state_dict()
B. prune_unet_model.py
閱讀
在PruneUNet
這個class中定義了4個函數(shù):
__ init __,forward始鱼,set_pruning和prune仔掸。
其中 __ init __里,所有的down和up layer 以及output layer 都是pruned layer医清。
其中prune是具體進(jìn)性layer prune的函數(shù)起暮,做了以下事情:
(1) 去掉model里的大的block
(2) 找到泰勒估計中,最小估計值所對應(yīng)的layer和filter的位置会烙,用prune_feature_map
函數(shù)進(jìn)行prune负懦。
(3) 如果下一層不是最后一層,對應(yīng)去drop掉下一層的輸入channel
(4) down layer的channel改變之后柏腻,對應(yīng)up layer的channel也要改變纸厉,這里用hard code去寫。
進(jìn)一步去看:line68的taylor_estimates_by_module 和 estimates_by_f_map是怎么計算得到的五嫂。
對每一個module list的module颗品,在line64進(jìn)行了module.taylor_estimates,去進(jìn)行排序沃缘。
先取出每個module_list 的module.taylor_estimates和idx躯枢,
再從module.taylor_estimates里取出f_map_idx和對應(yīng)的估計值estimate。
C. prune_layers.py
閱讀
在prune_layers.py中槐臀,定義了class PrunableConv2d(nn.Conv2d)
和class PrunableBatchNorm2d(nn.BatchNorm2d)
锄蹂,對PrunableConv2d(nn.Conv2d)
,定義了屬性taylor_estimates
D. 提問:
問題1:
pruning.py基于前幾個training batch和幾個epoch和手動輸入的num_prune_iterations對unet進(jìn)行pruning水慨,那么如何用prune好的網(wǎng)絡(luò)對我們的數(shù)據(jù)進(jìn)行計算呢得糜?
num_prune_iterations = 100敬扛,
epochs=5
這里又涉及兩個問題:(a) prune完需要retrain嗎? (b) 如何用pruned的網(wǎng)絡(luò)進(jìn)行inference?
對問題(a),其實在pruning.py里掀亩,有反向傳播更新梯度值和權(quán)重值的過程了舔哪,未必要重新去train。
對問題(b)槽棍,需要繼續(xù)閱讀代碼捉蚤。
代碼里并沒有寫,可能需要自己在prediction的代碼里導(dǎo)入pruned_unet
問題2: 如何計算每個module的taylor_estimates
在prune_layers.py
的class PrunableConv2d(nn.Conv2d)
中有一個函數(shù)_calculate_taylor_estimate(self, _, grad_input, grad_output)
專門計算taylor_estimates炼七。
這里有注釋:# skip dim 1 as it is kernel size
其中缆巧,_recent_activations是forward之后該conv2d層的output。
mul_(value)
mul()的直接運(yùn)算形式豌拙,即直接執(zhí)行并且返回修改后的張量
# skip dim 1 as it is kernel size
estimates = self._recent_activations.mul_(grad_output[0])
estimates = estimates.mean(dim=(0, 2, 3))
# normalization
self.taylor_estimates = torch.abs(estimates) / torch.sqrt(torch.sum(estimates * estimates))
修改代碼據(jù)為己用:
A. pruning.py
改動記錄
- 把optimizer從SGD改成Adam陕悬,和自己的UNet保持一致。(已完成)
-
line61
的criterion = nn.BCELoss()
改成自己定義的Diceloss
- line183的net的定義
n_channel
從3改成1 -
summary(net, (3, 640, 640))
注釋掉這一步可視化按傅,因為暫時不探究其參數(shù)含義捉超。
B.