【脱初心者】LSTMの数式とあの図を攻略!超絶分かりやすい図解でステップ解説【E資格合格者のノート公開】

左半分が薄ピンク、右半分が濃いピンク色に分かれた背景。左にはAIを表すアイコン、右には「Emma's note~AI資格編~」の白文字と、吹き出しの中に書かれた「イメージを掴む」「脱初心者」の文字。勉強法

こんにちは、E資格保持者のエマです。

本記事は、

「LSTMがRNNの進化系ってことくらいは知ってるけど、技術的な細かいところは分からない。。。」

という方を対象に、例のあの図の意味を数式ごと理解しちゃおう!という趣旨の解説記事となっています。

LSTMと言えば、この図ですよね↓

LSTMの回路図

いきなり見ると「なんじゃこりゃ」ですが、落ち着いてブロックごとに一つずつ追っていくと、実は超カンタンで、しかもオモシロさがいっぱい詰まったものなんです。

このページの最後の方まで読んだ頃には、上の図をすべて説明できるようになっているはずですよ。

それでは、解説していきます!

本記事は、非エンジニア・AI初心者の私がDeep Learning for ENGINEER 2022#1を受験する際にまとめた勉強ノート「エマノート(Emma’s note)」の一部を抜粋しつつ、技術や専門用語を解説する記事です。

LSTM(Long Short Term Memory)の概要

まずは、LSTM登場の背景や基本アイデアをおさらいします。

なぜLSTMというものが誕生たのか、ここを踏まえた上で先に進んだ方が、数式や図を理解するときに何倍も差があります。

既に知っているという方も、復習だと思ってサラっと見てください。

【背景】従来RNNの課題とLSTM開発の動機

左ページにRNNの課題、右ページにLSTM開発の動機を表すイメージ図が描かれた見開きノート

従来のSimple RNN(Recurrent Neural Network)の欠点は、「長期依存性」。

くだけた表現で言えば「記憶力が悪い」こと。

初めの方でインプットされた情報が、時刻が進むにつれて薄められてしまいます。

これは、xt(ある時刻の入力ベクトル)とht-1(前時刻の隠れ状態ベクトル)をtanhで整えただけのシンプルな構成が原因です。

簡単なのは良いことですが、短期記憶しかできず、低機能です。

そこで、何とか短期記憶を伸ばして記憶力UPできないか(長期記憶を形成できないか)と考えられたのかLSTMになります。

勘違いされやすいので補足しておくと、LSTMというのは、短期記憶を上手いこと伸ばして長期記憶を可能にする、つまり「長い短期記憶(LongなSTM)」という意味です。

【解決策】根本アイデアは“うなぎのタレ”

左ページに「ずっと繋がった回路」のイメージ図とうなぎのタレの絵が描かれ、右ページに忘れる+覚えるの重要性を示すイメージ図が描かれた見開きノート

LSTMをつくるモチベーションは記憶力UPでした。

そのために、「初めから終わりまでずっと繋がったような『長期記憶を担ってくれる回路』を導入しよう!」というのがLSTMのベースとなるアイデアです。

後述しますが、この回路こそが「記憶セル」になります。

私個人的にはうなぎのタレのイメージを持っています。

従来のRNNは、一般家庭のように毎回タレを作り直しているカンジ、LSTMは、老舗うなぎ屋のように何年も前の伝統の味が「継ぎ足し継ぎ足し」で確かに後世へと受け継がれていく感じです。

記憶を長持ちさせるためのとても大切な工夫が”忘れる&覚える”のコントロールです。

覚えさせるだけでは、メモリがパンクしてしまいます。
(うなぎのタレも、足していくだけでは溢れます)

なので、ぜんぶを覚えようとはせず、適度に忘れて、加えるべき必要な情報のみ選んで、を繰り返すわけです。

イイカンジに忘れ、イイカンジに覚える。

これはまさしく人間の普段の脳の働きのように思えます。

背景はざっくりこんな感じです。

LSTMの数式と図の意味を理解する

いよいよ、数式とあの図の説明へ移ります。

