from __future__ import print_function

from . import canonicalization

from .. import parser
from .. import parser_tools as parset
from .. parser_tools import ArgumentException, main_register

from ...translate import pddl, pddl_parser, translator
from ...translate.pddl import Atom, TypedObject, Predicate

import collections
import json
import multiprocessing
import os
import sys
import time

if sys.version_info < (3,):
    py_print = print
    def print(*args, **kwargs):
        flush = kwargs.get("flush", False)
        file = kwargs.get("file", sys.stdout)
        if "flush" in kwargs:
            del kwargs["flush"]
        py_print(*args, **kwargs)
        if flush:
            file.flush()


def forward_translator(args):
    counter = args[0]
    x = translator.main(args[1:])
    counter[1] += 1
    if counter[0] > 5:
        print("Problem loaded: %i/%i" % (counter[1], counter[2]))
    return x


class ProblemIterationWrapper(object):
    HAS_TUPLE = "tuples"
    HAS_STR = "str"
    def __init__(self, files_problem):
        assert len(files_problem) > 0, "At least one problem has to be given."
        assert all(type(files_problem[0]) == type(file_problem)
                   for file_problem in files_problem), \
            ("All given problems have to be of the same kind "
             "(task objects or file paths)")
        assert (isinstance(files_problem[0], tuple)
                or isinstance(files_problem[0], str)),\
            ("All given problems have to be either only (pddl_task, sas_task) "
             "tuples or paths to problems")
        self.is_type = (ProblemIterationWrapper.HAS_TUPLE
                        if isinstance(files_problem[0], tuple) else
                        ProblemIterationWrapper.HAS_STR)
        self.files_problem = files_problem
        self.any_pddl = None
        self.any_sas = None

    def _load_problem(self, file_domain, file_problem, return_sas):
        if not return_sas:
            return pddl_parser.open(domain_filename=file_domain,
                                    task_filename=file_problem), None
        else:
            translator_args = [file_domain, file_problem, "--no-sas-file",
                           "--log-verbosity", "ERROR"]
            pddl, sas = translator.main(translator_args)
            return pddl, sas


    def _find_domain(self, file_problem):
        dir_problem = os.path.dirname(file_problem)
        file_domain = os.path.join(dir_problem, "domain.pddl")
        if os.path.isfile(file_domain):
            return file_domain
        assert False, "Unable to detect the domain file"

    def __len__(self):
        return len(self.files_problem)

    def iter(self, yield_pddl=True, yield_sas=True):
        assert yield_pddl or yield_sas
        for problem in self.files_problem:
            if self.is_type == ProblemIterationWrapper.HAS_TUPLE:
                pddl, sas = problem
            elif self.is_type == ProblemIterationWrapper.HAS_STR:
                file_domain = self._find_domain(problem)
                pddl, sas = self._load_problem(
                    file_domain, problem, return_sas=yield_sas)
            else:
                assert False, "Internal Error: Invalid problem type"

            if yield_pddl:
                self.any_pddl = pddl
            if yield_sas:
                self.any_sas = sas

            if yield_pddl and not yield_sas:
                yield pddl
            elif not yield_pddl and yield_sas:
                yield sas
            else:
                yield pddl, sas


def domain_from_list_representation(list_domain):
    (domain_name, domain_requirements, types, type_dict, constants,
     predicates, predicate_dict, functions, actions, axioms) \
        = pddl_parser.parsing_functions.parse_domain_pddl(list_domain)
    return pddl.tasks.Domain(
        domain_name, domain_requirements, types, constants, predicates,
        functions, actions, axioms)


