## run viterbi using estimates rates and transition function, emission function

 
suppressPackageStartupMessages({
  library(dplyr)
  library(ggplot2)
  })

args <- (commandArgs(trailingOnly = TRUE))

for (i in seq_len(length(args))) {
  eval(parse(text = args[[i]]))
}

## input txt is the obtained DP/AD data.frame such as in bulkDNAseq_BGI/F20FTSAPHT0350_MUSyfqR/chr_cos/92/92_chr1_dp5.tsv
print(input_txt)
print(chr)
print(sid)
## whether use posterior estimates from FVB sample
print(use_post)
## the viterbi path plot 
print(out_png)

## the defined transition probability to use,  how many cM per Mb 
print(transit_prob)
## the output txt for Viterbi path states
print(co_out)
print(rate0)
print(rate1)
print(rate0_dir)
print(rate1_dir)

## the cut off of success rate of observing a FVB count for genotype 1/1, like 0.9
rate0 <- as.numeric(rate0)
transit_prob <- as.numeric(transit_prob)
## the cut off of success rate of observing a FVB count for genotype 0/1, like 0.6
rate1 <- as.numeric(rate1)
# c("Chr","Pos","Ref","Alt","DP","AD","GT","Ref_count","Alt_count","total_counts","Rate_1"),

## estimated by bulk_BL6_FVB_rate0_estimate.R 

## "output/bulkDNAseq/haplotypecaller/20_HH5NTDRXX_clean_"
if(use_post=="true"){
  rate_0 <- data.frame(readr::read_table2(file = paste0(rate0_dir,chr,"_postRate0.tsv"),
                                          col_names = T,
                                          col_types = c("cdccdccdddddddc")))
  rate_0$Rate_0 <- rate_0$map_e
  rate_1 <-  data.frame(readr::read_table2(file = paste0(rate1_dir,chr,"_postRate1.tsv"),
                                         col_names = T,
                                         col_types = c("ccdccddddddddd")))
  rate_1$Rate_1 <- rate_1$map_e
} else {
  
  rate_0 <- data.frame(readr::read_table2(file = paste0(rate0_dir,chr,"_rate0.tsv"),
                                         col_names = T,
                                        col_types = c("cnccdccdddd")))
                                        
  rate_1 <-  data.frame(readr::read_table2(file = paste0(rate1_dir,chr,"_rate1.tsv"),
                                         col_names = T,
                                         col_types = c("ccdccdddd")))
  rate_1$Rate_1 <- rate_1$AF_mean
}





## only keep sites that have passed the gated thresholds
## 

rate_0 <- rate_0[rate_0$Rate_0 > rate0,]
rate_1 <- rate_1[rate_1$Rate_1 < rate1,]

## SNP id as rownames

if(is.null(rate_0$SNP)){
  rate_0$SNP <- paste0(rate_0$Chr,"_",as.character(rate_0$Pos))
} 

rownames(rate_0) <- rate_0$SNP

if(is.null(rate_1$SNP)){
  rate_1$SNP <- paste0(rate_1$Chr,"_",as.character(rate_1$Pos))

}
rownames(rate_1) <- rate_1$SNP



cell <- data.frame(readr::read_table2(file = paste0(input_txt),
                                      col_names =  c("Pos","DP","AD","GT"),
                                      col_types = c("ddcc")))
cell$ALT <- 0 


## remove the non-biallelic sites which must be error/noise
cell <- cell[(cell$GT =="1/1") | (cell$GT =="0/1"), ]

alt_pos <- grepl(",",cell$AD )

cell$ALT[alt_pos] <- sapply(strsplit(cell$AD[alt_pos],","),`[[`,2)
cell$ALT <- as.numeric(cell$ALT)

head(cell)
rownames(cell) <- paste0(chr,"_",cell$Pos)


good_sites <- intersect(rate_1$SNP,rate_0$SNP)
good_sites <- intersect(rownames(cell),good_sites)

length(good_sites)


## ensure ordering of SNPs is correct

stopifnot(identical(good_sites[order(as.numeric(sapply(strsplit(good_sites,"_"),
                                                       `[[`,2)),decreasing = F)],
                    good_sites))
