#!/usr/bin/env python

import argparse
import os
import re
import subprocess
import sys
from typing import Any, Callable, List, Optional, Set, Tuple

PATTERN_VARIABLE = re.compile(
    "begin_variable\n"
    "var(\d+)\n"
    "(-?\d+)\n"
    "(\d+)\n"
    "(((Atom|Negated|<none of )[^\n]+\n)+)"
    "end_variable"
)

PATTERN_MUTEX_GROUP = re.compile(
    "begin_mutex_group\n"
    "(\d+)\n"
    "((\d+ \d+\n)+)"
    "end_mutex_group"
)

PATTERN_VAR_VAL_PAIR = re.compile(
    "(\d+) (\d+)\n"
)


parser = argparse.ArgumentParser(
    """Find the common mutexes (detected by the translator)""")

parser.add_argument("directory", action="store", type=str,
                    help="Directory within which to work.")
parser.add_argument("-fd", "--fast-downward", action="store", type=str,
                    default="./fast-downward.py",
                    help="Path to the fast-downward script.")
parser.add_argument("-b", "--build", action="store", type=str,
                    default="release64dynamic",
                    help="Name of the fast downward build to use")
parser.add_argument("--skip", action="store_true",
                    help="Skips the generation for a directory, if it contains "
                         "already a 'mutexes.sas' file.")

class Mutex(object):
    def __init__(self, atoms: Set[str]=None) -> None:
        self.atoms = set() if atoms is None else atoms

    def add(self, atom: str) -> None:
        self.atoms.add(atom)

    def size(self) -> int:
        return len(self.atoms)

    def __len__(self) -> int:
        return self.size()

    def empty(self) -> bool:
        return len(self) == 0

    def has(self, atom: str) -> bool:
        return atom in self.atoms

    def is_mutex(self, atom1:str, atom2:str) -> bool:
        return atom1 in self.atoms and atom2 in self.atoms

    def intersect(self, other: 'Mutex') -> 'Mutex':
        return Mutex(self.atoms & other.atoms)

    def is_submutex_of(self, other: 'Mutex') -> bool:
        return self.atoms.issubset(other.atoms)

    def copy(self) -> 'Mutex':
        return Mutex(set(self.atoms))

    def dumps(self) -> str:
        s = ""
        s += "begin_mutex_group\n"
        s += str(len(self.atoms)) + "\n"
        for atom in self.atoms:
            s += atom + "\n"
        s += "end_mutex_group\n"
        return s

    def dump(self) -> None:
        print(self.dumps(), end="")


class MutexCollection(object):
    def __init__(self, mutexes : Optional[List[Mutex]] = None) -> None:
        self.mutexes = [] if mutexes is None else mutexes

    def add(self, mutex: Mutex) -> None:
        self.mutexes.append(mutex)

    def copy(self) -> Mutex:
        return MutexCollection([x.copy() for x in self.mutexes])

    def merge(self, other: Optional['MutexCollection']) -> None:
        if other is None:
            return
        new = MutexCollection.merge2(self, other)
        self.mutexes = new.mutexes

    @staticmethod
    def merge2(first: 'MutexCollection',
              second: 'MutexCollection') -> 'MutexCollection':
        if first is None and second is None:
            return None
        elif first is None:
            return second.copy()
        elif second is None:
            return first.copy()
        else:
            # Find new common mutexes (contains possible redundancy)
            common = []
            for m in first.mutexes:
                for n in second.mutexes:
                    new = m.intersect(n)
                    if len(new) > 1:
                        common.append(new)

            # Remove duplicates and sub mutexes of a bigger known mutexes
            common = sorted(common, key=lambda x: len(x))
            reduced = []
            for i in range(len(common)):
                is_submutex = False
                for j in range(i + 1, len(common)):
                    if common[i].is_submutex_of(common[j]):
                        is_submutex = True
                        break
                if not is_submutex:
                    reduced.append(common[i])

            return MutexCollection(reduced)

    def dumps(self) -> str:
        s = "%i\n" % len(self.mutexes)
        for mutex in self.mutexes:
            s += mutex.dumps()
        return s

    def dump(self) -> None:
        print(self.dumps(), end="")


