#!/usr/bin/env Rscript

#library(weights) # for wtd.hist()
library(ape)

# count dispersal to/from focal areas over time from stochastic mapping

# read table with area codes
areas = read.table('../4_its_metadata_R1/area_codes.tab', header=TRUE,
                   fill=TRUE, stringsAsFactors=FALSE, comment.char='', sep='\t')

# load stochastic maps from corHMM
load("corHMM_ER_stoch_maps.Rdata")

# set time periods (younger boundary + length define periods)
per = 0.1
timeperiods = rev(seq(0,1,per))

# get node depths from one stochastic map (same tree used for all maps)
nd = node.depth.edgelength(stoch_maps[[1]])
# match to edges (child nodes)
nd = nd[stoch_maps[[1]]$edge[,2]]

# (1) function to list dispersals and dispersal times for a single branch in a tree/stochastic map
get_disp = function(n, i) {
  # n: number of stochastic map
  # i: number of edge
  # if branch is in a single state -> no dispersal
  if (length(stoch_maps[[n]]$maps[[i]])==1) {
    dmat = NULL
    # else, get dispersals
  } else {
    # length of the states per branch - 1 = number of dispersals
    # states/areas per branch
    states = names(stoch_maps[[n]]$maps[[i]])
    nst = length(states)
    # output a matrix: column 1 - FROM, column 2 - TO
    dmat = data.frame(from=states[1:(nst-1)], to=states[2:(nst)])
    # event times on branch
    eb = rev(cumsum(rev(stoch_maps[[n]]$maps[[i]])))
    # add absolute event times to matrix
    et = eb[2:nst] + 1 - nd[i]
    dmat$evttime = et
  }
  # return dispersal matrix
  return(dmat)
}

# function to output all dispersals for one tree/map as one matrix/dataframe
get_disp_tree = function(n) {
  ne = nrow(stoch_maps[[n]]$edge)
  all_d = lapply(1:ne, get_disp, n=n)
  all_d = do.call(rbind, all_d)
  return(all_d)
}

# get dispersal times for all stochastic maps
disp_all_maps = lapply(1:length(stoch_maps), get_disp_tree)

# function to count dispersals for one map across time intervals (returns vector)
count_disp = function(disp_matrix, from="", to="") {
  if (! from=="") {
    disp_matrix = disp_matrix[disp_matrix$from==from,]
  }
  if (! to=="") {
    disp_matrix = disp_matrix[disp_matrix$to==to,]
  }
  c = hist(disp_matrix$evttime, breaks=timeperiods, plot=FALSE)$counts
  # return reverse count vector
  return(rev(c))
}

# function to get means + confidence intervals for dispersals to and from an area,
  # or between two specific areas
count_disp_to_from = function(area1, area2='') {
  # data frame with time periods and dispersals TO
  df = data.frame(time=timeperiods[-1])
  df$from = rep(area2, nrow(df))
  df$to = rep(area1, nrow(df))
  # add rows for dispersals FROM
  df = rbind(df, data.frame(time=df$time, to=df$from, from=df$to))
  # count dispersals to and from for all stochastic maps
  disp_to = sapply(disp_all_maps, count_disp, to=area1, from=area2)
  disp_from = sapply(disp_all_maps, count_disp, to=area2, from=area1)
  # add median and quantiles of dispersal counts to data frame
  df$median = c( apply(disp_to, 1, median),
               apply(disp_from, 1, median) )
  df$q025 = c( apply(disp_to, 1, quantile, probs=0.025),
               apply(disp_from, 1, quantile, probs=0.025) )
  df$q975 = c( apply(disp_to, 1, quantile, probs=0.975),
               apply(disp_from, 1, quantile, probs=0.975) )
  return(df)
}
 
# count dispersals to and from all focal areas and combine in one data frame
disp_to_from = lapply(areas$code[areas$focal], count_disp_to_from)
disp_to_from = do.call(rbind.data.frame, disp_to_from)

# also count dispersals specifically between tr. S America and Afrotropic
disp_sam_afr = count_disp_to_from('1', '9')

# save both data frames to files
write.table(disp_to_from, file='disp_counts_focal.tab', sep='\t', row.names=FALSE)
write.table(disp_sam_afr, file='disp_counts_sam_afr.tab', sep='\t', row.names=FALSE)

