
library(randomForest)
library(caTools)

#function to apply to each module, computed diffusion mop, saves some plots
ApplyDiffusionMap<-function(seuratData, genes, find.branch=T){
  #subset genes only
  dat<-seuratData@assays[['RNA']]@data
  reducedDat<-dat[genes,]
  
  #run diffmap
  if(dim(reducedDat)[1]<100){
    dm <- destiny::DiffusionMap(t(as.matrix(reducedDat)))
    
  }else{
    dm <- destiny::DiffusionMap(t(as.matrix(reducedDat)), n_pcs = 50)
    
  }


  #compute pseudotime 
  dpt<-destiny::DPT(dm)
  
  
  if(find.branch){
    #add data to seurat object
    seuratData$dptbranch<-dpt@branch[,1]
    
    #clasify ambiguous cells by random forest
    #how about random forest classifier??
    
    
    # #Get degenes between branches
    # Idents(seuratData)<-'orig.ident'
    # DefaultAssay(seuratData)<-'RNA'
    # degenes<-FindAllMarkers(seuratData, only.pos = T, logfc.threshold = log(1.5), base= exp(1), fc.name='avg_logFC')
    # topdegenes<-degenes %>% group_by(cluster) %>% top_n(n=100, avg_logFC)
    # 
    # dat<-t(as.matrix(seuratData@assays[['RNA']]@data[unique(topdegenes$gene),!is.na(seuratData$dptbranch)]))
    # sample = sample.split(dat, SplitRatio = .75)
    # train <- sample(nrow(dat), 0.7*nrow(dat), replace = FALSE)
    # TrainSet <- dat[train,]
    # ValidSet <- dat[-train,]
    # summary(TrainSet)
    # summary(ValidSet)
    # 
    # model1 <- randomForest(x = as.matrix(TrainSet), y = as.factor(seuratData$dptbranch[rownames(TrainSet)]) , ntree=1000)
    # 
    # #predict on na values
    # dat2<-t(as.matrix(seuratData@assays[['RNA']]@data[,is.na(seuratData$dptbranch)]))
    # pred<- predict(model1, dat2, type = "class")
    # seuratData$dptbranch[is.na(seuratData$dptbranch)]<-pred
    
  }

  #need to decide which is the earliest tip by density of early cells
  times<-parse_number(sort(unique(seuratData$orig.ident)))
  
  #use day 20 as transition from E to P notation. Make sure this fits for the species you are using.
  embryonic<-grep('^E',sort(unique(seuratData$orig.ident)))
  times[embryonic]<-times[embryonic] - 20
  names(times)<-sort(unique(seuratData$orig.ident))
  seuratData$timeval<-times[seuratData$orig.ident]
  
  #get k nearest neighbours by eigenvalue distances
  ev<-dm@eigenvectors
  distance<-dist(ev)
  distance<-as.matrix(distance)
  k<-floor(dim(seuratData)[2]/10)
  
  #score tips by cumulative timepoint value
  score<-rep(NA,3)
  names(score)<-names(which(dpt@tips[,1]))
  
  #get k nearest neighbours for each tip, and sum all the time values. smallest score = earliest tip. 
  for(tip in names(which(dpt@tips[,1]))){
    
    neighbours<-names(sort(distance[tip,])[1:k])
    score[tip]<-sum(seuratData$timeval[neighbours])
  }
  
  #we've found the root tip! use this one to get pseudotime
  fetchval<-which(score == min(score))
  rootcell<-names(score)[fetchval]
  rootval<-grep(rootcell, rownames(dm@eigenvectors))
  
  #try out alternate branch computation
  input_dat=destiny::as.data.frame(dm)[,1:2]
  
  #cluster into a K-Star with K=3
  # clust3<-kbranches.global(input_dat,Kappa=3,c0=rootval)
  # clust2<-kbranches.global(input_dat,Kappa=2,c0=rootval)
  # clustering3<-clust3$cluster
  # clustering2<-clust2$cluster
  # names(clustering2)<-colnames(seuratData)
  # names(clustering3)<-colnames(seuratData)
  # 
  # 
  # seuratData$kbranches2<-clustering2
  # seuratData$kbranches3<-clustering3
  pseudotime<-dpt[[paste0('DPT',rootval)]]
  names(pseudotime)<-rownames(dm@eigenvectors)
  seuratData$dptval<-pseudotime
  

  #add diffusion component embedding
  seuratData[['DC']]<-CreateDimReducObject(eigenvectors(dm)[,c(1,2,3)],key='DC_', assay='RNA')
  
  out<-list(dm,dpt,seuratData, rootcell)
  names(out)<-c('dm','dpt','seuratData', 'rootcell')
  
  return(out)
  
}