class DomainProperties(object):
    """
    Manage the analysed properties of a domain.
    """
    def _set_problems(self, problems):
        if problems is None or len(problems) == 0:
            self.__problems = None
        else:
            self.__problems = ProblemIterationWrapper(problems)

    def _get_problems(self):
        return self.__problems

    problems = property(_get_problems, _set_problems)

    def __init__(self, domain, problems=None,
                 fixed_world=None, fixed_objects=None, instantiated_types=None,
                 gnd_static=None, gnd_flexible=None, no_gnd_static=False,
                 tpl_pred_static=None, pred_static=None, state_space_size=None,
                 upper_bound_reachable_state_space_size=None,
                 run_analysis=False, path_load=None, analysis_time=None, total_time=None,
                 analysed=False, domain_raw=None):

        """
        If no problems are given, the output is undefined.
        :param domain: Domain object (from translate) or list representation of
                       the domain PDDL (which is then translated to Domain object)
                       Attention: Domain object cannot be stored, list representation
                       will be stored.
        :param problems: List of tuple of PDDL and SAS Problem objects (from translate)
        :param real_predicates: List of predicate names without "=" predicate
        :param fixed_world: True if for all problems the same objects are used
        :param fixed_objects: Set of objects shared by all problems
        :param instantiated_types: object types instantiated by at least one problem
        :param gnd_static: Set of predicates static for ALL problems (in which
                           they exist)
        :param gnd_flexible: Set of predicates flexible in at least one problem
        :param no_gnd_static: Does not calculate gnd_static (and everything
                              depending)
        :param tpl_pred_static: Set of predicate templates which appear only
                                static and never flexible in the problems
        :param pred_static:     Set of predicates which are always static
                                (all its templates are static)
        :param run_analysis: Run analysis after construction
        :param path_load: Load DomainProperties from there (disables run_analysis,
                     will have no self.domain and self.problems)
        :param analysed: flag shows if the given domain and problems have been
                         analysed.
        :param domain_raw: The list representation of the domain file (if
                           'domain' is already given as list representation, then
                           the value is taken from that argument and 'domain_raw'
                           has no effect.
        """
        if isinstance(domain, pddl.tasks.Domain) or domain is None:
            self.domain = domain
            self.domain_raw = domain_raw
        else:
            self.domain_raw = domain
            self.domain = domain_from_list_representation(self.domain_raw)

        self.domain = domain
        self.problems = problems
        self.real_predicates = None if self.domain is None else [x for x in self.domain.predicates if x.name != "="]

        self.fixed_world = fixed_world
        self.fixed_objects = fixed_objects
        self.__fixed_objects_by_type = None
        self.instantiated_types = instantiated_types


        self.gnd_flexible = gnd_flexible
        self.gnd_static = gnd_static
        self.no_gnd_static = no_gnd_static

        self.tpl_pred_static = tpl_pred_static
        self.pred_static = pred_static

        self.state_space_size = state_space_size
        self.upper_bound_reachable_state_space_size = upper_bound_reachable_state_space_size
        self.analysed = analysed

        self.analysis_time = analysis_time
        self.total_time = total_time

        self.path_load = path_load
        if self.path_load is not None:
            self.load(self.path_load)

        if run_analysis:
            self.analyse()

    def __get_fixed_objects_by_type(self):
        assert self.fixed_world
        if self.__fixed_objects_by_type is None:
            self.__fixed_objects_by_type = collections.defaultdict(list)
            for obj in self.fixed_objects:
                type_name = obj.type_name
                while type_name is not None:
                    self.__fixed_objects_by_type[type_name].append(obj)
                    type_name = self.domain.type_hierarchy[type_name]
        return self.__fixed_objects_by_type
    fixed_objects_by_type = property(__get_fixed_objects_by_type)



    def _analyse_fixed_world(self):
        if self.problems is None:
            self.fixed_world = None
        else:
            self.fixed_world = True
            self.fixed_objects = None
            for pddl in self.problems.iter(yield_sas=False):
                new_objects = set(pddl.objects)
                if self.fixed_objects is None:
                    self.fixed_objects = new_objects
                else:
                    if new_objects != self.fixed_objects:
                        self.fixed_world = False
                        self.fixed_objects &= new_objects


    def _analyse_instantiated_types(self):
        self.instantiated_types = set()
        for pddl in self.problems.iter(yield_sas=False):
            for obj in pddl.objects:
                self.instantiated_types.add(obj.type_name)

    def _analyse_static_flexible_grounded_predicates_fixed_world(self):
        # Get flexible because changeable
        some_problem_flexible = set()
        for (pddl_task, sas_task) in self.problems.iter():
            problem_flexible = set()
            for var_names in sas_task.variables.value_names:
                for var_name in var_names:
                    if var_name.startswith("NegatedAtom") or var_name == "<none of those>":
                        continue
                    var_atom = Atom.from_string(var_name)
                    problem_flexible.add(var_atom)
            some_problem_flexible.update(problem_flexible)



        # Get flexible because different initial states
        init_inter = None
        init_union = set()
        for pddl_task in self.problems.iter(yield_sas=False):
            new_init = set([x for x in pddl_task.init
                            if isinstance(x, Atom) and not x.predicate == "="])
            if init_inter is None:
                init_inter = new_init
            else:
                init_inter &= new_init
            init_union |= new_init
        init_flexible = init_union - init_inter

        self.gnd_flexible = some_problem_flexible | init_flexible
        if not self.no_gnd_static:
            self.gnd_static = (
                    set(self.problems.any_pddl.get_grounded_predicates()) -
                               self.gnd_flexible)

    def _analyse_static_flexible_grounded_predicates(self):
        if self.fixed_world:
            self._analyse_static_flexible_grounded_predicates_fixed_world()
        else:
            raise NotImplementedError("Analyse static flexible predicates not implemented for not fixed worlds. (e.g. world is variable OR no problems for determining were provided")

    def _analyse_static_predicates_and_templates(self):
        """
        1. Analyse which predicate templates are always static
        2. Analyse which predicate has only static templates = predicate only
           used in a static way.
        :return:
        """
        if self.no_gnd_static:
            return
        # Find predicate templates which COULD be always static and to ignore
        predicate_types = {}
        static_predicate_templates = set()
        for gp in self.gnd_static:
            if gp.predicate not in predicate_types:
                args = self.domain.predicate_dict[gp.predicate].arguments
                predicate_types[gp.predicate] = [x.type_name for x in args]
            canonized = canonicalization.canonize_object_lists(
                [gp],
                input_format=canonicalization.Format.ATOM,
                output_format=canonicalization.Format.ATOM,
                types=[predicate_types[gp.predicate]])
            static_predicate_templates.add(canonized[0])

        # Remove predicate templates which are not static
        flexible_predicates = set()
        for gp in self.gnd_flexible:
            if gp.predicate not in predicate_types:
                continue #  predicate not previously encountered, no need to rmv
            flexible_predicates.add(gp.predicate)
            canonized = canonicalization.canonize_object_lists(
                [gp],
                input_format=canonicalization.Format.ATOM,
                output_format=canonicalization.Format.ATOM,
                types=[predicate_types[gp.predicate]])
            static_predicate_templates.discard(canonized[0])
        self.tpl_pred_static = static_predicate_templates
        self.pred_static = set(predicate_types.keys()) - flexible_predicates


    def _analyse_combined_state_space_sizes(self):
        if self.fixed_world:
            self.combined_state_space_size = 2 ** len(self.gnd_flexible)
            self.combined_reachable_state_space_upper_bound = (
                self.problems.any_sas.variables.get_state_space_size())
            diff_goals = set()
            for pddl_task in self.problems.iter(yield_sas=False):
                diff_goals.add(pddl_task.goal)

            self.combined_reachable_state_space_upper_bound *= len(diff_goals)
            self.combined_state_space_size *= len(diff_goals)



        else:
            raise NotImplementedError("Calc combined state space size for not "
                                      "fixed world")
            # TODO Implement
            pass

    def analyse(self):
        start_time = time.time()
        self._analyse_fixed_world()
        if self.problems is not None:
            self._analyse_instantiated_types()
            self._analyse_static_flexible_grounded_predicates()
            self._analyse_static_predicates_and_templates()
            self._analyse_combined_state_space_sizes()
        self.analysis_time = time.time() - start_time
        self.analysed = True

    def get_flexible_state_size(self):
        assert self.fixed_world, "Not sure if everything is valid otherwise"
        assert self.analysed, "Properties needs to be analysed to get this information"
        return len(self.gnd_flexible)

    def get_flexible_state_atoms(self):
        return set(self.gnd_flexible)

    def get_full_state_size(self):
        assert self.fixed_world, "Not sure if everything is valid otherwise"
        if self.gnd_static is not None:
            return self.get_flexible_state_size() + len(self.gnd_static)
        else:
            typed_object_counts = {}
            for to in self.fixed_objects:
                if to.type_name not in typed_object_counts:
                    typed_object_counts[to.type_name] = 0
                typed_object_counts[to.type_name] += 1
            for type_name, count in dict(typed_object_counts).items():
                while True:
                    assert type_name in self.domain.type_hierarchy
                    type_name = self.domain.type_hierarchy[type_name]
                    if type_name is None:
                        break
                    if type_name not in typed_object_counts:
                        typed_object_counts[type_name] = 0
                    typed_object_counts[type_name] += count


            size = 0
            for predicate in self.real_predicates:
                pcount = 1
                for to in predicate.arguments:
                    pcount *= typed_object_counts.get(to.type_name, 0)
                size += pcount
            return size

        assert False

    def get_full_state_atoms(self):
        assert self.fixed_world, "Not sure if everything is valid otherwise"
        if self.gnd_static is not None:
            return self.gnd_flexible + self.gnd_static
        else:
            assert False, "Not Implemented"

    def store(self, path):
        def lst_str_ifnn(x):
            return None if x is None else [str(y) for y in x]
        d = {}
        d["domain_raw"] = self.domain_raw
        d["real_predicates"] = lst_str_ifnn(self.real_predicates)
        d["fixed_world"] = self.fixed_world
        d["fixed_objects"] = lst_str_ifnn(self.fixed_objects)
        d["instantiated_types"] = [x for x in self.instantiated_types]
        d["gnd_flexible"] = lst_str_ifnn(self.gnd_flexible)
        d["gnd_static"] = lst_str_ifnn(self.gnd_static)
        d["no_gnd_static"] = self.no_gnd_static
        d["tpl_pred_static"] = lst_str_ifnn(self.tpl_pred_static)
        d["pred_static"] = lst_str_ifnn(self.pred_static)
        d["state_space_size"] = self.state_space_size
        d["upper_bound_reachable_state_space_size"] = self.upper_bound_reachable_state_space_size
        d["analysed"] = self.analysed
        d["analysis_time"] = self.analysis_time
        d["total_time"] = self.total_time

        with open(path, "w") as f:
            json.dump(d, f)

    def load(self, path):
        def from_strings(x, clazz, as_set=False):
            if x is None:
                return None
            atoms = [clazz.from_string(y) for y in x]
            return set(atoms) if as_set else atoms

        with open(path, "r") as f:
            d = json.load(f)

        self.domain_raw = d["domain_raw"]
        if self.domain_raw is not None:
            self.domain = domain_from_list_representation(self.domain_raw)
        self.real_predicates = from_strings(d["real_predicates"], Predicate)

        self.fixed_world = d["fixed_world"]
        self.fixed_objects = from_strings(d["fixed_objects"], TypedObject, as_set=True)
        self.instantiated_types = set(d["instantiated_types"])

        self.gnd_flexible = from_strings(d["gnd_flexible"], Atom, as_set=True)
        self.gnd_static = from_strings(d["gnd_static"], Atom, as_set=True)
        self.no_gnd_static = d["no_gnd_static"]

        self.tpl_pred_static = from_strings(d["tpl_pred_static"], Atom,
                                            as_set=True)
        self.pred_static = None if d["pred_static"] is None else set(d["pred_static"])

        self.state_space_size = d["state_space_size"]
        self.upper_bound_reachable_state_space_size = d["upper_bound_reachable_state_space_size"]
        self.analysed = d["analysed"]

        self.analysis_time = d["analysis_time"]
        self.total_time = d["total_time"]

    @staticmethod
    def sload(path):
        dp = DomainProperties(None)
        dp.load(path)
        return dp

    @staticmethod
    def get_property_for(*paths, **kwargs):
        """
        Analyses the problems in the given paths and returns a DomainProperty
        containing the analysis. Every given path is interpreter as a directory
        in which problem files are searched for. ALL directories are expected to
        contain problems of the SAME domain.
        :param paths: Sequence of directory paths in which all problem files are
                      analysed
        :param paths_problems: Iterable of paths to problem files to analyse
        :param path_domain: Path to the domain file. If not given, then a domain file
                            is searched in the given paths
        :param preload_tasks: True => Load task objects ONCE prior to creating
                              the DomainProperties object. False => pass problem
                              paths to constructor. Every time a tasks object is
                              needed, it is parsed again (Time VS Memory)
        :param no_gnd_static does not provide the data for the gnd static predicates
        :param verbose: Verbosity level (0 No outputs, 1 timings, 5 current problem loading)
        :param parallize: Parallizes the problem loading
        :param load: Loads DomainProperty from there (does not load the problems
                     anymore nor analyse again the properties)
        :param store: Stores DomainProperty there
        :param store_atoms: if not None and the DomainProperty is a fixed
                            universe, then the atoms of of the Fixed Universe
                            are stored in it as json (see atoms.json for example)
        :return: DomainProperty containing analysis results
        """
        start_time = time.time()

        paths_problems = kwargs.pop("paths_problems", None)
        path_domain = kwargs.pop("path_domain", None)
        preload_tasks = kwargs.pop("preload_tasks", False)
        no_gnd_static = kwargs.pop("no_gnd_static", False)

        load = kwargs.pop("load", None)
        store = kwargs.pop("store", None)

        verbose = kwargs.pop("verbose", 0)
        parallize = kwargs.pop("parallize", 0)

        store_atoms = kwargs.pop("store_atoms", False)

        dp = None
        # If load => Load previous instance
        if load is not None:
            load_time = time.time()
            if verbose > 0:
                print("Loading DomainProperties...", end="", flush=True)
            dp = DomainProperties.sload(load)
            load_time = time.time() - load_time
            if verbose > 0:
                print("Done (%.2f)." % load_time)

        # Analyse a new instance
        else:
            if len(paths) == 0 and path_domain is None:
                raise ValueError("No domain file given and no directories to look "
                                 "for it.")

            # Find domain file
            if path_domain is None:
                for dir in paths:
                    tmp = os.path.join(dir, "domain.pddl")
                    if os.path.isfile(tmp):
                        path_domain = tmp
                        break
            if path_domain is None:
                for dir in paths:
                    for item in os.listdir(dir):
                        path_item = os.path.join(dir, item)
                        if (item.endswith(".pddl")
                                and item.find("domain") != -1)\
                                and os.path.isfile(path_item):
                            path_domain = path_item
                            break
                    if path_domain is not None:
                        break

            # Detect problem files to analyse
            # TODO improve problem file checks
            paths_problems = [] if paths_problems is None else [p for p in paths_problems]
            for dir in paths:
                for item in os.listdir(dir):
                    path_item = os.path.join(dir, item)
                    if (path_item.endswith(".pddl")
                        and path_item.find("domain") == -1
                        and os.path.isfile(path_item)):
                        paths_problems.append(path_item)


            translator_args = [path_domain, None, "--no-sas-file",
                               "--log-verbosity", "ERROR"]

            domain_time = time.time()
            if verbose > 0:
                print("Loading Domain...", end="", flush=True)
            domain_raw = pddl_parser.pddl_file.parse_pddl_file("domain", path_domain)
            domain = domain_from_list_representation(domain_raw)
            domain_time = time.time() - domain_time
            if verbose > 0:
                print("Done (%.2fs)." % domain_time)


            problems_time =  time.time()
            if verbose > 0:
                print("Translating Problems...", end="", flush=True)
            if len(paths_problems) == 0:
                problems = None
            elif not preload_tasks:
                problems = paths_problems
            elif parallize:
                m = multiprocessing.Manager()
                mp_counter = m.list([verbose, 0, len(paths_problems)])
                with multiprocessing.Pool(processes=None) as pool:
                    arguments = [[mp_counter, translator_args[0]] + [p] + translator_args[2:] for p in paths_problems]
                    problems = pool.map(forward_translator, arguments)
            else:
                problems = []
                for no, path_problem in enumerate(paths_problems):
                    if verbose > 4:
                        print("\t%i/%i\t%s" % (no + 1, len(paths_problems), path_problem))
                    translator_args[1] = path_problem
                    problems.append(translator.main(translator_args))

            problems_time = time.time() - problems_time
            if verbose > 0:
                print("Done (%.2f)." % problems_time)


            properties_time = time.time()
            if verbose > 0:
                print("Analysing Domain Properties...", end="", flush=True)
            dp = DomainProperties(domain, problems, no_gnd_static=no_gnd_static,
                                  run_analysis=True, domain_raw=domain_raw)
            properties_time = time.time() - properties_time
            if verbose > 0:
                print("Done (%.2f)." % properties_time)
            dp.total_time = time.time() - start_time

        # Store
        if store is not None:
            store_time = time.time()
            if verbose > 0:
                print("Storing DomainProperties...", end="", flush=True)
            dp.store(store)
            store_time = time.time() - store_time
            if verbose > 0:
                print("Done (%.2f)." % store_time)

        if store_atoms:
            assert dp.fixed_world
            store_time = time.time()
            if verbose > 0:
                print("Storing Atoms...", end="", flush=True)
            atoms = {}
            atoms["PDDL_ATOMS_FLEXIBLE"] = [str(x)
                                            for x in sorted(dp.gnd_flexible)]
            if dp.problems is not None and len(dp.problems) > 0:
                atoms["PDDL_ATOMS"] = dp.problems.any_pddl.str_grounded_predicates(
                    sort=True)
            elif dp.gnd_static is not None:
                atoms["PDDL_ATOMS"] = sorted(
                    str(x) for x in dp.gnd_static | dp.gnd_flexible)
            else:
                print("Cannot store 'PDDL_ATOMS'.")
            with open(store_atoms, "w") as f:
                json.dump(atoms, f, sort_keys=True)
            store_time = time.time() - store_time
            if verbose > 0:
                print("Done (%.2f)." % store_time)

        return dp



class DomainPropertyParserWrapper(object):
    argument = parset.ClassArguments("DomainProperty", None,
        ("domain", False, None, str, "Path to the domain file"),
        ("problems", True, None, str, "List of paths to problems to analyse"),
        ("id", True, None, str, "ID"))

    def __init__(self, domain, problems=None, id=None):
        self.domain = domain
        self.problems = [] if problems is None else problems
        self.problems = self.problems if isinstance(self.problems, list) else [self.problems]
        self.id = id

        self.dp = DomainProperties.get_property_for(path_domain=self.domain,
                                                    paths_problems=self.problems)

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  DomainPropertyParserWrapper)


main_register.append_register(DomainPropertyParserWrapper, "dp", "domain_property")
