# Pohle, J., Langrock, R., van der Schaar, M., King, R. and Jensen, F.H.: #
# A primer on coupled state-switching models for multiple interacting time series #
#''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''#

# R code for real-data dolphin case study described in Section 4.1 #



##### 1) Load packages and data #####

library(Rcpp)
library(RcppArmadillo)
sourceCpp("Pohle_et_al_mllkRcpp.cpp") # C++ code for forward algorithm to evaluate HMM log-likelihood
data<-read.table('Pohle_et_al_Dolphin_Tortuosity_Data.txt',header=TRUE)

# replacement of zeros in the tortuosity time series
set.seed(1234)
hist(data[,1],probability=TRUE, breaks=50)
ind_mother<-which(data$tortuosity_mother==0)
summary(data[-ind_mother,1])
data[ind_mother,1]<-runif(length(ind_mother),0,0.001)

ind_calf<-which(data[,2]==0)
data[ind_calf,2]<-runif(length(ind_calf),0,0.001)


##### 2) Functions used for parameter estimation #####
# for Cartesian product CHMM for two individuals with NpV states per individual and state-dependent beta distributions #

# n2w_chmm: transformation of natural (constrained) parameters into unconstrained parameters
# input:  alpha - 2*NpV-dimensional parameter vector containing the state-dependent alpha parameters for each state and each individual 
#         (order: c(alpha_state1_ind1,...,alpha_stateNpV_ind1,alpha_state1_ind2,...alpha_stateNpV_ind2),
#         beta - 2*NpV-dimensional parameter vector containing the state-dependent beta parameters for each state and each individual,
#         gamma - NxN dimensional transition probability matrix (with N=N^2),
#         N - number of states in the Cartesian product state space 
# output: vector of transformed parameters
n2w_chmm<-function(alpha,beta,gamma,N){
  talpha<-log(alpha)
  tbeta<-log(beta)
  tgamma<-log(gamma/diag(gamma))
  tgamma<-tgamma[!diag(N)]
  parvect<-c(talpha,tbeta,tgamma)
  return(parvect)
}


# w2n_chmm: back-transformation of transformed parameters into natural parameters
# input:  parvect - vector of transformed parameters as given by n2w_chmm
#         N - number of states in the Cartesian product state space (with N=NpV^2)
#         NpV - number of states per individual
# output: list including back-transformed natural parameters alpha, beta and gamma as described above and the N-dimensional stationary distribution vector delta
w2n_chmm<-function(parvect,N,NpV){
  alpha<-exp(parvect[1:(NpV*2)])
  beta<-exp(parvect[(NpV*2)+1:(NpV*2)])
  gamma<-diag(N)
  gamma[!gamma]<-exp(parvect[(4*NpV)+1:(N*(N-1))])
  gamma<-gamma/apply(gamma,1,sum)
  delta<-solve(t(diag(N)-gamma+1),rep(1,N))
  return(list(alpha=alpha,beta=beta,gamma=gamma,delta=delta))
}


# mllk_chmm: evaluation of model's minus log-likelihood function for transformed parameter vector parvect
# function uses C++ code to speed up calculation
# input:  parvect - transformed parameter vector as given by n2w_chmm
#         data - data matrix of observed values, first column corresponds to individual 1, second column to individual 2
#         NpV - number of states per individual,
#         N - number of states in Cartesian product state vector (N=NpV^2),
#         ind_NpV - index array describing state combinations (defined in function mle_chmm),
#         n - number of observations
# output: minus log-likelihood value
mllk_chmm <- function(parvect,data,NpV,N,ind_NpV,n){
  lpn <- w2n_chmm(parvect,N,NpV)
  allprobs<-matrix(1,n,N)
  for(m in 1:2){
    for (j in 1:NpV){
      allprobs[,ind_NpV[,j,m]] <- allprobs[,ind_NpV[,j,m]]*dbeta(data[,m],lpn$alpha[(m-1)*NpV+j],lpn$beta[(m-1)*NpV+j])
    }
  }
  foo <- lpn$delta
  lscale = mllk_Rcpp(allprobs,gamma=lpn$gamma,foo,n) 
  return(-lscale)
}


