Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

LSTM

RNNの学習においては勾配消失問題を解決するためには、RNNレイヤのアーキテクチャを根本から変える必要があります。

ここで登場するのは、シンプルのRNNを改良した「ゲート付きRNN」です。この「ゲート付きRNN」では多くのアーキテクチャが提案されており、その代表的な仕組みはLSTM(Long Short-Term Memory)になります。

1LSTMのインタフェース

1.1LSTMの全体像

RNNとLSTMレイヤのインタフェースの違いは、LSTMにはccという経路があることです。このccは記憶セルと呼ばれ、これを通じてネットワークを流れる情報の流れを制御します。セル状態は、ネットワークの一種の「記憶」であり、重要な情報を長期間にわたって保持する能力を持っています。これにより、LSTMは長期間にわたる依存関係を捉えることができます。

ctc_tには、時刻ttにおけるLSTMの記憶が格納されています。具体的に言えば、現在の記憶セルctc_tは、三つの入力(ct1,ht1,xt)(c_{t-1},h_{t-1},x_t)から何らかの計算によって求められています。そのため、これに過去から時刻ttまでにおいて必要な情報が全て格納されていると考えられます。

必要な情報が詰まった記憶を元に、外部のレイヤへ隠れ状態hth_tを出力します。

1.2ゲート機構

ゲートは、セル状態に流れる情報を制御するために使用される仕組みです。

ゲートの「開き具合」は0.01.00.0 - 1.0までの実数で表されます。そしてその数値によって、必要な情報を保持し、不要な情報を排除し、適切な時に適切な情報を出力することができます。

ここで重要なのは、ゲートの「開き具合」ということは、データから自動的に学ばせるということです。

2LSTMの構造

2.1forgetゲート

LSTMのforgetゲートでは、長期記憶から不要な情報を忘却するための制御を行っている。

ここで、forgetゲートで行う一連の計算をσ\sigmaで表すことにします。計算は次の式で表されます。σ\sigmaはsigmoid関数を表します。

ft=σ(xtWx(f)+ht1Wh(f)+b(f))f_t = \sigma(x_tW_x^{(f)}+h_{t-1}W_h^{(f)} + b^{(f)})
  • 入力

    • 現在の入力xtx_t

    • 前の時点の隠れ層の出力ht1h_{t-1}

  • 重みとバイアス

    • 現在の入力xtx_tに適用されるゲートの重み行列Wx(f)W_x^{(f)}

    • 前の時点の隠れ層の出力ht1h_{t-1}に適用されるゲートの重み行列Wh(f)W_h^{(f)}

み付けされた入力とバイアスの合計にシグモイド関数によって計算されるため、0 から 1 の値をとります。

2.2inputゲート

新しい情報を追加する際、何も考えずに追加するのではなく、追加する情報としてどれだけ価値があるかを判断する上で、追加する情報を選択します。これにより、長期間にわたる依存関係をより効果的に管理し、複雑なシーケンスデータを扱うことができるようになります。

具体的には、inputゲートによって重みつけされた情報が新たに追加されることになります。

it=σ(xtWx(i)+ht1Wh(i)+b(i))i_t = \sigma(x_tW_x^{(i)}+h_{t-1}W_h^{(i)} + b^{(i)})

2.3新しい記憶セル

LSTMではセルの長期記憶を保つための変数ctc_tが用意されています。長期記憶ctc_tに対して、古くなった記憶を削除したり、新しい情報を新規追加したりすることで、適当な長期記憶を可能にしています。

具体的には、

  • 入力ゲートの計算

    • 現在の入力xtx_tと前の時点の隠れ状態ht1h_{t-1}から、inputゲートiti_tがシグモイド関数を用いて計算されます。

    • 同時に、tanh関数を用いて新しい候補セル状態c~t\tilde{c}_tが生成されます。

g=tanh(xtWx(g)+ht1Wh(g)+b(g))g = tanh(x_tW_x^{(g)}+h_{t-1}W_h^{(g)} + b^{(g)})
  • セル状態の更新

    • inputゲートiti_tと新しい候補セル状態c~t\tilde{c}_tがアダマール積によって組み合わされます。

    • forgetゲートftf_tを用いて、前のセル状態ct1c_{t-1}が更新されます。

  • 最終的なセル状態

    • 更新された前のセル状態と新しく生成されたセル状態が加算され、新しいセル状態ctc_tが生成されます。

2.4outputゲート

outputゲートは隠れ状態hth_tの形成に使用されます。

現在の記憶セルctc_tは、(ct1,ht1,xt)(c_{t-1},h_{t-1},x_t)を入力として求められます。そして、更新されたctc_tを使って、隠れ状態のhth_t計算されます。ここで、tanh(ct)tanh(c_t)の各要素に対して、「それらが次時刻の隠れ状態としてどれだけ重要か」ということを調整します。

なお、outputゲートの開き具合は、入力xtx_tと前の状態ht1h_{t-1}から求めます。

ot=σ(xtWx(o)+ht1Wh(o)+b(o))o_t = \sigma(x_tW_x^{(o)}+h_{t-1}W_h^{(o)} + b^{(o)})

sigmoidsigmoid関数による出力とtanhtanh関数によるセル状態の出力を掛け合わせ、新しい隠れ状態hth_tを生成します。

ht=ottanh(ct)h_t = o_t \odot \tanh(c_t)

3発展的なLSTM

3.1多層LSTM

多層LSTM(Multi-layer LSTM, Stacked LSTMとも呼ばれます)は、複数のリカレント層を積み重ねたニューラルネットワークの構造です。各層は独自の隠れ状態を持ち、前の層からの出力を次の層の入力として受け取ります。

各層が異なるレベルの特徴を学習できるため、多層RNNは単層RNNよりも複雑なパターンを捉えることができます。

fishy

3.2双方向LSTM(Bidirectional LSTM)

通常のLSTMは、時系列データを順方向に学習していますので、後ろにある単語やフレーズの情報を取り込むことができません。そのため、過去の情報を活用して現在の出力を決定するのには有効ですが、「未来」の情報を考慮に入れることができません。

テキスト処理のタスクでは、文の意味を完全に理解するためには、未来の情報が過去の情報と同じくらい、またはそれ以上に重要な場合があります。

Apple is something that I like to eat.

双方向LSTMはこれらの欠点を補うために設計されており,シーケンスデータの処理において、時間の前後の両方の方向から情報を捉えるために設計されます。

具体的は、双方向LSTMは、シーケンスデータを順方向(前から後ろへ)と逆方向(後ろから前へ)の両方で処理する二つのLSTM層から構成されます。

  • 順方向のLSTM: 一方のLSTM層がシーケンスを通常の時間の流れに沿って処理し、各時点での隠れ状態を更新します。

  • 逆方向のLSTM: もう一方のLSTM層がシーケンスを逆順に処理し、別の隠れ状態シーケンスを生成します

  • 出力の結合: 二つの層の出力(隠れ状態)は、各時点で結合されて最終的な出力を形成します。