# This file contains the source code for the asynchronous generations coevolutionary model. Functions are as follows:
# 
# mismatch_single_gen - handles events for each time step. This includes interactions, reproduction, and death. 
#
# mismatch_n_generations - calls the single generation function n times where n is the number of time steps the simulation
# will run for.
#
# mismatch_build_output - builds a dataframe to store the summary output for a single simulation. Has n rows, one
# for each time step, and one column for each summary statistic. Also includes columns for age structured summaries, up
# to age class 500.
#
# aged_summaries - calculates the summary stats for each population as well as age structured summary stats if either has
# mortality lower than 1. 
#
# interact - handles the random pairing of individuals for pairwise interactions. Calls the offset match function
#
# offset_match - function for calculating fitness under the offset match mechanism. Returns fitness for each partner in a 
# given interaction.
#
# marked_for_death - determines which individuals will die if either species has mortality lower than 1. Marks those 
# individuals in the population level dataframe.
#
# stochastic_death_mate - mating function for stochastic death model. Sets number of new offspring equal to number of 
# individuals marked for death.

# stochastic_death - removes individuals that were marked, adds +1 buff to age of all survivors.


### Loading in libraries ###

library(dplyr) # data management tools
library(tidyr) # more data management tools



####### Population creation for single pop tests #########



# 
# test_df <- data.frame(trait=rnorm(100,1,1),
#                       fit=rnorm(100,4,0.2),
#                       age=c(rep(1,25),rep(2,25), rep(3,25),rep(4,25))
# )
# 
# 
# 
# 

# This function builds a data frame for summary stats of both populations at each generation. It is, in general,
# good to preallocate memory in R particularly for large iterative processes like simulations, so this is created 
# at the start of each simulation. It has a row for each generation and a column for each summary statistic. When
# mortality for either population is below 1 (so any time we have more than one age class co-existing), this also
# creates columns for the same summary statistics but structured by age-class. Since the number of age classes
# varies stochastically based on mortality rates, I opt to create a very large number of age-classes in these summary
# tables (500 age classes). This is likely an unnecessarily large number for most simulations, but I opted to be conservative
# in case any simulation happened to have an individual who really beat the odds and lived beyond our expectations.

mismatch_build_output <- function(n,mort){
  
  poll_data <- plant_data <- data.frame(generation = c(1:n),trait_means=numeric(n), trait_sd = numeric(n), 
                                        int_means=numeric(n), int_sd=numeric(n), rel_fit_mean =numeric(n), 
                                        rel_fit_sd=numeric(n), sd_ratio=numeric(n), pop_size = numeric(n), fit_corr=numeric(n),
                                        rel_fit_corr=numeric(n),
                                        sel_differential= numeric(n), non_rel_sel_diff=numeric(n))
  
  
  old_names <- names(plant_data)[-1]

  if (mort[1]<1){
    #for plants with mortality below 1
    
    name_classes <- expand.grid(old_names, c(1:500))
    
    new_names <- paste0("A", name_classes$Var2,"_",  name_classes$Var1)
    
    num_col <- 500*(ncol(plant_data)-1)
    
    new_frame <- data.frame(matrix(data=0, nrow=n, ncol=num_col))
    
    names(new_frame) <- new_names
    
    plant_data <- cbind(plant_data, new_frame)

  }
  if (mort[2]<1){
    
    #for pollinators with mortality below 1
    name_classes <- expand.grid(old_names, c(1:500))

    new_names <- paste0("A", name_classes$Var2,"_",  name_classes$Var1)
    
    num_col <- 500*(ncol(poll_data)-1)
    
    new_frame <- data.frame(matrix(data=0, nrow=n, ncol=num_col))
    
    names(new_frame) <- new_names

    poll_data <- cbind(poll_data, new_frame)
    

  }
  
  return(list(plants=plant_data, polls=poll_data))
  
}



# testing the function with 10 generations and plant mortality of 0.5.

#mismatch_build_output(10, c(0.5,1))

# This function calculates the actual summary stats for each species at a given timestep, including the selection
# differential. There are also summary statistics collected here that are not included in the manuscript. The function
# takes the population dataframe for a single species and the mortality value for that parameterization. Since mortality
# is stochastic, some age classes may have no individuals in them at a given time step (i.e. a given cohort may die out while
# older individuals survive). To account for this, I check if there are any age classes with length 0, and if so, I create
# dummy individuals with trait values of 0. This is just to keep output vectors the correct size for the summary 
# data frame, otherwise it throws errors because of mismatched lengths etc.