def calculate_mutexes(path_fd, path_problem, path_domain=None, build=None):
    assert path_fd is not None
    assert path_problem is not None
    path_sas = "output.sas"
    if os.path.exists(path_sas):
        os.remove(path_sas)

    try:
        subprocess.call([path_fd] +
                        ([] if build is None else ["--build", build]) +
                        ([] if path_domain is None else [path_domain]) +
                        [path_problem])
    except subprocess.CalledProcessError as e:
        pass

    if not os.path.exists(path_sas):
        raise FileNotFoundError("SAS+ Representation not computed")

    MC = MutexCollection()
    with open(path_sas, "r") as f:
        sas = f.read()
        vars = {}
        m = PATTERN_VARIABLE.findall(sas)
        for var, e in enumerate(m):
            M = Mutex()
            vars[var] = {}
            vals = e[3].strip().split("\n")
            assert len(vals) == int(e[2])
            for val, atom in enumerate(vals):
                M.add(atom)
                vars[var][val] = atom
            MC.add(M)

        m = PATTERN_MUTEX_GROUP.findall(sas)
        for e in m:
            M = Mutex()
            for (var, val) in PATTERN_VAR_VAL_PAIR.findall(e[1]):
                M.add(vars[int(var)][int(val)])
            MC.add(M)

    os.remove(path_sas)
    return MC


class Node(object):
    def __init__(self, dir_name: str, parent: Optional['Node'], path: str=None,
                children: List=None, has_domain: bool=False,
                 problems: List[str]=None,
                 common_mutexes: Optional[MutexCollection]=None,
                 fd_path: str=None, fd_build:str=None):
        self.dir_name = dir_name
        self.path = path
        self._parent = parent
        if self._parent is not None:
            self._parent.add_child(self)
        self._children = [] if children is None else children
        self.has_domain = has_domain
        self._problems = [] if problems is None else problems
        self._problem_mutexes = ([] if problems is None
                                 else [None for _ in range(len(problems))])
        self._applied_problem_mutex = (
            [] if problems is None else [None for _ in range(len(problems))])
        # None => not processed, [] => nothing common
        self.common_mutexes = common_mutexes

        self.fd_path = fd_path
        self.fd_build = fd_build


    def get_path(self) -> str:
        if self.path is not None:
            return self.path

        if self._parent is None:
            return self.dir_name
        else:
            os.path.join(self._parent.get_path(), self.dir_name)

    def get_parent(self) -> 'Node':
        return self._parent

    def add_child(self, c: 'Node') -> None:
        assert c._parent is None or c._parent is self
        if c._parent is None:
            c._parent = self
        self._children.append(c)

    def rmv_child(self, c: 'Node') -> None:
        """
        REMOVING A CHILD CANNOT UNDO MERGING THE MUTEXES WITH THE CHILDS
        MUTEXES! BEST DO NOT REMOVE CHILDS AFTER STARTING TO PROCESS MUTEXES.
        :param c:
        :return:
        """
        assert c in self._children
        c._parent = None
        self._children.remove(c)

    def get_children(self) -> List['Node']:
        return self._children

    def get_nb_children(self) -> int:
        return len(self._children)

    def sort_children(self, key: Callable[['Node'], Any]=lambda x: x.dir_name,
                      recursive=False) -> None:
        self._children = sorted(self._children, key=key)
        if recursive:
            for child in self._children:
                child.sort_children(key=key, recursive=recursive)

    def add_problem(self, path: str) -> None:
        self._problems.append(path)
        self._problem_mutexes.append(None)
        self._applied_problem_mutex.append(False)

    def get_nb_problems(self) -> int:
        return len(self._problems)

    def calculate_problem_mutexes(self, index: int) -> MutexCollection:
        assert 0 <= index < self.get_nb_problems()
        if self._problem_mutexes[index] is None:
            self._problem_mutexes[index] = calculate_mutexes(
                self.fd_path, self._problems[index], build=self.fd_build)

        return self._problem_mutexes[index]


    def calculate_all_problem_mutexes(self) -> None:
        for i in range(self.get_nb_problems()):
            self.calculate_problem_mutexes(i)


    def apply_problem_mutexes(self) -> MutexCollection:
        """
        Calculates for every associated problem its mutexes (if not already
        done) and merges them with the nodes MutexCollection
        :return: MutexCollection associated to this node.
        """
        self.calculate_all_problem_mutexes()
        for i in range(self.get_nb_problems()):
            if not self._applied_problem_mutex[i]:
                self.merge_mutexes(self._problem_mutexes[i])
            self._applied_problem_mutex[i] = True

    def apply_children_mutexes(self, recursive=True,
                               apply_problems=False) -> None:
        if recursive:
            for child in self._children:
                child.apply_children_mutexes(recursive=recursive,
                                             apply_problems=apply_problems)
                self.common_mutexes.merge(child.comon_mutexes)
        if apply_problems:
            self.apply_problem_mutexes()



    def merge_mutexes(self, other_mutexes : Optional[MutexCollection]) -> None:
        if self.common_mutexes is None:
            if other_mutexes is not None:
                self.common_mutexes = other_mutexes.copy()
        else:
            self.common_mutexes.merge(other_mutexes)

    def store_mutexes(self):
        assert os.path.exists(self.get_path())
        assert self.common_mutexes is not None
        with open(os.path.join(self.get_path(), "mutexes.sas"), "w") as f:
            f.write(self.common_mutexes.dumps())


    def dumps(self, depth=0) -> str:
        s = ""
        if depth == 0:
            s += self.dir_name + "\n"
        else:
            s += (depth - 1) * "  " + "|-" + self.dir_name + "\n"
        for child in self._children:
            s += child.dumps(depth + 1)
        return s

    def dump(self) -> None:
        print(self.dumps(), end="")

    def __str__(self) -> str:
        return "({dir_name}, {_parent})".format(**self.__dict__)


