1. 前言
我們在訓練之前矩距,先建立好一個圖,然后我們可以在這個圖上做我們想做的優(yōu)化怖竭,這種形式稱為Symbolic Programs锥债。相對應的是Imperative Programs,也就是每一句代碼都對應著程序的執(zhí)行,在這種情況下哮肚,我們可以寫類似于下面的代碼:
a = 2
b= a + 1
d = np.zeros(10)
for i in range(d):
d += np.zeros(10)
這在symbolic的方式下是做不到的登夫,因為在for循環(huán)開始時,程序并不知道d
的值允趟,也就無法判斷循環(huán)的次數(shù)恼策。
因此我們可以說,symbolic更高效潮剪,imperative更靈活涣楷。
MxNet是一個異步式的訓練框架,它支持上面的兩種形式抗碰。我們可以使用NDArray
來進行imperative形式的程序編寫狮斗,也可以使用symbol
來建立圖。
2. op
先來了解operator
弧蝇,不了解operator
可能就很難理解源碼中占據(jù)了很大一部分的operator的定義碳褒。就是通過這些operator來將symbol連接成為了一個圖。
-
OpManager
:單例結構體捍壤,通過OpManager::Global()
總會返回同一個結構體骤视。Op的構造函數(shù)會將OpManager
的op_counter
加一,并且將自己的index_
注冊為當前的op_counter
鹃觉。 -
add_alias
:將別名注冊到`dmlc::Registry<Op>中 -
Get
:根據(jù)name
返回Op
GetAttrMap
2.1 op
-
name
:名字 -
description
:該op的描述 -
num_inputs
:輸入的個數(shù) -
num_outputs
:輸出的個數(shù) -
get_num_outputs, get_num_inputs
:函數(shù)专酗,返回輸出,輸入的個數(shù) -
attr_parser
:函數(shù)盗扇,用于方便返回該op的參數(shù) -
Op& Op::describe(const std::string& descr)
:方法用于將輸入注冊到description變量中祷肯,并返回這個op,方便接著調(diào)用其他方法疗隶。
2.2 幾個宏
-
#define NNVM_REGISTER_VAR_DEF(OpName)
:定義OpName -
#define NNVM_REGISTER_VAR_DEF(TagName)
:定義TagName
#define NNVM_REGISTER_OP(OpName) \
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName, __COUNTER__) = \
::dmlc::Register<::nnvm::op>::Get()->__REGISTER_OR_GET(#OpName)
注冊op佑笋,并返回該op
3. Node
Node是組成symbol的基本組件。
結構體NodeEntry
包含了:
-
node
:指向node的指針 -
index
:輸出的索引值 -
version
:輸入的version
結構體NodeAttrs
包含了:
-
op
: 指向operator的指針 -
name
: node的名字 -
dict
:attributes的字典
類Node
包含:
-
attrs
:結構體NodeAttrs
成員斑鼻,存儲了op, name, attributes
等信息蒋纬。 -
inputs
:輸入,是一個元素為NodeEntry
的向量 -
control_deps
:保存了應該在該node執(zhí)行之前執(zhí)行的node坚弱。 -
op()
:返回該Node的operator蜀备,就是返回attrs
中保存的op
-
Create()
:類方法,靜態(tài)方法荒叶,用于新建一個Node碾阁,返回指向它的指針 -
num_outputs
:如果是變量,輸出為1些楣,否則返回op
的輸出
幾個函數(shù)
定義在文件op_attr_types.h
中
-
FListinputNames
:返回輸入的名字脂凶,默認return {'data'}
-
FNumVisibleOutputs
:用于隱藏一些輸出 -
FListOutputNames
:返回輸出的名字 -
FMutateInputs
:返回該node會改變的node的索引值 -
FInferNodeEntryAttr
:推理出AttrType
-
FInferShape
:推理shape宪睹,也就是上面的AttrType
為Tshape
-
FInferType
:推理類型 -
TIsBackward
是否是反向傳播 FInplaceOption
-
FGradient
:返回node的梯度節(jié)點 -
FSetInputVarAttrOnCompose
:為輸入設置attribute -
FCorrectLayout
:推理layout -
FInputGraph
:返回輸入,解釋為圖而不是數(shù)據(jù)
這些函數(shù)是在定義具體的op時蚕钦,可以選擇注冊對應的函數(shù)亭病。
4. Symbol
Symbol是為了使用Node建立Graph。Symbol是我們能夠直接接觸的類嘶居,它定義了一系列方法用于更方便地構建圖命贴。在symbol的成員outputs
中,定義了一組由NodeEntry
組成的向量食听。
-
outputs
:該symbol包含的輸出,是一個元素是NodeEntry
的向量 -
Copy
:返回一個深拷貝污茵,方式是通過遍歷Node樱报,每次訪問到的Node保存起來,再建立起node之間的連接泞当,最后將head加入到outputs中迹蛤。 -
Symbol operator[] (size_t index) const
:返回第個輸出。
-
ListInputs
:返回輸入 -
ListInputNames
:返回輸入的名字 -
Compose
:組合symbol -
operator ()
:調(diào)用compose襟士,來組合symbol -
AddControlDeps
:加入控制盗飒,用于有向圖的構建 -
GetInternals
:返回一個symbol,它的輸出是原來symbol的輸出加上所有中間輸出和輸入 -
GetChildren
: -
SetAttrs
:設置attribution -
GetAttrs
: -
CreateFunctor
:給定op和attrs陋桂,返回一個symbol
我認為symbol
中比較重要的函數(shù)是compose逆趣,在調(diào)用的時候我們是通過調(diào)用symbol
的操作符()
函數(shù),也就是operator ()
嗜历,該函數(shù)將參數(shù)傳遞給Compose
宣渗。
5. Graph
類Graph
就是計算的時候使用的圖
-
outputs
:和symbol
的outputs
一樣,類型為std::vector<NodeEntry>
-
attrs
:定義了圖的一些屬性 -
PostOrderDFSVisit
:后序遍歷圖梨州,給定參數(shù)head痕囱,進行拓撲排序。算法暴匠,貌似鞍恢,就是拓撲排序算法。 -
DFSVisit
:調(diào)用PostOrderDFSVisit
每窖,對圖的head進行拓撲排序帮掉。參數(shù)為:const std::vector<NodeEntry>& heads, FVisit fvisit
,其中head
是反向傳播時的頭節(jié)點岛请,fvisit
是訪問時調(diào)用的函數(shù)旭寿,該方法將fvisit(*n)
作為訪問節(jié)點時的函數(shù),[](GNode n)->Node*{return->get();}
作為hash函數(shù)崇败,這個函數(shù)看簽名返回的是一個指向節(jié)點的指針盅称。圖的節(jié)點入度計算如下:
[](GNode n)->uint32_t {
if (!(*n)) return 0;
return (*n)->input.size() + (*n)->control_deps.size();
}
節(jié)點輸入計算如下:
[](GNode n, uint32_t index)->GNode {
if (index < (*n)->input.size()) {
return &(*n)->input.at(index).node;
} else {
return &(*n)->contorl_deps.at(index - (*n)->inputs.size());
}
6. IndexedGraph
IndexedGraph
由Graph
返回肩祥,
-
nodes_
:成員變量,一個指向Node
結構體的向量缩膝,Node
定義如下:
struct Node {
const nnvm::Node* source;
array_view<NodeEntry> inputs;
array_view<uint32_t> control_deps;
std::weak_ptr<nnvm::Node> weak_ref;
};
其中NodeEntry
如下:
struct NoodeEntry {
uint32_t node_id;
uint32_t index;
uint32_t version;
};
成員變量:
-
input_nodes_
:輸入node的索引 mutable_input_nodes_
-
outputs
:輸出節(jié)點 -
node2index
:node到索引的映射 -
entry_rptr_
: -
input_entries_
: -
control_deps_
:
方法: DFSVisit
PostOrderDFSVisti
7. pass
7.1 gradient.cc
-
Gradient
:gradient
會根據(jù)屬于的graph
混狠,返回一個帶反向傳播圖的新圖。它主要由executor
建立圖的時候調(diào)用疾层,調(diào)用方式如下:
nnvm::Graph g_grad = nnvm::pass::Gradient(g,
symbol.outputs, xs, head_grad_entry_, ArggregateGradient,
need_mirror, nullptr, zero_ops, "_copy");
調(diào)用該方法會調(diào)用文件pass_function.h
下的Gradient
函數(shù)将饺。該函數(shù)將傳入的參數(shù)保存在graph
下的attrs
中。再通過applypass
調(diào)用Gradient方法痛黎。也就是在該文件下定義的方法予弧,簽名:Graph Gradient(Graph src)
。
- 根據(jù)DFSVisit進行拓撲排序湖饱,將序列存儲到
topo_order
中 - 將輸出的梯度保存在
output_grads
- 根據(jù)
mirror_fun
在適當?shù)牡胤讲迦胄碌墓?jié)點掖蛤,來實現(xiàn)內(nèi)存的復用
-
DefaultAggregateGradient
: