from . import FieldParser, FIELD_PARSERS

from .base_field_parser import get_argument_for_parser

from .....translate import pddl

import numpy as np

# Arguments for the StateParser
ARG_CURRENT_FORMAT = "format"
ARG_NEW_FORMAT = "new_format"


# Definition of the different representation formats for a state
class StateFormat(object):
    main_name2obj = {}  # Single main name per StateFormat
    all_name2obj = {}  # names (including synonymies) refering to StateFormats

    def __init__(self, name, description, synonymies=None):
        self.name = name
        self.description = description
        self.synonymies = ([] if synonymies is None else
                           (synonymies if isinstance(synonymies, list) else
                            [synonymies]))
        self._add_to_enum()

    def _add_to_enum(self):
        StateFormat.main_name2obj[self.name] = self
        for name in [self.name] + self.synonymies:
            setattr(StateFormat, name, self)
            StateFormat.all_name2obj[name] = self

    @staticmethod
    def get(name):
        if name not in StateFormat.all_name2obj:
            raise ValueError("Unknown key for StateFormat: " + str(name))
        return StateFormat.all_name2obj[name]

    @staticmethod
    def get_formats():
        return [value for _, value in StateFormat.main_name2obj.items()]

    @staticmethod
    def get_format_names():
        return [key for key, _ in StateFormat.main_name2obj.items()]

    def __str__(self):
        return self.name


StateFormat("All_A_Name01",
            "All atoms (excluding negated atoms) are given in alphabetical "
            "order suffixed with \"1\" if true and \"0\" if false.",
            "Full")
StateFormat("All_A_01",
            "All atoms (excluding negated atoms) are given in alphabetical "
            "order represented by \"1\" if true and \"0\" if false.",
            "Full01")
StateFormat("All_T_Name",
            "All true atoms (excluding negated atoms) are given in alphabetical "
            "order by their name",
            "FullTrue")

StateFormat("FD_A_Name01",
            "True and False atoms (excluding negated atoms) which are "
            "reachable and NOT static(determined"
            " by FastDownward's translator module) are stored (no negated atoms"
            "inroduced by FastDownward are used). True atoms are "
            "suffixed with \"1\", False ones with \"0\". "
            "Atoms are alphabetically sorted.")
StateFormat("FD_A_01",
            "True and False (excluding negated atoms) atoms which are "
            "reachable and NOT static(determined"
            " by FastDownward's translator module) are stored (no negated atoms"
            "inroduced by FastDownward are used). True atoms are "
            "represented by \"1\", False ones by \"0\". "
            "Atoms are alphabetically sorted.",
            "Short")
StateFormat("FD_T_Name",
            "True atoms(excluding negated atoms)  which are reachable and NOT "
            "static(determined by "
            "FastDownward's translator module) are stored (no negated atoms"
            "inroduced by FastDownward are used). Atoms are alphabetically "
            "sorted.")

StateFormat("NFD_T_Name",
            "True atoms which are reachable and NOT static(determined by "
            "FastDownward's translator module) are stored (including negated "
            "atoms introduced by Fast Downward). Atoms are alphabetically "
            "sorted.",
            "FD")

StateFormat("NonStatic_A_01",
            "All atoms(excluding negated atoms) in alphabetical order which "
            "are not static over the given problems represented as \"1\" if "
            "true and as \"0\" if false.")

StateFormat("NonStatic_T_Name",
            "All true atoms(excluding negated atoms) which "
            "are not static over the given problems."
            "")
'''
StateFormat("Objects", "Adds of the problem to the meta tag and then provides"
                      "only the present grounded predicates.")
'''

# Caching

CACHE_STR_GROUNDINGS = {}


def get_cached_groundings(groundable, keep_negatives=False):
    if groundable not in CACHE_STR_GROUNDINGS:
        gnd = groundable.str_grounded_predicates(sort=True)
        gnd = [x for x in gnd if (keep_negatives or not x.startswith("Negated")) and x != "<none of those>"]
        CACHE_STR_GROUNDINGS[groundable] = gnd
    return CACHE_STR_GROUNDINGS[groundable]



CACHE_STR_STATIC_GROUNDINGS = {}


def get_cached_static_groundings(domain_properties):
    if domain_properties not in CACHE_STR_STATIC_GROUNDINGS:
        statics = set([str(item) for item in domain_properties.gnd_static])
        CACHE_STR_STATIC_GROUNDINGS[domain_properties] = statics
    return CACHE_STR_STATIC_GROUNDINGS[domain_properties]


CACHE_STR_NON_STATIC_GROUNDINGS = {}


def get_cached_non_static_groundings(pddl_task, domain_properties):
    key = (pddl_task, domain_properties)
    if key not in CACHE_STR_NON_STATIC_GROUNDINGS:
        assert domain_properties.fixed_world, "Only implemented for fixed universe domains"
        so = set(pddl_task.objects)
        assert so.issubset(domain_properties.fixed_objects), "Requires pddl_task to be from same fixed universe (or subuniverse) as DomainProperties object."
        if so != domain_properties.fixed_objects:
            print("Warning: PDDL is from subuniverse")

        non_statics = [str(x) for x in sorted(domain_properties.gnd_flexible)]

        CACHE_STR_NON_STATIC_GROUNDINGS[key] = non_statics
    return CACHE_STR_NON_STATIC_GROUNDINGS[key]


CACHE_STR_INITS = {}


def get_cached_inits(pddl_task):
    if pddl_task not in CACHE_STR_INITS:
        init = set([str(x) for x in pddl_task.init if (not isinstance(x, pddl.Assign)) and x.predicate != "="])
        CACHE_STR_INITS[pddl] = init
    return CACHE_STR_INITS[pddl]


CACHE_STR_INITS_NOT_IN_SAS = {}


def get_cached_inits_not_in_sas(pddl_task, sas_task):
    t = (pddl_task, sas_task)
    if t not in CACHE_STR_INITS_NOT_IN_SAS:
        init = get_cached_inits(pddl_task)
        for var_names in sas_task.variables.value_names:
            for name in var_names:
                if name in init:
                    init.remove(name)

        CACHE_STR_INITS_NOT_IN_SAS[t] = init
    return CACHE_STR_INITS_NOT_IN_SAS[t]



def clear_caches():
    CACHE_STR_GROUNDINGS.clear()
    CACHE_STR_STATIC_GROUNDINGS.clear()
    CACHE_STR_NON_STATIC_GROUNDINGS.clear()
    CACHE_STR_INITS.clear()
    CACHE_STR_INITS_NOT_IN_SAS.clear()


'''
CACHE_STR_TYPE_OBJECTS_PDDL = {}
def get_cached_type_obj_pddl(pddl):
    if pddl not in CACHE_STR_TYPE_OBJECTS_PDDL:
        s = ""
        objs = {}
        for obj in pddl.objects:
            if obj.type_name not in objs:
                objs[obj.type_name] = set()
            objs[obj.type_name].add(obj.name)
        for type_name in objs:
            s += type_name + "("
            for obj in objs[type_name]:
                s += obj + ", "
            s = s[:-1] + ")\t"
        s = s[:-1]
        CACHE_STR_TYPE_OBJECTS_PDDL[pddl] = s
    return CACHE_STR_TYPE_OBJECTS_PDDL[pddl]
'''


# Parsing functions
def parse_plain_atoms(state, as_string=False, keep_negative=False):
    """
    Parses an iterable of atom describing strings (ignoring negated atoms).
    :param state: [Atom a(...), Atom b(...), ...]
    :param as_string: If true, keep entries as atom describing strings,
                      otherwise convert them to atom objects
    :return: set of atoms or string of atoms
    """
    return set(
        [(atom.strip() if as_string else pddl.Atom.from_string(atom.strip()))
         for atom in state if keep_negative or not atom.startswith("Negated")])


def parse_positive_and_negative_atoms(state, atoms=None, as_string=False):
    """
    Parses an iterable of atom describing strings annotated with "0" or "1" to
    indicate that the atom is False or True and separates the atoms into the
    sets of True and False atoms.
    If atoms is given, then the state shall consist of onle "0" and "1" and the
    i-th entry of the atoms iterable describes the atom of the i-th entry in
    state.
    If as_string is True, the True and False atom sets store the atoms as their
    string description otherwise, Atom objects are stored.
    :param state: [Atom a(...)1, Atom b(...)0, ...] resp.
                  [1, 0, ...] if 'atoms' given
    :param atoms: list of atoms describing strings
    :param as_string: store atoms in output sets as string instead of as objects
    :return: set of true atoms and set of false atoms
    """

    pos = set()
    neg = set()

    if atoms is None:
        atoms = [atom[:-1] for atom in state]
        state = [atom[-1] for atom in state]
    elif len(atoms) != len(state):
        raise ValueError("State parsing error. Unequal amount of atoms in the"
                         "state to parse and the given atom list.")

    for i in range(len(atoms)):
        atom = atoms[i].strip()
        if atom.startswith("Negated"):
            continue
        the_set = pos if state[i] == 1 else neg
        the_set.add(atom if as_string else pddl.Atom.from_string(atom))

    return pos, neg


