seri::diary

日常

分散深層ニューラルネットワークの実装アプローチまとめ(2018年6月版)

これは何か

  • 自分が研究テーマとして扱っている分散深層ニューラルネットワークには、「分散」処理の部分において複数のアプローチが存在する
  • このエントリでは、自分の知識の整理のためにこれまで調べたことをまとめておく
  • (2018年6月版) と書いたのは、深層学習業界は変化が激しすぎて半年後には状況が変わっていてこのエントリが役に立たなくなることを想定しているためである(その時はまた新しい版を書こうかなと)

分散深層ニューラルネットワークとは

一般に「深層学習」と呼ばれる機械学習手法においては、隠れ層が多数連なる「深層ニューラルネットワーク」が使用される。
昨今では隠れ層の数を大規模に増加させて学習を行う手法が、主に画像認識の分野で効果を上げており*1、それに伴って訓練に要する時間も増加傾向にある。この問題を解決するために、スループットを向上させるためのアプローチとして分散処理を深層ニューラルネットワークに導入する手法が注目されている。

分散深層ニューラルネットワークでは大きくわけて2つのアプローチ、データ並列分散訓練モデル並列分散訓練 が存在する。以下、それぞれについて解説する。

データ並列分散訓練

f:id:serihiro:20180602214523p:plain

同一のネットワークの複製を複数用意し、それぞれのネットワーク(以下レプリカと呼称)に対して訓練データを適用して並列に訓練を行うアプローチである。 各レプリカは、通常のニューラルネットワークと同様に訓練を行い、個別に勾配を計算して重み行列やフィルタなどのモデルパラメータを更新する。

ここまでは並列でない通常のニューラルネットワークと変わらないが、データ並列分散訓練においては、各レプリカの訓練で得た勾配を、数iterationごと、あるいは毎iterationごとに集約して、平均化したもので各レプリカの勾配を置き変える。これにより、訓練データを分割して並列に訓練を行いつつも、逐次実行で全訓練データを用いて訓練したかのような結果を得ることができる。

この際、レプリカ間で訓練に用いる勾配の「鮮度」が課題となる。常に最新の(つまり鮮度が高い)勾配を用いて訓練を行うには、毎iterationごとにレプリカ間で処理同期を取った上で、勾配を集計し同期する必要がある。しかし、この集計処理がボトルネックとなり、ネットワークのスループット低下につながる。

一方で、勾配の更新頻度を下げすぎると、勾配が同期されるまでは、各レプリカ上のローカルな勾配を用いて訓練が進み、勾配の「鮮度」が劣化する。これにより、レプリカ間のモデルパラメータの差異が大きくなり、損失の収束を遅らせるなどの悪影響をもたらすという問題が生じる。

実際の深層学習フレームワークにおいては、同期処理によるスループット低下を防ぐために、非同期でレプリカ毎に異なるタイミングで勾配を更新する手法を採用しているものもある。以下、同期、非同期それぞれの手法について解説する。

勾配の同期更新

f:id:serihiro:20180602215407p:plain

ニューラルネットワークにおいて、勾配は訓練におけるBackwardフェーズでレイヤー毎に計算される。そして求めた勾配に学習係数をかけた値を使ってモデルパラメータを更新するのが一般的な手法である。

勾配の同期更新においては、モデルパラメータを更新する前に、レプリカ間で同期を取り、各レプリカで求めた勾配の平均値(以下、平均勾配と呼称)を用いて全パラメータを更新する。つまり、各レプリカで計算した勾配をモデルパラメータの更新に用いるのではなく、平均勾配を用いて各レプリカのモデルパラメータを更新する。これにより、常に最新の平均勾配を用いて訓練を行うことができる。

このアプローチは、具体的なプロダクトとしてはChainerMNが採用している。バージョン1.3では、同期更新時に他の処理と一部オーバーラップさせることで同期による処理遅延を低下させる DOUBLE BUFFERING という機能が搭載された。この機能を使うと、一時的に古い勾配のまま訓練を行うワーカーが生じるが、精度には影響のないレベルであるとのことである。 *2

勾配の非同期更新

f:id:serihiro:20180602221811p:plain

非同期に更新を行う場合、各レプリカ間で同期を取らずに平均勾配を求めるため、勾配の集約と平均勾配の計算を行うための別のプロセスが必要である。かつてGoogle社内で使われていたDistBeliefという深層学習フレームワークにおいては、パラメーターサーバという専用プロセスがこの役割を果たしている。Googleのプロダクトとしては後発であるTensorFlowを分散実行させる場合においてもこのアプローチは同じようではある。

