from snakemake.utils import R
from os.path import join
from os import listdir
import yaml
import os
import re
import pandas as pd
import numpy as np


workpath="."
specie = sys.argv[1]

specie = "hg38"

data = pd.read_csv("groups.tab", sep='\t', low_memory=False, encoding= 'unicode_escape',header = None)
data.columns = ["sample","group","name","type"]
data = data.loc[data['type'] =="gex"]

toRun = data["sample"].tolist()



contList = []
contrasts = open("contrasts.tab")
for line in contrasts:
        line = line.strip("\n")
        contList.append(re.sub('\t', '-',  line))


rule all:
	params:
		batch='--time=12:00:00',
	input:
		expand(join(workpath,"cellrangerOut","{name}","outs","filtered_feature_bc_matrix.h5"), name=toRun),
		expand(join(workpath,"seuratOut","{name}.rds"),name=toRun),
		expand(join(workpath,"integration","merged","{myGroups}.rds"),myGroups=contList),
		expand(join(workpath,"integration","seuratIntegrated","{myGroups}.rds"),myGroups=contList),
		expand(join(workpath,"integration","RPCA","{myGroups}.rds"),myGroups=contList),
		expand(join(workpath,"integration","harmonyGroup","{myGroups}.rds"),myGroups=contList),
		expand(join(workpath,"integration","harmonySample","{myGroups}.rds"),myGroups=contList)

rule runCR:
	output:
		join(workpath,"cellrangerOut","{name}","outs","filtered_feature_bc_matrix.h5")

rule perSample:
	input:
		join(workpath,"cellrangerOut","{name}","outs","filtered_feature_bc_matrix.h5")
	output:
		join(workpath,"seuratOut","{name}.rds")
	params:
		specie=specie
	shell: """
. "/data/CCBR_Pipeliner/db/PipeDB/Conda/etc/profile.d/conda.sh"
#conda init scRNA4
conda activate scRNA4
module load R/4.2.0;
Rscript workflow/scripts/scRNA.R {input} {params.specie}  {output}
	"""

rule seuratIntegration:
	input:
		expand(join(workpath,"seuratOut","{mySample}.rds"), mySample=toRun)
	output:
		join(workpath,"integration","seuratIntegrated","{myGroups}.rds")
	params:
		batch='--cpus-per-task=4 --mem=300g --time=24:00:00',
		specie = specie,
		contrasts = "{myGroups}"


	shell: """
		. "/data/CCBR_Pipeliner/db/PipeDB/Conda/etc/profile.d/conda.sh"
		conda activate scRNA4
		module load R/4.2.0;
		Rscript workflow/scripts/integrateBatches.R "seuratOut" {output}  {params.specie} {params.contrasts} "0.1,0.2,0.3,0.5,0.6,0.8,1"
	"""


rule rpca:
	input:
		expand(join(workpath,"seuratOut","{mySample}.rds"), mySample=toRun)
	output:
		join(workpath,"integration","RPCA","{myGroups}.rds")
	params:
		batch='--cpus-per-task=4 --mem=300g --time=24:00:00',
		specie = specie,
		contrasts = "{myGroups}"

	shell: """
		. "/data/CCBR_Pipeliner/db/PipeDB/Conda/etc/profile.d/conda.sh"
		conda activate scRNA4
		module load R/4.2.0;
		Rscript workflow/scripts/rpca.R seuratOut/ {output}  {params.specie} {params.contrasts} "0.1,0.2,0.3,0.5,0.6,0.8,1"
	"""

rule harmony:
	input:
		expand(join(workpath,"seuratOut","{mySample}.rds"), mySample=toRun)
	output:
		merged=join(workpath,"integration","merged","{myGroups}.rds"),
		harmonyGroup=join(workpath,"integration","harmonyGroup","{myGroups}.rds"),
		harmonySample=join(workpath,"integration","harmonySample","{myGroups}.rds")
	params:
		batch='--cpus-per-task=4 --mem=300g --time=24:00:00',
		specie = specie,
		contrasts = "{myGroups}"
	shell: """
	. "/data/CCBR_Pipeliner/db/PipeDB/Conda/etc/profile.d/conda.sh"
	conda activate scRNA4
	module load R/4.2.0;
	Rscript workflow/scripts/harmony.R seuratOut/ {output.merged} {output.harmonyGroup} {output.harmonySample}  {params.specie} {params.contrasts} "0.1,0.2,0.3,0.5,0.6,0.8,1"
	"""
