#---------------- Siqueira et al ----------------#
#----- Evolution of fish-coral interactions -----#
#-- Reconstruction of association through time --#

#----- Loading packages -----#
library(ape)        #Version 5.7.1
library(phytools)   #Version 1.0.3
library(plyr)       #Version 1.8.7
library(tidyverse)  #Version 2.0.0
library(RevGadgets) #Version 1.0.0


#----- Setting WD -----#
wd <- "XXX"   #Set your personal working directory here
setwd(wd)


#----- Reading data -----#
data.all <- read.csv("Data/Siqueira_etal_coral_association_Consensus.csv", header=T, sep = ",", stringsAsFactors = T)

data.all <- data.all %>% 
  mutate(Association=factor(Association, levels = c('Non-associated','Moderate','Strong'))) %>% 
  select(Family, Genus, Species, Association)


#Stacked barplot
data.plot <- data.all %>%
  dplyr::group_by(Family, Association) %>%
  dplyr::summarise(distinct_species = n_distinct(Species)) %>% as.data.frame()

#----- Plotting Figure 1 -----#
#pdf('Figures/Siqueira_etal_Fig1B.pdf', height = 5, width = 7, useDingbats = F)
ggplot(data.plot, aes(fill=Association, y=distinct_species, x=reorder(Family, -distinct_species, sum))) + 
  geom_bar(position="stack", stat="identity") + 
  ylab('Number of Species') + xlab('') +
  scale_y_continuous(limits=c(0, 800), expand = c(0, 0)) +
  scale_fill_manual(values = c("#86BBD8","#1E84AD","#2F4858")) +
  coord_flip() +
  theme_classic()
#dev.off()


#----- Reconstructing the number of lineages through time -----#
#----- Reading data pruned and preparing for analysis -----#
data.pruned <- read.csv("Data/Siqueira_etal_coral_association_sub_clean.csv", header=T, sep = ",", stringsAsFactors = T) %>% 
  mutate(Association=factor(Association, levels = c('Non-associated','Moderate','Strong')))

rownames(data.pruned) <- data.pruned$Species
fish.coral <- data.pruned$Model; names(fish.coral) <- data.pruned$Species

fish.coral.names <- data.pruned %>% 
  mutate(Association=factor(Association, levels = c('Non-associated','Moderate','Strong'))) %>% 
  select(Association)

fish.coral.names <- droplevels(fish.coral.names$Association)

# Reading the tree
tr.pruned <- read.nexus("RevBayes/data/reef_fish_wdata.nex")


#----- Reading results from RevBayes and preparing for analysis -----#
# read and process the ancestral states from the RevBayes output
HiSSE_file <- paste0("RevBayes/MuHiSSE/output/coral_association_MuHiSSE_coral_association_anc_states_results.tree")
p_anc <- processAncStates(HiSSE_file, state_labels = c("0" = "Non-associated",
                                                       "1" = "Moderate",
                                                       "2" = "Strong",
                                                       "3" = "Non-associated",
                                                       "4" = "Moderate",
                                                       "5" = "Strong"))


# preparing data for reconstruction
anc_data <- p_anc@data %>% 
  mutate(anc_state_1_pp = as.numeric(anc_state_1_pp),
         anc_state_2_pp = as.numeric(anc_state_2_pp),
         anc_state_3_pp = as.numeric(anc_state_3_pp),
         anc_state_other_pp = as.numeric(anc_state_other_pp),
         node = as.numeric(node)) %>% 
  filter(node > Ntip(tr.pruned)) %>% 
  arrange(node) %>% as.data.frame()

coral_st <- matrix(nrow = 3, ncol = nrow(anc_data), dimnames = list(levels(fish.coral.names), anc_data$node))

for(i in 1:nrow(anc_data)){
  vals <- list(anc_data[i,which(anc_data[i,] == "Non-associated") + 1],
               anc_data[i,which(anc_data[i,] == "Moderate") + 1],
               anc_data[i,which(anc_data[i,] == "Strong") + 1])
  
  if(any(sapply(vals, length) == 0)){
    if(table(sapply(vals, length) == 0)["TRUE"][[1]] > 1){
      vals[[which(sapply(vals, length) == 0)[1]]] = anc_data[i,"anc_state_other_pp"]
      vals[[which(sapply(vals, length) == 0)]] = anc_data[i,"anc_state_other_pp"]
    }
    else {
      vals[[which(sapply(vals, length) == 0)]] = anc_data[i,"anc_state_other_pp"]
    }
    
  }
  
  if(any(sapply(vals, length) > 1)){
    vals[[which(sapply(vals, length) > 1)]] = sum(vals[[which(sapply(vals, length) > 1)]])
  }
  
  coral_st[,i] <- c(vals[[1]],vals[[2]],vals[[3]])
  
}

