library(tidyverse)
library(mvtnorm)
library(gtools)

# set working directory 
#setwd("")
R=36 # number of sites
S=9 # number of visits/site
dist=as.matrix(read.csv("ORdist.csv",header=F)) # Oregon County Distance Matrix 
pops=read.csv("chlym.csv")[,18:26] #Load Oregon county population data 
diag_dist=diag(R) # Optional diagonal distance matrix 
#obs_abund=as.matrix(read.csv("chlym.csv")[,7:15]) #If not doing simulations, this is the observed data 


task=as.numeric(commandArgs(T)) # specify if doing a certain simulation scenario through cluster 

##Simulation scenarios for b1, p, omega, gamma, respectively 
simoptos=rep(list(c(.005,.4,.5,.3),c(.005,.4,.8,.6),c(.005,.7,.5,.3),c(.005,.7,.8,.6),
              c(.005,.9,.5,.3),c(.005,.9,.8,.6),c(.05,.4,.5,.3),c(.05,.4,.8,.6),
              c(.05,.7,.5,.3),c(.05,.7,.8,.6),c(.05,.9,.5,.3),c(.05,.9,.8,.6),
              c(.005,.4,.5,.6),c(.005,.4,.8,.3),c(.005,.7,.5,.6),c(.005,.7,.8,.3),
              c(.005,.9,.5,.6),c(.005,.9,.8,.3),c(.05,.4,.5,.6),c(.05,.4,.8,.3),
              c(.05,.7,.5,.6),c(.05,.7,.8,.3),c(.05,.9,.5,.6),c(.05,.9,.8,.3)),10)

params=simoptos[[task]] #define parameters for specific task 

##Asymptotic mean vector 
MuGen=function(b1,p,omega,gamma,pops,dist){
  curmean=matrix(0,R,S)
  curmean[,1]=b1*pops[,1]
  for (each in 2:S){
    prevalences=curmean[,each-1]/pops[,each-1]
    curmean[,each]=gamma*pops[,each-1]*((dist %*% prevalences)/apply(dist,1,sum))+omega*(curmean[,each-1])
  }
  return(p*curmean)
}
##Asymptotic variance covariance matrix 
MuSig=function(b1,p,omega,gamma,pops,dist){
  outers=matrix(as.numeric(outer(dist,dist)),nrow=R*R,ncol=R*R)
  mult=lapply(split(seq(R*R), rep(1:(R*R/R), each = R)), function(x) outers[x,x])
  covmat=matrix(0,R*S,R*S)
  curmean=MuGen(b1,p,omega,gamma,pops,dist)/p
  diag(covmat)[1:R]=curmean[,1]
  covloop=gtools::combinations(S,2,repeats.allowed=T)[-1,]
  for (each in 1:dim(covloop)[1]){
    poptile=matrix(pops[,covloop[each,][2]-1],R,R,byrow=T)
    stepdown=t(dist/apply(dist,1,sum)/poptile)
    
    T1=covloop[each,1]; T2=covloop[each,2]
    entry=gamma*poptile*(covmat[((T1-1)*R+1):((T1-1)*R+R),((T2-2)*R+1):((T2-2)*R+R)] %*% stepdown) +
      omega*covmat[((T1-1)*R+1):((T1-1)*R+R),((T2-2)*R+1):((T2-2)*R+R)]
    covmat[((T1-1)*R+1):((T1-1)*R+R),((T2-1)*R+1):((T2-1)*R+R)]=entry
    covmat[((T2-1)*R+1):((T2-1)*R+R),((T1-1)*R+1):((T1-1)*R+R)]=t(entry)
    if (T1==T2){
      l=((pops[,T2-1]*gamma/apply(dist,1,sum))^2 * unlist(purrr::map(mult,~sum(.x * covmat[((T1-2)*R+1):((T1-2)*R+R),((T1-2)*R+1):((T1-2)*R+R)]/poptile/t(poptile)))) +
           omega**2 * diag(covmat[((T1-2)*R+1):((T1-2)*R+R),((T1-2)*R+1):((T1-2)*R+R)]) +
           2*(pops[,T2-1]*omega*gamma/apply(dist,1,sum)) * apply(covmat[((T1-2)*R+1):((T1-2)*R+R),((T1-2)*R+1):((T1-2)*R+R)]/poptile*dist,1,sum) +
           curmean[,T1-1]*omega*(1-omega)+gamma*pops[,T2-1]*(dist %*% (curmean[,T1-1]/pops[,T2-1])/apply(dist,1,sum))
      )
      for (val in l){
        covmat[(T1-1)*R+which(val==l),(T2-1)*R+which(val==l)]=val
      }
    }
  }
  master_sig=covmat*p^2
  diag(master_sig)=diag(master_sig)+p*(1-p)*curmean
  return(master_sig)
}

