PyTorch Dispatch

介紹

dispatch是Pytorch中的一個內部組件,該組件負責將調用一個function(例如torch:add)的時候指出具體執(zhí)行的代碼, 因為PyTorch操作需要處理許多交叉的關注點搏恤,這些點是分層的掉奄,下面列舉了一些:

  1. 按照輸入Tensor的device桂躏,在算子的CPU和CUDA實現中轉換
  2. 按照是否進行autograd驼唱,在算子的autograd和backend實現中轉換
  3. 是否有必要對混合精度執(zhí)行autocast
  4. 是否對運行在vmap call下的算子執(zhí)行batch rules
  5. 是否trace算子的執(zhí)行

PyTorch中用DispatchKey表示不同的關注點

enum class DispatchKey : uint8_t {
  Undefined = 0,
  CatchAll = Undefined,
  CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp
  CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp
  HIP, // NB: I think this is not actually used, due to Note [Masquerading as
  // CUDA]
  FPGA, // Xilinx support lives out of tree at
  ORT,
  XLA, // lives out of tree at https://github.com/pytorch/xla
  MLC, // lives out of tree at https://github.com/pytorch/MLCompute
  Vulkan,
  Metal,
  XPU, // For out of tree Intel's heterogeneous computing plug-in
  HPU, // For out of tree & closed source integration of HPU / Habana
  VE, // For out of tree & closed source integration of SX-Aurora / NEC
  Lazy, // For lazy tensor backends
  SWAI, // For out of tree SWAI backend
  Meta,
  QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp
  QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp
  QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in
  CustomRNGKeyId,
  ...
};

總的來說劳殖,dispatch解決了一個問題:該調用哪個kernel
簡單地彤委,像下面的例子一樣鞭铆,使用if語句就可以處理多種情況

class MyAddFunction : ... {
public:
  static Tensor forward(
    AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {

    if (self.device().type() == DeviceType::CPU) {
      return add_cpu(self, other);
    } else if (self.device().type() == DeviceType::CUDA) {
      return add_cuda(self, other);
    } else {
      TORCH_CHECK(0, "Unsupported device ", self.device().type());
    }
  }
  ...
}

那么為什么要使用dispatch

  1. 去中心化的, 對于任意一個新的operator葫慎,不需要寫一個單獨的if語句去判斷衔彻。此外横侦,當第三方實現一個算子在不同情況下(例如設備)的實現時宦焦,不需要修補算子的原有實現喳魏。
  2. 除了CPU挨摸,CUDA镀钓,Autograd垂蜗,dispatch key支持更多的關注點审孽, c10/core/DispatchKey.h已經實現了一系列的dispatch key
  3. 實現了對boxed fallback functions的支持盟广,這些函數一次實現废岂,能夠應用于所有的算子祖搓。

Dispatch分發(fā)機制

概念

首先先定義一下一些概念
operator:算子,例如add
kernels:核函數湖苞,算子在不同設備(CPU,CUDA)拯欧,不同輸入(dense,sparse),是否梯度下的不同實現

思路

Dispatch機制是將if判斷轉換成映射的機制财骨,底層是通過hashmap實現镐作,Dispatch控制所有operator的分發(fā)
第一層分發(fā),通過operator name映射到 OperatorHandle (每一個operator都有一個OperatorHandle類處理)
第二層分發(fā)隆箩,通過dispatch key映射到 kernel function (不同的設備该贾,不同輸入..都對應于一個dispatch key)

代碼實現

dispatch主要代碼在aten/src/ATen/core/dispatch目錄下
主要的類有
● Dispatch 處理op name到Operator Handle的映射
● OperatorHandle 注冊,查找捌臊,調用具體operator的kernel
● OperatorEntry 處理Dispatch 可以到 KernelFunction映射
● KernelFunction 封裝backend kernel

Dispatch類

operatorLookupTable_表存放了operator name到Operator Handle的映射

class TORCH_API Dispatcher final{
    std::list<OperatorDef> operators_;
    LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;
};

下面的這些方法都是通過operator name查找到OperatorHanle杨蛋,具體實現可以查看源碼

c10::optional<OperatorHandle> findSchema(const OperatorName& operator_name);
OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name);
c10::optional<OperatorHandle> findOp(const OperatorName& operator_name);
const std::vector<OperatorName> getAllOpNames();
OperatorHandle類

