Caffe2核心代碼解析系列之六:Operator其一

介紹

在Caffe2的設(shè)計中邑跪,一切操作皆是Op画畅。其中數(shù)據(jù)IO操作相關(guān)的ops有CreateDBOp、PrefetchOperator症脂、ImageInputOp等淫僻;用于初始化數(shù)據(jù)的操作ops有UniformFillOp、GaussianFillOp等棕所;
還有真正用于計算的Ops像FullyConnectedOp悯辙、FullyConnectedGradientOp笑撞、ConvOp、ConvGradientOp等坚踩;用于多節(jié)點訓(xùn)練時梯度融合的操作如BroadcastOp瓤狐、AllreduceOp、AllgatherOp等嗓节;
其它則還有些用于各種參數(shù)更新算法的Ops如AdamOp皆警、AdagradOp、MomentumSGDOp等绸罗。

一般所有的Caffe2 Python前端代碼寫畢后豆瘫,兩張建好的static Graph就在背后生成出來了。一張Graph為init_net育灸,上面包含了一些初始化操作像Parameters的filler操作等或者多節(jié)點通信時
整體通信環(huán)境構(gòu)建時的操作像CreateCommonWorld昵宇、CloneCommonWorld等趟薄。通常init_net上的ops都只需執(zhí)行一次完成初始化目的即可典徊。另一張Graph則為net卒落,它上面包含了我們正式模型的基本所
用ops像一些用于計算的Ops如Conv/FC等,還有些則是用于更新參數(shù)的Ops像AdamOp/MomentumSGDOp等也切。它在init_net執(zhí)行完后再執(zhí)行腰湾,需要執(zhí)行多次,最終得到訓(xùn)練好的模型參數(shù)倒槐。

可以說Operator是Caffe2中最核心的元素之一附井,往往也是我們使用一個framework寫AI程序所需接觸最頻繁的一個部件永毅。

Caffe2中核心Operator的實現(xiàn)在兩個class里面:OperatorBase與Operator,其中Operator是OperatorBase的一個子類着逐。大多數(shù)我們用到的Op都是Operator的子類,只需要實現(xiàn)它的若干override
函數(shù)即可峰鄙,但也有些特別情況像PrefetchOp直接從OperatorBase繼承而來吟榴,因為我們想直接掌握它之上Op執(zhí)行時的輸出同步情況囊扳。

以下為caffe2核心代碼中與Operator相關(guān)的一些代碼文件。

core git:(master) ? ls operator
operator_c10wrapper.cc   operator.cc              operator_gradient.h      operator_schema.cc       operator_schema_test.cc
operator_c10wrapper.h    operator_gpu_test.cc     operator.h               operator_schema.h        operator_test.cc

我們將主要介紹一下OperatorBase與Operator這兩個類的基本APIs及其實現(xiàn)狭瞎,同時稍帶著也會看些其它像schema熊锭、gradient等之類的operator features雪侥。

OperatorBase

Observer和ObservableBase類

首先如下所示速缨,OperatorBase是Observable<OperatorBase>(本質(zhì)上是一個ObserverBase的一個實例類)的一個子類。這里使用了Observer的設(shè)計模式仿粹。即它考慮將每個Operator都設(shè)計為可觀察的對象原茅,以便在它執(zhí)行操作時有相關(guān)的Observer
可以去偵測它的執(zhí)行狀態(tài)擂橘。

class CAFFE2_API OperatorBase;
typedef ObserverBase<OperatorBase> OperatorObserver;

class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
 public:
  explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
  virtual ~OperatorBase() noexcept {}

如下為ObservableBase與Observer的定義,還是蠻簡單的契讲,我們可據(jù)此完成我們自己對某一類型T的observer滑频。在Caffe2中真正使用Observer來去監(jiān)控的對象有兩個峡迷,一個為這里介紹的Operator你虹,另一個則為我們以后將會去分析的network彤避。

template <class T>
class ObserverBase {
 public:
  explicit ObserverBase(T* subject) : subject_(subject) {}

  virtual void Start() {}
  virtual void Stop() {}

  virtual std::string debugInfo() {
    return "Not implemented.";
  }

  virtual ~ObserverBase() noexcept {};

  T* subject() const {
    return subject_;
  }

  virtual std::unique_ptr<ObserverBase<T>> rnnCopy(T* subject, int rnn_order)
      const {
    return nullptr;
  };

 protected:
  T* subject_;
};

/**
 *  Inherit to make your class observable.
 */
