大きなミニバッチの効力と、Tensorflowを使って大きなミニバッチを学習させる方法(Gradient Accumulation)

この記事について

この記事では、大きなミニバッチで学習さえることの効力と、Tensorflow2.0を用いてメモリに乗り切らない大きなミニバッチを学習させる方法を紹介します。この記事に書いてあることは以下の通りです。

  • Tensorflow2.0を用いた標準的なモデル学習の方法

今回使った実験のコードはGithubにアップロードしてます。

Tensorflow2.0の標準的なモデル学習の方法

tensorflow2.0では、kerasのfitを用いた学習のほかに、gradient tapeを用いた学習方法がチュートリアルで提案されています。以下のコードはチュートリアル[1]からの引用です。

Gradient Tape内で行われた計算は勾配を計算できるように記録され、tape.gradientを用いることによって勾配を計算できます。

大きなミニバッチで学習させることが効果的な場面

当然ですが、この形式ではメモリに乗り切らない大きなミニバッチで学習させることができません。しかし、大きなミニバッチを使ってすることが効果的な場面があります。例えば以下の2場面です。

  1. GPUのメモリが小さいせいで、大きなモデルを乗せると小さなミニバッチしか乗らない

1. GPUのメモリが小さいせいで、大きなモデルを乗せると小さなミニバッチしか乗らない

V100等の高性能かつ大容量メモリを持ったGPUを常時使えればいいのですが、現実は厳しいです。特に大きな画像はメモリ占有量が大きいので、DenseNet201で学習させると256x256の画像が4つしかメモリに乗らない、といった事態がありえます。また、ResNetやDenseNetのようにBatchNormalizationが使われている場合、小さいミニバッチで学習させると性能が低下するという報告[2]があります。

[2]より引用。batch sizeが小さくなるほど性能が低下しているのがわかる(左)。一方この論文で提案されているGroupNormalizationを用いると性能が低下していない。

この論文で提案されているGroup Normalizationを使ってもいいのですが、keras.applicationsで使えるResNet50等を使えなくなるので、自分で実装することになって面倒です。

2. より良いモデルを得るため、学習率を下げる代わりにバッチサイズを大きくしたい

ある時点で学習率を下げてより良いモデルを得る方法がありますが、代わりにバッチサイズを上げて同様の効果を得るという方法[3]もあります。この手法では指数的にバッチサイズを大きくするので、最初の段階ではメモリに乗っても最後はメモリに乗らないという状況が起こり得ます。

[3]より引用。学習率を下げる代わりにバッチサイズを下げても同様の効果得られることがわかる。

また、自然言語処理で頻繁に用いられるTransformerモデルでは、バッチサイズを上げた方がより良い結果が得られたという研究結果[4]もあり、Learning Rateを下げていく手法では効果がなかったという研究結果[5]もあります。

[4]より引用。ミニバッチサイズが大きいとスコアが高くなる。
[5]より引用。学習率を下げても精度はほとんど上がらないず、むしろ下がる。

また、Generative Adversarial Networksにおいても大きなバッチサイズを用いた方が生成画像の質が良くなったという報告[6]がなされています。

[6]より引用。バッチサイズが生成画像の質向上に効いていることがわかる。

大きなミニバッチで学習させる方法

大きなミニバッチを使った方が良い場面があることはご理解頂いたのではないかと思います。では、Tensorflowを使って大きなミニバッチサイズで学習させるためにはどうすれば良いのでしょうか?答えの1つとして、”小さなミニバッチ毎の勾配の和を保存しておき、その勾配の和を使ってパラメータ更新をすることにより、擬似的に大きなミニバッチサイズの学習をさせる”という戦略が取れます。

まず、下記のようなResNetを学習させるコードを考えてみます。データはCIFAR10を使っており、コード全体はGithubに上げてあります。また、ResNetは以前ブログで紹介したものを使っていますので、よろしければこちらもご覧ください。

コードは100行ほどありますが、注目して頂きたいのは37~43行目のtrain_step関数とそれを内包した46~48行目のminibatch_training関数、そしてミニバッチ学習を行なっている76行目周辺です。このクラスでは、tensorflowのチュートリアル通り、1つのミニバッチで勾配を取る形式になっています。

このコードで様々なミニバッチサイズで学習させてみると、GroupNormalizationの提案論文で書かれているように、ミニバッチサイズ1で大きく性能が落ちていることがわかります。

大きなミニバッチで学習させるために、このクラスを継承してminibatch_training関数を下記のように改良します。これを使うと、仮想的にbatch_size * batch_accumulate_numのサイズのミニバッチで学習させることができます。コアな部分は45~67行目のaccumulate_train_step関数で、下記のような処理をしています。

  1. ミニバッチの大きさがbach_sizeのデータを使って、train_step関数でgradientを計算する(52行目)

以上の処理をすることで、仮想的に大きなミニバッチで学習したときと同じような学習ができます。実際にbatch_size=1でbatch_accumulate_num(勾配を蓄積するバッチの数)=32,64,128で学習させてみた実験結果を下記に記します。それぞれ仮想的にbatch_size=32,64,128で学習させられていることがわかります。実線が実際のミニバッチの大きさを変えた水準、点線がbatch_size=1で、batch_accumulate_numを変えることで、仮想的なミニバッチ学習を実施した水準です。

まとめ

本記事では、バッチサイズを上げた方がいい場面の紹介と、tensorflow(2.0)でメモリに乗り切らないバッチサイズの学習をどう実現するかを紹介しました。特にTransformerモデルでは、学習率を定期的に下げる代わりにバッチサイズを大きくすることが効果的なようなので、これを機に試してみてはいかがでしょうか。

Twitter。論文の一言紹介とかやってます。

Reference

  1. https://www.tensorflow.org/tutorials/quickstart/advanced

Data Scientist (Engineer) in Japan Twitter : https://twitter.com/AkiraTOSEI LinkedIn : https://www.linkedin.com/mwlite/in/亮宏-藤井-999868122