
# http://www.petrkeil.com/?p=1050

# --------------------------------------------------------------------------------------------------------------------------------
# functions

merge_all <- function(dflist, by.var) {
    Reduce(function(...) merge(..., by = by.var, all = TRUE), dflist)
}

"%!in%" <- Negate("%in%")


# -----------------------------------------------------------------------------------------------------
# options

options(stringsAsFactors = FALSE)


# -----------------------------------------------------------------------------------------------------
# packages

packages <- c("texreg", "lme4", "nlme", "plyr", "MASS", "geosphere",
              "mgcv", "ncf", "spdep", "arm", "gstat", "ape", "automap", "MuMIn", "cvTools",
              "caret", "e1071", "xgboost", "subselect", "ipred", "parallel", "doParallel", "corrplot") 
sapply(packages, require, character.only = TRUE)


# --------------------------------------------------------------------------------------------------------------------------------
# load data

project_dir <- "path-to-output-directory"
data_dir <- "path-to-input-directory"
setwd(project_dir)

load("data/data_PP.Rdata")

# country data
countyDat <- read.csv(file.path(data_dir, "County Data/Master County Dataset.csv")) 
countyDat$fips <- countyDat$FIPS

# merge in human variables
human_variables <- c("fips", "pop2010", "popdensity")
combined_PP <- merge(combined_PP, countyDat[, human_variables], by = "fips", all.x = TRUE)

# exponentiate logged variables
#stay_logged <- c("species.i", "genus.i", "family.i", "PD.INTRO")
#exp_variables <- logged_variables[logged_variables %!in% stay_logged]                  
# combined_PP[, exp_variables] <- lapply(combined_PP[, exp_variables], exp)                 
combined_PP[, logged_variables] <- lapply(combined_PP[, logged_variables], exp)  

# number of counties in each state
ddply(combined_PP, .(state), summarise, count = length(fips))


# -----------------------------------------------------------------------------------------------------
# standardize some predictor variables

combined_PP <- within(combined_PP, {
    gmted2010.elev_mean.mn_std <- (gmted2010.elev_mean.mn - mean(gmted2010.elev_mean.mn)) / sd(gmted2010.elev_mean.mn)
    hwsd.t_clay.mn_std <- (hwsd.t_clay.mn - mean(hwsd.t_clay.mn)) / sd(hwsd.t_clay.mn) 
    hwsd.t_ece.mn_std <- (hwsd.t_ece.mn - mean(hwsd.t_ece.mn)) / sd(hwsd.t_ece.mn) 
    hwsd.t_gravel.mn_std <- (hwsd.t_gravel.mn - mean(hwsd.t_gravel.mn)) / sd(hwsd.t_gravel.mn) 
    hwsd.t_oc.mn_std <- (hwsd.t_oc.mn - mean(hwsd.t_oc.mn)) / sd(hwsd.t_oc.mn) 
    hwsd.t_ph.mn_std <- (hwsd.t_ph.mn - mean(hwsd.t_ph.mn)) / sd(hwsd.t_ph.mn) 
    hwsd.t_sand.mn_std <- (hwsd.t_sand.mn - mean(hwsd.t_sand.mn)) / sd(hwsd.t_sand.mn) 
    bio01.mn_std <- (bio01.mn - mean(bio01.mn)) / sd(bio01.mn) 
    bio07.mn_std <- (bio07.mn - mean(bio07.mn)) / sd(bio07.mn) 
    bio08.mn_std <- (bio08.mn - mean(bio08.mn)) / sd(bio08.mn) 
    bio12.mn_std <- (bio12.mn - mean(bio12.mn)) / sd(bio12.mn) 
    bio15.mn_std <- (bio15.mn - mean(bio15.mn)) / sd(bio15.mn) 
    pop2010_std <- (pop2010 - mean(pop2010)) / sd(pop2010)
    popdensity_std <- (popdensity - mean(popdensity)) / sd(popdensity)
    area_std <- (area - mean(area)) / sd(area)
    glaciation_binary <- factor(ifelse(glaciation %in% 1, "no", "yes"))
})


#######################################################################################################
# create autocovariate term

# EPSG:4326 = lat/long, WGS 84
# EPSG:2163 = onshore and offshore equal area projection, where coordinates are in meters
# ESRI:102008 = Albers continental US equal area projection, where coordinates are in meters
# https://source.opennews.org/articles/choosing-right-map-projection/
# https://groups.google.com/forum/#!topic/unmarked/KYrrgnxVzAg

