#!/usr/bin/env python3
from __future__ import print_function
import sys
sys.path.append("../../")

import argparse
import json
import os
import re
import subprocess

class ExitCodes(object):
    INPUT_ERROR = 1

def exit_with_error(msg, code):
    print(msg, file=sys.stderr)
    sys.exit(code)

PATTERN_COUNT = re.compile("Data sizes: (\d+),")
PATTERN_MEMORY_DATA = re.compile("Data set memory consumption: (\d+\.?\d*)MB")
PATTERN_MEMORY_NETWORK = re.compile("Network initialization time: "
                                    "\d+\.?\d*s\n\s+memory: (\d+\.?\d*)MB",
                                    re.DOTALL)

parser = argparse.ArgumentParser("use as final argument --call and write there"
                                 "the call arguments for ./fast-sample.py."
                                 "use {PATH} for the argument describing the"
                                 "domain directory.")
parser.add_argument("directory", type=str, action="store", nargs="*",
                    help="Any number (at least one) directories where all"
                         "domains in there shall be analysed.")
parser.add_argument("-p", "--pattern", type=str, action="append", default=[],
                    help="Define a regular expressions which have to be matched"
                         "by any domain.")
parser.add_argument("--anti-pattern", type=str, action="append", default=[],
                    help="Define a regular expression which may NOT match "
                         "domains.")
parser.add_argument("--store", type=str, action="store", default=None,
                    help="Path to store the output as json")


def parse_args(argv):
    call = None
    for no, arg in enumerate(argv):
        if arg == "--call":
            call = ["../../fast-training.py"] + argv[no + 1:]
            argv = argv[:no]

    options = parser.parse_args(argv)

    options.call = call

    for d in options.directory:
        if not os.path.exists(d):
            exit_with_error("A given root directory does not exist: %s" % d,
                            ExitCodes.INPUT_ERROR)
        elif not os.path.isdir(d):
            exit_with_error("A given root directory is not a directory: %s" % d,
                            ExitCodes.INPUT_ERROR)

    for no, p in enumerate(options.pattern):
        options.pattern[no] = re.compile(p)

    for no, p in enumerate(options.anti_pattern):
        options.anti_pattern[no] = re.compile(p)
    print(options)
    return options


def check_valid_domain(options, dir_data):

    if "domain.pddl" not in dir_data[2]:
        return False

    path = dir_data[0]
    for p in options.pattern:
        if not p.match(path):
            return False

    for p in options.anti_pattern:
        if p.match(path):
            return False

    return True

def get_domains(options):
    domains = set()
    for d in options.directory:
        domains.update([x[0] for x in os.walk(d)
                        if check_valid_domain(options, x)])
    return sorted(domains)

def get_nb_possible_objects(slot, type_counts, inv_type_hierarchy):
    s = 0
    todo = [slot]
    while len(todo) > 0:
        type_name = todo.pop()
        s += type_counts.get(type_name, 0)
        todo.extend(inv_type_hierarchy.get(type_name, []))
    return s


def get_domain_sizes(options, domain):
    call = options.call.copy()
    for no, item in enumerate(call):
        if item == "{PATH}":
            call[no] = domain

    try:
        output = subprocess.check_output(call, stderr=subprocess.STDOUT).decode()
    except subprocess.CalledProcessError as e:
        print("Err:", e.output, "\n", e.returncode)
        return float("nan"), float("nan"), float("nan")
    res = [PATTERN_COUNT.findall(output),
           PATTERN_MEMORY_DATA.findall(output),
           PATTERN_MEMORY_NETWORK.findall(output)]

    for r in res:
        assert len(r) == 1

    return [float(x[0]) for x in res]


def get_domains_sizes(options,domains):
    common_prefix = os.path.commonprefix(domains)
    if common_prefix[-1] != os.sep:
        common_prefix = os.path.dirname(common_prefix)
    else:
        common_prefix = common_prefix[:-1]
    data = {common_prefix: {}}

    for d in domains:
        d = (d[:-1] if d[-1] == os.sep else d)
        count, dataset, network = get_domain_sizes(options, d)
        seq = d[len(common_prefix) + 1:].split(os.sep)
        print(d)
        d = data[common_prefix]
        for dirname in seq:
            if dirname not in d:
                d[dirname] = {}
            d = d[dirname]
        d["samples"] = count
        d["memory_samples"] = dataset
        d["memory_1000_samples"] = 1000 * dataset / count
        d["memory_network"] = network

        print("\t", d)

    return data


def disp(sizes, current="", indent="    "):
    for key in sorted(sizes.keys()):
        if isinstance(sizes[key], dict):
            print(current + key + ":")
            disp(sizes[key], current + indent, indent)
        else:
            print(current + key + ":\t" + str(sizes[key]))


def run(argv):
    options = parse_args(argv)

    domains = get_domains(options)
    sizes = get_domains_sizes(options, domains)

    if options.store is not None:
        with open(options.store, "w") as f:
            json.dump(sizes, f, sort_keys=True, indent=4)

    disp(sizes)
    print(sizes)


if __name__== "__main__":
    run(sys.argv[1:])