#----- Sourcing functions -----#
source('R_scripts/Siqueira_etal_functions.R')

#----- Time-slices loop -----#
max_age <- max(nodeHeights(tr.pruned))

## Assisting objects ##
# time slices
timevec <- c(seq(0.5,125,0.5))

# output 
tempvec <- data.frame (time = NA, lineage = NA, Association = NA)
slices.all <- tempvec [-1,]

## Loop to reconstruct ancestral states ##
for (j in timevec) {
  
  Association <- slice.discrete (tr.pruned, j, fish.coral, fish.coral.names, coral_st)
  
  if (any(is.na(Association))) {
    df <- data.frame (time = NULL,lineage = NULL, Association = NULL, row.names= NULL)
    slices.all <- rbind (slices.all, df)
  }
  else {
    df <- data.frame (time = j, lineage = names (Association), Association = as.data.frame (Association), row.names= NULL)
    slices.all <- rbind (slices.all, df)
  }
}


#----- Combining data frame from extant and reconstructed lineages -----#
data_Extant <- data.pruned %>% 
  mutate(time = rep(0,length(data.pruned$Species))) %>% select(time, Species, Association)
oldnames <- colnames(data_Extant)
newnames <- colnames(slices.all)
data_Extant <- data_Extant %>% rename_at(vars(all_of(oldnames)), ~ newnames)

origin <- data.frame(time = max(nodeHeights(tr.pruned)), lineage = 'Root', Association = 'Non-associated')

slices.combined <- rbind(slices.all,data_Extant,origin)

slices <- slices.combined %>% split(slices.combined$time)

slices.coral <- list()
for (i in 1:length(slices)){
  coral.all <- unlist(as.character(slices[[i]]$Association))
  mat <- as.matrix(table(coral.all))
  slices.coral[[i]] <- cbind(mat,rep(unique(slices[[i]]$time), nrow(mat)),rownames(mat))
}

slices.coral <- lapply(slices.coral, function(x) {colnames(x) <- c("NLineages", "TimeSlice","Association"); x})


#----- Plotting LTT - Figure 2 -----#
df <- ldply (slices.coral, data.frame, stringsAsFactors=F)
df.slice.ori <- df %>%  
  mutate(Association = factor(Association , levels=c('Non-associated','Moderate','Strong')),
         NLineages = as.numeric(NLineages),
         TimeSlice = as.numeric(TimeSlice))


#### Coral lineage through time
# Read in and filter a set of 100 coral trees from Huang et al 2017
trees <- read.nexus("Data/Huang_etal_2017_Scleractinia.tre")
data.huang <- read.csv("Data/Huang_etal_2017_CoralSpecies.csv", header=T, sep = ",", stringsAsFactors = T)

trees <- trees[1:100]
class(trees) <- "multiPhylo"

# Lineage through tim for Scleractinia
ltt.scle <- ltt(trees, log.lineages = F, plot = F)

max.age.scle <- max(sapply(trees,nodeHeights))

ltt.coral.scle <- data.frame(tree = NULL, lineages = NULL, time = NULL) 
for(i in 1:length(ltt.scle)) {
  df <- data.frame(tree = rep(i, length(ltt.scle[[i]]$ltt)), lineages = ltt.scle[[i]]$ltt, time = ltt.scle[[i]]$times)
  age.corr <- max.age.scle - max(df$time)
  df$time <- df$time + age.corr
  ltt.coral.scle <- rbind(ltt.coral.scle, df)
}

# Scleractinia fossil data from Siqueira et al 2022
scleractinia_fossils <- read.csv("Data/scleractinia_fossil_ltt.txt", header=T, sep = "\t", stringsAsFactors = T)

## Plotting
ymax <- 1800

