• 検索結果がありません。

EM アルゴリズム

ドキュメント内 uda2008/main.tex 2008/05/ (ページ 117-126)

> ii2x3 <- apply(pr2x3,1,function(p) order(-p)[1])

> ii2x3[1:30] # 推定された信号源(最初の30個)

[1] 3 2 3 2 1 3 1 2 1 3 1 3 1 1 1 2 2 2 2 2 1 2 2 3 1 1 2 2 2 3

> sum(ii2x3 == ii2)/length(ii2) # 正解率 [1] 0.9033333

> ### 誤差の比較(小さいほど良い)

> sum((c(pr1,mu1,ss1)-c(pr,mu,ss))^2) # n1 個の (x,i) から推定 [1] 0.5636234

> sum((c(pr2,mu2,ss2)-c(pr,mu,ss))^2) # n2 個の x から推定 [1] 0.6626585

> sum((c(pr3,mu3,ss3)-c(pr,mu,ss))^2) # n1 個の (x,i) n2 個の x から推定 [1] 0.1497713

[

注意

]

例 2.8と例 2.9では,対数尤度関数を数値的に最大化した.この数値的最適化には R の組み込み関数

optim を用いた.これに実装されているアルゴリズムは汎用で十分性能の高いものであるけれど,パラメタ

の次元が大きくなってくると不安定になりやすい.実際,例 2.8や例 2.9でも,データや初期値によっては不 適解が得られる.混合分布のモデルの最尤推定では,もっと実装が簡単で安定な方法が知られている.これは

EM

アルゴリズムと呼ばれる.

Histogram of xx2

xx2

Density

−5 0 5

0.000.050.100.15

Histogram of xx2

xx2

Density

−5 0 5

0.000.050.100.15

23 n2 個の xt だけから推定した密度関数(青).n1 個の (xt, it) n2 個の xt から推定した密度関数(オレンジ).

への更新ルールを次で定める.以下,

n = n

1

+ n

2 とおく.

ˆ E ステップ (Expectation step).とりあえず θ(r) を信用して,

t = 1, . . . , n

1

+ n

2 について

i

の事後 確率

p(i

t|

x

t

) = π

i

(x

t

;

θ(r)

)

を計算する.

π

i

(x

t

;

θ(r)

) = I (i = i

t

), t = 1, . . . , n

1X1 では

i

t を観測している)

π

i

(x

t

;

θ(r)

) = f

i

