from snakemake.utils import R
from os.path import join
from os import listdir
import yaml
import os
import re
import pandas as pd
import numpy as np


workpath="."

data = pd.read_csv("groups.tab", sep='\t', low_memory=False, encoding= 'unicode_escape',header = None)

data.columns = ["sample","group","name","type"]

toRun = data["sample"].tolist()

specie = "mm10"

#print(data)
print(toRun)
seqPerSample = pd.DataFrame(columns=["Sample","rna","cite","tcr","atac"])
for sample in data['name'].unique():
	tempData = data.loc[data['name'] ==sample]
	rna = tempData[tempData["type"].str.contains("gex")].shape[0]
	cite = tempData[tempData["type"].str.contains("cite_seq")].shape[0]
	tcr = tempData[tempData["type"].str.contains("VDJ")].shape[0]
	atac = tempData[tempData["type"].str.contains("atac")].shape[0]
	snare = tempData[tempData["type"].str.contains("snare")].shape[0]
	hash = tempData[tempData["type"].str.contains("hashing")].shape[0]

	seqPerSample = seqPerSample.append({'Sample' : sample, 'rna' : rna, 'cite' : cite, 'tcr' : tcr,'atac' :atac, 'hash' : hash}, ignore_index = True)

#print(seqPerSample)

df = seqPerSample.iloc[:,1:5]
df['threshold'] = 0
df1 = df.drop('threshold', 1).gt(df['threshold'], 0)
df1 = df1.apply(lambda x: '_'.join(x.index[x]),axis=1)
seqPerSample['pipeline'] = df1
#print(seqPerSample)

seqPerSample["sampPipe"] = seqPerSample["Sample"] + "__" + seqPerSample["pipeline"]

samples = seqPerSample["sampPipe"]#.tolist()
d = dict()
for sample in samples:
	d[sample] = {"sample":sample.split("__")[0],"pipeline":sample.split("__")[1]}


#print(d)
print(samples)



rule all:
	params:
		batch='--time=12:00:00',
	input:
		expand(join(workpath,"cellrangerOut","{mySample}"), mySample=toRun),
		expand(join(workpath,"flags","{name}.txt"),name=d),
		expand(join(workpath,"seuratOut","{name}.rds"),name=d)

rule seurat:
	input:
		expand(join(workpath,"cellrangerOut","{mySample}"), mySample=toRun)
	output:
		flag=join(workpath,"flags","{name}.txt"),
		rds=join(workpath,"seuratOut","{name}.rds")
	params:
		sample=lambda wildcards: d[wildcards.name]["sample"],
		pipeline=lambda wildcards: d[wildcards.name]["pipeline"],
		org=specie,
		workpath=workpath
	shell: """
mkdir -p flags
module load R/4.2.0;
module load python;
python workflow/scripts/runSeurat.py {params.sample} {params.workpath} {params.pipeline} {params.org} {output.rds}
touch {output.flag}
touch {params.sample} {params.pipeline} {params.org}
	"""