template <class T>
class Observable {
 public:
  Observable() = default;

  Observable(Observable&&) = default;
  Observable& operator =(Observable&&) = default;

  virtual ~Observable() = default;

  C10_DISABLE_COPY_AND_ASSIGN(Observable);

  using Observer = ObserverBase<T>;
  /* Returns a reference to the observer after addition. */
  const Observer* AttachObserver(std::unique_ptr<Observer> observer) {
    CAFFE_ENFORCE(observer, "Couldn't attach a null observer.");
    std::unordered_set<const Observer*> observers;
    for (auto& ob : observers_list_) {
      observers.insert(ob.get());
    }

    const auto* observer_ptr = observer.get();
    if (observers.count(observer_ptr)) {
      return observer_ptr;
    }
    observers_list_.push_back(std::move(observer));
    UpdateCache();

    return observer_ptr;
  }
  /**
   * Returns a unique_ptr to the removed observer. If not found, return a
   * nullptr
   */
  std::unique_ptr<Observer> DetachObserver(const Observer* observer_ptr) {
    for (auto it = observers_list_.begin(); it != observers_list_.end(); ++it) {
      if (it->get() == observer_ptr) {
        auto res = std::move(*it);
        observers_list_.erase(it);
        UpdateCache();
        return res;
      }
    }
    return nullptr;
  }
 private:
  inline static void StartObserver(Observer* observer) {
    try {
      observer->Start();
    } catch (const std::exception& e) {
      LOG(ERROR) << "Exception from observer: " << e.what();
    } catch (...) {
      LOG(ERROR) << "Exception from observer: unknown";
    }
  }

  inline static void StopObserver(Observer* observer) {
    try {
      observer->Stop();
    } catch (const std::exception& e) {
      LOG(ERROR) << "Exception from observer: " << e.what();
    } catch (...) {
      LOG(ERROR) << "Exception from observer: unknown";
    }
  }
  .............
  .............
 private:
  // an on-stack cache for fast iteration;
  // ideally, inside StartAllObservers and StopAllObservers,
  // we should never access observers_list_
  Observer* observer_cache_;
  size_t num_observers_ = 0;

 protected:
  std::vector<std::unique_ptr<Observer>> observers_list_;
};

Input/Output blobs及Arguments基本操作

OperatorBase里面包含了些最基本的Operator所需的操作像輸入、輸出Blob處理(獲得或其之上類型查詢圆米、Copy等)娄帖,Operator參數(shù)處理等。

以下為它用于參數(shù)處理的兩個示例API函數(shù)诈嘿。

  /** @brief Checks if the operator has an argument of the given name.
   */
  inline bool HasArgument(const string& name) const {
    CAFFE_ENFORCE(operator_def_, "operator_def was null!");
    return ArgumentHelper::HasArgument(*operator_def_, name);
  }

  // Functions that deal with arguments. Basically, this allows us to map an
  // argument name to a specific type of argument that we are trying to access.
  template <typename T>
  inline T GetSingleArgument(const string& name, const T& default_value) const {
    CAFFE_ENFORCE(operator_def_, "operator_def was null!");
    return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
        *operator_def_, name, default_value);
  }

