sagantaf

メモレベルの技術記事を書くブログ。

RNNとLSTMを理解する



この記事の目的

RNN, LSTMの理論を理解し、Kerasで実装できるようにするために、理論部分をまとめた記事。



0. 通常のNeural NetworkやConvolutional Neural Networkの問題

これまでNNやCNNは入力サイズが固定だった。そのため、毎回同じ入力をするしかなく、時系列情報を入力させることができなかった。

RNN・LSTMは、入力を可変にして過去の出力結果も学習に利用できるようにしたニューラルネットワークのモデルである。



1. RNN (Recurrent Neural Network)

RNNの基本的な構造は、「現在入力値と前回の出力値を合計し、活性化関数としてtanhを適用して出力とする」という流れ。

f:id:sagantaf:20190603233207p:plain
RNNの基本的な構造

上図でいうと、現在の入力がx_tとなり、前回の出力値がh_{t-1}となる。

f:id:sagantaf:20190603233243p:plain
1つのセルの中身

隠れ層では活性化関数としてtanhが一般的に選ばれる。

数式で表すと、
 h_t = tanh(x_t + h_{t-1})
となる。

tanhx_th_{t-1}の和を-1〜+1の範囲に正規化する。

tanhが選ばれる理由は、2次微分の減衰がかなりゆっくりとゼロになり勾配消失問題が起きにくいため。



2. 勾配消失(爆発)問題

過去データに対する重みが発散、消失する問題のこと。下記サイトの解説が分かりやすい。

ニューラルネットワークと深層学習

発散・消失する理由は、最初の方のレイヤーの逆伝播計算時に、何度も重みwとシグモイドの微分を掛け合わせるからである。

シグモイド関数微分は0~0.25の範囲である。そのため、重みwが1以下だったら、最大でも w*σ'(z)= 1 * 0.25 = 0.25 となるので、0.25を繰り返し乗算することで、勾配がほぼ0になってしまう。これが勾配消失と呼ばれる理由。

また、重みwが0.25をかけたらかなり大きくなる値(100とか)の場合、0 ~ 0.25 × 100 = 0 ~ 25 の範囲になるため、最悪25×25×25×…となってしまい、勾配が発散してしまう、勾配爆発問題が起きる。(これはバイアスでも同じ話)

このように最初のレイヤーの重みとバイアスがかなり小さくなってしまうことにより、新たな入力値を学習しようとしてもかなり小さくなってしまうため、モデルの学習に寄与しなくなってしまう。

RNNでも数十ステップの短期依存(short-term dependencies)には対応できる。しかし、1000ステップのような長期の系列になると、tanhを使って上記の問題を和らげるとはいえ、無視できないくらい勾配が小さく(もしくは大きく)なってしまい、上述の勾配消失(爆発)問題が発生する。



3. LSTM (Long-short term model)

この対処法としては下記のようにいくつかある。

  • 行列Wを適切に初期化する
  • tanhではなくReLUを使用する
  • 教師なしで層を事前学習する
  • LSTM, GRUなどを利用する

中でも4番目に記載した「LSTM」では過去のデータを保持する仕組みを使って、勾配消失問題を回避している。

LSTMは過去のデータをsigmoidやtanhではなく「線形和」で保持するため、逆伝播しても勾配が極端に大きくなったり小さくなったりしないため、勾配消失問題が発生しない。

LSTMの様々な構成パターンが存在するが、以下では基本的な構成要素をまとめた。

  • CEC:過去のデータを保存するためのユニット
  • 入力ゲート:「前のユニット(1つ前の時間のユニット)の入力をどの程度受け取るか」を調整するためのゲート
  • 出力ゲート:「前のユニット(1つ前の時間のユニット)の出力をどの程度受け取るか」を調整するためのゲート
  • 忘却ゲート:「過去の情報が入っているCECの中身をどの程度残すか」を調整するためのゲート

f:id:sagantaf:20190604220957p:plain
LSTMの構成要素

基本的な流れは緑の線であり、入力x_ttanhによって活性化され出力h_tとなることは変わらない。そこにCEC C_tと各種ゲートが追加された構成になっている。

図の上の部分C_tが長期記憶(Long)を担い、下の部分h_tが短期記憶(Short)を担う。これがLong-Short Time Modelと言われる所以となっている。

入力ゲートと出力ゲートはなんのために用意されたか?

RNNでは、入力重み衝突、出力重み衝突という問題が内在していた。

