####################################################################################################################
# Study Title: Turning lances into shields: flower mantids stretch their raptorial forelegs to avert and deflect predator attack
# Script Name: Model Comparison, Ancestral State Reconstruction and Rate Change Across Time Plot. 
# Description: This script performs a model comparison (ER, SYM, ARD) based on the provided
#              polymorphic character state matrix and phylogenetic tree, selects the best 
#              model, and then uses stochastic mapping to reconstruct 
#              ancestral states under ‘fitpolyMk’ and ‘make.simmap’. 
#              It also plots the Estimation of transition rates among different states, 
#              merges states, and generates density plots for the 95% HPD of state
#              transitions.
#              The end of this script employed 'fitMk' and 'make.simmap' performing model comparison 
#              among the three models and then uses the selected best model ("SYM") to analyze the rate of 
#              state changes across time.
#              Due to differences in random number generators and numerical 
#              algorithms across different versions of R and its packages, slight variations 
#              in results may occur when running this script on different systems or with 
#              different R versions. 
#              
# Reference:
# Model Comparison: http://blog.phytools.org/2015/07/integrating-stochastic-character-maps.html
# Ancestral State Reconstruction & 95% HPD Plot: http://blog.phytools.org/2019/07/stochastic-character-mapping-with.html
# Rate Changes Across Time:  http://blog.phytools.org/search?q=ltt
####################################################################################################################

# Clear the environment and free memory
rm(list = ls())
gc()

# Load required libraries
library(phytools)

# ==============================
# Data Import and Preprocessing
# ==============================

# Read the character state matrix and phylogenetic tree
char_data <- read.csv("Mantidea_state_matrix.csv", row.names = 1)
x <-setNames(char_data[,1],rownames(char_data))
x
# To remove order bias in polymorphic states ("B+S" vs "S+B")
state <-setNames(as.character(x),names(x))
ii <-which(state=="B+S")
state[sample(ii,round(0.5*length(ii)))]<-"S+B"
state <-as.factor(state)
state

tree <- read.tree("Mantidea_pruned_outgroups.tree")

# ==============================
# Model Comparison: ER, SYM, ARD
# ==============================

# Define AIC and AIC weights calculation function
aic <- function(logL, k) {
  2 * k - 2 * logL
}
aic.w <- function(aic_values) {
  d.aic <- aic_values - min(aic_values)
  weights <- exp(-0.5 * d.aic) / sum(exp(-0.5 * d.aic))
  return(weights)
}

# Calculate logL values
models <- c("ER", "SYM", "ARD")
logL_values <- sapply(models, function(model) {
  fit_tmp <- fitpolyMk(tree, state, model = model, ordered = FALSE)
  return(fit_tmp$logL)
})
names(logL_values) <- models
logL_values

# Calculate AIC values based on the number of free parameters:
# ER = 1, SYM = 3, ARD = 6 (parameters adjust if needed)
k_values <- c(ER = 1, SYM = 3, ARD = 6)
AIC_values <- mapply(aic, logL = logL_values, k = k_values)
AIC_values

# Calculate AIC weights
AIC_weights <- aic.w(AIC_values)
AIC_weights

# ==============================
# Ancestral State Reconstruction using the optimal model ARD
# ==============================

# Fit the model using ARD (assumed best based on logL, AIC and AIC weights)
fit <- fitpolyMk(tree, state, model = "ARD", ordered = FALSE)

# Plot the fitpolyMk results
plot(fit, color = TRUE, width = TRUE, tol = 1e-3, show.zeros = FALSE,
     offset = 0.03, mar = c(0, 0, 2.1, 0))
title(main = " Estimation of transition rates among different states using ARD Model", 
      font.main = 3, line = -1)

# Print the transition rate matrix and raw data
fit$index.matrix
fit$data

# ==============================
# Stochastic Mapping for Ancestral State Reconstruction
# Using make.simmap (nsim = 1000 simulations)
# ==============================

simmap_trees <- make.simmap(tree, fit$data, model = fit$index.matrix, nsim = 1000)
simmap_trees

# Plot the summary of the stochastic mapping results if needed
#cols <- setNames(colorRampPalette(c("black", "yellow", "red", "blue", 
#                                    "green", "grey", "orange", "purple"))(15),
#                 colnames(fit$data))
#plot(summary(simmap_trees), colors = cols, ftype = "off")
#legend("topleft", legend = colnames(fit$data), pt.cex = 2.4, pch = 21, pt.bg = cols)

# ==============================
# Merging Mapped States
# Merge similar states based on data characteristics to facilitate further analyses.
# ==============================

# The following merging scheme is an example; adjust based on your state names
merged_trees <- mergeMappedStates(simmap_trees, c("B+N", "B"), "B")
merged_trees <- mergeMappedStates(merged_trees, c("N+S", "S"), "S")
merged_trees <- mergeMappedStates(merged_trees, c("N+W", "W"), "W")
merged_trees <- mergeMappedStates(merged_trees, c("B+N+S", "B+S"), "B+S")
merged_trees <- mergeMappedStates(merged_trees, c("B+N+W", "B+W"), "B+W")
merged_trees <- mergeMappedStates(merged_trees, c("N+S+W", "S+W"), "S+W")
merged_trees <- mergeMappedStates(merged_trees, c("B+N+S+W", "B+S+W"), "B+S+W")
# Additional merging if necessary
cols_merged <- setNames(c("purple", "red", "green", "blue", 
                          "orange", "black", "grey", "yellow"),
                        c("N", "B", "S", "W", "B+S", "B+W", "S+W", "B+S+W"))
