#!/usr/bin/env python
from __future__ import print_function
import sys
sys.path.append("../../")

from src.training.misc import DomainProperties
from src.translate import pddl, pddl_parser

import argparse
import json
import math
import os
import re

class ExitCodes(object):
    INPUT_ERROR = 1

def exit_with_error(msg, code):
    print(msg, file=sys.stderr)
    sys.exit(code)


parser = argparse.ArgumentParser()
parser.add_argument("directory", type=str, action="store", nargs="*",
                    help="Any number (at least one) directories where all"
                         "domains in there shall be analysed.")
parser.add_argument("-p", "--pattern", type=str, action="append", default=[],
                    help="Define a regular expressions which have to be matched"
                         "by any domain.")
parser.add_argument("--anti-pattern", type=str, action="append", default=[],
                    help="Define a regular expression which may NOT match "
                         "domains.")
parser.add_argument("--data", type=str, action="append", nargs=3, default=[],
                    help="[NAME] [COUNT] [NUMBER BIT SIZE]\n"
                         "Define a data storing for which the required"
                         "capacity shall be calculated.")
parser.add_argument("--network", type=str, action="append", nargs="+", default=[],
                    help="[NAME] [TYPE [ARGS FOR TYPE]*\n"
                         "Network description for which it minimum required "
                         "memory is calculated.")
parser.add_argument("--store", type=str, action="store", default=None,
                    help="Path to store the output as json")

networks = {}
class Network(object):
    @staticmethod
    def prepare_args(args):
        raise NotImplementedError

    def compute(self):
        raise NotImplementedError

class DynMLP(Network):
    def __init__(self, hidden_layers, output_units, bias=1, bits=64):
        self.hidden_layers = hidden_layers
        self.output_units = output_units
        self.bias = bias
        self.bits = bits

    @staticmethod
    def prepare_args(args):
        if len(args) < 2 and len(args) > 4:
            exit_with_error("The DynMLP requires as arguments [hidden layers] "
                            "[output units] [0/1 BIAS]?: %s" % ", ".join(args))

        converted = []
        for v, name in ([(args[0], "[hidden layers]"),
                         (args[1], "[output units]")] +
                        ([(args[2], "[0/1 BIAS]?")] if len(args) >= 3 else []) +
                        ([(args[3], "[BITS]")] if len(args) >= 4 else [])):
            try:
                converted.append(int(v))
            except ValueError:
                exit_with_error("%s has to be an int: %s" % (name, v))
            if converted[-1] < 0:
                exit_with_error("%s has to be at least 0: %i" %
                                (name, converted[-1]))
        if len(converted) >=3 and converted[2] > 1:
            exit_with_error("[0-1 BIAS]? has to be 0 or 1",
                            ExitCodes.INPUT_ERROR)

        return DynMLP(*converted)

    def compute(self, input_units):
        diff = math.fabs(input_units - self.output_units)
        step = int((diff / self.hidden_layers) * (-1 if input_units > self.output_units else 1))

        parameters = 0
        for i in range(self.hidden_layers):
            next_input_units = input_units + step
            parameters += (input_units + self.bias) * next_input_units
            input_units = next_input_units
        parameters += (input_units + self.bias) * self.output_units
        return parameters * self.bits


networks["DynMLP"] = DynMLP


def parse_args(argv):
    options = parser.parse_args(argv)

    for d in options.directory:
        if not os.path.exists(d):
            exit_with_error("A given root directory does not exist: %s" % d,
                            ExitCodes.INPUT_ERROR)
        elif not os.path.isdir(d):
            exit_with_error("A given root directory is not a directory: %s" % d,
                            ExitCodes.INPUT_ERROR)

    for no, p in enumerate(options.pattern):
        options.pattern[no] = re.compile(p)

    for no, p in enumerate(options.anti_pattern):
        options.anti_pattern[no] = re.compile(p)

    names = set()
    for item in options.data + options.network:
        if item[0] in names:
            exit_with_error("Cannot reuse a definition name: %s" % item[0],
                            ExitCodes.INPUT_ERROR)

    for no, item in enumerate(options.data):
        for i, name in [(1, "Data set"),(2, "Bit")]:
            try:
                options.data[no][i] = int(item[i])
            except ValueError:
                exit_with_error("%s size has to be an integer: %s" %
                                (name, item[i]), ExitCodes.INPUT_ERROR)
            if options.data[no][i] < 1:
                exit_with_error("%s size has to be at least 1: %i" %
                                (name, options.data[no][i]))

    for no, item in enumerate(options.network):
        if len(item) <= 2:
            exit_with_error("Every network needs at least to arguments "
                            "[NAME] [Type]: %s" % " ".join(item),
                            ExitCodes.INPUT_ERROR)

        if item[1] not in networks:
            exit_with_error("Unkown network type: %s" % item[1],
                            ExitCodes.INPUT_ERROR)
        network = networks[item[1]].prepare_args(item[2:])
        options.network[no] = (item[0], item[1], network)

    print(options)
    return options


