本論文從分布式系統的角度開展針對當前一些機器學習平臺的研究,綜述了這些平臺所使用的架構設計,對這些平臺在通信和控制上的瓶頸、容錯性和開發難度進行分析和對比,并對分布式機器學習平臺的未來研究工作提出了一些建議。文中的工作由MuratDemirbas教授與他的研究生KuoZhang和SalemAlqahtani共同完成。
機器學習,特別是深度學習,已在語音識別、圖像識別和自然語言處理以及近期在推薦及搜索引擎等領域上取得了革命性的成功。這些技術在無人駕駛、數字醫療系統、CRM、廣告、物聯網等領域具有很好的應用前景。當然,是資金引領和驅動了技術的加速推進,使得我們在近期看到了一些機器學習平臺的推出。
考慮到訓練中所涉及的數據集和模型的規模十分龐大,機器學習平臺通常是分布式平臺,部署了數十個乃至數百個并行運行的計算節點對模型做訓練。據估計在不遠的將來,數據中心的大多數任務都會是機器學習任務。
我來自于分布式系統研究領域,因此我們考慮從分布式系統的角度開展針對這些機器學習平臺的研究,分析這些平臺在通信和控制上的瓶頸。我們還考慮了這些平臺的容錯性和易編程性。
我們從設計方法上將機器學習平臺劃分為三個基本類別,分別是:基本數據流、參數-服務器模型和高級數據流。
下面我們將對每類方法做簡要介紹,以ApacheSpark為例介紹基本數據流,以PMLS(Petuum)為例介紹參數服務器模型,而高級數據流則使用TensorFlow和MXNet為例。我們對比了上述各平臺的性能并給出了一系列的評估結果。要了解詳細的評估結果,可參考我們的論文。遺憾的是,作為一個小型研究團隊,我們無法開展大規模的評估。
在本篇博文的最后,我給出了一些結論性要點,并對分布式機器學習平臺的未來研究工作提出了一些建議。對這些分布式機器學習平臺已有一定了解的讀者,可以直接跳到本文結尾。
Spark
在Spark中,計算被建模為一種有向無環圖(DAG),圖中的每個頂點表示一個RDD,每條邊表示了RDD上的一個操作。RDD由一系列被切分的對象(Partition)組成,這些被切分的對象在內存中存儲并完成計算,也會在Shuffle過程中溢出(Overflow)到磁盤上
在DAG中,一條從頂點A到B的有向邊E,表示了RDDB是在RDDA上執行操作E的結果。操作分為轉換(Transformation)和動作(Action)兩類。轉換操作(例如map、filter和join)應用于某個RDD上,轉換操作的輸出是一個新的RDD。
Spark用戶將計算建模為DAG,該DAG表示了在RDD上執行的轉換和動作。DAG進而被編譯為多個Stage。每個Stage執行為一系列并行運行的任務(Task),每個分區(Partition)對應于一個任務。這里,有限(Narrow)的依賴關系將有利于計算的高效執行,而寬泛(Wide)的依賴關系則會引入瓶頸,因為這樣的依賴關系引入了通信密集的Shuffle操作,這打斷了操作流。
Spark的分布式執行是通過將DAGStage劃分到不同的計算節點實現的。上圖清晰地展示了這種主機(master)-工作者(worker)架構。驅動器(Driver)包含有兩個調度器(Scheduler)組件,即DAG調度器和任務調度器。調度器對工作者分配任務,并協調工作者。
Spark是為通用數據處理而設計的,并非專用于機器學習任務。要在Spark上運行機器學習任務,可以使用MLlibforSpark。如果采用基本設置的Spark,那么模型參數存儲在驅動器節點上,在每次迭代后通過工作者和驅動器間的通信更新參數。如果是大規模部署機器學習任務,那么驅動器可能無法存儲所有的模型參數,這時就需要使用RDD去容納所有的參數。這將引入大量的額外開銷,因為為了容納更新的模型參數,需要在每次迭代中創建新的RDD。更新模型會涉及在機器和磁盤間的數據Shuffle,進而限制了Spark的擴展性。這正是基本數據流模型(即DAG)的短板所在。Spark并不能很好地支持機器學習中的迭代運算。
PMLS
PMLS是專門為機器學習任務而設計的。它引入了稱為參數-服務器(Parameter-Server,PS)的抽象,這種抽象是為了支持迭代密集的訓練過程。
PS(在圖中以綠色方框所示)以分布式key-value數據表形式存在于內存中,它是可復制和分片的。每個節點(node)都是模型中某個分片的主節點(參數空間),并作為其它分片的二級節點或復制節點。這樣PS在節點數量上的擴展性很好。
PS節點存儲并更新模型參數,并響應來自于工作者的請求。工作者從自己的本地PS拷貝上請求最新的模型參數,并在分配給它們的數據集分區上執行計算。
PMLS也采用了SSP(StaleSynchronousParallelism)模型。相比于BSP(BulkSynchronousParellelism)模型,SSP放寬了每一次迭代結束時各個機器需做同步的要求。為實現同步,SSP允許工作者間存在一定程度上的不同步,工業機器人維修,并確保了最快的工作者不會領先最慢的工作者s輪迭代以上。由于處理過程處于誤差所允許的范圍內,這種非嚴格的一致性模型依然適用于機器學習。我曾經發表過一篇博文專門介紹這一機制。
TensorFlow
Google給出了一個基于分布式機器學習平臺的參數服務器模型,稱為DistBelief(此處是我對DistBelief論文的綜述)。就我所知,大家對DistBelief的不滿意之處主要在于,它在編寫機器學習應用時需要混合一些底層代碼。Google想使其任一雇員都可以在無需精通分布式執行的情況下編寫機器學習代碼。正是出于同一原因,Google對大數據處理編寫了MapReduce框架。
TensorFlow是一種設計用于實現這一目標的平臺。它采用了一種更高級的數據流處理范式,其中表示計算的圖不再需要是DAG,圖中可以包括環,并支持可變狀態。我認為TensorFlow的設計在一定程度上受到了Naiad設計理念的影響。
TensorFlow將計算表示為一個由節點和邊組成的有向圖。節點表示計算操作或可變狀態(例如Variable),邊表示節點間通信的多維數組,這種多維數據稱為Tensor。TensorFlow需要用戶靜態地聲明邏輯計算圖,并通過將圖重寫和劃分到機器上實現分布式計算。需說明的是,MXNet,特別是DyNet,使用了一種動態定義的圖。這簡化了編程,并提高了編程的靈活性。
如上圖所示,在TensorFlow中,分布式機器學習訓練使用了參數-服務器方法。當在TensorFlow中使用PS抽象時,就使用了參數-服務器和數據并行。TensorFlow聲稱可以完成更復雜的任務,但是這需要用戶編寫代碼以通向那些未探索的領域。
MXNet
MXNet是一個協同開源項目,源自于在2015年出現的CXXNet、Minverva和Purines等深度學習項目。類似于TensorFlow,MXNet也是一種數據流系統,支持具有可變狀態的有環計算圖,并支持使用參數-服務器模型的訓練計算。同樣,MXNet也對多個CPU/GPU上的數據并行提供了很好的支持,并可實現模型并行。MXNet支持同步的和異步的訓練計算。下圖顯示了MXNet的主要組件。其中,運行時依賴引擎分析計算過程中的依賴關系,對不存在相互依賴關系的計算做并行處理。MXNet在運行時依賴引擎之上提供了一個中間層,用于計算圖和內存的優化。
MXNet使用檢查點機制支持基本的容錯,提供了對模型的save和load操作。save操作將模型參數寫入到檢查點文件,load操作從檢查點文件中讀取模型參數。