from __future__ import print_function

from . import SamplerBridge, StateFormat, StopSamplingException, field_parser
from .common import generator_context_load_and_convert

from ... import main_register
from ... import parser
from ... import parser_tools as parset

from ...misc import StreamContext

import os
import psutil
import sys


class LoadSampleBridge(SamplerBridge):
    arguments = parset.ClassArguments(
        'LoadSampleBridge',
        SamplerBridge.arguments,
        ('streams', True, None, main_register.get_register(StreamContext),
         "StreamContext (usually with StreamDefinitions) for loading the "
         "associated data"),
        ('write_streams', True, None, main_register.get_register(StreamContext),
         "StreamContext (usually with StreamDefinitions) for writing the "
         "loaded data"),
        ("format", True, StateFormat.FD, StateFormat.get,
         "Format to represent the sampled state"),
        ("prune", True, True, parser.convert_bool, "Prune duplicate samples"),
        ("fprune", True, None, None,
         "Callable which given an item tells if the item shall be pruned."),
        ("skip", True, True, parser.convert_bool,
         "Skip problem if no samples exists, else raise error"),
        ("max_mem", True, None, float,
         "Maximum memory the loaded data in total may consume"),
        ("skip_magic", True, False, parser.convert_bool,
         "Skip magic word check (no guarantees on opening the files with the"
         " right tool (DEPRECATED)"),
        ("provide", True, True, parser.convert_bool),
        ("sample_types", True, None, str,
         "None to load all samples or an object which supports 'in'. "
         "Only samples where the sample type of an object is 'in' the given "
         "sample types will be loaded"),
        ("samples_per_problem", True, None,
         "Loads at most the specified number of samples per "
         "(problem_hash, modification_hash) pair. ASSUMPTION: SAMPLES FROM THE "
         "SAME PAIR ARE IN CONSECUTIVE ORDER IN THE DATA FILES"),
        ("max_container_samples", True, None, int,
         "Maximum number of samples it will load into a container (this "
         "includes samples already in the data_container and for "
         "_generator_sample this will not be updated between generator calls"
         "aka if you empty the data_container in between, it will not be "
         "noticed.)"),
        order=["streams", "write_streams","fields", "parse_kwargs", "unparse_kwargs",
               "format", "prune", "fprune", "forget", "max_container_samples",
               "skip", "max_mem", "sample_types",
               "samples_per_problem",
               "tmp_dir", "provide", "domain", "domain_properties",
               "domain_properties_loader",
               "makedir", "skip_magic",
               "environment", "id"]
)

    current_memory_usage = property(lambda self: self._cur_mem)

    def __init__(self, streams=None, write_streams=None, fields=None,
                 parse_kwargs=None, unparse_kwargs=None,
                 format=StateFormat.FD, prune=True, fprune=None, forget=0.0,
                 max_container_samples=None, skip=True,
                 max_mem=None, sample_types=None, samples_per_problem=None,
                 tmp_dir=None, provide=True, domain=None,
                 domain_properties=None, domain_properties_loader=None,
                 makedir=False, skip_magic=False, environment=None, id=None,
                 reference_states=None):
        SamplerBridge.__init__(self, tmp_dir, fields,
                               parse_kwargs, unparse_kwargs, provide, forget,
                               domain, domain_properties,
                               domain_properties_loader, makedir, environment, id)
        self._streams = StreamContext() if streams is None else streams
        self._write_streams = write_streams
        self._format = format
        self._prune = prune
        self._fprune = fprune
        self._skip = skip
        self._skip_magic = skip_magic

        self._max_mem = max_mem
        self._sample_types = sample_types
        self._samples_per_problem = samples_per_problem
        self._maximum_container_samples = max_container_samples
        assert (self._maximum_container_samples is None
                or self._maximum_container_samples > 0)
        self._cur_mem = 0
        self._reference_states = reference_states

    def _initialize(self):
        pass

    def _generator_sample(self, path_problem, path_dir_tmp, path_domain,
                          data_container, skip_lines, count_lines, batch_size,
                          **kwargs):
        assert len(kwargs) == 0, ("Unknown parameters for %s: %s" %
                                  (self.__class__.__name__,
                                   ", ".join(kwargs.keys())))

        if self._max_mem is not None:
            if self._cur_mem > self._max_mem:
                raise StopSamplingException(
                    "SamplingBridge memory limit reached.")
            process = psutil.Process(os.getpid())
            base_memory = process.memory_info().rss

        data_container = (
            None if not self._provide else self.get_default_container(
                data_container, path_problem,
            prune=self._prune, fprune=self._fprune))

        max_samples = (None if self._maximum_container_samples is None else
                       (self._maximum_container_samples - len(data_container)))
        if max_samples is not None and max_samples <= 0:
            return

        for c_samples, c_problems, data_container in generator_context_load_and_convert(
            self._streams, data_container,
            format=self._format, field_filter=self._fields,
            prune=self._prune,
            parse_kwargs=self._parse_kwargs,
            unparse_kwargs=self._unparse_kwargs,
            path_problem=path_problem, path_domain=path_domain,
            skip=self._skip, skip_magic=self._skip_magic,
            write_context=self._write_streams,
            forget=self._forget,
            domain_properties=self._domain_properties,
            max_memory=(None if self._max_mem is None
                     else self._max_mem - self._cur_mem),
            sample_types=self._sample_types,
            samples_per_problem=self._samples_per_problem,
            max_samples=max_samples,
            batch_size=batch_size,
            skip_lines=skip_lines,
            count_lines=count_lines,
            reference_states=self._reference_states
        ):
            yield c_samples, c_problems, data_container

        field_parser.clear_caches()

        if self._max_mem is not None:
            self._cur_mem += process.memory_info().rss - base_memory

    def _finalize(self):
        self._streams.finalize()

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  LoadSampleBridge)


main_register.append_register(LoadSampleBridge, "loadbridge")