def check_valid_domain(options, dir_data):

    if "domain.pddl" not in dir_data[2]:
        return False

    path = dir_data[0]
    for p in options.pattern:
        if not p.match(path):
            return False

    for p in options.anti_pattern:
        if p.match(path):
            return False

    return True

def get_domains(options):
    domains = set()
    for d in options.directory:
        domains.update([x[0] for x in os.walk(d)
                        if check_valid_domain(options, x)])
    return sorted(domains)

def get_nb_possible_objects(slot, type_counts, inv_type_hierarchy):
    s = 0
    todo = [slot]
    while len(todo) > 0:
        type_name = todo.pop()
        s += type_counts.get(type_name, 0)
        todo.extend(inv_type_hierarchy.get(type_name, []))
    return s


def get_domain_sizes(domain):
    path_domain_properties = os.path.join(domain,
                                          "domain_properties_no_statics.json")
    dp = DomainProperties.get_property_for(
        domain, no_gnd_static=True, verbose=6, parallize=True,
        load=path_domain_properties,
        store=None if os.path.exists(path_domain_properties) else path_domain_properties)
    assert dp.fixed_world
    objs = {}
    for to in dp.fixed_objects:
        if to.type_name not in objs:
            objs[to.type_name] = 0
        objs[to.type_name] += 1



    (domain_name, domain_requirements, types, type_dict, constants,
     predicates, predicate_dict, functions, actions, axioms) \
        = pddl_parser.parsing_functions.parse_domain_pddl(
        pddl_parser.pddl_file.parse_pddl_file("domain", os.path.join(domain, "domain.pddl")))
    pddl_domain = pddl.tasks.Domain(domain_name, domain_requirements, types,
                               constants, predicates, functions, actions,
                               axioms)


    full = 0
    for pred in dp.real_predicates:
        possibilities = 1
        for arg in pred.args:
            arg = pddl.TypedObject.from_string(arg)
            possibilities *= get_nb_possible_objects(
                arg.type_name, objs, pddl_domain.inv_type_hierarchy)
        full += possibilities

    return full, len(dp.gnd_flexible)


def get_domains_sizes(domains):
    common_prefix = os.path.commonprefix(domains)
    if common_prefix[-1] != os.sep:
        common_prefix = os.path.dirname(common_prefix)
    else:
        common_prefix = common_prefix[:-1]
    data = {common_prefix:{}}

    for d in domains:
        d = (d[:-1] if d[-1] == os.sep else d)
        full, small = get_domain_sizes(d)
        seq = d[len(common_prefix) + 1:].split(os.sep)

        d = data[common_prefix]
        for dirname in seq:
            if dirname not in d:
                d[dirname] = {}
            d = d[dirname]
        d["full"] = full
        d["non_constant"] = small

    return data


UNITS = ["", "KB", "MB", "GB", "TB"]
def conv_units(size):
    size /= 8  # to Bytes
    for no, name in enumerate(UNITS):
        if size < 1024:
            return "%.2f%s" % (size, UNITS[no])
        else:
            size /= 1024
    return "%.2f%s" % (size * 1024, UNITS[-1])


def get_memory(options, size):
    mems = {}
    for data in options.data:
        mem = size * data[2] * data[1]
        mems[data[0]] = conv_units(mem)
    for name, ntype, network in options.network:
        mem = network.compute(size)
        mems[name] = conv_units(mem)
    return mems


def add_memories(options, sizes):
    for key in sizes:
        if isinstance(sizes[key], dict):
            add_memories(options, sizes[key])
        elif isinstance(sizes[key], int):
            mem = get_memory(options, sizes[key])
            mem.update({"size": sizes[key]})
            sizes[key] = mem
        else:
            assert False, str(sizes[key])



def disp(sizes, current="", indent="    "):
    for key in sorted(sizes.keys()):
        if isinstance(sizes[key], dict):
            print(current + key + ":")
            disp(sizes[key], current + indent, indent)
        else:
            print(current + key + ":\t" + str(sizes[key]))


def run(argv):
    options = parse_args(argv)

    domains = get_domains(options)
    sizes = get_domains_sizes(domains)
    add_memories(options, sizes)

    if options.store is not None:
        with open(options.store, "w") as f:
            json.dump(sizes, f, sort_keys=True, indent=4)

    disp(sizes)
    print(sizes)


if __name__== "__main__":
    run(sys.argv[1:])

