#################################################################################################
## Compare ancestral state reconstruction based on
## 1- two states according to the pollination syndrome of extant species
## 2- four states according to the pollination syndrome of extant species + evolutionary history
#################################################################################################

library(corHMM)
library(phytools)
library(beepr)
library(dplyr)
library(ggplot2)

###############
## functions
###############

detect_origin <- function(tip, phylo){
  root <- Nnode(phylo) + 2
  np <- nodepath(phylo, from = root, to = tip)
  nstates <- c(getStates(phylo, type = "nodes"))
  seq_nodes <- c(nstates[names(nstates) %in% np])
  seq_nodes <- replace(seq_nodes, seq_nodes == "3", "1")
  seq_nodes <- replace(seq_nodes, seq_nodes == "4", "2")
  N <- length(seq_nodes)
  seq_changes <- if_else(seq_nodes == lag(seq_nodes, 1), 0, 1)
  seq_changes[1] <- 0
  changes <- sum(seq_changes)
  res <- character()
  if(seq_nodes[1] == "1" & changes == 0) res <- "from_bee" else res <- "from_bird"
  return(res)
}

detect_classify <- function(phylo, FUN){
  X <- c(1:(phylo$Nnode+1))
  states <- sapply(X, FUN = FUN, phylo)
  res <- data.frame(states = states, species = phylo$tip.label)
  return(res)
}

################
## data loading 
################

# MCC tree from Kriebel et al. 
all_salvia_tree <- read.tree("Salvia_flower_evol_May2020/Phylogeny/Beast_Yule_MCC_newick.tre")

# Load pollinator and clade data from Kriebel et al.
all_salvia <- read.csv("Salvia_flower_evol_May2020/Data/groups.csv", header=T)
row.names(all_salvia) <- all_salvia$species

# Erase species not in tree
all_salvia <- all_salvia[all_salvia_tree$tip.label, ]

# Recode pollinator syndromes 
# 1 code for bee pollination
# 2 code for bird pollination
# 1&2 indicates ambiguity
all_salvia$pollinator2 <- gsub("bee&bird", "intermediate", all_salvia$pollinator)
all_salvia$pollinator2 <- gsub("sister", "bee", all_salvia$pollinator)
all_salvia$pollinator2<-as.factor(all_salvia$pollinator2)
all_salvia$pollinator2<-factor(all_salvia$pollinator2,labels = c("1","2","1&2"))

# Final base
b0 <- all_salvia[,c("species","pollinator2")]

# SMM_ARD (i.e. as in Kriebel et al. 2019, 2020)
# WARNING: adjust n.cores according to your system
SMM_ARD <- corHMM(all_salvia_tree, b0, model ="ARD", n.cores = 7, 
                  node.states = "marginal", rate.cat = 1, 
                  root.p = "yang",  get.tip.states = TRUE)


###########################################
## Stochastic character mapping on mcc tree
## based in a two-state ASR
###########################################

# smm_ard
simmap_smm_ard_t1 <- makeSimmap(tree = all_salvia_tree, data = b0, 
                                model =  SMM_ARD$solution,
                                rate.cat = 1, nSim = 150,
                                nCores = 7)
origins <- list(length(simmap_smm_ard_t1))
for(i in 1:length(simmap_smm_ard_t1)) origins[[i]] <- detect_classify(simmap_smm_ard_t1[[i]], detect_origin)
beep(5)

###########################################
# obtain the four groups: 
# bee-from-bee
# bee-from-bird
# bird
# polymorphic
###########################################

obtain4 <- function(data, org){
  prueba <- merge(org, data, by = "species")
  prueba$poll <- rep(0, length(prueba$pollinator2))
  prueba$poll[prueba$pollinator2 == "1" & prueba$states == "from_bee"] <- "1"
  prueba$poll[prueba$pollinator2 == "1" & prueba$states == "from_bird"] <- "2"
  prueba$poll[prueba$pollinator2 == "2"] <- "3"
  prueba$poll[prueba$pollinator2 == "1&2"] <- "4"
  prueba$poll <- as.factor(prueba$poll)
  rownames(prueba) <- prueba$species
  return(prueba[, c("species", "poll")])
}

newclass <- list()
for(i in 1:150) newclass[[i]] <- obtain4(data= b0, org = origins[[i]])

## CHECK: some simmap simulations (~20%) end in less than 4 groups
## (e.g. when the root is reconstructed as bird-pollinated)
check4 <- numeric(150)
for(i in 1:150) check4[i] <- length(unique(newclass[[i]]$poll))

