
#packages(?) check package loader in case something is missing
# library(Seurat)
# library(randomForest)
# library(future)
# library(doMC)



#Notes for use
# - This will take data from the RNA assay not the integrated assay, because we want to compare expression rather than distance
# - clustering should be the name of the metadata column where the cluster labels are
# - Example call: ScoreMembership(seuratObject, 'louvain_clustering') where you cluster labels are in seuratObject@meta.data[['louvain_clustering']]

#Minimum cluster representative controls how many cells ytou require as a minimum from eahc cluster identity, in each test set. If repeated sampling leaves few cells left, the function gives up.

#function to apply RF validation and score the class membership. Iteratives random forests to predict cluster identity
ScoreMembership <- function(seuratObject,
                            clustering,
                            minClusterRep=5,
                            ntree=100,
                            logFC.thresh=log(1.5),
                            nGene=10,
                            nIteration=10,
                            nCore=10,
                            ...){
  
  #relabel clusters into numeric, sequential labels
  
  # clusters<-unique(seuratObject[[clustering]][,1])
  # names(clusters)<-1:length(clusters)
  # seuratObject[[clustering]][,1]<-names(clusters)[match(seuratObject[[clustering]][,1],clusters)]
  # 
  cells<-rownames(seuratObject[[clustering]])
  seuratObject <- subset(seuratObject, cells = cells[which(!is.na(seuratObject[[clustering]]))] )
  DefaultAssay(seuratObject)<-'RNA'
  Input<-as.matrix(seuratObject@assays[['RNA']]@data)
  cate <- as.data.frame(as.factor(seuratObject[[clustering]][colnames(Input),]))
  rownames(cate)<-colnames(Input)
  feature <- t(as.matrix(Input))
  
  #size of test set
  ntest<-floor(dim(feature)[1]/5)
  #remainder (add on to the last group)
  rem<-dim(feature)[1]%%5
  
  
  cell_names<-rownames(feature)
  
  #Record possible combinations of classes
  classes<-unique(seuratObject[[clustering]])
  combos<-combn(classes[,1],m=2)
  
  
  #store predictions for each cell
  #predictions<-matrix(nrow=length(cell_names),ncol=10*dim(combos)[2])
  
  registerDoMC(nCore)
  
  out<-foreach(iteration = 1:nIteration) %dopar% {
    
    #for reproducibility
    set.seed(iteration)
    
    #matrix holding individual cell predictions for this iteration
    predictions<-matrix(nrow=length(cell_names),ncol=dim(combos)[2])
    rownames(predictions)<-cell_names
    colnames(predictions)<-1:dim(combos)[2]
    
    print(paste0("starting iteration: ",as.character(iteration)))
    
    test_sets<-list()
    
    waitCounter<-0
    successfulTestSets<-FALSE
    
    while(successfulTestSets==FALSE){
      possibilities<-1:dim(feature)[1]
      #Generate 5 test sets, that are each a different fifth of the data
      for(i in 1:5){
        if(i==5){
          #not divisible by 5, so make up the remainder in the last round
          set<-base::sample(possibilities,size=ntest+rem,replace=FALSE)
        }else{
          set<-base::sample(possibilities,size=ntest,replace=FALSE)
        }
        test_sets[[i]]<-set        
        possibilities<-possibilities[is.na(match(possibilities,set))]
      }
      
      #Check: Does each test set contain at least 5 samples from each class in our clustering? 
      keeptry=TRUE
      for(i in 1:5){
        if(keeptry){
          
          counts<-table(seuratObject[[clustering]][test_sets[[i]],])
          
          if(sum(counts<=minClusterRep) > 0 ){
            test_sets<-list()
            waitCounter<-waitCounter+1
            keeptry=FALSE;
          }
          else if(waitCounter>10){
            print("class proportions too low! Generating test sets failed")
            break()
          }
        }
        
      }
      
      if(length(test_sets)==5){
        successfulTestSets=TRUE
      }
    }
    
    
    for(setnum in 1:5){
      test<-test_sets[[setnum]]
      test_feat<-feature[test,]
      test_cate<-cate[test,]
      
      training_feat<-feature[-test,]
      training_cate<-cate[-test,]
      
      #Get DE genes for each pair
      Idents(seuratObject)<-clustering
      DefaultAssay(seuratObject)<-'RNA'
      seuratObject_training<-subset(seuratObject,cells=rownames(training_feat))
      preds<-list()
      
      plan("multiprocess", workers = 1)
      
      #Get DE genes to characterise clusters 
      
      #DE_Markers<-FindAllMarkers(seuratObject_training,verbose=FALSE,logfc.threshold = log(2),min.diff.pct = 0.2)
      DE_Markers<-FindAllMarkers(seuratObject_training,verbose=T,logfc.threshold = logFC.thresh,only.pos=T,min.pct=0.5)
      
      

      
      for(i in 1:dim(combos)[2]){
        comb<-combos[,i]
        
        #grab the genes to do with the pair of clusters we're looking at in this loop  
        DE_Markers_sub<-DE_Markers[c(which(DE_Markers$cluster==comb[1]),which(DE_Markers$cluster==comb[2])),]
        DE_Markers_sub<-DE_Markers_sub  %>% group_by(cluster) %>%top_n(n=nGene, abs(avg_logFC))
        DE_genes<-DE_Markers_sub$gene
        
        ind<-c(which(!is.na(match(training_cate,comb[1]))),which(!is.na(match(training_cate,comb[2]))))
        
        training_feat_sub<-training_feat[ind,DE_genes]
        training_cate_sub<-droplevels(training_cate[ind])
        
        #for subsampling
        minSize<-min(table(training_cate_sub))
        
        randf<-randomForest::randomForest(x=training_feat_sub, y=training_cate_sub, ntree=ntree,
                                          sampsize=rep(minSize,2)
                                          #classwt=table(training_cate_sub)
                                          ) 
        
        #Train forest for every pair of groups. 3 groups -> 3 forests
        # randf<- foreach(ntree=rep(5, 10), .combine=randomForest::combine,
        #                 .multicombine=TRUE, .packages='randomForest') %dopar%
        #   randomForest(training_feat_sub, training_cate_sub, ntree=ntree) 
        
        #Classify test data
        preds[[i]]<-predict(randf,test_feat,importance=T)
        
        #update objects
        #rfs[[length(rfs)+1]]<-list(randf,comb,setnum)
        
      }
      
      #save predictions to master table
      for(j in 1:dim(combos)[2]){
        dat<-preds[[j]]
        ind<-match(names(dat),rownames(predictions))
        predictions[ind,j]<-as.numeric(as.character(dat))
      }
      
      
    }
    
    return(predictions)
    
  }
  
  #End of 10 iterations. After this we have 10 predictions for each cell. 
  
  
  #Merge all the predictions into one object
  predictionTable<-vector()
  for(i in 1:dim(combos)[2]){
    block<-vector()
    for(j in 1:length(out)){
      block<-cbind(block,out[[j]][,i])
    }
    predictionTable<-cbind(predictionTable,block)
  }
  
  #This has all the predictions together
  predictions<-predictionTable
  
  #This matrix will hold a class membership score for each cell, for each class
  membershipScore<-matrix(0,nrow=dim(predictions)[1],ncol=length(classes[,1]))
  rownames(membershipScore)<-rownames(predictions)
  colnames(membershipScore)<-sort(classes[,1])
  

  
  
  for(i in  1:dim(predictions)[1]){
    #grab the predictions just for this cell
  
    cell<-predictions[i,]
    
    #what are the possible class identities of the cell?
    class_set<-classes[,1]
    
    for(class in 1:dim(combos)[2]){
      #for each pairwise class comparison, we grab the predictions corresponding to that comparison. They are in blocks of nIteration
      chunk<-cell[((nIteration*class)-(nIteration-1)):(nIteration*class)]
      
      #record any class in this pairwise comparison which does not appear in the predictions. It has been dominated by the other possibility 
      dominated<-combos[which(is.na(match(combos[,class],chunk))),class]
      
      #remove any dominated classes from the set of possible identities
      if(length(dominated!=0)){
        class_set[match(dominated,class_set)]<-NA
      }
      
    }
    
    #consider the classes that remain
    class_set<-class_set[!is.na(class_set)]
    
    # evaluated only the predictions for the remaining classes
    cell<-cell[!is.na(match(cell,class_set))]
    
    #the membership score is the proportion of predicitons for each class
    scores<-table(cell)/length(cell)
    #record in matrix
    membershipScore[i,as.character(names(scores))]<-as.numeric(scores)
  }
  
  
  #We need to give a certainty category, based on the value of scores
  #We also need a final classification, taking highest valued score. 
  certainty<-vector(length=dim(membershipScore)[1])
  classification<-vector(length=dim(membershipScore)[1])
  
  
  for(i in 1:length(certainty)){
    
    #This holds all the scores for the cell. Most of them will be zero.
    cell<-membershipScore[i,]
    #seek the class with the highest membership score
    find<-grep(max(as.numeric(cell[1:length(classes[,1])])),as.numeric(cell[1:length(classes[,1])]))
    
    if(sum(as.numeric(as.character(cell[1:length(classes[,1])]))==1)==1){
      #If there is one class with a score of 1, we have identified a single dominant identity: this is a core cell
      certainty[i]<-'Core'
    }else if(sum(as.numeric(as.character(cell[1:length(classes[,1])])))==0  || length(find)>1){
      #if no identities have any positive score (all identities are dominated), or if the maximal score is found in more than one column (multiple classes have the same score) then we cannot discern identity.
      certainty[i]<-'Failure'
    }else{
      #The remaining cases are those with non-zero scores that are less than one, and one class has the highest score. These cells have multiple identity ties
      certainty[i]<-'Intermediate'
    }
    
    if(certainty[i]=='Failure'){
      classification[i]<-"U"
    }else{
      classification[i]<-find
    }
  }
  
  membershipScore<- as.data.frame(cbind(membershipScore,certainty,classification))
  
  #returns the scores, along with the big prediction table (for me to check if the predictions look right)
  return(list(membershipScore,predictions))
  
  
}