下面則為它處理Input/Output blob的一些API函數(shù)奖亚。其中inputs_/outputs_分別是一個operator上所具有的blobs成員屬性佩耳。有了之前幾章講過的Blob與Tensor的知識后谭跨,這些函數(shù)也就比較好懂了螃宙。

  // Get the inputs and outputs as specific types.
  template <typename T>
  inline const T& Input(int idx) {
    static_assert(
        !std::is_same<T, Tensor>::value,
        "You should use Input<Tensor>(int, DeviceType) for "
        "Tensor.");
    DCHECK_LT(idx, inputs_.size());
    try {
      return inputs_.at(idx)->template Get<T>();
    } catch (::caffe2::EnforceNotMet& enf) {
      if (has_debug_def()) {
        enf.AppendMessage(".\nOffending Blob name: ");
        enf.AppendMessage(debug_def().input(idx));
        enf.AppendMessage(".\n");
      }
      throw enf;
    }
  }

  // TODO(jerryzh): Remove template
  // and the type argument?
  // This is to keep the API changes minimal and make refactoring
  // a bit easier
  template <typename T>
  inline const T& Input(int idx, DeviceType type) {
    static_assert(
        std::is_same<T, Tensor>::value,
        "Input(int, DeviceType) is only available for Tensor");
    DCHECK_LT(idx, inputs_.size());
    try {
      // TODO(jerryzh): We'll need to check device type in Get<T>() later
      // Get<T>() -> Get<T>(type)
      const auto& tensor = inputs_.at(idx)->template Get<T>();
      return tensor;
    } catch (::caffe2::EnforceNotMet& enf) {
      if (has_debug_def()) {
        enf.AppendMessage(".\nOffending Blob name: ");
        enf.AppendMessage(debug_def().input(idx));
        enf.AppendMessage(".\n");
      }
      throw enf;
    }
  }

  template <typename T>
  inline T* Output(int idx, DeviceType type) {
    static_assert(
        std::is_same<T, Tensor>::value,
        "Output(int, DeviceType) is only available for Tensor");
    // When you get a Tensor here it is not fully initialized
    return BlobGetMutableTensor(outputs_.at(idx), type);
  }

  inline Tensor*
  OutputTensor(int idx, at::IntList dims, at::TensorOptions options) {
    CAFFE_ENFORCE_WITH_CALLER(
        options.device_opt() != c10::nullopt,
        "device must be provided in option.");
    return BlobGetMutableTensor(outputs_.at(idx), dims, options);
  }

  template <typename T>
  inline T* Output(int idx, T* allocated) {
    outputs_.at(idx)->Reset(allocated);
    return allocated;
  }

其它與輸入挂捅、輸出相關(guān)的API函數(shù)還有如下一些utilities函數(shù)堂湖。(當(dāng)然這里只是一部分无蜂,但大致皆是如此,多是些淺顯易懂的斥季,多是像OutputIsType這樣直接使用了Blob的成員函數(shù))。

