#!/usr/bin/env python
from __future__ import print_function
import argparse
import os
import sys
import re

# Pattern for a variable in the SAS file
PATTERN_VARIABLE = re.compile(r"begin_variable.*?end_variable", re.DOTALL)
# Pattern for the state in the SAS file
PATTERN_STATE = re.compile(r"begin_state.*?end_state", re.DOTALL)
# Prefix of an atom in the SAS file
SAS_ATOM_PREFIX = "Atom "


def type_is_file(arg):
    if not os.path.isfile(arg):
        raise argparse.ArgumentTypeError("Argument is not a file: {}".format(arg))
    return arg


parser = argparse.ArgumentParser()
parser.add_argument("original_pddl", type=type_is_file,
                    help="Path to the original PDDL instance file.")
parser.add_argument("sampled_sas", type=type_is_file,
                    help="Path to the sampled SAS task.")


def natural_sort(l):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
    return sorted(l, key = alphanum_key)


def get_block(content, keys, start=0, inc_level=["("], dec_level=[")"],
              return_indices=False):
    """
    Return the first and last index of a text block within the given string.
    The beginning of the block is defined by an element of 'key' and only one
    block may exist in the whole string.

    :param content: string in which the block is searched
    :param keys: keys which determines the start of the block
    :param start: position in content to start searching
    :param inc_level: all elements in this variable increase the nesting
    :param dec_level: all elements in this variable decrease the nesting
    :param return_indices: add the start and index to the return value.
    :return: block extracted [,
            Tuple(position of first char in block [inclusive key],
            position of last char)]
    """
    tuple_start = [(content.find(key, start), key) for key in keys]
    tuple_start = [x for x in tuple_start if x[0] != -1]
    assert len(tuple_start) == 1, str(tuple_start)
    idx_start, used_key = tuple_start[0]

    # Find the end of our text block
    idx_end = idx_start + len(used_key)
    level = 1  # Nesting. Once we reach level 0, we have reached the block end.
    while len(content) > idx_end:
        if content[idx_end] in inc_level:
            level += 1
        elif content[idx_end] in dec_level:
            level -= 1
        idx_end += 1
        if level == 0:
            break

    assert level == 0
    block = content[idx_start: idx_end]
    return (block, (idx_start, idx_end)) if return_indices else block


def convert_sas2pddl_atoms(atom):
    """
    Convert the atom given in the SAS syntax (e.g.
    Atom NAME(param1, param2, ...)) to PDDL syntax (e.g.
    (NAME param1 param2 ...)).
    :param atom: atom in SAS syntax
    :return: atom in PDDL syntax
    """
    if atom == "<none of those>" or atom.startswith("NegatedAtom"):
        return None
    assert atom.startswith(SAS_ATOM_PREFIX), "Err: %s" % atom
    atom = atom[len(SAS_ATOM_PREFIX):].replace(",", "")
    idx_bracket = atom.find("(")
    return ("(" + atom[:idx_bracket] +
            (" " if atom[idx_bracket + 1:].strip() != ")" else "") +
            atom[idx_bracket + 1:])


def get_sas_initial_facts(sas, sas_variables):
    """
    Extract the initial state from a SAS file and return the facts.
    :param sas: content of a sas file
    :param sas_variables: variables of the SAS file
    :return: [fact1, fact2, ...]
    """
    init = PATTERN_STATE.findall(sas)
    assert len(init) == 1, init
    init_values = [int(x) for x in init[0].split("\n")[1:-1]]
    init_facts = [sas_variables[idx_var][idx_value]
                  for idx_var, idx_value in enumerate(init_values)
                  if sas_variables[idx_var][idx_value] is not None]
    return natural_sort(init_facts)


def get_pddl_initial_facts(pddl_init):
    assert pddl_init.lower().startswith("(:init")
    assert pddl_init.endswith(")")
    pddl_init = pddl_init[6:-1]

    # Remove comments
    lines = []
    for line in  pddl_init.splitlines():
        idx_comment = line.find(";")
        if idx_comment != -1:
            line = line[:idx_comment].strip()
        lines.append(line)
    pddl_init = "\n".join(lines)

    # Extract the facts (and function, i.e. (= FACT VALUE)
    init_facts = []
    for regex in [r"\(\s*=\s*\([^)]+?\)\s*\d+\s*\)", r"\([^)]+?\)"]:
        new_facts = re.findall(regex, pddl_init)
        # Clear the fact to ensure at the end that we haven't forgotten anything
        for fact in new_facts:
            pddl_init = pddl_init.replace(fact, "")
        init_facts.extend(new_facts)
    pddl_init = pddl_init.strip()
    assert pddl_init == "", pddl_init  # Have we forgotten anything?

    return init_facts


def run(options):
    # Read the files
    with open(options.original_pddl, "r") as f:
        pddl = f.read().lower()
    with open(options.sampled_sas, "r") as f:
        sas = f.read()

    # Extract initial state block from the PDDL file
    pddl_init, pddl_init_indes = get_block(pddl, ["(:init", "(:INIT"],
                                           return_indices=True)
    pddl_init_facts = set(get_pddl_initial_facts(pddl_init))
    init_indent = pddl[pddl[:pddl_init_indes[0]].rfind("\n") + 1:pddl_init_indes[0]]


    # [Variable1=[fact11,q fact12, ...], Variable2, ...]
    # Every fact is in a new line and the first 4 lines contain other informations
    sas_variables = [var.group().split("\n")[4:-1]
                     for var in PATTERN_VARIABLE.finditer(sas)]
    sas_variables = [[convert_sas2pddl_atoms(atom) for atom in variable]
                     for variable in sas_variables]
    sas_all_facts = set([f for v in sas_variables for f in v])
    static_facts = pddl_init_facts - sas_all_facts  # e.g. a unchanging road map

    sas_init_facts = get_sas_initial_facts(sas, sas_variables)

    new_init = (
            "(:init\n{0}    ".format(init_indent) +
            "\n{}".format(init_indent + "    ").join(natural_sort(
                list(sas_init_facts) + list(static_facts))) +
             "\n)"
    )

    print(pddl[:pddl_init_indes[0]] + new_init + pddl[pddl_init_indes[1]:])


if __name__ == "__main__":
    run(parser.parse_args(sys.argv[1:]))
