Pytorch底層源碼解讀(二)libtorch源碼淺析

前言

上文我們通過針對性的閱讀pytorch源碼的框架結(jié)構(gòu)猾瘸,同時以__init__.py文件為線索探索了pytorch中多個主要類型的實現(xiàn)和功能。在上文中我們曾提到丢习,pytorch框架是一個以python為前端牵触,C++為后端的框架。如果去除掉pytorch中的python前端咐低,那我們就可以得到一個C++的AI框架——libtorch揽思。沒有特殊說明時,本文所看源碼均是取自libtorch而非pytorch见擦,其中libtorch的源碼版本是libtorch-win-shared-with-deps-2.1.0+cpu钉汗。

一個引子

我們來看看如下代碼:

torch::Tensor x = torch::tensor({1.0});

是否覺得很眼熟,其實這就是當(dāng)我們在python中調(diào)用torch.tensor()這個函數(shù)時鲤屡,其在C++后端實際上調(diào)用的函數(shù)损痰,從這一個例子中我們可以知道,要想真正理解pytorch的底層原理酒来,就需要深入的理解libtorch的源碼卢未。

目錄結(jié)構(gòu)

和上文我們解析pytorch源碼一樣,現(xiàn)在我們來看一下libtorch的源碼結(jié)構(gòu)堰汉,值得注意的是辽社,和pytorch一樣,libtorch也分為cpu版本和gpu版本兩個版本翘鸭,因此其目錄結(jié)構(gòu)稍有不同滴铅。我們主要看cpu版本的include目錄以及lib目錄,lib目錄下存放的是靜態(tài)鏈接文件就乓,而include下存放的是頭文件失息,我們在使用libtorch的時候,實際上是引入頭文件档址,最后在編譯過程中,由編譯器根據(jù)頭文件去找到路徑下的動態(tài)鏈接文件邻梆,并在鏈接期間將其鏈接進(jìn)來守伸。

圖片1.png

include目錄下則是這個樣子,可以很容易地發(fā)現(xiàn)浦妄,這和pytorch的源碼結(jié)構(gòu)是幾乎一樣的尼摹。
圖片2.png

torch.h解讀

torch.h源碼很短见芹,只有如下幾行,從中我們可以看到它其實是引入了all.h這個頭文件蠢涝,另外還引入了一個extension.h文件玄呛。這里這個extension.h文件其實是引入了python.h,當(dāng)然和二,由于這里討論的是無python前端的libtorch版本徘铝,因此不需要注意這個。

#pragma once

#include <torch/all.h>

#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>

#endif // defined(TORCH_API_INCLUDE_EXTENSION_H)

在進(jìn)入all.h后惯吕,可以看到如下代碼惕它,很明顯,這個頭文件就是將所有模塊下的頭文件引入废登,現(xiàn)階段我們要關(guān)注的只有linalg.h這一個文件淹魄。

#pragma once

#if !defined(_MSC_VER) && __cplusplus < 201703L
#error C++17 or later compatible compiler is required to use PyTorch.
#endif

#include <torch/autograd.h>
#include <torch/cuda.h>
#include <torch/data.h>
#include <torch/enum.h>
#include <torch/fft.h>
#include <torch/jit.h>
#include <torch/linalg.h>
#include <torch/mps.h>
#include <torch/nested.h>
#include <torch/nn.h>
#include <torch/optim.h>
#include <torch/serialize.h>
#include <torch/sparse.h>
#include <torch/special.h>
#include <torch/types.h>
#include <torch/utils.h>
#include <torch/version.h>

torch::Tensor淺析

進(jìn)入linalg.h后我們可以看到如下代碼甲锡,這個文件提供了相當(dāng)多的內(nèi)聯(lián)函數(shù)疚俱,主要是一些矩陣運算相關(guān)的函數(shù)。當(dāng)然,這個文件最重要的部分其實是將ATen.h引入了torch.h中,記得上文我們說過ATen目錄是與Tensor類實現(xiàn)相關(guān)的庫蟀架,也就是說苫纤,在這個文件中我們可以找到torch::Tensor的聲明操禀。