OperatorHandle類包含OperatorDef,OperatorDef包含OperatorEntry, 具體的映射關系由OperatorEntry處理

class TORCH_API OperatorHandle {
    Dispatcher::OperatorDef* operatorDef_;
    void callBoxed(Stack* stack) const {
        c10::Dispatcher::singleton().callBoxed(*this, stack);
    }

    void callBoxed(Stack& stack) const {
        callBoxed(&stack);
    }

    void redispatchBoxed(DispatchKeySet ks, Stack* stack) const {
        c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
    }
}
OperatorEntry類

和Dispatch一樣理澎,OperatorEntry也有一個映射表逞力,存儲dispatch key到kernelfunction的映射關系,dispatch key是一個unit8_t的枚舉值糠爬,因此在這里用array實現了映射表

class TORCH_API OperatorEntry final {
    std::array<KernelFunction, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> dispatchTable_;
};

通過查找dispatch key返回kernel function

const KernelFunction& lookup(DispatchKey k) const {
    const auto& kernel = dispatchTable_[static_cast<uint8_t>(k)];
    // A valid kernel *always* has a boxed kernel and *may* have an
    // unboxed kernel. However, we typically do unboxed calls in at::
    // APIs, where the kernel 1) will very likely be valid and 2)
    // should have an unboxed kernel. Checking the unboxed kernel
    // first will allow us to avoid touching the boxed kernel at all
    // in the common case.
    if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
      if (!kernel.isValid()) {
        reportError(k);
      }
    }
    return kernel;
  }
KernelFunction類

KernelFunction封裝了backend kernel和boxed kernel掏击,unboxed_kernel
functor_指向了backend kernel function

class TORCH_API KernelFunction final {
  OperatorKernel* getFunctor_() const;
  std::shared_ptr<OperatorKernel> functor_;
  InternalBoxedKernelFunction* boxed_kernel_func_;
  void* unboxed_kernel_func_;
};

call最后是調用functor_的具體實現

template<class Return, class... Args>
C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
    // note: Args above is intentionally not Args&&. We don't want perfect
    // forwarding, which would require Args to be deduced, but instead we
    // want callers to explicitly specify the Args.

    if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
        return callUnboxedKernelFunction<Return, Args...>(unboxed_kernel_func_, functor_.get(), dispatchKeySet, std::forward<Args>(args)...);
    }

    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
        boxed_kernel_func_ != nullptr,
        "Tried to call KernelFunction::call() on an uninitialized KernelFunction."
    );

    return impl::BoxedKernelWrapper<Return(Args...)>::call(
        boxed_kernel_func_,
        functor_.get(),
        opHandle,
        dispatchKeySet,
        std::forward<Args>(args)...
    );
}

Dispatch注冊機制

注冊Operator

注冊operator的案例如下

TORCH_LIBRARY(myops, m) {  m.def("myadd(Tensor self, Tensor other) -> Tensor"); }

宏定義在torch/library中,追蹤代碼秩铆,具體的實現在aten/src/ATen/core/library.cpp中

#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name) & {
  TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
    DEF_PRELUDE,
    "Cannot define an operator inside of a ", toString(kind_), " block.  "
    "All def()s should be placed in the (unique) TORCH_LIBRARY block for their namespace.  ",
    ERROR_CONTEXT
  );
  TORCH_INTERNAL_ASSERT(ns_.has_value(), ERROR_CONTEXT);
  TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
  auto ns_opt = schema.getNamespace();
  if (ns_opt.has_value()) {
    // Note [Redundancy in registration code is OK]
    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    // In an earlier version of this code, I made it an error to explicitly
    // specify the namespace, even when the namespaces match.  I've decided
    // to relax this constraint because sometimes we code generate registrations
    // and you cannot conveniently tell what the enclosing context will be;
    // in these cases, it is simpler (and less error prone) to place all
    // of the information in the registration site, which will be cross-checked
    // in the end in any case (and if it turns out you DON'T have the right
    // information at the site, as is the case with backend specific
    // per-op registrations, you will get the right behavior!)
    TORCH_CHECK(*ns_opt == *ns_,
      "Explicitly provided namespace (", *ns_opt, ") in schema string "
      "does not match namespace of enclosing ", toString(kind_), " block (", *ns_, ").  "
      "Move this definition to the (unique) TORCH_LIBRARY block corresponding to this namespace "
      "(and consider deleting the namespace from your schema string.)  ",
      ERROR_CONTEXT
    );
  } else {
    bool b = schema.setNamespaceIfNotSet(ns_->c_str());
    TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
  }
  if (out_name) {
    *out_name = schema.operator_name(); // copy!
  }
  registrars_.emplace_back(
    c10::Dispatcher::singleton().registerDef(
      std::move(schema),
      debugString("", file_, line_)
    )
  );
  return *this;
}

