#!/usr/bin/env python
"""
This script can either be directly executed or will be executed from within
run_compare_experiments.py

This scripts checks which settings of parameters P is the best when letting
all other parameters outside of P unchanged.

./compare_parameter_changes.py -csv /home/ferber/TMP/eval/comparison_coverage.csv -f (reg|cls)_(ns|full)_(bal|ubal)_(h5)_(sigmoid)_(init|inter)_(gen_)?(opt_|sat_)?(drp0) --compare-0-1
"""
import argparse
from collections import defaultdict
import itertools
import numpy as np
import os
import re
import sys


def fac_finder(sub):
    return lambda x: x.find(sub) > -1


# ({param value: callable returning True if configuration has value},
#  default value or None)
# If adding a new parameter, also update the regex in compare_experiments.py
PARAMETER_NETWORK_TYPE = ({"classification": fac_finder("cls_"),
                           "regression": fac_finder("reg_")}, None)
PARAMETER_STATE_FORMAT = ({"full": fac_finder("_full_"),
                           "non_static": fac_finder("_ns_")}, None)
PARAMETER_BALANCE = ({"balanced": fac_finder("_bal_"),
                      "unbalanced": fac_finder("_ubal_")}, None)
PARAMETER_HIDDEN = ({"h5" : fac_finder("_h5_")}, None)
PARAMETER_ACTIVATION = ({"sigmoid": fac_finder("_sigmoid_"),
                        "relu": fac_finder("_relu_")}, None)
PARAMETER_DROPOUT = ({"0": fac_finder("_drp0"),
                     "25": fac_finder("_drp25")}, None)
PARAMETER_L2REGULARIZATION = ({"L2_1e-2": fac_finder("_rlt1e-2")}, "L2_0")
PARAMETER_SAMPLE_STATE = ({"initial state": fac_finder("_init_"),
                          "intermediate state": fac_finder("_inter_"),
                           "full plan": fac_finder("_plan_")}, None)
PARAMETER_SAMPLE_SOLVING = ({"optimal": fac_finder("_opt_"),
                                  "satisficing": fac_finder("_sat_")}, "optimal")
PARAMETER_SAMPLE_CREATION = ({"generator": fac_finder("_gen_")}, "uniform")

PARAMETER_DATASET_RATIO = ({"all_available": fac_finder("Kall"),
                            "75%": fac_finder("K75"),
                            "50%": fac_finder("K50")}, "not applicable")
PARAMETER_TEST_VALID_SPLIT = ({"Random": fac_finder("_MK"),
                               "ByProblems": fac_finder("_K")}, "not applicable")
PARAMETER_PRUNING = ({"off": fac_finder("pruneOff")},
                     "full")

PARAMETERS = {
    "network type": PARAMETER_NETWORK_TYPE,
    "state format": PARAMETER_STATE_FORMAT,
    "balance": PARAMETER_BALANCE,
    "hidden layers": PARAMETER_HIDDEN,
    "activations": PARAMETER_ACTIVATION,
    "sampled state": PARAMETER_SAMPLE_STATE,
    "solving approach": PARAMETER_SAMPLE_SOLVING,
    "sample creation": PARAMETER_SAMPLE_CREATION,
    "dropout": PARAMETER_DROPOUT,
    "l2 regularizer": PARAMETER_L2REGULARIZATION,
    "fraction of samples (calculated of minimum applicable data set size)":
    PARAMETER_DATASET_RATIO,
    "Train-Validation-Test-Split": PARAMETER_TEST_VALID_SPLIT,
    "pruning": PARAMETER_PRUNING,
}
DEFAULT_CHOSEN_PARAMETERS = [x for x in PARAMETERS.keys()]

parser = argparse.ArgumentParser()
parser.add_argument("-csv", type=str, action="append", default=[],
                    help="CSV file containing an NxN matrix of pairwise "
                         "comparisons between N configurations. Positive "
                         "values mean that the configuration of the row is "
                         "better than the configuration of the column.")
