localrules:
    constitutive_relations_xml,


localrules:
    media_xml,


localrules:
    project_file,
    project_includes_optional,


def get_process_subtype(project):
    if "-" in project:
        return project.split("-")
    return project, ""


# using different time steppings for different models and processes because:
# with salt models, the creep model requires small time steps at the beginning
# to successfully converge, whereas with the "Ton" models the small timesteps
# would cause numerical instabilities
def time_stepping(wcs):
    process = get_process_subtype(wcs.project)[0].replace("sim", "")
    stage = stage_type(wcs)
    ts_file = f"input/processes/timestepping/{stage}_{wcs.mesh_name}_{process}.xml"
    if not os.path.isfile(ts_file):
        ts_file = f"input/processes/timestepping/{stage}_default.xml"
    # print(f'Using {ts_file} for time stepping')
    return ts_file


# returns a subtype-specific medium.xml file if it exists,
# otherwise the base medium.xml file for that process
def medium(wcs):
    process = get_process_subtype(wcs.project)[0]
    med_file = f"input/processes/medium/{wcs.project}_medium.xml"
    if not os.path.isfile(med_file):
        med_file = f"input/processes/medium/{process}_medium.xml"
    # print(f'Using {med_file} as the template medium')
    return med_file


def material_laws(wcs):
    csv = PARAM_TABLE.format(mesh_name=wcs.mesh_name)
    sets = SET_TABLE.format(mesh_name=wcs.mesh_name)
    process = wcs.project.replace("sim", "").split("-")[0]
    df = csv2pandas(int(wcs.layer_set_id), csv, sets, process)
    rock_types = list(df["rock_type"])
    law_dict = {"-": "Elastic", "salt": "PLLC", "clay": "Elastic"}
    law_path = "input/processes/material_laws/"
    return [
        law_path + law_dict[rock_type] + "_constitutive_relation.xml"
        for rock_type in rock_types
    ]


def material_law_params(wcs):
    laws = material_laws(wcs)
    params = [
        law.replace("constitutive_relation", "inelastic_params")
        for law in laws
        if not "Elastic" in law
    ]
    return list(set(params))


# returns a subtype-specific prj file if it exists,
# otherwise the base prj file for that process
def prj_template(wcs):
    process, subtype = get_process_subtype(wcs.project)
    stage = stage_type(wcs)
    default = f"input/processes/{wcs.rank}D/{process}_{stage}.prj"
    specific = f"input/processes/{wcs.rank}D/{wcs.project}_{stage}.prj"
    if subtype not in ["", "freezing", "phaseless", "nopy", "sig0"]:
        raise Exception(f"Could not recognize the {subtype=}.")
    if os.path.isfile(specific):
        return specific
    return default


rule constitutive_relations_xml:
    input:
        write_c_rels_py="ogsworkflowhelper/write_constitutive_relations_xml.py",
        csv=PARAM_TABLE,
        sets=SET_TABLE,
        mat_laws=material_laws,
    output:
        xml=f"{PRJ_PATH}/constitutive_relations.xml",
    shell:
        """
        python {input.write_c_rels_py} {input.csv} {input.sets} {wildcards.project} {wildcards.layer_set_id} {output.xml} {input.mat_laws}
        """


rule media_xml:
    input:
        write_media_py="ogsworkflowhelper/write_media_xml.py",
        csv=PARAM_TABLE,
        sets=SET_TABLE,
        medium=medium,
    output:
        xml=f"{PRJ_PATH}/media.xml",
    shell:
        """
        python {input} {wildcards.project} {wildcards.layer_set_id} {output.xml}
        """


def optional_constitutive_relations_xml(wcs):
    if "M" in wcs.project.replace("sim", "").split("-")[0]:
        return f"{PRJ_PATH}/constitutive_relations.xml"
    return []


def optional_media_xml(wcs):
    if wcs.project.replace("sim", "").split("-")[0] != "M":
        return f"{PRJ_PATH}/media.xml"
    return []


rule project_includes_optional:
    input:
        media_xml=optional_media_xml,
        c_rels=optional_constitutive_relations_xml,
    output:
        xml_list=f"{PRJ_PATH}/xml_includes.txt",
    run:
        aggregate(input, output.xml_list)


rule project_file:
    input:
        csv2prj_py="ogsworkflowhelper/csv2prj.py",
        csv=PARAM_TABLE,
        sets=SET_TABLE,
        prj_template=prj_template,
        py_boundary_conditions=lambda wcs: f"input/processes/boundary_conditions/boundary_conditions_{stage_type(wcs)}_template.py",
        lin_solvers="input/processes/linear_solvers.xml",
        includes=rules.project_includes_optional.output,
        time_stepping=time_stepping,
        mat_law_params=material_law_params,
    output:
        prj=f"{PRJ_PATH}/sim.prj",
        py=f"{PRJ_PATH}/boundary_conditions.py",
    shell:
        """
        cp {input.py_boundary_conditions} {output.py}
        python {input.csv2prj_py} {input.csv} {input.sets} {input.prj_template} \
            {input.lin_solvers} {input.time_stepping} {input.mat_law_params} \
            --id {wildcards.layer_set_id} --type {wildcards.project} --out {output.prj}
        """
