本章中我們分析一下TensorFlow的XLA(Accelerated Linear Algebra 加速線性代數(shù))的內(nèi)核實(shí)現(xiàn)坏挠。代碼位置在tensorflow/compiler.
XLA
在XLA技術(shù)之前,TensorFlow中計(jì)算圖的執(zhí)行是由runtime(運(yùn)行時(shí))代碼驅(qū)動(dòng)的:runtime負(fù)責(zé)加載計(jì)算圖定義、創(chuàng)建計(jì)算圖、計(jì)算圖分區(qū)、計(jì)算圖優(yōu)化、分配設(shè)備株憾、管理節(jié)點(diǎn)間的依賴并調(diào)度節(jié)點(diǎn)kernel的執(zhí)行;計(jì)算圖是數(shù)據(jù)部分晒衩,runtime是代碼部分嗤瞎。在第五章session類的實(shí)現(xiàn)分析中,我們已經(jīng)比較詳細(xì)的分析了這個(gè)過(guò)程听系。在XLA出現(xiàn)之后贝奇,我們有了另一個(gè)選擇,計(jì)算圖現(xiàn)在可以直接被編譯成目標(biāo)平臺(tái)的可執(zhí)行代碼靠胜,可以直接執(zhí)行掉瞳,不需要runtime代碼的參與了。
本章我就來(lái)分析一下XLA是如何將tensorflow.GraphDef編譯成可執(zhí)行代碼的髓帽。
目前XLA提供了AOT(提前編譯)和JIT(即時(shí)編譯)兩種方式。
AOT
在編譯技術(shù)里脑豹,AOT(提前編譯)方式就是在代碼執(zhí)行階段之前全部編譯成目標(biāo)指令郑藏,進(jìn)入執(zhí)行階段后,不再有編譯過(guò)程發(fā)生瘩欺。
tensorflow的官網(wǎng)已經(jīng)介紹了一個(gè)AOT的使用例子必盖,這里引用一下這個(gè)例子,代碼位于tensorflow/compiler/aot/tests/make_test_graphs.py俱饿,函數(shù)tfmatmul構(gòu)建了一個(gè)簡(jiǎn)單的網(wǎng)絡(luò)如下:
例子中歌粥,我們將使用XLA的AOT方式將這計(jì)算圖編譯成可執(zhí)行文件,需要四步:
步驟1:編寫(xiě)配置
配置網(wǎng)絡(luò)的輸入和輸出節(jié)點(diǎn)拍埠,對(duì)應(yīng)生成函數(shù)的輸入輸出參數(shù)失驶。
/* tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt */
# Each feed is a positional input argument for the generated function. The order
# of each entry matches the order of each input argument. Here “x_hold” and “y_hold”
# refer to the names of placeholder nodes defined in the graph.
feed {
id { node_name: "x_hold" }
shape {
dim { size: 2 }
dim { size: 3 }
}
}
feed {
id { node_name: "y_hold" }
shape {
dim { size: 3 }
dim { size: 2 }
}
}
# Each fetch is a positional output argument for the generated function. The order
# of each entry matches the order of each output argument. Here “x_y_prod”
# refers to the name of a matmul node defined in the graph.
fetch {
id { node_name: "x_y_prod" }
}
步驟2:使用tf_library構(gòu)建宏來(lái)編譯子圖為靜態(tài)鏈接庫(kù)
load("http://third_party/tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
# Use the tf_library macro to compile your graph into executable code.
tf_library(
# name is used to generate the following underlying build rules:
# <name> : cc_library packaging the generated header and object files
# <name>_test : cc_test containing a simple test and benchmark
# <name>_benchmark : cc_binary containing a stand-alone benchmark with minimal deps;
# can be run on a mobile device
name = "test_graph_tfmatmul",
# cpp_class specifies the name of the generated C++ class, with namespaces allowed.
# The class will be generated in the given namespace(s), or if no namespaces are
# given, within the global namespace.
cpp_class = "foo::bar::MatMulComp",
# graph is the input GraphDef proto, by default expected in binary format. To
# use the text format instead, just use the ‘.pbtxt’ suffix. A subgraph will be
# created from this input graph, with feeds as inputs and fetches as outputs.
# No Placeholder or Variable ops may exist in this subgraph.
graph = "test_graph_tfmatmul.pb",
# config is the input Config proto, by default expected in binary format. To
# use the text format instead, use the ‘.pbtxt’ suffix. This is where the
# feeds and fetches were specified above, in the previous step.
config = "test_graph_tfmatmul.config.pbtxt",
)
步驟3:編寫(xiě)代碼以調(diào)用子圖
第二步會(huì)生成一個(gè)頭文件和Object文件,頭文件test_graph_tfmatmul.h的內(nèi)容如下:
/* test_graph_tfmatmul.h */
namespace foo {
namespace bar {
// MatMulComp represents a computation previously specified in a
// TensorFlow graph, now compiled into executable code.
class MatMulComp {
public:
// AllocMode controls the buffer allocation mode.
enum class AllocMode {
ARGS_RESULTS_AND_TEMPS, // Allocate arg, result and temp buffers
RESULTS_AND_TEMPS_ONLY, // Only allocate result and temp buffers
};
MatMulComp(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS);
~MatMulComp();
// Runs the computation, with inputs read from arg buffers, and outputs
// written to result buffers. Returns true on success and false on failure.
bool Run();
// Arg methods for managing input buffers. Buffers are in row-major order.
// There is a set of methods for each positional argument.
void** args();
void set_arg0_data(float* data);
float* arg0_data();
float& arg0(size_t dim0, size_t dim1);
void set_arg1_data(float* data);
float* arg1_data();
float& arg1(size_t dim0, size_t dim1);
// Result methods for managing output buffers. Buffers are in row-major order.
// Must only be called after a successful Run call. There is a set of methods
// for each positional result.
void** results();
float* result0_data();
float& result0(size_t dim0, size_t dim1);
};
} // end namespace bar
} // end namespace foo
引用頭文件枣购,編寫(xiě)使用端代碼:
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL
#include <iostream>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" // generated
int main(int argc, char** argv) {
Eigen::ThreadPool tp(2); // Size the thread pool as appropriate.
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
foo::bar::MatMulComp matmul;
matmul.set_thread_pool(&device);
// Set up args and run the computation.
const float args[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
std::copy(args + 0, args + 6, matmul.arg0_data());
std::copy(args + 6, args + 12, matmul.arg1_data());
matmul.Run();
// Check result
if (matmul.result0(0, 0) == 58) {
std::cout << "Success" << std::endl;
} else {
std::cout << "Failed. Expected value 58 at 0,0. Got:"
<< matmul.result0(0, 0) << std::endl;
}
return 0;
}
步驟4:使用cc_binary創(chuàng)建最終的可執(zhí)行二進(jìn)制文件
# Example of linking your binary
# Also see //third_party/tensorflow/compiler/aot/tests/BUILD
load("http://third_party/tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
# The same tf_library call from step 2 above.
tf_library(
name = "test_graph_tfmatmul",
...
)
# The executable code generated by tf_library can then be linked into your code.
cc_binary(
name = "my_binary",
srcs = [
"my_code.cc", # include test_graph_tfmatmul.h to access the generated header
],
deps = [
":test_graph_tfmatmul", # link in the generated object file
"http://third_party/eigen3",
],
linkopts = [
"-lpthread",
]
)
四步編譯出了可執(zhí)行的文件嬉探,但是其實(shí)第二步中,tf_library宏的輸出就是計(jì)算圖對(duì)應(yīng)的可執(zhí)行文件了棉圈,包含一個(gè)頭文件和Object文件涩堤。 所以計(jì)算圖的編譯工作主要在tf_library完成的,我們來(lái)分析一下tf_library的實(shí)現(xiàn), tf_library定義在文件tensorflow/compiler/aot/tfcompile.bzl中:
/* tensorflow/compiler/aot/tfcompile.bzl */
...
def tf_library(name, graph, config,
freeze_checkpoint=None, freeze_saver=None,
cpp_class=None, gen_test=True, gen_benchmark=True,
visibility=None, testonly=None,
tfcompile_flags=None,
tfcompile_tool="http://tensorflow/compiler/aot:tfcompile",
deps=None, tags=None):
...
# Rule that runs tfcompile to produce the header and object file.
header_file = name + ".h"
object_file = name + ".o"
ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_")
native.genrule(
name=("gen_" + name),
srcs=[
tfcompile_graph,
config,
],
outs=[
header_file,
object_file,
],
cmd=("$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_header=$(@D)/" + header_file +
" --out_object=$(@D)/" + object_file +
" " + (tfcompile_flags or "")),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,
# Run tfcompile on the build host since it's typically faster on the local
# machine.
#
# Note that setting the local=1 attribute on a *test target* causes the
# test infrastructure to skip that test. However this is a genrule, not a
# test target, and runs with --genrule_strategy=forced_forge, meaning the
# local=1 attribute is ignored, and the genrule is still run.
#
# https://www.bazel.io/versions/master/docs/be/general.html#genrule
local=1,
tags=tags,
)
...
上面我節(jié)選了tf_library代碼中關(guān)鍵的一步分瘾,這步調(diào)用tfcompile_tool命令行工具胎围,生成頭文件和二進(jìn)制問(wèn)題。可以看到調(diào)用tfcompile_tool的命令行包括--graph白魂,--config等等汽纤。
tfcompile_tool的入口main函數(shù)定義在tensorflow/compiler/aot/tfcompile_main.cc中,編譯過(guò)程主要分為四步:
1碧聪、由GraphDef構(gòu)建tensorflow.Graph冒版。
2、調(diào)用xla.XlaCompiler.CompileGraph逞姿,將tensorflow.Graph編譯為xla.Computation辞嗡。
3、調(diào)用xla.CompileOnlyClient.CompileAheadOfTime函數(shù)滞造,將xla.Computation編譯為可執(zhí)行代碼续室。
4、保存編譯結(jié)果到頭文件和object文件
TensorFlow目前支持的AOT編譯的平臺(tái)有x86-64和ARM.
JIT
JIT全稱Just In Time(即時(shí)).在即時(shí)編譯中谒养,計(jì)算圖在不會(huì)在運(yùn)行階段前被編譯成可執(zhí)行代碼挺狰,而是在進(jìn)入運(yùn)行階段后的適當(dāng)?shù)臅r(shí)機(jī)才會(huì)被編譯成可執(zhí)行代碼,并且可以被直接調(diào)用了买窟。
關(guān)于JIT編譯與AOT編譯優(yōu)缺點(diǎn)的對(duì)比丰泊,不是本章的主題,限于篇幅這里不做過(guò)多的分析了始绍。我們直接來(lái)看TensorFlow中JIT的實(shí)現(xiàn)瞳购。
Python API中打開(kāi)JIT支持的方式有一下幾種:
方式一、通過(guò)Session設(shè)置:
這種方式的影響是Session范圍的亏推,內(nèi)核會(huì)編譯盡可能多的節(jié)點(diǎn)学赛。
# Config to turn on JIT compilation
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
sess = tf.Session(config=config)
方式二、通過(guò)tf.contrib.compiler.jit.experimental_jit_scope():
這種方式影響scope內(nèi)的所有節(jié)點(diǎn)吞杭,這種方式會(huì)對(duì)Scope內(nèi)的所有節(jié)點(diǎn)添加一個(gè)屬性并設(shè)置為true: _XlaCompile=true.
jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
x = tf.placeholder(np.float32)
with jit_scope():
y = tf.add(x, x) # The "add" will be compiled with XLA.
方式三盏浇、通過(guò)設(shè)置device:
通過(guò)設(shè)置運(yùn)行的Device來(lái)啟動(dòng)JIT支持。
with tf.device("/job:localhost/replica:0/task:0/device:XLA_GPU:0"):
output = tf.add(input1, input2)
接下來(lái)我們來(lái)分析一下這個(gè)問(wèn)題:上面的這些接口層的設(shè)置芽狗,最終是如何影響內(nèi)核中計(jì)算圖的計(jì)算的呢绢掰?
首先來(lái)回憶一下 TensorFlow技術(shù)內(nèi)幕(五):核心概念的實(shí)現(xiàn)分析 的圖4,session的本地執(zhí)行這一節(jié):graph在運(yùn)行前童擎,需要經(jīng)過(guò)一系列優(yōu)化和重構(gòu)(包括前一章中分析的grappler模塊的優(yōu)化)曼月。其中一步涉及到類:tensorflow.OptimizationPassRegistry,通過(guò)此類我們可以運(yùn)行其中注冊(cè)的tensorflow.GraphOptimizationPass的子類柔昼,每一個(gè)子類都是實(shí)現(xiàn)了一種graph的優(yōu)化和重構(gòu)的邏輯哑芹。XLA JIT 相關(guān)的Graph優(yōu)化和重構(gòu),也是通過(guò)這個(gè)入口來(lái)執(zhí)行的捕透。
JIT相關(guān)的tensorflow.GraphOptimizationPass注冊(cè)代碼在:
/* tensorflow/compiler/jit/jit_compilation_pass_registration.cc */
...
namespace tensorflow {
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
// also need to run it after the graph been rewritten to have _Send nodes added
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
// name, and encapsulation might remove that node from the graph.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
EncapsulateSubgraphsPass);
// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
BuildXlaLaunchOpsPass);
} // namespace tensorflow
...
可以看到JIT編譯相關(guān)的tensorflow.GraphOptimizationPass有三個(gè):
1聪姿、tensorflow.MarkForCompilationPass:
上面提到的開(kāi)啟JIT的三種設(shè)置方式碴萧,就是在此類中進(jìn)行檢查的。通過(guò)檢查這些設(shè)置末购,此類首先會(huì)挑選出所有開(kāi)啟JIT并且目前版本支持JIT編譯的節(jié)點(diǎn)破喻,并且運(yùn)行聚類分析,將這些等待JIT編譯的節(jié)點(diǎn)分到若干個(gè)Cluster中盟榴,看一下下面的例子:
B,C節(jié)點(diǎn)被標(biāo)記到cluster 1曹质,E,F(xiàn)節(jié)點(diǎn)被標(biāo)記到cluster 0. A擎场,E應(yīng)為不支持編譯所以沒(méi)有被分配cluster.
2羽德、tensorflow.EncapsulateSubgraphsPass:
這一步優(yōu)化分三步,
第一步 :為上一個(gè)優(yōu)化類MarkForCompilationPass mark形成的cluster分別創(chuàng)建對(duì)應(yīng)的SubGraph對(duì)象迅办。
第二步:為每個(gè)SubGraph對(duì)象創(chuàng)建對(duì)應(yīng)的FunctionDef宅静,并將創(chuàng)建的FunctionDef添加到FunctionLibrary中。
這里補(bǔ)充一下TensorFlow中Funtion的概念站欺,F(xiàn)ucntionDef的定義如下:
/* tensorflow/core/framework/function.proto */
message FunctionDef {
// The definition of the function's name, arguments, return values,
// attrs etc.
OpDef signature = 1;
map<string, AttrValue> attr = 5;
repeated NodeDef node_def = 3;
map<string, string> ret = 4;
}
Function可以看做一個(gè)獨(dú)立的計(jì)算圖姨夹,node_def就是這個(gè)子圖包含的所有節(jié)點(diǎn)。Function可以被實(shí)例化和調(diào)用矾策,方式是向調(diào)用方的計(jì)算圖中插入一個(gè)Call節(jié)點(diǎn)磷账,這類節(jié)點(diǎn)的運(yùn)算核(OpKernel)是CallOp:
我們知道計(jì)算圖的計(jì)算最終是由Executor對(duì)象驅(qū)動(dòng)的,CallOp是連接調(diào)用方計(jì)算圖的Executor和Function內(nèi)計(jì)算圖的橋梁:CallOp對(duì)外響應(yīng)Executor的調(diào)用贾虽,對(duì)內(nèi)會(huì)為每次調(diào)用創(chuàng)建一個(gè)獨(dú)立的Executor來(lái)驅(qū)動(dòng)Function內(nèi)部計(jì)算圖的運(yùn)算逃糟。
第三步:重新創(chuàng)建一張新的計(jì)算圖,首先將原計(jì)算圖中沒(méi)有被mark的節(jié)點(diǎn)直接拷貝過(guò)來(lái)榄鉴,然后為每個(gè)SubGraph對(duì)應(yīng)的Function創(chuàng)建CallOp節(jié)點(diǎn)履磨,最后創(chuàng)建計(jì)算圖中數(shù)據(jù)和控制依賴關(guān)系蛉抓。
下面的例子中庆尘,就將C和c節(jié)點(diǎn)一起,替換成了F1節(jié)點(diǎn)巷送,調(diào)用了Function F1:
3驶忌、tensorflow.BuildXlaLaunchOpsPass:
經(jīng)過(guò)EncapsulateSubgraphsPass優(yōu)化的計(jì)算圖中的function call節(jié)點(diǎn)全部替換成xlalaunch節(jié)點(diǎn)。
JIT的關(guān)鍵就是這個(gè)xlalaunch節(jié)點(diǎn)笑跛。xlalaunch節(jié)點(diǎn)節(jié)點(diǎn)的運(yùn)算名為"_XlaLaunch",運(yùn)算核是XlaLocalLaunchOp付魔,按照運(yùn)算核的要求它的父類也是OpKernel。
XlaLocalLaunchOp對(duì)外響應(yīng)Executor的調(diào)用請(qǐng)求飞蹂,對(duì)內(nèi)調(diào)用JIT相關(guān)API類編譯和執(zhí)行FunctionDef几苍。當(dāng)然對(duì)編譯結(jié)果會(huì)有緩存操作,沒(méi)必要每次調(diào)用都走一次編譯過(guò)程:
步驟一:調(diào)用XlaCompilationCache的將FunctionDef編譯為xla.LocalExecutable陈哑。在cache沒(méi)命中的情況下妻坝,會(huì)調(diào)用xla.LocalClient執(zhí)行真正的編譯
步驟二:調(diào)用xla.LocalExecutable.Run
JIT方式目前支持的平臺(tái)有X86-64, NVIDIA GPU伸眶。
小結(jié)
以上分析的是XLA在TensorFlow中的調(diào)用方式:AOT方式和JIT方式。
兩種方式下都會(huì)將整個(gè)計(jì)算圖或則計(jì)算圖的一部分直接編譯成可執(zhí)行代碼刽宪。兩則的區(qū)別也是比較明顯的厘贼,除了編譯時(shí)機(jī)不一樣外,還有就是runtime(運(yùn)行時(shí))的參與程度圣拄。AOT中徹底不需要運(yùn)行時(shí)的參與了嘴秸,而JIT中還是需要運(yùn)行時(shí)參與的,但是JIT會(huì)優(yōu)化融合原計(jì)算圖中的節(jié)點(diǎn)庇谆,加入XlaLaunch節(jié)點(diǎn)岳掐,來(lái)加速計(jì)算圖的執(zhí)行。
后面我們會(huì)詳細(xì)分析一下XLA這個(gè)編譯器的內(nèi)部實(shí)現(xiàn)族铆。