#pragma once

#include <ATen/ATen.h>

namespace torch {
namespace linalg {

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {

inline Tensor cholesky(const Tensor& self) {
  return torch::linalg_cholesky(self);
}
...

進(jìn)入到ATen.h文件后,我們看到如下代碼钩杰,其中出現(xiàn)了我們在上一節(jié)提到過的兩個核心概念怎披,TensorStorage,它們被包含在ATen/Tensor.h以及c10/core/Storage.h中菩佑。

#pragma once

#if !defined(_MSC_VER) && __cplusplus < 201703L
#error C++17 or later compatible compiler is required to use ATen.
#endif

#include <ATen/Context.h>
#include <ATen/Device.h>
#include <ATen/DeviceGuard.h>
#include <ATen/DimVector.h>
#include <ATen/Dispatch.h>
#include <ATen/Formatting.h>
#include <ATen/Functions.h>
#include <ATen/NamedTensor.h>
#include <ATen/ScalarOps.h>
#include <ATen/Tensor.h>
#include <ATen/TensorGeometry.h>
#include <ATen/TensorIndexing.h>
#include <ATen/TensorOperators.h>
#include <ATen/Version.h>
#include <ATen/core/ATenGeneral.h>
#include <ATen/core/Generator.h>
#include <ATen/core/Reduction.h>
#include <ATen/core/Scalar.h>
#include <ATen/core/UnsafeFromTH.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <c10/core/Allocator.h>
#include <c10/core/InferenceMode.h>
#include <c10/core/Layout.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>

// TODO: try to remove this
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
#include <ATen/NativeFunctions.h>

進(jìn)入ATen/Tensor.h中咧党,可以看到該文件其實引入了ATen/core/Tensor.h遮糖,因此我們進(jìn)入該文件敬惦,進(jìn)入后我們可以看到如下代碼:

#pragma once

#include <ATen/core/TensorBody.h>
#include <c10/util/Exception.h>

namespace at {
class TORCH_API OptionalTensorRef {
 public:
  OptionalTensorRef() = default;

  ~OptionalTensorRef() {
    ref_.unsafeReleaseTensorImpl();
  }

  OptionalTensorRef(const TensorBase& src)
      : ref_(Tensor::unsafe_borrow_t{}, src) {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
  }

  OptionalTensorRef(const OptionalTensorRef& rhs)
      : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {}

  OptionalTensorRef& operator=(OptionalTensorRef rhs) {
    std::swap(ref_, rhs.ref_);
    return *this;
  }

  bool has_value() const {
    return ref_.defined();
  }

  const Tensor& getTensorRef() const & {
    return ref_;
  }

  const Tensor& operator*() const & {
    return ref_;
  }

  const Tensor* operator->() const & {
    return &ref_;
  }

  operator bool() const {
    return ref_.defined();
  }

 private:
  Tensor ref_;
};

// Use to convert a TensorBase (that may be undefined) to an at::Tensor
// without bumping refcount.
class TORCH_API TensorRef {
 public:
  ~TensorRef() {
    ref_.unsafeReleaseTensorImpl();
  }

  TensorRef(const TensorBase& src)
      : ref_(Tensor::unsafe_borrow_t{}, src) {}

  const Tensor& operator*() const & {
    return ref_;
  }
 private:
  Tensor ref_;
};

template <typename T>
auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> {
  // Return the grad argument in case of a hook with void return type to have an
  // std::function with Tensor return type
  static_assert(std::is_same<decltype(hook(Tensor())), void>::value,
                "Expected hook to return void");
  return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
    TensorRef grad(grad_base);
    fn(*grad);
    return Tensor();
  });
}

template <typename T>
auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> {
  return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
    TensorRef grad(grad_base);
    Tensor ret = fn(*grad);
    return TensorBase(std::move(ret));
  });
}

} // namespace at