lttP <- ltt.coral.scle %>% filter(time > max(ltt.coral.scle$time) - 66) %>% 
  ggplot(aes(x = time-max(time), y = lineages)) + 
  geom_area(data = df.slice.ori[df.slice.ori$TimeSlice < 66,], aes(x=-TimeSlice, y=NLineages, fill=Association),
            linewidth=.1, alpha=1, inherit.aes = F) +
  geom_line(aes(x = time-max(time), y = lineages, group = factor(tree)), alpha = 0.2, col = '#F7AF0F') + 
  geom_line(data = scleractinia_fossils, aes(x = -time, y = diversity),linewidth = 1, col = "#F4D967") + 
  annotate(geom="text", x=-20.5, y=320, label="*", size = 7) +
  theme_bw() +
  theme(panel.grid.major = element_blank(), 
        panel.grid.minor = element_blank()) +
  scale_x_continuous(limits = c(-66,0), expand = c(0,0), breaks = seq(-60,0,10), labels = rev(seq(0,60,10))) + 
  scale_y_continuous(limits = c(0,ymax), expand = c(0,0), position = 'right')

# creating data frame with geologic epochs
epochs <- data.frame(epoch = c("Paleocene","Eocene", "Oligocene", "Miocene", "Pli.", "Ple."),
                     xmin = c(-66,-56,-33.9,-23,-5.3,-2.58), xmax = c(-56,-33.9,-23,-5.3,-2.58,0), 
                     ymin = 0, ymax = ymax, colr = c("gray88","#A0A0A0","gray88","#A0A0A0","gray88","#A0A0A0"), stringsAsFactors = F)
epochs <- epochs %>% mutate(xlab = (xmin+xmax)/2, ylab = ymax) %>% data.frame()

# adding epochs to plot
lttP2 <- lttP + geom_rect(data = epochs, aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax, fill = epoch), 
                          colour = "white", alpha = 0.3, inherit.aes = F) + 
  scale_fill_manual(values = c(setNames(epochs$colr, epochs$epoch),setNames(c("#86BBD8","#1E84AD","#2F4858"),levels(df.slice.ori$Association))), guide = "none") + 
  xlab("Time (Ma)") + ylab("Number of lineages")

# send the epoch box layer to the back of the plot
lttP2$layers <- c(lttP2$layers[5], lttP2$layers[-5])

# add epoch label to plot
lttP3 <- lttP2 + geom_text(data = epochs, aes(x = xlab, y = ylab, label = epoch), size = 3, vjust = 1.3,hjust = 0.5, inherit.aes = F)

#pdf('Figures/Siqueira_etal_Fig2.pdf', height = 5, width = 7, useDingbats = F)
lttP3
#dev.off()


### Inset Figure 2
acro <- data.huang %>%
  filter(Family == 'Acroporidae') %>% select(Family, Species, Taxon)

tr.acro <- list()
for(i in 1:length(trees)){
  tr.acro[[i]] = keep.tip(trees[[i]], as.character(droplevels(acro$Taxon)))
}

class(tr.acro) <- "multiPhylo"

ltt <- ltt(tr.acro, log.lineages = F, plot = F)

max.age.acro <- max(sapply(tr.acro,nodeHeights))

ltt.coral.acro <- data.frame(tree = NULL, lineages = NULL, time = NULL) 
for(i in 1:length(ltt)) {
  df <- data.frame(tree = rep(i, length(ltt[[i]]$ltt)), lineages = ltt[[i]]$ltt, time = ltt[[i]]$times)
  age.corr <- max.age.acro - max(df$time)
  df$time <- df$time + age.corr
  ltt.coral.acro <- rbind(ltt.coral.acro, df)
}


df.slice.rel <- df.slice.ori %>% 
  filter(Association != "Non-associated") %>% 
  mutate(relative = NLineages/max(NLineages))

ltt.coral.acro <- ltt.coral.acro %>% 
  mutate(time = max(time) - time) %>% 
  filter(time < 66) %>% 
  group_by(tree) %>%
  mutate(relative = lineages / max(lineages))

# Acroporidae fossil data from Siqueira et al 2022
acroporidae_fossils<- read.csv("Data/acroporidae_fossils_ltt.txt", header=T, sep = "\t", stringsAsFactors = T)

acroporidae_fossils <- acroporidae_fossils %>% 
  filter(time < 66) %>% 
  mutate(relative = diversity / max(diversity))

## Plotting
ggplot(df.slice.rel, aes(x=-TimeSlice, y=relative, fill=Association)) +
  geom_area(colour="black", linewidth=.1, alpha=1) + 
  geom_line(data = ltt.coral.acro, aes(x = -time, y = relative, group = factor(tree)), alpha = 0.2, col = '#F15E22', inherit.aes = F) + 
  geom_line(data = acroporidae_fossils, aes(x = -time, y = relative),linewidth = 1, col = "#F4A971", inherit.aes = F) +
  labs(x = "Time (Ma)", y = "Proportion of lineages") +
  scale_fill_manual(values = c("#1E84AD","#2F4858"), breaks=c("Moderate","Strong"), guide = "none") +
  scale_x_continuous(limits = c(-66,0), expand = c(0,0), breaks = seq(-60,0,10), labels = rev(seq(0,60,10))) + 
  scale_y_continuous(limits = c(0,1.35), expand = c(0,0), position = 'right') +
  theme_bw() +
  theme(text = element_text(size=14),
        plot.background = element_rect(fill = "transparent", colour = NA))

