一移斩、TVM的工作流程
TVM主要由兩個(gè)部分組成:
(1)TVM編譯器:負(fù)責(zé)編譯和優(yōu)化模型
(2)TVM runtime:提供目標(biāo)設(shè)備上運(yùn)行模型的API
1伶唯、整體流程
如圖所示赖瞒,TVM的工作流程包括4個(gè)主要部分:
前端導(dǎo)入(Import):前端部件將不同神經(jīng)網(wǎng)絡(luò)框架所訓(xùn)練得到的模型文件轉(zhuǎn)化為IRModule,IRModule是TVM核心的數(shù)據(jù)結(jié)構(gòu)之一,它包含了可以表述模型的函數(shù)集合原朝。
編譯轉(zhuǎn)化(Transformation):編譯器對(duì)IRModule通過(guò)各種Relay Passes的優(yōu)化規(guī)則進(jìn)行優(yōu)化踊淳,比如對(duì)模型進(jìn)行量化等假瞬。
目標(biāo)代碼轉(zhuǎn)換(Target Translation):既然IRModule是一個(gè)函數(shù)集合,那編譯器就可以將IRModule交叉編譯為目標(biāo)設(shè)備可運(yùn)行的格式迂尝,并提供了導(dǎo)出脱茉,加載和執(zhí)行的API給目標(biāo)設(shè)備調(diào)用。
運(yùn)行(Runtime Execution):用戶將交叉編譯得到的Module加載到設(shè)備上并執(zhí)行其中的函數(shù)集合垄开。
2琴许、關(guān)鍵數(shù)據(jù)結(jié)構(gòu)
IRModule (intermediate representation module)是貫穿整個(gè)TVM的數(shù)據(jù)結(jié)構(gòu),重要性不言而喻溉躲。它是一系列Function的集合榜田,用于表述一個(gè)神經(jīng)網(wǎng)絡(luò)模型寸认,目前TVM支持兩種主要的變體函數(shù):
relay::Function 是高層級(jí)的函數(shù)編程表示,一個(gè)relay.Function通常對(duì)應(yīng)一個(gè)端到端的模型串慰∑可以將它理解為一個(gè)支持控制流、遞歸和復(fù)雜數(shù)據(jù)結(jié)構(gòu)的計(jì)算圖邦鲫。
tir::PrimFunc 是低層級(jí)的函數(shù)編程表示灸叼,它包括循環(huán)嵌套、多維加載與存儲(chǔ)庆捺,線程處理以及向量和張量操作指令古今。它通常用以定義一個(gè)算子的操作,對(duì)應(yīng)一個(gè)模型中的某一層滔以。
在整個(gè)編譯過(guò)程中捉腥,一個(gè)relay function可能會(huì)被優(yōu)化為多個(gè)tir::PrimFunc。
3你画、Transformations
transformation的作用有兩個(gè):
(1)優(yōu)化(optimization):將程序轉(zhuǎn)換為等效的抵碟、或者更優(yōu)化的版本;
(2)底層表示(lowering):將程序轉(zhuǎn)換為更接近目標(biāo)設(shè)備的低層級(jí)表示坏匪。
relay/transform 包含一組優(yōu)化模型的passes拟逮。優(yōu)化包括constant folding和dead-code消除,以及針對(duì)張量計(jì)算的優(yōu)化适滓,如layout轉(zhuǎn)換和scaling factor folding敦迄。
在relay優(yōu)化的pipeline的最后,會(huì)運(yùn)行一個(gè) FuseOps的pass凭迹,將一個(gè)完整的Function(對(duì)應(yīng)一個(gè)端到端的模型如 MobileNet)分解為多個(gè)子Funcions(例如 conv2d-relu)段罚屋。這樣做的好處是將問(wèn)題分成了兩個(gè)子問(wèn)題:
編譯和優(yōu)化可以針對(duì)每個(gè)子Function。TVM使用低層級(jí)的 tir 來(lái)編譯和優(yōu)化每個(gè)子功能嗅绸。對(duì)于特定目標(biāo)設(shè)備脾猛,也可以直接使用外部代碼生成器進(jìn)行目標(biāo)代碼轉(zhuǎn)換。
整體運(yùn)行時(shí)需要調(diào)用所有的子Function朽砰。TVM支持幾種運(yùn)行方式尖滚,所有運(yùn)行模式都封裝在一個(gè)統(tǒng)一的 runtime.Module 接口中:
對(duì)于形狀已知且沒有控制流的簡(jiǎn)單模型,我們可以降級(jí)為將執(zhí)行結(jié)構(gòu)存儲(chǔ)在圖中的圖執(zhí)行器瞧柔;
支持用于動(dòng)態(tài)執(zhí)行的虛擬機(jī)后端漆弄;
后續(xù)計(jì)劃支持提前編譯,將高層級(jí)的運(yùn)行結(jié)構(gòu)編譯為可運(yùn)行的原始functions造锅。
tir/transform 包含 TIR 層級(jí)functions的轉(zhuǎn)換passes撼唾。例如,將multi-dimensional access flatten到一維訪問(wèn)哥蔚,將內(nèi)在函數(shù)擴(kuò)展為特定于目標(biāo)的函數(shù)倒谷,以及修飾函數(shù)入口以滿足運(yùn)行時(shí)調(diào)用約束蛛蒙。除此之外,也有優(yōu)化passes渤愁,如access index簡(jiǎn)化和dead-code消除牵祟。
4、搜索空間和基于機(jī)器學(xué)習(xí)的轉(zhuǎn)換
前面描述的轉(zhuǎn)換都是基于規(guī)則和確定的抖格,而TVM的設(shè)計(jì)目標(biāo)之一是支持對(duì)于不同的硬件平臺(tái)都可以進(jìn)行高性能的代碼優(yōu)化诺苹。因此,需要對(duì)盡可能多的優(yōu)化進(jìn)行選擇雹拄,每個(gè)優(yōu)化又需要選擇最優(yōu)的參數(shù)收奔,從這個(gè)角度來(lái)看,這個(gè)的工作量無(wú)疑是巨大的滓玖。TVM采用了基于空間搜索和機(jī)器學(xué)習(xí)的方法來(lái)解決這個(gè)優(yōu)化選擇與調(diào)參的問(wèn)題坪哄。
顧名思義,空間搜索需要在特定的空間势篡,所以首先需要定義一系列的轉(zhuǎn)換操作翩肌,比如循環(huán)轉(zhuǎn)換、內(nèi)聯(lián)殊霞、矢量化等摧阅。這些操作稱為調(diào)度原語(yǔ)(scheduling primitives)。調(diào)度原語(yǔ)的集合定義了可以對(duì)程序進(jìn)行優(yōu)化的搜索空間绷蹲,然后TVM搜索不同的調(diào)度順序以挑選最佳調(diào)度組合。這個(gè)搜索過(guò)程通常由機(jī)器學(xué)習(xí)算法完成顾孽,TVM使用的是xgboost算法祝钢。在搜索完成后,記錄下每個(gè)算子最優(yōu)的調(diào)度順序若厚,然后編譯器就可以將此調(diào)度序列應(yīng)用到程序中拦英。TVM使用基于搜索的優(yōu)化方法來(lái)處理初始 tir function生成問(wèn)題。這部分模塊稱為 AutoTVM(auto_scheduler)测秸。
5疤估、目標(biāo)代碼轉(zhuǎn)化
目標(biāo)代碼轉(zhuǎn)換階段主要是將 IRModule 轉(zhuǎn)換為可以相應(yīng)的目標(biāo)設(shè)備上運(yùn)行的格式:
對(duì)于 x86 和 ARM 等后端,使用 LLVM IRBuilder 來(lái)構(gòu)建 LLVM IR霎冯;
支持生成例如 CUDA C 和 OpenCL等源碼級(jí)的代碼铃拇;
支持通過(guò)外部代碼生成器將Relay function(子圖)直接轉(zhuǎn)換為特定目標(biāo)代碼。
代碼生成階段需要盡可能地輕量化沈撞,所以絕大多數(shù)的轉(zhuǎn)換和降層級(jí)都應(yīng)該放在目標(biāo)代碼轉(zhuǎn)換之前執(zhí)行慷荔。
二、邏輯架構(gòu)組件
上圖顯示了TVM中的主要邏輯組件:
-
tvm/support
包含tvm最常用的實(shí)用工具函數(shù)缠俺,例如通用的 arena 分配器显晶、套接字和日志記錄等贷岸。
-
tvm/runtime
runtime作為tvm的基礎(chǔ)組件,提供了加載和執(zhí)行編譯的機(jī)制磷雇,它定義了一組標(biāo)準(zhǔn)的C API與前端高級(jí)語(yǔ)言如Python和Rust進(jìn)行交互偿警。在runtime中,runtime::Object是主要的數(shù)據(jù)結(jié)構(gòu)之一唯笙,它是一個(gè)帶有類型索引的引用計(jì)數(shù)基類螟蒸,用于支持運(yùn)行時(shí)類型檢查和向下類型轉(zhuǎn)換。通過(guò)它可以向runtime引入新的數(shù)據(jù)結(jié)構(gòu)睁本,如 Array尿庐、Map 和新的 IR 數(shù)據(jù)結(jié)構(gòu)。
編譯器本身也大量使用了 TVM 的runtime機(jī)制呢堰。所有 IR 數(shù)據(jù)結(jié)構(gòu)都是runtime::Object 的子類抄瑟,因此,它們可以直接通過(guò) Python 前端進(jìn)行操作枉疼,tvm使用 PackedFunc 機(jī)制向前端公開各種 API皮假。
runtime/rpc實(shí)現(xiàn)了對(duì) PackedFunc 的 RPC 支持,由此可以將交叉編譯的庫(kù)發(fā)送到遠(yuǎn)程設(shè)備并測(cè)試性能骂维。因?yàn)閞pc架構(gòu)支持從各種遠(yuǎn)程硬件后端收集數(shù)據(jù)惹资,所以它是基于機(jī)器學(xué)習(xí)優(yōu)化方法的基礎(chǔ)。
-
tvm/node
node 模塊在runtime::Object之上為 IR 數(shù)據(jù)結(jié)構(gòu)添加了附加功能航闺。主要功能包括反射褪测、序列化、結(jié)構(gòu)等價(jià)和散列潦刃。還可以將任意 IR 節(jié)點(diǎn)序列化為 JSON 格式侮措,然后將它們加載回來(lái)。保存/存儲(chǔ)和檢查 IR 節(jié)點(diǎn)的能力為使編譯器更易于訪問(wèn)奠定了基礎(chǔ)乖杠。
-
tvm/ir
在tvm/ir文件夾中包含跨所有IR功能變異體的統(tǒng)一的數(shù)據(jù)結(jié)構(gòu)和接口分扎。tvm/ir中的組件由tvm/relay和tvm/tir共享,包括IRModule胧洒、Type畏吓、PassContext 和 Pass、OP卫漫。
-
tvm/target
target模塊包含將 IRModule 轉(zhuǎn)換為目標(biāo) runtime.Module 的所有代碼生成器菲饼。它還提供了一個(gè)通用的Target類來(lái)描述目標(biāo)。通過(guò)查詢target中的屬性信息和注冊(cè)到每個(gè)target id(cuda, opencl)的內(nèi)置信息汛兜,可以根據(jù)target定制編譯流水線巴粪。
-
tvm/tir
tir 包含低層級(jí)程序表示的定義,使用tir::PrimFunc來(lái)表示可以通過(guò) tir 通道轉(zhuǎn)換的函數(shù)。除了 IR 數(shù)據(jù)結(jié)構(gòu)之外肛根,tir 模塊還通過(guò)公共 Op 注冊(cè)表定義了一組內(nèi)置函數(shù)及屬性辫塌,以及tir/transform 中的轉(zhuǎn)換passes。
-
tvm/arith
該模塊與 tir 密切相關(guān)派哲。低層級(jí)代碼生成中的關(guān)鍵問(wèn)題之一是對(duì)索引算術(shù)屬性的分析臼氨。arith 模塊提供了一組(主要是整數(shù))分析工具。tir 通過(guò)可以使用這些分析來(lái)簡(jiǎn)化和優(yōu)化代碼芭届。
-
tvm/te
te 代表“張量表達(dá)式”储矩,通過(guò)編寫張量表達(dá)式可以快速構(gòu)建tir::PrimFunc變體。te/schedule提供了一組調(diào)度原語(yǔ)來(lái)控制正在生成的函數(shù)褂乍。
-
tvm/topi
雖然可以為每個(gè)用例直接通過(guò) TIR 或張量表達(dá)式 (TE) 構(gòu)造運(yùn)算符持隧,但這樣做很乏味。topi(張量運(yùn)算符清單)提供了一組由 numpy 定義并在常見深度學(xué)習(xí)工作負(fù)載中找到的預(yù)定義運(yùn)算符(在 TE 或 TIR 中)逃片。我們還提供了一組通用計(jì)劃模板屡拨,以獲得跨不同目標(biāo)平臺(tái)的高性能實(shí)現(xiàn)。
-
tvm/relay
relay 是用于表示完整模型的高級(jí)功能 IR褥实。在relay.transform中定義了各種優(yōu)化呀狼。relay 編譯器定義了多種優(yōu)化策略,每種策略都旨在支持特定的優(yōu)化方式损离。
-
tvm/autotvm
AutoTVM和AutoScheduler是自動(dòng)搜索優(yōu)化所必須的兩個(gè)組件哥艇。主要包括:cost models和特征提取僻澎;用于存儲(chǔ)運(yùn)行cost models性能結(jié)果的格式以及一組變換搜索策略貌踏。
三、運(yùn)行TVM實(shí)例
1窟勃、交叉編譯runtime
想要在目標(biāo)設(shè)備上運(yùn)行模型的前提是交叉編譯模型和runtime庫(kù)哩俭。以Raspberry Pi為例,首先需要在主機(jī)安裝Raspberry Pi的編譯工具鏈:
sudo apt-get update
sudo apt-get install gcc-aarch64-linux-gnu g++-aarch64-linux-gnu
sudo apt-get install gcc-multilib-arm-linux-gnueabihf g++-multilib-arm-linux-gnueabihf
然后交叉編譯TVM runtime庫(kù):
cmake .. \
-DCMAKE_SYSTEM_NAME=Linux \
-DCMAKE_SYSTEM_VERSION=1 \
-DCMAKE_C_COMPILER=/usr/bin/aarch64-linux-gnu-gcc \
-DCMAKE_CXX_COMPILER=/usr/bin/aarch64-linux-gnu-g++ \
-DCMAKE_FIND_ROOT_PATH=/usr/aarch64-linux-gnu \
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
-DMACHINE_NAME=aarch64-linux-gnu
make -j2 runtime
編譯完成后使用file命令查看編譯出來(lái)的runtime庫(kù)是否OK:
2拳恋、編譯模型
在主機(jī)上構(gòu)造一個(gè)簡(jiǎn)單的kernel,并在主機(jī)上編譯砸捏,示例代碼如下:
import numpy as np
import tvm
from tvm import te
from tvm import rpc
from tvm.contrib import utils
# 構(gòu)造計(jì)算核
n = tvm.runtime.convert(1024)
A = te.placeholder((n,), name="A")
B = te.compute((n,), lambda i: A[i] + 1.0, name="B")
s = te.create_schedule(B.op)
# 編譯并保存結(jié)果:local_demo為True表示編譯target為主機(jī)端運(yùn)行谬运,否則為raspbarry pi
local_demo = True
if local_demo:
target = "llvm"
else:
target = "llvm -mtriple=armv7l-linux-gnueabihf"
func = tvm.build(s, [A, B], target=target, name="add_one") # 為目標(biāo)設(shè)備生成代碼
print(func)
path = "./tvm_test_lib.tar"
func.export_library(path)
運(yùn)行代碼后會(huì)得到tvm_test_lib.tar的編譯結(jié)果,func的打印輸出為:
Module(llvm, 56334d7e8738)
它是一個(gè) tvm.runtime.PackedFunc 類型垦藏,TVM使用Function開作為前后端的黏合梆暖,一個(gè)編譯后的module返回Function,TVM后端同樣也以Functions的方式注冊(cè)和暴露API掂骏。
3轰驳、運(yùn)行模型
將它用rpc的方式運(yùn)行在設(shè)備上,需要將編譯的lib上傳到設(shè)備,然后使用設(shè)備端的編譯器重新鏈接之后级解,func就是一個(gè)設(shè)備端的模型對(duì)象了冒黑。
if local_demo:
remote = rpc.LocalSession()
else:
host = "192.168.1.111"
port = 9090
remote = rpc.connect(host, port)
remote.upload(path)
func = remote.load_module("tvm_test_lib.tar")
print(func)
dev = remote.cpu()
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), dev)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev)
func(a, b)
np.testing.assert_equal(b.numpy(), a.numpy() + 1)
time_f = func.time_evaluator(func.entry_name, dev, number=10)
cost = time_f(a, b).mean
print("%g secs/op" % cost)
此時(shí)的func打印輸出是:
Module(rpc, 56334d6d7148)
四、總結(jié)
本文介紹了TVM的工作流程和內(nèi)部的邏輯框架組件勤哗,通過(guò)運(yùn)行TVM的一個(gè)實(shí)例了解和熟悉TVM的Python API使用抡爹。