自由度2の正確確率を2段階で計算(2)



一昨日の記事Java版。

2段階処理にすると、サンプルサイズが小さいときは、オーバーヘッドの部分のせいで、かえって遅くなるようだが、10000を越えるあたりから、2段階処理の方が速い(ようだ)


/*
* 自由度2のフィッシャーは自由度1→2と2段階でやることで計算量
* を減らすことができる(はず)
* そのためのシリーズ
* 2x3の表にのみ対応するシリーズを
* df2シリーズとする
*/
public static void main(String[] args) {
// TODO 自動生成されたメソッド・スタブ
int[][] d = {{20000,20100,10000},{20000,20000,10000}};
double time1 = System.currentTimeMillis();
double Panswer = StatUtils.Fisher.Fisherdf2Step2(d);
double time2 = System.currentTimeMillis();
System.out.println("P="+Panswer);
int[][] d2 = {{d[0][0],d[0][1]},{d[1][0],d[1][1]}};
double time3 = System.currentTimeMillis();
double P2x2 = StatUtils.Fisher.Fishernxm2(d2);
double time4 = System.currentTimeMillis();
System.out.println("P2="+P2x2);
double t1 = time2-time1;
double t2 = time4-time3;
System.out.println("time1="+t1);
System.out.println("time2="+t2);
}
public static int[][] marginaldf2(int[][] d){
int[][] ret = {
{d[0][0]+d[0][1]+d[0][2],d[1][0]+d[1][1]+d[1][2],
0},
{d[0][0]+d[1][0],d[0][1]+d[1][1],d[0][2]+d[1][2]}
};
ret[0][2]=ret[0][0]+ret[0][1];
return ret;
}
public static double[] serialLogdf2(int[][] marginal){
int tmp = marginal[0][2]*2;
double[] ret = serialLogfact(tmp);

return ret;
}
public static double lnpr(int[][] d,int[][] m,double[] s){
double ret =(s[m[0][0]]+s[m[0][1]]+s[m[1][0]]+s[m[1][1]]+s[m[1][2]])-
(s[m[0][2]]+s[d[0][0]]+s[d[0][1]]+s[d[0][2]]+s[d[1][0]]+s[d[1][1]]+s[d[1][2]]);

return ret;
}
public static int[] minmaxA(int[][] d,int[][] m){
int[] ret = {0,0};
ret[1] = Math.min(m[0][0], m[1][0]);
int maxBC = Math.min(m[0][0], m[1][1]+m[1][2]);
//System.out.println("maxBC="+maxBC);
//int maxC = Math.min(m[0][0], m[1][2]);
ret[0] = Math.max(0, m[0][0]-(maxBC));
return ret;
}
public static int[] minmaxB(int[][] d,int[][] m2){
int[] ret = {0,0};
ret[1] = Math.min(m2[0][0], m2[1][0]);
int maxBC = Math.min(m2[0][0], m2[1][1]);
//System.out.println("maxBC="+maxBC);

ret[0] = Math.max(0, m2[0][0]-(maxBC));
return ret;
}
public static double prPerA(int a,int[][] d,int[][] m,double lnpr,double[]s){
double ret = 0;

double lnPrA = lnPrA(a,d,m,s);
double tmptmp = Math.exp(lnPrA);
//System.out.println("AAAA\t\t\t"+tmptmp);
if(lnPrA<lnpr){
//System.out.println("less "+a);
ret = Math.exp(lnPrA);

}else{
//System.out.println("more "+a);
int[][] m2 = {
{m[0][0]-a,m[0][1]-(m[1][0]-a),m[0][2]-m[1][0]},
{m[1][1],m[1][2]}
};
int d2 = m[1][0]-a;
//System.out.println("a="+a+"d="+d2);
/*
for(int i=0;i<m2.length;i++){
for(int j=0;j<m2[i].length;j++){
System.out.println(m2[i][j]);
}
}
*/
ret = PrAB(a,m[1][0]-a,d,m,m2,lnpr,s);

}
//System.out.println("PrPerA\t"+ret);
return ret;
}
public static double PrAB(int a,int d2,int[][] d,int[][] m,int[][] m2,double lnpr,double[]s){
double ret = 0;
int[] minmaxB = minmaxB(d,m2);
int[] seriald ={d[0][0],d[0][1],d[0][2],d[1][0],d[1][1],d[1][2]};
double lnconst = s[m[0][0]]+s[m[0][1]]+s[m[1][0]]+s[m[1][1]]+s[m[1][2]]-
(s[m[0][2]]+s[a]+s[d2]);
int here =minmaxB[1];
for(int i=minmaxB[0];i<=minmaxB[1];i++){
//System.out.println("i="+i);
int[] serialtmp ={a,d2,i,m2[0][0]-i,m2[1][0]-i,m2[1][1]-m2[0][0]+i};
boolean identical = Fisher.identicalTableElem1(seriald,serialtmp);
if(identical){
ret += Math.exp(lnpr);
//System.out.println("identical");
}else{
//System.out.println("NOT identical");
double tmplnpr=lnconst-(s[serialtmp[2]]+s[serialtmp[3]]+s[serialtmp[4]]+s[serialtmp[5]]);
if(tmplnpr<=lnpr){
//System.out.println("Less");
ret += Math.exp(tmplnpr);
}else{
//System.out.println("More");
here = i;
break;
}
}
}
for(int i=minmaxB[1];i>here;i--){
//System.out.println("i="+i);
int[] serialtmp ={a,d2,i,m2[0][0]-i,m2[1][0]-i,m2[1][1]-m2[0][0]+i};
boolean identical = Fisher.identicalTableElem1(seriald,serialtmp);
if(identical){
ret += Math.exp(lnpr);
}else{
double tmplnpr=lnconst-(s[serialtmp[2]]+s[serialtmp[3]]+s[serialtmp[4]]+s[serialtmp[5]]);
if(tmplnpr<=lnpr){
ret += Math.exp(tmplnpr);
}else{
here = i;
break;
}
}
}
return ret;
}
public static double lnPrA(int a,int[][] d,int[][] m, double[]s){

int[][] t = {
{a,m[0][0]-a},
{m[1][0]-a,m[0][2]-m[0][0]-m[1][0]+a}
};
//System.out.println("t");
//for(int i=0;i<t.length;i++){
//for(int j=0;j<t[i].length;j++){
//System.out.println(t[i][j]);
//}
//}
int[][] m2 = {
{m[0][0],m[0][1],m[0][2]},
{m[1][0],m[1][1]+m[1][2]}
};
//System.out.println("m2");
//for(int i=0;i<m2.length;i++){
//for(int j=0;j<m2[i].length;j++){
//System.out.println(m2[i][j]);
//}
//}
double ret = (s[m2[0][0]]+s[m2[0][1]]+s[m2[1][0]]+s[m2[1][1]])-
(s[m2[0][2]]+s[t[0][0]]+s[t[0][1]]+s[t[1][0]]+s[t[1][1]]);
;
return ret;
}