這段代碼其實主要是實現(xiàn)了ATen/core/TensorBody.h中聲明的部分類型盼理,同時,最重要的torch::Tensor類型也是在ATen/core/TensorBody.h這個文件中聲明的俄删,因此我們進(jìn)入這個文件宏怔。該文件中第92行到1458行為torch::Tensor的定義。從中我們可以看到畴椰,該類繼承于TensorBase類臊诊,同時沒有子類,也就是說在libtorch中斜脂,我們不能使用類似于pytorch中的torch.FloatTensor()這類函數(shù)初始化一個Tensor對象抓艳。這個現(xiàn)象在上文中其實有所體現(xiàn),因為這些類型是在python中進(jìn)行派生的帚戳。

class TORCH_API Tensor: public TensorBase {
 protected:
  // Create a Tensor with a +0 reference count. Special care must be
  // taken to avoid decrementing this reference count at destruction
  // time. Intended to support MaybeOwnedTraits<Tensor>.
  explicit Tensor(unsafe_borrow_t, const TensorBase& rhs): TensorBase(unsafe_borrow_t{}, rhs) {}
  friend MaybeOwnedTraits<Tensor>;
  friend OptionalTensorRef;
  friend TensorRef;

 public:
  Tensor() = default;
  // This constructor should not be used by end users and is an implementation
  // detail invoked by autogenerated code.
  explicit Tensor(
      c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
      : TensorBase(std::move(tensor_impl)) {}
  Tensor(const Tensor &tensor) = default;
  Tensor(Tensor &&tensor) = default;
...

torch::Storage淺析

現(xiàn)在我們來看看Storage類玷或,拋去那些繁瑣的查找儡首,我們看到c10/core/Storage.h頭文件的內(nèi)容,該文件主要是Storage類的定義偏友。

#pragma once

#include <c10/core/StorageImpl.h>

namespace c10 {

struct C10_API Storage {
 public:
  struct use_byte_size_t {};

  Storage() = default;
  Storage(c10::intrusive_ptr<StorageImpl> ptr)
      : storage_impl_(std::move(ptr)) {}

  // Allocates memory buffer using given allocator and creates a storage with it
  Storage(
      use_byte_size_t /*use_byte_size*/,
      SymInt size_bytes,
      Allocator* allocator = nullptr,
      bool resizable = false)
      : storage_impl_(c10::make_intrusive<StorageImpl>(
            StorageImpl::use_byte_size_t(),
            std::move(size_bytes),
            allocator,
            resizable)) {}

  // Creates storage with pre-allocated memory buffer. Allocator is given for
  // potential future reallocations, however it can be nullptr if the storage
  // is non-resizable
  Storage(
      use_byte_size_t /*use_byte_size*/,
      size_t size_bytes,
      at::DataPtr data_ptr,
      at::Allocator* allocator = nullptr,
      bool resizable = false)
      : storage_impl_(c10::make_intrusive<StorageImpl>(
            StorageImpl::use_byte_size_t(),
            size_bytes,
            std::move(data_ptr),
            allocator,
            resizable)) {}
...

在這部分位他,我們主要粗略地看了torch::Tensor以及torch::Storage的定義和部分實現(xiàn),在繼續(xù)翻看源碼之前鹅髓,我們在中間插入一些理論上的部分,這個部分主要和pytorch的設(shè)計理念有關(guān)窿冯。

Tensor原理介紹

我們先來學(xué)習(xí)下Tensor的實現(xiàn)原理,即官方在實現(xiàn)Tensor的過程中遵循了怎樣的思想靡菇。Tensor 是PyTorch的核心數(shù)據(jù)結(jié)構(gòu)重归,它是包含若干個標(biāo)量(標(biāo)量可以是各種數(shù)據(jù)類型如浮點型鼻吮、整形等)的n-維的數(shù)據(jù)結(jié)構(gòu)椎木。我們可以認(rèn)為tensor包含了數(shù)據(jù)和元數(shù)據(jù)(metadata)博烂,元數(shù)據(jù)用來描述tensor的大小禽篱、其包含內(nèi)部數(shù)據(jù)的類型躺率、存儲的位置(CPU內(nèi)存或是CUDA顯存),而數(shù)據(jù)則是tensor真正的物理存儲的數(shù)據(jù)慎框。簡單來說笨枯,我們對一個Tensor對象進(jìn)行操作的時候馅精,比如切分洲敢,resize等操作,實際上并不會改變這個Tensor對象在物理上的存儲位置,而是通過操作metadata來改變這個Tensor對象的“邏輯表示”哮塞。

tensor原理.png

元數(shù)據(jù)metadata中有我們已經(jīng)熟知的一些屬性忆畅,如device,sizes,dtype家凯,同樣的也存在layout绊诲,strides這些我們以前并未了解過的屬性掂之。
我們先來解釋strides步長的概念脆丁,首先我們在上方已經(jīng)提到槽卫,操作Tensor對象實際上是操作metadata的過程歼培,例如執(zhí)行以下代碼:

import torch
a = torch.tensor([[1, 2],[3, 4]])
print(a[1, 0])

我們可以很容易的知道這段代碼打印的結(jié)果是3丐怯,因為這符合我們的編程直覺和習(xí)慣读跷。那么,當(dāng)我們對a這個變量進(jìn)行索引的時候无切,其底層究竟是怎么做的呢哆键。根據(jù)我們在上文所講述的籍嘹,一個Tensor對象分為數(shù)據(jù)和元數(shù)據(jù)辱士,數(shù)據(jù)其實是連續(xù)分配在某設(shè)備device上的颂碘,而元數(shù)據(jù)中的strides可以實現(xiàn)索引與數(shù)據(jù)物理位置的一一映射头岔,具體我們看下方這張圖:

stride原理.png

Tensor是一個數(shù)學(xué)概念峡竣。當(dāng)用計算機表示數(shù)學(xué)概念的時候澎胡,通常我們需要定義一種物理存儲方式攻谁。最常見的表示方式是將Tensor中的每個元素按照次序連續(xù)的在內(nèi)存中鋪開戚宦,將每一行寫到相應(yīng)內(nèi)存位置里受楼。如上圖所示艳汽,假設(shè)tensor包含的是32位的整數(shù)河狐,因此每個整數(shù)占據(jù)一塊物理內(nèi)存,每個整數(shù)的地址都和上下相鄰整數(shù)相差4個字節(jié)迈套。為了記住tensor的實際維度桑李,我們需要將tensor的維度大小記錄在額外的元數(shù)據(jù)中贵白。假設(shè)我想要訪問位于tensor [1, 0]位置處的元素戒洼,如何將這個邏輯地址轉(zhuǎn)化到物理內(nèi)存的地址上呢?步長就是用來解決這樣的問題:當(dāng)我們根據(jù)下標(biāo)索引查找tensor中的任意元素時靴寂,將某維度的下標(biāo)索引和對應(yīng)的步長相乘百炬,然后將所有維度乘積相加就可以了剖踊。在上圖中我將第一維(行)標(biāo)為紅色德澈,第二維(列)標(biāo)為藍(lán)色梆造,因此你能夠在計算中方便的觀察下標(biāo)和步長的對應(yīng)關(guān)系镇辉。求和返回了一個0維的標(biāo)量2忽肛,而內(nèi)存中地址偏移量為2的位置正好儲存了元素3屹逛。
到目前為止煎源,細(xì)心的讀者應(yīng)該已經(jīng)可以意識到手销,為什么一個Storage對象可以對應(yīng)多個Tensor對象了锋拖。很顯然兽埃,對于上面所舉的例子柄错,其中物理位置就是由Storage對象進(jìn)行管理的售貌,而邏輯表示則是由Tensor對象進(jìn)行管理颂跨,通過不同的strides恒削,sizes钓丰,我們可以很輕松地實現(xiàn)將一個Storage對象映射到多個Tensor對象中去携丁。

Tensor源碼解讀

TensorBase類

在開始講這個類之前则北,我們先來看一下core\TensorBase.h引入的頭文件涌矢,可以看到當(dāng)中有我們之前提到的一些熟悉的面孔娜庇,像Layout.h名秀,Storage.h等匕得。

#include <c10/core/Device.h>
#include <c10/core/Layout.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/ScalarType.h>
#include <c10/core/ScalarTypeToTypeMeta.h>
#include <c10/core/Storage.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/TensorOptions.h>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/core/WrapDimMinimal.h>
#include <c10/util/Exception.h>
#include <c10/util/ExclusivelyOwned.h>
#include <c10/util/ExclusivelyOwnedTensorTraits.h>
#include <c10/util/MaybeOwned.h>
#include <c10/util/Optional.h>
#include <c10/util/intrusive_ptr.h>

#include <ATen/core/NamedTensor.h>
#include <ATen/core/QuantizerBase.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/StorageUtils.h>

我們在ATen\core\TensorBase.h中可以看到TensorBase類的聲明和部分實現(xiàn)略吨,這是所有Tensor類的基類翠忠,本文不去細(xì)究該類每一個成員函數(shù)的作用秽之,而是從宏觀的角度介紹這個類政溃。

class TORCH_API TensorBase {
 public:
  struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };

 protected:
  // Create a Tensor with a +0 reference count. Special care must be
  // taken to avoid decrementing this reference count at destruction
  // time. Intended to support MaybeOwnedTraits<Tensor>.
  explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs)
      : impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>::reclaim(rhs.impl_.get())) {}
  friend MaybeOwnedTraits<TensorBase>;
...

首先,對于部分C++了解不多的讀者可能不清楚在class和類名TensorBase之間的TORCH_API是什么。其實TORCH_API是一個宏捐友,我們知道宏是在程序的預(yù)編譯時期起作用的匣砖,自然這里的TORCH_API也一定是由于某種需要在預(yù)編譯時期進(jìn)行處理的需求而出現(xiàn)的猴鲫。我們再來看到開發(fā)者在contributing.md中說到的一段話:

Symbols are NOT exported by default on Windows; instead, you have to explicitly mark a symbol as exported/imported in a header file with __declspec(dllexport) / __declspec(dllimport). We have codified this pattern into a set of macros which follow the convention *_API, e.g., TORCH_API inside Caffe2, Aten and Torch. (Every separate shared library needs a unique macro name, because symbol visibility is on a per shared library basis. See c10/macros/Macros.h for more details.)

其實這里也就說明了為什么要使用TORCH_API的原因了谣殊,因為pytorch是以C姻几、C++為后端,python為前端的框架。想要在python中使用C++回溺,就不免需要用到動態(tài)鏈接庫馅而,即將C++的代碼導(dǎo)出為.dll.so格式瓮恭,而在windows上這些函數(shù)不能直接導(dǎo)出屯蹦,必須在頭文件中使用__declspec(dllexport)/__declsspec(dllimport)顯式地將符號標(biāo)記為導(dǎo)出/導(dǎo)入登澜,而標(biāo)記方式就是在類前加上這些宏。因此官方采用了一套*_API的宏用以編碼谴仙,TORCH_API就是其中一種晃跺。
去查找TORCH_API這個宏掀虎,不難發(fā)現(xiàn)其實該宏的定義來自于C10_IMPORT這個宏烹玉。

#define TORCH_API C10_IMPORT

C10_IMPORT這個宏又定義于__declspec(dllimport),這也印證了前文所述址儒。

#define C10_IMPORT __declspec(dllimport)

除了TORCH_API以外莲趣,這里還有個常用的宏函數(shù)TORCH_CHECK用于對數(shù)據(jù)合法性進(jìn)行檢查喧伞。

intrusive_ptr_target類

TensorBase.h文件的剩下部分都是對TensorBase類的成員函數(shù)進(jìn)行定義潘鲫,幾乎每一個成員函數(shù)都用到了一個成員變量impl_溉仑,同時這些用到該成員變量的成員函數(shù)基本都可以認(rèn)為是作用在這個成員變量上實現(xiàn)的浊竟,找到該文件的第890行必怜,這里給出了該變量的定義:

protected:
  void enforce_invariants();
  c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;

在這里梳庆,對這個成員變量而言膏执,我們需要注意兩個部分,一是它的數(shù)據(jù)類型,是個典型的模板類镇草,二是它的泛型聲明為了TensorImplUndefinedTensorImpl兩個類型梯啤。
我們首先來看TensorImpl是什么因宇,實際上在TensorBody.h中有這么一段注釋打厘,這段注釋位于75到91行户盯,它的大意可以理解為莽鸭,Tensor本質(zhì)上是一個采用引用計數(shù)的對象,多個Tensor可以同時指向同一個TensorImpl捺球。也就是說TensorBase這個類本質(zhì)上是通過引用計數(shù)的方式指向TensorImpl對象來實現(xiàn)底層操作的氮兵,即TensorBase可以理解為對TensorImpl的進(jìn)一步封裝。當(dāng)然南片,這樣的設(shè)計思想在各種工程項目中也隨處可見疼进。

// Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which
// has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr.
//
// For example:
//
// void func(Tensor a) {
//   Tensor b = a;
//   ...
// }
//
// In this example, when we say Tensor b = a, we are creating a new object that points to the
// same underlying TensorImpl, and bumps its reference count. When b goes out of scope, the
// destructor decrements the reference count by calling release() on the TensorImpl it points to.
// The existing constructors, operator overloads, etc. take care to implement the correct semantics.
//
// Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and
// special care must be taken to handle this.

另一方面,它的類型是c10::intrusive_ptr嚼锄,而這是pytorch框架最基礎(chǔ)区丑,最核心的數(shù)據(jù)結(jié)構(gòu)代碼可霎,我們先來看一下官方是怎么描述的:

/**
 * intrusive_ptr<T> is an alternative to shared_ptr<T> that has better
 * performance because it does the refcounting intrusively
 * (i.e. in a member of the object itself).
 * Your class T needs to inherit from intrusive_ptr_target to allow it to be
 * used in an intrusive_ptr<T>. Your class's constructor should not allow
 *`this` to escape to other threads or create an intrusive_ptr from `this`.
 */

我們?nèi)ふ疫@個類型的定義啥纸,會發(fā)現(xiàn)在c10/util/intrusive_ptr.h這個文件下,第54到90行有如下代碼婴氮,從這段代碼可以知道c10::intrusive_ptrintrusive_ptr_target的友元類斯棒,除此之外,還有個值得注意的友元類weak_intrusive_ptr主经。實際上荣暮,PyTorch中使用intrusive_ptr來管理TensorStorage的引用計數(shù),其中引用分為強引用和弱引用(弱引用為了解決循環(huán)引用問題)穗酥,對應(yīng)的類名 intrusive_ptrweak_intrusive_ptr

class C10_API intrusive_ptr_target {
  // Note [Weak references for intrusive refcounting]
  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // Here's the scheme:
  //
  //  - refcount == number of strong references to the object
  //    weakcount == number of weak references to the object,
  //      plus one more if refcount > 0
  //    An invariant: refcount > 0  =>  weakcount > 0
  //
  //  - c10::StorageImpl stays live as long as there are any strong
  //    or weak pointers to it (weakcount > 0, since strong
  //    references count as a +1 to weakcount)
  //
  //  - finalizers are called and data_ptr is deallocated when refcount == 0
  //
  //  - Once refcount == 0, it can never again be > 0 (the transition
  //    from > 0 to == 0 is monotonic)
  //
  //  - When you access c10::StorageImpl via a weak pointer, you must
  //    atomically increment the use count, if it is greater than 0.
  //    If it is not, you must report that the storage is dead.
  //
  mutable std::atomic<size_t> refcount_;
  mutable std::atomic<size_t> weakcount_;