繰り返しになりますが、難しそうに見えて一つ一つ分解して考えると実はそうでもないので、リラックスしてオーケーです。

  • まずは全体像を把握
  • それから一つ一つ細かく分けて見ていく

の流れで説明していきます。

【全体像】6つの数式と4つの構成部品

左ページに数式、右ページにLSTMユニットの概略図が描かれた見開きノート

LSTMは、6つの数式で記述され、特徴的な4つの部品から構成されています。

上図の左6つが、LSTMという存在を表す全数式で、逆にいえば、これさえ理解すればLSTMの全てが分かります。

LSTMというユニットを外から見たものが上図右の簡易的な図です。

LSTMということばを、RNNと同じように「ネットワーク」の名前として使う人がいますが、本来はネットワークを構成する一つの「ユニット」を指します。

Cという線が増えているだけで、シンプルなRNNユニットとそこまで変わりません。

このCこそが記憶セル」と呼ばれる長期記憶の役割を持つ部分で、LSTMのカギとなります。

ちなみに隠れ状態hは、LSTMにおいて短期記憶の役割を担うことになります(あとで説明しますね)。

記憶セルの他に、3つの「ゲート」と呼ばれる部品がありますが、これらは外からは見えません。

しかし、これら「ゲートこそが記憶を操る司令官たちなんです(こちらもあとで分かります)。

ここまで全体像を説明しましたので、ここから詳細に移ります。

LSTMがどのように動くのかを分かりやすく説明するために、以下の4つの工程に分割して解説していくことにします。

★LSTMが動作する手順★
  • STEP 1
    ゲートの準備
  • STEP 2
    忘却
  • STEP 3
    入力
  • STEP 4
    出力

それでは、順番に見ていきましょう。

【ゲート】記憶を制御する3つの門を準備する

入力を受けたLSTMユニットがまず初めにやること、それは3つのゲートの準備です。

これは水門のようなもので、0なら全く通さず、1なら全て通す、そういうものです。

何のため?と思うかもしれませんが、これは後でハッキリ理解できますので、今はとりあえず水門のような働きをする部品を作るんだな、くらいの認識でだいじょうぶです。

3つのゲートの概要を書きます↓

1. 忘却ゲート(forget gate):旧情報の忘れ具合を制御

左ページに複数の数式と一部ハイライトされた数式、右ページにLSTMの回路図全体と一部ハイライトされた回路が描かれた見開きノート

一つ目が「忘却ゲート」とよばれるもので、「古い情報をどれくらい忘れさせるか?」をコントロールする働きがあります。

ゲートが0であれば「閉じる=忘れる」、1であれば「開ける=覚える」を意味します。

ゲートの値はxt(その時刻の入力)とht-1(前時刻の隠れ状態)を材料として決めます。

数式に少し踏み込むと、xtベクトルに重みをかけ、ht-1ベクトルに重みをかけ、バイアスを加えたものに対して、シグモイド関数をとります。
(シグモイド関数をつかうのは、最終的にゲートの値を0~1にしたいからです)

こうして0~1の値を持つFt(forget gate vectorとよぶ)ができます。

2. 入力ゲート(input gate):新情報の覚え具合を制御

左ページに複数の数式と一部ハイライトされた数式、右ページにLSTMの回路図全体と一部ハイライトされた回路が描かれた見開きノート

二つ目が「入力ゲート」とよばれるもので、「新しい情報をどれくらい覚えさせるか?」をコントロールする働きがあります。

忘却ゲートと同じく、値が0であれば「閉じる=忘れる」、1であれば「開ける=覚える」を意味します。

数式の意味も忘却ゲートの場合と同様で、0~1の値を持つIt(input gate vectorとよぶ)ができます。

ただし、重みやバイアスはゲートごとに特有のもの(forget gateであれば下付き文字がFWF、input gateであれば下付き文字がIWIなど)なので、もちろん忘却ゲートと入力ゲートの値は異なります。

