library(FNN)
my.dist.KL <- function(X,method="euclidean",k=10){
if(is.matrix(X)){
ret <- as.matrix(dist(X,method=method))
return(ret)
}else{
n <- length(X)
num.sample <- sapply(X,nrow)
K <- min(k,min(num.sample)-1)
ret <- array(0,c(n,n,K))
for(i in 1:n){
for(j in 1:n){
ret[i,j,] <- KL.divergence(X[[i]],X[[j]],k=K)
}
}
return(ret)
}
}
my.perm.KL <- function(X,Y,n.iter=100,method="euclidean",k1=10,k2=10){
D.X <- my.dist.KL(X,k=k1)
log.D.X <- log(D.X)
D.Y <- my.dist.KL(Y,k=k1)
log.D.Y <- log(D.Y)
k.NN.X <- t(apply(D.X,1,order))
k.NN.Y <- t(apply(D.Y,1,order))
K <- k2
n <- length(D.X[,1])
St.K <- rep(0,K)
for(i in 1:K){
St.K[i] <- sum(log.D.Y[cbind(1:n,k.NN.X[,i+1])])
St.K[i] <- St.K[i] + sum(log.D.X[cbind(1:n,k.NN.Y[,i+1])])
}
St.Perm <- matrix(0,n.iter,K)
for(i in 1:n.iter){
sh <- sample(1:n)
sh.log.D.Y <- log.D.Y[sh,]
sh.log.D.Y <- sh.log.D.Y[,sh]
for(j in 1:K){
St.Perm[i,j] <- sum(sh.log.D.Y[cbind(1:n,k.NN.X[,j+1])])
}
sh <- order(sh)
sh.log.D.X <- log.D.X[sh,]
sh.log.D.X <- sh.log.D.X[,sh]
for(j in 1:K){
St.Perm[i,j] <- St.Perm[i,j] + sum(sh.log.D.X[cbind(1:n,k.NN.Y[,j+1])])
}
}
plot(c(sort(St.Perm[,1]),St.K[1]))
abline(h = St.K[1])
ps <- rep(0,K)
for(i in 1:K){
ps[i] <- length(which(St.Perm[,i]< St.K[i]))/n.iter
}
return(list(St.K=St.K,St.Perm=St.Perm,p.value=ps))
}
N <- 100
D <- 100
d <- 70
t <- 3
X <- matrix(rnorm(N*D),ncol=D)
Y <- array(0,c(N,d,t))
Y[,,1] <- rnorm(d)
for(i in 1:N){
Y[i,,2] <- Y[i,,1] + (Y[i,,1]+1)*(X[i,1]*X[i,2]+X[i,3]) + sum(X[i,4:7])
Y[i,,3] <- Y[i,,1] + (Y[i,,1]+1)*(X[i,1]*X[i,2]+X[i,3]) + sum(X[i,4:7])*3
}
er <- 0.1
Y <- Y + rnorm(length(Y),0,var(Y)*er)
Y2 <- matrix(0,N,d*t)
for(i in 1:N){
Y2[i,] <- Y[i,,]
}
my.perm.out <- my.perm.KL(X,Y2)
my.perm.out
par(ask=TRUE)
for(i in 1:length(my.perm.out$St.K)){
plot(c(sort(my.perm.out$St.Perm[,i]),my.perm.out$St.K[i]))
abline(h=my.perm.out$St.K[i])
}
par(ask=FALSE)
n.x <- 30
n.y <- 30
X <- list()
Y <- list()
for(i in 1:n.x){
X[[i]] <- matrix(rnorm(50),ncol=2)
}
for(i in 1:n.y){
Y[[i]] <- matrix(rnorm(75),ncol=3)
}
my.perm.out <- my.perm.KL(X,Y)
par(ask=TRUE)
for(i in 1:length(my.perm.out$St.K)){
plot(c(sort(my.perm.out$St.Perm[,i]),my.perm.out$St.K[i]))
}
par(ask=FALSE)