(x

t

;

θ(r)i

i(r)

f (x

t

;

θ(r)

) , t = n

1

+ 1, . . . , n

ˆ M ステップ (Maximization step)

π

1

, . . . , π

k を事後確率の平均値で推定する

(

ただし P

i=1

π

i

= 1)

π

i(r+1)

=

n1

Xn t=1

π

i

(x

t

;

θ(r)

)

各サンプル

x

t の「重み」を次式で計算する.

w

t

= π

i

(x

t

;

θ(r)

)

Pn

t0=1

π

i

(x

t0

;

θ(r)

) , t = 1, . . . , n

そして,各成分

i = 1, . . . , k

で場合分けして重みつきの最尤推定を行う.正規混合分布では次式で推定 する.

µ

(r+1)i

=

Xn

t=1

w

t

x

t

, t = 1, . . . , n

σ

i2 (r+1)

=

Xn

t=1

w

t

(x

t

µ

(r+1)i

)

2

, t = 1, . . . , n

上記の

E

ステップと

M

ステップを反復するだけである.計算はきわめて単純であるが,以下の実行結果を見

ると,optimによって得られた結果と等価なものが得られている.また,反復のたびに対数尤度がかならず増

加している.なお,反復計算そのものには,対数尤度の

log L(θ

(r)|X1

,

X2

)

を計算する必要さえない.

> ## EM アルゴリズムによる正規混合分布の最尤推定

> ## データ: xx1, ii1, xx2

> ## 成分数: k

> ## 推定結果は,pr4,mu4,ss4 に保存

> pr4 <- c(1/3,1/3,1/3); mu4 <- c(0,2,-2); ss4 <- c(1,1,1) # 初期値

> nr <- 30 # とりあえず反復回数をきめておく

> mystat <- function(pr,mu,ss) { # 反復の途中経過を表示する関数を準備しておく + lik <- mylik3(c(pr[-k],mu,ss)) # 目的関数

+ cat(format(lik,digits=10),round(c(pr,mu,ss),3),"\n") # 一行でサマリを表示 + c(lik,pr,mu,ss)

+ }

> stat3 <- mystat(pr3,mu3,ss3) # 参考のため,optim mylik3 を最適化した結果を表示しておく 998.7260493 0.487 0.308 0.205 -0.086 4.056 -3.12 1.124 3.676 1.066

1372.011849 0.333 0.333 0.333 0 2 -2 1 1 1

> xx <- c(xx1,xx2) # 推定に利用するデータを結合しておく

> pp1 <- matrix(0,n1,k); pp1[seq(n1)+(ii1-1)*n1] <- 1 # xx1における「事後確率」

> t(pp1[1:10,]) # 最初の10個のメール

[,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]

[1,] 1 1 0 1 1 1 0 1 0 1

[2,] 0 0 1 0 0 0 0 0 1 0

[3,] 0 0 0 0 0 0 1 0 0 0

> ## EM アルゴリズムの反復計算はここから

> for(r in 1:nr) { # 反復計算:本当は収束判定して break するべき.

+ a <- drawnormmix(xx2,k,pr4,mu4,ss4,FALSE); pp2 <- a$fi/a$f # xx2 における事後確率 + pp <- rbind(pp1,pp2) # xx1 xx2の事後確率を結合しておく

+ pr4 <- apply(pp,2,sum)/(n1+n2) # pr の推定

+ # wt <- pp/rep(apply(pp,2,sum),rep(n1+n2,k)) #「重み」を事後確率で定義

+ wt <- sweep(pp,2,apply(pp,2,sum),"/") # 上記と同じだけどsweep 関数を使ってみた + mu4 <- apply(xx*wt,2,sum) # mu の推定

+ # ss4 <- apply((xx-rep(mu4,rep(n1+n2,k)))^2*wt,2,sum) # ss の推定

+ ss4 <- apply(sweep(matrix(xx,n1+n2,k),2,mu4,"-")^2*wt,2,sum) # 上記と同じ + cat(r,": ")

+ stat4[r+1,] <- mystat(pr4,mu4,ss4) # 目的関数を保存しておく(反復計算には必要ない)

+ }

1 : 1009.451812 0.352 0.364 0.283 -0.056 3.607 -2.558 0.883 4.279 1.776 2 : 1004.582255 0.399 0.343 0.257 -0.034 3.733 -2.706 0.916 4.368 1.691 3 : 1002.053705 0.425 0.332 0.243 -0.034 3.825 -2.819 0.955 4.215 1.51 4 : 1000.541512 0.442 0.325 0.233 -0.04 3.892 -2.905 0.985 4.062 1.363 5 : 999.6855876 0.454 0.32 0.226 -0.047 3.94 -2.968 1.009 3.945 1.26 6 : 999.2271227 0.463 0.317 0.22 -0.055 3.975 -3.013 1.03 3.861 1.191 7 : 998.9887642 0.47 0.314 0.216 -0.062 3.999 -3.044 1.048 3.803 1.147 8 : 998.865351 0.474 0.313 0.213 -0.068 4.015 -3.066 1.064 3.763 1.119

9 : 998.8007836 0.478 0.311 0.211 -0.072 4.027 -3.081 1.077 3.737 1.101 10 : 998.766521 0.48 0.311 0.209 -0.076 4.035 -3.092 1.088 3.718 1.09 11 : 998.748113 0.482 0.31 0.208 -0.078 4.041 -3.1 1.097 3.706 1.083 12 : 998.738132 0.483 0.31 0.207 -0.08 4.045 -3.105 1.103 3.697 1.078 13 : 998.7326863 0.484 0.309 0.207 -0.082 4.048 -3.109 1.109 3.691 1.074 14 : 998.7297025 0.485 0.309 0.206 -0.083 4.05 -3.112 1.112 3.687 1.072 15 : 998.728063 0.486 0.309 0.206 -0.084 4.052 -3.114 1.115 3.684 1.07 16 : 998.7271603 0.486 0.309 0.205 -0.084 4.053 -3.115 1.118 3.682 1.069 17 : 998.7266627 0.486 0.309 0.205 -0.085 4.054 -3.116 1.119 3.68 1.068 18 : 998.7263881 0.486 0.308 0.205 -0.085 4.055 -3.117 1.121 3.679 1.068 19 : 998.7262365 0.487 0.308 0.205 -0.085 4.055 -3.118 1.121 3.678 1.067 20 : 998.7261528 0.487 0.308 0.205 -0.085 4.055 -3.118 1.122 3.678 1.067 21 : 998.7261065 0.487 0.308 0.205 -0.086 4.056 -3.119 1.123 3.677 1.066 22 : 998.726081 0.487 0.308 0.205 -0.086 4.056 -3.119 1.123 3.677 1.066 23 : 998.7260668 0.487 0.308 0.205 -0.086 4.056 -3.119 1.123 3.677 1.066 24 : 998.726059 0.487 0.308 0.205 -0.086 4.056 -3.119 1.124 3.677 1.066 25 : 998.7260546 0.487 0.308 0.205 -0.086 4.056 -3.119 1.124 3.677 1.066 26 : 998.7260522 0.487 0.308 0.205 -0.086 4.056 -3.12 1.124 3.676 1.066 27 : 998.726051 0.487 0.308 0.205 -0.086 4.056 -3.12 1.124 3.676 1.066 28 : 998.7260502 0.487 0.308 0.205 -0.086 4.056 -3.12 1.124 3.676 1.066 29 : 998.7260498 0.487 0.308 0.205 -0.086 4.056 -3.12 1.124 3.676 1.066 30 : 998.7260496 0.487 0.308 0.205 -0.086 4.056 -3.12 1.124 3.676 1.066

> pr4 # 推定した pr

[1] 0.4870391 0.3082993 0.2046617

> mu4 # 推定した mu

[1] -0.08593303 4.05631674 -3.11967172

> sqrt(ss4) # 推定した sqrt(ss) [1] 1.060217 1.917354 1.032377

> plot(0:nr,stat4[,1],type="b",xlab="iteration",ylab="lik") # 目的関数のグラフ

> abline(h=stat3[1],lty=2,col="pink") # optim の結果を赤線で表示

> matplot(0:nr,stat4[,-1],type="b",xlab="iteration",ylab="parameters") # パラメタ推定値の変化

> abline(h=stat3[-1],lty=2,col="pink") # optim の結果を赤線で表示

[

定義

2.7]

観測したデータを X,未観測データを Y で表す(対応する確率変数も同じ文字で書く).課題 2.9では,X

=

{X1

,

X2}

=

{

x

1

, . . . , x

n1+n2

, i

1

, . . . , n

n1}Y

=

{

i

n1+1

, . . . , i

n1+n2} である.

(

X

,

Y

)

の同時 分布を表す尤度を

L(θ

|X

,

Y

)

X の周辺分布を表す尤度を

L(θ

|X

) =

R

L(θ

|X

,

Y

) d

Y で表す.ただし積分Y の取りうるすべての値について取る(課題 2.9では離散分布なので積分は和になる).

L(θ

|X

)

を最大 化するための

EM

アルゴリズムでは,

r = 0

を初期値として,現在のパラメタ値 θ(r) から次のパラメタ値

θ(r+1) への更新ルールを次で定める.

ˆ E ステップ (Expectation step).とりあえず θ(r) を信用して,X を与えたときの Y の条件付分布

f (

Y|X

;

θ(r)

)

を計算する.

f (

Y|X

;

θ(r)

) = L(θ

(r)|X

,

Y

) L(θ

(r)|X

)

この条件付分布(つまり Y の事後確率)に関して,

log(L(θ

|X

,

Y

))

の期待値を定義する.

Q(θ,

θ(r)

) =

Z

log(L(θ

|X

,

Y

))f (

Y|X

;

θ(r)

) d

Y

0 5 10 15 20 25 30

1000110012001300

iteration

lik 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1

0 5 10 15 20 25 30

−2024

iteration

parameters 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 23 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3

4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 5

5 55 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5

6

6 6 66 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 8

8 8 8 8

8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8

9

9 99 9 9 99 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9

24 ( 左 )logL(θ(r)|X1,X2), r = 0, . . . ,30 の プ ロ ッ ト .実 用 的 に は 最 初 の 10 程 度 の 反 復 で 収 束 し た .optim に よ る 値 は ピ ン ク の 線 で 示 し た .( 右 )パ ラ メ タ θ(r) = (π1(r), π2(r), µ(r)1 , µ(r)2 , µ(r)3 , σ12 (r), σ22 (r), σ32 (r)) のプロット

ˆ M ステップ (Maximization step)

Q(θ,

θ(r)

)

θ の関数とみなし最大化する.

θ(r+1)

= arg max

θ

Q(θ,

θ(r)

)

これを反復すると,

L(θ

(r+1)|X

)

L(θ

(r)|X

)

である.

[

証明

] L(θ

|X

) = L(θ

|X

,

Y

)/f (

Y|X

;

θ) の対数をとると

log L(θ

|X

) = log L(θ

|X

,

Y

)

log f (

Y|X

;

θ)

この両辺を条件付分布

f (

Y|X

;

θ(r)

)

に関して期待値を計算すると左辺はそのままなので,

log L(θ

|X

) = Q(θ,

θ(r)

)

H (θ,

θ(r)

)

ただし

H (θ,

θ(r)

) =

Z

log(f (

Y|X

;

θ))f

(

Y|X

;

θ(r)

) d

Y

とおく.一般に任意の θ

H (θ,

θ(r)

)

H

(r)

,

θ(r)

)

