介紹
dispatch是Pytorch中的一個內部組件,該組件負責將調用一個function(例如torch:add)的時候指出具體執(zhí)行的代碼, 因為PyTorch操作需要處理許多交叉的關注點搏恤,這些點是分層的掉奄,下面列舉了一些:
- 按照輸入Tensor的device桂躏,在算子的CPU和CUDA實現中轉換
- 按照是否進行autograd驼唱,在算子的autograd和backend實現中轉換
- 是否有必要對混合精度執(zhí)行autocast
- 是否對運行在vmap call下的算子執(zhí)行batch rules
- 是否trace算子的執(zhí)行
PyTorch中用DispatchKey表示不同的關注點
enum class DispatchKey : uint8_t {
Undefined = 0,
CatchAll = Undefined,
CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp
CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp
HIP, // NB: I think this is not actually used, due to Note [Masquerading as
// CUDA]
FPGA, // Xilinx support lives out of tree at
ORT,
XLA, // lives out of tree at https://github.com/pytorch/xla
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
Vulkan,
Metal,
XPU, // For out of tree Intel's heterogeneous computing plug-in
HPU, // For out of tree & closed source integration of HPU / Habana
VE, // For out of tree & closed source integration of SX-Aurora / NEC
Lazy, // For lazy tensor backends
SWAI, // For out of tree SWAI backend
Meta,
QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp
QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp
QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in
CustomRNGKeyId,
...
};
總的來說劳殖,dispatch解決了一個問題:該調用哪個kernel
簡單地彤委,像下面的例子一樣鞭铆,使用if語句就可以處理多種情況
class MyAddFunction : ... {
public:
static Tensor forward(
AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
if (self.device().type() == DeviceType::CPU) {
return add_cpu(self, other);
} else if (self.device().type() == DeviceType::CUDA) {
return add_cuda(self, other);
} else {
TORCH_CHECK(0, "Unsupported device ", self.device().type());
}
}
...
}
那么為什么要使用dispatch
- 去中心化的, 對于任意一個新的operator葫慎,不需要寫一個單獨的if語句去判斷衔彻。此外横侦,當第三方實現一個算子在不同情況下(例如設備)的實現時宦焦,不需要修補算子的原有實現喳魏。
- 除了CPU挨摸,CUDA镀钓,Autograd垂蜗,dispatch key支持更多的關注點审孽, c10/core/DispatchKey.h已經實現了一系列的dispatch key
- 實現了對boxed fallback functions的支持盟广,這些函數一次實現废岂,能夠應用于所有的算子祖搓。
Dispatch分發(fā)機制
概念
首先先定義一下一些概念
operator:算子,例如add
kernels:核函數湖苞,算子在不同設備(CPU,CUDA)拯欧,不同輸入(dense,sparse),是否梯度下的不同實現
思路
Dispatch機制是將if判斷轉換成映射的機制财骨,底層是通過hashmap實現镐作,Dispatch控制所有operator的分發(fā)
第一層分發(fā),通過operator name映射到 OperatorHandle (每一個operator都有一個OperatorHandle類處理)
第二層分發(fā)隆箩,通過dispatch key映射到 kernel function (不同的設備该贾,不同輸入..都對應于一個dispatch key)
代碼實現
dispatch主要代碼在aten/src/ATen/core/dispatch目錄下
主要的類有
● Dispatch 處理op name到Operator Handle的映射
● OperatorHandle 注冊,查找捌臊,調用具體operator的kernel
● OperatorEntry 處理Dispatch 可以到 KernelFunction映射
● KernelFunction 封裝backend kernel
Dispatch類
operatorLookupTable_表存放了operator name到Operator Handle的映射
class TORCH_API Dispatcher final{
std::list<OperatorDef> operators_;
LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;
};
下面的這些方法都是通過operator name查找到OperatorHanle杨蛋,具體實現可以查看源碼
c10::optional<OperatorHandle> findSchema(const OperatorName& operator_name);
OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name);
c10::optional<OperatorHandle> findOp(const OperatorName& operator_name);
const std::vector<OperatorName> getAllOpNames();
OperatorHandle類
OperatorHandle類包含OperatorDef,OperatorDef包含OperatorEntry, 具體的映射關系由OperatorEntry處理
class TORCH_API OperatorHandle {
Dispatcher::OperatorDef* operatorDef_;
void callBoxed(Stack* stack) const {
c10::Dispatcher::singleton().callBoxed(*this, stack);
}
void callBoxed(Stack& stack) const {
callBoxed(&stack);
}
void redispatchBoxed(DispatchKeySet ks, Stack* stack) const {
c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
}
}
OperatorEntry類
和Dispatch一樣理澎,OperatorEntry也有一個映射表逞力,存儲dispatch key到kernelfunction的映射關系,dispatch key是一個unit8_t的枚舉值糠爬,因此在這里用array實現了映射表
class TORCH_API OperatorEntry final {
std::array<KernelFunction, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> dispatchTable_;
};
通過查找dispatch key返回kernel function
const KernelFunction& lookup(DispatchKey k) const {
const auto& kernel = dispatchTable_[static_cast<uint8_t>(k)];
// A valid kernel *always* has a boxed kernel and *may* have an
// unboxed kernel. However, we typically do unboxed calls in at::
// APIs, where the kernel 1) will very likely be valid and 2)
// should have an unboxed kernel. Checking the unboxed kernel
// first will allow us to avoid touching the boxed kernel at all
// in the common case.
if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
if (!kernel.isValid()) {
reportError(k);
}
}
return kernel;
}
KernelFunction類
KernelFunction封裝了backend kernel和boxed kernel掏击,unboxed_kernel
functor_指向了backend kernel function
class TORCH_API KernelFunction final {
OperatorKernel* getFunctor_() const;
std::shared_ptr<OperatorKernel> functor_;
InternalBoxedKernelFunction* boxed_kernel_func_;
void* unboxed_kernel_func_;
};
call最后是調用functor_的具體實現
template<class Return, class... Args>
C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
// note: Args above is intentionally not Args&&. We don't want perfect
// forwarding, which would require Args to be deduced, but instead we
// want callers to explicitly specify the Args.
if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
return callUnboxedKernelFunction<Return, Args...>(unboxed_kernel_func_, functor_.get(), dispatchKeySet, std::forward<Args>(args)...);
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
boxed_kernel_func_ != nullptr,
"Tried to call KernelFunction::call() on an uninitialized KernelFunction."
);
return impl::BoxedKernelWrapper<Return(Args...)>::call(
boxed_kernel_func_,
functor_.get(),
opHandle,
dispatchKeySet,
std::forward<Args>(args)...
);
}
Dispatch注冊機制
注冊Operator
注冊operator的案例如下
TORCH_LIBRARY(myops, m) { m.def("myadd(Tensor self, Tensor other) -> Tensor"); }
宏定義在torch/library中,追蹤代碼秩铆,具體的實現在aten/src/ATen/core/library.cpp中
#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name) & {
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
DEF_PRELUDE,
"Cannot define an operator inside of a ", toString(kind_), " block. "
"All def()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. ",
ERROR_CONTEXT
);
TORCH_INTERNAL_ASSERT(ns_.has_value(), ERROR_CONTEXT);
TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
auto ns_opt = schema.getNamespace();
if (ns_opt.has_value()) {
// Note [Redundancy in registration code is OK]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// In an earlier version of this code, I made it an error to explicitly
// specify the namespace, even when the namespaces match. I've decided
// to relax this constraint because sometimes we code generate registrations
// and you cannot conveniently tell what the enclosing context will be;
// in these cases, it is simpler (and less error prone) to place all
// of the information in the registration site, which will be cross-checked
// in the end in any case (and if it turns out you DON'T have the right
// information at the site, as is the case with backend specific
// per-op registrations, you will get the right behavior!)
TORCH_CHECK(*ns_opt == *ns_,
"Explicitly provided namespace (", *ns_opt, ") in schema string "
"does not match namespace of enclosing ", toString(kind_), " block (", *ns_, "). "
"Move this definition to the (unique) TORCH_LIBRARY block corresponding to this namespace "
"(and consider deleting the namespace from your schema string.) ",
ERROR_CONTEXT
);
} else {
bool b = schema.setNamespaceIfNotSet(ns_->c_str());
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
}
if (out_name) {
*out_name = schema.operator_name(); // copy!
}
registrars_.emplace_back(
c10::Dispatcher::singleton().registerDef(
std::move(schema),
debugString("", file_, line_)
)
);
return *this;
}
最后調用Dispatcher的registerDef方法砚亭,該方法映射operator name和operator的關系灯变,也就是執(zhí)行插入映射表的操作
RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug) {
// we need a lock to avoid concurrent writes
std::lock_guard<std::mutex> lock(mutex_);
OperatorName op_name = schema.operator_name();
auto op = findOrRegisterName_(op_name);
TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
" Each overload's schema should only be registered with a single call to def().",
" Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug());
op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug));
listeners_->callOnOperatorRegistered(op);
// NB: do not increment the counts until AFTER error checking
++op.operatorDef_->def_count;
++op.operatorDef_->def_and_impl_count;
return RegistrationHandleRAII([this, op, op_name] {
deregisterDef_(op, op_name);
});
}
注冊kernel
下面的代碼將myadd算子在CPU上實現的kernel注冊到Dispatch中
TORCH_LIBRARY_IMPL(myops, CPU, m) {
m.impl("myadd", myadd_cpu);
}
和注冊operator一樣,追蹤注冊kernel的宏定義代碼捅膘,具體的實現為
#define IMPL_PRELUDE "impl(\"", name_str, "\", ...): "
Library& Library::_impl(const char* name_str, CppFunction&& f) & {
auto name = torch::jit::parseName(name_str);
auto ns_opt = name.getNamespace();
// This is kind of similar to the checking in def(), but the error
// messages are a little different for this call site
if (ns_opt.has_value()) {
// See Note [Redundancy in registration code is OK]
TORCH_CHECK(*ns_opt == *ns_,
IMPL_PRELUDE,
"Explicitly provided namespace (", *ns_opt, ") in operator name "
"does not match namespace of enclosing ", toString(kind_), " block (", *ns_, "). "
"Move this definition to the ", toString(kind_), " block corresponding to this namespace "
"(and consider deleting the namespace from your schema string.) ",
ERROR_CONTEXT
);
} else {
bool b = name.setNamespaceIfNotSet(ns_->c_str());
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
}
// See Note [Redundancy in registration code is OK]
TORCH_CHECK(!(f.dispatch_key_.has_value() &&
dispatch_key_.has_value() &&
*f.dispatch_key_ != *dispatch_key_),
IMPL_PRELUDE,
"Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent "
"with the dispatch key of the enclosing ", toString(kind_), " block (", *dispatch_key_, "). "
"Please declare a separate ", toString(kind_), " block for this dispatch key and "
"move your impl() there. "
ERROR_CONTEXT
);
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
registrars_.emplace_back(
c10::Dispatcher::singleton().registerImpl(
std::move(name),
dispatch_key,
std::move(f.func_),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(f.cpp_signature_),
std::move(f.schema_),
debugString(std::move(f.debug_), file_, line_)
)
);
return *this;
}
調用Dispatcher::registerImpl方法
RegistrationHandleRAII Dispatcher::registerImpl(
OperatorName op_name,
c10::optional<DispatchKey> dispatch_key,
KernelFunction kernel,
c10::optional<impl::CppSignature> cpp_signature,
std::unique_ptr<FunctionSchema> inferred_function_schema,
std::string debug
) {
std::lock_guard<std::mutex> lock(mutex_);
auto op = findOrRegisterName_(op_name);
auto handle = op.operatorDef_->op.registerKernel(
*this,
dispatch_key,
std::move(kernel),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(cpp_signature),
std::move(inferred_function_schema),
std::move(debug)
);
++op.operatorDef_->def_and_impl_count;
return RegistrationHandleRAII([this, op, op_name, dispatch_key, handle] {
deregisterImpl_(op, op_name, dispatch_key, handle);
});
}
找到對應的operatorHandle添祸,調用OperatorEntry的registerKernel方法
registerKernel關鍵的是在67行到71行的更新映射表方法
OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
const c10::Dispatcher& dispatcher,
c10::optional<DispatchKey> dispatch_key,
KernelFunction kernel,
c10::optional<CppSignature> cpp_signature,
std::unique_ptr<FunctionSchema> inferred_function_schema,
std::string debug
) {
// NB: cpp_signature doesn't get cleared even after the kernel that populated
// it is deleted. This means you could poison the value of cpp_signature_
// with a bad signature value, and then it would permanently stay there until
// you deregister the schema. This can't really be fixed, because we
// only do a typed() test once in the lifetime of a TypedOperatorHandle,
// which means if you could validly change the type of a cpp_signature, then
// that would also invalidate the old TypedOperatorHandles.
if (cpp_signature.has_value()) {
if (cpp_signature_.has_value()) {
TORCH_CHECK(*cpp_signature == cpp_signature_->signature,
"\nMismatch in kernel C++ signatures\n",
" operator: ", (this->schema_.has_value() ? toString(this->schema_->schema) : toString(name_)), "\n",
" ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
" kernel 1: ", cpp_signature_->signature.name(), "\n",
" dispatch key: ", toString(cpp_signature_->dispatch_key), "\n",
" ", cpp_signature_->debug, "\n",
" kernel 2: ", cpp_signature->name(), "\n",
" dispatch key: ", toString(dispatch_key), "\n",
" ", debug, "\n"
);
} else {
cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key };
}
}
if (schema_ && inferred_function_schema) {
checkSchema(name_, schema_->schema, schema_->debug, *inferred_function_schema, debug);
}
// Add the kernel to the kernels list,
// possibly creating the list if this is the first kernel.
// Redirect catchAll registrations to CompositeImplicitAutograd.
auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::CompositeImplicitAutograd];
#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
if (k[0].kernel.isValid()) {
#else
if (k.size() > 0) {
#endif
TORCH_WARN("Overriding a previously registered kernel for the same operator and the same dispatch key\n",
" operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
" ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
" dispatch key: ", toString(dispatch_key), "\n",
" previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), "\n",
" new kernel: ", debug
);
}
#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
k[0].kernel = std::move(kernel);
k[0].inferred_function_schema = std::move(inferred_function_schema);
k[0].debug = std::move(debug);
#else
k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug));
#endif
AnnotatedKernelContainerIterator inserted = k.begin();
// update the dispatch table, i.e. re-establish the invariant
// that the dispatch table points to the newest kernel
if (dispatch_key.has_value()) {
updateDispatchTable_(dispatcher, *dispatch_key);
} else {
updateDispatchTableFull_(dispatcher);
}
return inserted;
}
Dispatch調用過程
build/aten/src/ATen/Functions.h包含了算子的入口函數
以torch::relu為例描述從入口函數到最終backend kernel function的調用過程
TORCH_API inline at::Tensor relu(const at::Tensor & self) {
return at::_ops::relu::call(self);
}
build/aten/src/ATen/Operators_4.cpp
通過Operator name查找到TypedOperatorHandle對象,然后調用call方法
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(relu, name, "aten::relu")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(relu, overload_name, "")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(relu, schema_str, "relu(Tensor self) -> Tensor")
// aten::relu(Tensor self) -> Tensor
static C10_NOINLINE c10::TypedOperatorHandle<relu::schema> create_relu_typed_handle() {
return c10::Dispatcher::singleton()
.findSchemaOrThrow(relu::name, relu::overload_name)
.typed<relu::schema>();
}
// aten::relu(Tensor self) -> Tensor
at::Tensor relu::call(const at::Tensor & self) {
static auto op = create_relu_typed_handle();
return op.call(self);
}
// aten::relu(Tensor self) -> Tensor
at::Tensor relu::redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
static auto op = create_relu_typed_handle();
return op.redispatch(dispatchKeySet, self);
}
TypedOperatorHandle call方法調用Dispatch call
C10_ALWAYS_INLINE Return call(Args... args) const {
return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
}
Dispatch call方法通過dispatchKey找到KernelFunction寻仗,然后調用KernelFunction的call方法
template<class Return, class... Args>
C10_DISPATCHER_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
.template getDispatchKeySetUnboxed<Args...>(args...);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKeySet.highestPriorityTypeId()));
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet.highestPriorityTypeId());
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// By default, when there're no high-frequency or non-sampled callbacks,
// RecordFunction is pre-sampled as a perf optimization;
// shouldRunRecordFunction checks whether RecordFunction should be executed,
// and sets pre_sampled boolean argument value to whether pre-sampling was used -
// this boolean is passed into RecordFunction to adjust the sampling rates of
// the callbacks
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
return callWithDispatchKeySlowPath<Return, Args...>(op, pre_sampled, dispatchKeySet, kernel, std::forward<Args>(args)...);
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
}
KernelFunction最終調用backend kernel function
template<class Return, class... Args>
C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
// note: Args above is intentionally not Args&&. We don't want perfect
// forwarding, which would require Args to be deduced, but instead we
// want callers to explicitly specify the Args.
if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
return callUnboxedKernelFunction<Return, Args...>(unboxed_kernel_func_, functor_.get(), dispatchKeySet, std::forward<Args>(args)...);
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
boxed_kernel_func_ != nullptr,
"Tried to call KernelFunction::call() on an uninitialized KernelFunction."
);
return impl::BoxedKernelWrapper<Return(Args...)>::call(
boxed_kernel_func_,
functor_.get(),
opHandle,
dispatchKeySet,
std::forward<Args>(args)...
);
}
autograd
Autograd在dispatch key的優(yōu)先級目前是最高的刃泌,大部分的operator都有autograd過程,因此每個operator call首次進入的backend kernel function是它的autograd function署尤, 還是以relu為例耙替,它的autograd function如下
at::Tensor relu(c10::DispatchKeySet ks, const at::Tensor & self) {
auto& self_ = unpack(self, "self", 0);
auto _any_requires_grad = compute_requires_grad( self );
(void)_any_requires_grad;
std::shared_ptr<ReluBackward0> grad_fn;
if (_any_requires_grad) {
grad_fn = std::shared_ptr<ReluBackward0>(new ReluBackward0(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self ));
}
#ifndef NDEBUG
c10::optional<Storage> self__storage_saved =
self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;
c10::intrusive_ptr<TensorImpl> self__impl_saved;
if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
#endif
auto _tmp = ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
return at::redispatch::relu(ks & c10::after_autograd_keyset, self_);
})();
auto result = std::move(_tmp);
#ifndef NDEBUG
if (self__storage_saved.has_value())
AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());
if (result.has_storage()) AT_ASSERT(result.storage().use_count() == 1, "function: relu");
AT_ASSERT(result.use_count() <= 1, "function: relu");
#endif
if (grad_fn) {
set_history(flatten_tensor_args( result ), grad_fn);
}
throw_error_for_complex_autograd(result, "relu");
TORCH_CHECK_NOT_IMPLEMENTED(!(isFwGradDefined(self)), "Trying to use forward AD with relu that does not support it.");
if (grad_fn) {
grad_fn->result_ = SavedVariable(result, true);
}
return result;
}
autograd函數處理完一些最后調用的是at::redispatch::relu方法,進行重新dispatch過程曹体。經過了autograd了俗扇,redispatch的dispatch key也更新了,通過調用鏈后執(zhí)行的是relu前向計算的backend kernel function
template<class Return, class... Args>
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
// do not use RecordFunction on redispatch
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet.highestPriorityTypeId());
return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
}