#library(Rcpp, lib.loc = "C:/Program Files/R/R-3.6.1/library")
library(raster)
library(rgdal)
library(sf)
library(dplyr)
library(ggforce)
library(stringr)
library(reshape)
library(viridis)
library(ggsci)
library(erer)
library(ggplot2)
library(ggridges)
library(ggpubr)
library(maditr)
library(RColorBrewer)    
library(cowplot)
library(ggspatial)
library(ggsn)
library(ggpmisc)
library(tidyverse)
library(Rfast)
library(matrixStats)
library(devtools)
library(reldist)

library(cluster.datasets)
library(factoextra)
library(NbClust)
library(car)
library(arm)
library(MuMIn)
library(lme4)
library(tidyverse)
library(corrplot)
library(corrgram)


library(scattermore)

########################## Purpose of code #######################
#This code will read in the .rds data created in step 6 to:
# i. determine if the data is linear between the variables
# ii. carry out the multi-linear regression analysis to obtain the coefficients 
#    to determine the relationship of the variable with AGB/AGC

#----------------------------------------------------------------#



######## All my functions 

variable_table = function(variable_files, var_string){
  age_list = list()
  for (i in (1:length(variable_files))){
    print(i)
    file = variable_files[i]
    print(file)
    dbf_data = foreign::read.dbf(file)
    print(nrow(dbf_data))
    if (nrow(dbf_data) ==0){
      print("Empty data")
    }else{
      colnames(dbf_data) = c("Value", "Count", "Area", var_string)
      dbf_data$fileNum = toString(i)
      age_list[[i]] = dbf_data
    }
    
  }
  variable_data = do.call(rbind, age_list)
  return(variable_data)
}

calc_mean = function(data, col){
  all_data %>%
    group_by(age, !! sym(col)) %>% 
    summarise(weighted_mean = weighted.mean(AGB, Count))}

calc_median = function(data, col){
  all_data %>%
    group_by(age, !! sym(col)) %>%
    summarise(weighted_median = weightedMedian(AGB, Count))}

calc_freq = function(data, col){
  all_data %>%
    group_by(age,  !! sym(col)) %>% 
    summarise(freq = n())}


########################## Test if data is linear ################
main_folder = "D:\\Chapter3\\JRC_recovery\\"

biome = c("amazon_basin","borneo", "congo_basin")
forest_type = c("degraded_for","secondary_for")

letters = c("a)", "b)", "c)")

all_biome_data = list()
set.seed(123) # 123