parser.add_argument("-p", "--parameter", choices=[x for x in PARAMETERS.keys()],
                    action="append", default=None,
                    help="Selection of parameters which shall be checked. By "
                         "default all parameters are checked.")
parser.add_argument("-f", "--filter", type=str, action="append", default=[],
                    help="Python regex. Only configurations matching all given "
                         "regexes are considered")
parser.add_argument("--compare-0-1", action="store_true",
                    help="If not set, the scores are calculated by summing the"
                         "pairwise comparison scores from the CSV (e.g. use if "
                         "those values are already normalized). If set, each"
                         "pairwise comparison provides +1 for the better "
                         "configuration")
parser.add_argument("--invert-comparison", action="store_true",
                    help="Multiplies the pairwise comparison values by -1. Use "
                         "if a negative value represents a better setting.")


def get_associated_output_path(csv_file, compare_01=False):
    base, ext = os.path.splitext(csv_file)
    return base + "_parameters" + ("_01" if compare_01 else "") + ".txt"


def get_parameter_assignment(configurations, parameters):
    param2conf = {}  # {parameter: {parameter value: set(configurations)}}
    for parameter in parameters:
        pvalues, pdefault = PARAMETERS[parameter]
        passignment = defaultdict(set)
        for configuration in configurations:
            cvalue = None
            for pvalue, checker in pvalues.items():
                if checker(configuration):
                    assert cvalue is None, \
                        ("One configuration has two values (%s, %s) for a "
                         "single parameter (%s): %s") % \
                        (cvalue, pvalue, parameter, configuration)
                    cvalue = pvalue
            cvalue = pdefault if cvalue is None else cvalue
            assert cvalue is not None, "property cannot be not set"
            if cvalue is not None:
                passignment[cvalue].add(configuration)
        param2conf[parameter] = passignment

    conf2param = {}  # {configuration: {parameter: value}}
    for parameter, pvalues_confs in param2conf.items():
        for pvalue, configs in pvalues_confs.items():
            for config in configs:
                if config not in conf2param:
                    conf2param[config] = {}
                conf2param[config][parameter] = pvalue

    return param2conf, conf2param


def _matches_filters(configuration, filters):
    for f in filters:
        if not f.match(configuration):
            return False
    return True


def get_generator_all_pairs(l):
    for idx1, item1 in enumerate(l):
        for idx2 in range(idx1 + 1, len(l)):
            item2 = l[idx2]
            yield item1, item2


def compare_setting(setting, data, configurations, conf2idx, conf2param,
                    compare_0_1, invert_comparison):
    """

    :param setting:
    :param data:
    :param configurations: List of configuration names sorted like in data array
    :param conf2idx: {Configuration Name: index of configuration names row in data
    :param conf2param: {Configuration Name: { Parameter Name: Parameter Value}}
    :return:
    """
    setting = sorted(setting)

    # Calculate all parameter value settings and setup their scores
    N = 0  # Number of considered pairwise comparisons
    E = 0  # Number of skipped (e.g. 'NA' entries) comparisons
    S = 0  # Sum of comparison vales
    Z = 0  # Zero comparisons
    cp_scores = {
        cp: 0 for cp in itertools.product(
        *[([pvalue for pvalue in PARAMETERS[param][0].keys()] +
           ([] if PARAMETERS[param][1] is None else [PARAMETERS[param][1]]))
          for param in setting]
    )}

    # Configuration name mapped to the parameter value setting they belong to
    conf2cp = {}
    for configuration in conf2idx.keys():
        key = tuple(conf2param[configuration][parameter]
                    for parameter in setting)
        if key not in cp_scores:
            print("pause")
        assert key in cp_scores
        conf2cp[configuration] = key

    # Group configurations where all parameters are the same EXCEPT those
    # currently under consideration
    same_other_params = defaultdict(list)
    for configuration in configurations:
        key = tuple(sorted(
            (param, pvalue)
            for param, pvalue in conf2param[configuration].items()
            if param not in setting
        ))
        same_other_params[key].append(configuration)

    # For each group with same parameters (except for those under consideration)
    # pairwise check which parameter setting of those under consideration
    # performs better than the others
    for same_params, configs in same_other_params.items():
        for config1, config2 in get_generator_all_pairs(configs):
            cp_key1, cp_key2 = conf2cp[config1], conf2cp[config2]
            #assert cp_key1 != cp_key2, "Can this happend? %s, %s" % (cp_key1, cp_key2)
            try:
                comparison_value = float(data[conf2idx[config1],
                                              conf2idx[config2]])
                if invert_comparison:
                    comparison_value = -1 * comparison_value
                if compare_0_1:
                    comparison_value = (1 if comparison_value > 0 else
                                        (-1 if comparison_value < 0 else 0))

                if comparison_value > 0:
                    cp_scores[cp_key1] += comparison_value
                elif comparison_value < 0:
                    cp_scores[cp_key2] -= comparison_value
                else:
                    Z += 1
                N += 1
                S += abs(comparison_value)
            except ValueError:
                # E.g. 'NA', '-' entries
                E += 1
    return N, E, S, Z, cp_scores