def convert_from_NFD_T_Name(state, format, pddl_task, sas_task,
                           domain_properties):
    if format == StateFormat.NFD_T_Name:
        return tuple(state)

    new_state = []
    positive_atoms = parse_plain_atoms(state, as_string=True, keep_negative=False)

    if format == StateFormat.All_A_Name01:
        positive_atoms = positive_atoms | get_cached_inits_not_in_sas(pddl_task, sas_task)
        for atom in get_cached_groundings(pddl_task):
            new_state.append(atom + ("1" if atom in positive_atoms else "0"))

    elif format == StateFormat.All_T_Name:
        positive_atoms = positive_atoms | get_cached_inits_not_in_sas(pddl_task, sas_task)
        new_state = sorted(positive_atoms)

    elif format == StateFormat.All_A_01:
        positive_atoms = positive_atoms | get_cached_inits_not_in_sas(pddl_task, sas_task)
        for atom in get_cached_groundings(pddl_task):
            new_state.append(1 if atom in positive_atoms else 0)

    elif format == StateFormat.FD_A_Name01:
        for atom in get_cached_groundings(sas_task.variables):
            new_state.append(atom + ("1" if atom in positive_atoms else "0"))

    elif format == StateFormat.FD_A_01:
        for atom in get_cached_groundings(sas_task.variables):
            new_state.append(1 if atom in positive_atoms else 0)

    elif format == StateFormat.FD_T_Name:
        new_state = [atom for atom in get_cached_groundings(sas_task.variables)
                     if atom in positive_atoms]

    elif format == StateFormat.NonStatic_A_01:
        if domain_properties is None:
            raise ValueError(
                "Requires DomainProperty for conversion to NonStatic_A_01")
        positive_atoms = positive_atoms | get_cached_inits_not_in_sas(pddl_task, sas_task)
        for atom in get_cached_non_static_groundings(pddl_task, domain_properties):
            new_state.append(1 if atom in positive_atoms else 0)

    elif format == StateFormat.NonStatic_T_Name:
        if domain_properties is None:
            raise ValueError(
                "Requires DomainProperty for conversion to NonStatic_T_Name")
        positive_atoms = positive_atoms | get_cached_inits_not_in_sas(
            pddl_task, sas_task)
        new_state = [
            atom for atom in
            get_cached_non_static_groundings(pddl_task, domain_properties)
            if atom in positive_atoms]

    else:
        raise NotImplementedError("The conversion from FD is not implemented "
                                  "to: " + str(format))
    '''
    elif format == StateFormat.Objects:
        new_state += get_cached_type_obj_pddl(pddl_task) + "\t"
        init = get_cached_no_sas_inits(pddl_task, sas_task) | state
        for atom in init:
            new_state += atom + "\t"
        new_state = new_state[:-1]
    '''

    return tuple(new_state)


def convert_from_FD_T_Name(state, format, pddl_task, sas_task,
                            domain_properties):
    if format == StateFormat.FD_T_Name:
        return tuple(state)

    new_state = []

    if format == StateFormat.All_A_Name01:
        state = set(state) | get_cached_inits_not_in_sas(pddl_task, sas_task)
        for atom in get_cached_groundings(pddl_task):
            new_state.append(atom + ("1" if atom in state else "0"))

    elif format == StateFormat.All_A_01:
        state = set(state) | get_cached_inits_not_in_sas(pddl_task, sas_task)
        for atom in get_cached_groundings(pddl_task):
            new_state.append(1 if atom in state else 0)

    elif format == StateFormat.FD_A_Name01:
        for atom in get_cached_groundings(sas_task.variables):
            new_state.append(atom + ("1" if atom in state else "0"))

    elif format == StateFormat.FD_A_01:
        for atom in get_cached_groundings(sas_task.variables):
            new_state.append(1 if atom in state else 0)

    elif format == StateFormat.NonStatic_A_01:
        if domain_properties is None:
            raise ValueError(
                "Requires DomainProperty for conversion to NonStatic_A_01")
        state = set(state) | get_cached_inits_not_in_sas(pddl_task, sas_task)
        for atom in get_cached_non_static_groundings(pddl_task, domain_properties):
            new_state.append(1 if atom in state else 0)

    elif format == StateFormat.NonStatic_T_Name:
        if domain_properties is None:
            raise ValueError(
                "Requires DomainProperty for conversion to NonStatic_T_Name")
        state = set(state) | get_cached_inits_not_in_sas(pddl_task, sas_task)
        new_state = [
            atom for atom in
            get_cached_non_static_groundings(pddl_task, domain_properties)
            if atom in state]

    else:
        raise NotImplementedError("The conversion from FD is not implemented "
                                  "to: " + str(format))

    return tuple(new_state)