例えば、推理小説の文章データを入力として、その小説を要約するというタスクを考えた場合、分割して「このt=1/2/3/ある4/…」と入力していくわけですが、小説の後半で「犯人t-3/t-2/○○t-1/t」というパンチラインが出てきたら重要度爆上がりですよね。

時刻のとき、xt = 「」 が入力され、ht-1には「犯人は○○」くらいの短期記憶が埋め込まれているわけで、xtht-1から判断すると、それまでの古い情報よりも圧倒的に最近の情報が重要なので、忘却ゲートFt0に近く、入力ゲートIt1に近くなるでしょう。

3. 出力ゲート(output gate):抜き出し具合を制御

左ページに複数の数式と一部ハイライトされた数式、右ページにLSTMの回路図全体と一部ハイライトされた回路が描かれた見開きノート

三つ目が「出力ゲート」とよばれるもので、「長期記憶の中からどれくらい情報を抜き出すか?」をコントロールする働きを持ちます。

何を抜き出すかと言うと、短期記憶として覚えておくべき部分です。

前二つのゲートは、記憶セルに対して作用(TO memory cell)しており、長期記憶の形成に関わっていました。

一方でこの出力ゲートは、記憶セルからの方向に作用(FROM memory cell)し、実をいうと短期記憶の形成に関わることになります(解説まであともう少しです!)。

値が0であれば「閉じる=抜き出さない」、1であれば「開ける=抜き出す」を意味します。

役目はちがいますが、同じゲートなので数式の意味はほぼ同じで、結果的に0~1の値を持つOt(output gate vectorとよぶ)ができます。

ここまでゲートの形成に関わる部分の数式と図を説明しました。

まだ面白さに欠けると思いますが、ここからはついに記憶の形成に関わる部分を見ていきます。

【忘却】古い情報から要らない情報を“忘れる”

左ページに複数の数式と一部ハイライトされた数式、右ページにLSTMの回路図全体と一部ハイライトされた回路が描かれた見開きノート

さきほど決められた忘却ゲートFtの値(0~1)が、前時刻から流れてきたCt-1にかけられます。

つまり、古い情報が適度に忘れ去られます

上図の右下に描かれたベルトコンベアがイメージしやすいと思います。

記憶セルは前述の通りずっと繋がっていて、ベルトコンベアのように情報が流れていると考えられます。

忘却ゲートによって、一部の情報が忘れ去られ、次の「入力」工程へと流れていくわけです。

Ftの値がCt-1にかけられる」と説明しましたが、かけ算はかけ算でも、ここでは要素ごとの掛け算を意味する「アダマール積」です。

まるの中にポチっと点がある記号がアダマール積を表します。
AIを勉強していると頻出ですね。

要素ごとのかけ算の意義を理解するために下のイメージ図を見てください。

右上に「アダマール積のイメージ」というタイトルが書かれ、文章が書かれたマス目と、0または1の数字が描かれたマス目がアダマール積でかけられ、0の部分の文章が消え、1の部分だけ文章がハイライトされたマス目が出力されたイメージ図

7×8の文脈ベクトル(実際は数字が格納されたベクトルですが、ここではイメージしやすいように文章を書いています)に対して、7×8のゲートベクトル(実際は0~1の範囲の実数ですが、ここでは分かりやすく極端にゼロイチの場合にしています)アダマール積で作用させるケースを考えます。

要素ごとのかけ算とは、「同じ場所の要素同士でかけ算する」という意味です。

例えば、7×8マスの一番左上に注目すると、文脈ベクトルの左上の要素「吾」とゲートベクトルの左上の要素「1」がかけられて、出力ベクトルの左上の要素は「吾」になります。

一方で、7×8マスの一番右下に注目すると、文脈ベクトルの右下の要素「て」とゲートベクトルの右下の要素「0」がかけられて、出力ベクトルの右下の要素は「0」になります(「て」が消えます)。

文脈ベクトルの上にゲートベクトルを重ねるイメージ(フィルターを被せるイメージ)が分かりやすいです。

このように、アダマール積を作用させると、注目させたい部分にスポットライトを当てることができるので、忘れたい部分を忘れ、覚えたい部分を残す、ということができるわけです。

【入力】新しく加えたい情報を“覚える”

左ページに複数の数式と一部ハイライトされた数式、右ページにLSTMの回路図全体と一部ハイライトされた回路が描かれた見開きノート

次に「入力」工程です。

まずは、xt(その時刻の入力)とht-1(前時刻の隠れ状態)を基にして、tanh関数を通して新情報を作り出します。

Cの上にニョロっとしたものが付いていますが、これが新情報を表すベクトルの一般的な表記の仕方になります。
(よく「仮の文脈ベクトル(context vector)」などと呼ばれます)

σ関数とtanh関数を使い分けている理由が分からない、という方が居ますが、目的を考えると納得できるはずです。

σ関数は、値が0~1ですよね。
ゲートの目的は閉じるか開くか(0か1か)を制御することなので、σ関数がピッタリというわけです。

一方で今回のC言葉の意味を表すことが目的です。

言語を数値化する場合、ゼロイチよりかはプラマイの方が都合が良いことが多々あります。
(極端な例を言えば、「熱い」と「寒い」というような対義語は、-1と+1の関係ですよね)

よって、σ関数よりもtanh関数の方がマッチしているというわけです。

Ctをそのまま記憶セルにぶち込む訳ではありません。

その前に、さきほど準備した入力ゲートに通す、つまり、入力ゲートItの値をCtにかけます↓

左ページに複数の数式と一部ハイライトされた数式、右ページにLSTMの回路図全体と一部ハイライトされた回路が描かれた見開きノート

Ctの内のどれくらい覚えさせるか厳選するわけです。

こうして厳選された新情報が、ベルトコンベアを流れてきた情報の器に加えられ、長期記憶(long-term memory)が完成します↓

左ページに複数の数式と一部ハイライトされた数式、右ページにLSTMの回路図全体と一部ハイライトされた回路が描かれた見開きノート

完成した長期記憶Ctは、次の時刻へと流れていく、という繰り返しです。

【出力】出力したい情報を“抽出する”

左ページに複数の数式と一部ハイライトされた数式、右ページにLSTMの回路図全体と一部ハイライトされた回路が描かれた見開きノート

長期記憶に関しては終わりましたが、最後に出力ytと隠れ状態htを作る必要があります。

LSTMの場合、ythtは同じ成分です。

これらは、時刻tというその一瞬における出力を意味するので、短期記憶的な存在と言えます。

どうやって作られるかというと、さっき触れたように、長期記憶の情報から一部抜き出すという方法をとります。

記憶セルのベルトコンベアを流れてきた長期記憶Ctを、tanh関数で整えたのち、出力ゲートOtに通し、出力したい部分=短期記憶として残したい部分を抽出するわけです。

以上がLSTMの全てです。

一つ一つ分解して考えると、はじめは複雑に見えた数式や入り組んだ回路図も、ハッキリとその意味を理解できたのではないでしょうか。

LSTMに限らず、一気にぜんぶ見ようとするのではなく、細分化してみると意外とすんなり頭に入ってくるのでおすすめです。

ポイントまとめ:LSTMは記憶セルの導入とゲート制御によって長期記憶・短期記憶形成を可能に!

左ページにLSTMにおける長期記憶形成に関わる部分がハイライトされた回路図、右ページにLSTMにおける短期記憶形成に関わる部分がハイライトされた回路図が描かれた見開きノート

本記事の内容をまとめます。

LSTMとは、

  • うなぎのタレのように過去からずっと繋がった記憶セルを導入した。
  • 忘却ゲート&入力ゲートで忘れる&覚えるを制御し、記憶セルに長期記憶の機能を持たせた。
  • 出力ゲートで長期記憶の中から重要部分を取り出し、短期記憶も可能にした。

という技術になります。

あなたのAIの勉強に少しでも役立っていれば嬉しいです。

もし余裕があれば、いっしょにLSTMの類似技術であるGRU(Gated Recurrent Unit)についても理解しちゃってください!

脱初心者を目指して、これからも勉強がんばってください!