from abc import ABCMeta

from .keras_network import KerasNetwork

from ... import parser_tools as parset
from ... import parser
from ... import main_register

from ...parser_tools import ArgumentException


class KerasFieldsNetwork(KerasNetwork):
    __metaclass__ = ABCMeta
    arguments = parset.ClassArguments(
        'KerasFieldsNetwork',
        KerasNetwork.arguments,
        ("x_fields", False, None, str,
         "Single or list of field names which make up the network input "
         "(in the order their appear in the input)."),
        ("y_fields", False, None, str,
         "Single or list of field names which make up the network output "
         "(in the order their appear in the input)."),
        order=["x_fields", "y_fields", "tparams", "load", "store",
               "learner_formats", "out",
               "count_samples", "test_similarity", "graphdef",
               "variables", "id"])

    def __init__(self, x_fields, y_fields, tparams=None, load=None, store=None,
                 learner_formats=None, out=".",
                 count_samples=False, test_similarity=None,
                 graphdef=None,
                 variables=None, id=None):

        KerasNetwork.__init__(
            self, tparams, load, store, learner_formats, out, count_samples,
            test_similarity, graphdef, variables, id)

        self._x_field_names = (x_fields if isinstance(x_fields, list)
                               else [x_fields])
        self._y_field_names = (y_fields if isinstance(y_fields, list)
                               else [y_fields])

        self._x_fields_extractor = lambda ds: [getattr(ds, "field_" + xfn)
                                               for xfn in self._x_field_names]
        self._y_fields_extractor = lambda ds: [getattr(ds, "field_" + yfn)
                                               for yfn in self._y_field_names]

    @staticmethod
    def parse(tree, item_cache):
        obj = parser.try_lookup_obj(tree, item_cache, KerasNetwork, None)
        if obj is not None:
            return obj
        else:
            raise ArgumentException("The definition of KerasFieldsNetwork can "
                                    "only be used for look up of any previously"
                                    " defined KerasFieldsNetwork via "
                                    "'keras_fields_network(id=ID)'")


main_register.append_register(KerasFieldsNetwork, "keras_fields_network")