であることに注意すると,

Q(θ,

θ(r)

)

Q(θ

(r)

,

θ(r)

)

となるように θ を選びさえすれば,

log L(θ

|X

)

log L(θ

(r)|X

)

であることが分かる.

[

注意

] (i) EM

アルゴリズムは

log L(θ

|X

)

を増加させるが,その最大値に収束するとは言えない.適当な条

件下で極大値に収束することが示されている.この問題を回避する現実的な方法は,いくつかの初期値から

EM

アルゴリズムを実行して,そのなかで尤度を最大にするものが選ぶ.

(ii)

正規混合モデルでは,もし

µ

1

をデータのどれかの点

x

t に一致させると

log L(θ

|X

)

は無限大になる.つまり最大化は意味をなさず,極大解 の中で尤度を最大にするものを求めることが目的になる.

[

課題

2.10]

密度関数

f (x)

g(x)

−∞

< x <

f (x) > 0, g(x) > 0

とする.このとき Z

−∞

log(g(x))f (x) dx

Z

−∞

log(f (x))f (x) dx

を示せ.

[

課題

2.11]

定義 2.7の

EM

アルゴリズムを正規混合モデルに適用すると例 2.10のアルゴリズムが得られる

ことを示せ.

ドキュメント内 uda2008/main.tex 2008/05/ (ページ 117-126)

関連したドキュメント