for (i in (1:length(biome))){
  b = biome[i]
  print(b)
  letter = letters[i]
  
  # initialise new list per forest type inside this first loop
  all_forest_data = list()
  plot_list = list()
  for (j in (1:length(forest_type))){
    t=forest_type[j]
    print(t)
    
    folder = paste0(main_folder,b, "\\", b, "_", t,"\\")
    setwd(folder)
    
    all_data=readRDS(paste0(folder, b,"_",t,"_all_vars_geom_agb_fix.rds"))
    
    #convert raster NODATA to R noData values 
    all_data$HAND[all_data$HAND == -32768] = NA
    all_data$Temperature[all_data$Temperature == 0] = NA
    all_data$MCWD[all_data$MCWD == -121188] = NA
    all_data = na.omit(all_data)
    all_data$Temperature = all_data$Temperature/10
    
    all_data = subset(all_data, select = -oilPalm)
    
    
    pearson=cor(all_data[5:10], method = "pearson")
    #spearman = cor(all_data[5:10], method = "spearman")
    
    testRes = cor.mtest(all_data[5:10], conf.level=0.95)
    
    if (b == "amazon_basin"){
      name = "Amazon"
    } else if (b == "borneo"){
      name = "Borneo"
    } else if (b == "congo_basin"){
      name = "Central Africa"
    }
    #title = paste(sub("_", " ", b),": ", sub("_", " ", t))
    title = paste(": ", sub("_", " ", t))
    
    title1 = gsub("(^|[[:space:]])([[:alpha:]])", "\\1\\U\\2",    # Uppercase with Base R
                 title,
                 perl = TRUE)
    
    #title2 = paste(letter, title1)

    title2 = paste(letter, name, title1)
    
    pearPlot=corrplot(pearson, p.mat = testRes$p, method = "color", type = "lower", insig= "blank",addCoef.col = "black", mar = c(0, 0, 0, 0),sig.level = 0.05)
    mtext(title2, at=2.5, line =-4.0, cex=1.2, font=2)
    #print(pearPlot)
    
    #sPlot=corrplot(spearman, method = "color",type = "lower", addCoef.col = "black", title = paste(b, t), mar=c(0,0,1,0))
    #print(sPlot)
  
    
    for (c in (7:10)){

      k = colnames(all_data[c])
      p = c-6

      variable_mean = calc_mean(all_data, k)

      if (k == "MCWD"){
        if(name =="Borneo"){
          labelx = -200
        } else{
          labelX = -400
        }
        

      } else{
        labelX = 20
      }
      quick_plot=ggplot(variable_mean, aes(x =!! sym(k) , y = weighted_mean, colour=age))+
        geom_point() +
        theme_bw() +
        theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
              axis.text.y = element_text(color="black", size=12),
              text=element_text(size=12),
              strip.background=element_rect(fill=NA, colour=NA),
              panel.spacing = unit(0, "lines"))+
        geom_smooth(method = "lm", color = "red") +
        stat_cor(r.digits=2, method = "pearson", label.x =labelX , label.y = 500) +
        #stat_cor(r.digits=2, method = "pearson", label.x =labelX , label.y = 500, 
         #        aes(label = paste("'r ='",..r.., sep = " "))) +
        #stat_cor(r.digits=2, method = "pearson", label.x =labelX, label.y = 450, 
         #        aes(label = paste(..p.label.., sep = " "))) +
        stat_cor(method = "spearman", label.x = labelX, label.y = 450) +
        ylab(expression('AGB (Mg ha'^-1*')')) +

        labs(title = paste(name, title1))

      #print(quick_plot)
      plot_list[[p]] = quick_plot
    }
    all_subplots=cowplot::plot_grid(plotlist = plot_list, nrow=2, labels = c("a", "b", "c", "d"))
    print(all_subplots)
  }
}


 ###################### scaled multi linear analysis by region/biome ########################
main_folder = "D:\\Chapter3\\JRC_recovery\\"

biome = c("amazon_basin", "borneo", "congo_basin")
forest_type = c("degraded_for","secondary_for")


all_biome_data = list()
set.seed(123) # 123, 1

