最近閱讀了兩篇關(guān)于聯(lián)邦學(xué)習(xí)的論文司顿,《Communication-Efficient Learning of Deep Networks from Decentralized Data》和《FedMD Heterogenous Federated Learning via Model Distillation》稻轨,提出了兩種算法FedAvg和FedMD徒坡。
Communication-Efficient Learning of Deep Networks from Decentralized Data
研究背景
Modern mobile devices have access to a wealth of data suitable for learning models, which in turn can greatly improve the user experience on the device. For example, language models can improve speech recognition and text entry, and image models can automatically select good photos.However, this rich data is often privacy sensitive, large in quantity, or both, which may preclude logging to the data center and training there using conventional approaches.
面對敏感或者很龐大的且不能直接上傳到data center的數(shù)據(jù)羞酗,可能無法使用傳統(tǒng)的方法集中訓(xùn)練模型。
主要研究內(nèi)容
We advocate an alternative that leaves the training data distributed on the mobile devices, and learns a shared model by aggregating locally-computed updates. We term this decentralized approach Federated Learning.
提出了一種代替性的模型訓(xùn)練方法——Federated Learning
We present a practical method for the federated learning of deep networks based on iterative model averaging, and conduct an extensive empirical evaluation, considering five different model architectures and four datasets. These experiments demonstrate the approach is robust to the unbalanced and non-IID data distributions that are a defining characteristic of this setting. Communication costs are the principal constraint, and we show a reduction in required communication rounds by 10–100× as compared to synchronized stochastic gradient descent.
提出了基于迭代模型平均的算法无蜂,對于非獨(dú)立同分布數(shù)據(jù)仍然有效伺糠,且相比同步隨即梯度下降減少了10-100倍的通信成本
More concretely, we introduce the FederatedAveraging algorithm, which combines local stochastic gradient descent (SGD) on each client with a server that performs model averaging. We perform extensive experiments on this algorithm, demonstrating it is robust to unbalanced and non-IID data distributions, and can reduce the rounds of communication needed to train a deep network on decentralized data by orders of magnitude.
提出FedAvg算法,將每個(gè)客戶端上的本地隨機(jī)梯度下降(SGD)與執(zhí)行模型平均的服務(wù)器相結(jié)合斥季。 對不平衡和非IID數(shù)據(jù)分布有很好的魯棒性训桶,并且可以減少對分散數(shù)據(jù)進(jìn)行深度網(wǎng)絡(luò)訓(xùn)練所需的通信次數(shù)。
聯(lián)邦學(xué)習(xí)
Federated Learning Ideal problems for federated learn- ing have the following properties: 1) Training on real-world data from mobile devices provides a distinct advantage over training on proxy data that is generally available in the data center. 2) This data is privacy sensitive or large in size (compared to the size of the model), so it is preferable not to log it to the data center purely for the purpose of model training (in service of the focused collection principle). 3) For supervised tasks, labels on the data can be inferred naturally from user interaction.
聯(lián)合學(xué)習(xí)聯(lián)合學(xué)習(xí)的理想問題具有以下特性:
訓(xùn)練來自移動設(shè)備的真實(shí)數(shù)據(jù)比訓(xùn)練數(shù)據(jù)中心通常提供的代理數(shù)據(jù)具有明顯的優(yōu)勢酣倾。
此數(shù)據(jù)是隱私敏感的或較大的(與模型的大小相比)舵揭,因此最好不要僅出于模型訓(xùn)練的目的(出于集中收集的原則)將其記錄到數(shù)據(jù)中心。
對于監(jiān)督任務(wù)灶挟,可以從用戶交互中自然推斷出數(shù)據(jù)上的標(biāo)簽琉朽。
隱私
Privacy Federated learning has distinct privacy advantages compared to data center training on persisted data. Holding even an “anonymized” dataset can still put user privacy at risk via joins with other data (Sweeney, 2000). In contrast, the information transmitted for federated learning is the minimal update necessary to improve a particular model (naturally, the strength of the privacy benefit depends on the content of the updates.) The updates themselves can (and should) be ephemeral. They will never contain more information than the raw training data (by the data processing inequality), and will generally contain much less. Further, the source of the updates is not needed by the aggregation algorithm, so updates can be transmitted without identifying meta-data over a mix network such as Tor (Chaum, 1981) or via a trusted third party. We briefly discuss the possibility of combining federated learning with secure multiparty computation and differential privacy at the end of the paper.
- FL 傳輸?shù)男畔⑹歉倪M(jìn)特定模型所必需的最小更新(隱私利益的強(qiáng)度取決于更新的內(nèi)容);
- 更新本身是短暫的,所包含的信息絕不會超過原始訓(xùn)練數(shù)據(jù)且通常會少得多稚铣;
- 聚合算法不需要更新源(不需要知道用戶是誰箱叁?),因此惕医,可以通過混合網(wǎng)絡(luò)(例如Tor)或通過受信任的第三方傳輸更新而無需標(biāo)識元數(shù)據(jù)耕漱。
- 將聯(lián)合學(xué)習(xí)與安全的多方計(jì)算及差分隱私相結(jié)合
聯(lián)邦優(yōu)化
在聯(lián)邦學(xué)習(xí)中的優(yōu)化問題
幾個(gè)關(guān)鍵屬性:
Non-IID The training data on a given client is typically based on the usage of the mobile device by a particular user, and hence any particular user’s local dataset will not be representative of the population distribution.
Unbalanced Similarly, some users will make much heavier use of the service or app than others, leading to varying amounts of local training data.
Massively distributed We expect the number of clients participating in an optimization to be much larger than the average number of examples per client.
Limited communication Mobile devices are frequently offline or on slow or expensive connections.
- 用戶數(shù)據(jù)非獨(dú)立同分布: 特定的用戶數(shù)據(jù)不能代表用戶的整體分布;
- 用戶數(shù)據(jù)量不平衡: 數(shù)據(jù)量不均衡抬伺,因?yàn)橛械挠脩羰褂枚嗝唬械挠脩羰褂蒙伲?/li>
- 用戶(分布)是大規(guī)模的: 參與優(yōu)化的用戶數(shù)大于平均每個(gè)用戶的數(shù)據(jù)量;
- 用戶端設(shè)備通信限制: 移動設(shè)備經(jīng)常掉線峡钓、速度緩慢妓笙、費(fèi)用昂貴。
最重要的就是非獨(dú)立同分布能岩、不平衡的數(shù)據(jù)和面對的通信約束
實(shí)踐中面對的問題:
A deployed federated optimization system must also address a myriad of practical issues: client datasets that change as data is added and deleted; client availability that correlates with the local data distribution in complex ways; and clients that never respond or send corrupted updates.
- 隨著數(shù)據(jù)添加和刪除而不斷改變的客戶端數(shù)據(jù)集寞宫;
- 客戶端(更新)的可用性與其本地?cái)?shù)據(jù)分布有著復(fù)雜的關(guān)系;
- 從來不響應(yīng)或發(fā)送信息的客戶端會損壞更新
但在本文中不做考慮
優(yōu)化方法
執(zhí)行方法
We assume a synchronous update scheme that proceeds in rounds of communication. There is a fixed set of K clients, each with a fixed local dataset. At the beginning of each round, a random fraction C of clients is selected, and the server sends the current global algorithm state to each of these clients (e.g., the current model parameters). Each client then performs local computation based on the global state and its local dataset, and sends an update to the server. The server then applies these updates to its global state, and the process repeats.
假設(shè)同步更新方案在各輪通信中進(jìn)行拉鹃;有一組固定的客戶端集合辈赋,大小為K鲫忍,每個(gè)客戶端都有一個(gè)固定的本地?cái)?shù)據(jù)集;
- 在每輪更新開始時(shí)钥屈,隨機(jī)選擇部分客戶端悟民,比例為C(C≤1);
- 服務(wù)器將當(dāng)前的全局算法的狀態(tài)發(fā)送給這些客戶(例如篷就,當(dāng)前的模型參數(shù))射亏;
- 每個(gè)客戶端都基于全局狀態(tài)及其本地?cái)?shù)據(jù)集執(zhí)行本地計(jì)算,并將更新發(fā)送到服務(wù)器腻脏;
- 最后鸦泳,服務(wù)器將這些更新應(yīng)用于其全局狀態(tài),然后重復(fù)該過程永品。
非凸神經(jīng)網(wǎng)絡(luò)的目標(biāo)函數(shù)
While we focus on non-convex neural network objectives, the algorithm we consider is applicable to any finite-sum objective of the form
For a machine learning problem, we typically take fi(w) = l(xi, yi; w), that is, the loss of the prediction on example (xi, yi) made with model parameters w. We assume there are K clients over which the data is partitioned, with Pk the set of indexes of data points on client k, with nk = |Pk|.
Thus, we can re-write the objective (1) as
對于機(jī)器學(xué)習(xí)問題,我們通常定義fi(w) = L(xi,yi;w)击纬;假設(shè)數(shù)據(jù)分布在K個(gè)客戶端鼎姐,Dk代表客戶端k數(shù)據(jù)點(diǎn)的集合,nk為Dk的大小更振,目標(biāo)函數(shù)可以重寫為:
如果劃分Dk是所有用戶數(shù)據(jù)的隨機(jī)取樣炕桨,則目標(biāo)函數(shù)f(w)就等價(jià)于損失函數(shù)關(guān)于Dk的期望:
(這就是傳統(tǒng)的分布式優(yōu)化問題的獨(dú)立同分布假設(shè))
通信成本和計(jì)算成本
**Thus, our goal is to use additional computation in order to decrease the number of rounds of communication needed to train a model. **There are two primary ways we can add computation: **1) increased parallelism, **where we use more clients working independently between each communication round; and, 2) increased computation on each client, where rather than performing a simple computation like a gradient calculation, each client performs a more complex calculation between each communication round. We investigate both of these approaches, but the speedups we achieve are due primarily to adding more computation on each client, once a minimum level of parallelism over clients is used.
在data center的優(yōu)化問題中,通信成本相對較小肯腕,而計(jì)算成本占主導(dǎo)地位献宫,重點(diǎn)是可以使用GPU來降低這些成本;
-
在聯(lián)合優(yōu)化中实撒,通信成本占主導(dǎo)地位:
通常會受到1 MB/s或更小的上傳帶寬的限制姊途;
并且客戶通常只會在充電,插入電源和不計(jì)量的Wi-Fi連接時(shí)自愿參與優(yōu)化知态;
希望每個(gè)客戶每天只參加少量的更新回合
而計(jì)算成本相對較薪堇肌:
任何單個(gè)設(shè)備上的數(shù)據(jù)集都比總數(shù)據(jù)集的大小小负敏;
現(xiàn)代智能手機(jī)具有相對較快的處理器(包括GPU)
因此贡茅,我們的目標(biāo)是使用額外的計(jì)算,以減少訓(xùn)練模型所需的通信次數(shù)其做。
考慮兩種方法來添加計(jì)算量:
- 提高并行性:在每個(gè)通信回合之間使用更多的客戶端獨(dú)立工作顶考;
- 增加每個(gè)客戶端的計(jì)算量: 不像梯度計(jì)算那樣執(zhí)行簡單的計(jì)算,而是每個(gè)客戶端在每個(gè)通信回合之間執(zhí)行更復(fù)雜的計(jì)算妖泄。
在最低級別的客戶端并行性后驹沿,我們實(shí)現(xiàn)的加速主要是由于在每個(gè)客戶端上添加了更多的計(jì)算。
Asynchronous distributed forms of SGD have also been applied to training neural net- works, e.g., Dean et al. (2012), but these approaches require a prohibitive number of updates in the federated setting. One endpoint of the (parameterized) algorithm family we consider is simple one-shot averaging, where each client solves for the model that minimizes (possibly regularized) loss on their local data, and these models are averaged to produce the final global model. This approach has been studied extensively in the convex case with IID data, and it is known that in the worst-case, the global model produced is no better than training a model on a single client (Zhang et al., 2012; Arjevani and Shamir, 2015; Zinkevich et al., 2010).
SGD的異步分布式形式也已用于訓(xùn)練神經(jīng)網(wǎng)絡(luò)
在眾多(參數(shù)化)算法中浮庐,我們最終考慮的是簡單一次平均(simple one-shot averaging)甚负,其中每個(gè)客戶解決的模型將其本地?cái)?shù)據(jù)的損失降到最低(可能是正則化的)柬焕,然后將這些模型取平均值以生成最終的全局模型。這種方法已經(jīng)在帶有獨(dú)立同分布數(shù)據(jù)的凸情況下進(jìn)行了廣泛的研究梭域,在最壞的情況下斑举,生成的全局模型并不比在單個(gè)客戶端上訓(xùn)練模型更好。
Federated Averaging Algorithm
The recent multitude of successful applications of deep learning have almost exclusively relied on variants of stochastic gradient descent (SGD) for optimization; in fact, many advances can be understood as adapting the structure of the model (and hence the loss function) to be more amenable to optimization by simple gradient-based methods (Goodfellow et al., 2016). Thus, it is natural that we build algorithms for federated optimization by starting from SGD.
- 深度學(xué)習(xí)的最新成功應(yīng)用幾乎都依賴于隨機(jī)梯度下降(SGD)的變體進(jìn)行優(yōu)化病涨;
- 實(shí)際上富玷,許多進(jìn)展可以理解為通過調(diào)整模型的結(jié)構(gòu)(或者損失函數(shù)),使其更易于使用簡單的gradient-based methods進(jìn)行優(yōu)化既穆。
baseline算法——FederatedSGD
SGD can be applied naively to the federated optimization problem, where a single batch gradient calculation (say on a randomly selected client) is done per round of communication. This approach is computationally efficient, but requires very large numbers of rounds of training to produce good models (e.g., even using an advanced approach like batch normalization, Ioffe and Szegedy (2015) trained MNIST for 50000 steps on minibatches of size 60). We consider this baseline in our CIFAR-10 experiments.
問題:計(jì)算效率很高赎懦,但是需要大量(多輪)訓(xùn)練才能生成好的模型
In the federated setting, there is little cost in wall-clock time to involving more clients, and so for our baseline we use large-batch synchronous SGD; experiments by Chen et al. (2016) show this approach is state-of-the-art in the data center setting, where it outperforms asynchronous approaches. To apply this approach in the federated setting, we select a C-fraction of clients on each round, and computes the gradient of the loss over all the data held by these clients. Thus, C controls the global batch size, with C = 1 corresponding to full-batch (non-stochastic) gradient descent. We refer to this baseline algorithm as FederatedSGD (or FedSGD).
SGD可以直接應(yīng)用于聯(lián)邦優(yōu)化,即每輪在隨機(jī)選擇的客戶端上進(jìn)行一次梯度計(jì)算
基線算法:大批量同步SGD(在data center中是最先進(jìn)的幻工,優(yōu)于異步方法)
FL形式: 每輪在clients中選擇C-fraction励两,計(jì)算這些clients的所有數(shù)據(jù)的損失函數(shù)梯度
參數(shù)C: 控制global batch size;C = 1即全批(非隨機(jī))梯度下降
FedratedAveraging
The amount of computation is controlled by three key parameters: C, the fraction of clients that perform computation on each round; E, then number of training passes each client makes over its local dataset on each round; and B, the local minibatch size used for the client updates. We write B = ∞ to indicate that the full local dataset is treated as a single minibatch. Thus, at one endpoint of this algorithm family, we can take B = ∞ and E = 1 which corresponds exactly to FedSGD. Complete pseudo-code is given in Algorithm 1
個(gè)人認(rèn)為囊颅,F(xiàn)edAvg可以看作FedSGD在用戶本地進(jìn)行多次梯度更新
幾個(gè)參數(shù):C当悔,B(本地batch),E(本地輪數(shù))
B = INF & E= 1 local_batchsize為全部數(shù)據(jù)踢代,更新一輪
模型的平均效果分析
對于一般的非凸目標(biāo)函數(shù)盲憎,參數(shù)空間中的平均模型可能會產(chǎn)生任意不好的模型結(jié)果。當(dāng)我們平均兩個(gè)從不同初始條件訓(xùn)練的MNIST數(shù)字識別模型時(shí)胳挎,我們恰好看到了這種不良結(jié)果(圖1饼疙,左)。
目前的一些實(shí)驗(yàn)顯示慕爬,從相同的隨機(jī)初始化開始訓(xùn)練兩個(gè)模型窑眯,然后在不同的數(shù)據(jù)子集上對每個(gè)模型進(jìn)行獨(dú)立訓(xùn)練,進(jìn)行樸素的參數(shù)平均效果很好(圖1澡罚,右)伸但。
The success of dropout training also provides some intuition for the success of our model averaging scheme; dropout training can be interpreted as averaging models of different architectures which share parameters, and the inference- time scaling of the model parameters is analogous to the model averaging used in FedAvg (Srivastava et al., 2014).
Dropout training可以解釋為共享參數(shù)的不同體系結(jié)構(gòu)的平均模型,模型參數(shù)的推理時(shí)間縮放類似于FedAvg中使用的模型平均留搔。
實(shí)驗(yàn)結(jié)果
對圖像分類和語言建模分別進(jìn)行了試驗(yàn)更胖,也對IID和non-IID進(jìn)行了試驗(yàn)
總結(jié)
For all three model classes, FedAvg converges to a higher level of test-set accuracy than the baseline FedSGD models. This trend continues even if the lines are extended beyond the plotted ranges. For example, for the CNN the B = ∞, E = 1 FedSGD model eventually reaches 99.22% accuracy after 1200 rounds (and had not improved further after 6000 rounds), while the B = 10, E = 20 FedAvg model reaches an accuracy of 99.44% after 300 rounds. We conjecture that in addition to lowering communication costs, model averaging produces a regularization benefit similar to that achieved by dropout (Srivastava et al., 2014).
We are primarily concerned with generalization performance, but FedAvg is effective at optimizing the training loss as well, even beyond the point where test-set accuracy plateaus. We observed similar behavior for all three model classes, and present plots for the MNIST CNN in Figure 6 in Appendix A
FedAvg收斂到比基準(zhǔn)FedSGD模型更高的測試集準(zhǔn)確性水平。(即使超出了繪制范圍隔显,這種趨勢仍將繼續(xù)却妨。)例如,對于CNN括眠,B =∞彪标,E = 1 FedSGD模型最終在1200輪后達(dá)到了99.22%的準(zhǔn)確度(并且在6000輪之后并沒有進(jìn)一步改善);而B = 10掷豺,E = 20的FedAvg模型達(dá)到了300輪后達(dá)到99.44%捞烟。
因此推測薄声,除了降低通信成本外,模型平均還產(chǎn)生了與dropout正則化相似的優(yōu)化效果题画。
FedAvg具有一定的泛化能力默辨,甚至可以優(yōu)化訓(xùn)練損失(超出測試集精度的穩(wěn)定水平)
That is, we would expect that while one round of averaging might produce a reasonable model, additional rounds of communication (and averaging) would not produce further improvements.
當(dāng)前模型參數(shù)僅通過初始化影響每個(gè)Client Update中執(zhí)行的優(yōu)化。 當(dāng)E→∞時(shí)苍息,至少對于凸問題缩幸,并且無論初始化如何,都將達(dá)到全局最小值竞思;對于非凸問題表谊,只要初始化是在同一個(gè)”盆地“中,算法也會收斂到相同的局部最小值盖喷。
對于某些模型爆办,尤其是在收斂的后期階段,以與降低學(xué)習(xí)率有用的相同方式來降低每輪的本地計(jì)算量(移動到較小的E或較大的B)可能是有用的课梳,但對于很大的E值押逼,我們看不到收斂速度的明顯下降。
其他發(fā)現(xiàn)
對SGD和FedAvg進(jìn)行minibatch B = 50的實(shí)驗(yàn)惦界,可以將精度視為進(jìn)行minibatch gradient calculations次數(shù)的函數(shù)。 我們希望SGD能表現(xiàn)得更好咙冗,因?yàn)樵诿看蝝inibatch computation之后都會采取一個(gè)順序步驟沾歪。 但是,如圖9雾消,對于適當(dāng)?shù)腃和E值灾搏,F(xiàn)edAvg在每次minibatch computation中取得相似的進(jìn)度。 此外立润,當(dāng)SGD和FedAvg每輪只有一個(gè)client時(shí)(綠)狂窑,準(zhǔn)確性顯著波動,而對更多clients進(jìn)行平均則可以解決這一問題(黃)桑腮。
FedAvg對比FedSGD
顯示了最佳學(xué)習(xí)率的單調(diào)學(xué)習(xí)曲線泉哈。 η= 0.4的FedSGD需要310輪才能達(dá)到8.1%的準(zhǔn)確度,而η= 18.0的FedAvg僅在20輪就達(dá)到了8.5%的準(zhǔn)確性(比FedSGD少15倍)破讨。
不同lr對于FedAvg的準(zhǔn)確性影響也要小得多
展望
FedAvg可以使用相對較少的輪次來訓(xùn)練高質(zhì)量的模型丛晦,聯(lián)邦學(xué)習(xí)是實(shí)際可行的
盡管聯(lián)邦學(xué)習(xí)目前提供了許多隱私優(yōu)勢,但是通過差分隱私提陶、多方安全計(jì)算及他們的組合是未來發(fā)展的方向
FedMD Heterogenous Federated Learning via Model Distillation
思路
相比之下這個(gè)思路看起來比較簡單烫沙,使用了Model DIstilation的思想
可以支持多個(gè)client使用不同的模型進(jìn)行本地訓(xùn)練
其中關(guān)鍵是存在一個(gè)公共數(shù)據(jù)集
Before a participant starts the collaboration phase, its model must first undergo the entire transfer learning process. It will be trained fully on the public dataset and then on its own private data. Therefore any future improvements are compared to this baseline.
在正式訓(xùn)練開始前,所有client先要進(jìn)行一輪預(yù)訓(xùn)練隙笆,預(yù)訓(xùn)練使用的數(shù)據(jù)集是公共數(shù)據(jù)集和各個(gè)client的本地?cái)?shù)據(jù)集
We re-purpose the public dataset D0 as the basis of communication between models, which is realized using knowledge distillation. Each learner fk expresses its knowledge by sharing the class scores, fk(x0i ), computed on the public dataset D0. The central server collects these class scores and computes an average f (xi ). Each party then trains fk to approach the consensus f (xi ). In this way, the knowledge of one participant can be understood by others without explicitly
sharing its private data or model architecture. Using the entire large public dataset can cause a large communication burden. In practice, the server may randomly select a much smaller subset dj ? D0 at each round as the basis of communication. In this way, the cost is under control and does not scale with the complexity of participating models.
預(yù)訓(xùn)練進(jìn)行完成之后進(jìn)行正式訓(xùn)練锌蓄,在正式訓(xùn)練過程中升筏,每一輪都從公共數(shù)據(jù)集中拿出一個(gè)small_batch進(jìn)行訓(xùn)練
然后每個(gè)client都根據(jù)自己的模型計(jì)算出一個(gè)output(向量,未經(jīng)過激活函數(shù))瘸爽,并把他上傳到server
server將對這些向量進(jìn)行一個(gè)平均您访,作為所有client的一個(gè)consensus,然后server會將這個(gè)consensus傳給每一個(gè)client蝶糯,然后各個(gè)client需要調(diào)整權(quán)重洋只,向著靠近sonsensus的方向調(diào)整。
實(shí)驗(yàn)結(jié)果
Figure 2: FedMD improves the test accuracy of participating models beyond their baselines. A dashed line (on the left) represents the test accuracy of a model after full transfer learning with the public dataset and its own small private dataset. This baseline is our starting point and overlaps with the beginning of the corresponding learning curve. A dash-dot line (on the right) represents the would-be performance of a model if private datasets from all participants were declassified and made available to every participant of the group.
FedMD提高了參與模型的測試精度昼捍,超出了其基線识虚。 虛線(左側(cè))代表使用公共數(shù)據(jù)集和其自己的小型私有數(shù)據(jù)集進(jìn)行完全轉(zhuǎn)移學(xué)習(xí)后模型的測試準(zhǔn)確性。 該基線是我們的起點(diǎn)妒茬,并且與相應(yīng)的學(xué)習(xí)曲線的起點(diǎn)重疊担锤。 點(diǎn)線(在右側(cè))表示如果將所有參與者的私人數(shù)據(jù)集解密并提供給組中的每個(gè)參與者,則模型的預(yù)期性能乍钻。