#ggsave("Figures/Siqueira_etal_Fig2_inset.pdf")


#----- Performing simulations -----#
df.slice.all <- data.frame(NLineages = NA, TimeSlice = NA, 
                           Association = NA, boot = NA)

nboot <- 100

for(i in 1:nboot) {
  ##Loop starts here
  
  #----- Simulating data -----#
  sim <- rTraitDisc(tr.pruned, k = 3, 
                    freq = summary(data.pruned$Association)/sum(summary(data.pruned$Association)),
                    states = levels(data.pruned$Association))
  
  fish.coral <- as.character(droplevels(sim)); names(fish.coral) <- names(sim)
  
  fish.coral.names <- droplevels(sim)
  
  
  #----- Making simulated ancestral reconstruction -----#
  fit <- ace(fish.coral,tr.pruned,model="ER",type="discrete")
  
  fit_st <- t(fit$lik.anc[,c("Non-associated","Moderate","Strong")]) #Reordering columns
  
  
  #----- Time-slices loop -----#
  max_age <- max(nodeHeights(tr.pruned))
  
  ## Assisting objects ##
  # time slices
  timevec <- c(seq(0.5,125,0.5))
  
  # output object
  tempvec <- data.frame (time = NA, lineage = NA, Association = NA)
  slices.all <- tempvec [-1,]
  
  ## Loop ##
  for (j in timevec) {
    
    Association <- slice.discrete (tr.pruned, j, fish.coral, fish.coral.names, fit_st)
    
    if (any(is.na(Association))) {
      df <- data.frame (time = NULL,lineage = NULL, Association = NULL, row.names= NULL)
      slices.all <- rbind (slices.all, df)
    }
    else {
      df <- data.frame (time = j, lineage = names (Association), Association = as.data.frame (Association), row.names= NULL)
      slices.all <- rbind (slices.all, df)
    }
  }
  
  
  #----- Combining data frame -----#
  data_Extant <- data.pruned %>% 
    mutate(time = rep(0,length(data.pruned$Species)),
           Association = sim) %>% 
    select(time, Species, Association)
  
  oldnames <- colnames(data_Extant)
  newnames <- colnames(slices.all)
  data_Extant <- data_Extant %>% rename_at(vars(all_of(oldnames)), ~ newnames)
  
  origin <- data.frame(time = max(nodeHeights(tr.pruned)), lineage = 'Root', Association = 'Non-associated')
  
  slices.combined <- rbind(slices.all,data_Extant,origin)
  
  slices <- slices.combined %>% split(slices.combined$time)
  
  slices.coral <- list()
  for (x in 1:length(slices)){
    coral.all <- unlist(as.character(slices[[x]]$Association))
    mat <- as.matrix(table(coral.all))
    slices.coral[[x]] <- cbind(mat,rep(unique(slices[[x]]$time), nrow(mat)),rownames(mat))
  }
  
  slices.coral <- lapply(slices.coral, function(x) {colnames(x) <- c("NLineages", "TimeSlice","Association"); x})
  
  
  # LTT data
  df <- ldply (slices.coral, data.frame, stringsAsFactors=F)
  df.slice <- df %>%  
    mutate(Association = factor(Association , levels=c('Non-associated','Moderate','Strong')),
           NLineages = as.numeric(NLineages),
           TimeSlice = as.numeric(TimeSlice))
  
  
  df.slice <- df.slice[df.slice$TimeSlice < 66,]
  
  df.slice$boot <- rep(i,nrow(df.slice))
  
  df.slice.all <- rbind(df.slice.all, df.slice)
}

df.slice.all <- df.slice.all[-1,] %>% 
  mutate(Association = factor(Association , levels=c('Non-associated','Moderate','Strong')))

#write.csv(df.slice.all, "Data/simulation.csv", row.names = F)

## Getting pre-cooked results - uncomment this to get the results presented in the paper
#df.slice.all <- read.csv("Data/simulation.csv", header=T, sep = ",", stringsAsFactors = T)


