import codecs
import logging
import os
import random
from copy import deepcopy
from dataclasses import dataclass
import dill

from contextlib import nullcontext
from math import inf
from os.path import abspath, dirname
from shutil import rmtree

from grammarinator.tool.default_population import DefaultTree
from grammarinator.runtime.rule import Rule

logger = logging.getLogger(__name__)

@dataclass(slots=True)
class NodePair:
    concrete: Rule
    abstract: Rule

class SynthFuzzGeneratorTool:
    """
    Class to create new test cases using the generator produced by ``grammarinator-process``.
    """

    def __init__(self, generator_factory, out_format, lock=None, rule=None, max_depth=inf,
                 population=None, generate=True, mutate=True, recombine=True, edit=True, keep_trees=False,
                 transformers=None, serializer=None,
                 cleanup=True, encoding='utf-8', errors='strict', edit_seed=None, edit_log=None):
        """
        :param generator_factory: A callable that can produce instances of a
            generator. It is a generalization of a generator class: it has to
            instantiate a generator object, and it may also set the decision
            model and the listeners of the generator as well. In the simplest
            case, it can be a ``grammarinator-process``-created subclass of
            :class:`~grammarinator.runtime.Generator`, but in more complex
            scenarios a factory can be used, e.g., an instance of
            :class:`DefaultGeneratorFactory`.
        :param str rule: Name of the rule to start generation from (default: the
            default rule of the generator).
        :param str out_format: Test output description. It can be a file path pattern possibly including the ``%d``
               placeholder which will be replaced by the index of the test case. Otherwise, it can be an empty string,
               which will result in printing the test case to the stdout (i.e., not saving to file system).
        :param multiprocessing.Lock lock: Lock object necessary when printing test cases in parallel (optional).
        :param int or float max_depth: Maximum recursion depth during generation (default: ``inf``).
        :param ~grammarinator.tool.Population population: Tree pool for mutation
            and recombination.
        :param bool generate: Enable generating new test cases from scratch, i.e., purely based on grammar.
        :param bool mutate: Enable mutating existing test cases, i.e., re-generate part of an existing test case based on grammar.
        :param bool recombine: Enable recombining existing test cases, i.e., replace part of a test case with a compatible part from another test case.
        :param bool keep_trees: Keep generated trees to participate in further mutations or recombinations
               (otherwise, only the initial population will be mutated or recombined). It has effect only if
               population is defined.
        :param list transformers: List of transformers to be applied to postprocess
               the generated tree before serializing it.
        :param serializer: A seralizer that takes a tree and produces a string from it (default: :class:`str`).
               See :func:`grammarinator.runtime.simple_space_serializer` for a simple solution that concatenates tokens with spaces.
        :param bool cleanup: Enable deleting the generated tests at :meth:`__exit__`.
        :param str encoding: Output file encoding.
        :param str errors: Encoding error handling scheme.
        """

        self._generator_factory = generator_factory
        self._transformers = transformers or []
        self._serializer = serializer or str
        self._rule = rule

        if out_format:
            os.makedirs(abspath(dirname(out_format)), exist_ok=True)

        self._out_format = out_format
        self._lock = lock or nullcontext()
        self._max_depth = max_depth
        self._population = population
        self._enable_generation = generate
        self._enable_mutation = mutate
        self._enable_recombination = recombine
        self._enable_edit = edit
        self._keep_trees = keep_trees
        self._cleanup = cleanup
        self._encoding = encoding
        self._errors = errors

        self._edit_rand = random.Random(edit_seed)
        self._edit_log = edit_log

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        Delete the output directory if the tests were saved to files and if ``cleanup`` was enabled.
        """
        if self._cleanup and self._out_format:
            rmtree(dirname(self._out_format))

    def create(self, index):
        """
        Create new test case with a randomly selected generator method from the available
        options (i.e., via :meth:`generate`, :meth:`mutate`, or :meth:`recombine`). The
        generated tree is transformed, serialized and saved according to the parameters
        used to initialize the current generator object.

        :param int index: Index of the test case to be generated.
        :return: Tuple of the path to the generated serialized test file and the path to the tree file. The second item,
               (i.e., the path to the tree file) might be ``None``, if either ``population`` or ``keep_trees`` were not set
               in :meth:`__init__` and hence the tree object was not saved either.
        :rtype: tuple[str, str]
        """
        creators = []
        if self._enable_generation:
            creators.append(("generate", self.generate))
        if self._population:
            if self._enable_mutation and self._population.can_mutate():
                creators.append(("mutate", lambda: self.mutate(self._population.select_to_mutate(self._max_depth))))
            if self._enable_recombination and self._population.can_recombine():
                creators.append(("recombine", lambda: self.recombine(*self._population.select_to_recombine(self._max_depth))))
            if self._enable_edit and self._population.can_recombine():
                creators.append(("edit", lambda: self.edit(*self._population.select_to_recombine(self._max_depth))))
        strategy, creator = random.choice(creators)

        if strategy == "edit":
            root, donor, recipient, substitutions = creator()
        elif strategy == "recombine":
            root, donor, recipient = creator()
        elif strategy == "mutate":
            root, mutated_node, original_node = creator()
        else:
            root = creator()
        for transformer in self._transformers:
            root = transformer(root)

        test = self._serializer(root)
        test_fn = self._out_format % index if '%d' in self._out_format else self._out_format
        if self._edit_log:
            log_path = self._edit_log / f'{index}.pkl'
            info = {
                "strategy": strategy,
            }
            if strategy == "edit":
                info["donor"] = donor
                info["recipient"] = recipient
                info["substitutions"] = substitutions
            elif strategy == "recombine":
                info["donor"] = donor
                info["recipient"] = recipient
            elif strategy == "mutate":
                info["mutated"] = mutated_node
                info["original"] = original_node
            with open(log_path, 'wb') as f:
                dill.dump(info, f)

        if self._population and self._keep_trees:
            self._population.add_individual(root, path=test_fn)

        if test_fn:
            with codecs.open(test_fn, 'w', self._encoding, self._errors) as f:
                f.write(test)
        else:
            with self._lock:
                print(test)

        return test_fn

    def generate(self, *, rule=None, max_depth=None):
        """
        Instantiate a new generator and generate a new tree from scratch.

        :param str rule: Name of the rule to start generation from.
        :param int max_depth: Maximum recursion depth during generation.
        :return: The root of the generated tree.
        :rtype: Rule
        """
        max_depth = max_depth if max_depth is not None else self._max_depth
        generator = self._generator_factory(max_depth=max_depth)

        rule = rule or self._rule or generator._default_rule.__name__
        start_rule = getattr(generator, rule)
        if not hasattr(start_rule, 'min_depth'):
            logger.warning('The \'min_depth\' property of %s is not set. Fallback to 0.', rule)
        elif start_rule.min_depth > max_depth:
            raise ValueError(f'{rule} cannot be generated within the given depth: {max_depth} (min needed: {start_rule.min_depth}).')

        return start_rule()

    def mutate(self, mutated_node):
        """
        Mutate a tree at a given position, i.e., discard and re-generate its
        sub-tree at the specified node.

        :param Rule mutated_node: The root of the sub-tree that should be
            re-generated.
        :return: The root of the mutated tree.
        :rtype: Rule
        """
        original_node = deepcopy(mutated_node)
        node, level = mutated_node, 0
        while node.parent:
            node = node.parent
            level += 1

        mutated_node = mutated_node.replace(self.generate(rule=mutated_node.name, max_depth=self._max_depth - level))

        node = mutated_node
        while node.parent:
            node = node.parent
        return node, mutated_node, original_node

    def recombine(self, recipient_node, donor_node):
        """
        Recombine two trees at given positions where the nodes are compatible
        with each other (i.e., they share the same node name). One of the trees
        is called the recipient while the other is the donor. The sub-tree
        rooted at the specified node of the recipient is discarded and replaced
        by the sub-tree rooted at the specified node of the donor.

        :param Rule recipient_node: The root of the sub-tree in the recipient.
        :param Rule donor_node: The root of the sub-tree in the donor.
        :return: The root of the recombined tree.
        :rtype: Rule
        """
        original_donor = deepcopy(donor_node)
        original_recipient = deepcopy(recipient_node)
        if recipient_node.name != donor_node.name:
            raise ValueError(f'{recipient_node.name} cannot be replaced with {donor_node.name}')

        node = recipient_node.replace(donor_node)
        while node.parent:
            node = node.parent
        return node, original_donor, original_recipient
    
    def edit(self, recipient_node, donor_node):
        original_donor = deepcopy(donor_node)
        original_recipient = deepcopy(recipient_node)
        substitutions = dict()

        # if the donor has no children, then we can't do any adaptations
        if not donor_node.children:
            return (*self.recombine(recipient_node, donor_node), substitutions)

        # get the root node of the donor tree
        def get_root(node):
            root = node
            while root.parent:
                root = root.parent
            return root
        donor_root = get_root(donor_node)

        # index nodes of the donor tree
        def index_nodes(current, nodes_by_name, exclude_subtree):
            if current == exclude_subtree:
                return
            if current.name not in nodes_by_name:
                nodes_by_name[current.name] = []
            nodes_by_name[current.name].append(current)
            if current.children:
                for child in current.children:
                    index_nodes(child, nodes_by_name, exclude_subtree)

        
        # index the donor tree for possible substitutions
        fragment_nodes = dict()
        for child in donor_node.children:
            index_nodes(child, fragment_nodes, exclude_subtree=None)
        context_nodes = dict()
        index_nodes(donor_root, context_nodes, exclude_subtree=donor_node)
        common_names = set(fragment_nodes.keys()) & set(context_nodes.keys())

        # traverse the recipient and donor trees to find substitutions
        # we assume exact matches occur when the source strings match exactly
        def get_matching_nodes(ref_value, value_list, node_list):
            return [node_list[i] for i, value in enumerate(value_list) if value == ref_value]

        # locate parameters
        def locate_parameters():
            parameters = dict()
            for name in common_names:
                nodes_in_fragment = fragment_nodes[name]
                nodes_in_context = context_nodes[name]
                node_strings_in_fragment = [str(node) for node in nodes_in_fragment]
                node_strings_in_context = [str(node) for node in nodes_in_context]
                for unique_node_string in set(node_strings_in_fragment):
                    param_nodes_in_fragment = get_matching_nodes(ref_value=unique_node_string, value_list=node_strings_in_fragment, node_list=nodes_in_fragment)
                    param_nodes_in_context = get_matching_nodes(ref_value=unique_node_string, value_list=node_strings_in_context, node_list=nodes_in_context)
                    for param_node in param_nodes_in_context:
                        parameters[param_node] = param_nodes_in_fragment
            return parameters
        parameters = locate_parameters()

        # first collect all common ancestors
        # ancestors will be a list from closest to furthest ancestor
        ancestors_concrete = [recipient_node]
        ancestors_abstract = [donor_node]
        concrete, abstract = recipient_node, donor_node
        while (concrete.parent and abstract.parent and
        (concrete.parent.name == abstract.parent.name)):
            concrete, abstract = concrete.parent, abstract.parent
            ancestors_concrete.append(concrete)
            ancestors_abstract.append(abstract)

        # TODO: when we explore, we should stop exploring a path when there are no parameters to find along that path
        def get_siblings(idx, ancestors):
            ancestor_parent: Rule = ancestors[idx]
            ancestor_child: Rule = ancestors[idx-1]
            siblings = ancestor_parent.children
            ancestor_child_idx = siblings.index(ancestor_child)
            siblings_left = siblings[:ancestor_child_idx]
            siblings_right = siblings[ancestor_child_idx+1:]
            return siblings_left, siblings_right
        parameter_values = dict()
        def save_param(a_node, c_node):
            if a_node in parameter_values:
                parameter_values[a_node].append(c_node)
            else:
                parameter_values[a_node] = [c_node]
        def match_nodes(abstract_nodes: list[Rule], concrete_nodes: list[Rule]):
            matching_nodes = []
            # for each abstract node, we look for a matching concrete node
            c_idx = 0
            for a_idx in range(len(abstract_nodes)):
                a_node = abstract_nodes[a_idx]
                old_idx = c_idx
                while c_idx < len(concrete_nodes):
                    c_node = concrete_nodes[c_idx]
                    c_idx += 1
                    # if we find a matching pair, we stop and go to the next abstract node
                    if a_node.name == c_node.name:
                        if a_node in parameters:
                            save_param(a_node=a_node, c_node=c_node)
                        else:
                            # continue matching down the chain
                            matching_nodes.append(NodePair(concrete=c_node, abstract=a_node))
                        break
                # if we've exhausted the concrete nodes for this abstract node, then we reset to after the last maching concrete node
                if c_idx >= len(concrete_nodes):
                    c_idx = old_idx

            return matching_nodes
        def recursively_match_nodes(abstract_nodes: list[Rule], concrete_nodes: list[Rule]):
            matching_nodes = match_nodes(abstract_nodes=abstract_nodes, concrete_nodes=concrete_nodes)
            for pair in matching_nodes:
                if pair.abstract.children is None or pair.concrete.children is None:
                    continue
                recursively_match_nodes(abstract_nodes=pair.abstract.children, concrete_nodes=pair.concrete.children)

        # get parameters
        assert len(ancestors_concrete) == len(ancestors_abstract)
        for ancestor_idx in range(1, len(ancestors_concrete)):
            siblings_concrete_left, siblings_concrete_right = get_siblings(ancestor_idx, ancestors_concrete)
            siblings_abstract_left, siblings_abstract_right = get_siblings(ancestor_idx, ancestors_abstract)
            recursively_match_nodes(abstract_nodes=siblings_abstract_left, concrete_nodes=siblings_concrete_left)
            recursively_match_nodes(abstract_nodes=siblings_abstract_right, concrete_nodes=siblings_concrete_right)
            
        # substitute parameters
        for a_node, param_values in parameter_values.items():
            if len(param_values) == 0:
                continue
            # randomly choose one of the possible parameter values
            param_value = self._edit_rand.choice(param_values)
            substitutions[a_node] = param_value
            for param_node in parameters[a_node]:
                param_node.replace(param_value)
        
        # insert fragment
        node = recipient_node.replace(donor_node)
        while node.parent:
            node = node.parent
        return node, original_donor, original_recipient, substitutions