import os
import subprocess
import pickle as pkl
import sys

# Ensure mlpf is in the path for unpickling
sys.path.append(os.getcwd())

from mlpf.conf import MLPFConfig, ModelType

# Configuration
CONFIGS = ["standard", "linear", "gnn-lsh-bin32", "gnn-lsh-bin128"]
DEVICES = ["cuda", "cpu"]
#DEVICES = ["cuda"]
MODEL_NAME = "pyg-cms-v1"
PRODUCTION = "cms_run3"
NUM_STEPS = 2000
CHECKPOINT_FREQ = 2000

# Directory paths
EXPERIMENTS_DIR = "experiments"
ONNX_DIR = "onnx_validation"

def get_onnx_validate_extra_args(wildcards):
    extra_args = ""
    if wildcards.config.startswith("gnn-lsh"):
        model_kwargs_path = os.path.join(EXPERIMENTS_DIR, f"training_{wildcards.config}", "model_kwargs.pkl")
        if os.path.exists(model_kwargs_path):
            with open(model_kwargs_path, "rb") as f:
                config_raw = pkl.load(f)
                if isinstance(config_raw, dict):
                    config = MLPFConfig.model_validate(config_raw)
                else:
                    config = MLPFConfig.model_validate(config_raw.model_dump())

                if config.model.type == ModelType.GNN_LSH:
                    extra_args += f" --pad-bin-size {config.model.gnn_lsh.bin_size}"
    return extra_args

# Helper to retrieve DATA_DIR from the spec file
def get_data_dir():
    try:
        cmd = "PF_SITE=local python3 scripts/get_param.py particleflow_spec.yaml productions.cms_run3.workspace_dir"
        workspace_dir = subprocess.check_output(cmd, shell=True).decode().strip()
        return os.path.join(workspace_dir, "tfds/")
    except Exception as e:
        print(f"Error retrieving DATA_DIR: {e}")
        return "data/cms/tfds/" # Fallback default

DATA_DIR = get_data_dir()

rule all:
    input:
        expand(os.path.join(ONNX_DIR, "{device}", "{config}", "model_math_fp32.onnx"), device=DEVICES, config=CONFIGS),
        expand(os.path.join(ONNX_DIR, "{device}", "plots", "runtime_scaling_summary.png"), device=DEVICES)

rule train:
    output:
        checkpoint = os.path.join(EXPERIMENTS_DIR, "training_{config}", "checkpoints", f"checkpoint-{NUM_STEPS}.pth"),
        model_kwargs = os.path.join(EXPERIMENTS_DIR, "training_{config}", "model_kwargs.pkl")
    params:
        extra_args = lambda wildcards: {
            "standard": "--model.attention.attention_type flash",
            "linear": "--model.attention.attention_type linear",
            "gnn-lsh-bin32": "--model.type gnn_lsh --model.gnn_lsh.num_convs 3 --model.gnn_lsh.bin_size 32 --pad_to_multiple_elements 32",
            "gnn-lsh-bin128": "--model.type gnn_lsh --model.gnn_lsh.num_convs 3 --model.gnn_lsh.bin_size 128 --pad_to_multiple_elements 128"
        }[wildcards.config]
    shell:
        """
        PF_SITE=local ./scripts/local/wrapper.sh python3 mlpf/pipeline.py \
            --spec-file particleflow_spec.yaml \
            --model-name {MODEL_NAME} \
            --production {PRODUCTION} \
            --data-dir {DATA_DIR} \
            --experiment-dir {EXPERIMENTS_DIR}/training_{wildcards.config} \
            train \
            --gpu_batch_multiplier 6 \
            --checkpoint_freq {CHECKPOINT_FREQ} \
            --ntrain 1000 --ntest 1000 --nvalid 1000 \
            --num_steps {NUM_STEPS} \
            {params.extra_args}
        """

rule onnx_validate:
    input:
        checkpoint = os.path.join(EXPERIMENTS_DIR, "training_{config}", "checkpoints", f"checkpoint-{NUM_STEPS}.pth"),
        model_kwargs = os.path.join(EXPERIMENTS_DIR, "training_{config}", "model_kwargs.pkl")
    output:
        onnx_model = os.path.join(ONNX_DIR, "{device}", "{config}", "model_math_fp32.onnx"),
        summary = os.path.join(ONNX_DIR, "{device}", "{config}", "summary.json")
    params:
        configs_args = lambda wildcards: "--configs PT_MATH_FP32 PT_MATH_FP16 ONNX_MATH_FP32 ONNX_MATH_FP16" if (wildcards.config == "linear" or wildcards.config.startswith("gnn-lsh")) else "",
        extra_args = get_onnx_validate_extra_args,
        num_events = lambda wildcards: 100 if wildcards.device == "cuda" else 10
    shell:
        """
        PF_SITE=local ./scripts/local/wrapper.sh python3 scripts/cms-validate-onnx.py \
            --checkpoint {input.checkpoint} \
            --model-kwargs {input.model_kwargs} \
            --dataset cms_pf_ttbar \
            --data-dir {DATA_DIR} \
            --num-events {params.num_events} \
            --outdir {ONNX_DIR}/{wildcards.device}/{wildcards.config} \
            --device {wildcards.device} \
            {params.configs_args} \
            {params.extra_args}
        """

rule plot_onnx_summary:
    input:
        lambda wildcards: expand(os.path.join(ONNX_DIR, wildcards.device, "{config}", "summary.json"), config=CONFIGS)
    output:
        os.path.join(ONNX_DIR, "{device}", "plots", "runtime_scaling_summary.png")
    shell:
        """
        PF_SITE=local ./scripts/local/wrapper.sh python3 scripts/plot-onnx-summary.py --indir {ONNX_DIR}/{wildcards.device} --outdir {ONNX_DIR}/{wildcards.device}/plots
        """