最后調用Dispatcher的registerDef方法砚亭,該方法映射operator name和operator的關系灯变,也就是執(zhí)行插入映射表的操作

RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug) {
  // we need a lock to avoid concurrent writes
  std::lock_guard<std::mutex> lock(mutex_);

  OperatorName op_name = schema.operator_name();
  auto op = findOrRegisterName_(op_name);

  TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
                                                    " Each overload's schema should only be registered with a single call to def().",
                                                    " Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug());
  op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug));
  listeners_->callOnOperatorRegistered(op);

  // NB: do not increment the counts until AFTER error checking
  ++op.operatorDef_->def_count;
  ++op.operatorDef_->def_and_impl_count;

  return RegistrationHandleRAII([this, op, op_name] {
    deregisterDef_(op, op_name);
  });
}

注冊kernel

下面的代碼將myadd算子在CPU上實現的kernel注冊到Dispatch中

TORCH_LIBRARY_IMPL(myops, CPU, m) {
  m.impl("myadd", myadd_cpu);
}

和注冊operator一樣,追蹤注冊kernel的宏定義代碼捅膘,具體的實現為

#define IMPL_PRELUDE "impl(\"", name_str, "\", ...): "
Library& Library::_impl(const char* name_str, CppFunction&& f) & {
  auto name = torch::jit::parseName(name_str);
  auto ns_opt = name.getNamespace();
  // This is kind of similar to the checking in def(), but the error
  // messages are a little different for this call site
  if (ns_opt.has_value()) {
    // See Note [Redundancy in registration code is OK]
    TORCH_CHECK(*ns_opt == *ns_,
      IMPL_PRELUDE,
      "Explicitly provided namespace (", *ns_opt, ") in operator name "
      "does not match namespace of enclosing ", toString(kind_), " block (", *ns_, ").  "
      "Move this definition to the ", toString(kind_), " block corresponding to this namespace "
      "(and consider deleting the namespace from your schema string.)  ",
      ERROR_CONTEXT
    );
  } else {
    bool b = name.setNamespaceIfNotSet(ns_->c_str());
    TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
  }
  // See Note [Redundancy in registration code is OK]
  TORCH_CHECK(!(f.dispatch_key_.has_value() &&
                dispatch_key_.has_value() &&
                *f.dispatch_key_ != *dispatch_key_),
    IMPL_PRELUDE,
    "Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent "
    "with the dispatch key of the enclosing ", toString(kind_), " block (", *dispatch_key_, ").  "
    "Please declare a separate ", toString(kind_), " block for this dispatch key and "
    "move your impl() there.  "
    ERROR_CONTEXT
  );
  auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
  registrars_.emplace_back(
    c10::Dispatcher::singleton().registerImpl(
      std::move(name),
      dispatch_key,
      std::move(f.func_),
      // NOLINTNEXTLINE(performance-move-const-arg)
      std::move(f.cpp_signature_),
      std::move(f.schema_),
      debugString(std::move(f.debug_), file_, line_)
    )
  );
  return *this;
}

調用Dispatcher::registerImpl方法

