源碼來自于https://github.com/longcw/faster_rcnn_pytorch
這次只看了roi_pooling文件夾下的代碼,其他的需要編譯的文件跟這個類似。
roi_pooling目錄
-src文件夾下是c和cuda版本的源碼堰怨,其中roi_pooling的操作的foward是c和cuda版本都有的碉碉,而backward僅寫了cuda版本的代碼试幽。
-functions文件夾下的roi_pool.py是繼承了torch.autograd.Function類,實現(xiàn)RoI層的foward和backward函數(shù)锄列。
-modules文件夾下的roi_pool.py是繼承了torch.nn.Modules類,實現(xiàn)了對RoI層的封裝惯悠,此時RoI層就跟ReLU層一樣的使用了邻邮。
-_ext文件夾下還有個roi_pooling文件夾,這個文件夾是存儲src中c克婶,cuda編譯過后的文件的筒严,編譯過后就可以被funcitons中的roi_pool.py調(diào)用了丹泉。
functions/roi_pool.py
# -*- coding:utf8 -*-
import torch
from torch.autograd import Function
from .._ext import roi_pooling
import pdb
# 重寫函數(shù)實現(xiàn)RoI層的正向傳播和反向傳播 modules中的roi_pool實現(xiàn)層的封裝
class RoIPoolFunction(Function):
def __init__(ctx, pooled_height, pooled_width, spatial_scale):
ctx.pooled_width = pooled_width
ctx.pooled_height = pooled_height
ctx.spatial_scale = spatial_scale
ctx.feature_size = None
def forward(ctx, features, rois):
ctx.feature_size = features.size()
batch_size, num_channels, data_height, data_width = ctx.feature_size
num_rois = rois.size(0)
output = features.new(num_rois, num_channels, ctx.pooled_height, ctx.pooled_width).zero_() #new是torch.tensor的方法
ctx.argmax = features.new(num_rois, num_channels, ctx.pooled_height, ctx.pooled_width).zero_().int()
ctx.rois = rois
if not features.is_cuda:
_features = features.permute(0, 2, 3, 1) # permute = transform 也是torch.tensor的方法
roi_pooling.roi_pooling_forward(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
_features, rois, output) #調(diào)用_ext下的編譯好的cpu版本函數(shù)
else:
roi_pooling.roi_pooling_forward_cuda(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
features, rois, output, ctx.argmax) #調(diào)用_ext下的編譯好的gpu版本函數(shù)
return output
def backward(ctx, grad_output):
assert(ctx.feature_size is not None and grad_output.is_cuda)
batch_size, num_channels, data_height, data_width = ctx.feature_size
grad_input = grad_output.new(batch_size, num_channels, data_height, data_width).zero_()
roi_pooling.roi_pooling_backward_cuda(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
grad_output, ctx.rois, grad_input, ctx.argmax) #這個地方就只有g(shù)pu版本的了
return grad_input, None
modules\roi_pool.py
# -*- coding:utf8 -*-
from torch.nn.modules.module import Module
from ..functions.roi_pool import RoIPoolFunction
# 對roi_pooling層的封裝,就是ROI Pooling Layer了
class _RoIPooling(Module):
def __init__(self, pooled_height, pooled_width, spatial_scale):
super(_RoIPooling, self).__init__()
self.pooled_width = int(pooled_width)
self.pooled_height = int(pooled_height)
self.spatial_scale = float(spatial_scale)
def forward(self, features, rois):
return RoIPoolFunction(self.pooled_height, self.pooled_width, self.spatial_scale)(features, rois)
# 直接調(diào)用了functions中的函數(shù)鸭蛙,此時已經(jīng)實現(xiàn)了foward摹恨,backward操作
剩下的src,_ext文件的代碼就可以自己讀讀了规惰,就是用c睬塌,cuda對roi_pooling實現(xiàn)了foward和backward,目的就是為了讓python可以調(diào)用歇万。