#!/usr/bin/env Rscript

# extract dispersal counts between areas from corHMM/Mk1 stochastic maps

# load stochastic maps
load("corHMM_ER_stoch_maps.Rdata")

# read table with area codes
areas = read.table('../4_its_metadata_R1/area_codes.tab', header=TRUE, fill=TRUE, stringsAsFactors=FALSE,
                   comment.char='', sep='\t')

# (1) function to list dispersals for a single branch in a tree/stochastic map
get_disp = function(n, i) {
  # n: number of stochastic map
  # i: number of edge
  # if branch is in a single state -> no dispersal
  if (length(stoch_maps[[n]]$maps[[i]])==1) {
    dmat = NULL
  # else, get dispersals
  } else {
    # length of the states per branch - 1 = number of dispersals
    # states/areas per branch
    states = names(stoch_maps[[n]]$maps[[i]])
    nst = length(states)
    # output a matrix: column 1 - FROM, column 2 - TO
    dmat = cbind(states[1:(nst-1)], states[2:(nst)])
  }
  # return dispersal matrix
  return(dmat)
}

# (2) get all dispersals for all trees/stochastic maps

# function to output all dispersals for one tree/map as one matrix
get_disp_tree = function(n) {
  ne = nrow(stoch_maps[[n]]$edge)
  all_d = lapply(1:ne, get_disp, n=n)
  all_d = do.call(rbind, all_d)
  return(all_d)
}

# function to convert line-based matrix to one/two-sided count matrix
make_count_matrix = function(mat, states=areas$code, dimn=areas$abbreviation, onesided) {
  # onesided: TRUE or FALSE
  nst = length(states)
  cmat = matrix(NA, ncol=nst, nrow=nst, dimnames = list(dimn, dimn) )
  # per row...
  for (i in 1:nst) {
    # if one-sided, count dispersals in both directions
    for (j in 1:nst) {
      if (!i==j) {
        cmat[i,j] = sum(mat[,1]==states[i] & mat[,2]==states[j])
      }
    }
  }
  # if one-sided, combine dispersals in both directions in lower half of matrix
  if (onesided) {
    for (i in 1:nst) {
      for (j in 1:i) {
        cmat[i,j] = cmat[i,j] + cmat[j,i]
      }
    }
    cmat[upper.tri(cmat)] = NA
  }
  return(cmat)
}

# apply to all trees/maps
disp_all_maps = lapply(1:length(stoch_maps), get_disp_tree)
disp_all_maps = lapply(disp_all_maps, make_count_matrix, onesided=TRUE)


# (3) make summary matrices

make_summary_matrix = function(mat_list, states=as.character(1:9), cnames=areas$abbreviation, FUN, ...) {
  n = length(mat_list)
  nst = length(states)
  smat = matrix(NA, ncol=nst, nrow=nst, dimnames = list(areas$abbreviation, areas$abbreviation) )
  # for each area combination...
  for (i in 1:nst) {
    for (j in 1:nst) {
      # get vector of counts in all stochastic maps
      d = sapply(1:n, function(x) mat_list[[x]][i,j])
      # get summary statistic, if not all NAs (one-sided matrix)
      if (! all(is.na(d))) {
        smat[i,j] = FUN(d, ...)
      }
    }
  }
  return(smat)
}

# get means and quantiles of dispersal rates summed for both directions
disp_median = make_summary_matrix(disp_all_maps, FUN=median)
disp_025 = make_summary_matrix(disp_all_maps, FUN=quantile, probs=0.025)
disp_975 = make_summary_matrix(disp_all_maps, FUN=quantile, probs=0.975)

# write summary matrices to file
write.table(disp_median, file='disp_matrix_median.txt', sep='\t')
write.table(disp_025, file='disp_matrix_025.txt', sep='\t')
write.table(disp_975, file='disp_matrix_975.txt', sep='\t')