cell <- cell[good_sites,]
rate_0 <- rate_0[good_sites,]
rate_1 <- rate_1[good_sites,]



pos <- cell$Pos
rec <- 1-exp(-diff(pos)*1e-8) # for autosomes 1 cM per Mb 
rec <- rec*transit_prob

observes <- matrix(t(cell[,c("DP","ALT"),]),nrow=2)

## i index for SNP along the chromosome
## log scale transition prob
# rec <- 1-exp(-diff(pos)*1e-8) # for autosomes 1 cM per Mb
transbasis <- rbind(c(0,1),c(1,0))

gettrans <- function(i)
{

  return(log(diag(2)*(1-rec[i])+transbasis*rec[i]))
}

# observes, rate_0, rate_1
# rates for each SNP

get_emission<-function(observes,rate_0,rate_1)
{
  ems_0 <- dbinom(x = observes[2,],size = observes[1,],prob = rate_0,
                  log=T)
  ems_1 <- dbinom(x = observes[2,],size = observes[1,],prob = rate_1,
                  log=T)
  return(cbind(ems_0,ems_1))
}

## observes 2XN matrix
run_viterbi <- function(observes,init_prob =c(.5,.5),emission_matrix){
  ## number of steps/SNPs
  N  <-  ncol(observes)
  ## number of states
  S <- ncol(emission_matrix)
  ## initiate
  pathScore <- matrix(1,N,S)
  pathStates <- matrix(1,N,S)
  states_seq <- rep(0,N)

  observe_seq_prob <- t(emission_matrix)

  ## starting score on a log scale
  pathScore[1,] <- log(init_prob) + observe_seq_prob[,1]

  for (step in 2:N){
    belief <- pathScore[step-1,] + gettrans(step-1)

    ## for time step , the state, which prev state gives a max prob, record a backtrace
    pathStates[step,] <- apply(belief,2,which.max)
    pathScore[step,] <- apply(belief,2,max) + observe_seq_prob[,step]
  }

  # infer most likely last state
  states_seq[N] <-  which.max(pathScore[N,])

  ## back trace

  for(step in N:2){
    # for every timestep retrieve inferred state
    state = states_seq[step] ## the last state
    ## index of the pre step that gives the largest score at the step
    state_prob = pathStates[step,state]
    states_seq[step-1] = state_prob
  }
  return( list(states_seq = states_seq, pathScores=exp(pathScore),
               pathStates = pathStates)) # turn scores back to probabilities)
}

## incorrect calculation would occure if the estimated rate is 0 or 1..
## now try to map the absolute 0 to 0.001 and 1 to 0.999
mapp_rate0 <- ifelse(rate_0$Rate_0 ==1,0.999,rate_0$Rate_0)
mapp_rate1 <- ifelse(rate_1$Rate_1 ==0,0.001,rate_1$Rate_1)
head(mapp_rate0)
head(mapp_rate1)
head(observes[1:2,1:5])
emission_matrix <- get_emission(observes = observes, rate_0 = mapp_rate0,
                                rate_1 = mapp_rate1)

                             
dim(emission_matrix)

paths <- run_viterbi(observes =observes,init_prob = c(0.5,0.5),
                     emission_matrix =emission_matrix)
head(paths$states_seq)

plt_df <- cell
plt_df$state <- paths$states_seq
plt_df$state <- as.factor(plt_df$state)
plt_df$sid <- sid
plt_df$ALT_ratio <- plt_df$ALT/plt_df$DP
plt_df$Rate_0 <- mapp_rate0
plt_df$Rate_1 <- mapp_rate1

ggplot(data=plt_df)+geom_point(mapping = aes(x = Pos,y= state,
                                             color =ALT_ratio ),
                               size = 0.3)+ scale_color_continuous(type = "viridis")+
  theme_classic()+xlab(chr)+ylab(paste0(sid, " State"))
#  geom_point(mapping = aes(x = Pos,y= ALT_ratio/2,color=as.factor(ALT_ratio/2)))+
  
# 

ggsave(file = out_png,dpi=70,width = 14,height = 4)

write.table(plt_df,file=co_out,quote = F,col.names = T,row.names = F)
