#! /usr/bin/env python3

import argparse
from collections import defaultdict
import json
import numpy as np
import os
import sys
import re

import config

def type_is_file(arg):
    if not os.path.isfile(arg):
        raise argparse.ArgumentTypeError(f"Not a file: {arg}")
    return arg

ATT_COVERAGE = "coverage"
STATISTICS = {
    "sum": sum
}

parser = argparse.ArgumentParser()
parser.add_argument("task_set", choices=[x for x in config.TASK_SETS.keys()])
parser.add_argument("properties", type=type_is_file)
parser.add_argument("--algorithm-filter", "-a",
                    type=re.compile, action="append", default=[], required=True)
parser.add_argument("attribute", choices=[ATT_COVERAGE])
# parser.add_argument("--statistic", choices=[x for x in STATISTICS.keys()], required=True)
parser.add_argument("--states-per-task", type=int, required=True)


def natural_sort(l):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
    return sorted(l, key = alphanum_key)


def nan_median(x):
    if len(x) == 0:
        return "-"
    else:
        return np.median(x)

def fmt_float(x):
    if isinstance(x, str):
        return x
    else:
        return "%.1f" % x



def run(options):
    # [{domain: {(task, initial state): (algorithm_name, attribute)] per algorithm
    data = [defaultdict(lambda: defaultdict(lambda: None)) for _ in options.algorithm_filter]

    # List of the original tasks to consider
    tasks = config.TASK_SETS[options.task_set]
    all_domains = natural_sort(set([os.path.dirname(t) for t in tasks]))
    states_per_domain = {}
    for domain in all_domains:
        states_per_domain[domain] = len(
            [x for x in tasks if os.path.dirname(x) == domain]) * options.states_per_task

    # Load properties
    with open(options.properties, "r") as f:
        properties = json.load(f)
    for props in properties.values():
        if props["problem"].find("source") > -1:
            continue  # this is not a pddl we wanted to evaluate
        pseudo_domain = props["domain"]
        domain = os.path.dirname(pseudo_domain)
        task = os.path.basename(pseudo_domain)

        pseudo_domain = os.path.join(domain, config.rename_task(task))
        # assert domain in all_domains, domain
        if domain not in all_domains:
            continue
        # print(pseudo_domain)
        assert any(pseudo_domain in task_set
                   for task_set in config.TASK_SETS.values())
        if pseudo_domain not in tasks:
            continue

        algorithm = props["algorithm"]
        algorithm_mask = [no_af
                          for no_af, af in enumerate(options.algorithm_filter)
                          if af.match(algorithm)]
        if len(algorithm_mask) == 0:
            continue

        init_state = props["problem"]
        att = props.get(options.attribute, None)
        if att is None:
            continue

        for idx_algorithm in algorithm_mask:
            previous_entry = data[idx_algorithm][domain][(task, init_state)]
            if previous_entry is not None:
                l = natural_sort([previous_entry[0], algorithm])
                if l[-1] != algorithm:
                    # Later sorted -> higher index -> retrained model for a
                    # previous not successful model
                    continue
            data[idx_algorithm][domain][(task, init_state)] = (algorithm, att)
    properties = None  # free memory (once GC hits)

    # Print Table
    print("Task Set", options.task_set)
    print("\tAttribute: ", options.attribute)

    rows = [["Domain"] + [a for a in options.algorithm_filter]]
    to_average = [[] for _ in data]
    for domain in all_domains:
        next_row = [domain]
        for no_a, adata in enumerate(data):
            if domain not in adata:
                next_row.append("-")
                continue
            adata = adata[domain]
            avalues = []
            for algo_att in adata.values():
                avalues.append(algo_att[1])
            aval = 100*sum(avalues)/float(states_per_domain[domain])
            next_row.append(str(fmt_float(aval)) + "%")
            to_average[no_a].append(aval)
        rows.append(next_row)
    rows.append(["Average"] + [fmt_float(np.mean(ta)) for ta in to_average])
    print("\n".join(",".join(str(x) for x in row) for row in rows))


if __name__ == "__main__":
    run(parser.parse_args(sys.argv[1:]))