def compare_array(ary, settings, filters, compare_0_1, invert_comparison):
    assert np.array_equal(ary[1:, 0], ary[0, 1:])
    all_parameters = set([j for i in settings for j in i])
    configurations = ary[1:, 0]
    data = ary[1:, 1:]

    configurations_mask = np.array([_matches_filters(c, filters)
                                    for c in configurations])

    configurations = configurations[configurations_mask]
    data = data[configurations_mask, :][:, configurations_mask]
    conf2idx = {conf: idx for idx, conf in enumerate(configurations)}
    param2conf, conf2param = get_parameter_assignment(configurations,
                                                      all_parameters)

    print("Detected configurations:\n\t%s" % "\n\t".join(sorted(configurations)))

    s = ""
    for setting in settings:
        N, E, S, Z, scores = compare_setting(setting, data, configurations,
                                             conf2idx, conf2param,
                                             compare_0_1, invert_comparison)
        s += "X".join(setting)
        s += "\nSum: %.2f\tCounted: %i\tZeros: %i\tSkipped: %i\n" % (S, N, Z, E)
        s += "\t".join("%s(%.2f)" % (label, score)
                       for label, score in scores.items())
        s += "\n\n\n"
    return s


def compare_parameters(csv_files, parameters=DEFAULT_CHOSEN_PARAMETERS,
                       filters=[], compare_0_1=False, invert_comparison=False):
    """

    :param csv_files: List of CSV files to analyse
    :param parameters: List of parameter collections where the cross product of
                       all parameters values in the collections will be compared.
                       If such a collection is solely the name of a property,
                       then it is taken as collection containing only the
                       property.
                       (not supported currently)
    :param filters: List of python regexes (string or compiled). A configuration
                    of the CSV file has to match all regexes.

    :return: None
    """
    csv_files = [csv_files] if isinstance(csv_files, str) else csv_files
    assert len(csv_files) > 0
    assert all(os.path.isfile(x) for x in csv_files)

    parameter_settings = [(para,) if isinstance(para, str) else para
                          for para in parameters]
    assert all(all(x in PARAMETERS for x in setting)
               for setting in parameter_settings)

    filters = [f if isinstance(f, re._pattern_type) else re.compile(f)
               for f in filters]

    for csv_file in csv_files:
        path_out = get_associated_output_path(csv_file, compare_0_1)
        ary = np.loadtxt(csv_file, delimiter=";", dtype=object)
        comparisons = compare_array(ary, parameter_settings, filters,
                                    compare_0_1, invert_comparison)
        with open(path_out, "w") as f:
            f.write(comparisons)


def run_from_cmd(args):
    options = parser.parse_args(args)
    if len(options.csv) == 0:
        raise ValueError("At least one csv file to compare should be given")
    parameters = (DEFAULT_CHOSEN_PARAMETERS if options.parameter is None
                  else options.parameter)
    options.filter = [re.compile(f) for f in options.filter]
    compare_parameters(options.csv, parameters, options.filter,
                       options.compare_0_1, options.invert_comparison)

if __name__ == "__main__":
    run_from_cmd(sys.argv[1:])
