#!/usr/bin/env Rscript

#----------------------------------------
# Script to run seGMM
#----------------------------------------

library(stringr)

# get args from command line
args <- commandArgs(trailingOnly=TRUE)
filename_downsampled_counts <- args[1]
filename_full_counts <- args[2]
filename_bexy_statePos_s_full <- args[3]
prefixOut <- args[4]

# Read full counts
n_full <- read.table(filename_full_counts, header = T, check.names = F)
samples_full <- n_full$individual
n_full$individual <- NULL
n_full$sequencing_type <- NULL

# Read sex classification by BeXY
threshold_certainty <- 0.9

statePos_s_full <- read.table(filename_bexy_statePos_s_full, header = T, check.names = F)
certain <- apply(statePos_s_full, 2, function(x) any(x > threshold_certainty))
max_s <- as.numeric(apply(statePos_s_full, 2, which.max) - 1)

# Make sure s and n match in samples
samplesFromS <- as.character(sapply(names(statePos_s_full), function(x) return(str_split(x, "s_")[[1]][2])))
if (length(samplesFromS) != length(samples_full) | any(samplesFromS != samples_full)){ stop("Mismatch in samples!")}

counts_XY <- n_full[max_s == 0 & certain,]
counts_XX <- n_full[max_s == 1 & certain,]

# Normalize:
# Xmap/Ymap was computed as the fraction of high-quality reads (mapq > 30) that mapped to the X/Y chromosome divided by the total number of high-quality reads that mapped to the genome using the samtools algorithm

Xmap_XY <- counts_XY$X / rowSums(counts_XY)
Xmap_XX <- counts_XX$X / rowSums(counts_XX)
Ymap_XY <- counts_XY$Y / rowSums(counts_XY)
Ymap_XX <- counts_XX$Y / rowSums(counts_XX)

# Calculate mean and sd per sex

mean_m_xmap <- mean(Xmap_XY)
sd_m_xmap <- sd(Xmap_XY)

mean_f_xmap <- mean(Xmap_XX)
sd_f_xmap <- sd(Xmap_XX)

mean_m_ymap <- mean(Ymap_XY)
sd_m_ymap <- sd(Ymap_XY)

mean_f_ymap <- mean(Ymap_XX)
sd_f_ymap <- sd(Ymap_XX)

# Define gates
# Code from https://github.com/liusihan/seGMM/blob/main/code/script/seGMM.r
getKaryotype <- function(x, y,
                         mean_m_xmap, sd_m_xmap, 
                         mean_f_xmap, sd_f_xmap, 
                         mean_m_ymap, sd_m_ymap, 
                         mean_f_ymap, sd_f_ymap){
  if (x>(mean_m_xmap-3*sd_m_xmap) & x<(mean_m_xmap+3*sd_m_xmap) & y>(mean_m_ymap-3*sd_m_ymap) & y<(mean_m_ymap+3*sd_m_ymap)){
    return(0) # XY
  } else if (x>(mean_m_xmap-3*sd_m_xmap) & x<(mean_m_xmap+3*sd_m_xmap) & y>(2*mean_m_ymap)){
    return(4) # XYY
  } else if (x>(2*mean_m_xmap) & y>(mean_m_ymap-3*sd_m_ymap) & y<(mean_m_ymap+3*sd_m_ymap)){
    return(3) # XXY
  } else if (x<(mean_m_xmap/5) & y>(mean_m_ymap-3*sd_m_ymap) & y<(mean_m_ymap+3*sd_m_ymap)){
    stop("karyotype = Y") # Y
  } else if (x>(mean_f_xmap-3*sd_f_xmap) & x<(mean_f_xmap+3*sd_f_xmap) & y>(mean_f_ymap-3*sd_f_ymap) & y<(mean_f_ymap+3*sd_f_ymap)){
    return(1) # XX
  } else if (y>(3*mean_f_ymap) & x>(mean_f_xmap-3*sd_f_xmap) & x<(mean_f_xmap+3*sd_f_xmap)){
    return(3) # XXY
  } else if (x>(1.5*mean_f_xmap) & y>(mean_f_ymap-3*sd_f_ymap) & y<(mean_f_ymap+3*sd_f_ymap)){ ##### <----- Note: I fixed a bug here, previously was x>(2*mean_f_xmap)
    return(5) # XXX
  } else if (x<(0.5*mean_f_xmap) & y>(mean_f_ymap-3*sd_f_ymap) & y<(mean_f_ymap+3*sd_f_ymap)){
    return(2) # X
  } else{
    return(NA)
  }
}

# Now read downsampled counts
n_downsampled <- read.table(filename_downsampled_counts, header = T, check.names = F)
samples_downsampled <- n_downsampled$individual
n_downsampled$individual <- NULL
n_downsampled$sequencing_type <- NULL

# Run

karyotypes <- numeric(nrow(n_downsampled))
for (i in 1:nrow(n_downsampled)){
  x <- n_downsampled$X[i] / sum(n_downsampled[i,])
  y <- n_downsampled$Y[i] / sum(n_downsampled[i,])
  karyotypes[i] <- getKaryotype(x, y, 
                                mean_m_xmap, sd_m_xmap, 
                                mean_f_xmap, sd_f_xmap, 
                                mean_m_ymap, sd_m_ymap, 
                                mean_f_ymap, sd_f_ymap)
}

# Write results
names(karyotypes) <- samples_downsampled
write.table(karyotypes, file = paste0(prefixOut, "_seGMM.txt"), append = F, quote = F, sep = "\t", row.names = F, col.names = T)