#Negative log likelihood function 
MaxLE=function(params,obs_abund,dist){
  b1 = params[1]
  pr = params[2]
  o=params[3]
  g = params[4]
  
  if (b1==0|pr==0|o==0|g==0) {return(9999999)}
  if (b1==1|pr==1|o==1) {return(9999999)}
  
  else{
    
    print(b1)
    print(pr)
    print(o)
    print(g)
    print("---")
    
    negLL=-dmvnorm(as.vector(obs_abund),mean=as.numeric(MuGen(b1,pr,o,g,pops,dist)),sigma=MuSig(b1,pr,o,g,pops,dist),log=T)
    #print(negLL)
    #print("---")
    
    return(negLL)
  }}

for (iter in 1:1000){ #number of iterations in simulation, skip to line 118 if using fixed dataset
  ##simulates data set
  b1=params[1]
  p=params[2]
  omega=params[3]
  gamma=params[4]
  lam=b1*(pops[,1])
  
  true=matrix(0,nrow=R,ncol=S)
  true[,1]=rpois(R,lam)
  
  for (each in 2:S){
    Su=rbinom(R,true[,each-1],omega)
    prevalences=true[,each-1]/pops[,each-1]
    G=rpois(R,gamma*pops[,each-1]*((dist %*% prevalences)/apply(dist,1,sum)))
    true[,each]=Su+G
  }
  
  obs_abund=matrix(rbinom(R*S,true,p),nrow=R,ncol=S)
  

  LL=-dmvnorm(as.vector(obs_abund),mean=as.numeric(MuGen(b1,p,omega,gamma,pops,dist)),sigma=MuSig(b1,p,omega,gamma,pops,dist),log=T)# LL
  
  stepsize=if(mean(obs_abund[,1]/pops[,1])<.01) .00001 else .0001  #specify stepsize for b1 parameter based on observed values and population

## Start optimization algorithm (L-BFGS-B) for both diagonal distance matrix and Oregon distance matrix at 3 starting points each 
  results1=tryCatch(optim(c(mean(obs_abund[,1]/pops[,1]),.9,.5,.5),MaxLE,obs_abund=obs_abund,dist=diag_dist,
                          method="L-BFGS-B",lower=c(.00001,.01,.01,.01),upper=c(.99,.99,.99,2),control=list(maxit=500,ndeps=c(stepsize,.001,.001,.001))),silent=T,
                    error=function(cond) {
                      out=list(par=c(NA,NA,NA,NA),value=NA,convergence=NA)
                      return(out)
                    })
  
  results2=tryCatch(optim(c(mean(obs_abund[,1]/pops[,1]),.5,.5,.5),MaxLE,obs_abund=obs_abund,dist=diag_dist,
                          method="L-BFGS-B",lower=c(.00001,.01,.01,.01),upper=c(.99,.99,.99,2),control=list(maxit=500,ndeps=c(stepsize,.001,.001,.001))),silent=T,
                    error=function(cond) {
                      out=list(par=c(NA,NA,NA,NA),value=NA,convergence=NA)
                      return(out)
                    })
  
  results3=tryCatch(optim(c(mean(obs_abund[,1]/pops[,1]),.1,.5,.5),MaxLE,obs_abund=obs_abund,dist=diag_dist,
                          method="L-BFGS-B",lower=c(.00001,.01,.01,.01),upper=c(.99,.99,.99,2),control=list(maxit=500,ndeps=c(stepsize,.001,.001,.001))),silent=T,
                    error=function(cond) {
                      out=list(par=c(NA,NA,NA,NA),value=NA,convergence=NA)
                      return(out)
                    })
  

  results_non=list(results1,results2,results3)[[which.min(c(results1$value,results2$value,results3$value))]] # Pick best diagonal distance matrix LL

  end1=Sys.time()
  

  start2=Sys.time()
  
  results1=tryCatch(optim(c(mean(obs_abund[,1]/pops[,1]),.9,.5,.5),MaxLE,obs_abund=obs_abund,dist=dist,
                          method="L-BFGS-B",lower=c(.00001,.01,.01,.01),upper=c(.99,.99,.99,2),control=list(maxit=500,ndeps=c(stepsize,.001,.001,.001))),silent=T,
                    error=function(cond) {
                      out=list(par=c(NA,NA,NA,NA),value=NA,convergence=NA)
                      return(out)
                    })
  
  results2=tryCatch(optim(c(mean(obs_abund[,1]/pops[,1]),.5,.5,.5),MaxLE,obs_abund=obs_abund,dist=dist,
                          method="L-BFGS-B",lower=c(.00001,.01,.01,.01),upper=c(.99,.99,.99,2),control=list(maxit=500,ndeps=c(stepsize,.001,.001,.001))),silent=T,
                    error=function(cond) {
                      out=list(par=c(NA,NA,NA,NA),value=NA,convergence=NA)
                      return(out)
                    })
  
  results3=tryCatch(optim(c(mean(obs_abund[,1]/pops[,1]),.1,.5,.5),MaxLE,obs_abund=obs_abund,dist=dist,
                          method="L-BFGS-B",lower=c(.00001,.01,.01,.01),upper=c(.99,.99,.99,2),control=list(maxit=500,ndeps=c(stepsize,.001,.001,.001))),silent=T,
                    error=function(cond) {
                      out=list(par=c(NA,NA,NA,NA),value=NA,convergence=NA)
                      return(out)
                    })
  
  
  results=list(results1,results2,results3)[[which.min(c(results1$value,results2$value,results3$value))]] # Pick best Oregon distance matrix LL
  

##These give mean abundance estimates for each site and time point 
cur=matrix(0,nrow=R,ncol=S)
cur[,1]=pops[,1]*results$par[1]
for (each in 1:(S-1)){
weights=cur[,each]/pops[,each]
cur[,(each+1)]=results$par[3]*cur[,each]+results$par[4]*pops[,each]*((dist %*% weights)/apply(dist,1,sum))
}

cur2=matrix(0,nrow=R,ncol=S)
cur2[,1]=pops[,1]*results_non$par[1]
for (each in 1:(S-1)){
  weights=cur2[,each]/pops[,each]
  cur2[,(each+1)]=results_non$par[3]*cur2[,each]+results_non$par[4]*pops[,each]*((dist_diag %*% weights)/apply(dist_diag,1,sum))
}

##function to calculate partial derivates (R's functions seem to have a hard time doing this when there's a recursive loop in the likelihood) in order to 
## calculate asymptotic information 
ben_par_der=function(h=2.22e-10,partial_var=NULL,FUN=NULL,...){
  inputs=list(...)
  inputsh=inputs
  idx=which(names(inputs)==partial_var)
  inputsh[[idx]] = inputsh[[idx]]+h
  (do.call(FUN, inputsh)-  do.call(FUN, inputs))/h
}
h=2.22e-10

dfMu=c("b1","p","omega","gamma") %>% 
  purrr::map(~ben_par_der(h,partial_var=.,FUN=MuGen,b1=results$par[1],p=results$par[2],omega=results$par[3],gamma=results$par[4],pops=pops,dist=dist))

dfSig=c("b1","p","omega","gamma") %>% 
  purrr::map(~ben_par_der(h,partial_var=.,FUN=MuSig,b1=results$par[1],p=results$par[2],omega=results$par[3],gamma=results$par[4],pops=pops,dist=dist))

inf=matrix(0,nrow=4,ncol=4) # initiate information matrix 

#results=list(par=as.numeric(tst2[2,11:14]))

sigplug=MuSig(b1=results$par[1],p=results$par[2],omega=results$par[3],gamma=results$par[4],pops,dist)

combos=matrix(c(1,1,1,1,2,2,2,3,3,4,1,2,3,4,2,3,4,3,4,4),ncol=2)#combinations(4, 2,replace=T)


for (combs in 1:10){
  m=combos[combs,1]
  n=combos[combs,2]
  a=dfMu[[m]]
  b=dfMu[[n]]
  c=dfSig[[m]]
  e=dfSig[[n]]
  out=as.vector(a) %*% solve(sigplug) %*% as.vector(b) + .5* sum(diag(solve(sigplug) %*% c %*% solve(sigplug) %*% e))
  inf[m,n]=out
  inf[n,m]=out
}

cov=solve(inf) #asymptotic covariance 

##Bootstrap function to calculate 95% intervals for ninth time period total abundance 
size=10000
boot_fun=function(x) {
  b=x[1];o=x[3];g=x[4]
  cur=matrix(0,nrow=R,ncol=S)
  cur[,1]=rpois(R,pops[,1]*b)
  for (each in 1:(S-1)){
    weights=cur[,each]/pops[,each]
    cur[,(each+1)]=rbinom(R,cur[,each],o)+rpois(R,g*pops[,each]*((dist %*% weights)/apply(dist,1,sum)))
  }
  return(sum(cur[,S]))
}

sim_boot=rmvnorm(size,c(results$par[1],results$par[2],results$par[3],results$par[4]),cov)

bt=apply(sim_boot,1,boot_fun)

sim_boot2=rmvnorm(size,c(results$par[1],results$par[2],results$par[3],results$par[4]),solve(results$hessian)) # can also do this with numeric hessian 

bt2=apply(sim_boot,1,boot_fun)

# write results as desired 
}