  template <typename T, typename NullType>
  friend class intrusive_ptr;
  friend inline void raw::intrusive_ptr::incref(intrusive_ptr_target* self);

  template <typename T, typename NullType>
  friend class weak_intrusive_ptr;
  friend inline void raw::weak_intrusive_ptr::incref(
      intrusive_ptr_target* self);

  template <typename T>
  friend struct ExclusivelyOwnedTensorTraits;

當(dāng)然惠遏,不僅如此砾跃,除了intrusive_ptrweak_intrusive_ptrintrusive_ptr_target的友元類之外,還可以知道的是节吮,TensorImplUndefinedTensorImpl均是intrusive_ptr_target的子類抽高,從這里我們也可以看出intrusive_ptr_target是pytorch得以實現(xiàn)的極為底層的核心數(shù)據(jù)結(jié)構(gòu)。

TensorImpl類

現(xiàn)在我們來看一下TensorImpl類是如何組織的透绩,這部分定義位于c10/core/TensorImpl.h的497到3060行翘骂,在這段代碼里,我們可以看到一些上文提到過的概念帚豪,比如TensorImpl在初始化的時候其實傳入了一個Storage對象碳竟,同時TensorImpl也出現(xiàn)了上文提到過的sizes。由于TensorBase本質(zhì)上是對TesorImpl的引用狸臣,那么Tensor也就是對TensorImpl的引用莹桅,所有當(dāng)我們創(chuàng)建torch::Tensor類型時,實際上是執(zhí)行了這段代碼烛亦。

struct C10_API TensorImpl : public c10::intrusive_ptr_target {
  TensorImpl() = delete;
  ~TensorImpl() override;
  // Note [Enum ImplType]
  // This enum is temporary. In the followup refactor we should
  // think about how to specialize TensorImpl creation for view
  // tensors. Currently we only special case its key_set_ but
  // there's also potential to share version_counter_ directly
  // without creating first and then override in as_view.
  enum ImplType { VIEW };

