• 検索結果がありません。

Truncated Backpropagation Through Time

6. 深層学習 [6][7]

6.3. Truncated Backpropagation Through Time

①,②を考慮した学習方法を構築するには,一般的なRNNの学習方法とは明らかに異な る。その決定的な違いは時系列データと教師データの比率が1:1対応していない点にある。

本来,RNNの学習にはある時刻tに対して学習データと教師データの両方が存在し,比率 は1:1にする。しかし,我々が予測しようとしているキュウリの収穫量は何日分(1日で 96データ)もの学習データに対して教師データが1データとなっているため,その比率は n:1である。これを考慮して学習する方法として,Truncated Backpropagation Through

Time(以降,Truncated BPTTと略す)という技術の応用を提案する。そもそも,Truncated

BPTT とはどのような技術であるのか示した上でどのような応用をしていくのか本節では 示していく。

6.3.1. RNN とは

Truncated BPTT について説明するためには RNN の基本構造を知っている必要がある

ため,RNNから順に説明を行っていく。

RNNの特徴はループする経路(閉じた経路)を持つことである。このループする経路を 持つことによって,データは絶えず循環することができる。そしてデータが循環することに より,過去の情報を記憶しながら最新のデータへと更新される。RNNで用いられるレイヤ を「RNN レイヤ」という名前で呼ぶことにすると,RNN レイヤは図6.2 のように書くこ とができる。図6.2に示すとおり,RNNレイヤはループする経路を持つ。このループする 経路によって,データがレイヤ内を循環することができるようになる。図6.2では時刻をt として,𝑥𝑡を入力としている。これは時系列データとしてレイヤに入力されることを示して いる。そしてその入力に対応する形でℎ𝑡が出力される。

61

図6.2を見て分かる通り,各時刻のRNNレイヤはそのレイヤへの入力と1つ前のRNN レイヤからの出力を受け取る。そして,その 2 つの情報を元にその時刻の出力が計算され る。このとき行う計算は以下に示す(6.1)式であり,模式的に図で表したものが図 6.3 であ る。

𝑡= tanh(ℎ𝑡−1 + 𝑥𝑡 𝑥+ 𝑏) (6.1) RNNでは重みが2つあり,1つは入力𝑥を出力ℎに変換するための重み 𝑥,もう1つは1 つ前のRNNの出力を次時刻の出力に変換するための重み である。また,バイアス𝑏があ る。

6.3.2. Backpropagation Through Time とは

図6.4に示す通り,ループを展開した後のRNNは誤差逆伝播法を使うことができる。つ まり最初に順伝播を行い,続いて逆伝播を行うことで目的とする勾配を求めることができ る。ここでの誤差逆伝播法は,「時間方向に展開したニューラルネットワークの誤差逆伝播

図6.2:RNNレイヤ

RNN

𝑥𝑡𝑡

RNN

𝑥11

RNN

𝑥22

RNN

𝑥𝑡𝑡

RNN

𝑥00

=

図6.3:RNNレイヤの内部構造

𝑛𝑒𝑥𝑡

𝑛𝑒𝑥𝑡 𝑡𝑎𝑛ℎ

+ 𝑎𝑡 +

𝑏

𝑥

𝑥

𝑝𝑟𝑒

62

よってRNNの学習は行えるように見える。しかし,その前に解決しなければならない問題 がある。それは,長い時系列データを学習する場合である。なぜそれが問題であるのかは時 系列データの時間サイズが大きくなるに比例して,BPTT で消費するコンピュータの計算 リソースも増加することになるからである。また,時間サイズが長くなると逆伝播時の勾配 が不安定になることも問題である。これを解決する為にTruncated BPTTという技術があ る。

6.3.3. Truncated BPTT とは

大きな時系列データを扱うときに通常用いられるのが,ネットワークのつながりを適当 な長さで“断ち切る”ことである。これは,時間軸方向に長くなりすぎたネットワークを適 当な長さに切り取ることで,小さなネットワークを複数作るというアイディアで,それらひ とつひとつに対して誤差逆伝播法を行う。これがTruncated BPTTと呼ばれる手法である。

Truncated BPTT ではネットワークのつながりを断ち切るが,正しくはネットワークの

「逆伝播」のつながりだけを断ち切るということである。すなわち,順伝播のつながりは維 持されたままになるのである。一方,逆伝播のつながりは適当な長さで切り取り,その切り 取られたネットワーク単位で学習を行う。

ここからは具体的に例を挙げて説明していく。例えば,1000個の時系列データがあった とする。この時系列データを扱うときRNNレイヤを展開すると,横方向に1000個のレイ ヤが並んだネットワークになる。もちろん,どれだけレイヤが並んだとしても誤差逆伝播法 によって勾配を計算することは可能である。しかしそれがあまりにも長いと,計算量やメモ リの使用量などの点で問題になる。またレイヤが長くなるに従い,勾配が徐々に小さくなる ことがあり,勾配が前時刻へと届かなくなる。そこで図6.5に示すように横方向に長く伸び たネットワークの逆伝播のつながりを適当な長さに断ち切ることを考える。このように逆 伝播のつながりを切ってしまえば,それより未来のデータについて考える必要がなくなり,

ブロック単位で誤差逆伝播法を完結することができるようになる。

図6.4:RNNレイヤに対する誤差逆伝播法

RNN

𝑥

1

1

RNN

𝑥

2

2

RNN

𝑥

𝑡

𝑡

RNN

𝑥

0

0

63

関連したドキュメント