1. KVStore里的Barrier
在mxnet的分布式訓(xùn)練里四濒,主要模式就是參數(shù)服務(wù)器换况。每個(gè)worker或者agent就是一臺(tái)machine,server用于參數(shù)的更新峻黍。那么复隆,當(dāng)我們期望在不同的worker之間進(jìn)行同步的時(shí)候拨匆,就會(huì)需要到barrier
這個(gè)方法姆涩。
當(dāng)代碼運(yùn)行在worker的時(shí)候,我們可以通過(guò)調(diào)用kv._barrier()
來(lái)進(jìn)行同步惭每。它的作用就是骨饿,會(huì)阻塞代碼運(yùn)行亏栈,直到每個(gè)worker都運(yùn)行了kv._barrier()
。然后接著運(yùn)行宏赘。這樣就實(shí)現(xiàn)了同步绒北。
那么它是怎么做到的呢?
通過(guò)源碼察署,我們不難發(fā)現(xiàn)闷游,python端的接口調(diào)用了c++端的方法:
void Barrier() override {
ps::Postoffice::Get()->Barrier(ps_worker_->get_customer()->customer_id(), ps::kWorkerGroup);
}
這個(gè)全局的Postoffice
的Barrier
方法的部分源碼如下:
void Postoffice::Barrier(int customer_id, int node_group) {
// 省略部分代碼
// 省略部分代碼
std::unique_lock<std::mutex> ulk(barrier_mu_);
barrier_done_[0][customer_id] = false;
Message req;
req.meta.recver = kScheduler;
req.meta.request = true;
req.meta.control.cmd = Control::BARRIER;
req.meta.app_id = 0;
req.meta.customer_id = customer_id;
req.meta.control.barrier_group = node_group;
req.meta.timestamp = van_->GetTimestamp();
CHECK_GT(van_->Send(req), 0);
barrier_cond_.wait(ulk, [this, customer_id] {
return barrier_done_[0][customer_id];
});
}
可以看到該方法會(huì)首先對(duì)barrier_mu_
上鎖,之后將對(duì)應(yīng)的barrier_done_
設(shè)置為false
贴汪。然后將這次的barrier信息發(fā)送給scheduler脐往。告訴scheduler需要進(jìn)行一次barrier。然后就阻塞等待barrier_done_
被設(shè)置為true
扳埂,代表完成了barrier业簿,也就是其他的worker也都進(jìn)行了barrier。
那么問(wèn)題就變成了阳懂,每個(gè)worker都是怎么直到其他worker也進(jìn)行了barrier的梅尤?
首先我們要知道,在參數(shù)服務(wù)器也就是PS中岩调,每個(gè)進(jìn)程都會(huì)建立kvstore巷燥。如果是worker,會(huì)在構(gòu)造函數(shù)中運(yùn)行如下代碼:
if (IsWorkerNode()) {
int new_customer_id = GetNewCustomerId();
ps_worker_ = new ps::KVWorker<char>(0, new_customer_id);
ps::StartAsync(new_customer_id, "mxnet\0");
if (!ps::Postoffice::Get()->is_recovery()) {
ps::Postoffice::Get()->Barrier(
new_customer_id,
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
}
}
其中ps::StartAsync
如下:
inline void StartAsync(int customer_id, const char* argv0 = nullptr) {
Postoffice::Get()->Start(customer_id, argv0, false);
}
也就是說(shuō)誊辉,worker在建立起ps_worker_
后矾湃,開始運(yùn)行postoffice,而postoffice的Start
會(huì)進(jìn)行一系列的操作堕澄,并調(diào)用van_->Start
邀跃,接著van
的Start
會(huì)進(jìn)行一系列的初始化后,開啟接受消息的線程蛙紫,也就是
receiver_thread_ = std::unique_ptr<std::thread>(
new std::thread(&Van::Receiving, this));
而receiving
函數(shù)會(huì)使用ProcessBarrierCommand
處理barrier信號(hào)拍屑,該函數(shù)會(huì)++barrier_count_[group]
,也就是將對(duì)應(yīng)group的barrier次數(shù)進(jìn)行統(tǒng)計(jì)坑傅。當(dāng)barrier_count_[group]
等于這個(gè)group的個(gè)數(shù)的時(shí)候僵驰。它會(huì)發(fā)送類似于ACK的返回信息。
然后worker會(huì)調(diào)用Manage
方法來(lái)處理該message唁毒。Manage
發(fā)現(xiàn)是barrier的返回信息蒜茴,將barrier_done_設(shè)置為true
,然后將等待的線程喚醒浆西。也就是python端調(diào)用barrier后被阻塞的地方粉私。
至此,就完成了一次worker之間的barrier近零。