

plotTree <- function(tree, numLab, tips2change = NULL){
  if(!is.null(tips2change)){
    tree<-renameTips(tree, tips2change)
  }
  p<-ggtree::ggtree(tree) + 
    ggtree::geom_label(ggplot2::aes_string(x="branch", label=numLab),
                       fill='lightgreen', hjust=1) + 
    ggtree::geom_label(ggplot2::aes(label=name),
                       fill='lightblue', hjust = -0.03) +
    ggplot2::coord_cartesian(clip = 'off') +
    ggtree::theme_tree2(plot.margin=ggplot2::margin(0, 2.5, 0, 0 ,"cm"))
  p
}


renameTips <- function(tree, tips2keep){
  treedf <- tidytree::as_tibble(tree)
  #browser()
  treedf$name <- sapply(treedf$name, function (x, names){
    invNames <- setNames(names(names), names)
    newName <- unname(invNames[x])
    ifelse(is.na(newName), x, newName)
  }, names=tips2keep)
  
  tidytree::as.treedata(treedf)
  
}


plotDuplications <- function(trees,
                             cosmic,
                             ensembl_gene_id = NULL, 
                             Gene.Symbol = NULL,
                             tips2change = NULL){
  
  if(is.null(ensembl_gene_id) & is.null(Gene.Symbol)) {
    stop("Either ensembl_gene_id or Gene.Symbol must be supplied!")
  }
  
  if(!is.null(ensembl_gene_id)){
    Gene.Symbol <- unique(getSymbol(ensembl_gene_id, cosmic))
  }
  
  if(!is.null(Gene.Symbol) & is.null(ensembl_gene_id)){
    ensembl_gene_id <- getEnsemblID(Gene.Symbol, cosmic)
    if(length(ensembl_gene_id)>1)stop("gene name has multiple ensemblIDs!")
  }
  
  tree <- trees[[ensembl_gene_id]]
  
  grob <- grid::grobTree(grid::textGrob(paste0(Gene.Symbol,", ", ensembl_gene_id),
                                        x=0.01,  y=0.96, hjust=0,
                                        gp=grid::gpar(col="black", fontsize=10, fontface="bold")))
  p<-plotTree(tree, "n", tips2change)
  p + ggplot2::annotation_custom(grob)
}


calcNodePresence <- function(trees, cosmic){
  # for info see  http://www.ensembl.org/Help/View?id=379 
  # pvalues indcate copy number change events, hence can be ignored 
  combiTree <- NULL
  for(nr in 1:length(trees)){
    name <- names(trees[nr])
    treedf <- tidytree::as_tibble(trees[[nr]])
    treedf$genePresent <- ifelse(treedf$n > 0, 1, 0)
    treedf$geneName <- ifelse(treedf$n > 0, getSymbol(name, cosmic), "")
    combiTree <- rbind(combiTree, treedf) 
  }
  
  combiAgg <- aggregate(genePresent ~ node, data=combiTree, FUN = sum, drop = FALSE)
  # the treeDF may not be ordered !!!
  if(nrow(treedf) != nrow(combiAgg)){
    warning("Some species dropped! Tree may not be accurate!")
  }
  treedf$genePresent <- NULL
  treedf <- dplyr::left_join(treedf, combiAgg, by = "node")
  tree<-tidytree::as.treedata(treedf)
  
  
  nodeGenes <- aggregate(geneName ~ node, data=combiTree,
                         FUN = function(x){
                           #TODO why this unique is nessesary!?
                           unique(x[x != ""]) 
                         },
                         drop = FALSE)
  nodeNames <- unique(combiTree[c("node", "name")])
  nodeGenes<-merge(nodeGenes, nodeNames, by="node")
  
  names(nodeGenes$geneName) <- names(nodeGenes$name)
  return(list(tree = tree, nodeGenes = nodeGenes))
}