# proj4 projection strings
proj4_lat_lon <- "+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs"
proj4_US_contig_Albers_eqArea_conic <- "+proj=aea +lat_1=29.5 +lat_2=45.5 +lat_0=37.5 +lon_0=-96 +x_0=0 +y_0=0 +ellps=GRS80 +datum=NAD83 +units=m +no_defs" 
proj4_US_Nat_Altas_eqArea <- "+proj=laea +lat_0=45 +lon_0=-100 +x_0=0 +y_0=0 +a=6370997 +b=6370997 +units=m +no_defs" 

# create SpatialPointsDataFrame
coords_lat_lon <- combined_PP
coordinates(coords_lat_lon) <- c("longitude", "latitude")
proj4string(coords_lat_lon) <- CRS(proj4_lat_lon) 
# project coordinates
coords_eqArea <- spTransform(coords_lat_lon, CRS(proj4_US_contig_Albers_eqArea_conic)) 
# create a matrix of coordinates
coords_mat_eqArea <- as.matrix(as.data.frame(coords_eqArea)[, c("longitude", "latitude")])
# store in data frame
#coords_eqArea_dat <- as.data.frame(coords_eqArea)
#combined_PP[, c("longitude_eqArea", "latitude_eqArea")] <- coords_eqArea_dat[, c("longitude", "latitude")]
# create a matrix of coordinates
#coords_mat_eqArea <- as.matrix(combined_PP[, c("longitude_eqArea", "latitude_eqArea")])


# -----------------------------------------------------------------------------------------------------
# determine range of autocorrelation

# range <- autofitVariogram(formula(paste(i, "~ 1")), input_data = coords_eqArea)$var_model[2, "range"]
# range <- 200000
# create neighbour list
# nb_list <- dnearneigh(coords_mat_eqArea, 0, range)
# create weights list
# nb_weights <- nb2listw(nb_list, zero.policy = TRUE)

explain_predictors <- c("bio01.mn_std", "bio07.mn_std", "bio08.mn_std", "bio12.mn_std", "bio15.mn_std", "gmted2010.elev_mean.mn_std",
    "gmted2010.elev_mean.sd", "hwsd.t_clay.mn_std", "hwsd.t_ece.mn_std", "hwsd.t_gravel.mn_std", "hwsd.t_oc.mn_std", "hwsd.t_ph.mn_std", 
    "hwsd.t_sand.mn_std", "glaciation_binary", "area_std", "enow.coast") # "longitude", "latitude", "long_x_lat"
    
# get autocovariate ranges
ranges <- list()

for (i in responseColumns) {
	print(i)
	# create formula for non-spatial model
    fixed <- paste(i, "~", paste(explain_predictors, collapse = " + "))
    random <- "+ (1 | state)"
    # fit non-spatial models and store residuals
    nonSpatial_resid <- as.vector(resid(lmer(formula(paste(fixed, random)), data = combined_PP, na.action = "na.exclude")))
    coords_eqArea@data$nonSpatial_resid <- nonSpatial_resid
    # exclude missing values for the residuals
    coords_eqArea_NA <- coords_eqArea[complete.cases(coords_eqArea@data$nonSpatial_resid), ]
	# determine range of autocorrelation
    ranges[[i]] <- autofitVariogram(formula(paste("nonSpatial_resid", "~ 1")), input_data = coords_eqArea_NA)$var_model[2, "range"]
}

round(sort(unlist(ranges))/1000)

# --------------------------------------------------------------------------------------------------------------------------------
# set up cross validation and perform pre-processing

# run algorithms using 10-fold cross validation
# control <- trainControl(method = "cv", number = 10, allowParallel = TRUE)
# run algorithms using repeated 10-fold cross validation
control <- trainControl(method = "repeatedcv", number = 10, repeats = 10, allowParallel = TRUE)
metric <- "RMSE"


# -----------------------------------------------------------------------------------------------------

# pad out predictions with NAs to length of original data (3066)
pad_NA <- function(x, resid = nonSpatial_resid) {
	present = which(!is.na(resid))
	missing = which(is.na(resid))
    length(x) <- length(resid)
    df <- data.frame(x)
    rownames(df) <- c(present, missing)
    df <- df[order(as.numeric(row.names(df))), ]    
    return(df)
}


