Tensorflow2.0でResNet1D, TimeDistributed ResNet2D, ResNet3Dを作る

この記事ではtf.kerasを使ってResNet(2D)や、ビデオデータに適用できるTimeDistributed ResNet(2D), ResNet3Dをどうやって構築するかを解説します。

tensorflow2.0がリリースされましたね。2.x系がリリースされてからtf.kerasとtf.dataを使って機械学習をしてきましたが、たまにResNet3Dを使いたいのに、ググってもパッとgithubが出てこないので、自分で作るついでに解説記事を書いてみました。この記事のコードや結果はGithubに上げておきますので、必要な方は自由に活用してください。

ResNetのおさらい

ResNetの論文: https://arxiv.org/pdf/1512.03385.pdf

ResNetはSkip Connectionを用いることで勾配消失を防ぎ、ImageNetで152層という当時では異例の深さのネットワークでSOTAを更新しました。ResNetでは5段階のStage(下図のconv1~conv5_x)があり、それぞれのStage内で複数のConvolution層とSkip connectionを有したResidual Blockが存在しています。

例えば、一番左の”18-layer”と書いてある水準は”ResNet18”と呼ばれ、各2層のConvとSkip connectionをもつResidual Blockを各Stageに2つずつ配置しています。

さらなる詳細は、他に良質な記事があるためこの記事では触れません。

tf.kerasを使ったネットワーク構築方法

tf.kerasでは、tf.keras.Modelを継承してモデルを作ることができます。

__init__関数でパラメータを定義し、call関数でネットワークを定義します。

__init___内部に学習パラメータを入れなければならないことに注意しましょう。call関数に入れてしまうと、パラメータとして認識されず、学習されなくなってしまいます。

tf.kerasを使ってResNetを作る

では、このtf.keras.Modelを活用してResNetを作っていきます。ここではtf.keras.Modelを使って各ResidualBlockを作り、さらにtf.keras.Modelを使ってStage毎にそれをつなげていく戦略をとります。

ここではResNet1D, TimeDistributed ResNet2D, ResNet2D, ResNet3Dを作りたいので、色々なConvolution layerを使わなければなりません。別々のclassを作りたくないので、どのConvolution layerを使うかをmodeというパラメータで決める関数を作り、同じclass内で構築できるようにします。

ついでに1st stageにMaxPoolingがあるので、そこも同様に実装します。

tf.kerasを使って、ResBlockを作る

最初にResNet18~ResNet152で共通するsize=7のConvolutionとMaxPoolingから成るBlockを作ります。上述するように、レイヤーの定義は__init__関数、ネットワークの定義(伝播の順序)はcall関数で実施していることに注意してください

次に、channel数と解像度がIn-Outで変化しないIdentity Blockを作ります。ここではResNet18、ResNet34は2層のConvolution+Skip connectionのResblock, ResNet50以降は3層のConvolution+Skip connectionのResblockになっています。

先ほど作ったdefine_ConvLayer関数のおかげで、1D,2D,3Dをまとめることができました。また、この記事では詳細に触れませんが、NormalizationとしてBatchNormalizationとGroupNormalizationをdefine_Norm関数を使って選択しています。

次にchannel数と解像度がIn-outで変化するBlockを作ります。ResNetではMaxPoolingによる解像度圧縮は最初のレイヤー以外では実施せず、Convolution Layerのstride幅を2にすることで実現しています。また、In-Outでチャネル数が変化するためfilter_size=1,stride=2のConvolutionをskip connection側にかけることによりチャネル数変化を実現しています。

最後にResNet全てで共通するGlobalAveragePoolingから各クラスの確率を出力するstageを構築します。kerasのResNet50に倣って、modelをGlobalAveragePoolingで止めるinclude_top、出力クラス数を変えるnum_classesを導入しています。

tf.kerasを使ってResBlockを繋げる

上で作ったConvBlockをつなげていきます。ここでもLayerとして定義するときは__init__関数内、伝播の順序定義はcall関数内というルールを守ります。

ここで作ったResnetBuilderを使うと下図のようにResNetを構築することができます。下の例は1DのResNet18です。他の例はGithubに載せていますので興味がある方はご覧ください。

CIFAR10を学習させる

せっかくResNetを作ったので、CIFAR10で学習させてみます。tensorflowには、様々なデータ操作・データ拡張ができるtf.dataという協力なツールが用意されています。tf.kerasのfitメソッドにも簡単に適用できるので、本記事ではtf.dataとfitを使ってCIFAR10を学習させます。

tf.dataを使ってデータ拡張をする

まず、CIFAR10をダウンロードし、one_hot化等の整形を加えます

データ拡張として「左右のランダム反転」と「ランダムな切り抜き」を適用することを考えます。まず「左右のランダム反転」ですが、tf.data.imageにそれを行う専用の関数があるのでそれを使います。map関数を使って適用します。

ランダム切り抜きですが、こちらは「(i)ずらし幅に従ってpaddingする (ii)ランダムに切り抜きをする」という手順を踏みます。ただ単に切り取りをすると画像サイズが小さくなってしまうので、最初にpaddingが必要です。これも関数を作ってmap関数を適用することで適用できます。

どのようになるのか、可視化してみましょう。tensorflow2.0からEager Extensionがデフォルトになったため、tf.dataによるデータ出力確認が非常に簡単になりました。今まではイチイチsess.runをしなければならなかったことを考えると大きな進歩です。tf2.0では作ったdatasetをiterator化でき、nextを使うとデータが呼び出せるので非常に簡単です。(効果を見やすくためshift_range=0.5に設定しています)

tf.dataとfitメソッドを使って学習させる

先ほど作った関数を使って、学習用・テスト用のデータセットを作ります。Input dataとTarget dataをまとめるときはtf.data.Dataset.zipを使いましょう。fitメソッドに適用させるためには、データを最後まで呼び出した後に再び最初から繰り返す.repeat()の適用が必要です。また、.batch()を使えば任意のバッチサイズでミニバッチ学習ができます。

ちなみに.batch()を2回適用させると、下記のようになります。特定のシーケンス長をもつビデオデータを作るときに便利です

作ったデータセットは直接fitメソッドを使って適用することができます。ResNet18を使って学習させてみます。

学習させると下記のようになります。1epoch 40s強で学習できています。最高valid accは82.04%でした。

tf.dataの処理を高速化する方法

このままでもいいですが、tf.dataの処理を高速化してみましょう。

tf.dataのmap関数は並列化できます。具体的には、map関数でnum_parallel_calls=tf.data.experimental.AUTOTUNEとすることで適用できます。何をしているかの詳細はこのURLを参照してください。map関数の適用を並列化することで読み込みの待ち時間を大幅に短縮することができます。

適用してみると、1epochを30秒以下に短縮することができました。25%以上の高速化です。

まとめ

この記事では、tf.keras.Modelを継承してResNetを構築するという記事を書きました。上記の結果、コードは全て私のGithubに上げてあります。

tf.kerasもかなり書きやすくなり、様々なデータ拡張がしやすいtf.dataと組み合わせることでtensorflowもかなり書きやすくなったのではないかと思います。

もしこの記事で間違っている点等ありましたら、遠慮なくご指摘ください。

Data Scientist (Engineer) in Japan Twitter : https://twitter.com/AkiraTOSEI LinkedIn : https://www.linkedin.com/mwlite/in/亮宏-藤井-999868122