import os

os.environ["GOTO_NUM_THREADS"]="1"
os.environ["MKL_NUM_THREADS"]="1"
os.environ["NUMEXPR_NUM_THREADS"]="1"
os.environ["OMP_NUM_THREADS"]="1"
os.environ["OPENBLAS_NUM_THREADS"]="1"
os.environ["VECLIB_MAXIMUM_THREADS"]="1"
os.environ["PYTHONNOUSERSITE"]="1"
os.environ["TMPDIR"]="/scratch/local/joosep/tmp"
os.environ["TEMPDIR"]="/scratch/local/joosep/tmp"
os.environ["TEMP"]="/scratch/local/joosep/tmp"
os.environ["TMP"]="/scratch/local/joosep/tmp"
os.environ["APPTAINER_TMPDIR"]="/scratch/local/joosep/tmp"
os.environ["APPTAINER_CACHEDIR"]="/scratch/local/joosep/tmp"
os.environ["SINGULARITY_TMPDIR"]="/scratch/local/joosep/tmp"
os.environ["SINGULARITY_CACHEDIR"]="/scratch/local/joosep/tmp"

rule all:
    input:
        expand('snakemake_jobs/clic/gen/gen_ttbar_{seed}.done', seed=range(300000, 305010, 1)),
        expand('snakemake_jobs/clic/gen/gen_ww_fullhad_{seed}.done', seed=range(400000, 405010, 1)),
        expand('snakemake_jobs/clic/gen/gen_qq_{seed}.done', seed=range(500000, 505010, 1)),
        expand('snakemake_jobs/clic/post/post_ttbar_{seed}.done', seed=range(300000, 305010, 1)),
        expand('snakemake_jobs/clic/post/post_ww_fullhad_{seed}.done', seed=range(400000, 405010, 1)),
        expand('snakemake_jobs/clic/post/post_qq_{seed}.done', seed=range(500000, 505010, 1)),
        expand('snakemake_jobs/clic/tfds/tfds_ttbar_tfds_{config}.done', config=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
        expand('snakemake_jobs/clic/tfds/tfds_ww_fullhad_tfds_{config}.done', config=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
        expand('snakemake_jobs/clic/tfds/tfds_qq_tfds_{config}.done', config=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
        "snakemake_jobs/clic/train/train_pyg-clic-v1_clic.done"

rule gen:
    input:
        expand('snakemake_jobs/clic/gen/gen_ttbar_{seed}.done', seed=range(300000, 305010, 1)), expand('snakemake_jobs/clic/gen/gen_ww_fullhad_{seed}.done', seed=range(400000, 405010, 1)), expand('snakemake_jobs/clic/gen/gen_qq_{seed}.done', seed=range(500000, 505010, 1))
    output:
        "snakemake_jobs/clic/gen/all.done"
    shell:
        "touch {output}"

rule gen_task:
    output:
        "snakemake_jobs/clic/gen/gen_{sample}_{seed}.done"
    resources:
        tmpdir="/scratch/local/joosep/tmp", mem_mb=4000, slurm_partition="main", runtime=120, slurm_account="hepusers", cpus_per_task=1, threads=1
    container:
        "/cvmfs/unpacked.cern.ch/gitlab-registry.cern.ch/key4hep/k4-deploy/alma9:latest"
    shell:
        "snakemake_jobs/clic/gen/gen_{wildcards.sample}.sh {wildcards.seed} && touch {output}"

rule post_ttbar_all:
    input:
        expand("snakemake_jobs/clic/post/post_ttbar_{seed}.done", seed=range(300000, 305010, 1))
    output:
        "snakemake_jobs/clic/post/post_ttbar_all.done"
    shell:
        "touch {output}"

rule post_ww_fullhad_all:
    input:
        expand("snakemake_jobs/clic/post/post_ww_fullhad_{seed}.done", seed=range(400000, 405010, 1))
    output:
        "snakemake_jobs/clic/post/post_ww_fullhad_all.done"
    shell:
        "touch {output}"

rule post_qq_all:
    input:
        expand("snakemake_jobs/clic/post/post_qq_{seed}.done", seed=range(500000, 505010, 1))
    output:
        "snakemake_jobs/clic/post/post_qq_all.done"
    shell:
        "touch {output}"

rule post:
    input:
        expand('snakemake_jobs/clic/post/post_ttbar_{seed}.done', seed=range(300000, 305010, 1)), expand('snakemake_jobs/clic/post/post_ww_fullhad_{seed}.done', seed=range(400000, 405010, 1)), expand('snakemake_jobs/clic/post/post_qq_{seed}.done', seed=range(500000, 505010, 1))
    output:
        "snakemake_jobs/clic/post/all.done"
    shell:
        "touch {output}"

rule post_task:
    input:
        "snakemake_jobs/clic/gen/gen_{sample}_{seed}.done"
    output:
        "snakemake_jobs/clic/post/post_{sample}_{seed}.done"
    resources:
        tmpdir="/scratch/local/joosep/tmp", mem_mb=1000, slurm_partition="main", runtime=120, slurm_account="hepusers", cpus_per_task=1, threads=1
    container:
        "/scratch/persistent/joosep/singularity/pytorch-20260305-08d6950.sif"
    shell:
        "snakemake_jobs/clic/post/post_{wildcards.sample}.sh {wildcards.seed} && touch {output}"

rule tfds_task:
    input:
        "snakemake_jobs/clic/post/post_{sample}_all.done"
    output:
        "snakemake_jobs/clic/tfds/tfds_{sample}_tfds_{config}.done"
    resources:
        tmpdir="/scratch/local/joosep/tmp", mem_mb=16000, slurm_partition="main", runtime=120, slurm_account="hepusers", cpus_per_task=1, threads=1
    container:
        "/scratch/persistent/joosep/singularity/pytorch-20260305-08d6950.sif"
    shell:
        "snakemake_jobs/clic/tfds/tfds_{wildcards.sample}.sh {wildcards.config} && touch {output}"

rule tfds:
    input:
        expand('snakemake_jobs/clic/tfds/tfds_ttbar_tfds_{config}.done', config=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), expand('snakemake_jobs/clic/tfds/tfds_ww_fullhad_tfds_{config}.done', config=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), expand('snakemake_jobs/clic/tfds/tfds_qq_tfds_{config}.done', config=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    output:
        "snakemake_jobs/clic/tfds/all.done"
    shell:
        "touch {output}"

rule tfds_hit_task:
    input:
        "snakemake_jobs/clic/post/post_{sample}_all.done"
    output:
        "snakemake_jobs/clic/tfds_hit/tfds_hit_{sample}_tfds_hit_{config}.done"
    resources:
        tmpdir="/scratch/local/joosep/tmp", mem_mb=16000, slurm_partition="main", runtime=120, slurm_account="hepusers", cpus_per_task=1, threads=1
    container:
        "/scratch/persistent/joosep/singularity/pytorch-20260305-08d6950.sif"
    shell:
        "snakemake_jobs/clic/tfds_hit/tfds_hit_{wildcards.sample}.sh {wildcards.config} && touch {output}"

rule train_pyg_clic_v1:
    input:
        "snakemake_jobs/clic/tfds/all.done"
    output:
        "snakemake_jobs/clic/train/train_pyg-clic-v1_clic.done"
    threads: 16
    resources:
        tmpdir="/scratch/local/joosep/tmp", mem_mb=8000, mem_per_gpu=80000, slurm_partition="gpu", runtime=2880, slurm_account="hepusers", gres="gpu:a100:1", cpus_per_task=16, threads=16
    container:
        "/scratch/persistent/joosep/singularity/pytorch-20260305-08d6950.sif"
    shell:
        "snakemake_jobs/clic/train/train_pyg-clic-v1_clic.sh && touch {output}"