rac.model <- function(response = responseColumns, mat = coords_mat_eqArea,
                                    coords_SPDF = coords_eqArea, dat = combined_PP,
                                    predictors = explain_predictors) {
models <- list()
models_2 <- list()
validation <- list()
dataset <- list()
for (i in response) {
	print(i)
	# create formula for non-spatial model
    fixed <- paste(i, "~", paste(predictors, collapse = " + "))
    # fit non-spatial models and store residuals
    nonSpatial_resid <- as.vector(resid(lme(formula(fixed), random = ~ 1 | state, data = dat, na.action = "na.exclude")))
    #nonSpatial_resid <- as.vector(resid(lm(formula(paste(fixed, "+ state")), data = dat, na.action = "na.exclude")))
    coords_SPDF@data$nonSpatial_resid <- nonSpatial_resid
    # exclude missing values for the residuals
    NA_index <- complete.cases(nonSpatial_resid)
    nonSpatial_resid_NA <- nonSpatial_resid[NA_index]
    coords_SPDF_NA <- coords_SPDF[NA_index, ]
    mat_NA <- mat[NA_index, ]
	# determine range of autocorrelation
    range <- autofitVariogram(formula(paste("nonSpatial_resid", "~ 1")), input_data = coords_SPDF_NA)$var_model[2, "range"]
    range <- as.numeric(ifelse(range > 250000, 250000, range))
    range <- as.numeric(ifelse(range < 70000, 70000, range))
    print(round(range/1000))
    # create neighbour list
    nb_list <- dnearneigh(mat_NA, 0, range)
    # create weights list
    nb_weights <- nb2listw(nb_list, style = "W", zero.policy = TRUE) # S     # WHERE DOES THIS FACTOR INTO THE BELOW STUFF? autocov_dist() does this
	# calculate residual autocovariate based on above distance and weight from residuals                                                                                           
    dat$rac <- pad_NA(autocov_dist(nonSpatial_resid_NA, mat_NA, nbs = range, type = "inverse", zero.policy = TRUE), resid = nonSpatial_resid)
    dat$rac_std <- (dat$rac - mean(dat$rac, na.rm = TRUE)) / sd(dat$rac, na.rm = TRUE)
    dat$rac2 <- pad_NA(autocov_dist(nonSpatial_resid_NA, mat_NA, nbs = range, type = "inverse.squared", zero.policy = TRUE), resid = nonSpatial_resid)
    dat$rac2_std <- (dat$rac2 - mean(dat$rac2, na.rm = TRUE)) / sd(dat$rac2, na.rm = TRUE) 
    
    # create a list of 80% of the rows in the original dataset we can use for training
    set.seed(7)
    validation_index <- createDataPartition(dat[, i], times = 1, p = 0.80, list = FALSE, groups = 5)    
    # select 20% of the data for validation
    validation[[i]] <- dat[-validation_index, ]
    # use the remaining 80% of data to training and testing the models
    dataset[[i]] <- dat[validation_index, ]
    
    # cvTools
    #test <- lme(formula(paste(fixed, "+ rac_std")), random = ~ 1 | state, data = dataset, na.action = "na.exclude")
    #CV <- cvFit(test, data = dataset, y = dataset$species, K = 10, R = 10, type = "random", cost = "rtmspe", costArgs = list(trim = 0), seed = 02138)

    # fit model with residual autocovariate 
    models[[i]] <- train(formula(paste(fixed, "+ state + rac_std")), method = "lm", data = dataset[[i]], na.action = "na.exclude", metric = metric, trControl = control)
    models_2[[i]] <- train(formula(paste(fixed, "+ state + rac2_std")), method = "lm", data = dataset[[i]], na.action = "na.exclude", metric = metric, trControl = control)
    
    # fit model with residual autocovariate 
    #models[[i]] <- train(formula(paste(fixed, "+ rac_std")), random = ~ 1 | state, method = lme, data = dataset, na.action = "na.exclude", metric = metric, trControl = control)
    #models_2[[i]] <- train(formula(paste(fixed, "+ rac2_std", random)), method = "lme", data = dataset, na.action = "na.exclude", metric = metric, trControl = control)
        
    # fit model with residual autocovariate 
    #models[[i]] <- lmer(formula(paste(fixed, "+ rac_std", random)), data = dataset, na.action = "na.exclude")
    #models_2[[i]] <- lmer(formula(paste(fixed, "+ rac2_std", random)), data = dataset, na.action = "na.exclude")
}  
#sign(as.numeric(print(models[["species"]])[, "Rsquared"]) - as.numeric(print(models_2[["species"]])[, "Rsquared"]))
idx_squared <- sign(sapply(models, function(x) as.numeric(print(x)[, "Rsquared"])) - sapply(models, function(x) as.numeric(print(x)[, "Rsquared"])))
#idx_squared <- sign(sapply(models, AIC) - sapply(models_2, AIC))
models[idx_squared == 1] <- models_2[idx_squared == 1]
list(models=models, validation=validation, dataset=dataset)
}


# http://topepo.github.io/caret/using-your-own-model-in-train.html





#######################################################################################################
######################### MIXED MODELS WITH SEPARATE AUTOCOVARIATE TERM #######################
#######################################################################################################

# -----------------------------------------------------------------------------------------------------
# which environmental variables explain diversity?

# lon, lat, enow.coast and area should be included as geographic controls. Glaciation should also be included as environmental variable. 

