import argparse
import pathlib
import xml.etree.ElementTree as ET
import lxml.etree as L_ET

import pandas as pd
from ogs6py import ogs


def get_parameter_dict(process):
    H_dict = {'hydraulic_permeability_[m^2]_mean': 'permeability',
              'Porosity_[-]_mean': 'porosity'}
    T_dict = {'thermal_conductivity_[W/m/K]_mean': 'thermal_conductivity',
              'specific_isobar_heat_capacity_[J/kg/K]_mean':
              'specific_heat_capacity',
              'solid_mass_density_[kg/m^3]_mean': 'density',
              'Porosity_[-]_mean': 'porosity'}
    M_dict = {'solid_mass_density_[kg/m^3]_mean': 'density',
              'Young_modulus_[Pa]_mean': 'YoungModulus',
              'Poisson_ratio_[-]_mean': 'PoissonRatio',
              'internal_friction_[deg]_mean': 'FrictionAngle',
              'Porosity_[-]_mean': 'porosity'}
    HM_dict = {'Biot_coefficient_[-]_mean': 'biot_coefficient'}
    TM_dict = {'thermal_expansion_coefficient_[1/K]_mean':
               'thermal_expansivity'}
    all_process_variables = {'H': H_dict, 'T': T_dict, 'M': M_dict,
        'HM': {**H_dict, **M_dict, **HM_dict},
        'TM': {**T_dict, **M_dict, **TM_dict},
        'TH': {**T_dict, **H_dict},
        'THM': {**T_dict, **H_dict, **M_dict, **HM_dict, **TM_dict}
        }
    return all_process_variables[process]


def csv2pandas(set_id: int, param_table, set_table, process):
    """ Filters specific model from csv parameter file """

    param_dict = get_parameter_dict(process)
    #choose subset of parameter table according to toy.csv
    dfs = pd.read_csv(set_table)
    dfs = dfs[dfs["set_id"] == set_id]
    if len(dfs) == 0:
        raise Exception(f'no model defined with {set_id=}')
    mat_ids = dfs["material_id"].to_numpy()[1:]
    # filter and rename
    dfp = pd.read_csv(param_table)
    dfp = dfp[dfp['layer_id'].isin(mat_ids)]
    index = ["layer_id", "model_unit", "rock_type", "H_deactivated"]
    dfp = dfp.rename(param_dict, axis='columns')
    dfp = dfp[index + list(param_dict.values())]
    dfp['row_num'] = dfp.reset_index().index
    dfp = dfp.set_index('row_num')
    return dfp


def add_deactivated_subdomains_block(prj: ogs.OGS, process_var, id_list):
    xp_pressure = f"./process_variables/process_variable[name='{process_var}']"
    xp_ds = xp_pressure + "/deactivated_subdomains/deactivated_subdomain"
    prj.add_block("deactivated_subdomains", None, xp_pressure,
                  ["deactivated_subdomain"], [""])
    prj.add_block("time_interval", None, xp_ds, ["start", "end"], ["0", "1e15"])
    prj.add_entry(xp_ds, "material_ids", " ".join(map(str, id_list)))


def pandas2prj(param_df, template_prj, lin_solvers, time_stepping, 
               output_prj, process, mat_law_params):
    """ Creates OGS prj file with references to media xmls """

    p_out = pathlib.Path(output_prj).parent
    p_out.mkdir(parents=True, exist_ok=True)
    main_prj = ogs.OGS(INPUT_FILE=template_prj, PROJECT_FILE=output_prj)

    lin_solvers_root = L_ET.parse(lin_solvers).getroot()
    main_prj._get_root().insert(-1, lin_solvers_root)

    time_stepping_root = L_ET.parse(time_stepping).getroot()
    xp_time_stepping = ".//time_loop/processes/process"
    main_prj._get_root().find(xp_time_stepping).insert(0, time_stepping_root)

    H_deactivated = param_df[param_df['H_deactivated']]["layer_id"].tolist()
    if H_deactivated:
        add_deactivated_subdomains_block(main_prj, "pressure", H_deactivated)

    # the M process does not require any media properties
    if process != "M":
        main_prj.add_entry("./media", "include", None, "file", "media.xml")

    # constitutive relation only required if the process includes M
    if "M" in process:
        main_prj.add_entry(".processes/process", tag="include", attrib="file",
                           attrib_value="constitutive_relations.xml")

        for mat_law_param_file in mat_law_params:
            for param in ET.parse(mat_law_param_file).getroot():
                main_prj.add_block(param.tag, None, "./parameters",
                                   [e.tag for e in param.iter()][1:],
                                   [e.text for e in param.iter()][1:])

        # Add index values for group based parameters rho, E, nu
        # (group bases parameter not yet supported natively by ogs6py)
        for param in ['YoungModulus', 'PoissonRatio', 'density', 'FrictionAngle', 'porosity']:
            pname = param
            for value, id in zip(param_df[param], param_df["layer_id"]):
                if param == 'porosity':
                    pname = 'InitialPorosity'
                if param == 'FrictionAngle':
                    # calculate critical state line slope M from friction angle
                    M = value /25 # ~= 6*sin(value) / (3-sin(value))
                    value = M
                    pname = 'CriticalStateLineSlope'
                
                text = f"<index>{id}</index><value>{value}</value>"
                main_prj.add_entry(f"./parameters/parameter[name='{pname}']", 
                                   tag="index_values", text=text)

    main_prj.write_input()

    # ogs6py codes "<" as "&lt;" and ">" as  "&gt;",
    # although XML readers should accept this, we replace it for sake of beauty
    with open(output_prj, 'r') as main_file:
        main_filedata = main_file.read()
        main_filedata = main_filedata.replace('&lt;', '<')
        main_filedata = main_filedata.replace('&gt;', '>')
    with open(output_prj, 'w') as main_file:
        main_file.write(main_filedata)
    return output_prj


def csv2prj(param_table, set_table, template_prj, lin_solvers, time_stepping,
            mat_law_params=[], set_id=0, project="", output_prj=""):
    print("csv2prj: " + template_prj)
    process = project.replace("sim", "").split("-")[0]
    param_df = csv2pandas(set_id, param_table, set_table, process)
    pandas2prj(param_df, template_prj, lin_solvers, time_stepping, 
               output_prj, process, mat_law_params)


if __name__ == '__main__':
    """
    Generates prj and media xml files from csv parameter table
    Take all rows with specified model from csv_parameter_file.
    Create for all rows a material xml based on a material template file 
    found on prj_template_folder. Reference the files in 
    process template found on prj_template folder and named process.prj"
    """

    parser = argparse.ArgumentParser()
    parser.add_argument("param_table", help="material parameter csv")
    parser.add_argument("set_table", help="set definitions csv")
    parser.add_argument("template_prj", help="template OGS prj")                        
    parser.add_argument("lin_solvers", help="linear_solvers xml")
    parser.add_argument("time_stepping", help="timestepping xml")
    parser.add_argument("ml_params", help="material law params xml", nargs="*")
    parser.add_argument("--id", help="set ID of model", type=int)
    parser.add_argument("--type", help="project name, e.g. simTHM[-subtype]")
    parser.add_argument("--out", help="resulting OGS prj")

    args = parser.parse_args()
    csv2prj(*vars(args).values())
