
#############################################################################
##########  TESTING THE MODEL FITTING AGAINST SIMULATED DATA ################
#############################################################################

library(deSolve)
library(bbmle)
library(plyr)
library(gplots)

##########################################################################
#################   SIMULATE THE LAGEED-DISPERSAL MODEL   ################
##########################################################################

#  lagged dispersal model AKA global dispersal AKA N=2
modN2 <- function(t,y,params){
  V1 <- y[1]; V2 <- y[2]; W1 <- y[3]; W2 <- y[4];
  S1 <- y[5]; S2 <- y[6]; I1 <- y[7]; I2 <- y[8]
  with(as.list(params), {
    dV1 <- (r*(V1+W1)*(1-((V1+W1)/K))) - (d*V1) + (d*(V1*((S1+I1)/(S1+I1+S2+I2))+V2*((S2+I2)/(S1+I1+S2+I2)))) - (I1*V1*beta_hv)
    dV2 <- (r*(V2+W2)*(1-((V2+W2)/K))) - (d*V2) + (d*(V1*((S1+I1)/(S1+I1+S2+I2))+V2*((S2+I2)/(S1+I1+S2+I2)))) - (I2*V2*beta_hv)
    dW1 <- (I1*V1*beta_hv) - (d*W1) + (d*(W1*((S1+I1)/(S1+I1+S2+I2))+W2*((S2+I2)/(S1+I1+S2+I2)))) 
    dW2 <- (I2*V2*beta_hv) - (d*W2) + (d*(W1*((S1+I1)/(S1+I1+S2+I2))+W2*((S2+I2)/(S1+I1+S2+I2))))
    dS1 <- -beta_vh*S1*W1
    dS2 <- -beta_vh*S2*W2
    dI1 <- beta_vh*S1*W1
    dI2 <- beta_vh*S2*W2
    res <- c(dV1, dV2, dW1, dW2, dS1, dS2, dI1, dI2)
    list(res)})}

##################################
# simulate the deterministic model
# below, will use this model output to simulate data, so then we test the model 
# fitting machinery by seeing whether it spits out these underlying pararms

t <- seq(from=0, to=80, by=1)
par2 <- c(r=0.2, K=100, d=0.01, beta_vh=0.005, beta_hv=0.68)
start2 <- c(V1=0, V2=0, W1=2.5, W2=0, S1=4, S2=96, I1=0, I2=0)
sim2 <- data.frame(lsoda(y=start2, times=t, func=modN2, parms=par2)) 
sim2$prev <- sim2$I2/(sim2$S2+sim2$I2)
sim2$aph <- sim2$V2 + sim2$W2


########################################################################
######################   LIKELIHOOD FUNCTIONS   ########################
########################################################################

#########################
### GLOBAL DISPERSAL ####
#########################