# mle_chmm: maximum likelihood estimation using nlm
# input:  data - data matrix of observed values, first column corresponds to individual 1, second column to individual 2
#         alpha, beta, gamma - starting values for natural parameters as described above
#         NpV - number of states per indivdual
# output: list including maximum likelihood estimates for all model parameters found by nlm,
#         negative maximum log-likelihood value found, nlm characteristic (number of iterations needed, convergence code)
mle_chmm <- function(data,alpha,beta,gamma,NpV){
  N<-NpV^2
  ind_states<-rev(expand.grid(1:NpV,1:NpV))
  ind_NpV<-array(0,dim=c(NpV,NpV,2))
  for(m in 1:2){
    for(j in 1:NpV)
      ind_NpV[,j,m]<-which(ind_states[,m]==j)
  }
  parvect <- n2w_chmm(alpha,beta,gamma,N)
  n<-dim(data)[1]
  mod <- nlm(mllk_chmm,parvect,data,NpV,N,ind_NpV,n,print.level=2,iterlim=10000,stepmax=150)
  pn <- w2n_chmm(mod$estimate,N,NpV)
  return(list(alpha=pn$alpha,beta=pn$beta,delta=pn$delta,gamma=pn$gamma,mllk=mod$minimum,iterations=mod$iterations,code=mod$code))
}


# viterbi_chmm: Viterbi algorithm
# input: data - data matrix of observed values, first column corresponds to individual 1, second column to individual 2
#         mod - model object as given by function mle_chmm
#         NpV - number of states per individual
# output: Viterbi sequence indicating the most likely state sequence given the model,
#         values from 1 to NpV^2, the corresponding state vectors can be identified using ind_states<-rev(expand.grid(1:NpV,1:NpV)),
#         which row-wise shows the different state-vectors that can be realised
viterbi_chmm<-function(data,mod,NpV){
  N<-NpV^2
  n<-dim(data)[1]
  ind_states<-rev(expand.grid(1:NpV,1:NpV))
  ind_NpV<-array(0,dim=c(NpV,NpV,2))
  for(m in 1:2){
    for(j in 1:NpV)
      ind_NpV[,j,m]<-which(ind_states[,m]==j)
  }
  allprobs<-matrix(1,n,N)
  for(m in 1:2){
    for (j in 1:NpV){
      allprobs[,ind_NpV[,j,m]] <- allprobs[,ind_NpV[,j,m]]*dbeta(data[,m],mod$alpha[(m-1)*NpV+j],mod$beta[(m-1)*NpV+j])
    }
  }
  yi <- matrix(0,n,N)
  foo <- mod$delta*allprobs[1,]
  yi[1,] <- foo/sum(foo)
  for (i in 2:n){
    foo <- apply(yi[i-1,]*mod$gamma,2,max)*allprobs[i,]
    yi[i,] <- foo/sum(foo)
  }
  iv <- numeric(n)
  iv[n] <-which.max(yi[n,])
  for (i in (n-1):1){
    iv[i] <- which.max(mod$gamma[,iv[i+1]]*yi[i,])
  }
  iv
} 



##### 3) Parameter estimation #####

# initialisation

# number of states
NpV<-3 # per indivdual
N<-9 # for state vector

# order of state vectors:
ind_states<-rev(expand.grid(1:NpV,1:NpV))

# starting values:
gamma0<-matrix(0.02,N,N)
diag(gamma0)<-1-(N-1)*0.02
alpha0<-rep(c(3.5, 1, 3),2)
beta0<-rep(c(138, 6,7.5),2)
  
# estimation: 
mod_chmm<-mle_chmm(data,alpha0,beta0,gamma0,NpV)



##### 4) Results #####

# parameter estimates

# mean values
mod_chmm$mean<-mod_chmm$alpha/(mod_chmm$alpha+mod_chmm$beta)
# mother states 1-3
round(mod_chmm$mean[1:3],3) 
# calf states 1-3
round(mod_chmm$mean[4:6],3)

# standard deviations
mod_chmm$sd<-sqrt(mod_chmm$alpha*mod_chmm$beta/((mod_chmm$alpha+mod_chmm$beta)^2*(mod_chmm$alpha+mod_chmm$beta+1)))
# mother states 1-3
round(mod_chmm$sd[1:3],3) 
# calf state 1-3
round(mod_chmm$sd[4:6],3)

# stationary distribution
round(mod_chmm$delta,3) #order as indicated by ind_states


# plot state-dependent beta-distributions
# log-scale for y-axis

color<-c(rgb(176,048,096,maxColorValue=255),rgb(000,139,139,maxColorValue=255),'orange')