explain_predictors <- c("bio01.mn_std", "bio07.mn_std", "bio08.mn_std", "bio12.mn_std", "bio15.mn_std", "gmted2010.elev_mean.mn_std",
    "gmted2010.elev_mean.sd", "hwsd.t_clay.mn_std", "hwsd.t_ece.mn_std", "hwsd.t_gravel.mn_std", "hwsd.t_oc.mn_std", "hwsd.t_ph.mn_std", 
    "hwsd.t_sand.mn_std", "glaciation_binary", "area_std", "enow.coast")  # "longitude", "latitude", "long_x_lat"


explain_models_taxon <- rac.model(response = c("species", "genus", "family"),
    predictors = explain_predictors)
explain_models_all <- rac.model(response = c("PD.ALL",  "MPD.ALL", "PD.s.ALL", "MPD.s.ALL"),
    predictors = c(explain_predictors, "totalOntree"))
                   	
                     	
explain_models_CV <- c(explain_models_taxon$models, explain_models_all$models)
explain_models_validation <- c(explain_models_taxon$validation, explain_models_all$validation)
explain_models_dataset <- c(explain_models_taxon$dataset, explain_models_all$dataset)
 
save(explain_models_CV, file = "data/explain_models_varImp_ALL_EXP_CV.Rdata", compress = "gzip")
# load("data/explain_models_varImp_ALL_EXP_CV.Rdata")
save(explain_models_validation, file = "data/explain_models_varImp_ALL_EXP_validation.Rdata", compress = "gzip")
save(explain_models_dataset, file = "data/explain_models_varImp_ALL_EXP_dataset.Rdata", compress = "gzip")
 
explain_models_summary <- lapply(explain_models, summary)



# --------------------------------------------------------------------------------------------------------------------------------
# validation

validate <- function(cv_trained_model, validation_data, response_var) {
    # estimate skill of best model on the validation dataset and calculate RMSE and R^2
    preds <- predict(cv_trained_model, newdata = validation_data) # listwise deletion for newdata
    rmse <- sqrt(mean((preds - validation_data[, response_var])^2)) 
    R2 <- cor(preds, validation_data[, response_var])^2
    list(preds = preds, rmse = rmse, R2 = R2)
}

val_species <- validate(explain_models_CV$species, validation_data = explain_models_validation$species, response_var = "species")
val_genus <- validate(explain_models_CV$genus, validation_data = explain_models_validation$genus, response_var = "genus")
val_family <- validate(explain_models_CV$family, validation_data = explain_models_validation$family, response_var = "family")
#
val_PD.ALL <- validate(explain_models_CV$PD.ALL, validation_data = explain_models_validation$PD.ALL, response_var = "PD.ALL")
val_MPD.ALL <- validate(explain_models_CV$MPD.ALL, validation_data = explain_models_validation$MPD.ALL, response_var = "MPD.ALL")
val_PD.s.ALL <- validate(explain_models_CV$PD.s.ALL, validation_data = explain_models_validation$PD.s.ALL, response_var = "PD.s.ALL")
val_MPD.s.ALL <- validate(explain_models_CV$MPD.s.ALL, validation_data = explain_models_validation$MPD.s.ALL, response_var = "MPD.s.ALL")

val_estimates <- list(
    val_species, val_genus, val_family, 
    val_PD.ALL, val_MPD.ALL, val_PD.s.ALL, val_MPD.s.ALL
    )

val_df <- data.frame(
    response = c("species", "genus", "family", 
        "PD.ALL", "MPD.ALL", "PD.s.ALL", "MPD.s.ALL"),
    # rmse = sapply(val_estimates, function(x) x$rmse),
    R2 = sapply(val_estimates, function(x) round(x$R2, digits = 2))
    )

write.csv(val_df, file = "data/validation_LMM_R2.csv", row.names = FALSE)

# plot observed versus predicted
pdf("figures/validation_LMM.pdf", height = 4, width = 7)
op <- par(mfrow = c(2, 4), mar = c(4, 4, 1, 1) + 0.1)
  plot(val_species$preds, explain_models_validation$species$species)
  plot(val_genus$preds, explain_models_validation$genus$genus) 
  plot(val_family$preds, explain_models_validation$family$family)   
  #
  plot(val_PD.ALL$preds, explain_models_validation$PD.ALL$PD.ALL) 
  plot(val_MPD.ALL$preds, explain_models_validation$MPD.ALL$MPD.ALL) 
  plot(val_PD.s.ALL$preds, explain_models_validation$PD.s.ALL$PD.s.ALL) 
  plot(val_MPD.s.ALL$preds, explain_models_validation$MPD.s.ALL$MPD.s.ALL) 
par(op)
dev.off()

       
         