## find low confidence CO for bulkDNA seq samples
## return the segment information for each chromosome
## author: Ruqian Lyu
## Date: 2020-10-19

suppressPackageStartupMessages({
  library(dplyr)
  library(tidyr)
  library(ggplot2)
  library(doParallel)
  library(foreach)
})

args <- (commandArgs(trailingOnly = TRUE))
for (i in seq_len(length(args))) {
  eval(parse(text = args[[i]]))
}

print(threads)
print(chr)
print(outTSV)
#print(fvbRateFile)
#print(greyRateFile)
print(sample_meta_file)
#print(badSampleChr)
print(filePath)
print(bcfResult)
print(transit_prob)
## per 100 SNP in a bin
ncluster <- as.numeric(threads)
transit_prob <- as.numeric(transit_prob)


cl <- makeCluster(ncluster)
registerDoParallel(cl)

sampleNames <- read.table(sample_meta_file,
                            stringsAsFactors = F,
                            header = 1)
if(grepl(",",sampleNames[1])){
  sampleNames <- read.csv(sample_meta_file,
                            stringsAsFactors = F,
                            header = 1)
}
sampleNames <- unique(sampleNames$sample_name)

#badCells <- read.table(badSampleChr,stringsAsFactors = F,header = T)


## state 1 is FVB(AF=1)
# fvbRate <- read.table(file=fvbRateFile,
#                       stringsAsFactors = F,
#                       header =T)
# ## state 2 is HET
# greyRate <- read.table(file=greyRateFile,
#                       stringsAsFactors = F,
#                       header =T)

## for each sample, get the formed haplotype blocks and the logllRatio
get_seg_info <- function(sampleIDs,
                         chrs = chr,
                         file_path=filePath,
                         bcfResult=bcfResult,transit_prob){
  
  binded_co <- foreach(i_c = sampleIDs,
                       .combine="rbind",
                       .packages = c("tidyr","dplyr")) %dopar% {
                       
                file_name <-paste0(file_path,i_c,"/",i_c,"_",chrs,bcfResult,"_dp2_postvi.tsv")
                co_df <- read.table(file = file_name, header = T)
                #co_df$Rate_0 <- fvb_Rate$map_e[match(co_df$SNP,fvb_Rate$SNP)]
                #co_df$Rate_1 <- bl6_Rate$map_e[match(co_df$SNP,bl6_Rate$SNP)]
                # message(i_c)

                calc_emission_change <- function( inferred_state,
                                                  alt_count,
                                                  total_dp,
                                                  rate0,
                                                  rate1,type){
                  if(type == "original")
                   {
                     if(inferred_state=="1"){
                       log_prob <-
                         sum(dbinom(
                           prob = rate0,
                           size = total_dp,
                           x = alt_count,
                           log = T
                         ))
                     } else {
                       log_prob <-
                         sum(dbinom(
                           prob = rate1,
                           size = total_dp,
                           x = alt_count,
                           log = T
                         ))
                     }
                     
                   } else {
                     if(inferred_state=="1"){
                       log_prob <-
                         sum(dbinom(
                           prob = rate1,
                           size = total_dp,
                           x = alt_count,
                           log = T
                         ))
                     } else {
                       log_prob <-
                         sum(dbinom(
                           prob = rate0,
                           size = total_dp,
                           x = alt_count,
                           log = T
                         ))
                     }
                   }
                  # message(log_prob)
                   return(log_prob)
                   
                 }
                 
             
             to_re <- co_df %>% dplyr::mutate(CO = ( dplyr::lag(state) != state),
                                              Prev =  dplyr::lag(Pos),
                                       Next =  dplyr::lead(Pos))
             to_re <- to_re %>% dplyr::mutate(cum_CO = c(0,cumsum(CO[2:(dim(to_re)[1])])))
             
             to_re_t <- to_re %>% dplyr::group_by(cum_CO) %>% 
               summarise(seg_start =  dplyr::first(Pos),
                         seg_prev_snp =  dplyr::first(Prev),
                         seg_end =  dplyr::last(Pos),
                         seg_next_snp =  dplyr::last(Next),
                         seg_len = seg_end - seg_start+1,
                         seg_state = unique(state),
                         SNP_count = length(Pos),
                         emission_origin_logprob = calc_emission_change(
                           inferred_state = unique(seg_state),
                           rate0=Rate_0,
                           rate1=Rate_1,
                           alt_count = ALT,
                           total_dp = DP ,
                           type="original"),
                         emission_reverse_logprob = calc_emission_change(
                           inferred_state = unique(seg_state),
                           rate0=Rate_0,
                           rate1=Rate_1,
                           alt_count = ALT,
                           total_dp = DP ,
                           type="reversed"),
                         Chr = chrs,
                         Sid = i_c)

            to_re <- to_re_t %>% dplyr::mutate(gap1 = seg_start - seg_prev_snp,
                                         gap2 = seg_next_snp - seg_end,
                                         trans_yes_gap1 = ifelse(is.na(gap1),
                                                                 log(0.5),
                                                                 log(1-exp(-gap1*1e-8*transit_prob))),
                                         trans_yes_gap2 = ifelse(is.na(gap2), 0,
                                                                 log(1-exp(-gap2*1e-8*transit_prob))),
                                         trans_no_gap1 = ifelse(is.na(gap1), log(0.5),
                                                                log(exp(-gap1*1e-8*transit_prob))),
                                         trans_no_gap2 = ifelse(is.na(gap2),0,
                                                                log(exp(-gap2*1e-8*transit_prob))),
                                         logllRatio = (emission_origin_logprob+
                                                       trans_yes_gap1+trans_yes_gap2)-(emission_reverse_logprob+trans_no_gap1+trans_no_gap2) )
            to_re
     }
  
  binded_co
}
 
segments_posterio_ratio <- get_seg_info(sampleIDs = sampleNames,
                                              chrs=chr,
                                              file_path=filePath,
                                              bcfResult=bcfResult,
                                              transit_prob=transit_prob)
#ggplot(data=segments_posterio_ratio)
write.table(segments_posterio_ratio,
            file=outTSV,
            col.names = T,
            quote = F,
            row.names = F)