# mother
z<-seq(0,1,0.001)
sdd1<-sum(mod_chmm$delta[which(ind_states[,1]==1)])*dbeta(z,mod_chmm$alpha[1],mod_chmm$beta[1])
sdd2<-sum(mod_chmm$delta[which(ind_states[,1]==2)])*dbeta(z,mod_chmm$alpha[2],mod_chmm$beta[2])
sdd3<-sum(mod_chmm$delta[which(ind_states[,1]==3)])*dbeta(z,mod_chmm$alpha[3],mod_chmm$beta[3])
# histogram
hist_a<-hist(data[,1], breaks=40,plot=FALSE)
n_breaks<-length(hist_a$breaks)
# plot
par(mar=c(4,4,1,1))
plot(hist_a$breaks,c(log(hist_a$density+1),0),type='s',ylab='',xlab='',yaxt='n',col='darkgrey',bty='n',xaxt='n',xlim=c(0,1),ylim=c(0,3.5))
lines(hist_a$breaks[1:(n_breaks-1)],log(hist_a$density+1),type='h',col='darkgrey')
lines(hist_a$breaks[c(1,n_breaks)],c(0,0),type='l',col='darkgrey')
axis(1,seq(0,1,0.2))
axis(2,c(0,log(2),log(6),log(21)),c(0,1,5,20))
mtext('tortuosity mother',side=1,line=2.2)
mtext('density',side=2,line=2.1)
lines(z,log(sdd1+1),lwd=3,col=color[1])
lines(z,log(sdd2+1),lwd=3,col=color[2])
lines(z,log(sdd3+1),lwd=3,col=color[3])
legend(x=0.6,y=log(21),c('state 1','state 2', 'state 3'),lwd=3,col=color,bty='n') #cex=0.8,

# calf
sdd1<-sum(mod_chmm$delta[which(ind_states[,2]==1)])*dbeta(z,mod_chmm$alpha[4],mod_chmm$beta[4])
sdd2<-sum(mod_chmm$delta[which(ind_states[,2]==2)])*dbeta(z,mod_chmm$alpha[5],mod_chmm$beta[5])
sdd3<-sum(mod_chmm$delta[which(ind_states[,2]==3)])*dbeta(z,mod_chmm$alpha[6],mod_chmm$beta[6])
# histogram
hist_b<-hist(data[,2], breaks=40,plot=FALSE)
n_breaks<-length(hist_b$breaks)
par(mar=c(4,4,1,1))
plot(hist_b$breaks,c(log(hist_b$density+1),0),type='s',ylab='',xlab='',yaxt='n',col='darkgrey',bty='n',xaxt='n',xlim=c(0,1),ylim=c(0,3.5))
lines(hist_b$breaks[1:(n_breaks-1)],log(hist_b$density+1),type='h',col='darkgrey')
lines(hist_b$breaks[c(1,n_breaks)],c(0,0),type='l',col='darkgrey')
axis(1,seq(0,1,0.2))
axis(2,c(0,log(2),log(6),log(21)),c(0,1,5,20))
mtext('tortuosity calf',side=1,line=2.2)
mtext('density',side=2,line=2.1)
lines(z,log(sdd1+1),lwd=3,col=color[1])
lines(z,log(sdd2+1),lwd=3,col=color[2])
lines(z,log(sdd3+1),lwd=3,col=color[3])


# Viterbi sequence

# Viterbi decoding - decoded states take values from 1 to 9 according to the different state combinations that can be realised
s_chmm<-viterbi_chmm(data,mod_chmm,NpV)

# individual states for mother - states take values from 1 to 3
s_chmm_mother<-ind_states[s_chmm,1]

# individual states for calf - states take values from 1 to 3
s_chmm_calf<-ind_states[s_chmm,2]

# percentage of Viterbi states that indicate different behavioural modes
round(sum(!s_chmm_mother==s_chmm_calf)/nrow(data),2)*100

# plot emphasising observations associated with Viterbi states indicating different behavioural modes 
ind_diff<-which(s_chmm_mother!=s_chmm_calf)
color<-c(rgb(176,048,096,maxColorValue=255),rgb(000,139,139,maxColorValue=255),'orange')

# mother
plot(data[,1],type='h',bty='n',main='',xaxt='n',ylim=c(0,1),col='gray',cex.axis=0.8,xlab='',ylab='')
lines(ind_diff,data[ind_diff,1],type='h',col=color[s_chmm_mother[ind_diff]])
indx<-(0:9)*2
axis(1,indx*360,indx,cex.axis=0.8)
mtext('hour',side=1,line=2,cex=0.8)
mtext('tortuosity mother',side=2,line=2,cex=0.8)

# calf
plot(data[,2],type='h',bty='n',main='',xaxt='n',ylim=c(0,1),col='gray',cex.axis=0.8,xlab='',ylab='')
lines(ind_diff,data[ind_diff,2],type='h',col=color[s_chmm_calf[ind_diff]])
indx<-(0:9)*2
axis(1,indx*360,indx,cex.axis=0.8)
mtext('hour',side=1,line=2,cex=0.8)
mtext('tortuosity calf',side=2,line=2,cex=0.8)




