#!/usr/bin/env python
import sys
sys.path.append("../../")

from src.training.misc import StreamContext, GzipStreamDefinition
from src.training.samplers import IterableFileSampler
from src.training.bridges import StateFormat, LoadSampleBridge
from src.training.networks.keras_networks.keras_tools import KerasDataGenerator

import keras
import numpy as np
import os
import re



PATTERN_DATA_TO_PROBLEM = re.compile(".*(p\d+).*?")
PATTERN_FOLD = re.compile(".*_(\d+)_fold.*?")

ROOT = "../../../DeePDown/data/FixedWorlds/opt/"
# Prefixes shall not be one prefix a sub of another
PREFIXES = ["uniform_mlp_h5_"]
FIELDS = ["current_state", "goals", "hplan"]
FORMAT = StateFormat.All_A_01

OUT_DIR = "scores"

ERROR_IF_MULTIPLE_FOR_SAME = True

def detect_all_fixed_universe(path, prefixes):
    mapping = {}
    for domain_name in os.listdir(path):
        path_domain = os.path.join(path, domain_name)
        if not os.path.isdir(path_domain):
            continue
        mapping[domain_name] = {}

        for fixed_universe in os.listdir(path_domain):
            path_universe = os.path.join(path_domain, fixed_universe)
            if not os.path.isdir(path_universe):
                continue

            data = ([], [], [])  # Problems, Data, Models
            has_domain = False
            for item in os.listdir(path_universe):
                path_item = os.path.join(path_universe, item)
                if os.path.isdir(path_item):
                    continue
                if item.endswith(".pddl"):
                    if item.find("domain") > -1:
                        has_domain = True
                    else:
                        data[0].append(path_item)
                elif item.endswith(".data.gz"):
                    data[1].append(path_item)
                elif item.endswith(".h5"):
                    for p in prefixes:
                        if item.startswith(p):
                            data[2].append(path_item)
                            break

            # Filter Problem and data has to exists for both
            present_problems = set()
            for p in data[0]:
                present_problems.add(os.path.basename(p[:-len(".pddl")]))
            present_datas = set()
            for d in data[1]:
                pid = PATTERN_DATA_TO_PROBLEM.findall(d)
                assert len(pid) == 1
                present_datas.add(pid[0])
            common = present_problems & present_datas
            for idx in range(len(data[0]) - 1, -1, -1):
                pid = os.path.basename(data[0][idx][:-len(".pddl")])
                if pid not in common:
                    del data[0][idx]
            for idx in range(len(data[1]) - 1, -1, -1):
                pid = PATTERN_DATA_TO_PROBLEM.findall(d)
                assert len(pid) == 1
                pid = pid[0]
                if pid not in common:
                    del data[1][idx]

            # Associate data with problems
            pid2d = {}
            for no, d in enumerate(data[1]):
                pid = PATTERN_DATA_TO_PROBLEM.findall(d)[0]
                if not pid in pid2d:
                    pid2d[pid] = []
                pid2d[pid].append(d)
            p2d = {}
            for p in data[0]:
                pid = os.path.basename(p[:-len(".pddl")])
                p2d[p] = pid2d[pid]

            data = (p2d, data[2])
            if has_domain and len(data[0]) > 0 and len(data[1]) > 0:
                mapping[domain_name][fixed_universe] = data

        if len(mapping[domain_name]) == 0:
            del mapping[domain_name]
    return mapping


def get_prefix_to_model(prefixes, path_model):
    basename = os.path.basename(path_model)
    for p in prefixes:
        if basename.startswith(p):
            return p
    assert False

def estimate_suffix_for_stream(problem, data):
    problem = problem[:-len(".pddl")]
    assert data.startswith(problem)
    return data[len(problem):]


def evaluate(path_problem, path_model, path_data, domain_properties=None):
    stream = None
    if path_data.endswith(".gz"):
        stream = GzipStreamDefinition(None, None, None, None,
                                      suffix=estimate_suffix_for_stream(path_problem, path_data))
    else:
        assert False, "Unknown suffix for stream"

    bridge = LoadSampleBridge(streams=StreamContext(streams=[stream]),
                              fields=FIELDS,
                              format=FORMAT, prune=True,
                              skip=False, skip_magic=False,
                              domain_properties=domain_properties)

    sampler = IterableFileSampler(sampler_bridge=bridge,
                                  iterable=[path_problem])

    sampler.initialize()
    dtest = sampler.sample()
    sampler.finalize()
    if len(dtest) == 0:
        return None

    for d in dtest:
        d.finalize()

    model = keras.models.load_model(path_model)
    output_units = model.outputs[0].shape[1].value
    y_labels = []
    assert output_units > 0

    def convert_y_one_hot(y):
        Y = np.zeros((y.shape[0], output_units))
        Y[np.arange(y.shape[0]), y[:, 0].astype(int).clip(0, output_units - 1)] = 1
        return Y

    kdg_test = KerasDataGenerator(
        dtest,
        x_fields=[getattr(dtest[0], "field_" + xfn) for xfn in FIELDS[:-1]],
        y_fields=[getattr(dtest[0], "field_" + xfn) for xfn in [FIELDS[-1]]],
        x_converter=lambda x: [np.stack(x[:, i], axis=0) for i in range(len(FIELDS) - 1)],
        y_converter=None if output_units == 1 else convert_y_one_hot,
        y_remember=y_labels
    )

    result = model.predict_generator(kdg_test, max_queue_size=10, workers=1,
                                     use_multiprocessing=False)
    y_labels = np.concatenate(y_labels)
    if y_labels.shape[1] == 1:
        y_labels = y_labels.squeeze(axis=1)
        result = result.squeeze(axis=1)
    y_labels = y_labels[:len(result)]

    return np.mean(np.absolute(result-y_labels))