template <typename T>
  inline bool OutputIsType(int idx) {
    static_assert(
        !std::is_same<T, Tensor>::value,
        "You should use OutputIsTensorType(int, DeviceType) for "
        "Tensor.");
    return outputs_.at(idx)->template IsType<T>();
  }

  inline bool OutputIsTensorType(int idx, DeviceType type) {
    return BlobIsTensorType(*outputs_.at(idx), type);
  }

  inline int InputSize() const {
    return inputs_.size();
  }
  inline int OutputSize() const {
    return outputs_.size();

Event處理相關(guān)async操作函數(shù)

下面一些與Event相關(guān)的操作才是真正開始有趣的東西谤专。每個Caffe2 op都有一個Event成員(Event在將來我們也會單獨拿出來介紹)午绳,用于同步op間的依賴執(zhí)行操作箱叁。它是一個op async執(zhí)行時保證parent op與child op之間同步的產(chǎn)物。如果讀者
熟悉CUDA編程APIs算色,那么對于cudaEvent一定不陌生螟够,本質(zhì)上在CUDAOperator當(dāng)中,它之上的event_就是一個cudaEvent的wrapper若河。

簡單說來寞宫,async執(zhí)行時辈赋,我們有些op可能需要去check(Query或Wait)父類的event以確定它是否執(zhí)行完了(即自己的輸入有保證了)。同時它們也要在執(zhí)行過后(或者是完成了自己op執(zhí)行命令的提交)悟民,更新自己的event篷就,以對自己的childeren operators發(fā)揮影響竭业。

  virtual void WaitEvent(const Event& ev, int /*stream_id */ = -1) {
    ev.Finish();
  }

  inline void Wait(const OperatorBase& other, int stream_id = -1) {
    if (!other.IsEventDisabled()) {
      WaitEvent(other.event(), stream_id);
    }
  }

  virtual void WaitEvents(
      const std::vector<const Event*>& events,
      int /*stream_id*/ = -1) {
    for (const auto& ev : events) {
      ev->Finish();
    }
  }

  virtual void Finish() {
    if (event_) {
      event_->Finish();
    }
  }
  
    const Event& event() const {
    CAFFE_ENFORCE(event_, "Event is disabled");
    return *event_;
  }

  Event& event() {
    CAFFE_ENFORCE(event_, "Event is disabled");
    return *event_;
  }

  void ResetEvent() {
    if (event_) {
      event_->Reset();
    }
  }

  void DisableEvent() {
    event_ = nullptr;
  }

  bool IsEventDisabled() const {
    return !event_;
  }

 protected:
  virtual void RecordEvent(const char* /*err_msg*/ = nullptr) {
    CAFFE_NOT_IMPLEMENTED;
  }

  void SetEventFinished(const char* err_msg = nullptr) {
    if (event_) {
      event_->SetFinished(err_msg);
    }
  }

  void SetEventFinishedWithException(const char* err_msg = nullptr) {
    if (event_) {
      event_->SetFinishedWithException(err_msg);
    }
  }

下面的兩個APIs函數(shù)同樣與op的async執(zhí)行相關(guān)未辆。在OperatorBase中將它一律簡單設(shè)為了false即op默認(rèn)不support這兩種feature(至于feature具體為何義,我們將在Operator講解時進行闡述)钾麸。
可見大部分與async執(zhí)行相關(guān)的Op操作都被放入了Operator中饭尝,OperatorBase這里有API人,但其實形同虛設(shè)实撒。所以如果你在實現(xiàn)自己的Op時涉瘾,如果想遵循Caffe2中已有一套Op async執(zhí)行邏輯,那么可將自己op實現(xiàn)為
operator的子類(這是絕大多數(shù)時候的正常做法负敏,畢竟caffe2是assume用戶會直接遵循自己關(guān)于event/op async執(zhí)行的一套做法的即operator中的那些默認(rèn)做法)秘蛇;但如果你對某一Device context上的Op async執(zhí)行
有著與framework original design略為不同的想法時赁还,那么你就可以將op實現(xiàn)為OperatorBase的子類,然后去重載這幾個與async執(zhí)行蹈胡、Events操作相關(guān)的APIs函數(shù)朋蔫。

  virtual bool HasAsyncPart() const {
    return false;
  }

  virtual bool SupportsAsyncScheduling() const {
    return false;
  }

Utility函數(shù)

下面是一些用于出錯時debug的APIs函數(shù)斑举,無關(guān)核心病涨。

  inline const OperatorDef& debug_def() const {
    CAFFE_ENFORCE(has_debug_def(), "operator_def was null!");
    return *operator_def_;
  }

  inline void set_debug_def(
      const std::shared_ptr<const OperatorDef>& operator_def) {
    operator_def_ = operator_def;
  }

  inline bool has_debug_def() const {
    return operator_def_ != nullptr;
  }
 public:
  void RecordLastFailedOpNetPosition() {
    if (net_position_ != kNoNetPositionSet) {
      VLOG(1) << "Operator with id " << net_position_ << " failed";
      operator_ws_->last_failed_op_net_position = net_position_;
    } else {
      VLOG(1) << "Failed operator doesn't have id set";
    }
  }

  int net_position() const {
    return net_position_;
  }

  void set_net_position(int idx) {
    net_position_ = idx;
  }

Caffe2中每個Operator都有Device Context的概念即此Op操作是在哪個device之上進行的既穆。Device的具體信息則可從DeviceOption中獲得。

  const DeviceOption& device_option() const {
    return device_option_;
  }

另外每個Operator又可根據(jù)其Engine不同有著不一樣的實現(xiàn)(在同一Device context下面)励两。這個設(shè)計follow了之前Caffe中的思路(如Layer的類型可以為GPU engine或CPU engine等)当悔。
要知道像其它Caffe2組件如Blob/Tensor/Context等使用注冊機制來完成具體定義一樣,Operator也是在完成后需要去注冊在某一device相關(guān)的OperatorRegistry下面的嗅骄。一般Operator name會
作為此operator實現(xiàn)item的key饼疙。而若我們在同一device下面對同一類型op(有同樣的名字op_name)有著不同的實現(xiàn)(比如Allreduce操作在CPU設(shè)計上可以使用MPI窑眯,也可使用Gloo來完成其底層的具體通信aggregation),
那我們就可使用engine來區(qū)別這一不同的實現(xiàn)炊林。

下面是與engine相關(guān)的一些base operator函數(shù)卷要。

  void annotate_engine(const std::string& engine) {
    engine_ = engine;
  }

  const std::string& engine() const {
    return engine_;
  }

像其它許多深度學(xué)習(xí)Framework一樣却妨,Caffe2對于Op操作執(zhí)行在GPU device上有著很強的假設(shè)(very bad assumptions,它讓CUDA生態(tài)在DL領(lǐng)域過于強大了)倍权。如下函數(shù)即為此類假設(shè)的一個縮影捞烟。

  // Checks whether stream is ready to execute new computation,
  // used in stream allocation optimization to skip stream that is currently
  // busy. Depends on context and operator's device, returns true by default
  virtual bool IsStreamFree(int /* unused */) const {
    return true;
  }

核心的Run及RunAsync函數(shù)

最終我們看下關(guān)鍵的Op真正運行的兩個函數(shù)吧题画。如下示,其中Run一般為op sync執(zhí)行時所調(diào)用的缩幸,而RunAsync則是其async執(zhí)行時所調(diào)用的竞思「桥纾可以看出在執(zhí)行時都默認(rèn)我們會將操作提交到一stream上,這顯然亦是
由CUDA編程API中所用到的cuda_stream的思想而來距辆。。

  virtual bool Run(int /* unused */ /*stream_id*/ = 0) {
    CAFFE_NOT_IMPLEMENTED;
  }

    
  // RunAsync, if implemenented by the specific operators, will schedule the
  // computation on the corresponding context and record the event in its
  // event_ member object. If the specific operator does not support RunAsync,
  // it will simply be synchronous as a fallback.
  virtual bool RunAsync(int stream_id = 0) {
    try {
      auto result = Run(stream_id);
      if (result) {
        if (HasAsyncPart()) {
          RecordEvent();
        } else {
          SetEventFinished();
        }
      } else {
        SetEventFinished(getErrorMsg().c_str());
      }
      return result;
    } catch (EnforceNotMet& err) {
      SetEventFinishedWithException(err.what());
      throw;
    } catch (const std::exception& err) {
      SetEventFinishedWithException(err.what());
      throw;
    } catch (...) {
      SetEventFinishedWithException(getErrorMsg().c_str());
      throw;
    }
  }

參考文獻

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市雾消,隨后出現(xiàn)的幾起案子挫望,更是在濱河造成了極大的恐慌,老刑警劉巖桑腮,帶你破解...
    沈念sama閱讀 218,682評論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件破讨,死亡現(xiàn)場離奇詭異奕纫,居然都是意外死亡匹层,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,277評論 3 395
  • 文/潘曉璐 我一進店門撑柔,熙熙樓的掌柜王于貴愁眉苦臉地迎上來铅忿,“玉大人灵汪,你說我怎么就攤上這事≈叮” “怎么了担锤?”我有些...
    開封第一講書人閱讀 165,083評論 0 355
  • 文/不壞的土叔 我叫張陵肛循,是天一觀的道長多糠。 經(jīng)常有香客問我,道長被盈,這世上最難降的妖魔是什么搭伤? 我笑而不...
    開封第一講書人閱讀 58,763評論 1 295
  • 正文 為了忘掉前任怜俐,我火速辦了婚禮,結(jié)果婚禮上贴谎,老公的妹妹穿的比我還像新娘季稳。我一直安慰自己景鼠,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 67,785評論 6 392
  • 文/花漫 我一把揭開白布谭确。 她就那樣靜靜地躺著逐哈,像睡著了一般问顷。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上肠骆,一...
    開封第一講書人閱讀 51,624評論 1 305
  • 那天蚀腿,我揣著相機與錄音,去河邊找鬼廓脆。 笑死磁玉,一個胖子當(dāng)著我的面吹牛蚊伞,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播氧枣,決...
    沈念sama閱讀 40,358評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼便监,長吁一口氣:“原來是場噩夢啊……” “哼碳想!你這毒婦竟也來了胧奔?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,261評論 0 276
  • 序言:老撾萬榮一對情侶失蹤胳泉,失蹤者是張志新(化名)和其女友劉穎岩遗,沒想到半個月后宿礁,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,722評論 1 315
  • 正文 獨居荒郊野嶺守林人離奇死亡控汉,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,900評論 3 336
  • 正文 我和宋清朗相戀三年姑子,在試婚紗的時候發(fā)現(xiàn)自己被綠了街佑。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,030評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡岳服,死狀恐怖希俩,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情纲辽,我是刑警寧澤颜武,帶...
    沈念sama閱讀 35,737評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站拖吼,受9級特大地震影響鳞上,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜吊档,卻給世界環(huán)境...
    茶點故事閱讀 41,360評論 3 330
  • 文/蒙蒙 一篙议、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧怠硼,春花似錦鬼贱、人聲如沸香璃。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,941評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽葡秒。三九已至姻乓,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間眯牧,已是汗流浹背蹋岩。 一陣腳步聲響...
    開封第一講書人閱讀 33,057評論 1 270
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留学少,地道東北人星澳。 一個月前我還...
    沈念sama閱讀 48,237評論 3 371
  • 正文 我出身青樓,卻偏偏與公主長得像旱易,于是被迫代替她去往敵國和親禁偎。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,976評論 2 355

推薦閱讀更多精彩內(nèi)容