def convert_from_FD_A_01(state, format, pddl_task, sas_task,
                            domain_properties):
    if format == StateFormat.FD_A_01:
        return tuple(int(x) for x in state)

    state = [int(x) for x in state]
    positive_atoms = get_cached_groundings(sas_task.variables)
    assert len(state) == len(positive_atoms)
    positive_atoms = set([atom for no, atom in enumerate(positive_atoms)
                          if state[no] == 1])

    new_state = []

    if format == StateFormat.All_A_Name01:
        positive_atoms = positive_atoms | get_cached_inits_not_in_sas(pddl_task, sas_task)
        for atom in get_cached_groundings(pddl_task):
            new_state.append(atom + ("1" if atom in positive_atoms else "0"))

    elif format == StateFormat.All_A_01:
        positive_atoms = positive_atoms | get_cached_inits_not_in_sas(pddl_task, sas_task)
        for atom in get_cached_groundings(pddl_task):
            new_state.append(1 if atom in positive_atoms else 0)

    elif format == StateFormat.FD_A_Name01:
        for atom in get_cached_groundings(sas_task.variables):
            new_state.append(atom + ("1" if atom in positive_atoms else "0"))

    elif format == StateFormat.FD_A_01:
        for atom in get_cached_groundings(sas_task.variables):
            new_state.append(1 if atom in positive_atoms else 0)

    elif format == StateFormat.NonStatic_A_01:
        if domain_properties is None:
            raise ValueError(
                "Requires DomainProperty for conversion to NonStatic_A_01")
        positive_atoms = positive_atoms | get_cached_inits_not_in_sas(pddl_task, sas_task)
        for atom in get_cached_non_static_groundings(pddl_task, domain_properties):
            new_state.append(1 if atom in positive_atoms else 0)

    elif format == StateFormat.NonStatic_T_Name:
        if domain_properties is None:
            raise ValueError(
                "Requires DomainProperty for conversion to NonStatic_A_01")
        positive_atoms = (positive_atoms |
                          get_cached_inits_not_in_sas(pddl_task, sas_task))
        new_state = [
            atom for atom
            in get_cached_non_static_groundings(pddl_task, domain_properties)
            if atom in positive_atoms
        ]
    else:
        raise NotImplementedError("The conversion from FD is not implemented "
                                  "to: " + str(format))

    return tuple(new_state)


def convert_from_All_A_Name01(state, format, pddl_task, sas_task,
                              domain_properties):
    if format == StateFormat.All_A_Name01:
        return tuple(state)

    #state = parse_plain_atoms(state, as_string=True)
    new_state = []
    if format == StateFormat.All_A_01:
        for atom in state:
            new_state.append(0 if atom[-1] == "0" else 1)


    else:
        raise NotImplementedError("The conversion from All_A_Name01 is not "
                                  "implemented to: " + str(format))
    return tuple(new_state)


def convert_from_All_T_Name(state, format, pddl_task, sas_task,
                              domain_properties):
    if format == StateFormat.All_T_Name:
        return tuple(state)

    state = set(state)
    new_state = []
    if format == StateFormat.All_A_01:
        for atom in get_cached_groundings(pddl_task):
            new_state.append(1 if atom in state else 0)

    elif format == StateFormat.NonStatic_A_01:
        if domain_properties is None:
            raise ValueError(
                "Requires DomainProperty for conversion to NonStatic_A_01")

        for atom in get_cached_non_static_groundings(pddl_task, domain_properties):
            new_state.append(1 if atom in state else 0)

    elif format == StateFormat.NonStatic_T_Name:
        if domain_properties is None:
            raise ValueError(
                "Requires DomainProperty for conversion to NonStatic_A_01")
        new_state = [
            atom for atom
            in get_cached_non_static_groundings(pddl_task, domain_properties)
            if atom in state
        ]

    else:
        raise NotImplementedError("The conversion from All_A_Name01 is not "
                                  "implemented to: " + str(format))
    return tuple(new_state)


