## Functions for fitting quantile foliation model
## Part of this code reproduced here, kindly  provided by I. Currie

#     Define Howie transform: operate on an array by a matrix
#     Arguments
#     W. = an array
#     L. = matrix with ncol = first dimension of W.
H <- function(L., W.){
  d <- dim(W.)
  M <- matrix(W., nrow = d[1])
  LM <- L. %*% M
  array(LM, c(nrow(LM), d[-1]))
}

#     Define rotate function ie the fastest varying index (the first)
#     becomes the slowest varying index (the last)
#     Argument: W. = an array
#
Rotate <- function(W.){
  d <- 1:length(dim(W.))
  d1 <- c(d[-1], d[1])
  aperm(W., d1)
}

#     Define rotated Howie transform
RH <- function(L., W.) Rotate(H(L., W.))

#     Compute the row tensor of a matrix
#     Argument: X. = a matrix
Rten <- function(X.){
  one <- matrix(1, 1, ncol(X.))
  kronecker(X., one) * kronecker(one, X.)
}

#     Compute B-spline regression matrix.
#     Arguments
#     X. = x-domain vector
#     XL. = left of x-domain, XR. = right of x-domain
#     NDX. = number of intervals in domain
#     BDEG. = degree of B-spline (quadratic = 2, etc)
bspline <- function(X., XL., XR., NDX., BDEG.){
  dx <- (XR. - XL.)/NDX.
  knots <- seq(XL. - BDEG.*dx, XR. + BDEG.*dx, by=dx)
  B <- spline.des(knots, X., BDEG. + 1, 0 * X.)$design
  B
}

#     Computes B and D'D
#     Arguments: as bspline( ).  Also PORD. = order of the penalty
setup.P <- function(X., XL., XR., BDEG., PORD., NDX.){
  B <- bspline(X., XL., XR., NDX., BDEG.)
  D <- diag(ncol(B))
  for (k in 1:PORD.) D <- diff(D)
  DtD <- t(D) %*% D
  return(list(B = B, DtD = DtD))
}

#     Computation of (A3xA2xA1)'W(A3xA2xA1) using rotated Howie transform
#      Arguments:
#      W.  = weight array
#      C1. = the number of columns in a marginal regression matrix
#            (and C2. and C3.)
#      B1. = the row tensor of a marginal regression matrix
#            (and B2. and B3.)      
Turbo.BWB <- function(W., C1., C2., C3., B1., B2., B3.){  
  Ans <- RH(t(B3.), RH(t(B2.), RH(t(B1.), W.)))
  dim(Ans) <- c(C1., C1., C2., C2., C3., C3.)
  Ans <- matrix(aperm(Ans, c(1, 3, 5, 2, 4, 6)), nrow = C1. * C2. * C3.)
  Ans
}

#     Computation of A3xA2xA1 z using rotated Howie transform
#      Arguments:
#      Z.  = an array
#      A1. = a marginal regression matrix (and A2. and A3.)      
Turbo.Bz <- function(Z., A1., A2., A3.) {
  RH(A3., RH(A2., RH(A1., Z.)))
}

## Computation of Foliation array
#      Arguments:
#      Y           = array of data
#      B.w, B.m... = Bases matrices
#      P.w, P.m... = Penalty matrices
#      x.p         = vector of quantiles
#      labmda1...  = penalty weights 
#      cri         = stopping criterion
#      mon         = monitor steps