非同期更新においては、各レプリカが求めた勾配をパラメーターサーバに送信し、その勾配を用いてパラメータサーバが平均勾配を更新する。各レプリカは次のiterationの訓練を開始する前に、パラメーターサーバから最新の鮮度の高い勾配を取得し、それを用いて訓練を行う。

各レプリカは同期を取らずに平均勾配が更新されるため、レプリカ間で使用している勾配に差が生じている状態が発生する。このことにより、より良い勾配を古い勾配を使った訓練により生じたより劣悪な勾配によって、悪化させてしまう事も考えられる。 このような最新のパラメーターサーバーにある勾配と比べて古い勾配を「陳腐化した勾配(Stale gradient)」と呼ぶ。陳腐化した勾配の影響を緩和するには、学習率を下げる、陳腐化した勾配を定期的に捨てる、ミニバッチサイズを調整する、等が考えられる。*3

モデル並列

f:id:serihiro:20180602220457p:plain

レプリカを作らず、ネットワーク上のパラメータを複数のプロセスに分割し訓練を行う並列化手法。 これにより、例えば前結合ネットワークにおいて、巨大な行列となった重み行列と入力ベクトルとの計算を複数のマシンで分散して実行することで、要求されるマシンスペック要求の低下*4と、並列化による高速化が望める。

一方で、ネットワーク的に分断されたマシンクラスタで実現する場合、各プロセス間のデータ通信が大きなボトルネックになると考えられる。そのため、全結合ネットワークやcNNではスループットが大きく低下する可能性が考えられる。この手法を適用して高速化を行うには各プロセス上での高い密度を高くすることができるネットワークが適していると考えられる。例えば、TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systemsではモデル並列の実装例としてLSTMを挙げている。

自分が論文を調べた限りでは、Apache Spark上に全結合ネットワークやcNNをモデル並列で実際に実装している例*5や、レイヤーごとに複数のマシンに分割し、訓練処理をパイプライン化して並列効率を向上させる実装*6もあるが、まだ実例が少ないアプローチであると考えられる。

2018年6月時点では同期更新・非同期更新どちらが良いのか

様々な条件下で両者のアプローチを比較した2016年の論文Revisiting Distributed Synchronous SGDによると、同期更新の方が非同期更新よりも収束速度も正解精度も高いことを示している。また、訓練を実行するワーカー数を増加させた場合も同期更新の方が収束速度がより効率的にスケールすることを示している。近年の画像コンペにおいても上位入賞者チームにおいて同期型が支持されているという意見もある。*7

よって、非同期型に特化した新しい手法が開発されたら状況は変わる可能性はあるが、2018年6月現在では同期型の方を選択する方が正解である可能性が高い。例えば2017年9月にarXivに投稿されたImageNetの訓練に関する論文、ImageNet Training in Minutesでも、同期型データ並列アプローチを使っていることが明記されている。

まとめ

訓練精度という観点で見ると、データ並列・同期更新が現時点では最も高い精度を得つつ、高いパフォーマンスを発揮できるアプローチだと考えられる。しかし、メモリに乗り切らない大規模なパラメータを扱う場合や、マシンリソースの有効活用という観点では、モデル並列を活用できるシーンもあると考えられる。*8

実際、自分は大規模なパラメータを想定したネットワークにおいてモデル並列訓練を行った場合の性能特性について研究する予定であり、現在c++Blas系ライブラリとMPIで、モデル並列ネットワークを実装している最中である(作るのはとりあえず全結合ネットワークのみ)。自分の研究成果については追ってまた別のエントリで紹介したい。

参考文献

*1:例えばここ数年の画像コンペでかなり使われているcNNの1つであるResNetは、作者らの報告によるとImageNetで152層のネットワーク、CIPHER-10で1202層のネットワークを構築した報告がある

*2:https://chainer.org/general/2018/05/25/chainermn-v1-3.html

*3:https://www.oreilly.co.jp/books/9784873118345/

*4:例えば大規模な行列計算を行うノードとそうでないノードとでGPUの搭載数を変えたりすることで、スペックを要求されない部分には低スペックなマシンを採用することができる

*5:https://arxiv.org/abs/1708.05840

*6:https://arxiv.org/abs/1705.09786

*7:https://logmi.jp/285424

*8:完全に余談だが、過去に非同期処理を多用するマイクロサービスを開発してきた元ウェブアプリケーションエンジニアの感想として、DistBeleifに代表されるパラメータサーバを用いたアーキテクチャの方が馴染みがあるので作ってみたい気持ちにはなる。