本文檔對(duì)RunOptions的參數(shù)定義及使用進(jìn)行說(shuō)明心赶,tensorflow版本為1.12。
1.RunOptions參數(shù)
RunOptions提供配置參數(shù)岗喉,供SessionRun調(diào)用時(shí)使用蚌堵,包括:
- TraceLevel:
- timeout_in_ms: op超時(shí)等待時(shí)間,單位為ms
- inter_op_thread_pool: 創(chuàng)建session時(shí)如果配置了session_inter_op_thread_pool參數(shù)莲绰,當(dāng)前參數(shù)指定需要使用的線程池。注釋中有說(shuō)明姑丑,如果配置為-1蛤签,使用調(diào)用者的線程,適用于比較簡(jiǎn)單的圖執(zhí)行栅哀,避免線程切換的開銷顷啼,注意此處存在版本差異,tf1.10之前的版本昌屉,如果配置為-1會(huì)報(bào)InvalidArgument的錯(cuò)钙蒙。
- output_partition_graphs:布爾型變量,標(biāo)記當(dāng)前子圖執(zhí)行結(jié)果是否需要輸出至MetaData间驮。
- debug_options : debug相關(guān)參數(shù)躬厌。
- report_tensor_allocations_upon_oom:當(dāng)allocator發(fā)生OOM時(shí),Error信息中包含tensor allocation的信息,使能后會(huì)導(dǎo)致Session::Run執(zhí)行變慢扛施。
- experimental: 相關(guān)參數(shù)不穩(wěn)定鸿捧,不同版本使用時(shí)需要注意兼容性問(wèn)題。RunOptions中兩個(gè)實(shí)驗(yàn)參數(shù)至tensorflow 2.1依然有效疙渣。其中use_run_handler_pool推薦在CPU負(fù)載較大的場(chǎng)景比如inference中使用匙奴,達(dá)到session間線程池集中調(diào)度、降低延時(shí)的作用妄荔。
message RunOptions {
enum TraceLevel {
NO_TRACE = 0;
SOFTWARE_TRACE = 1;
HARDWARE_TRACE = 2;
FULL_TRACE = 3;
}
TraceLevel trace_level = 1;
int64 timeout_in_ms = 2;
int32 inter_op_thread_pool = 3;
bool output_partition_graphs = 5;
DebugOptions debug_options = 6;
bool report_tensor_allocations_upon_oom = 7;
message Experimental {
int64 collective_graph_key = 1;
bool use_run_handler_pool = 2;
};
Experimental experimental = 8;
reserved 4;
}
2. RunMetaData參數(shù)
RunMetaData與RunOptions中參數(shù)一樣泼菌,定義在config.proto中。通常啦租,配合RunOptions相關(guān)配置收集執(zhí)行過(guò)程中的跟蹤信息哗伯,包括延時(shí)、內(nèi)存開銷等篷角。
message RunMetadata {
StepStats step_stats = 1;
CostGraphDef cost_graph = 2;
repeated GraphDef partition_graphs = 3;
}
3. 源碼解析
session.h中定義了session.Run()的API焊刹,其中支持RunOptions作為參數(shù)輸入的API如下所示:
virtual Status Run(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs, RunMetadata* run_metadata);
本節(jié)主要關(guān)注inter_op_thread_pool 及use_run_handler_pool 兩個(gè)參數(shù)相關(guān)的源碼:
3.1 inter_op_thread_pool參數(shù)
在前序介紹NewSession流程的文檔中,了解到創(chuàng)建的thread_pool保存在了vector thread_pools_中恳蹲。
std::vector<std::pair<thread::ThreadPool*, bool>> thread_pools_;
在調(diào)用Session::Run時(shí)虐块,會(huì)先進(jìn)行參數(shù)檢查,inter_op_thread_pool應(yīng)該小于thread_pools_.size()嘉蕾,否則會(huì)報(bào)錯(cuò)贺奠。
if (run_options.inter_op_thread_pool() < -1 ||
run_options.inter_op_thread_pool() >=
static_cast<int32>(thread_pools_.size())) {
run_state.executors_done.Notify();
delete barrier;
return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
run_options.inter_op_thread_pool());
}
對(duì)于合法參數(shù),tensorflow采用指定的線程池完成后續(xù)的計(jì)算荆针。
tensorflow1.12中開始允許inter_op_thread_pool=-1,此時(shí)采用主線程完成計(jì)算敞嗡。
thread::ThreadPool* pool =
run_options.inter_op_thread_pool() >= 0
? thread_pools_[run_options.inter_op_thread_pool()].first
: nullptr;
if (pool == nullptr) {
if (executors_and_keys->items.size() > 1) {
pool = thread_pools_[0].first;
} else {
VLOG(1) << "Executing Session::Run() synchronously!";
}
}
3.2 use_run_handler_pool
當(dāng)使用GlobalThreadPool時(shí)颁糟,該配置參數(shù)有效航背。
備注:GlobalThreadPool相關(guān)介紹可參見:http://www.reibang.com/p/e9fd4f0d6bd1
std::unique_ptr<RunHandler> handler;
if (ShouldUseRunHandlerPool() &&
run_options.experimental().use_run_handler_pool()) {
handler = GetOrCreateRunHandlerPool(options_)->Get();
}
auto* handler_ptr = handler.get();
主要影響Session::Run()時(shí)使用的RunHandler,該類的定義位于:
tensorflow/core/framework/run_handler.h。
class RunHandler {
public:
void ScheduleInterOpClosure(std::function<void()> fn);
~RunHandler();
private:
class Impl;
friend class RunHandlerPool::Impl;
explicit RunHandler(Impl* impl);
Impl* impl_; // NOT OWNED.
};
當(dāng)配置use_run_handler_pool時(shí)棱貌,通過(guò)GetOrCreateRunHandlerPool獲取RunHandler玖媚。通過(guò)維護(hù)一個(gè)全局的RunHandlerPool,達(dá)到提升性能的目的婚脱。
static RunHandlerPool* GetOrCreateRunHandlerPool(
const SessionOptions& options) {
static RunHandlerPool* pool =
new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options));
return pool;
}
4.使用示例
4.1 timeline
可將運(yùn)行時(shí)trace信息通過(guò)chrome:://tracing打開保存的json文件進(jìn)行分析:
import tensorflow as tf
from tensorflow.python.client import timeline
a = tf.random_normal([1, 100])
b = tf.random_normal([1, 100])
res = tf.add(a, b)
with tf.Session() as sess:
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
sess.run(res, options=options, run_metadata=run_metadata)
timeline = timeline.Timeline(run_metadata.step_stats)
trace = timeline.generate_chrome_trace_format()
with open('timeline.json', 'w') as f:
f.write(chrome_trace)
如果需要合并多次session.run的trace今魔,可使用如下TimeLiner類實(shí)現(xiàn),每次調(diào)用session.run將trace存為json格式后障贸,可調(diào)用TimeLiner的update_timeline函數(shù)進(jìn)行更新错森,最后調(diào)用save函數(shù)將timeline寫入json文件:
import json
class TimeLiner:
_timeline_dict = None
def update_timeline(self, chrome_trace):
chrome_trace_dict = json.loads(chrome_trace)
if self._timeline_dict is None:
self._timeline_dict = chrome_trace_dict
else:
for event in chrome_trace_dict['traceEvents']:
if 'ts' in event:
self._timeline_dict['traceEvents'].append(event)
def save(self,f_name):
print (f_name)
with open(f_name,'w') as f:
json.dump(self._timeline_dict,f)