通信理論に特化した深層学習 第6回ゼミ資料
誤差逆伝播法
豊橋技術科学大学 電気・電子情報工学系
准教授 竹内啓悟
勾配の計算
損失関数を局所最小化するためには、順伝播型ネットワークに含まれるすべ てのパラメータに関する損失関数の偏微分を効率的に計算する必要がある。
「深層」合成関数のパラメータに関する偏微分を効率的に計算する問題
「深層」合成関数
𝒈𝒈 𝑡𝑡 ⋅, 𝜽𝜽 𝑡𝑡 : ℝ 𝑁𝑁𝑡𝑡−1 × ℝ 𝑁𝑁𝑡𝑡をパラメータ𝜽𝜽 𝑡𝑡を持つ微分可能なベクトル値関数とし、
𝜽𝜽 𝑡𝑡を持つ微分可能なベクトル値関数とし、
𝑮𝑮 𝑡𝑡 𝑡𝑡12 𝒛𝒛, 𝚯𝚯 𝑡𝑡 𝑡𝑡12 = 𝒈𝒈 𝑡𝑡2 ∘ 𝒈𝒈 𝑡𝑡2−1 ∘ ⋯ ∘ 𝒈𝒈 𝑡𝑡
1 (𝒛𝒛, 𝜽𝜽 𝑡𝑡1), 𝚯𝚯 𝑡𝑡 𝑡𝑡12 = 𝜽𝜽 𝑡𝑡1, … , 𝜽𝜽 𝑡𝑡2 .
注意
= 𝒈𝒈 𝑡𝑡2 ∘ 𝒈𝒈 𝑡𝑡2−1 ∘ ⋯ ∘ 𝒈𝒈 𝑡𝑡
1 (𝒛𝒛, 𝜽𝜽 𝑡𝑡1), 𝚯𝚯 𝑡𝑡 𝑡𝑡12 = 𝜽𝜽 𝑡𝑡1, … , 𝜽𝜽 𝑡𝑡2 .
注意
−1 ∘ ⋯ ∘ 𝒈𝒈 𝑡𝑡
1(𝒛𝒛, 𝜽𝜽 𝑡𝑡1), 𝚯𝚯 𝑡𝑡 𝑡𝑡12 = 𝜽𝜽 𝑡𝑡1, … , 𝜽𝜽 𝑡𝑡2 .
注意
= 𝜽𝜽 𝑡𝑡1, … , 𝜽𝜽 𝑡𝑡2 .
注意
.
注意𝑁𝑁 𝑇𝑇 = 1
として、スカラー関数𝑔𝑔 𝑇𝑇 ⋅, 𝜽𝜽 𝑇𝑇 : ℝ 𝑁𝑁𝑇𝑇−1 → ℝ
に損失関数を含める。
偏微分の計算 偏微分の連鎖測
𝜕𝜕
𝜕𝜕𝜃𝜃 𝑖𝑖 𝑡𝑡 𝐺𝐺 1 𝑇𝑇 𝒙𝒙, 𝚯𝚯 1 𝑇𝑇 = �
𝑛𝑛=1 𝑁𝑁
𝑡𝑡𝜕𝜕 𝑛𝑛 𝐺𝐺 𝑡𝑡+1 𝑇𝑇 𝒛𝒛 𝑡𝑡 , 𝚯𝚯 𝑡𝑡+1 𝑇𝑇 𝜕𝜕𝑔𝑔 𝑛𝑛 𝑡𝑡
𝜕𝜕𝜃𝜃 𝑖𝑖 𝑡𝑡 𝒛𝒛 𝑡𝑡−1 , 𝜽𝜽 𝑡𝑡 , 𝑔𝑔 𝑛𝑛 𝑡𝑡 = 𝒈𝒈 𝑡𝑡 𝑛𝑛 . 𝐺𝐺 1 𝑇𝑇 𝒙𝒙, 𝚯𝚯 1 𝑇𝑇 = 𝐺𝐺 𝑡𝑡+1 𝑇𝑇 𝒈𝒈 𝑡𝑡 𝒛𝒛 𝑡𝑡−1 , 𝜽𝜽 𝑡𝑡 , 𝚯𝚯 𝑡𝑡+1 𝑇𝑇 , 𝒛𝒛 𝑡𝑡 = 𝑮𝑮 1 𝑡𝑡 𝒙𝒙, 𝚯𝚯 1 𝑡𝑡 .
上記の表現の両辺を
𝜃𝜃 𝑖𝑖 𝑡𝑡 = 𝜽𝜽 𝑡𝑡 𝑖𝑖に関して偏微分すると、
後者の因子は順伝播時に計算可能
𝐺𝐺 𝑡𝑡+1 𝑇𝑇 ⋅, 𝚯𝚯 𝑡𝑡+1 𝑇𝑇 𝒈𝒈 𝑡𝑡 (⋅, 𝜽𝜽 𝑡𝑡 )
𝑮𝑮 1 𝑡𝑡−1 ⋅, 𝚯𝚯 1 𝑡𝑡−1
𝒙𝒙 𝒛𝒛 𝑡𝑡−1 𝒛𝒛 𝑡𝑡
𝑡𝑡
番目の関数のパラメータ𝜽𝜽 𝑡𝑡に注目する。
𝐺𝐺 𝑡𝑡 𝑇𝑇 𝒛𝒛, 𝚯𝚯 𝑡𝑡 𝑇𝑇 = 𝐺𝐺 𝑡𝑡+1 𝑇𝑇 𝒈𝒈 𝑡𝑡 𝒛𝒛, 𝜽𝜽 𝑡𝑡 , 𝚯𝚯 𝑡𝑡+1 𝑇𝑇
𝜕𝜕 𝑛𝑛 𝐺𝐺 𝑡𝑡 𝑇𝑇 𝒛𝒛 𝑡𝑡−1 , 𝚯𝚯 𝑡𝑡 𝑇𝑇 = �
𝑁𝑁
𝑡𝑡𝜕𝜕 𝑛𝑛 𝑔𝑔 𝑡𝑡′ 𝒛𝒛 𝑡𝑡−1 , 𝜽𝜽 𝑡𝑡 𝜕𝜕 𝑛𝑛′𝐺𝐺 𝑡𝑡+1 𝑇𝑇 𝒛𝒛 𝑡𝑡 , 𝚯𝚯 𝑡𝑡+1 𝑇𝑇
逆伝播
𝐺𝐺 𝑡𝑡+1 𝑇𝑇 𝒛𝒛 𝑡𝑡 , 𝚯𝚯 𝑡𝑡+1 𝑇𝑇 逆伝播
𝐺𝐺 𝑡𝑡+1 𝑇𝑇 ⋅, 𝚯𝚯 𝑡𝑡+1 𝑇𝑇 𝐺𝐺 𝑡𝑡 𝑇𝑇 ⋅, 𝚯𝚯 𝑡𝑡 𝑇𝑇
𝒈𝒈 𝑡𝑡 (⋅, 𝜽𝜽 𝑡𝑡 ) 𝒛𝒛 𝑡𝑡−1
上記の表現の両辺を
𝑧𝑧 𝑛𝑛 = 𝒛𝒛 𝑛𝑛に関して点𝒛𝒛 = 𝒛𝒛 𝑡𝑡−1で偏微分すると、
𝒛𝒛 𝑡𝑡 表記法
関数
𝑓𝑓
の𝑛𝑛
番目の変数に関する偏微分を𝜕𝜕 𝑛𝑛 𝑓𝑓
と書く。誤差逆伝播法(
Back propagation
)𝐺𝐺 𝑡𝑡+1 𝑇𝑇 ⋅, 𝚯𝚯 𝑡𝑡+1 𝑇𝑇 𝒈𝒈 𝑡𝑡 ⋅, 𝜽𝜽 𝑡𝑡
𝑮𝑮 1 𝑡𝑡−1 ⋅, 𝚯𝚯 1 𝑡𝑡−1
𝒙𝒙 𝒛𝒛 𝑡𝑡−1 𝒛𝒛 𝑡𝑡
𝜕𝜕 𝑛𝑛 𝐺𝐺 𝑡𝑡+1 𝑇𝑇 𝒛𝒛 𝑡𝑡 , 𝚯𝚯 𝑡𝑡+1 𝑇𝑇 {𝜕𝜕 𝑛𝑛 𝐺𝐺 𝑡𝑡 𝑇𝑇 𝒛𝒛 𝑡𝑡−1 , 𝚯𝚯 𝑡𝑡 𝑇𝑇 }
逆伝播により、出力側から勾配を順に計算する。
𝜕𝜕
𝜕𝜕𝜃𝜃 𝑖𝑖 𝑡𝑡 𝐺𝐺 1 𝑇𝑇 𝒙𝒙, 𝚯𝚯 1 𝑇𝑇 手順1
順伝播により、
𝒛𝒛 𝑡𝑡 , 𝜕𝜕𝒈𝒈 𝑡𝑡 𝒛𝒛 𝑡𝑡−1 , 𝜽𝜽 𝑡𝑡 /𝜕𝜕𝜃𝜃 𝑖𝑖 𝑡𝑡 , 𝜕𝜕 𝑛𝑛 𝒈𝒈 𝑡𝑡 𝒛𝒛 𝑡𝑡−1 , 𝜽𝜽 𝑡𝑡 を計算する。
手順2
計算グラフ
計算結果が誤差逆伝播法による計算結果と等価になるように、
ノード間でやり取りされるメッセージを定義する。
𝑔𝑔 𝑛𝑛 𝑡𝑡−1 𝑔𝑔 𝑛𝑛 𝑡𝑡′
𝑔𝑔 𝑡𝑡+1 𝑔𝑔 1 𝑡𝑡+1
𝛿𝛿 𝑛𝑛←𝑛𝑛 𝑡𝑡 ′
𝛿𝛿 𝑛𝑛 𝑡𝑡+1′←1
𝛿𝛿 𝑛𝑛 𝑡𝑡+1′←𝑁𝑁
𝑡𝑡+1
𝛿𝛿 𝑡𝑡 = 𝜕𝜕 𝑔𝑔 𝑡𝑡 (𝒛𝒛 𝑡𝑡−1 , 𝜽𝜽 𝑡𝑡 ) �
𝑁𝑁
𝑡𝑡+1𝛿𝛿 𝑡𝑡+1 𝑔𝑔 𝑡𝑡−1
𝑔𝑔 1 𝑡𝑡−1
𝑧𝑧 1 𝑡𝑡−1
𝑧𝑧 𝑁𝑁 𝑡𝑡−1𝑡𝑡−1
関数
𝑔𝑔 𝑛𝑛 𝑡𝑡を𝑡𝑡
層𝑛𝑛
番目のノードに対応付ける。
𝛿𝛿 𝑛𝑛←𝑛𝑛 𝑡𝑡 ′:𝑡𝑡
層𝑛𝑛 ′番目のノードから𝑡𝑡 − 1
層𝑛𝑛
番目のノードに送られるメッセージ
𝑡𝑡 − 1
層𝑛𝑛
番目のノードに送られるメッセージ・・ ・
・・ ・ ・・ ・
計算結果の等価性
𝛿𝛿 𝑛𝑛←𝑛𝑛 𝑡𝑡 ′ = 𝜕𝜕 𝑛𝑛 𝑔𝑔 𝑛𝑛 𝑡𝑡′ 𝒛𝒛 𝑡𝑡−1 , 𝜽𝜽 𝑡𝑡 𝜕𝜕 𝑛𝑛′𝐺𝐺 𝑡𝑡+1 𝑇𝑇 𝒛𝒛 𝑡𝑡 , 𝚯𝚯 𝑡𝑡+1 𝑇𝑇 .
誤差逆伝播法
𝒛𝒛 𝑡𝑡−1 , 𝜽𝜽 𝑡𝑡 𝜕𝜕 𝑛𝑛′𝐺𝐺 𝑡𝑡+1 𝑇𝑇 𝒛𝒛 𝑡𝑡 , 𝚯𝚯 𝑡𝑡+1 𝑇𝑇 .
誤差逆伝播法
𝜕𝜕 𝑛𝑛 𝐺𝐺 𝑡𝑡 𝑇𝑇 𝒛𝒛 𝑡𝑡−1 , 𝚯𝚯 𝑡𝑡 𝑇𝑇 = �
𝑛𝑛
′=1 𝑁𝑁
𝑡𝑡𝜕𝜕 𝑛𝑛 𝑔𝑔 𝑛𝑛 𝑡𝑡′ 𝒛𝒛 𝑡𝑡−1 , 𝜽𝜽 𝑡𝑡 𝜕𝜕 𝑛𝑛′𝐺𝐺 𝑡𝑡+1 𝑇𝑇 𝒛𝒛 𝑡𝑡 , 𝚯𝚯 𝑡𝑡+1 𝑇𝑇 .
計算グラフ
𝐺𝐺 𝑡𝑡+1 𝑇𝑇 𝒛𝒛 𝑡𝑡 , 𝚯𝚯 𝑡𝑡+1 𝑇𝑇 .
計算グラフ𝛿𝛿 𝑛𝑛←𝑛𝑛 𝑡𝑡−1′ = 𝜕𝜕 𝑛𝑛 𝑔𝑔 𝑛𝑛 𝑡𝑡−1′ (𝒛𝒛 𝑡𝑡−2 , 𝜽𝜽 𝑡𝑡−1 ) �
(𝒛𝒛 𝑡𝑡−2 , 𝜽𝜽 𝑡𝑡−1 ) �
𝑛𝑛
′′=1 𝑁𝑁
𝑡𝑡𝛿𝛿 𝑛𝑛 𝑡𝑡′←𝑛𝑛
′′
= 𝜕𝜕 𝑛𝑛 𝑔𝑔 𝑛𝑛 𝑡𝑡−1′ 𝒛𝒛 𝑡𝑡−2 , 𝜽𝜽 𝑡𝑡−1 𝜕𝜕 𝑛𝑛′𝐺𝐺 𝑡𝑡 𝑇𝑇 𝒛𝒛 𝑡𝑡−1 , 𝚯𝚯 𝑡𝑡 𝑇𝑇 .
帰納法による証明のスケッチ
𝐺𝐺 𝑡𝑡 𝑇𝑇 𝒛𝒛 𝑡𝑡−1 , 𝚯𝚯 𝑡𝑡 𝑇𝑇 .
帰納法による証明のスケッチ∎
三つの添え字を持つメッセージ
𝛿𝛿 𝑛𝑛←𝑛𝑛 𝑡𝑡 ′(3階のテンソル)が計算グラフ上を流れる。
TensorFlow
の名前の由来最後の等号は、帰納法の仮定と逆伝播の定義式から従う。
計算グラフによる勾配計算
𝑔𝑔 𝑛𝑛 𝑡𝑡 ・・ ・
𝜃𝜃 𝑖𝑖 𝑡𝑡
𝑔𝑔 1 𝑡𝑡+1
𝑔𝑔 𝑁𝑁 𝑡𝑡+1𝑡𝑡+1
入力メッセージの和を取り、宛先ノードに関する偏微分をかける。
𝛿𝛿 𝑛𝑛←1 𝑡𝑡+1
𝛿𝛿 𝑛𝑛←𝑁𝑁 𝑡𝑡+1 𝑡𝑡+1
𝜕𝜕𝑔𝑔 𝑛𝑛 𝑡𝑡
𝜕𝜕𝜃𝜃 𝑖𝑖 𝑡𝑡 𝜕𝜕 𝑛𝑛 𝐺𝐺 𝑡𝑡+1 𝑇𝑇
𝜃𝜃 𝑖𝑖 𝑡𝑡
𝑔𝑔 1 𝑡𝑡
𝑔𝑔 𝑁𝑁 𝑡𝑡𝑡𝑡
・・ ・
𝜕𝜕𝑔𝑔 1 𝑡𝑡
𝜕𝜕𝜃𝜃 𝑖𝑖 𝑡𝑡 𝜕𝜕 1 𝐺𝐺 𝑡𝑡+1 𝑇𝑇
𝜕𝜕𝑔𝑔 𝑁𝑁 𝑡𝑡𝑡𝑡
𝜕𝜕𝜃𝜃 𝑖𝑖 𝑡𝑡 𝜕𝜕 𝑁𝑁𝑡𝑡𝐺𝐺 𝑡𝑡+1 𝑇𝑇
入力メッセージの和を取る。