library(emdist)
my.emd <- function(A,B,r=100){
tandem.AB <- c(A,B)
s.t.AB <- sort(tandem.AB,decreasing = TRUE)
cutoff <- s.t.AB[r]
A. <- which(A >= cutoff,arr.ind=TRUE)
B. <- which(B >= cutoff,arr.ind=TRUE)
if(length(A.[,1]) * length(B.[,1])==0){
if(length(A.[,1])==0){
if(length(B.[,1])==0){
return(list(em.d=0,r=r,cutoff=cutoff,A.=matrix(NA,0,0),B.=matrix(NA.0,0)))
}else{
B. <- A.
B.[1:length(B.)] <- 1
}
}else{
A. <- B.
A.[1:length(A.)] <- 1
}
}
A.w <- A[A.]
B.w <- B[B.]
A.2 <- cbind(A.w,A.)
B.2 <- cbind(B.w,B.)
A.2[,1] <- A.2[,1]/sum(A.2[,1])
B.2[,1] <- B.2[,1]/sum(B.2[,1])
return(list(em.d=emd(A.2,B.2),r=r,cutoff=cutoff,A.=A.,B.=B.))
}
n <- 100
m <- 200
A <- matrix(rnorm(n*m),n,m)
A[sample(1:length(A))] <- A
A <- exp(A)
image(A)
n.iter <- 100
em.ds <- rep(0,n.iter)
for(i in 1:n.iter){
ss <- sample(1:length(A),1000)
B <- A
B[ss] <- B[sample(ss)]
par(mfcol=c(1,2))
image(A)
image(B)
par(mfcol=c(1,1))
my.emd.out <- my.emd(A,B)
em.ds[i] <- my.emd.out$em.d
}
plot(sort(em.ds))