aged_summaries <- function(df, mort_val){
  
  
  df$relative_fit <- df$fit/mean(df$fit)

  
  
  dif_model <- lm(df$relative_fit ~ df$trait)
  non_rel_mod <- lm(df$fit ~ df$trait)
  
  
  pop_trait_mean <- mean(df$trait)
  pop_trait_sd <- sd(df$trait)
  pop_fit_mean <- mean(df$fit)
  pop_fit_sd <- sd(df$fit)
  pop_rel_fit_mean <- mean(df$relative_fit)
  pop_rel_fit_sd <- sd(df$relative_fit)
  pop_size <- length(df$trait)
  sd_ratio <- pop_trait_sd/pop_rel_fit_sd
  fit_correlation <- cor(df$trait, df$fit)
  rel_fit_correlation <- cor(df$trait, df$relative_fit)
  sel_differential <- dif_model$coeff[2]
  non_rel_sel_diff <- non_rel_mod$coeff[2]
  age_distro <- table(df$age)
  
  
  if(mort_val<1){
    
    #note here that I am using 100 as the sort of 'max age' setting. If you are getting age classes above 100 then you need to expand this a bit, as well as the size of the summary data frame.
    missing_age <- which(!(1:500 %in% df$age))
    
    if(length(missing_age)==0){
      aged_summ <- df %>%
        group_by(age) %>%
        summarise(
          A_trait_mean = mean(df$trait),
          A_trait_sd = sd(df$trait),
          A_fit_mean = mean(df$fit),
          A_fit_sd = sd(df$fit),
          A_rel_fit_mean = mean(df$relative_fit),
          A_rel_fit_sd = sd(df$relative_fit),
          A_size = length(df$trait),
          A_sd_ratio = pop_trait_sd/pop_rel_fit_sd,
          A_fit_correlation = cor(df$trait, df$fit),
          A_rel_fit_correlation = cor(df$trait, df$relative_fit),
          A_sel_diff = sel_differential,
          A_non_rel_diff = non_rel_sel_diff
        )
      
      
      
      full_sums <- c(pop_trait_mean, pop_trait_sd, 
                     pop_fit_mean, pop_fit_sd, pop_rel_fit_mean, pop_rel_fit_sd, sd_ratio, pop_size,fit_correlation, 
                     rel_fit_correlation, sel_differential, non_rel_sel_diff,
                     aged_summ$A_trait_mean, aged_summ$A_trait_sd, 
                     aged_summ$A_fit_mean, aged_summ$A_fit_sd,
                     aged_Summ$A_rel_fit_mean, aged_summ$A_rel_fit_sd, aged_summ$A_sd_ratio,
                     aged_summ$A_size, aged_summ$A_fit_correlation, aged_summ$A_rel_fit_correlation, aged_summ$A_sel_diff, 
                     aged_summ$A_non_rel_diff)
      

      
      return(list(full_sums=full_sums, age_distro=age_distro))
    }else{
      for(i in missing_age){
        new_vec <- c(0,i,0)
        df <- rbind(df, new_vec)
      }
      aged_summ <- df %>%
        group_by(age) %>%
        summarise(
          A_trait_mean = mean(df$trait),
          A_trait_sd = sd(df$trait),
          A_fit_mean = mean(df$fit),
          A_fit_sd = sd(df$fit),
          A_rel_fit_mean = mean(df$relative_fit),
          A_rel_fit_sd = sd(df$relative_fit),
          A_size = length(df$trait),
          A_sd_ratio = pop_trait_sd/pop_rel_fit_sd,
          A_fit_correlation = cor(df$trait, df$fit),
          A_rel_fit_correlation = cor(df$trait, df$relative_fit),
          A_sel_diff = sel_differential,
          A_non_rel_diff = non_rel_sel_diff
        )
      
     # print(aged_summ)
      
      
      full_sums <- c(pop_trait_mean, pop_trait_sd, 
                     pop_fit_mean, pop_fit_sd, pop_rel_fit_mean, pop_rel_fit_sd, sd_ratio, pop_size,fit_correlation, 
                     rel_fit_correlation, sel_differential, non_rel_sel_diff,
                     aged_summ$A_trait_mean, aged_summ$A_trait_sd, 
                     aged_summ$A_fit_mean, aged_summ$A_fit_sd,
                     aged_summ$A_rel_fit_mean, aged_summ$A_rel_fit_sd, aged_summ$A_sd_ratio,
                     aged_summ$A_size, aged_summ$A_fit_correlation, aged_summ$A_rel_fit_correlation, aged_summ$A_sel_diff,
                     aged_summ$A_non_rel_diff)
      return(list(full_sums=full_sums, age_distro=age_distro))
    }
   
    
  
  }else{
    full_sums <- c(pop_trait_mean, pop_trait_sd, 
                   pop_fit_mean, pop_fit_sd, pop_rel_fit_mean, pop_rel_fit_sd, sd_ratio, pop_size, fit_correlation, rel_fit_correlation,
                   sel_differential, non_rel_sel_diff)
    return(list(full_sums=full_sums, age_distro=age_distro))
  }


}