for (i in (1:length(biome))){
  b = biome[i]
  print(b)
  
  # initialise new list per forest type inside this first loop
  all_forest_data = list()
  for (j in (1:length(forest_type))){
    t=forest_type[j]
    print(t)
    
    folder = paste0(main_folder,b, "\\", b, "_", t,"\\")
    setwd(folder)
    
    all_data=readRDS(paste0(folder, b,"_",t,"_all_vars_geom_agb_fix.rds"))
    
    #convert raster NODATA to R noData values 
    all_data$HAND[all_data$HAND == -32768] = NA
    #all_data$Precipitation[all_data$Precipitation == -9999] = NA
    all_data$MCWD[all_data$MCWD == -121188] = NA
    all_data = na.omit(all_data)
    
    all_data$Temperature2 = all_data$Temperature^2
    
    
    #using the results from the spatial autocorrelation and semi-variogram (step7, part1), we can see that
    # for different regions we need to extract only 1 pixel cluster from each grid cell
    #e.g. amaon ~0.5 deg
    
    if (b == "amazon_basin") {
      degree = 0.5
    } else if (b == "borneo"){
      degree = 0.3
    } else if (b == "congo_basin"){
      degree = 0.5
    }

    all_data$latRound = round(all_data$latitude/degree)*degree
    all_data$lonRound = round(all_data$longitude/degree)*degree
    
    all_data$latInt =all_data$latRound%%1==0
    all_data$lonInt = all_data$lonRound%%1==0
    
    
    
    counter = 0
    
    sampled_test = all_data %>%
      group_by(latRound,
               lonRound) %>%
      sample_n(1)
    n_iterations = ceiling(100000/nrow(sampled_test))
    
    looped_info = list()
    
    for (c in 1:n_iterations){
      
      print(c)

      sampled_test = all_data %>%
        group_by(latRound,
                 lonRound) %>%
        sample_n(1)
      counter = counter + nrow(sampled_test)
      print(counter)
      
      model = lm(scale(AGB) ~scale(age)+ scale(HAND) + scale(Temperature)+
                   scale(MCWD) + scale(distance_from_undisturbed) , data = sampled_test)
      
      mod_summary = summary(model)
      model_info = data.frame(mod_summary$coefficients)
      
      model_info$names = c("Intercept","Age", "HAND","Av. Max. \n Temperature",  "MCWD", "Distance from \n Undisturbed TMF")

      model_info$biome = b
      model_info$forest_type = t
      model_info$loop = toString(c)
      
      looped_info[[c]] = model_info
    }
    combined_data = do.call(rbind, looped_info)
    
    combined_data$row = rep(seq(1,6,1), times= (nrow(combined_data)/6))
    
    
    mean_coeeficients = combined_data %>%
      group_by(row) %>%
      summarise_each(funs(mean, sd, se=sd(.)/sqrt(n())))
    
    sd_coeeficients = combined_data %>%
      group_by(row) %>%
      summarise_each(funs(sd))
    
    se_coeeficients = combined_data %>%
      group_by(row) %>%
      summarise_each(funs(sd))
    
    library(Rmisc)

    mean_coeficients = summarySE(combined_data, measurevar="Estimate", groupvars = "row", na.rm=TRUE)
    mean_pvalue = summarySE(combined_data, measurevar="Pr...t..", groupvars = "row", na.rm=TRUE)
    mean_coeficients$pValue = mean_pvalue$Pr...t..
    mean_coeficients$names = c("Intercept","Age", "HAND", "Av. Max. \n Temperature", "MCWD", "Distance from \n Undisturbed TMF")
    mean_coeficients$biome = b
    mean_coeficients$forest_type = t
        all_forest_data[[j]] = mean_coeficients
  }
  combined_data = do.call(rbind, all_forest_data)
  all_biome_data[[i]] = combined_data
}

all_data = do.call(rbind, all_biome_data)
all_data = all_data[all_data$names != "Intercept",]


#all_data$star[all_data$pvalue >=0.01 & all_data$pValue <0.05] ="*"
all_data$star = all_data$Estimate
all_data$star[all_data$pValue >=0.1] = NA

all_data$star2 = all_data$Estimate
all_data$star2[all_data$pValue >=0.05] = NA

all_data$star3 = all_data$Estimate
all_data$star3[all_data$pValue >=0.01] = NA


all_data$biome[all_data$biome == "amazon_basin"] = "Amazon"
all_data$biome[all_data$biome == "borneo"] = "Borneo"
all_data$biome[all_data$biome == "congo_basin"] = "Central Africa"


all_data$forest_type[all_data$forest_type == "degraded_for"] = "Degraded Forest"
all_data$forest_type[all_data$forest_type == "secondary_for"] = "Secondary Forest"

all_data$names[all_data$names == "Age"] = "YSLD \n (~Age)"



dodge= position_jitter((width=0.4), seed = 926015923) 

