介紹
在Caffe2的設(shè)計中邑跪,一切操作皆是Op画畅。其中數(shù)據(jù)IO操作相關(guān)的ops有CreateDBOp、PrefetchOperator症脂、ImageInputOp等淫僻;用于初始化數(shù)據(jù)的操作ops有UniformFillOp、GaussianFillOp等棕所;
還有真正用于計算的Ops像FullyConnectedOp悯辙、FullyConnectedGradientOp笑撞、ConvOp、ConvGradientOp等坚踩;用于多節(jié)點訓(xùn)練時梯度融合的操作如BroadcastOp瓤狐、AllreduceOp、AllgatherOp等嗓节;
其它則還有些用于各種參數(shù)更新算法的Ops如AdamOp皆警、AdagradOp、MomentumSGDOp等绸罗。
一般所有的Caffe2 Python前端代碼寫畢后豆瘫,兩張建好的static Graph就在背后生成出來了。一張Graph為init_net育灸,上面包含了一些初始化操作像Parameters的filler操作等或者多節(jié)點通信時
整體通信環(huán)境構(gòu)建時的操作像CreateCommonWorld昵宇、CloneCommonWorld等趟薄。通常init_net上的ops都只需執(zhí)行一次完成初始化目的即可典徊。另一張Graph則為net卒落,它上面包含了我們正式模型的基本所
用ops像一些用于計算的Ops如Conv/FC等,還有些則是用于更新參數(shù)的Ops像AdamOp/MomentumSGDOp等也切。它在init_net執(zhí)行完后再執(zhí)行腰湾,需要執(zhí)行多次,最終得到訓(xùn)練好的模型參數(shù)倒槐。
可以說Operator是Caffe2中最核心的元素之一附井,往往也是我們使用一個framework寫AI程序所需接觸最頻繁的一個部件永毅。
Caffe2中核心Operator的實現(xiàn)在兩個class里面:OperatorBase與Operator,其中Operator是OperatorBase的一個子類着逐。大多數(shù)我們用到的Op都是Operator的子類,只需要實現(xiàn)它的若干override
函數(shù)即可峰鄙,但也有些特別情況像PrefetchOp直接從OperatorBase繼承而來吟榴,因為我們想直接掌握它之上Op執(zhí)行時的輸出同步情況囊扳。
以下為caffe2核心代碼中與Operator相關(guān)的一些代碼文件。
core git:(master) ? ls operator
operator_c10wrapper.cc operator.cc operator_gradient.h operator_schema.cc operator_schema_test.cc
operator_c10wrapper.h operator_gpu_test.cc operator.h operator_schema.h operator_test.cc
我們將主要介紹一下OperatorBase與Operator這兩個類的基本APIs及其實現(xiàn)狭瞎,同時稍帶著也會看些其它像schema熊锭、gradient等之類的operator features雪侥。
OperatorBase
Observer和ObservableBase類
首先如下所示速缨,OperatorBase是Observable<OperatorBase>(本質(zhì)上是一個ObserverBase的一個實例類)的一個子類。這里使用了Observer的設(shè)計模式仿粹。即它考慮將每個Operator都設(shè)計為可觀察的對象原茅,以便在它執(zhí)行操作時有相關(guān)的Observer
可以去偵測它的執(zhí)行狀態(tài)擂橘。
class CAFFE2_API OperatorBase;
typedef ObserverBase<OperatorBase> OperatorObserver;
class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
public:
explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
virtual ~OperatorBase() noexcept {}
如下為ObservableBase與Observer的定義,還是蠻簡單的契讲,我們可據(jù)此完成我們自己對某一類型T的observer滑频。在Caffe2中真正使用Observer來去監(jiān)控的對象有兩個峡迷,一個為這里介紹的Operator你虹,另一個則為我們以后將會去分析的network彤避。
template <class T>
class ObserverBase {
public:
explicit ObserverBase(T* subject) : subject_(subject) {}
virtual void Start() {}
virtual void Stop() {}
virtual std::string debugInfo() {
return "Not implemented.";
}
virtual ~ObserverBase() noexcept {};
T* subject() const {
return subject_;
}
virtual std::unique_ptr<ObserverBase<T>> rnnCopy(T* subject, int rnn_order)
const {
return nullptr;
};
protected:
T* subject_;
};
/**
* Inherit to make your class observable.
*/
template <class T>
class Observable {
public:
Observable() = default;
Observable(Observable&&) = default;
Observable& operator =(Observable&&) = default;
virtual ~Observable() = default;
C10_DISABLE_COPY_AND_ASSIGN(Observable);
using Observer = ObserverBase<T>;
/* Returns a reference to the observer after addition. */
const Observer* AttachObserver(std::unique_ptr<Observer> observer) {
CAFFE_ENFORCE(observer, "Couldn't attach a null observer.");
std::unordered_set<const Observer*> observers;
for (auto& ob : observers_list_) {
observers.insert(ob.get());
}
const auto* observer_ptr = observer.get();
if (observers.count(observer_ptr)) {
return observer_ptr;
}
observers_list_.push_back(std::move(observer));
UpdateCache();
return observer_ptr;
}
/**
* Returns a unique_ptr to the removed observer. If not found, return a
* nullptr
*/
std::unique_ptr<Observer> DetachObserver(const Observer* observer_ptr) {
for (auto it = observers_list_.begin(); it != observers_list_.end(); ++it) {
if (it->get() == observer_ptr) {
auto res = std::move(*it);
observers_list_.erase(it);
UpdateCache();
return res;
}
}
return nullptr;
}
private:
inline static void StartObserver(Observer* observer) {
try {
observer->Start();
} catch (const std::exception& e) {
LOG(ERROR) << "Exception from observer: " << e.what();
} catch (...) {
LOG(ERROR) << "Exception from observer: unknown";
}
}
inline static void StopObserver(Observer* observer) {
try {
observer->Stop();
} catch (const std::exception& e) {
LOG(ERROR) << "Exception from observer: " << e.what();
} catch (...) {
LOG(ERROR) << "Exception from observer: unknown";
}
}
.............
.............
private:
// an on-stack cache for fast iteration;
// ideally, inside StartAllObservers and StopAllObservers,
// we should never access observers_list_
Observer* observer_cache_;
size_t num_observers_ = 0;
protected:
std::vector<std::unique_ptr<Observer>> observers_list_;
};
Input/Output blobs及Arguments基本操作
OperatorBase里面包含了些最基本的Operator所需的操作像輸入、輸出Blob處理(獲得或其之上類型查詢圆米、Copy等)娄帖,Operator參數(shù)處理等。
以下為它用于參數(shù)處理的兩個示例API函數(shù)诈嘿。
/** @brief Checks if the operator has an argument of the given name.
*/
inline bool HasArgument(const string& name) const {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::HasArgument(*operator_def_, name);
}
// Functions that deal with arguments. Basically, this allows us to map an
// argument name to a specific type of argument that we are trying to access.
template <typename T>
inline T GetSingleArgument(const string& name, const T& default_value) const {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
*operator_def_, name, default_value);
}
下面則為它處理Input/Output blob的一些API函數(shù)奖亚。其中inputs_/outputs_分別是一個operator上所具有的blobs成員屬性佩耳。有了之前幾章講過的Blob與Tensor的知識后谭跨,這些函數(shù)也就比較好懂了螃宙。
// Get the inputs and outputs as specific types.
template <typename T>
inline const T& Input(int idx) {
static_assert(
!std::is_same<T, Tensor>::value,
"You should use Input<Tensor>(int, DeviceType) for "
"Tensor.");
DCHECK_LT(idx, inputs_.size());
try {
return inputs_.at(idx)->template Get<T>();
} catch (::caffe2::EnforceNotMet& enf) {
if (has_debug_def()) {
enf.AppendMessage(".\nOffending Blob name: ");
enf.AppendMessage(debug_def().input(idx));
enf.AppendMessage(".\n");
}
throw enf;
}
}
// TODO(jerryzh): Remove template
// and the type argument?
// This is to keep the API changes minimal and make refactoring
// a bit easier
template <typename T>
inline const T& Input(int idx, DeviceType type) {
static_assert(
std::is_same<T, Tensor>::value,
"Input(int, DeviceType) is only available for Tensor");
DCHECK_LT(idx, inputs_.size());
try {
// TODO(jerryzh): We'll need to check device type in Get<T>() later
// Get<T>() -> Get<T>(type)
const auto& tensor = inputs_.at(idx)->template Get<T>();
return tensor;
} catch (::caffe2::EnforceNotMet& enf) {
if (has_debug_def()) {
enf.AppendMessage(".\nOffending Blob name: ");
enf.AppendMessage(debug_def().input(idx));
enf.AppendMessage(".\n");
}
throw enf;
}
}
template <typename T>
inline T* Output(int idx, DeviceType type) {
static_assert(
std::is_same<T, Tensor>::value,
"Output(int, DeviceType) is only available for Tensor");
// When you get a Tensor here it is not fully initialized
return BlobGetMutableTensor(outputs_.at(idx), type);
}
inline Tensor*
OutputTensor(int idx, at::IntList dims, at::TensorOptions options) {
CAFFE_ENFORCE_WITH_CALLER(
options.device_opt() != c10::nullopt,
"device must be provided in option.");
return BlobGetMutableTensor(outputs_.at(idx), dims, options);
}
template <typename T>
inline T* Output(int idx, T* allocated) {
outputs_.at(idx)->Reset(allocated);
return allocated;
}
其它與輸入挂捅、輸出相關(guān)的API函數(shù)還有如下一些utilities函數(shù)堂湖。(當(dāng)然這里只是一部分无蜂,但大致皆是如此,多是些淺顯易懂的斥季,多是像OutputIsType這樣直接使用了Blob的成員函數(shù))。
template <typename T>
inline bool OutputIsType(int idx) {
static_assert(
!std::is_same<T, Tensor>::value,
"You should use OutputIsTensorType(int, DeviceType) for "
"Tensor.");
return outputs_.at(idx)->template IsType<T>();
}
inline bool OutputIsTensorType(int idx, DeviceType type) {
return BlobIsTensorType(*outputs_.at(idx), type);
}
inline int InputSize() const {
return inputs_.size();
}
inline int OutputSize() const {
return outputs_.size();
Event處理相關(guān)async操作函數(shù)
下面一些與Event相關(guān)的操作才是真正開始有趣的東西谤专。每個Caffe2 op都有一個Event成員(Event在將來我們也會單獨拿出來介紹)午绳,用于同步op間的依賴執(zhí)行操作箱叁。它是一個op async執(zhí)行時保證parent op與child op之間同步的產(chǎn)物。如果讀者
熟悉CUDA編程APIs算色,那么對于cudaEvent一定不陌生螟够,本質(zhì)上在CUDAOperator當(dāng)中,它之上的event_就是一個cudaEvent的wrapper若河。
簡單說來寞宫,async執(zhí)行時辈赋,我們有些op可能需要去check(Query或Wait)父類的event以確定它是否執(zhí)行完了(即自己的輸入有保證了)。同時它們也要在執(zhí)行過后(或者是完成了自己op執(zhí)行命令的提交)悟民,更新自己的event篷就,以對自己的childeren operators發(fā)揮影響竭业。
virtual void WaitEvent(const Event& ev, int /*stream_id */ = -1) {
ev.Finish();
}
inline void Wait(const OperatorBase& other, int stream_id = -1) {
if (!other.IsEventDisabled()) {
WaitEvent(other.event(), stream_id);
}
}
virtual void WaitEvents(
const std::vector<const Event*>& events,
int /*stream_id*/ = -1) {
for (const auto& ev : events) {
ev->Finish();
}
}
virtual void Finish() {
if (event_) {
event_->Finish();
}
}
const Event& event() const {
CAFFE_ENFORCE(event_, "Event is disabled");
return *event_;
}
Event& event() {
CAFFE_ENFORCE(event_, "Event is disabled");
return *event_;
}
void ResetEvent() {
if (event_) {
event_->Reset();
}
}
void DisableEvent() {
event_ = nullptr;
}
bool IsEventDisabled() const {
return !event_;
}
protected:
virtual void RecordEvent(const char* /*err_msg*/ = nullptr) {
CAFFE_NOT_IMPLEMENTED;
}
void SetEventFinished(const char* err_msg = nullptr) {
if (event_) {
event_->SetFinished(err_msg);
}
}
void SetEventFinishedWithException(const char* err_msg = nullptr) {
if (event_) {
event_->SetFinishedWithException(err_msg);
}
}
下面的兩個APIs函數(shù)同樣與op的async執(zhí)行相關(guān)未辆。在OperatorBase中將它一律簡單設(shè)為了false即op默認(rèn)不support這兩種feature(至于feature具體為何義,我們將在Operator講解時進行闡述)钾麸。
可見大部分與async執(zhí)行相關(guān)的Op操作都被放入了Operator中饭尝,OperatorBase這里有API人,但其實形同虛設(shè)实撒。所以如果你在實現(xiàn)自己的Op時涉瘾,如果想遵循Caffe2中已有一套Op async執(zhí)行邏輯,那么可將自己op實現(xiàn)為
operator的子類(這是絕大多數(shù)時候的正常做法负敏,畢竟caffe2是assume用戶會直接遵循自己關(guān)于event/op async執(zhí)行的一套做法的即operator中的那些默認(rèn)做法)秘蛇;但如果你對某一Device context上的Op async執(zhí)行
有著與framework original design略為不同的想法時赁还,那么你就可以將op實現(xiàn)為OperatorBase的子類,然后去重載這幾個與async執(zhí)行蹈胡、Events操作相關(guān)的APIs函數(shù)朋蔫。
virtual bool HasAsyncPart() const {
return false;
}
virtual bool SupportsAsyncScheduling() const {
return false;
}
Utility函數(shù)
下面是一些用于出錯時debug的APIs函數(shù)斑举,無關(guān)核心病涨。
inline const OperatorDef& debug_def() const {
CAFFE_ENFORCE(has_debug_def(), "operator_def was null!");
return *operator_def_;
}
inline void set_debug_def(
const std::shared_ptr<const OperatorDef>& operator_def) {
operator_def_ = operator_def;
}
inline bool has_debug_def() const {
return operator_def_ != nullptr;
}
public:
void RecordLastFailedOpNetPosition() {
if (net_position_ != kNoNetPositionSet) {
VLOG(1) << "Operator with id " << net_position_ << " failed";
operator_ws_->last_failed_op_net_position = net_position_;
} else {
VLOG(1) << "Failed operator doesn't have id set";
}
}
int net_position() const {
return net_position_;
}
void set_net_position(int idx) {
net_position_ = idx;
}
Caffe2中每個Operator都有Device Context的概念即此Op操作是在哪個device之上進行的既穆。Device的具體信息則可從DeviceOption中獲得。
const DeviceOption& device_option() const {
return device_option_;
}
另外每個Operator又可根據(jù)其Engine不同有著不一樣的實現(xiàn)(在同一Device context下面)励两。這個設(shè)計follow了之前Caffe中的思路(如Layer的類型可以為GPU engine或CPU engine等)当悔。
要知道像其它Caffe2組件如Blob/Tensor/Context等使用注冊機制來完成具體定義一樣,Operator也是在完成后需要去注冊在某一device相關(guān)的OperatorRegistry下面的嗅骄。一般Operator name會
作為此operator實現(xiàn)item的key饼疙。而若我們在同一device下面對同一類型op(有同樣的名字op_name)有著不同的實現(xiàn)(比如Allreduce操作在CPU設(shè)計上可以使用MPI窑眯,也可使用Gloo來完成其底層的具體通信aggregation),
那我們就可使用engine來區(qū)別這一不同的實現(xiàn)炊林。
下面是與engine相關(guān)的一些base operator函數(shù)卷要。
void annotate_engine(const std::string& engine) {
engine_ = engine;
}
const std::string& engine() const {
return engine_;
}
像其它許多深度學(xué)習(xí)Framework一樣却妨,Caffe2對于Op操作執(zhí)行在GPU device上有著很強的假設(shè)(very bad assumptions,它讓CUDA生態(tài)在DL領(lǐng)域過于強大了)倍权。如下函數(shù)即為此類假設(shè)的一個縮影捞烟。
// Checks whether stream is ready to execute new computation,
// used in stream allocation optimization to skip stream that is currently
// busy. Depends on context and operator's device, returns true by default
virtual bool IsStreamFree(int /* unused */) const {
return true;
}
核心的Run及RunAsync函數(shù)
最終我們看下關(guān)鍵的Op真正運行的兩個函數(shù)吧题画。如下示,其中Run一般為op sync執(zhí)行時所調(diào)用的缩幸,而RunAsync則是其async執(zhí)行時所調(diào)用的竞思「桥纾可以看出在執(zhí)行時都默認(rèn)我們會將操作提交到一stream上,這顯然亦是
由CUDA編程API中所用到的cuda_stream的思想而來距辆。。
virtual bool Run(int /* unused */ /*stream_id*/ = 0) {
CAFFE_NOT_IMPLEMENTED;
}
// RunAsync, if implemenented by the specific operators, will schedule the
// computation on the corresponding context and record the event in its
// event_ member object. If the specific operator does not support RunAsync,
// it will simply be synchronous as a fallback.
virtual bool RunAsync(int stream_id = 0) {
try {
auto result = Run(stream_id);
if (result) {
if (HasAsyncPart()) {
RecordEvent();
} else {
SetEventFinished();
}
} else {
SetEventFinished(getErrorMsg().c_str());
}
return result;
} catch (EnforceNotMet& err) {
SetEventFinishedWithException(err.what());
throw;
} catch (const std::exception& err) {
SetEventFinishedWithException(err.what());
throw;
} catch (...) {
SetEventFinishedWithException(getErrorMsg().c_str());
throw;
}
}