# Here is the actual offset fitness function. This is vectorised such that it can take the full column of each parameter 
# from the population data frame. This is faster but slightly more difficult to read as coded. This returns both fitness
# functions from a given interaction.

offset_match <- function(beta_plant, beta_poll, off_plant, off_poll, b_plant, b_poll, plant, poll){
  diff2_plant_poll <- (poll + off_plant - plant)^2
  diff2_poll_plant <- (plant + off_poll - poll)^2
  fit_plant <- exp((-beta_plant/2)*diff2_plant_poll)
  fit_poll <- exp((-beta_poll/2)*diff2_poll_plant) 
  int_type_plant <- b_plant*fit_plant
  int_type_poll <- b_poll*fit_poll
  
  
  
  return(list(fit_plant=int_type_plant, fit_poll=int_type_poll))
}

# this function handles the randomized interactions between individuals. The pop and inds parameters here
# are for ensuring interactions are randomized across each population. The inds lines just generate a randomized
# permutation of the indices for each population and the pop objects just say how large those permutations should be.

interact <- function(dfs, param_df){
  
  plant_pop <- length(dfs$plant$trait)
  poll_pop <- length(dfs$poll$trait)
  
  dfs$plant$fit <- 0
  dfs$poll$fit <- 0
  

  plant_inds <- sample.int(n=plant_pop)
  poll_inds <- sample.int(n=poll_pop)

  
  fitness_values <- offset_match(beta_plant = param_df['plant',]$beta, beta_poll = param_df['poll',]$beta, 
                                 off_plant = param_df['plant',]$offset, off_poll=param_df['poll',]$offset, 
                                 b_plant=param_df['plant',]$b, b_poll=param_df['poll',]$b,
                                 plant = dfs$plant$trait[plant_inds], poll = dfs$poll$trait[poll_inds])

  
  dfs$plant$fit[plant_inds] <- fitness_values$fit_plant
  dfs$poll$fit[poll_inds] <- fitness_values$fit_poll

  return(dfs)
  
}






# This function handles simulating over n generations. It builds the final output summary dataframe using build_output,
# and intializes some other lists/vectors for recording information and output. It then loops through n generations,
# calling the single-generation function (below), separating plant and poll outputs, and placing the results of each generation 
# into its relevant row in the output dataframes. Lastly, it combines the two results dataframes into a single dataframe along
# with the average number of deaths in the longer-lived (plant) species and calculates the slope of the trait value
# over time for each species.

mismatch_n_generations <- function(n_generations = 100, spp_info, params_df, Mort_list){
  
  
  both_dfs <- mismatch_build_output(n_generations, Mort_list)
  
  plant_results <- both_dfs$plants
  poll_results <- both_dfs$polls
  
  demo_change_plant <- numeric(n_generations)
  age_distributions <- list(plant=list(),
                            poll=list())
  

  for(i in 1:n_generations){
    
    
    
    one_output <- mismatch_single_gen(dfs=spp_info, param_df = params_df, Mort_list = Mort_list)
    
    plant_results[i,-1] <- one_output$plant_summ
    poll_results[i, -1] <- one_output$poll_summ
    
    age_distributions$plant[[i]] <- one_output$plant_age_distro
    age_distributions$poll[[i]] <- one_output$poll_age_distro
    
    spp_info <- one_output$dfs
    
    demo_change_plant[i] <- one_output$new_plants
    
    #single gen will output a list with 2 entries, a list of plant/poll traits and ages, and then a list of fitness values.
    

    
  }
  

  big_df <- do.call(cbind.data.frame, list(plants=plant_results, polls=poll_results, demo_change=demo_change_plant))
  
  
  t_2000 <- big_df %>%
    filter(plants.generation > 2000)

  
  plant_mod <- lm(plants.trait_means ~ plants.generation, data=t_2000)
  poll_mod <- lm(polls.trait_means ~ polls.generation, data=t_2000)
  big_df$plant_slope <- plant_mod$coeff[2]
  big_df$poll_slope <- poll_mod$coeff[2]
  
  return(list(big_df=big_df, age_distributions=age_distributions))

}


# This function handles a single generation of the life cycle for both species. It handles interactions, aged summaries,
# both mortality and reproduction, merging of off spring data and survivor data, then returns the pop dfs and summary vectors.