def convert_from_NonStatic_A_01(state, format, pddl_task, sas_task,
                            domain_properties):
    state = tuple(int(x) for x in state)
    if format == StateFormat.NonStatic_A_01:
        return state

    nsg = get_cached_non_static_groundings(pddl_task, domain_properties)
    assert len(nsg) == len(state)
    if format == StateFormat.NonStatic_T_Name:
        return [atom for no, atom in enumerate(nsg) if state[no] == 1]
    else:
        raise NotImplementedError("The conversion from NonStatic_A_01 is not "
                                  "implemented to: " + str(format))


def convert_from_NonStatic_T_Name(state, format, pddl_task, sas_task,
                            domain_properties):
    if format == StateFormat.NonStatic_T_Name:
        return state
    else:
        raise NotImplementedError("The conversion from NonStatic_A_01 is not "
                                  "implemented to: " + str(format))


def convert_from_X_to_Y(state, in_format, out_format,
                        pddl_task=None, sas_task=None, domain_properties=None):
    if in_format == out_format and isinstance(state, tuple):
        return state
    elif in_format == StateFormat.NFD_T_Name:
        return convert_from_NFD_T_Name(state, out_format,pddl_task, sas_task,
                                      domain_properties)
    elif in_format == StateFormat.FD_T_Name:
        return convert_from_FD_T_Name(state, out_format, pddl_task, sas_task,
                                      domain_properties)
    elif in_format == StateFormat.FD_A_01:
        return convert_from_FD_A_01(state, out_format, pddl_task, sas_task,
                                    domain_properties)
    elif in_format == StateFormat.All_A_Name01:
        return convert_from_All_A_Name01(state, out_format, pddl_task, sas_task,
                                         domain_properties)
    elif in_format == StateFormat.All_T_Name:
        return convert_from_All_T_Name(state, out_format, pddl_task, sas_task,
                                       domain_properties)
    elif in_format == StateFormat.NonStatic_A_01:
        return convert_from_NonStatic_A_01(state, out_format, pddl_task, sas_task,
                                    domain_properties)
    elif in_format == StateFormat.NonStatic_T_Name:
        return convert_from_NonStatic_T_Name(state, out_format, pddl_task, sas_task,
                                    domain_properties)

    raise NotImplementedError("Conversions for the given formats is not "
                              "supported: %s -> %s" % (str(in_format),
                                                       str(out_format)))


class StateFieldParser(FieldParser):
    @staticmethod
    def _convert(data, description, list_kwargs, unparse=False,
                 domain_properties=None, **kwargs):
        current_format = get_argument_for_parser(
            ARG_CURRENT_FORMAT, description,
            None if list_kwargs is None else list_kwargs[0], kwargs,
            converter_user=StateFormat.get,
            converter_description=StateFormat.get)
        new_format = get_argument_for_parser(
            ARG_NEW_FORMAT, description,
            None if list_kwargs is None else list_kwargs[0], kwargs,
            converter_user=StateFormat.get,
            converter_description=StateFormat.get)

        if current_format is None:
            raise ValueError("Unable to determine the current format of the "
                             "state data to load")
        if new_format is None:
            new_format = current_format

        pddl_task = kwargs["pddl_task"]
        sas_task = kwargs["sas_task"]
        data = convert_from_X_to_Y(data, current_format, new_format,
                                   pddl_task, sas_task,
                                   domain_properties=domain_properties)
        description[ARG_CURRENT_FORMAT] = new_format.name
        return description, data

    def _parse(self, data, description, list_kwargs, **kwargs):
        state = data.split("\t")
        return StateFieldParser._convert(state, description,
                                         list_kwargs, **kwargs)

    def _unparse(self, data, description, list_kwargs, **kwargs):
        description, state = StateFieldParser._convert(data, description,
                                         list_kwargs, unparse=True, **kwargs)
        return description, "\t".join([str(x) for x in state])

    def _prepare_kwargs(self, description, kwargs):
        return kwargs


FIELD_PARSERS["state"] = StateFieldParser()