plot(summary(merged_trees), colors = cols_merged, ftype = "off")
legend("bottomleft", legend = names(cols_merged), pt.cex = 2.4, pch = 21, pt.bg = cols_merged)

# ==============================
# Plotting 95% HPD Density for State Transitions
# Use the density function to calculate the transition rate density and plot the 95%
# HPD intervals for selected transitions.
# ==============================

density_obj <- density(merged_trees)
density_obj

# Set plotting parameters (adjust based on your data)
ylim_range <- c(0, 1.5)

# Plots for selected state transitions:
plot(density_obj, transition = "B+S->S", ylim = ylim_range,
     main = "Density of Transition: B+S->S")
plot(density_obj, transition = "S->B+S", ylim = ylim_range,
     main = "Density of Transition: S->B+S")

plot(density_obj, transition = "B->B+S", ylim = ylim_range,
     main = "Density of Transition:B->B+S ")
plot(density_obj, transition = "B+S->B", ylim = ylim_range,
     main = "Density of Transition: B+S->B")

plot(density_obj, transition = "S+W->S", ylim = ylim_range,
     main = "Density of Transition:S+W->S ")
plot(density_obj, transition = "S->S+W", ylim = ylim_range,
     main = "Density of Transition: S->S+W")

plot(density_obj, transition = "W->S+W", ylim = ylim_range,
     main = "Density of Transition:W->S+W ")
plot(density_obj, transition = "S+W->W", ylim = ylim_range,
     main = "Density of Transition: S+W->W")

plot(density_obj, transition = "N->S", ylim = ylim_range,
     main = "Density of Transition:N->S ")
plot(density_obj, transition = "S->N", ylim = ylim_range,
     main = "Density of Transition: S->N")

# ==============================
# Plotting the Rate Changes Across Time.
# ==============================

library(RColorBrewer)

# ------------------------------
# 2. Model Comparison
# ------------------------------

# Define AIC calculation functions.
aic <- function(logL, k) 2 * k - 2 * logL
aic.w <- function(aic_values) {
  d.aic <- aic_values - min(aic_values)
  weights <- exp(-0.5 * d.aic) / sum(exp(-0.5 * d.aic))
  return(weights)
}

# Fit three models (ER, SYM, ARD) using fitMk and compute log likelihoods.
models <- c("ER", "SYM", "ARD")
logL_values <- sapply(models, function(model) {
  fit_tmp <- fitMk(tree, x, model = model, ordered = FALSE)
  return(fit_tmp$logL)
})
names(logL_values) <- models
logL_values

# Compute AIC values.
# (Assumed parameter counts: ER = 1, SYM = 3, ARD = 6; adjust as necessary.)
param_counts <- c(ER = 1, SYM = 3, ARD = 6)
AIC_values <- mapply(aic, logL = logL_values, k = param_counts)
cat("AIC values for models:\n")
print(AIC_values)

# Compute AIC weights.
AIC_weights <- aic.w(AIC_values)
cat("AIC weights for models:\n")
print(AIC_weights)

# Select the SYM model for better AIC and AIC weight.
best_model <- "SYM"

# ------------------------------
# 3. Rate Change Analysis with Stochastic Mapping
# ------------------------------

# Perform stochastic mapping using the best model (SYM) with 1000 simulations.
smap_trees <- make.simmap(tree, x, model = best_model, nsim = 1000)
smap_trees

# Plot the posterior distribution of state changes.
# Define colors for plotting using RColorBrewer.
colors<-setNames(brewer.pal(5,"Set1"),letters[1:5])
# Prepare plotting parameters for overlaying changes.
plotTree(tree,ftype="off",lwd=1)
par(fg="transparent")
par(fg="black")
changes<-sapply(smap_trees,markChanges,
                colors=sapply(colors,make.transparent,alpha=0.3))

head(changes,n=2)


# Divide the total time (from the root to the present) into segments.
h<-max(nodeHeights(tree))
b<-50
segs<-cbind(seq(0,h-h/b,h/b),
            seq(1/b*h,h,h/b))
segs


# Compute the average number of changes for each time segment.
nchanges<-rep(0,b)
for(i in 1:length(changes)){
  for(j in 1:nrow(changes[[i]])){
    ind<-which((changes[[i]][j,1]>segs[,1])+
                 (changes[[i]][j,1]<=segs[,2])==2)
    nchanges[ind]<-nchanges[ind]+1/length(changes)
  }
}

# Plot the rate (mean number of changes) across time.
plot(h-as.vector(t(segs)),rbind(nchanges,nchanges),type="l",lwd=2,
     xlim=c(max(segs),min(segs)),
     lend=0,xlab="time since the present",ylab="mean number of changes")

# Overlay the tree for additional context if needed.
plotTree(tree,add=TRUE,ftype="off",lwd=1,color=make.transparent("blue",0.1),
         mar=par()$mar,direction="leftwards",xlim=c(max(segs),min(segs)))


####################################################################################################################
# End of Script
####################################################################################################################