plot_coefs = ggplot(data = all_data, mapping = aes(x = factor(names, levels = c("Distance from \n Undisturbed TMF", "HAND","MCWD", "Av. Max. \n Temperature","YSLD \n (~Age)")),
                                                   y = Estimate, ymin =Estimate - sd, ymax = Estimate + sd, fill = biome, color=biome, shape=forest_type))+
  geom_pointrange(size=1, position = dodge) +
  #geom_point(mapping=aes(x = names, y = star), size=1, color="white",  position = dodge)+
  geom_point(mapping=aes(x = factor(names, levels = c("Distance from \n Undisturbed TMF", "HAND","MCWD", "Av. Max. \n Temperature","YSLD \n (~Age)")),
                         y = star2), size=1, color="white",  position = dodge)+
  geom_point(mapping=aes(x = factor(names, levels = c("Distance from \n Undisturbed TMF", "HAND","MCWD", "Av. Max. \n Temperature","YSLD \n (~Age)")),
                         y = star3), size=1, color="black",  position = dodge)+
  #stat_pvalue_manual(all_data, x="names", y.position = "coef", label = "star", position=dodge)+
  #annotate("text", aes(x=all_data$names, y = all_data$coef, label=all_data$star))+
  scale_x_discrete("") + 
  scale_y_continuous("Standardised Coefficient")+
  geom_hline(yintercept=0, color="grey", size=1, linetype = "dashed") +
  #scale_alpha_discrete(breaks = c(0.5, 0.8, 1), labels = c("p >= 0.5", "p < 0.5", "p<0.01"), guide = "legend") +
  #geom_text(data = all_data,aes(x= names, y = coef, label=star),position = position_dodge(width = 0.5)) +
  theme_bw() +
  theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        axis.text.x = element_text(color="black", size=12),axis.text.y = element_text(color="black", size=12),
        text=element_text(size=12),)+
  theme(text = element_text(size=10)) +
  geom_vline(xintercept=c(1.5,2.5,3.5, 4.5),color="grey", size=0.5)+
  scale_color_manual(values = c("Amazon" = "#18BECF" , "Borneo" = "#BCBD23", "Central Africa" ="#7F7F7F"))+
  #ylab(NULL) + 
  #xlab("Standardised Coefficient")+
  coord_flip() +
  labs(fill = "Region", shape = "Forest Type", colour = "Region")


print(plot_coefs)


results_folder2 = "change to your results folder"



ggsave(paste0(results_folder2,"multi_linear_coefficients_",Sys.Date(),"_grid_agb_fix.png"),plot_coefs,width =8, height=4, units="in" , dpi=900)



########################## Histograms of variables (supplementary Figure 2) ####################


main_folder = "D:\\Chapter3\\JRC_recovery\\"

biome = c("amazon_basin", "borneo", "congo_basin")
forest_type = c("degraded_for","secondary_for")

all_biome_data = list()
set.seed(123)

for (i in (1:length(biome))){
  b = biome[i]
  print(b)
  
  # initialise new list per forest type inside this first loop
  all_forest_data = list()
  for (j in (1:length(forest_type))){
    t=forest_type[j]
    print(t)
    
    folder = paste0(main_folder,b, "\\", b, "_", t,"\\")
    setwd(folder)
    
    all_data=readRDS(paste0(folder, b,"_",t,"_all_vars_geom.rds"))
    
    #convert raster NODATA to R noData values 
    all_data$HAND[all_data$HAND == -32768] = NA
    #all_data$Precipitation[all_data$Precipitation == -9999] = NA
    all_data$MCWD[all_data$MCWD == -121188] = NA
    all_data = na.omit(all_data)
    all_data$forest_type = t
    all_data$biome = b
    print(nrow(all_data))
    
    all_forest_data[[j]] = all_data
    
  }
  combined_data = do.call(rbind, all_forest_data)
  all_biome_data[[i]] = combined_data
}

all_data = do.call(rbind, all_biome_data)

all_data = na.omit(all_data)

all_data$biome[all_data$biome == "amazon_basin"] = "Amazon"
all_data$biome[all_data$biome == "borneo"] = "Borneo"
all_data$biome[all_data$biome == "congo_basin"] = "Central Africa"

all_data$Temperature1 = all_data$Temperature/10