ll.func.N2 <- function(r_D0R0, r_D0R1, r_D1R0, r_D1R1, K_D0R0, K_D0R1, K_D1R0, K_D1R1,
                       d_D0R0, d_D0R1, d_D1R0, d_D1R1, 
                       beta_vh_R0, beta_vh_R1, theta_D0R0, theta_D0R1, theta_D1R0, theta_D1R1,
                       beta_hv){
  # ensure all positive:
  r_D0R0=exp(log(r_D0R0)); r_D1R0=exp(log(r_D1R0)); r_D0R1=exp(log(r_D0R1)); r_D1R1=exp(log(r_D1R1))
  K_D0R0=exp(log(K_D0R0)); K_D1R0=exp(log(K_D1R0)); K_D0R1=exp(log(K_D0R1)); K_D1R1=exp(log(K_D1R1))
  d_D0R0=exp(log(d_D0R0)); d_D1R0=exp(log(d_D1R0)); d_D0R1=exp(log(d_D0R1)); d_D1R1=exp(log(d_D1R1))
  beta_vh_R0=exp(log(beta_vh_R0)); beta_vh_R1=exp(log(beta_vh_R1))
  theta_D0R0=exp(log(theta_D0R0)); theta_D1R0=exp(log(theta_D1R0)); theta_D0R1=exp(log(theta_D0R1)); theta_D1R1=exp(log(theta_D1R1))
  
  # rescale K's:
  K_D0R0=1000*K_D0R0; K_D0R1=1000*K_D0R1; K_D1R0=1000*K_D1R0; K_D1R1=1000*K_D1R1
  
  # which parameters used for each simulation
  # need a bunch of if/else statements for backwards model selection.
  par_D0R0 <- c(r=r_D0R0, K=K_D0R0, d=d_D0R0, beta_vh=0, theta=theta_D0R0, beta_hv=beta_hv)
  par_D1R0 <- c(r=ifelse(r_D1R0>0, r_D1R0, r_D0R0), K=ifelse(K_D1R0>0, K_D1R0, K_D0R0), 
                d=ifelse(d_D1R0>0, d_D1R0, d_D0R0), beta_hv=beta_hv,
                beta_vh=beta_vh_R0, theta=ifelse(theta_D1R0>0, theta_D1R0, theta_D0R0))
  par_D0R1 <- c(r=ifelse(r_D0R1>0, r_D0R1, r_D0R0), K=ifelse(K_D0R1>0, K_D0R1, K_D0R0),
                d=ifelse(d_D0R1>0, d_D0R1, d_D0R0), beta_hv=beta_hv, 
                beta_vh=0, theta=ifelse(theta_D0R1>0, theta_D0R1, theta_D0R0))
  par_D1R1 <- c(r=ifelse(r_D1R1>0, r_D1R1, ifelse(r_D1R0>0, r_D1R0, ifelse(r_D0R1>0, r_D0R1, r_D0R0))),
                K=ifelse(K_D1R1>0, K_D1R1, ifelse(K_D1R0>0, K_D1R0, ifelse(K_D0R1>0, K_D0R1, K_D0R0))),
                d=ifelse(d_D1R1>0, d_D1R1, ifelse(d_D1R0>0, d_D1R0, ifelse(d_D0R1>0, d_D0R1, d_D0R0))),
                beta_vh=ifelse(beta_vh_R1>0, beta_vh_R1, beta_vh_R0), beta_hv=beta_hv,
                theta=ifelse(theta_D1R1>0, theta_D1R1, ifelse(theta_D1R0>0, theta_D1R0, ifelse(theta_D0R1>0, theta_D0R1, theta_D0R0))))
  par_lists <- list(par_D0R0, par_D1R0, par_D0R1, par_D1R1) # order must match treatments
  
  # for all simulations:
  start_D0 <- c(V1=2.5, V2=0, W1=0, W2=0, S1=4, S2=96, I1=0, I2=0)
  start_D1 <- c(V1=0, V2=0, W1=2.5, W2=0, S1=4, S2=96, I1=0, I2=0)
  t <- seq(from=0, to=60, by=1)
  
  treat.lls <- data.frame(treats=treats)
  for(i in 1:length(treats)){
    tdata <- data[data$ttmt==treats[i],]
    tdata <- tdata[!(is.na(tdata$aph)),] # omit cases with unknown aphid #
    if (max(tdata$inf)>0){start=start_D1} else {start=start_D0}  
    sim <- data.frame(lsoda(y=start, times=t, func=modN2, parms=as.list(par_lists[[i]])))
    res.df <- merge(data.frame(time=tdata$day, aph=tdata$aph, ring=tdata$ring, inf=tdata$I), sim, by="time")
    res.df$ll.aph <- dnbinom(res.df$aph, mu=res.df$V2+res.df$W2, size=1/par_lists[[i]]["theta"] , log=T)
    res.df$ll.inf <- dbinom(res.df$inf, size=1, prob=pmin(res.df$I2/(res.df$S2+res.df$I2),1), log=T)
    res.df$ll.inf <- ifelse(res.df$ll.inf==-Inf, -100, res.df$ll.inf) # replace -inf (from rounding error) with big negative
    treat.lls$ll[i] <- -sum(res.df$ll.aph)+-sum(res.df$ll.inf)
  }
  sum(treat.lls$ll)
}


################################################################
#################### SIMULATE DATA & FIT #######################
################################################################

iterations = 100 # seems sufficient

sim.fits <- data.frame(parm=c("r_D0R0", "r_D0R1", "r_D1R0", "r_D1R1", 
                                "K_D0R0", "K_D0R1", "K_D1R0", "K_D1R1", 
                                "d_D0R0", "d_D0R1", "d_D1R0", "d_D1R1",
                                "beta_vh_R0", "beta_vh_R1", 
                                "theta_D0R0", "theta_D0R1", "theta_D1R0", "theta_D1R1",
                                "beta_hv"))

for(j in 1:iterations){
  print(j)
  
  # 1) SIMULATE DATA
  
  simdat <- data.frame(day=sort(rep(seq(0,8)*7, 80)),
                       ring=c(0,1,2,3), 
                       block=sort(rep(seq(1,5),16)),
                       ttmt=sort(rep(c(2,3,5,6),4)))
  
  predictions <- data.frame(day=unique(simdat$day))
  predictions$predI <- sim2[sim2$time %in% predictions$day,]$prev
  predictions$predA <- sim2[sim2$time %in% predictions$day,]$aph
  
  data <- merge(simdat, predictions, by.x="day", all.x=T, all.y=T)
  data$inf <- ifelse(data$ttmt==2 | data$ttmt==5, 0, 1) 
  data$predI <- ifelse(data$inf==1, data$predI, 0) # pred
  data$I <- rbinom(n=nrow(data), size=1, prob=data$predI)
  data$aph <- rnbinom(n=nrow(data), mu=data$predA, size=1/0.3)
  treats <- unique(data$ttmt) # so fitting function loop through each treatment
  
  # 2 FIT MODEL TO SIMULATED DATA
  
  start=list(r_D0R0=0.2, d_D0R0=0.01, K_D0R0=0.1, beta_vh_R0=0.005, theta_D0R0=0.3)
  fit.sim <- mle2(ll.func.N2, method="L-BFGS-B", skip.hessian=F, 
                  start=start,  
                  fixed=c(r_D0R1=0, r_D1R0=0, r_D1R1=0, 
                          K_D1R0=0, K_D1R1=0, K_D0R1=0,
                          beta_vh_R1=0, 
                          theta_D1R0=0, theta_D1R1=0, theta_D0R1=0, 
                          d_D1R0=0, d_D0R1=0, d_D1R1=0,
                          beta_hv=0.68),
                  upper=as.list(unlist(start)*c(5,5,5,5,5)), # more flexibility for d 
                  lower=as.list(unlist(start)/c(5,100,5,5,5)), # more flexibility for d
                  control=list(parscale=unlist(start), maxit=500))
  fit.sim
  
  sim.fits[,j+1] <- coef(fit.sim)
}