#----- Plotting simulation - Figure 3 -----#
## Organising data for plotting
dat <- df.slice.ori[df.slice.ori$TimeSlice < 66,] %>% 
  group_by(Association, TimeSlice) %>% 
  summarise(NLineages = mean(NLineages)) %>% 
  as.data.frame()

df.slice.mean <- df.slice.all %>% 
  mutate(Association = factor(Association , levels=c('Non-associated','Moderate','Strong'))) %>% 
  group_by(Association, TimeSlice) %>% 
  summarise(NLineages_sim = mean(NLineages)) %>% 
  as.data.frame()

dff <- dat %>% 
  full_join(df.slice.mean, by = c("Association", "TimeSlice"))

dff_diff <- dff %>% 
  group_by(Association, TimeSlice) %>%
  summarise(diff = NLineages - NLineages_sim) %>% 
  as.data.frame()

dff$diff <- dff_diff$diff

dff_diff_max <- dff %>%
  group_by(Association) %>% 
  filter(abs(diff) == max(abs(diff), na.rm = T))

## Plotting
ymax <- 1300

lttP.sim <- ggplot(data = df.slice.all, aes(x=-TimeSlice, y=NLineages)) + 
  geom_line(data = ltt.coral.scle[ltt.coral.scle$time > max(ltt.coral.scle$time) - 66,], aes(x = time-max(time), y = lineages, group = factor(tree)), alpha = 0.2, col = '#F7AF0F', inherit.aes = F) + 
  geom_ribbon(data = dff, aes(ymin=NLineages, ymax=NLineages_sim, fill = Association), alpha=0.7) +
  geom_line(aes(group = interaction(boot, Association), col=Association), alpha=0.5) +
  geom_line(data = df.slice.mean, aes(x=-TimeSlice, y=NLineages_sim, group=Association), 
            col = "white", inherit.aes = F) + 
  geom_line(data = dat, aes(x=-TimeSlice, y=NLineages, group = Association),
            linetype = "dashed", linewidth = 0.8, col = "black", inherit.aes = F) +
  geom_segment(data = dff_diff_max, aes(x = -TimeSlice, y = NLineages, xend = -TimeSlice, yend = NLineages_sim, group = Association),
               arrow = arrow(length = unit(0.15, "cm"), type="closed", ends = "both"), linewidth = 0.8, inherit.aes = F) +
  scale_x_continuous(limits = c(-66,0), expand = c(0,0), breaks = seq(-60,0,10), labels = rev(seq(0,66,10))) + 
  scale_y_continuous(limits = c(0,ymax), expand = c(0,0), position = 'right') +
  scale_color_manual(values = setNames(c("#86BBD8","#1E84AD","#2F4858"),c("Non-associated","Moderate","Strong")),guide = "none") +
  theme_bw() +
  theme(panel.grid.major = element_blank(), 
        panel.grid.minor = element_blank())

# creating data frame with geologic epochs
epochs <- data.frame(epoch = c("Paleocene","Eocene", "Oligocene", "Miocene", "Pli.", "Ple."),
                     xmin = c(-66,-56,-33.9,-23,-5.3,-2.58), xmax = c(-56,-33.9,-23,-5.3,-2.58,0), 
                     ymin = 0, ymax = ymax, colr = c("gray88","#A0A0A0","gray88","#A0A0A0","gray88","#A0A0A0"), stringsAsFactors = F)
epochs <- epochs %>% mutate(xlab = (xmin+xmax)/2, ylab = ymax) %>% data.frame()

# add epochs to plot
lttP2.sim <- lttP.sim + geom_rect(data = epochs, aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax, fill = epoch), 
                 colour = "white", alpha = 0.3, inherit.aes = F) + 
  scale_fill_manual(values = c(setNames(epochs$colr, epochs$epoch),setNames(c("#F6511D","#A9D877","#A9D877"),c("Non-associated","Moderate","Strong"))), guide = "none") + 
  xlab("Time (Ma)") + ylab("Number of lineages")

# send the epoch box layer to the back of the plot (looks better)
lttP2.sim$layers <- c(lttP2.sim$layers[7], lttP2.sim$layers[-7])

# add epoch label to plot
lttP3.sim <- lttP2.sim + geom_text(data = epochs, aes(x = xlab, y = ylab, label = epoch), size = 3, vjust = 1.3,hjust = 0.5, inherit.aes = F)

#pdf('Figures/Siqueira_etal_Fig3.pdf', height = 5, width = 7, useDingbats = F)
lttP3.sim
#dev.off()

####################End of the script###########################