EstFoliation <- function(Y, B.w, B.m, B.p, P.w, P.m, P.p, x.p, lambda1,lambda2,lambda3,cri=1e-2,mon=TRUE){
  c.w <- ncol(B.w)
  c.m <- ncol(B.m)
  c.p <- ncol(B.p)
  Rten.B.w <- Rten(B.w)
  Rten.B.m <- Rten(B.m)
  Rten.B.p <- Rten(B.p)
  
  A.init <- matrix(rep(.000002,c.w*c.m*c.p),ncol=1)
  P <- lambda1 * P.w + lambda2 * P.m + lambda3 * P.p
  b <- .001
  Z <- 0 * Y + mean(Y)
  OP <- outer(matrix(rep(1,ncol(Y)*nrow(Y)),ncol=ncol(Y)), x.p)
  R <- Y - Z
  a <- array(A.init, c(c.w, c.m,c.p)) 
  Mu <- Turbo.Bz(a, B.w, B.m, B.p)
  dz <- iter <- 1
  while(dz>cri){
    R <- Y-Mu
    Wt <- ifelse(R > 0, OP, 1 - OP)/sqrt(b + R^2)
    Bt.W.B <- Turbo.BWB(Wt, c.w, c.m, c.p,  Rten.B.w,  Rten.B.m, Rten.B.p)
    Rhs1 <- matrix(Turbo.Bz(Y *Wt, t(B.w), t(B.m), t(B.p)),                 ncol = 1)
    Rhs <- Bt.W.B %*% A.init + Rhs1
    a <- solve(Bt.W.B + P, Rhs)
    a <- array(a, c(c.w, c.m,c.p)) 
    old <- Mu
    Mu <- Turbo.Bz(a, B.w, B.m, B.p)
    dz <- max(abs(old-Mu))
    if(mon){cat("Iteration", iter,
                ", Maximum abs difference", round(dz,4),"stopping when <",cri,"\r")
      iter <- iter+1}
  }
  Tr <- sum(diag(solve(Bt.W.B + P, Bt.W.B)))
  y.init <- Y
  y.init[Y==0] <- 10^(-4)
  Dev <- 2*sum(  Y*log(y.init/Mu) )
  Bic <- Dev + log(sum(dim(Y))) * Tr
  out <- list(Mu=Mu,Bic=Bic,Dev=Dev,Tr=Tr,l1=lambda1,l2=lambda2,l3=lambda3,a=a)
  out
}

#load R array data on male performance 
load("MaleDataGrid.Rdata")

# create B basis  and difference matrix for weight 
x.weight <- as.numeric(rownames(Y))
Calc.P <- setup.P(x.weight, min(x.weight) - 1e-8, max(x.weight) + 1e-8, 3, 2, 8)
B.weight <- Calc.P$B
DtD.weight <- Calc.P$DtD

# create B basis  and difference matrix for age 
x.age <-  as.numeric(colnames(Y))
Calc.P <- setup.P(x.age, min(x.age)- 1e-8, max(x.age)+ 1e-8, 3, 2, 8)
B.age <- Calc.P$B
DtD.age <- Calc.P$DtD

# create B basis  and difference matrix for   quantiles
x.p <- seq(0,1,length=dim(Y)[3])
Calc.P <- setup.P(x.p, min(x.p) - 1e-8, max(x.p) + 1e-8, 3, 2, 8)
B.p <- Calc.P$B
DtD.p <- Calc.P$DtD

#    Calculate penalty matrices
P.weight <- kronecker(diag(ncol(B.age)), kronecker(diag(ncol(B.age)), DtD.weight))
P.age <- kronecker(diag(ncol(B.age)), kronecker(DtD.age, diag(ncol(B.weight))))
P.p <- kronecker(DtD.p, kronecker(diag(ncol(B.age)), diag(ncol(B.weight))))

# Run foliation model
rs <- EstFoliation(Y, B.weight, B.age, B.p, P.weight, P.age, P.p,x.p, 0.25,.01,.05,1e-4,mon=TRUE)

## Plot surfaces
library(plotly)
Mu <- rs$Mu
a <- rs$a
axx <- list(title = "Body Weight",range = c(40, 120))
axy <- list(title = "Age",range = c(14, 90))
axz <- list(title = "Total",range = c(50, 420))

p <- plot_ly( x=rev(x.weight),y=rev(x.age),z=~Y,showscale=FALSE) %>%
  add_surface(z = ~Mu[,,40],colors="#af8dc3") %>%
  add_surface(z = ~Mu[,,1], opacity = 0.95) %>%
  layout(scene = list(xaxis=axx,yaxis=axy,zaxis=axz))
p