geneGainLossDf <- function(nodeName, nodePrecenceData, tipsOnly = T){
  tree <- nodePrecenceData$tree
  nodeGenes <- nodePrecenceData$nodeGenes
  nodeChildren <- getNodeChildren(tree, nodeName, tipsOnly)
  childNodes <- as.data.frame(t(sapply(names(nodeChildren), splitCafeLabel)))
  if(tipsOnly){
    childNodes <- rbind(childNodes, data.frame(list(name=nodeName, n=NA, pval=NA)))
  }
  nodedf <- merge(childNodes, nodeGenes, by = "name", all.x = T)
  nodedf$geneName <- sapply(nodedf$geneName, unique) #TODO why this is nessesary!?
  nodedf$n <- sapply(nodedf$geneName, length)
  nodedf$pval <- NULL
  nodeGenes <- unlist(nodedf[nodedf$name == nodeName, ][["geneName"]])
  nodedf$gainedGenes <- sapply(nodedf$geneName, function(x, nodeGenes){
    setdiff(x, nodeGenes)}, nodeGenes)
  nodedf$lostGenes <- sapply(nodedf$geneName, function(x, nodeGenes){
    setdiff(nodeGenes, x)}, nodeGenes)
  nodedf
}

plot_sp_tree_round <- function(p, plotData, ...){
  
  dfSumStat <- plotData$dfSumStat
  panelName <- plotData$panelName
  confLimits <- plotData$confLimits
  tagLabel <- plotData$tagLabel
  barColorVar <- plotData$barColorVar
  
  if(length(unique(dfSumStat$fillColor)) == 1){
    mainaes <- ggplot2::aes(x = var,  y=sp)
  }
  else {
    mainaes <- ggplot2::aes(x = var,  y=sp, fill=fillColor)
  }
  
  geom_col <- ggplot2::geom_col
  
  # Add additional data on the outer rings (optional)
  p <- p + ggtreeExtra::geom_fruit(data = dfSumStat, 
                                   geom = geom_col, 
                                   offset = 0.2,
                                   mapping = mainaes,
                                   axis.params = list(
                                     axis       = "x",
                                     text.size  = 4,
                                     hjust      = 0.5,
                                     vjust      = 0.5,
                                     nbreak     = 3,
                                     limits     = c(-1,1),
                                   ),
                                   grid.params = list()
  )
  
  
  if(confLimits %in% c("CI", "SD", "SE")){
    #TODO
    
  }
  
  
  return(p) 

}


plot_sp_tree_linear <- function(p, plotData,
                                splitCol=NULL,
                                facetPos=c(0.3,0.1,0.1, 0.5),
                                legend.position=c(.05, .85),
                                text.size = 20, ...){
  
  dfSumStat <- plotData$dfSumStat
  panelName <- plotData$panelName
  confLimits <- plotData$confLimits
  tagLabel <- plotData$tagLabel
  barColorVar <- plotData$barColorVar
  
  if(length(unique(dfSumStat$fillColor)) == 1){
    mainaes <- ggtree::aes(x = var, xmin = lower, 
                            xmax = upper)
  }
  else {
    mainaes <- ggtree::aes(x = var,  fill=fillColor, xmin = lower, 
                            xmax = upper)
  }
  
  if(!is.null(splitCol) && splitCol %in% names(dfSumStat)){
    dfSumStat$fillColor <- dfSumStat[[splitCol]]
    mainaes <- ggplot2::aes(x = var,  fill=fillColor, xmin = lower, 
                            xmax = upper)
  } 
  p <- p + ggnewscale::new_scale_fill()
  p <- p + ggtree::geom_facet(panel = panelName, data = dfSumStat, 
                              geom = ggstance::geom_barh, 
                              position = ggstance::position_dodgev(height = .9),
                              mainaes, 
                              stat = "identity", width = .5)
  if(confLimits %in% c("CI", "SD", "SE")){
    
    p <- p + ggtree::geom_facet(panel = panelName, data = dfSumStat, 
                                geom = ggstance::geom_errorbarh, 
                                position = ggstance::position_dodgev(height = .9),
                                mapping=mainaes,
                                color = "black",
                                height = .2)
    
    p <- p + ggtree::geom_facet(panel = panelName, data = dfSumStat, 
                                position = ggstance::position_dodgev(height = .9),
                                geom = ggplot2::geom_point, 
                                mapping=mainaes,
                                color = "black",
                                size = 1)
    
  }
  
  
  
  #add N if requested
  #p <- p + ggtree::geom_facet(panel = "N", data = dfSumStat, geom = #ggstance::geom_barh, 
  #                ggtree::aes(x = conc_z.N, fill = airBreathing), 
  #                stat = "identity", width = .5) 
  
  p <- p + ggtree::theme_tree2(legend.position=legend.position,
                               axis.text.x=ggplot2::element_text(size=text.size),
                               strip.text.x = ggplot2::element_text(size = text.size),
                               strip.background = ggplot2::element_blank(),
                               legend.title = ggplot2::element_text(size=text.size),
                               legend.text = ggplot2::element_text(size=text.size)) 
  p <-ggtree::facet_widths(p, facetPos)
  
  
  treePlot<- p + ggplot2::labs(tag=paste0("Millions of Years",tagLabel)) +
    ggplot2::theme(plot.tag.position = "bottom",
                   plot.tag = ggplot2::element_text(size=text.size))
  treePlot
}