  /**
   * Construct a 1-dim 0-size tensor backed by the given storage.
   */
  TensorImpl(
      Storage&& storage,
      DispatchKeySet,
      const caffe2::TypeMeta data_type);

  // See Note [Enum ImplType]
  TensorImpl(
      ImplType,
      Storage&& storage,
      DispatchKeySet,
      const caffe2::TypeMeta data_type);
...
  TensorImpl(const TensorImpl&) = delete;
  TensorImpl& operator=(const TensorImpl&) = delete;
  TensorImpl(TensorImpl&&) = delete;
  TensorImpl& operator=(TensorImpl&&) = delete;
...
 public:
  /**
   * Return a reference to the sizes of this tensor.  This reference remains
   * valid as long as the tensor is live and not resized.
   */
  IntArrayRef sizes() const {
    if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
      return sizes_custom();
    }
    return sizes_and_strides_.sizes_arrayref();
  }
...

StorageImpl類

StorageImpl類位于c10/core/StorageImpl.h中诈泼,它也是intrusive_ptr_target的子類,它的構(gòu)造函數(shù)里有不少pytorch中的核心概念此洲,如Allocator厂汗,resizable等委粉,這些概念這篇文章不做討論呜师,將在后續(xù)的系列談及。

struct C10_API StorageImpl : public c10::intrusive_ptr_target {
 public:
  struct use_byte_size_t {};