RegistrationHandleRAII Dispatcher::registerImpl(
  OperatorName op_name,
  c10::optional<DispatchKey> dispatch_key,
  KernelFunction kernel,
  c10::optional<impl::CppSignature> cpp_signature,
  std::unique_ptr<FunctionSchema> inferred_function_schema,
  std::string debug
) {
  std::lock_guard<std::mutex> lock(mutex_);

  auto op = findOrRegisterName_(op_name);

  auto handle = op.operatorDef_->op.registerKernel(
    *this,
    dispatch_key,
    std::move(kernel),
    // NOLINTNEXTLINE(performance-move-const-arg)
    std::move(cpp_signature),
    std::move(inferred_function_schema),
    std::move(debug)
  );

  ++op.operatorDef_->def_and_impl_count;

  return RegistrationHandleRAII([this, op, op_name, dispatch_key, handle] {
    deregisterImpl_(op, op_name, dispatch_key, handle);
  });
}

找到對應的operatorHandle添祸,調用OperatorEntry的registerKernel方法
registerKernel關鍵的是在67行到71行的更新映射表方法

OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
  const c10::Dispatcher& dispatcher,
  c10::optional<DispatchKey> dispatch_key,
  KernelFunction kernel,
  c10::optional<CppSignature> cpp_signature,
  std::unique_ptr<FunctionSchema> inferred_function_schema,
  std::string debug
) {
  // NB: cpp_signature doesn't get cleared even after the kernel that populated
  // it is deleted.  This means you could poison the value of cpp_signature_
  // with a bad signature value, and then it would permanently stay there until
  // you deregister the schema.  This can't really be fixed, because we
  // only do a typed() test once in the lifetime of a TypedOperatorHandle,
  // which means if you could validly change the type of a cpp_signature, then
  // that would also invalidate the old TypedOperatorHandles.
  if (cpp_signature.has_value()) {
    if (cpp_signature_.has_value()) {
      TORCH_CHECK(*cpp_signature == cpp_signature_->signature,
        "\nMismatch in kernel C++ signatures\n",
        "  operator: ", (this->schema_.has_value() ? toString(this->schema_->schema) : toString(name_)), "\n",
        "    ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
        "  kernel 1: ", cpp_signature_->signature.name(), "\n",
        "    dispatch key: ", toString(cpp_signature_->dispatch_key), "\n",
        "    ", cpp_signature_->debug, "\n",
        "  kernel 2: ", cpp_signature->name(), "\n",
        "    dispatch key: ", toString(dispatch_key), "\n",
        "    ", debug, "\n"
      );
    } else {
      cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key };
    }
  }

  if (schema_ && inferred_function_schema) {
    checkSchema(name_, schema_->schema, schema_->debug, *inferred_function_schema, debug);
  }

  // Add the kernel to the kernels list,
  // possibly creating the list if this is the first kernel.
  // Redirect catchAll registrations to CompositeImplicitAutograd.
  auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::CompositeImplicitAutograd];

#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
  if (k[0].kernel.isValid()) {
#else
  if (k.size() > 0) {
#endif
    TORCH_WARN("Overriding a previously registered kernel for the same operator and the same dispatch key\n",
               "  operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
               "    ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
               "  dispatch key: ", toString(dispatch_key), "\n",
               "  previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), "\n",
               "       new kernel: ", debug
    );
  }

#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
  k[0].kernel = std::move(kernel);
  k[0].inferred_function_schema = std::move(inferred_function_schema);
  k[0].debug = std::move(debug);
#else
  k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug));
#endif
  AnnotatedKernelContainerIterator inserted = k.begin();
  // update the dispatch table, i.e. re-establish the invariant
  // that the dispatch table points to the newest kernel
  if (dispatch_key.has_value()) {
    updateDispatchTable_(dispatcher, *dispatch_key);
  } else {
    updateDispatchTableFull_(dispatcher);
  }
  return inserted;
}

Dispatch調用過程

build/aten/src/ATen/Functions.h包含了算子的入口函數
以torch::relu為例描述從入口函數到最終backend kernel function的調用過程

TORCH_API inline at::Tensor relu(const at::Tensor & self) {
    return at::_ops::relu::call(self);
}

build/aten/src/ATen/Operators_4.cpp
通過Operator name查找到TypedOperatorHandle對象,然后調用call方法

STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(relu, name, "aten::relu")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(relu, overload_name, "")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(relu, schema_str, "relu(Tensor self) -> Tensor")
// aten::relu(Tensor self) -> Tensor
static C10_NOINLINE c10::TypedOperatorHandle<relu::schema> create_relu_typed_handle() {
  return c10::Dispatcher::singleton()
      .findSchemaOrThrow(relu::name, relu::overload_name)
      .typed<relu::schema>();
}
// aten::relu(Tensor self) -> Tensor
at::Tensor relu::call(const at::Tensor & self) {
    static auto op = create_relu_typed_handle();
    return op.call(self);
}
// aten::relu(Tensor self) -> Tensor
at::Tensor relu::redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
    static auto op = create_relu_typed_handle();
    return op.redispatch(dispatchKeySet, self);
}

TypedOperatorHandle call方法調用Dispatch call

C10_ALWAYS_INLINE Return call(Args... args) const {
    return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
}

Dispatch call方法通過dispatchKey找到KernelFunction寻仗,然后調用KernelFunction的call方法

template<class Return, class... Args>
C10_DISPATCHER_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
  detail::unused_arg_(args...);  // workaround for a false-positive warning about unused parameters in gcc 5
  auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
    .template getDispatchKeySetUnboxed<Args...>(args...);
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKeySet.highestPriorityTypeId()));
  const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet.highestPriorityTypeId());
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
  // By default, when there're no high-frequency or non-sampled callbacks,
  // RecordFunction is pre-sampled as a perf optimization;
  // shouldRunRecordFunction checks whether RecordFunction should be executed,
  // and sets pre_sampled boolean argument value to whether pre-sampling was used -
  // this boolean is passed into RecordFunction to adjust the sampling rates of
  // the callbacks
  bool pre_sampled = false;
  if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
    return callWithDispatchKeySlowPath<Return, Args...>(op, pre_sampled, dispatchKeySet, kernel, std::forward<Args>(args)...);
  }
#endif  // PYTORCH_DISABLE_PER_OP_PROFILING
  return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
}

KernelFunction最終調用backend kernel function

template<class Return, class... Args>
C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
    // note: Args above is intentionally not Args&&. We don't want perfect
    // forwarding, which would require Args to be deduced, but instead we
    // want callers to explicitly specify the Args.

    if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
        return callUnboxedKernelFunction<Return, Args...>(unboxed_kernel_func_, functor_.get(), dispatchKeySet, std::forward<Args>(args)...);
    }

    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
        boxed_kernel_func_ != nullptr,
        "Tried to call KernelFunction::call() on an uninitialized KernelFunction."
    );

    return impl::BoxedKernelWrapper<Return(Args...)>::call(
        boxed_kernel_func_,
        functor_.get(),
        opHandle,
        dispatchKeySet,
        std::forward<Args>(args)...
    );
}

autograd

Autograd在dispatch key的優(yōu)先級目前是最高的刃泌,大部分的operator都有autograd過程,因此每個operator call首次進入的backend kernel function是它的autograd function署尤, 還是以relu為例耙替,它的autograd function如下

at::Tensor relu(c10::DispatchKeySet ks, const at::Tensor & self) {
  auto& self_ = unpack(self, "self", 0);
  auto _any_requires_grad = compute_requires_grad( self );
  
  (void)_any_requires_grad;
  std::shared_ptr<ReluBackward0> grad_fn;
  if (_any_requires_grad) {
    grad_fn = std::shared_ptr<ReluBackward0>(new ReluBackward0(), deleteNode);
    grad_fn->set_next_edges(collect_next_edges( self ));
  }
  #ifndef NDEBUG
  c10::optional<Storage> self__storage_saved =
    self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;
  c10::intrusive_ptr<TensorImpl> self__impl_saved;
  if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
  #endif
  auto _tmp = ([&]() {
    at::AutoDispatchBelowADInplaceOrView guard;
    return at::redispatch::relu(ks & c10::after_autograd_keyset, self_);
  })();
  auto result = std::move(_tmp);
  #ifndef NDEBUG
  if (self__storage_saved.has_value())
    AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
  if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());
  if (result.has_storage()) AT_ASSERT(result.storage().use_count() == 1, "function: relu");
  AT_ASSERT(result.use_count() <= 1, "function: relu");
  #endif
  if (grad_fn) {
      set_history(flatten_tensor_args( result ), grad_fn);
  }
  throw_error_for_complex_autograd(result, "relu");
  TORCH_CHECK_NOT_IMPLEMENTED(!(isFwGradDefined(self)), "Trying to use forward AD with relu that does not support it.");
  if (grad_fn) {
    grad_fn->result_ = SavedVariable(result, true);
  }
  return result;
}

