# The MIT License (MIT) # Copyright (c) 2014-2017 University of Bristol
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
# OR OTHER DEALINGS IN THE SOFTWARE.
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
from future.utils import python_2_unicode_compatible
# import pprint
from treelib.tree import Tree, NodePropertyAbsentError, NodeIDAbsentError
import simplejson as json
from bidict import bidict, ValueDuplicationError
import logging
import os
import sys
# import coloredlogs
from .. import __version__
# The next two lines are to fix the "UnicodeDecodeError: 'ascii' codec can't decode byte" error
# http://stackoverflow.com/questions/21129020/how-to-fix-unicodedecodeerror-ascii-codec-cant-decode-byte
reload(sys)
sys.setdefaultencoding('utf8')
@python_2_unicode_compatible
@python_2_unicode_compatible
[docs]class Printable(object):
"""
A base class for default printing
"""
def __str__(self):
# pp = pprint.PrettyPrinter(indent=4)
# return pp.pformat({self.__class__.__name__: self.__dict__})
return repr(self)
def __repr__(self):
name = self.__class__.__name__
values = ", ".join("{}={}".format(k, repr(v)) for k, v in sorted(self.__dict__.items()) if k[0] != "_")
return "{}({})".format(name, values)
[docs]class Hashable(object):
_name = None
"""
A base class that creates hashes based on the __dict__
Requires keys to be strings to work properly. It will first try to use json.dumps, but if that fails because one of
the values is not json serializable (e.g. datetime.datetime) then it will fall back on repr
"""
@property
def name(self):
return self._name if self._name is not None else self.__class__.__name__
@name.setter
def name(self, name):
self._name = str(name)
def __hash__(self):
try:
return hash((self.name, json.dumps(self.__dict__, sort_keys=True)))
except TypeError:
return hash((self.name, repr(sorted(self.__dict__.items()))))
[docs]class TypedBiDict(Printable):
"""
Custom strongly typed bi-directional dictionary where keys and values must be a specific type.
Raises ValueDuplicationError if the same value is added again
"""
def __init__(self, key_type, value_type, *args, **kwargs):
if not isinstance(key_type, type):
raise ValueError("expected type, got {}", type(key_type))
if not isinstance(value_type, type):
raise ValueError("expected type, got {}", type(value_type))
self._store = bidict(*args, **kwargs)
self.key_type = key_type
self.value_type = value_type
def __repr__(self):
return "{}(key_type={}, value_type={})".format(
self.__class__.__name__,
repr(self.key_type),
repr(self.value_type))
def __iter__(self):
return iter(self._store)
def __len__(self):
return len(self._store)
def __getitem__(self, key):
if not isinstance(key, self.key_type):
raise TypeError("expected {}, got {}".format(self.key_type, type(key)))
try:
return self._store[key]
except KeyError as e:
# for debugging
raise e
def __setitem__(self, key, value):
if not isinstance(key, self.key_type):
raise TypeError("expected {}, got {}".format(self.key_type, type(key)))
if not isinstance(value, self.value_type):
raise ValueError("expected {}, got {}".format(self.value_type, type(value)))
try:
self._store[key] = value
except ValueDuplicationError as e:
# TODO: debugging
raise e
def __contains__(self, item):
return item in self._store
[docs] def keys(self):
return self._store.keys()
[docs] def values(self):
return self._store.values()
[docs] def items(self):
return self._store.items()
[docs] def iterkeys(self):
return self._store.iterkeys()
[docs] def itervalues(self):
return self._store.itervalues()
[docs] def iteritems(self):
return self._store.iteritems()
[docs]class FrozenKeyDict(dict):
def __setitem__(self, key, value):
if key in self:
# Try to reconcile the new value with the old
old = self[key]
if isinstance(value, dict) and isinstance(old, dict):
for k in value:
if k in old:
if value[k] != old[k]:
raise KeyError(
"Key {} has already been set with value {}, new value {}".format(key, self[key], value))
continue
self[key][k] = value[k]
else:
raise KeyError("Key {} has already been set with value {}, new value {}".format(key, self[key], value))
return
super(FrozenKeyDict, self).__setitem__(key, value)
[docs]class TypedFrozenKeyDict(FrozenKeyDict):
def __init__(self, key_type, *args, **kwargs):
if not isinstance(key_type, type):
raise ValueError("Expected type, got {}".format(type(key_type)))
self.key_type = key_type
super(TypedFrozenKeyDict, self).__init__(*args, **kwargs)
def __setitem__(self, key, value):
if not isinstance(key, self.key_type):
raise KeyError("Expected type {}, got {}".format(self.key_type, type(key)))
super(TypedFrozenKeyDict, self).__setitem__(key, value)
[docs]def touch(full_name, times=None):
with open(full_name, 'a'):
os.utime(full_name, times)
[docs]def handle_exception(exc_type, exc_value, exc_traceback):
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logging.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
[docs]class HyperStreamLogger(Printable):
def __init__(self, path='/tmp', filename='hyperstream', loglevel=logging.DEBUG):
# coloredlogs.install(level=loglevel)
log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s")
root_logger = logging.getLogger()
root_logger.setLevel(loglevel)
if not os.path.exists(path):
os.makedirs(path)
if not filename.endswith('.log'):
filename += '.log'
full_name = os.path.join(path, filename)
touch(full_name)
file_handler = logging.FileHandler(full_name)
file_handler.setFormatter(log_formatter)
root_logger.addHandler(file_handler)
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_formatter)
root_logger.addHandler(console_handler)
#
# stream_handler = logging.StreamHandler()
# stream_handler.setFormatter(log_formatter)
# memory_handler = logging.handlers.MemoryHandler(1024 * 10, root_logger.level, stream_handler)
# root_logger.addHandler(memory_handler)
# Capture warnings
logging.captureWarnings(True)
# Capture uncaught exceptions
sys.excepthook = handle_exception
# logging.config.dictConfig(LOGGING)
logging.debug("HyperStream version: " + __version__)