  StorageImpl(
      use_byte_size_t /*use_byte_size*/,
      SymInt size_bytes,
      at::DataPtr data_ptr,
      at::Allocator* allocator,
      bool resizable)
      : data_ptr_(std::move(data_ptr)),
        size_bytes_(std::move(size_bytes)),
        size_bytes_is_heap_allocated_(size_bytes_.is_heap_allocated()),
        resizable_(resizable),
        received_cuda_(false),
        allocator_(allocator) {
    if (resizable) {
      TORCH_INTERNAL_ASSERT(
          allocator_, "For resizable storage, allocator must be provided");
    }
  }

  StorageImpl(
      use_byte_size_t /*use_byte_size*/,
      const SymInt& size_bytes,
      at::Allocator* allocator,
      bool resizable)
      : StorageImpl(
            use_byte_size_t(),
            size_bytes,
            size_bytes.is_heap_allocated()
                ? allocator->allocate(0)
                : allocator->allocate(size_bytes.as_int_unchecked()),
            allocator,
            resizable) {}

  StorageImpl& operator=(StorageImpl&& other) = delete;
  StorageImpl& operator=(const StorageImpl&) = delete;
  StorageImpl() = delete;
  StorageImpl(StorageImpl&& other) = delete;
  StorageImpl(const StorageImpl&) = delete;
  ~StorageImpl() override = default;
...

總結(jié)贾节,本文主要講述了Tensor的源碼實現(xiàn)機理汁汗,以及StorageTensor的關(guān)系衷畦,下一篇文章將會講述pytorch是如何實現(xiàn)C++與python綁定的。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末知牌,一起剝皮案震驚了整個濱河市祈争,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌角寸,老刑警劉巖菩混,帶你破解...
    沈念sama閱讀 210,978評論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異扁藕,居然都是意外死亡沮峡,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 89,954評論 2 384
  • 文/潘曉璐 我一進(jìn)店門亿柑,熙熙樓的掌柜王于貴愁眉苦臉地迎上來邢疙,“玉大人,你說我怎么就攤上這事望薄∨庇危” “怎么了?”我有些...
    開封第一講書人閱讀 156,623評論 0 345
  • 文/不壞的土叔 我叫張陵痕支,是天一觀的道長颁虐。 經(jīng)常有香客問我,道長采转,這世上最難降的妖魔是什么聪廉? 我笑而不...
    開封第一講書人閱讀 56,324評論 1 282
  • 正文 為了忘掉前任,我火速辦了婚禮故慈,結(jié)果婚禮上板熊,老公的妹妹穿的比我還像新娘。我一直安慰自己察绷,他們只是感情好干签,可當(dāng)我...
    茶點故事閱讀 65,390評論 5 384
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著拆撼,像睡著了一般容劳。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上闸度,一...
    開封第一講書人閱讀 49,741評論 1 289
  • 那天竭贩,我揣著相機與錄音,去河邊找鬼莺禁。 笑死留量,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播楼熄,決...
    沈念sama閱讀 38,892評論 3 405
  • 文/蒼蘭香墨 我猛地睜開眼忆绰,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了可岂?” 一聲冷哼從身側(cè)響起错敢,我...
    開封第一講書人閱讀 37,655評論 0 266
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎缕粹,沒想到半個月后稚茅,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,104評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡平斩,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,451評論 2 325
  • 正文 我和宋清朗相戀三年峰锁,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片双戳。...
    茶點故事閱讀 38,569評論 1 340
  • 序言:一個原本活蹦亂跳的男人離奇死亡虹蒋,死狀恐怖房午,靈堂內(nèi)的尸體忽然破棺而出纤勒,到底是詐尸還是另有隱情脆淹,我是刑警寧澤灶伊,帶...
    沈念sama閱讀 34,254評論 4 328
  • 正文 年R本政府宣布乌奇,位于F島的核電站陪竿,受9級特大地震影響迄汛,放射性物質(zhì)發(fā)生泄漏寂恬。R本人自食惡果不足惜扣墩,卻給世界環(huán)境...
    茶點故事閱讀 39,834評論 3 312
  • 文/蒙蒙 一哲银、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧呻惕,春花似錦荆责、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,725評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至濒持,卻和暖如春键耕,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背柑营。 一陣腳步聲響...
    開封第一講書人閱讀 31,950評論 1 264
  • 我被黑心中介騙來泰國打工屈雄, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人官套。 一個月前我還...
    沈念sama閱讀 46,260評論 2 360
  • 正文 我出身青樓酒奶,卻偏偏與公主長得像蓖议,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子讥蟆,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 43,446評論 2 348

推薦閱讀更多精彩內(nèi)容