autograd函數處理完一些最后調用的是at::redispatch::relu方法,進行重新dispatch過程曹体。經過了autograd了俗扇,redispatch的dispatch key也更新了,通過調用鏈后執(zhí)行的是relu前向計算的backend kernel function

template<class Return, class... Args>
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
  detail::unused_arg_(args...);  // workaround for a false-positive warning about unused parameters in gcc 5
  // do not use RecordFunction on redispatch
  const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet.highestPriorityTypeId());
  return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
}
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
  • 序言:七十年代末箕别,一起剝皮案震驚了整個濱河市铜幽,隨后出現的幾起案子,更是在濱河造成了極大的恐慌串稀,老刑警劉巖除抛,帶你破解...
    沈念sama閱讀 212,454評論 6 493
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現場離奇詭異母截,居然都是意外死亡到忽,警方通過查閱死者的電腦和手機,發(fā)現死者居然都...
    沈念sama閱讀 90,553評論 3 385
  • 文/潘曉璐 我一進店門清寇,熙熙樓的掌柜王于貴愁眉苦臉地迎上來喘漏,“玉大人,你說我怎么就攤上這事颗管。” “怎么了滓走?”我有些...
    開封第一講書人閱讀 157,921評論 0 348
  • 文/不壞的土叔 我叫張陵垦江,是天一觀的道長。 經常有香客問我搅方,道長比吭,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 56,648評論 1 284
  • 正文 為了忘掉前任姨涡,我火速辦了婚禮衩藤,結果婚禮上,老公的妹妹穿的比我還像新娘涛漂。我一直安慰自己赏表,他們只是感情好检诗,可當我...
    茶點故事閱讀 65,770評論 6 386
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著瓢剿,像睡著了一般逢慌。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上间狂,一...
    開封第一講書人閱讀 49,950評論 1 291
  • 那天攻泼,我揣著相機與錄音,去河邊找鬼鉴象。 笑死忙菠,一個胖子當著我的面吹牛,可吹牛的內容都是我干的纺弊。 我是一名探鬼主播牛欢,決...
    沈念sama閱讀 39,090評論 3 410
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼俭尖!你這毒婦竟也來了氢惋?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 37,817評論 0 268
  • 序言:老撾萬榮一對情侶失蹤稽犁,失蹤者是張志新(化名)和其女友劉穎焰望,沒想到半個月后,有當地人在樹林里發(fā)現了一具尸體已亥,經...
    沈念sama閱讀 44,275評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡熊赖,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 36,592評論 2 327
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現自己被綠了虑椎。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片震鹉。...
    茶點故事閱讀 38,724評論 1 341
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖捆姜,靈堂內的尸體忽然破棺而出传趾,到底是詐尸還是另有隱情,我是刑警寧澤泥技,帶...
    沈念sama閱讀 34,409評論 4 333
  • 正文 年R本政府宣布浆兰,位于F島的核電站,受9級特大地震影響珊豹,放射性物質發(fā)生泄漏簸呈。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 40,052評論 3 316
  • 文/蒙蒙 一店茶、第九天 我趴在偏房一處隱蔽的房頂上張望蜕便。 院中可真熱鬧,春花似錦贩幻、人聲如沸轿腺。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,815評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽吃溅。三九已至溶诞,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間决侈,已是汗流浹背螺垢。 一陣腳步聲響...
    開封第一講書人閱讀 32,043評論 1 266
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留赖歌,地道東北人枉圃。 一個月前我還...
    沈念sama閱讀 46,503評論 2 361
  • 正文 我出身青樓,卻偏偏與公主長得像庐冯,于是被迫代替她去往敵國和親孽亲。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 43,627評論 2 350

推薦閱讀更多精彩內容