mismatch_single_gen <- function(dfs, param_df, Mort_list){
  

   #Progression goes:
  
    #Interact
    #Summarize
    #Mate
    #death and age
    #rbind new gen and living
    #return dfs and summaries
  
    dfs <- interact(dfs, param_df)

    plant_summ <- aged_summaries(dfs$plant, Mort_list$plant)
    plant_age_distro <- plant_summ$age_distro
    plant_summ <- plant_summ$full_sum
    
    
    
    poll_summ <- aged_summaries(dfs$poll, Mort_list$poll)
    poll_age_distro <- poll_summ$age_distro
    poll_summ <- poll_summ$full_sum

    
    
    dfs$plant <- marked_for_death(dfs$plant, Mort_list$plant)
    dfs$poll <- marked_for_death(dfs$poll, Mort_list$poll)
    
    
    
    youngins_plant <- sum(dfs$plant$marked)
    youngins_poll <- sum(dfs$poll$marked)

    
    
    new_gen_plant <- data.frame(trait=numeric(youngins_plant), age=rep(1,youngins_plant), fit=numeric(youngins_plant))
    new_gen_poll <- data.frame(trait=numeric(youngins_poll), age=rep(1,youngins_poll), fit=numeric(youngins_poll))

    
    new_gen_plant$trait <- stochastic_death_mate(df = dfs$plant, additive_variance = param_df['plant',]$additive_variance)
  
    new_gen_poll$trait <- stochastic_death_mate(df = dfs$poll, additive_variance = param_df['poll',]$additive_variance)
    
    

    if(Mort_list$plant < 1){
      int_df_plant <- stochastic_death(dfs$plant)
      dfs$plant <- rbind(new_gen_plant, int_df_plant)
    }else{
      dfs$plant <- new_gen_plant
    }
    
    if(Mort_list$poll < 1){
      int_df_poll <- stochastic_death(dfs$poll)
      dfs$poll <- rbind(new_gen_poll, int_df_poll)
      
    }else(
      dfs$poll <- new_gen_poll
    )
    
    return(list(dfs = dfs, plant_summ = plant_summ, poll_summ = poll_summ, new_plants=youngins_plant, plant_age_distro=plant_age_distro, poll_age_distro=poll_age_distro))
    
    
}
  











# This function handles mating when we have stochastic death and constant population sizes. It counts the number of 
# individuals marked for death (N_birth) and generates offspring trait values for that many new individuals.
# I toyed initially with a possible alternative mechanism for the demographic constraint on trait change - 
# looking at how much the pop trait changes moving from parents to an equivalently sized offspring pool, then the
# change from this large offspring pool to a very small sample of offspring equivalent to the number of deaths.
# It's not an overly interesting metric, but I've left the code here for it if anyone wants to examine it.



stochastic_death_mate <- function(df, additive_variance){
  
n_birth <- sum(df$marked)

size <- length(df$trait)
big_next_gen <- numeric(size)
lil_next_gen <- numeric(n_birth)

for(i in 1:size){
  
  parent_indices <- sample.int(size, 2, replace = F, prob = df$fit)
  
  off_pheno <- ((df$trait[parent_indices[1]]+df$trait[parent_indices[2]])/2)+rnorm(1,0,(sqrt(additive_variance)/2))
  
  big_next_gen[i] <- off_pheno
}

lil_next_gen <- sample(big_next_gen, n_birth)

return(lil_next_gen)


}




# This function marks individuals for death based on a constant mortality rate mort and random numbers generated between
# 0 and 1. If the random value drawn in less than the mortality rate, then an individual is marked for death and will
# be removed from the population after they have reproduced. 

marked_for_death <- function(spp_df, mort){
  
  n_pop <- nrow(spp_df)
  
  spp_df$r_d <- runif(n_pop, 0, 1)
  
  spp_df <-spp_df %>%
    group_by(age) %>%
    mutate(marked = r_d < mort)
  
  return(spp_df)
}

#stochastic_death_mate(marked_for_death(test_df,0.1), 0.1)


# This function removes individuals who are marked for death and increases the age for any individuals who are surviving. 
# Note that this is always called after reproduction such that individuals who are 'marked' still contribute to the next
# generation. 

stochastic_death <- function(spp_df){
    
    the_living <- spp_df %>%
     filter(marked == F)
    
    the_living$age <- the_living$age +1
    
    #note that we drop the fourth and fifth columns here, which are r_d and marked respectively.
    
    return(the_living[,-c(4,5)])
  
}

#test_y <- marked_for_death(test_df, 0.1)

#stochastic_death(test_y)