plot_sp_tree <- function(treeReduced, dfSumStat,
                                textColorVar = NULL,
                         barColorVar = NULL,
                                dataVar=NULL, confLimits="CI",
                         dataLab="LC50",
                         addTipLabels=TRUE,
                         layout="rectangular",
                         shiftAnnotVar = -2.5,
                         shiftAnnotVarText = -2.5,
                         fanFontSize = 6,
                         fanOffset = 100,
                         annotCol = NULL, ...){
  
  p <- ggtree::ggtree(treeReduced, size =1.1, layout = layout, open.angle = 20)
  
  #attach data to tree
  p <- ggtree::`%<+%`(p, dfSumStat)
  
  if(addTipLabels){
    p <- p + add_tip_labels(layout=layout, textColorVar = textColorVar)
  
  }
  
  if(is.null(dataVar)){
    return(p)
  } else {
    plotData <- getPlottingVariables(dfSumStat, dataVar,barColorVar,dataLab, textColorVar, confLimits)
  }
  if(!is.null(annotCol)){
    col2Split <- annotCol
    #get the mrca for each order
    if(col2Split %in% colnames(dfSumStat)){
      orders <- as.character(unique(dfSumStat[,col2Split]))
      orderNodes <- sapply(orders, function(x){
        # Specify the set of tip labels
        tip_labs <- dfSumStat[dfSumStat[,col2Split] == x, "sp"]
        
        # Find the MRCA of the set of tip labels
        tryCatch(mrca <- ape::getMRCA(treeReduced, tip_labs),
                 error = function(e){
                   browser()
                   mrca <- NULL
                 })
        
        if(is.null(mrca)){
          #find the node of the tip
          nodeLab <- tip_labs[1]
          nodeNum <- treeReduced$edge[, 2] == match(nodeLab, treeReduced$tip.label)
          mrca <- treeReduced$edge[nodeNum, 1]
        }
        return(mrca)
      })
      orderNodes <- data.frame(order = orders, node = orderNodes)
      if(layout %in% c("circular", "fan")){
        p <- p + ggnewscale::new_scale_color()
        p <- p+ ggtree::geom_cladelab(data = orderNodes,
                                      mapping = ggplot2::aes(node = node, 
                                                             color = order,
                                                             label = order),
                                      fontsize = fanFontSize,
                                      horizontal=FALSE,
                                      offset.text = 10,
                                      align = TRUE,
                                      parse = TRUE,
                                      offset = fanOffset,
                                      hjust = 0,
                                      vjust = 0,
                                      barsize = 1,
                                      size = 1,
                                      linetype = "solid",
                                      angle="auto",
                                      show.legend = TRUE
        )
      }
      if(layout %in% c("rectangular", "ellipse","roundrect", "slanted")){
        #get the total maximum edge length to offset the labels
        offsetLab <- max(ape::branching.times(tree)) * -0.5
        p <- p + ggnewscale::new_scale_color()
        p <- p+ ggtree::geom_cladelab(data = orderNodes,
                                      mapping = ggplot2::aes(node = node, 
                                                             color = order,
                                                             label = order),
                                      fontsize = fanFontSize,
                                      horizontal=F,
                                      offset.text = shiftAnnotVarText,
                                      align = TRUE,
                                      parse = TRUE,
                                      offset = offsetLab + shiftAnnotVar,
                                      hjust = 0.5,
                                      vjust = 0.5,
                                      barsize = 2,
                                      size = 2,
                                      linetype = "solid",
                                      angle=90,
                                      show.legend = FALSE
        )
      }
      
      
    } else {
      warning(annotCol, " not in data!")
    }
    
  }
  
    
  if(layout %in% c("rectangular", "ellipse","roundrect", "slanted")){
  
    p <- plot_sp_tree_linear(p, plotData, ...)
  } 
  
  if(layout %in% c("circular", "fan")){
    p <- plot_sp_tree_round(p, plotData, ...)
    
  }
  
  if(!is.null(barColorVar) && barColorVar %in% colnames(dfSumStat)){
    if(barColorVar == "Order"){
      barColorVar <- "order"
    }
    p <- p + ggplot2::labs(fill = barColorVar)
  }
  
  if(!is.null(textColorVar) && textColorVar %in% colnames(dfSumStat)){
    if(textColorVar == "Order"){
      textColorVar <- "order"
    }
    p <- p + ggplot2::labs(color = textColorVar)
  }
  
  
  
  p
}