## simmaps with less than 4 groups are discarded
## the whole set is reduced to 100 simmaps
newclass <- newclass[check4 == 4]
newclass <- newclass[1:100] 


###########################################
## Stochastic character mappings on mcc tree
## based in a four-state ASR
###########################################

# new classification of internal nodes (ancestors)
newsim <- list()
for(i in 1:100){
  SMM_ARD4 <- corHMM(all_salvia_tree, newclass[[i]], model ="ARD", n.cores = 7, 
                     node.states = "marginal", rate.cat = 1, 
                     root.p = "yang",  get.tip.states = TRUE)
  simm4 <- makeSimmap(tree = all_salvia_tree, data = newclass[[i]], 
                      model =  SMM_ARD4$solution,
                      rate.cat = 1, nSim = 10,
                      nCores = 7)
  newsim[[i]] <- getStates(simm4, type = "nodes")
}
pru <- do.call(cbind,newsim) 
beep(5)

# original classification of internal nodes (ancestors)
original <- list()
for(i in 1:100){
  original[[i]] <- matrix(rep(getStates(simmap_smm_ard_t1[1]), 10), 527, 10)
}
pru2 <- do.call(cbind, original)

# comparisson
comparisson <- list()
for(i in 1:1000){
  comparisson[[i]] <- c(table(pru2[,i], pru[,i]))
  names(comparisson[[i]]) <- c("bee_bee_from_bee", "bird_bee_from_bee", 
                               "bee_bee_from_bird", "bird_bee_from_bird",
                               "bee_bird", "bird_bird", 
                               "bee_poly", "bird_poly")
}
final <- as.data.frame(do.call(rbind, comparisson))


finalB <- final
finalB$bee_bee_from_bee <- final$bee_bee_from_bee/(final$bee_bee_from_bee + final$bee_bee_from_bird + 
                                                     final$bee_bird + final$bee_poly)*100
finalB$bee_bee_from_bird <- final$bee_bee_from_bird/(final$bee_bee_from_bee + final$bee_bee_from_bird + 
                                                     final$bee_bird + final$bee_poly)*100
finalB$bee_bird <- final$bee_bird/(final$bee_bee_from_bee + final$bee_bee_from_bird + 
                                                     final$bee_bird + final$bee_poly)*100
finalB$bee_poly <- final$bee_poly/(final$bee_bee_from_bee + final$bee_bee_from_bird + 
                                                     final$bee_bird + final$bee_poly)*100
finalB$bird_bee_from_bee <- final$bird_bee_from_bee/(final$bird_bee_from_bee + final$bird_bee_from_bird + 
                                                     final$bird_bird + final$bird_poly)*100
finalB$bird_bee_from_bird <- final$bird_bee_from_bird/(final$bird_bee_from_bee + final$bird_bee_from_bird + 
                                                       final$bird_bird + final$bird_poly)*100
finalB$bird_bird <- final$bird_bird/(final$bird_bee_from_bee + final$bird_bee_from_bird + 
                                                       final$bird_bird + final$bird_poly)*100
finalB$bird_poly <- final$bird_poly/(final$bird_bee_from_bee + final$bird_bee_from_bird + 
                                                       final$bird_bird + final$bird_poly)*100

M <- matrix(apply(finalB, 2, mean), 2, 4)
colnames(M) <- c("bee from bee", "bee from bird", "bird", "polymorphic")
rownames(M) <- c("bee", "bird")
round(M, 2)

S <- matrix(apply(finalB, 2, sd), 2, 4)
colnames(S) <- c("bee from bee", "bee from bird", "bird", "polymorphic")
rownames(S) <- c("bee", "bird")
round(S, 2)


## percentage of bird-pollinated internal nodes (ancestors) according to
## the ancestral state reconstruction based on 2 states that are 
## reconstructed under the "bee from bird" regime according to the
## ancestral state reconstruction based on 4 states

summary(final$bird_bee_from_bird/(final$bird_bee_from_bee+final$bird_bee_from_bird+
                           final$bird_bird + final$bird_poly)*100)

g2 <- ggplot(final, aes(bird_bee_from_bird/(bird_bee_from_bee+bird_bee_from_bird+
                                            bird_bird + bird_poly)*100)) + 
  geom_density(fill = "red") + xlab("percentage of bird ancestors\nassigned to the 'bee from bird' regime") +
  theme_bw()

g2

save(pru, pru2, final, file = "compareASR.RData")

svg(file = "compASR.svg", width = 5, height = 4)
g2
dev.off()


#############
## END
#############