Overhead
PyTorch 執(zhí)行 eager 操作時,例如,torch.add(a, b)
此疹,調(diào)度器(c10::Dispatcher
)會根據(jù)分派鍵(DispatchKey
) 來查找并執(zhí)行 add op 的 op kernel (理解PyTorch分發(fā)機(jī)制的內(nèi)部工作原理)。因此,算子注冊過程就是在調(diào)度器中定義 op舔亭,并將 kernel function 注冊到 op 的指定分派鍵條目中。
Torch Library
torch::Library
是算子注冊用的 helper蟀俊,通過它注冊的算子有著相同的命名空間钦铺、dispatch key等。
TORCH_LIBRARY(myops, m) {
m.def("myadd(Tensor self, Tensor other) -> Tensor");
m.def("mysub(Tensor self, Tensor other) -> Tensor", mysub_func);
m.impl("myadd", myadd_func);
}
m
就是命名空間為 myops
的 library肢预,它通過 m.def
定義了 myadd 和 mysub 這兩個 op 的靜態(tài)信息 schema矛洞。mysub 在定義的同時也將 mysub_func
函數(shù)注冊到 op,而 myadd 的 op kernel 則是通過 m.impl
單獨(dú)注冊的烫映。由于 TORCH_LIBRARY 宏沒有指定 dispatch key沼本,因此,這兩個 op kernel 都是 CatchAll
函數(shù)锭沟。
如果要將 kernel function 注冊到指定的 dispatch key抽兆,需要用到 TORCH_LIBRARY_IMPL 宏:
TORCH_LIBRARY_IMPL(myops, CUDA, m) {
m.impl("myadd", myadd_cuda);
m.impl("mysub", mysub_cuda);
}
所有通過 m
注冊的 kernel function 都會注冊到 op 的 CUDA
key 條目中,它執(zhí)行的優(yōu)先級會比 CatchAll 更高族淮。
OperatorDef
OperatorDef
用于描述調(diào)度器中 op 的靜態(tài)信息辫红,它會提供 registerSchema()
、registerKernel()
方法給 m.def() 和 m.impl() 分別用于注冊 op 和 kernel祝辣。
Kernel list
通過 m.impl() 注冊的 kernel function 會插入到指定 dispatch key 的 kernel list(kernels_
)的頭部贴妻,而調(diào)度器則會從列表中的首元素中獲取 kernel。也就是說蝙斜,PyTorch 允許為 op 的同一個 dispatch key 注冊多個 kernel揍瑟,而新 kernel 會覆蓋舊 kernel。
class TORCH_API OperatorEntry final {
...
ska::flat_hash_map<DispatchKey, std::list<AnnotatedKernel>> kernels_;
};
const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
auto kern_it = kernels_.find(dispatch_key);
if (kern_it != kernels_.end()) {
TORCH_INTERNAL_ASSERT(!kern_it->second.empty());
TORCH_INTERNAL_ASSERT(kern_it->second.front().kernel.isValid());
return &kern_it->second.front();
}
return nullptr;
}