localrules: plot_timeseries, plot_timeseries_pdf, plot_slices, plot_slices_pdf, plot_slices_joined_pdf

def timestepsForPlots(stage):
    s_a = 365.25*24*3600
    if stage == "init":
        ts = [1e5 * t for t in range(11)]
    else:
        ts = [1e3 * t for t in [0.5, 1, 3, 5, 15, 25, 35, 45, 55, 70, 80, 90, 95, 100]]
    return [f"{t:.0f}".zfill(len(str(int(ts[-1])))) for t in ts]


def get_all_param_funcs(wcs):
    process = get_process_subtype(wcs.project)[0].replace("sim", "")
    p = []
    dim = int(wcs.rank)
    if "T" in process:
        p += ["temperature"]
    if "H" in process:
        p += ["pressure", "pressure_hydraulic-head", "velocity", "velocity_log"]
        p += [f"velocity_{i}" for i in range(dim)]
    if "M" in process:
        p += ["displacement"] + [f"displacement_{i}" for i in range(dim)]
        p += ["epsilon","epsilon_trace"] 
        p += [f"epsilon_{i}" for i in range(dim * 2)]
        p += ["sigma", "sigma_effective-pressure", "sigma_von-Mises-stress",
              "sigma_qp-ratio"]
        p += [f"sigma_{i}" for i in range(dim * 2)]
    return p


rule plot_timeseries:
    input:
        plot_py="ogsworkflowhelper/post_processing/plot_timeseries.py",
        pvd=OUT_DIR + f"sim/{{rank}}D/{MODEL}_{MESH}/{{stage}}/{{project}}{{optional_parts}}/timeseries.pvd",
        csv=PARAM_TABLE
    output:
        png=PLOT_DIR + f"{{rank}}D/{MODEL}_{MESH}/{{stage}}/{{project}}{{optional_parts}}/temporal/{{param_func}}.png" 
    shell:
        "python {input.plot_py} {output.png} {input.pvd} {wildcards.rank} {wildcards.mesh_name} {wildcards.param_func} {input.csv}"


rule plot_timeseries_pdf:
    input:
        aggregate_py="ogsworkflowhelper/post_processing/aggregate_pngs.py",
        pngs=lambda wcs: expand(rules.plot_timeseries.output.png, 
                                param_func=get_all_param_funcs(wcs), 
                                allow_missing=True)
    output:
        pdf = PLOT_DIR + f"{{rank}}D/{MODEL}_{MESH}/{{stage}}/{{project}}{{optional_parts}}/temporal/report.pdf"
    shell:
        "python {input.aggregate_py} {output.pdf} {input.pngs}"


rule plot_slices:
    input:
        "ogsworkflowhelper/post_processing/SlicePlotter.py",
        plot_py="ogsworkflowhelper/post_processing/plot_slices.py",
        pvd=OUT_DIR + f"sim/{{rank}}D/{MODEL}_{MESH}/{{stage}}/{{project}}{{optional_parts}}/timeseries.pvd",
        repo=lambda wcs: ancient(MSH_DIR + f"{{rank}}D/{MODEL}_{MESH}/glacialcycle/{{project}}/repo.vtu") if wcs.stage == "glacialcycle" else [],
        csv=PARAM_TABLE
    output:
        png=PLOT_DIR + f"{{rank}}D/{MODEL}_{MESH}/{{stage}}/{{project}}{{optional_parts}}/spatial/{{param_func}}/plot_{{t}}.png"
    shell:
        "python {input.plot_py} {output.png} {input.pvd} {wildcards.rank} {wildcards.mesh_name} {wildcards.param_func} {input.csv} {wildcards.t} {input.repo}"
        
        
rule plot_slices_pdf:
    input:
        aggregate_py="ogsworkflowhelper/post_processing/aggregate_pngs.py",
        pngs=lambda wcs: expand(PLOT_DIR + f"{{rank}}D/{MODEL}_{MESH}/{{stage}}/{{project}}{{optional_parts}}/spatial/{{param_func}}/plot_{{t}}.png", 
                                t=timestepsForPlots(wcs.stage), 
                                allow_missing=True) 
    output:
        pdf = PLOT_DIR + f"{{rank}}D/{MODEL}_{MESH}/{{stage}}/{{project}}{{optional_parts}}/spatial/{{param_func}}.pdf"
    shell:
        "python {input.aggregate_py} {output.pdf} {input.pngs}"


rule plot_slices_joined_pdf:
    input:
        aggregate_py="ogsworkflowhelper/post_processing/aggregate_pdfs.py",
        pdfs = lambda wcs: expand(rules.plot_slices_pdf.output, 
                                  param_func=get_all_param_funcs(wcs), 
                                  allow_missing=True)
    output:
        pdf = PLOT_DIR + "{rank}D/{mesh_name}_{mesh_type}-id_{layer_set_id}-xres_{x_res}/{stage}/{project}{optional_parts}/spatial/report.pdf"
    shell:
        "python {input.aggregate_py} {output.pdf} {input.pdfs}"