simfit.summary <- as.data.frame(matrix(t(sim.fits[,-1]), nrow=iterations))
colnames(simfit.summary) <- sim.fits$parm
rownames(simfit.summary) <- 1:iterations
simfit.summary

# rescale K's:
simfit.summary$K_D0R0 <- 1000*simfit.summary$K_D0R0 


################################################################
###################### PLOT RESULTS ############################
################################################################


par(mfrow=c(1,5),mar=c(0.5,3,0,0), oma=c(1,0,0.5,0.5))

# r
plot(mean(simfit.summary$r_D0R0), ylim=c(0.17,0.23), cex=0, xaxt="n", yaxt="n", ann=F)
abline(h=0.2, lty=2, col="dark gray")          
points(mean(simfit.summary$r_D0R0), pch=21, bg="black", cex=2)
plotCI(mean(simfit.summary$r_D0R0),
       li=quantile(simfit.summary$r_D0R0, seq(0.025,0.975,0.95))[1],
       ui=quantile(simfit.summary$r_D0R0, seq(0.025,0.975,0.95))[2],
       add=T, err='y', sfrac=0, gap=0, cex=0)
axis(side=2, at=c(0.18,0.2,0.22), cex.axis=0.9)

# K
plot(mean(simfit.summary$K_D0R0), ylim=c(85,115), cex=0, xaxt="n", yaxt="n", ann=F)
abline(h=100, lty=2, col="dark gray")          
points(mean(simfit.summary$K_D0R0), pch=21, bg="black", cex=2)
plotCI(mean(simfit.summary$K_D0R0),
       li=quantile(simfit.summary$K_D0R0, seq(0.025,0.975,0.95))[1],
       ui=quantile(simfit.summary$K_D0R0, seq(0.025,0.975,0.95))[2],
       add=T, err='y', sfrac=0, gap=0, cex=0)
axis(side=2, at=c(90,100,110), cex.axis=0.9)

# d
plot(mean(simfit.summary$d_D0R0), ylim=c(0.003,0.017), cex=0, xaxt="n", yaxt="n", ann=F)
abline(h=0.01, lty=2, col="dark gray")          
points(mean(simfit.summary$d_D0R0), pch=21, bg="black", cex=2)
plotCI(mean(simfit.summary$d_D0R0),
       li=quantile(simfit.summary$d_D0R0, seq(0.025,0.975,0.95))[1],
       ui=quantile(simfit.summary$d_D0R0, seq(0.025,0.975,0.95))[2],
       add=T, err='y', sfrac=0, gap=0, cex=0)
axis(side=2, at=c(0.005,0.01,0.015), cex.axis=0.9)

# beta
plot(mean(simfit.summary$beta_vh_R0), ylim=c(0.002,0.008), cex=0, xaxt="n", yaxt="n", ann=F)
abline(h=0.005, lty=2, col="dark gray")          
points(mean(simfit.summary$beta_vh_R0), pch=21, bg="black", cex=2)
plotCI(mean(simfit.summary$beta_vh_R0),
       li=quantile(simfit.summary$beta_vh_R0, seq(0.025,0.975,0.95))[1],
       ui=quantile(simfit.summary$beta_vh_R0, seq(0.025,0.975,0.95))[2],
       add=T, err='y', sfrac=0, gap=0, cex=0)
axis(side=2, at=c(0.003,0.005,0.007), cex.axis=0.9)

# theta
plot(mean(simfit.summary$theta_D0R0), ylim=c(.22,.38), cex=0, xaxt="n", yaxt="n", ann=F)
abline(h=0.3, lty=2, col="dark gray")          
points(mean(simfit.summary$theta_D0R0), pch=21, bg="black", cex=2)
plotCI(mean(simfit.summary$theta_D0R0),
       li=quantile(simfit.summary$theta_D0R0, seq(0.025,0.975,0.95))[1],
       ui=quantile(simfit.summary$theta_D0R0, seq(0.025,0.975,0.95))[2],
       add=T, err='y', sfrac=0, gap=0, cex=0)
axis(side=2, at=c(0.25,0.3,0.35), cex.axis=0.9)