def is_pddl(path):
    return path.endswith(".pddl")


def is_domain(path):
    return path.find("domain") > -1


def prepare_args(argv) -> argparse.Namespace:
    options = parser.parse_args(argv)
    assert os.path.isfile(options.fast_downward)
    return options


def find_problems(path: str, path_fd: str, build: str) -> Tuple[Node, List[Node]]:
    # HACK!
    root = Node(os.path.basename(path[:-1] if path[-1] == "/" else path),
                None, path)
    todo = [root]
    while len(todo) > 0:
        next_node = todo.pop()
        path = next_node.get_path()
        for item in os.listdir(path):
            path_item = os.path.join(path, item)
            if os.path.isdir(path_item):
                child = Node(item, next_node, path_item,
                             fd_path=path_fd, fd_build=build)
                todo.append(child)

            elif os.path.isfile(path_item) and is_pddl(item):
                if is_domain(item):
                    next_node.has_domain = True
                else:
                    next_node.add_problem(path_item)

    # Detect leave nodes
    leaves = set()
    todo = [root]
    while len(todo) > 0:
        next_node = todo.pop()
        if len(next_node.get_children()) == 0:
            leaves.add(next_node)
        else:
            todo.extend(next_node.get_children())

    # Prune leave nodes which do not contain relevant files (e.g. no problems
    # or do not have a domain file in it)
    todo = leaves
    leaves = []
    while len(todo) > 0:
        leave = todo.pop()
        if leave.get_nb_problems() > 0 and leave.has_domain:
            leaves.append(leave)
        else:
            parent = leave.get_parent()
            parent.rmv_child(leave)
            if len(parent.get_children()) == 0:
                todo.add(parent)

    return root, leaves




def run(argv):
    options = prepare_args(argv)
    root, leaves = find_problems(options.directory,
                                 options.fast_downward, options.build)

    root.sort_children(recursive=True)

    # Create for all dirs with a domain file.
    todo = [root]
    while len(todo) > 0:
        node = todo.pop()
        todo.extend(node.get_children())
        if node.has_domain:
            if options.skip and os.path.exists(os.path.join(node.get_path(), "mutexes.sas")):
                continue
            node.apply_problem_mutexes()
            node.store_mutexes()



if __name__ == "__main__":
    run(sys.argv[1:])