これは、学習過程で新たな入力・出力がきた時に、

  • 前と同じパターンが来た
     →今回のパターンにさらに適合するように重みW1を大きくしてW2を小さくしよう
  • 前と違うパターンが来た
     →新しいパターンに適合するようにW1を小さくしてW2を大きくしよう

という動きが繰り返され、重みWが上下してなかなか精度が上がらないという問題である。

この問題に対処するために、入力ゲート・出力ゲートを用意することで、

  • 前と同じパターンが来た
     →入力ゲートW_{in}は前に来たパターンを通すように適合しているので、そのままそのパターンを通そう
     →そして、今回のパターンにさらに適合するようにW1を大きくしてW2を小さくしよう
  • 前と違うパターンが来た
     →入力ゲートW_{in}は違うパターンを通すような内容になっていないので、通さない
     →変化0なので、W1、W2に変化はない

という動きを実現できる。

ゲートとして通す・通さない仕組みは、重みとしてほぼ0か1への活性化ができるSigmoid関数を利用することで実現している。 Tanh関数の場合0か1ではなく、-1〜1になってしまうためゲートとしての役割を担えない。

下記のサイトで詳しく解説されている。
わかるLSTM ~ 最近の動向と共に - Qiita

Long Short-term Memory

忘却ゲートはなんのために用意されたか?

忘却ゲートがないモデルの場合、大きな状況の変化に対応できないという欠点が存在する。

CECは線形和で過去のすべての入力を重み1で渡しつづけるため、ある時に大きな変化のある入力が来たとしても、相対的にその入力の影響は小さくなり、今まで通りの結果しか出力されなくなってしまう。

この問題に対処するためにSigmoid関数による忘却ゲートを導入し、入力系列のパターンががらりと変わったとき、セルの状態を一気に更新することを実現した。

例えば、現在CECの中身が [0,0,0,10,10,10,0,0,0] であり、今回の入力が [100,100,0,0,0,0,0,100,100] であった場合、忘却ゲートの値は [1,1,0,0,0,0,0,1,1] となり、CECは [0,0,0,0,0,0,0,0,0] になる。つまり過去の長期記憶がきちんと消えている。

図にすると以下のようなイメージ。

f:id:sagantaf:20190604223730p:plain
忘却ゲートの前後

そのほか

入力はtanhによって活性化されているが、ReLUに変更しても対して精度は変わらないことがわかっている。(RNNにReLUを適用するのは効果的)

LSTMモデルに対してドロップアウトを適用する場合は、non-recurrentな接続にのみ適用する必要がある。 つまり入力ゲートなどの再帰的な処理ではない部分に適用することになる。

CECh_{t-1}ドロップアウトを適用した場合、過去のデータをランダムに落としてしまうことになるため、学習がうまくできなくなってしまう。



最後に

今回RNNとLSTMについてまとめてみて、理解するのに時間がかかると感じた。

いまだに研究されている領域なので、何が最適かわかっていない部分も多々あるらしい。

そのため完璧に理論を理解しようとすることは避け、業務で使えるレベルの知識に留めておき、あとは目の前のデータに合う形で試行錯誤を繰り返すことが最短ルートになると思われる。

また、画像データに対する深層学習のアルゴリズムであるConvolutional Neural Network(CNN)のまとめは下記。

sagantaf.hatenablog.com



参考にした書籍やサイト

深層学習 (機械学習プロフェッショナルシリーズ)

深層学習 (機械学習プロフェッショナルシリーズ)

詳解 ディープラーニング ~TensorFlow・Kerasによる時系列データ処理~

詳解 ディープラーニング ~TensorFlow・Kerasによる時系列データ処理~

直感 Deep Learning ―Python×Kerasでアイデアを形にするレシピ

直感 Deep Learning ―Python×Kerasでアイデアを形にするレシピ

TensorFlowではじめるDeepLearning実装入門 (impress top gear)

TensorFlowではじめるDeepLearning実装入門 (impress top gear)

PythonとKerasによるディープラーニング

PythonとKerasによるディープラーニング

わかるLSTM ~ 最近の動向と共に - Qiita

わかるLSTM ~ 最近の動向と共に - Qiita

再帰型ニューラルネット in 機械学習プロフェッショナルシリーズ輪読会

LSTM (Long short-term memory) 概要

LSTMネットワークの概要 - Qiita

http://isw3.naist.jp/~neubig/student/2015/seitaro-s/161025neuralnet_study_LSTM.pdf

LSTM 〜Long Short-Term Memory〜(Vol.18)

ニューラルネットワークと深層学習