def evaluate_all_iteratively(mapping, prefixes):
    all_mae = {}
    for domain in mapping:
        print("Next Domain: ", domain)
        for universe in mapping[domain]:
            print("\tNext Universe:", universe)
            entries = mapping[domain][universe]
            for p in entries[0]:
                pid = os.path.basename(p[:-len(".pddl")])
                assoc_fold = None
                try:
                    assoc_fold = int(pid[1:])
                    assoc_fold = int(max(0, assoc_fold - 1) / 20)
                except:
                    pass

                for m in entries[1]:
                    match = PATTERN_FOLD.match(m)
                    # If the problem is not associated to a fold -> eval always
                    # Elif model not associated to fold -> eval always
                    # Else if associated fold match -> eval
                    if (assoc_fold is not None) and match is not None and assoc_fold != int(match.group(1)):
                        continue

                    assoc_prefix = get_prefix_to_model(prefixes, m)

                    for d in entries[0][p]:
                        mae = evaluate(p, m, d)
                        if mae is None:
                            continue
                        if d not in all_mae:
                            all_mae[d] = {}
                        if assoc_prefix not in all_mae[d]:
                            all_mae[d][assoc_prefix] = mae
                        else:
                            assert not ERROR_IF_MULTIPLE_FOR_SAME, "Multiple networks would evaluate the same data"
    return all_mae


def write_csv(mae, path):
    map_data = {}
    for no, d in enumerate(sorted(mae.keys())):
        map_data[d] = no
    used_prefixes = set()
    for d in mae:
        used_prefixes.update(mae[d].keys())
    map_prefixes = {}
    for no, p in enumerate(used_prefixes):
        map_prefixes[p] = no

    ary = np.ndarray(shape=(len(mae) + 1, len(used_prefixes) + 1), dtype=object)
    for no, d in enumerate(mae.keys()):
        ary[no + 1, 0] = d
    for p, no in map_prefixes.items():
        ary[0, no + 1] = p

    ary[0, 0] = ""
    ary[1:, 1:] = np.nan
    for d in mae:
        idx_row = map_data[d] + 1
        for p in mae[d]:
            idx_col = map_prefixes[p] + 1
            ary[idx_row, idx_col] = mae[d][p]

    np.savetxt(path, ary, delimiter=",", fmt="%s")

    return ary


def write_more_csv(ary, outdir):
    count_not_nan = (~np.isnan(ary[1:, 1:].astype("float"))).sum(axis=0)
    np.stack((ary[0, 1:], count_not_nan))
    np.savetxt(os.path.join(outdir, "count_not_nan.csv"), count_not_nan, delimiter=",", fmt="%s")

    scores = np.copy(ary)
    for i in range(1, scores.shape[0]):
        row = scores[i, 1:].astype("float")
        minimum = np.nanmin(row)
        maximum = np.nanmax(row)
        if not (np.isnan(minimum) or maximum == minimum):
            row = (row - minimum) / (maximum - minimum)
            scores[i, 1:] = row
        else:
            scores[i, 1:] = np.nan
    np.savetxt(os.path.join(outdir, "network_scores_all.csv"), scores,
               delimiter=",", fmt="%s")

    domains = {}
    for i in range(1, ary.shape[0]):
        domain = os.path.dirname(os.path.dirname(ary[i, 0]))
        if domain not in domains:
            domains[domain] = []
        domains[domain].append(i)
    domain_scores = np.ndarray(shape=(len(domains) + 1, ary.shape[1]), dtype="object")
    domain_scores[0, :] = ary[0, :]
    for no, domain in enumerate(sorted(domains.keys())):
        domain_scores[no + 1, 0] = domain
        for idx_col in range(1, scores.shape[1]):
            domain_scores[no + 1, idx_col] = np.nanmean(scores[domains[domain], idx_col].astype("float"))
    np.savetxt(os.path.join(outdir, "network_scores_domain.csv"), domain_scores,
                   delimiter=",", fmt="%s")


if __name__ == "__main__":
    m = detect_all_fixed_universe(ROOT, PREFIXES)
    MAE = evaluate_all_iteratively(m, PREFIXES)
    if not os.path.exists(OUT_DIR):
        os.makedirs(OUT_DIR)
    ary = write_csv(MAE, os.path.join(OUT_DIR, "mae.csv"))
    write_more_csv(ary, OUT_DIR)

