Operator
基本特點(diǎn)
Caffe2中大多數(shù)我們所接觸的operator都是class Operator的子類鸟辅。
而Operator則是上系列中我們提及的class OperatorBase的子類泉粉。下面我們將一一過一下它新加的一些主要接口及其涵義。
它與OperatorBase不同,是Device context相關(guān)的一個(gè)模板類手形。因此與OperatorBase不同牺蹄,它包含了一個(gè)名為context的屬性酒觅。
// Operator is the class that you usually want to derive, if your operator will
// run on different devices. You should then implement the RunOnDevice()
// function.
template <class Context>
class Operator : public OperatorBase {
public:
explicit Operator(const OperatorDef& operator_def, Workspace* ws)
: OperatorBase(operator_def, ws), context_(operator_def.device_option()) {
// In the constructor, we switch to the device so that the child class
// constructors will run on that device.
context_.SwitchToDevice(0);
}
~Operator() noexcept override {}
.......
........
const Context* getContext() const {
return &context_;
}
protected:
void RecordEvent(const char* err_msg = nullptr) final {
if (event_) {
context_.Record(event_.get(), err_msg);
}
}
Context context_;
};
而像一般的輸入撮执、輸出基本operator功能,則都通過直接委托父類OperatorBase來完成舷丹,如下所示:
inline const Tensor& Input(
int idx,
DeviceType type = Context::GetDeviceType()) {
return OperatorBase::template Input<Tensor>(idx, type);
}
inline Tensor* Output(int idx, at::IntList dims, at::TensorOptions options) {
if (options.device_opt() == c10::nullopt) {
return OperatorBase::OutputTensor(
idx, dims, options.device(context_.device()));
}
return OperatorBase::OutputTensor(idx, dims, options);
}
inline Tensor* Output(int idx, DeviceType type = Context::GetDeviceType()) {
return OperatorBase::template Output<Tensor>(idx, type);
}
Event相關(guān)async處理
對(duì)于Event等operator異步執(zhí)行相關(guān)的處理抒钱,它跟OperatorBase并無太多不同,但同時(shí)也將event處理跟具體的device context更加緊密地綁定起來了。因此如果你的async operator
是Operator的子類谋币,那么將可直接使用它提供的一些Event處理函數(shù)來進(jìn)行異步操作即可仗扬,但若你使用OperatorBase作為父類,那么還得考慮所使用的operator的具體device種類瑞信,并在
進(jìn)行event處理時(shí)考慮傳device context參數(shù)厉颤。
void WaitEvent(const Event& ev, int stream_id = -1) final {
if (stream_id >= 0) {
context_.SwitchToDevice(stream_id);
}
context_.WaitEvent(ev);
}
void WaitEvents(const std::vector<const Event*>& events, int stream_id = -1)
final {
if (stream_id >= 0) {
context_.SwitchToDevice(stream_id);
}
for (const auto& ev : events) {
context_.WaitEvent(*ev);
}
}
Run及RunAsync
若說operator里面最為核心及用戶接觸最頻繁的兩個(gè)函數(shù)穴豫,那么肯定非Run及RunAsync莫屬凡简。
當(dāng)然它們都是包了真正子類Operator里面定義的RunOnDevice函數(shù)。只是Run用來以sync的方式來執(zhí)行一個(gè)op精肃,而RunAsync則是以async的方式來執(zhí)行它秤涩。
下面是op sync執(zhí)行的方式。
// The run function of Operator switches to the device, and then carries out
// the actual computation with RunOnDevice(). You should implement RunOnDevice
// instead of Run().
// Note: Run does not update operator's event and can be used only with
// non-async executors that do not rely on events
bool Run(int stream_id = 0) final {
try {
StartAllObservers();
context_.SwitchToDevice(stream_id);
bool result = RunOnDevice();
if (!result) {
this->RecordLastFailedOpNetPosition();
}
context_.FinishDeviceComputation(); // throws on error
StopAllObservers();
return result;
} catch (EnforceNotMet& err) {
if (has_debug_def()) {
err.AppendMessage(
"Error from operator: \n" + ProtoDebugString(debug_def()));
AddRelatedBlobInfo(&err);
}
this->RecordLastFailedOpNetPosition();
StopAllObservers();
throw;
} catch (...) {
this->RecordLastFailedOpNetPosition();
StopAllObservers();
throw;
}
}
下面則是op async執(zhí)行的具體函數(shù)RunAsync司抱】鹁欤可見我們?cè)趏perator里面定義的一些event相關(guān)函數(shù),大多都是在這里被使用的习柠。
bool RunAsync(int stream_id = 0) final {
try {
StartAllObservers();
context_.SwitchToDevice(stream_id);
auto result = RunOnDevice();
if (result) {
if (HasAsyncPart()) {
RecordEvent();
} else {
// Manually set CPU operator's event status to finished,
// unless this is an async CPU operator
SetEventFinished();
}
} else {
SetEventFinished(getErrorMsg().c_str());
this->RecordLastFailedOpNetPosition();
}
StopAllObservers();
return result;
} catch (EnforceNotMet& err) {
if (has_debug_def()) {
err.AppendMessage(
"Error from operator: \n" + ProtoDebugString(debug_def()));
AddRelatedBlobInfo(&err);
}
SetEventFinishedWithException(err.what());
this->RecordLastFailedOpNetPosition();
StopAllObservers();
throw;
} catch (const std::exception& err) {
SetEventFinishedWithException(err.what());
this->RecordLastFailedOpNetPosition();
StopAllObservers();
throw;
} catch (...) {
SetEventFinishedWithException(getErrorMsg().c_str());
this->RecordLastFailedOpNetPosition();
StopAllObservers();
throw;
}
}
真正干活的RunOnDevice函數(shù)在這里是個(gè)純虛函數(shù)匀谣,并不做啥事,接口而已资溃。
virtual bool RunOnDevice() = 0;
Operator async屬性及async執(zhí)行
Operator async執(zhí)行有兩個(gè)概念需要領(lǐng)悟清楚武翎。其一HasAsyncPart,它指的是我們常規(guī)意義上理解的async執(zhí)行溶锭,即不等待操作真正執(zhí)行完畢而是立即就返回一個(gè)handle宝恶,將來
需要時(shí)再去check看操作是否真正完成。另二則是AsyncScheduling趴捅,它指的是是否我們?cè)趫?zhí)行的op支持不待其input ready(即父op執(zhí)行完成)就被schedule到pool或stream中以
avalable的狀態(tài)去執(zhí)行(這一設(shè)計(jì)顯然是受CUDA programming model影響而來的)垫毙,當(dāng)然也不可能真正的不顧input是否就緒,只是將同步的責(zé)任由framework移交給了CUDA而已拱绑。
// Events of operators that don't have async parts are automatically set
// to finished state by RunAsync.
// Defaulting to the value from context (true for CUDA, false for CPU).
// Override in case of async CPU operators
// Async CPU operators are expected to catch all exceptions in async parts
// and set Event to finished/failed state with Event::SetFinished or
// SetFinishedWithException call.
bool HasAsyncPart() const override {
return context_.HasAsyncPartDefault();
}
// Returns whether operator's RunOnDevice schedules async on device part and
// can be run without waiting for parent operator's async part to be finished
// on the same device.
// Note: when true, RunOnDevice must not access the content of the input blobs
// as they might not be computed yet
// Note: when true, operator's device needs to support async scheduling:
// - supports concept of streams: async ops scheduled on the same stream are
// guaranteed to be executed in the same order they were scheduled
// - provides non-blocking cross device/cross stream synchronization
// primitives
//
// By default, assuming an op with an async part can be scheduled
// asynchronously if device supports async scheduling
bool SupportsAsyncScheduling() const override {
return HasAsyncPart() && context_.SupportsAsyncScheduling();
}
一般综芥,我們以sync的方式執(zhí)行完成一個(gè)op后,就需要以使用一個(gè)barrier來保證其真正執(zhí)行完成猎拨,這里亦是參考了CUDA 異步編程模型里的思想膀藐。這framework簡(jiǎn)直就是CUDA
編程模型的一個(gè)wrapper啊迟几!呵呵消请,這也是讓其它CPU/ASIC等各種廠商很討厭的地方,所謂的nVidia CUDA生態(tài)的護(hù)城河类腮。臊泰。
void SyncDeviceBarrierForObservers() override {
context_.FinishDeviceComputation();
}
Operator相關(guān)的utilities
下面是不同Device context operator注冊(cè)相關(guān)的一些utilities。這一套作法跟之前Caffe里面Layer/Net/Solver等工廠模式很是類似蚜枢,畢竟是出自一個(gè)人的手筆嘛缸逃。针饥。
以下為CPU operator注冊(cè)的一些utilities,其它像CUDA/HIP/IDEEP等都比較類似需频。
// The operator registry. Since we are not expecting a great number of devices,
// we will simply have an if-then type command and allocate the actual
// generation to device-specific registerers.
// Note that although we have CUDA and CUDNN here, the registerers themselves do
// not depend on specific cuda or cudnn libraries. This means that we will be
// able to compile it even when there is no cuda available - we simply do not
// link any cuda or cudnn operators.
C10_DECLARE_REGISTRY(
CPUOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
#define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \
C10_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__)
#define REGISTER_CPU_OPERATOR(name, ...) \
C10_IMPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CPU##name() { \
CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
} \
C10_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__)
#define REGISTER_CPU_OPERATOR_STR(str_name, ...) \
C10_REGISTER_TYPED_CLASS(CPUOperatorRegistry, str_name, __VA_ARGS__)
#define REGISTER_CPU_OPERATOR_WITH_ENGINE(name, engine, ...) \
C10_REGISTER_CLASS(CPUOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)