variables = c("HAND", "MCWD", "MaxTemp")

var_str = "MCWD"


build_histogram = function(binW, xaxisName, all_data, var_str, xmin, xmax){
  
  amz_col = "#18BECF" 
  bor_col = "#BCBD23" 
  con_col = "#7F7F7F"
  binWidth = binW
  xaxis = xaxisName
  

  hist = ggplot()+
    geom_histogram(data=all_data,aes(x = !! sym(var_str), y= stat(density*width),weight = Count,color= biome, fill= biome), position= "identity", alpha=0.4, binwidth =binWidth) + 
    #geom_vline(data = median, aes(xintercept = weighted_median, color = biome)) +
    theme_bw() +
    theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
          axis.text.x = element_text(color="black", size=12),axis.text.y = element_text(color="black", size=12),
          text=element_text(size=12),
          #legend.position = c(0.1, 0.8),
          legend.direction = "vertical", legend.box="horizontal",
          #legend.position = c(1.5, 0.5),
          legend.background = element_rect(fill = "white", color = "grey"))+
  
    scale_fill_manual(values = c(amz_col, bor_col, con_col)) +
    scale_color_manual(values = c(amz_col, bor_col, con_col)) +
    labs(x = xaxis, y= "Density",  fill = "Region", color = "Region") +
    guides(col=FALSE) +
    guides(fill = FALSE)+
    xlim(xmin, xmax) 

  
  print(hist)
  
  return(hist)

}


get_legend<-function(a.gplot){
  tmp <- ggplot_gtable(ggplot_build(a.gplot))
  leg <- which(sapply(tmp$grobs, function(x) x$name) == "guide-box")
  legend <- tmp$grobs[[leg]]
  return(legend)}

MCWD_hist = build_histogram(10, "MCWD (mm)", all_data, "MCWD", NA, 0)

temp_hist = build_histogram(0.5, "Av. Max. Temperature (\u00B0C)", all_data, "Temperature1",20, 33)

hand_hist = build_histogram(5, "HAND (m)", all_data, "HAND", 0, 300)

dist_hist = build_histogram(100, "Distance from undisturbed TMF (m)", all_data, "distance_from_undisturbed", 0, 1500)

age_hist = build_histogram(1, "Years since last disturbance event", all_data, "age", 0, 34)

age_hist_no_leg = build_histogram(1, "Years since last disturbance event", all_data, "age", 0, 34)

legend = get_legend(age_hist)

amz_col = "#18BECF" 
bor_col = "#BCBD23" 
con_col = "#7F7F7F"

legendV2 <- ggplot(all_data) + 
  geom_histogram(aes(x =age, y= stat(density*width),weight = Count,color= biome, fill= biome), position= "identity", alpha=0.4)+
  lims(x = c(0,0), y = c(0,0))+
  theme_void()+
  theme(legend.position = c(0.5,0.5),
        legend.key.size = unit(0.2, "cm"),
        legend.text = element_text(size =  12),
        legend.title = element_text(size = 15), 
        legend.margin = margin(6, 6, 6, 6),
        legend.background = element_rect(fill = "white", color = "grey"))+
  scale_fill_manual(values = c(amz_col, bor_col, con_col)) +
  scale_color_manual(values = c(amz_col, bor_col, con_col)) +
  labs( fill = "Region", color = "Region") +
  guides(colour = guide_legend(override.aes = list(size=8)))

legendV2

all_hists = list(MCWD_hist, temp_hist, hand_hist, dist_hist,age_hist_no_leg, legendV2)

p.allHist=ggarrange(plotlist = all_hists, labels = c("a", "b", "c","d","e", ""), ncol=3, nrow=2) #, common.legend = TRUE, legend.position = "right", legend.justification = c(0.75, 0.5)

print(p.allHist)

ggsave(paste0(main_folder,"results\\enviro_variables_",Sys.Date(),"_2.png"),p.allHist, width =12, height=6, units="in" , dpi=900)
