from collections import defaultdict
import json
import os
import re

import config


REGEX_ITERATED_ALGORITHM = re.compile(r"^(\d\d\d\d-\d\d-\d\d-\w+-T)\d+(-.*)$")


""" Properties of an experiment run of Lab """
PROPERTY_DOMAIN = "domain"
PROPERTY_PROBLEM = "problem"
PROPERTY_ALGORITHM = "algorithm"
PROPERTY_COVERAGE = "coverage"
PROPERTY_TOTAL_TIME = "total_time"
PROPERTY_EXPANSIONS = "expansions"

ATTRIBUTE_ABBREVIATION = {
    PROPERTY_TOTAL_TIME: "T",
    PROPERTY_EXPANSIONS: "E",
}


""" Functions used as type for argparse """
def type_is_file(arg):
    if not os.path.isfile(arg):
        raise argparse.ArgumentTypeError(f"Not a file: {arg}")
    return arg


def natural_sort(l):
    def convert(text): return int(text) if text.isdigit() else text.lower()
    def alphanum_key(key): return [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key=alphanum_key)

ALGORITHM_ORDER = ["hBoot", "hBExp", "hAVI", "hSL", "hHGN", "hFF", "LAMA"]

def my_algo_sort(l):
    result = []
    remaining = []
    for a in ALGORITHM_ORDER:
        if a in l:
            result.append(a)
        else:
            remaining.append(a)
    return result + natural_sort(remaining)


def load_properties(files_properties, algorithm_filters=None, property_name=PROPERTY_TOTAL_TIME):
    """
    Loads the coverage and timing information from the given properties.
    :param files_properties: iterable of paths to properties files
    :param algorithm_filters: iterable of regular expressions to filter which
                              algorithms to keep (algorithm has to match all
                              filters)
    :return: {Planning Domain: {State Space Task: {Algorithm:
                 [{solved task id: time}, set(attempted task ids)]}}}
    """
    if files_properties is None:
        return None
    algorithm_filters = [] if algorithm_filters is None else algorithm_filters
    print("Load properties...")
    # {Planning Domain: {State Space Task: {Algorithm: [{solved task id: time}, set(attempted task ids)]}}}
    data = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: [dict(), set()])))
    all_algorithms = set()

    for file_properties in files_properties:
        print("\t", file_properties)
        with open(file_properties, "r") as f:
            properties = json.load(f)

        for props in properties.values():
            algorithm = props[PROPERTY_ALGORITHM]
            if not all([af.match(algorithm) for af in algorithm_filters]):
                continue
            all_algorithms.add(algorithm)

            planning_domain = os.path.dirname(props[PROPERTY_DOMAIN])
            state_space_task = config.rename_task(os.path.basename(props[PROPERTY_DOMAIN]))
            if not state_space_task.endswith(".pddl"):
                state_space_task += ".pddl"
            task_id = props[PROPERTY_PROBLEM]
            if task_id.find("source.pddl") > -1:
                continue
            assert task_id not in data[planning_domain][state_space_task][algorithm][1]
            if props.get(PROPERTY_COVERAGE, 0):
                total_time = props[property_name]
                data[planning_domain][state_space_task][algorithm][0][task_id] = total_time
            data[planning_domain][state_space_task][algorithm][1].add(task_id)
    print("\n\tAlgorithms found:", ", ".join(natural_sort(all_algorithms)))
    print("Load properties...Done.")
    return data


def reduce_retraining_iterations(data):
    """
    Given the loaded property data, algorithms which are retraining from the
    same algorithm are reduced to a single one (e.g. something is trained until
    the validation performance is good enough, then retraining is stopped. In
    this case it is reduced by removing the performance of the retrained
    iterations).
    :param data: {Planning Domain: {State Space Task: {Algorithm: [{solved task id: time}, set(attempted task ids)]}}}
    :return: {Planning Domain: {State Space Task: {Algorithm: [{solved task id: time}, set(attempted task ids)]}}}
    """
    print("Reducing iterated algorithms...")
    if data is None:
        return None
    all_algorithms = set()

    for task_algorithm_results in data.values():
        for task, algorithm_results in task_algorithm_results.items():
            # We have to modify only on the level of the algorithm_result dict
            # Find all algorithms which are reiterations of another
            group_algorithms = defaultdict(set)
            for algorithm, results in algorithm_results.items():
                m = REGEX_ITERATED_ALGORITHM.match(algorithm)
                if m:
                    base_name = REGEX_ITERATED_ALGORITHM.sub(r"\1X\2", algorithm)
                    group_algorithms[base_name].add(algorithm)

            # For the detected groups make a new algorithms, keep the results
            # from the last iteration.
            new_algorithms = defaultdict(lambda: [set(), set()])
            for group_name, group_algos in group_algorithms.items():
                group_coverage, group_attempted = dict(), set()
                group_algos = natural_sort(group_algos)[::-1]
                for algo in group_algos:
                    curr_coverage, curr_attempted = algorithm_results[algo]
                    group_coverage.update({
                        task_id: total_time
                        for task_id, total_time in curr_coverage.items()
                        if task_id not in group_attempted})
                    group_attempted.update(curr_attempted)
                new_algorithms[group_name][0] = group_coverage
                new_algorithms[group_name][1] = group_attempted

            for new_algorithm, new_results in new_algorithms.items():
                algorithm_results[new_algorithm] = new_results
            for old_algorithms in group_algorithms.values():
                for old_algorithm in old_algorithms:
                    del algorithm_results[old_algorithm]
            all_algorithms.update(algorithm_results.keys())

    print("\tAlgorithms remaining:", ", ".join(all_algorithms))
    print("Reducing iterated algorithms...Done.")
    return data
