混合分布をBUGSで推定する

  • ここ数日の記事はBUGS
  • 混合分布をやってみる(参考はこちら) Mixture distribution model
  • 今、全部でN人がエントリーしている100m走大会があるとする N runnners registered for a 100m race
  • このうち、pi[1]の割合で招待選手(群 g=1)、pi[2]の割合で一般参加選手(群 g=2)であるとする N runners are from two groups; g1 is a group of invited and g2 is a group of general population
  • 招待選手のタイムは平均mm[1],標準偏差ss[1]、一般参加選手はmm[2],ss[2]の正規分布に従うとする Assume 100m-race records of g1 and g2 are in normal distribution with mm and ss.
  • ただし、この招待選手というのは、100m走の速さを基準に招待されたわけではなく、マラソンが得意であることで招待されたという。The g1 is invited not for their good performance in 100m race but in 42.195 Km race.
  • 招待選手はアスリートなので、一般参加選手よりも100m走のパフォーマンスがよいようにも思えるし、大差ないかもしれないし、かえって遅いかもしれない You may believe g1 should be faster in 100m race than g2; may not.
  • pi[1],pi[2]の事前確率は0-1の範囲で全く不明だという You don't know the fraction of g1 in the registrants.
  • データを作成してみる
    • 招待選手の割合が20%で、全1000人の大会。招待選手の方が結構速い Generate an data set with g1 being 20% who is faster than g2.

N <- 1000
pi <- c(0.2,0.8)
n1 <- N*pi[1]
n2 <- N-n1

mm <- c(12,16)
ss <- c(1,2)
# 100m走のタイム
y1 <- rnorm(n1,mm[1],ss[1])
y2 <- rnorm(n2,mm[2],ss[2])
y12 <- c(y1,y2)
# グループ
g12 <- c(rep(1,n1),rep(2,n2))
# 選手番号をランダムにつける
ord <- sample(1:N)
# 選手番号順に並べ替える。これが観察データ
y <- y12[ord]
g <- g12[ord]

h <- hist(y,plot=FALSE)
hist(y1,breaks=h$breaks,density=20,col=2,ylim=c(0,max(h$counts)))
hist(y2,breaks=h$breaks,density=17,col=3,add=TRUE)
  • さて、これをBUGSで推定してやる MCMC can estimate the fraction of g1, mm and ss of g1 and g2 by random sampling-based procedure.
  • モデルファイルはこんな感じ
    • N人のタイムyとグループgを個別に分布から採取する
    • ただし、グループはpiという確率ベクトルでカテゴリカル割り当て(dcat()関数)
    • 割り当てられたグループg[i]によって、タイムの正規分布の平均とtauとが付値され
    • それに基づいて、正規分布からタイムが採取される
    • 2グループの割合は、外から与える長さ2のベクトルalphaに従うディリクレ分布(つまりベータ分布)で2群のそれぞれの割合が採取される
    • 2グループのそれぞれの平均は9から30の一様分布から採取し、tauはさっぱりわからない、ということで大きな分散のガンマ分布から採取する
model
{
    for (i in 1:N) {
        y[i] ~ dnorm(m[i], tau[i])
        m[i] <- mm[g[i]]
        tau[i] <- tt[g[i]]
        g[i] ~ dcat(pi[])
    }
    pi[1:2] ~ ddirch(alpha[])
    mm[1] ~ dunif(9, 30)
    mm[2] ~ dunif(9, 30)
    tt[1] ~ dgamma(1.00000E-04, 1.00000E-04)
    tt[2] ~ dgamma(1.00000E-04, 1.00000E-04)
}
  • このモデルファイルを"model4.txt"として保存しておく。以下のようにR2WinBUGSパッケージのwrite.model()関数を使って、Rにテキストファイルを作らせてもよい
model4 <- function()
{
    for (i in 1:N) {
        y[i] ~ dnorm(m[i], tau[i])
        m[i] <- mm[g[i]]
        tau[i] <- tt[g[i]]
        g[i] ~ dcat(pi[])
    }
    pi[1:2] ~ ddirch(alpha[])
    mm[1] ~ dunif(9,30)
    mm[2] ~ dunif(9,30)
    tt[1] ~ dgamma(0.0001,0.0001)
    tt[2] ~ dgamma(0.0001,0.0001)
}
file.name <- "model4.txt"
write.model(model4,file.name)
  • 実行するには、タイムのデータと人数を与える必要がある。また、だれがどちらのグループに属するかは不明ながら、「不明」という情報を与えることになる。ちなみに、「速い群」を仮に1(招待選手群とは限らない)、「遅い群」を仮に2とするために、群情報は次のようにして与える。また、モデルでは、2群の比率をディリクレ分布から採用するように書いている。ディリクレ分布が平坦になる(どんな場合も同確率で想定)するには、引数としてc(1,1)を与えるのがよいから、それをalphaというベクトルとして与えている
    • タイムのベクトルをソートする。グループのベクトルは、第1要素に1を、最終要素に2を与え、それ以外にはNAを与える。これにより、全選手のうち、最も速いヒトと最も遅いヒトは必ず別の群に属することになるが、まあ、それは許せる範囲とみなす
  • その上で、初期値を与える必要がある。2群の比率は0.5:0.5で与えるのが、よいだろう。比率が全く不明なときの期待値がそれだから。2群の平均タイムは同じ値を適当にそれらしい値として選び、分散に対応するtauの初期値もそのように与える
alpha <- c(1,1)
data1 <- list(y=sort(y),g=c(1,rep(NA,N-2),2),alpha=alpha,N=N) 
in1 <- list(pi=c(0.5,0.5),mm=c(15,15),tt=c(1,1)) 
inits <- list(in1) 
param <- c("pi","mm","tt") 
file.name <- "model4.txt"
model1 <- file.name
# 実行
bugs4 <- bugs(data1, inits, param, model.file=model1, 
n.chains=1, n.iter=1100, n.burnin=100, n.thin=1, debug=TRUE, program="OpenBUGS", bugs.directory="C:/Program Files(x86)/OpenBUGS/OpenBUGS323")
print(bugs4)
# 推定したtauからsdを出す
mean(1/sqrt(post4[,5]))
mean(1/sqrt(post4[,6]))
  • 結果
    • 比率も平均もいい感じ
> print(bugs4)
Inference for Bugs model at "model4.txt", fit using OpenBUGS,
 1 chains, each with 1100 iterations (first 100 discarded)
 n.sims = 1000 iterations saved
           mean   sd   2.5%    25%    50%    75%  97.5%
pi[1]       0.2  0.1    0.1    0.2    0.2    0.2    0.3
pi[2]       0.8  0.1    0.7    0.8    0.8    0.8    0.9
mm[1]      12.0  0.3   11.6   11.8   12.0   12.1   12.7
mm[2]      16.0  0.2   15.7   15.8   15.9   16.0   16.4
tt[1]       1.0  0.3    0.5    0.8    0.9    1.1    1.6
tt[2]       0.2  0.0    0.2    0.2    0.2    0.2    0.3
deviance 4090.1 87.2 3912.6 4032.2 4097.4 4147.5 4251.2
> mean(1/sqrt(post4[,5]))
[1] 1.050577
> mean(1/sqrt(post4[,6]))
[1] 2.113046
DIC info (using the rule, pD = var(deviance)/2)
pD = 3802.0 and DIC = 7892.1
DIC is an estimate of expected predictive error (lower deviance is better).