前言
上文我們通過針對性的閱讀pytorch源碼的框架結(jié)構(gòu)猾瘸,同時以__init__.py
文件為線索探索了pytorch中多個主要類型的實現(xiàn)和功能。在上文中我們曾提到丢习,pytorch框架是一個以python為前端牵触,C++為后端的框架。如果去除掉pytorch中的python前端咐低,那我們就可以得到一個C++的AI框架——libtorch
揽思。沒有特殊說明時,本文所看源碼均是取自libtorch
而非pytorch
见擦,其中libtorch
的源碼版本是libtorch-win-shared-with-deps-2.1.0+cpu
钉汗。
一個引子
我們來看看如下代碼:
torch::Tensor x = torch::tensor({1.0});
是否覺得很眼熟,其實這就是當(dāng)我們在python中調(diào)用torch.tensor()
這個函數(shù)時鲤屡,其在C++后端實際上調(diào)用的函數(shù)损痰,從這一個例子中我們可以知道,要想真正理解pytorch的底層原理酒来,就需要深入的理解libtorch的源碼卢未。
目錄結(jié)構(gòu)
和上文我們解析pytorch源碼一樣,現(xiàn)在我們來看一下libtorch的源碼結(jié)構(gòu)堰汉,值得注意的是辽社,和pytorch一樣,libtorch
也分為cpu版本和gpu版本兩個版本翘鸭,因此其目錄結(jié)構(gòu)稍有不同滴铅。我們主要看cpu版本的include
目錄以及lib
目錄,lib
目錄下存放的是靜態(tài)鏈接文件就乓,而include
下存放的是頭文件失息,我們在使用libtorch
的時候,實際上是引入頭文件档址,最后在編譯過程中,由編譯器根據(jù)頭文件去找到路徑下的動態(tài)鏈接文件邻梆,并在鏈接期間將其鏈接進(jìn)來守伸。
include
目錄下則是這個樣子,可以很容易地發(fā)現(xiàn)浦妄,這和pytorch的源碼結(jié)構(gòu)是幾乎一樣的尼摹。torch.h解讀
torch.h
源碼很短见芹,只有如下幾行,從中我們可以看到它其實是引入了all.h
這個頭文件蠢涝,另外還引入了一個extension.h
文件玄呛。這里這個extension.h
文件其實是引入了python.h
,當(dāng)然和二,由于這里討論的是無python前端的libtorch
版本徘铝,因此不需要注意這個。
#pragma once
#include <torch/all.h>
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // defined(TORCH_API_INCLUDE_EXTENSION_H)
在進(jìn)入all.h
后惯吕,可以看到如下代碼惕它,很明顯,這個頭文件就是將所有模塊下的頭文件引入废登,現(xiàn)階段我們要關(guān)注的只有linalg.h
這一個文件淹魄。
#pragma once
#if !defined(_MSC_VER) && __cplusplus < 201703L
#error C++17 or later compatible compiler is required to use PyTorch.
#endif
#include <torch/autograd.h>
#include <torch/cuda.h>
#include <torch/data.h>
#include <torch/enum.h>
#include <torch/fft.h>
#include <torch/jit.h>
#include <torch/linalg.h>
#include <torch/mps.h>
#include <torch/nested.h>
#include <torch/nn.h>
#include <torch/optim.h>
#include <torch/serialize.h>
#include <torch/sparse.h>
#include <torch/special.h>
#include <torch/types.h>
#include <torch/utils.h>
#include <torch/version.h>
torch::Tensor淺析
進(jìn)入linalg.h
后我們可以看到如下代碼甲锡,這個文件提供了相當(dāng)多的內(nèi)聯(lián)函數(shù)疚俱,主要是一些矩陣運算相關(guān)的函數(shù)。當(dāng)然,這個文件最重要的部分其實是將ATen.h
引入了torch.h
中,記得上文我們說過ATen
目錄是與Tensor
類實現(xiàn)相關(guān)的庫蟀架,也就是說苫纤,在這個文件中我們可以找到torch::Tensor
的聲明操禀。
#pragma once
#include <ATen/ATen.h>
namespace torch {
namespace linalg {
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor cholesky(const Tensor& self) {
return torch::linalg_cholesky(self);
}
...
進(jìn)入到ATen.h
文件后,我們看到如下代碼钩杰,其中出現(xiàn)了我們在上一節(jié)提到過的兩個核心概念怎披,Tensor
和Storage
,它們被包含在ATen/Tensor.h
以及c10/core/Storage.h
中菩佑。
#pragma once
#if !defined(_MSC_VER) && __cplusplus < 201703L
#error C++17 or later compatible compiler is required to use ATen.
#endif
#include <ATen/Context.h>
#include <ATen/Device.h>
#include <ATen/DeviceGuard.h>
#include <ATen/DimVector.h>
#include <ATen/Dispatch.h>
#include <ATen/Formatting.h>
#include <ATen/Functions.h>
#include <ATen/NamedTensor.h>
#include <ATen/ScalarOps.h>
#include <ATen/Tensor.h>
#include <ATen/TensorGeometry.h>
#include <ATen/TensorIndexing.h>
#include <ATen/TensorOperators.h>
#include <ATen/Version.h>
#include <ATen/core/ATenGeneral.h>
#include <ATen/core/Generator.h>
#include <ATen/core/Reduction.h>
#include <ATen/core/Scalar.h>
#include <ATen/core/UnsafeFromTH.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <c10/core/Allocator.h>
#include <c10/core/InferenceMode.h>
#include <c10/core/Layout.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>
// TODO: try to remove this
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
#include <ATen/NativeFunctions.h>
進(jìn)入ATen/Tensor.h
中咧党,可以看到該文件其實引入了ATen/core/Tensor.h
遮糖,因此我們進(jìn)入該文件敬惦,進(jìn)入后我們可以看到如下代碼:
#pragma once
#include <ATen/core/TensorBody.h>
#include <c10/util/Exception.h>
namespace at {
class TORCH_API OptionalTensorRef {
public:
OptionalTensorRef() = default;
~OptionalTensorRef() {
ref_.unsafeReleaseTensorImpl();
}
OptionalTensorRef(const TensorBase& src)
: ref_(Tensor::unsafe_borrow_t{}, src) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
}
OptionalTensorRef(const OptionalTensorRef& rhs)
: ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {}
OptionalTensorRef& operator=(OptionalTensorRef rhs) {
std::swap(ref_, rhs.ref_);
return *this;
}
bool has_value() const {
return ref_.defined();
}
const Tensor& getTensorRef() const & {
return ref_;
}
const Tensor& operator*() const & {
return ref_;
}
const Tensor* operator->() const & {
return &ref_;
}
operator bool() const {
return ref_.defined();
}
private:
Tensor ref_;
};
// Use to convert a TensorBase (that may be undefined) to an at::Tensor
// without bumping refcount.
class TORCH_API TensorRef {
public:
~TensorRef() {
ref_.unsafeReleaseTensorImpl();
}
TensorRef(const TensorBase& src)
: ref_(Tensor::unsafe_borrow_t{}, src) {}
const Tensor& operator*() const & {
return ref_;
}
private:
Tensor ref_;
};
template <typename T>
auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> {
// Return the grad argument in case of a hook with void return type to have an
// std::function with Tensor return type
static_assert(std::is_same<decltype(hook(Tensor())), void>::value,
"Expected hook to return void");
return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
TensorRef grad(grad_base);
fn(*grad);
return Tensor();
});
}
template <typename T>
auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> {
return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
TensorRef grad(grad_base);
Tensor ret = fn(*grad);
return TensorBase(std::move(ret));
});
}
} // namespace at
這段代碼其實主要是實現(xiàn)了ATen/core/TensorBody.h
中聲明的部分類型盼理,同時,最重要的torch::Tensor
類型也是在ATen/core/TensorBody.h
這個文件中聲明的俄删,因此我們進(jìn)入這個文件宏怔。該文件中第92行到1458行為torch::Tensor
的定義。從中我們可以看到畴椰,該類繼承于TensorBase
類臊诊,同時沒有子類,也就是說在libtorch中斜脂,我們不能使用類似于pytorch中的torch.FloatTensor()
這類函數(shù)初始化一個Tensor
對象抓艳。這個現(xiàn)象在上文中其實有所體現(xiàn),因為這些類型是在python中進(jìn)行派生的帚戳。
class TORCH_API Tensor: public TensorBase {
protected:
// Create a Tensor with a +0 reference count. Special care must be
// taken to avoid decrementing this reference count at destruction
// time. Intended to support MaybeOwnedTraits<Tensor>.
explicit Tensor(unsafe_borrow_t, const TensorBase& rhs): TensorBase(unsafe_borrow_t{}, rhs) {}
friend MaybeOwnedTraits<Tensor>;
friend OptionalTensorRef;
friend TensorRef;
public:
Tensor() = default;
// This constructor should not be used by end users and is an implementation
// detail invoked by autogenerated code.
explicit Tensor(
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
: TensorBase(std::move(tensor_impl)) {}
Tensor(const Tensor &tensor) = default;
Tensor(Tensor &&tensor) = default;
...
torch::Storage淺析
現(xiàn)在我們來看看Storage
類玷或,拋去那些繁瑣的查找儡首,我們看到c10/core/Storage.h
頭文件的內(nèi)容,該文件主要是Storage
類的定義偏友。
#pragma once
#include <c10/core/StorageImpl.h>
namespace c10 {
struct C10_API Storage {
public:
struct use_byte_size_t {};
Storage() = default;
Storage(c10::intrusive_ptr<StorageImpl> ptr)
: storage_impl_(std::move(ptr)) {}
// Allocates memory buffer using given allocator and creates a storage with it
Storage(
use_byte_size_t /*use_byte_size*/,
SymInt size_bytes,
Allocator* allocator = nullptr,
bool resizable = false)
: storage_impl_(c10::make_intrusive<StorageImpl>(
StorageImpl::use_byte_size_t(),
std::move(size_bytes),
allocator,
resizable)) {}
// Creates storage with pre-allocated memory buffer. Allocator is given for
// potential future reallocations, however it can be nullptr if the storage
// is non-resizable
Storage(
use_byte_size_t /*use_byte_size*/,
size_t size_bytes,
at::DataPtr data_ptr,
at::Allocator* allocator = nullptr,
bool resizable = false)
: storage_impl_(c10::make_intrusive<StorageImpl>(
StorageImpl::use_byte_size_t(),
size_bytes,
std::move(data_ptr),
allocator,
resizable)) {}
...
在這部分位他,我們主要粗略地看了torch::Tensor
以及torch::Storage
的定義和部分實現(xiàn),在繼續(xù)翻看源碼之前鹅髓,我們在中間插入一些理論上的部分,這個部分主要和pytorch的設(shè)計理念有關(guān)窿冯。
Tensor原理介紹
我們先來學(xué)習(xí)下Tensor的實現(xiàn)原理,即官方在實現(xiàn)Tensor的過程中遵循了怎樣的思想靡菇。Tensor 是PyTorch的核心數(shù)據(jù)結(jié)構(gòu)重归,它是包含若干個標(biāo)量(標(biāo)量可以是各種數(shù)據(jù)類型如浮點型鼻吮、整形等)的n-維的數(shù)據(jù)結(jié)構(gòu)椎木。我們可以認(rèn)為tensor包含了數(shù)據(jù)和元數(shù)據(jù)(metadata)博烂,元數(shù)據(jù)用來描述tensor的大小禽篱、其包含內(nèi)部數(shù)據(jù)的類型躺率、存儲的位置(CPU內(nèi)存或是CUDA顯存),而數(shù)據(jù)則是tensor真正的物理存儲的數(shù)據(jù)慎框。簡單來說笨枯,我們對一個Tensor
對象進(jìn)行操作的時候馅精,比如切分洲敢,resize
等操作,實際上并不會改變這個Tensor
對象在物理上的存儲位置,而是通過操作metadata來改變這個Tensor
對象的“邏輯表示”哮塞。
元數(shù)據(jù)metadata中有我們已經(jīng)熟知的一些屬性忆畅,如
device
,sizes
,dtype
家凯,同樣的也存在layout
绊诲,strides
這些我們以前并未了解過的屬性掂之。我們先來解釋
strides
步長的概念脆丁,首先我們在上方已經(jīng)提到槽卫,操作Tensor
對象實際上是操作metadata的過程歼培,例如執(zhí)行以下代碼:
import torch
a = torch.tensor([[1, 2],[3, 4]])
print(a[1, 0])
我們可以很容易的知道這段代碼打印的結(jié)果是3丐怯,因為這符合我們的編程直覺和習(xí)慣读跷。那么,當(dāng)我們對a
這個變量進(jìn)行索引的時候无切,其底層究竟是怎么做的呢哆键。根據(jù)我們在上文所講述的籍嘹,一個Tensor對象分為數(shù)據(jù)和元數(shù)據(jù)辱士,數(shù)據(jù)其實是連續(xù)分配在某設(shè)備device
上的颂碘,而元數(shù)據(jù)中的strides
可以實現(xiàn)索引與數(shù)據(jù)物理位置的一一映射头岔,具體我們看下方這張圖:
Tensor是一個數(shù)學(xué)概念峡竣。當(dāng)用計算機表示數(shù)學(xué)概念的時候澎胡,通常我們需要定義一種物理存儲方式攻谁。最常見的表示方式是將Tensor中的每個元素按照次序連續(xù)的在內(nèi)存中鋪開戚宦,將每一行寫到相應(yīng)內(nèi)存位置里受楼。如上圖所示艳汽,假設(shè)tensor包含的是32位的整數(shù)河狐,因此每個整數(shù)占據(jù)一塊物理內(nèi)存,每個整數(shù)的地址都和上下相鄰整數(shù)相差4個字節(jié)迈套。為了記住tensor的實際維度桑李,我們需要將tensor的維度大小記錄在額外的元數(shù)據(jù)中贵白。假設(shè)我想要訪問位于tensor [1, 0]位置處的元素戒洼,如何將這個邏輯地址轉(zhuǎn)化到物理內(nèi)存的地址上呢?步長就是用來解決這樣的問題:當(dāng)我們根據(jù)下標(biāo)索引查找tensor中的任意元素時靴寂,將某維度的下標(biāo)索引和對應(yīng)的步長相乘百炬,然后將所有維度乘積相加就可以了剖踊。在上圖中我將第一維(行)標(biāo)為紅色德澈,第二維(列)標(biāo)為藍(lán)色梆造,因此你能夠在計算中方便的觀察下標(biāo)和步長的對應(yīng)關(guān)系镇辉。求和返回了一個0維的標(biāo)量2忽肛,而內(nèi)存中地址偏移量為2的位置正好儲存了元素3屹逛。
到目前為止煎源,細(xì)心的讀者應(yīng)該已經(jīng)可以意識到手销,為什么一個Storage
對象可以對應(yīng)多個Tensor
對象了锋拖。很顯然兽埃,對于上面所舉的例子柄错,其中物理位置就是由Storage
對象進(jìn)行管理的售貌,而邏輯表示則是由Tensor
對象進(jìn)行管理颂跨,通過不同的strides恒削,sizes钓丰,我們可以很輕松地實現(xiàn)將一個Storage
對象映射到多個Tensor
對象中去携丁。
Tensor源碼解讀
TensorBase類
在開始講這個類之前则北,我們先來看一下core\TensorBase.h
引入的頭文件涌矢,可以看到當(dāng)中有我們之前提到的一些熟悉的面孔娜庇,像Layout.h
名秀,Storage.h
等匕得。
#include <c10/core/Device.h>
#include <c10/core/Layout.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/ScalarType.h>
#include <c10/core/ScalarTypeToTypeMeta.h>
#include <c10/core/Storage.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/TensorOptions.h>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/core/WrapDimMinimal.h>
#include <c10/util/Exception.h>
#include <c10/util/ExclusivelyOwned.h>
#include <c10/util/ExclusivelyOwnedTensorTraits.h>
#include <c10/util/MaybeOwned.h>
#include <c10/util/Optional.h>
#include <c10/util/intrusive_ptr.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/core/QuantizerBase.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/StorageUtils.h>
我們在ATen\core\TensorBase.h
中可以看到TensorBase
類的聲明和部分實現(xiàn)略吨,這是所有Tensor類的基類翠忠,本文不去細(xì)究該類每一個成員函數(shù)的作用秽之,而是從宏觀的角度介紹這個類政溃。
class TORCH_API TensorBase {
public:
struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };
protected:
// Create a Tensor with a +0 reference count. Special care must be
// taken to avoid decrementing this reference count at destruction
// time. Intended to support MaybeOwnedTraits<Tensor>.
explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs)
: impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>::reclaim(rhs.impl_.get())) {}
friend MaybeOwnedTraits<TensorBase>;
...
首先,對于部分C++了解不多的讀者可能不清楚在class
和類名TensorBase
之間的TORCH_API
是什么。其實TORCH_API
是一個宏捐友,我們知道宏是在程序的預(yù)編譯時期起作用的匣砖,自然這里的TORCH_API
也一定是由于某種需要在預(yù)編譯時期進(jìn)行處理的需求而出現(xiàn)的猴鲫。我們再來看到開發(fā)者在contributing.md
中說到的一段話:
Symbols are NOT exported by default on Windows; instead, you have to explicitly mark a symbol as exported/imported in a header file with __declspec(dllexport) / __declspec(dllimport). We have codified this pattern into a set of macros which follow the convention *_API, e.g., TORCH_API inside Caffe2, Aten and Torch. (Every separate shared library needs a unique macro name, because symbol visibility is on a per shared library basis. See c10/macros/Macros.h for more details.)
其實這里也就說明了為什么要使用TORCH_API
的原因了谣殊,因為pytorch是以C姻几、C++為后端,python為前端的框架。想要在python中使用C++回溺,就不免需要用到動態(tài)鏈接庫馅而,即將C++的代碼導(dǎo)出為.dll
或.so
格式瓮恭,而在windows上這些函數(shù)不能直接導(dǎo)出屯蹦,必須在頭文件中使用__declspec(dllexport)
/__declsspec(dllimport)
顯式地將符號標(biāo)記為導(dǎo)出/導(dǎo)入登澜,而標(biāo)記方式就是在類前加上這些宏。因此官方采用了一套*_API
的宏用以編碼谴仙,TORCH_API
就是其中一種晃跺。
去查找TORCH_API
這個宏掀虎,不難發(fā)現(xiàn)其實該宏的定義來自于C10_IMPORT
這個宏烹玉。
#define TORCH_API C10_IMPORT
而C10_IMPORT
這個宏又定義于__declspec(dllimport)
,這也印證了前文所述址儒。
#define C10_IMPORT __declspec(dllimport)
除了TORCH_API
以外莲趣,這里還有個常用的宏函數(shù)TORCH_CHECK
用于對數(shù)據(jù)合法性進(jìn)行檢查喧伞。
intrusive_ptr_target類
TensorBase.h
文件的剩下部分都是對TensorBase
類的成員函數(shù)進(jìn)行定義潘鲫,幾乎每一個成員函數(shù)都用到了一個成員變量impl_
溉仑,同時這些用到該成員變量的成員函數(shù)基本都可以認(rèn)為是作用在這個成員變量上實現(xiàn)的浊竟,找到該文件的第890行必怜,這里給出了該變量的定義:
protected:
void enforce_invariants();
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
在這里梳庆,對這個成員變量而言膏执,我們需要注意兩個部分,一是它的數(shù)據(jù)類型,是個典型的模板類镇草,二是它的泛型聲明為了TensorImpl
和UndefinedTensorImpl
兩個類型梯啤。
我們首先來看TensorImpl
是什么因宇,實際上在TensorBody.h
中有這么一段注釋打厘,這段注釋位于75到91行户盯,它的大意可以理解為莽鸭,Tensor
本質(zhì)上是一個采用引用計數(shù)的對象,多個Tensor可以同時指向同一個TensorImpl
捺球。也就是說TensorBase
這個類本質(zhì)上是通過引用計數(shù)的方式指向TensorImpl
對象來實現(xiàn)底層操作的氮兵,即TensorBase
可以理解為對TensorImpl
的進(jìn)一步封裝。當(dāng)然南片,這樣的設(shè)計思想在各種工程項目中也隨處可見疼进。
// Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which
// has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr.
//
// For example:
//
// void func(Tensor a) {
// Tensor b = a;
// ...
// }
//
// In this example, when we say Tensor b = a, we are creating a new object that points to the
// same underlying TensorImpl, and bumps its reference count. When b goes out of scope, the
// destructor decrements the reference count by calling release() on the TensorImpl it points to.
// The existing constructors, operator overloads, etc. take care to implement the correct semantics.
//
// Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and
// special care must be taken to handle this.
另一方面,它的類型是c10::intrusive_ptr
嚼锄,而這是pytorch框架最基礎(chǔ)区丑,最核心的數(shù)據(jù)結(jié)構(gòu)代碼可霎,我們先來看一下官方是怎么描述的:
/**
* intrusive_ptr<T> is an alternative to shared_ptr<T> that has better
* performance because it does the refcounting intrusively
* (i.e. in a member of the object itself).
* Your class T needs to inherit from intrusive_ptr_target to allow it to be
* used in an intrusive_ptr<T>. Your class's constructor should not allow
*`this` to escape to other threads or create an intrusive_ptr from `this`.
*/
我們?nèi)ふ疫@個類型的定義啥纸,會發(fā)現(xiàn)在c10/util/intrusive_ptr.h
這個文件下,第54到90行有如下代碼婴氮,從這段代碼可以知道c10::intrusive_ptr
是intrusive_ptr_target
的友元類斯棒,除此之外,還有個值得注意的友元類weak_intrusive_ptr
主经。實際上荣暮,PyTorch中使用intrusive_ptr
來管理Tensor
和Storage
的引用計數(shù),其中引用分為強引用和弱引用(弱引用為了解決循環(huán)引用問題)穗酥,對應(yīng)的類名 intrusive_ptr
和weak_intrusive_ptr
。
class C10_API intrusive_ptr_target {
// Note [Weak references for intrusive refcounting]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Here's the scheme:
//
// - refcount == number of strong references to the object
// weakcount == number of weak references to the object,
// plus one more if refcount > 0
// An invariant: refcount > 0 => weakcount > 0
//
// - c10::StorageImpl stays live as long as there are any strong
// or weak pointers to it (weakcount > 0, since strong
// references count as a +1 to weakcount)
//
// - finalizers are called and data_ptr is deallocated when refcount == 0
//
// - Once refcount == 0, it can never again be > 0 (the transition
// from > 0 to == 0 is monotonic)
//
// - When you access c10::StorageImpl via a weak pointer, you must
// atomically increment the use count, if it is greater than 0.
// If it is not, you must report that the storage is dead.
//
mutable std::atomic<size_t> refcount_;
mutable std::atomic<size_t> weakcount_;
template <typename T, typename NullType>
friend class intrusive_ptr;
friend inline void raw::intrusive_ptr::incref(intrusive_ptr_target* self);
template <typename T, typename NullType>
friend class weak_intrusive_ptr;
friend inline void raw::weak_intrusive_ptr::incref(
intrusive_ptr_target* self);
template <typename T>
friend struct ExclusivelyOwnedTensorTraits;
當(dāng)然惠遏,不僅如此砾跃,除了intrusive_ptr
和weak_intrusive_ptr
是intrusive_ptr_target
的友元類之外,還可以知道的是节吮,TensorImpl
和UndefinedTensorImpl
均是intrusive_ptr_target
的子類抽高,從這里我們也可以看出intrusive_ptr_target
是pytorch得以實現(xiàn)的極為底層的核心數(shù)據(jù)結(jié)構(gòu)。
TensorImpl類
現(xiàn)在我們來看一下TensorImpl
類是如何組織的透绩,這部分定義位于c10/core/TensorImpl.h
的497到3060行翘骂,在這段代碼里,我們可以看到一些上文提到過的概念帚豪,比如TensorImpl
在初始化的時候其實傳入了一個Storage
對象碳竟,同時TensorImpl
也出現(xiàn)了上文提到過的sizes
。由于TensorBase
本質(zhì)上是對TesorImpl
的引用狸臣,那么Tensor
也就是對TensorImpl
的引用莹桅,所有當(dāng)我們創(chuàng)建torch::Tensor
類型時,實際上是執(zhí)行了這段代碼烛亦。
struct C10_API TensorImpl : public c10::intrusive_ptr_target {
TensorImpl() = delete;
~TensorImpl() override;
// Note [Enum ImplType]
// This enum is temporary. In the followup refactor we should
// think about how to specialize TensorImpl creation for view
// tensors. Currently we only special case its key_set_ but
// there's also potential to share version_counter_ directly
// without creating first and then override in as_view.
enum ImplType { VIEW };
/**
* Construct a 1-dim 0-size tensor backed by the given storage.
*/
TensorImpl(
Storage&& storage,
DispatchKeySet,
const caffe2::TypeMeta data_type);
// See Note [Enum ImplType]
TensorImpl(
ImplType,
Storage&& storage,
DispatchKeySet,
const caffe2::TypeMeta data_type);
...
TensorImpl(const TensorImpl&) = delete;
TensorImpl& operator=(const TensorImpl&) = delete;
TensorImpl(TensorImpl&&) = delete;
TensorImpl& operator=(TensorImpl&&) = delete;
...
public:
/**
* Return a reference to the sizes of this tensor. This reference remains
* valid as long as the tensor is live and not resized.
*/
IntArrayRef sizes() const {
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
return sizes_custom();
}
return sizes_and_strides_.sizes_arrayref();
}
...
StorageImpl類
StorageImpl
類位于c10/core/StorageImpl.h
中诈泼,它也是intrusive_ptr_target
的子類,它的構(gòu)造函數(shù)里有不少pytorch中的核心概念此洲,如Allocator厂汗,resizable等委粉,這些概念這篇文章不做討論呜师,將在后續(xù)的系列談及。
struct C10_API StorageImpl : public c10::intrusive_ptr_target {
public:
struct use_byte_size_t {};
StorageImpl(
use_byte_size_t /*use_byte_size*/,
SymInt size_bytes,
at::DataPtr data_ptr,
at::Allocator* allocator,
bool resizable)
: data_ptr_(std::move(data_ptr)),
size_bytes_(std::move(size_bytes)),
size_bytes_is_heap_allocated_(size_bytes_.is_heap_allocated()),
resizable_(resizable),
received_cuda_(false),
allocator_(allocator) {
if (resizable) {
TORCH_INTERNAL_ASSERT(
allocator_, "For resizable storage, allocator must be provided");
}
}
StorageImpl(
use_byte_size_t /*use_byte_size*/,
const SymInt& size_bytes,
at::Allocator* allocator,
bool resizable)
: StorageImpl(
use_byte_size_t(),
size_bytes,
size_bytes.is_heap_allocated()
? allocator->allocate(0)
: allocator->allocate(size_bytes.as_int_unchecked()),
allocator,
resizable) {}
StorageImpl& operator=(StorageImpl&& other) = delete;
StorageImpl& operator=(const StorageImpl&) = delete;
StorageImpl() = delete;
StorageImpl(StorageImpl&& other) = delete;
StorageImpl(const StorageImpl&) = delete;
~StorageImpl() override = default;
...
總結(jié)贾节,本文主要講述了Tensor
的源碼實現(xiàn)機理汁汗,以及Storage
和Tensor
的關(guān)系衷畦,下一篇文章將會講述pytorch是如何實現(xiàn)C++與python綁定的。