getPlottingVariables <- function(dfSumStat, dataVar,barColorVar,dataLab, textColorVar,  confLimits="CI"){
  if(dataVar %in% colnames(dfSumStat)){
    dfSumStat$var <- dfSumStat[[dataVar]]
  }else{
    stop(dataVar," not in data!")
  }
  if(grepl("conc_z", dataVar)){
    tagLabel <- " LC50 (z-score)"
  }else if(grepl("conc_mean", dataVar)){
    tagLabel <- " LC50 (mg/L)"
  }
  else if(grepl("conc_adj", dataVar)){
    tagLabel <- " LC50 (proportion of mean)"
  } else {
    tagLabel <- dataLab
  }
  
  if(grepl("mean", dataVar)){
    dataLab <- paste0("mean ", tagLabel)
    tagLabel <- paste0(" and mean", tagLabel)
  }
  if(grepl("med", dataVar)){
    dataLab <- paste0("median ", tagLabel)
    tagLabel <- paste0(" and median", tagLabel)
  }
  
  
  if(!is.null(barColorVar) && barColorVar %in% colnames(dfSumStat)){
    dfSumStat$fillColor <- dfSumStat[[barColorVar]]
  }else{
    dfSumStat$fillColor <- "black"
  }
  
  if(!is.null(textColorVar) && textColorVar %in% colnames(dfSumStat)){
    dfSumStat$textColor <- dfSumStat[[textColorVar]]
  }else{
    dfSumStat$textColor <- "black"
  }
  
  
  if(confLimits == "CI"){
    dfSumStat$upper <- dfSumStat[[paste0(dataVar,".upper")]]
    dfSumStat$lower <- dfSumStat[[paste0(dataVar,".lower")]]
    panelName <- paste0(dataLab, " ± 95% CI")
  }
  else if(confLimits == "SD"){
    dfSumStat$upper <- dfSumStat[[paste0(dataVar,".mean")]] + dfSumStat[[paste0(dataVar,".sd")]]
    dfSumStat$lower <- dfSumStat[[paste0(dataVar,".mean")]] - dfSumStat[[paste0(dataVar,".sd")]]
    panelName <- paste0(dataLab, " ± SD")
  }
  
  else if(confLimits == "SE"){
    dfSumStat$upper <- dfSumStat[[paste0(dataVar,".mean")]] + dfSumStat[[paste0(dataVar,".se")]]
    dfSumStat$lower <- dfSumStat[[paste0(dataVar,".mean")]] - dfSumStat[[paste0(dataVar,".se")]]
    panelName <- paste0(dataLab, " ± SE")
  } else{
    panelName <- dataLab
    dfSumStat$upper <- 0
    dfSumStat$lower <- 0
  }
  
  return(list(dfSumStat=dfSumStat,
              tagLabel=tagLabel,
              panelName=panelName,
              confLimits=confLimits,
              barColorVar = barColorVar,
              textColorVar=textColorVar))
}


add_tip_labels <- function(layout, textColorVar){
  mainAes <- ggplot2::aes()
  if(!is.null(textColorVar)){
    mainAes <- ggplot2::aes_string(color = textColorVar)
  }
  
  if(layout %in% c("circular", "fan")){
    lab <-  ggtree::geom_tiplab(align = T,
                                mapping=mainAes,
                                hjust = 0,
                                offset = 110,
                                fontface="italic",
                                label.size = NA,
                                size=8)
  }
  if(layout %in% c("rectangular", "ellipse","roundrect", "slanted")){
    lab <- ggtree::geom_tiplab(align = T,
                               mapping=mainAes,
                                 hjust = 1,
                                 offset = 180,
                                 fontface="italic",
                                 geom = "label",
                                 label.size = NA,
                                 size=8)
  }
  return(lab)
}




