Module mici.utils
Utility functions and classes.
Expand source code Browse git
"""Utility functions and classes."""
import numpy as np
from math import log, exp, log1p, expm1, inf, nan
import mici
try:
import xxhash
XXHASH_AVAILABLE = True
except ImportError:
XXHASH_AVAILABLE = False
try:
import tqdm
import logging
class TqdmHandler(logging.StreamHandler):
"""Simple log handler which uses tqdm write method."""
def __init__(self):
super().__init__()
def emit(self, record):
msg = self.format(record)
tqdm.tqdm.write(msg)
def setup_tqdm_logger():
"""Returns a logger which redirects log output to `tqdm.write`."""
logger = logging.getLogger()
handler = TqdmHandler()
logger.addHandler(handler)
return logger
except ImportError:
pass
try:
import arviz
def convert_to_arviz_inference_data(
traces, chain_stats, sample_stats_key=None):
"""Wrap chain outputs in an `arviz.InferenceData` container object.
The `traces` and `chain_stats` arguments should correspond to a
multiple-chain sampler output i.e. the returned values from a
`sample_chains` call.
Args:
traces (Dict[str, List[array]]): Trace arrays, with one entry per
function in `trace_funcs` passed to sampler method. Each entry
consists of a list of arrays, one per chain, with the first
axes of the arrays corresponding to the sampling (draw) index.
chain_stats (Dict[str, List[array]]): Chain integration transition
statistics as a dictionary with string keys describing the
statistics recorded and values corresponding to a list of
arrays with one array per chain and the first axis of the
arrays corresponding to the sampling index.
sample_stats_key (str): Optional. Key of transition in
`chain_stats` to use the recorded statistics of to populate the
`sampling_stats` group in the returned `InferenceData` object.
Returns:
arviz.InferenceData:
An arviz data container with groups `posterior` and
'sample_stats', both of instances of `xarray.Dataset`. The
`posterior` group corresponds to the chain variable traces
provides in the `traces` argument and the `sample_stats`
group corresponds to the chain transitions statistics passed
in the `chain_stats` argument (if multiple transition
statistics dictionaries are present the `sample_stats_key`
argument should be specified to indicate which to use).
"""
if (sample_stats_key is not None and
sample_stats_key not in chain_stats):
raise ValueError(
f'Specified `sample_stats_key` ({sample_stats_key}) does '
f'not match any transition in `chain_stats`.')
if sample_stats_key is not None:
return arviz.InferenceData(
posterior=arviz.dict_to_dataset(traces, library=mici),
sample_stats=arviz.dict_to_dataset(
chain_stats[sample_stats_key], library=mici))
elif not isinstance(next(iter(chain_stats.values())), dict):
# chain_stats dictionary value not another dictionary therefore
# assume corresponds to statistics for a single transition
return arviz.InferenceData(
posterior=arviz.dict_to_dataset(traces, library=mici),
sample_stats=arviz.dict_to_dataset(chain_stats, library=mici))
elif len(chain_stats) == 1:
# single transtition statistics dictionary in chain_stats therefore
# unambiguous to set sample_stats
return arviz.InferenceData(
posterior=arviz.dict_to_dataset(traces, library=mici),
sample_stats=arviz.dict_to_dataset(
chain_stats.popitem()[1], library=mici))
else:
raise ValueError(
'`sample_stats_key` must be specified as `chain_stats` '
'contains multiple transtitiion statistics dictionaries.')
except ImportError:
pass
def hash_array(array):
"""Compute hash of a NumPy array by hashing data as a byte sequence.
Args:
array (array): NumPy array to compute hash of.
Returns:
hash (int): Computed hash as an integer.
"""
if XXHASH_AVAILABLE:
# If fast Python wrapper of fast xxhash implementation is available use
# in preference to built in hash function
h = xxhash.xxh64()
# Update hash by viewing array as byte sequence - no copy required
h.update(array.view(np.byte).data)
# Also update hash by array dtype, shape and strides to avoid clashes
# between different views of same array
h.update(bytes(f'{array.dtype}{array.shape}{array.strides}', 'utf-8'))
return h.intdigest()
else:
# Evaluate built-in hash function on *copy* of data as a byte sequence
return hash(array.tobytes())
LOG_2 = log(2.)
def log1p_exp(val):
"""Numerically stable implementation of `log(1 + exp(val))`."""
if val > 0.:
return val + log1p(exp(-val))
else:
return log1p(exp(val))
def log1m_exp(val):
"""Numerically stable implementation of `log(1 - exp(val))`."""
if val >= 0.:
return nan
elif val > LOG_2:
return log(-expm1(val))
else:
return log1p(-exp(val))
def log_sum_exp(val1, val2):
"""Numerically stable implementation of `log(exp(val1) + exp(val2))`."""
if val1 > val2:
return val1 + log1p_exp(val2 - val1)
else:
return val2 + log1p_exp(val1 - val2)
def log_diff_exp(val1, val2):
"""Numerically stable implementation of `log(exp(val1) - exp(val2))`."""
if val1 < val2:
return nan
elif val1 == val2:
return -inf
else:
return val1 + log1m_exp(val2 - val1)
class LogRepFloat(object):
"""Numerically stable logarithmic representation of positive float values.
Stores logarithm of value and overloads arithmetic operators to use more
numerically stable implementations where possible.
"""
def __init__(self, val=None, log_val=None):
if log_val is None:
if val is None:
raise ValueError('One of val or log_val must be specified.')
elif val > 0:
self.log_val = log(val)
elif val == 0.:
self.log_val = -inf
else:
raise ValueError('val must be non-negative.')
else:
if val is not None:
raise ValueError('Specify only one of val and log_val.')
else:
self.log_val = log_val
@property
def val(self):
try:
return exp(self.log_val)
except OverflowError:
return inf
def __add__(self, other):
if isinstance(other, LogRepFloat):
return LogRepFloat(
log_val=log_sum_exp(self.log_val, other.log_val))
else:
return self.val + other
def __radd__(self, other):
return self.__add__(other)
def __iadd__(self, other):
if isinstance(other, LogRepFloat):
self.log_val = log_sum_exp(self.log_val, other.log_val)
else:
self.log_val = log_sum_exp(self.log_val, log(other))
return self
def __sub__(self, other):
if isinstance(other, LogRepFloat):
if self.log_val >= other.log_val:
return LogRepFloat(
log_val=log_diff_exp(self.log_val, other.log_val))
else:
return self.val - other.val
else:
return self.val - other
def __rsub__(self, other):
return (-self).__radd__(other)
def __mul__(self, other):
if isinstance(other, LogRepFloat):
return LogRepFloat(log_val=self.log_val + other.log_val)
else:
return self.val * other
def __rmul__(self, other):
return self.__mul__(other)
def __truediv__(self, other):
if isinstance(other, LogRepFloat):
return LogRepFloat(log_val=self.log_val - other.log_val)
else:
return self.val / other
def __neg__(self):
return -self.val
def __eq__(self, other):
if isinstance(other, LogRepFloat):
return self.log_val == other.log_val
else:
return self.val == other
def __ne__(self, other):
if isinstance(other, LogRepFloat):
return self.log_val != other.log_val
else:
return self.val != other
def __lt__(self, other):
if isinstance(other, LogRepFloat):
return self.log_val < other.log_val
else:
return self.val < other
def __gt__(self, other):
if isinstance(other, LogRepFloat):
return self.log_val > other.log_val
else:
return self.val > other
def __le__(self, other):
if isinstance(other, LogRepFloat):
return self.log_val <= other.log_val
else:
return self.val <= other
def __ge__(self, other):
if isinstance(other, LogRepFloat):
return self.log_val >= other.log_val
else:
return self.val >= other
def __str__(self):
return str(self.val)
def __repr__(self):
return 'LogRepFloat(val={0})'.format(self.val)
Functions
def convert_to_arviz_inference_data(traces, chain_stats, sample_stats_key=None)
-
Wrap chain outputs in an
arviz.InferenceData
container object.The
traces
andchain_stats
arguments should correspond to a multiple-chain sampler output i.e. the returned values from asample_chains
call.Args
traces
:Dict
[str
,List
[array
]]- Trace arrays, with one entry per
function in
trace_funcs
passed to sampler method. Each entry consists of a list of arrays, one per chain, with the first axes of the arrays corresponding to the sampling (draw) index. chain_stats
:Dict
[str
,List
[array
]]- Chain integration transition statistics as a dictionary with string keys describing the statistics recorded and values corresponding to a list of arrays with one array per chain and the first axis of the arrays corresponding to the sampling index.
sample_stats_key
:str
- Optional. Key of transition in
chain_stats
to use the recorded statistics of to populate thesampling_stats
group in the returnedInferenceData
object.
Returns
arviz.InferenceData
:- An arviz data container with groups
posterior
and 'sample_stats', both of instances ofxarray.Dataset
. Theposterior
group corresponds to the chain variable traces provides in thetraces
argument and thesample_stats
group corresponds to the chain transitions statistics passed in thechain_stats
argument (if multiple transition statistics dictionaries are present thesample_stats_key
argument should be specified to indicate which to use).
Expand source code Browse git
def convert_to_arviz_inference_data( traces, chain_stats, sample_stats_key=None): """Wrap chain outputs in an `arviz.InferenceData` container object. The `traces` and `chain_stats` arguments should correspond to a multiple-chain sampler output i.e. the returned values from a `sample_chains` call. Args: traces (Dict[str, List[array]]): Trace arrays, with one entry per function in `trace_funcs` passed to sampler method. Each entry consists of a list of arrays, one per chain, with the first axes of the arrays corresponding to the sampling (draw) index. chain_stats (Dict[str, List[array]]): Chain integration transition statistics as a dictionary with string keys describing the statistics recorded and values corresponding to a list of arrays with one array per chain and the first axis of the arrays corresponding to the sampling index. sample_stats_key (str): Optional. Key of transition in `chain_stats` to use the recorded statistics of to populate the `sampling_stats` group in the returned `InferenceData` object. Returns: arviz.InferenceData: An arviz data container with groups `posterior` and 'sample_stats', both of instances of `xarray.Dataset`. The `posterior` group corresponds to the chain variable traces provides in the `traces` argument and the `sample_stats` group corresponds to the chain transitions statistics passed in the `chain_stats` argument (if multiple transition statistics dictionaries are present the `sample_stats_key` argument should be specified to indicate which to use). """ if (sample_stats_key is not None and sample_stats_key not in chain_stats): raise ValueError( f'Specified `sample_stats_key` ({sample_stats_key}) does ' f'not match any transition in `chain_stats`.') if sample_stats_key is not None: return arviz.InferenceData( posterior=arviz.dict_to_dataset(traces, library=mici), sample_stats=arviz.dict_to_dataset( chain_stats[sample_stats_key], library=mici)) elif not isinstance(next(iter(chain_stats.values())), dict): # chain_stats dictionary value not another dictionary therefore # assume corresponds to statistics for a single transition return arviz.InferenceData( posterior=arviz.dict_to_dataset(traces, library=mici), sample_stats=arviz.dict_to_dataset(chain_stats, library=mici)) elif len(chain_stats) == 1: # single transtition statistics dictionary in chain_stats therefore # unambiguous to set sample_stats return arviz.InferenceData( posterior=arviz.dict_to_dataset(traces, library=mici), sample_stats=arviz.dict_to_dataset( chain_stats.popitem()[1], library=mici)) else: raise ValueError( '`sample_stats_key` must be specified as `chain_stats` ' 'contains multiple transtitiion statistics dictionaries.')
def hash_array(array)
-
Compute hash of a NumPy array by hashing data as a byte sequence.
Args
array
:array
- NumPy array to compute hash of.
Returns
hash
:int
- Computed hash as an integer.
Expand source code Browse git
def hash_array(array): """Compute hash of a NumPy array by hashing data as a byte sequence. Args: array (array): NumPy array to compute hash of. Returns: hash (int): Computed hash as an integer. """ if XXHASH_AVAILABLE: # If fast Python wrapper of fast xxhash implementation is available use # in preference to built in hash function h = xxhash.xxh64() # Update hash by viewing array as byte sequence - no copy required h.update(array.view(np.byte).data) # Also update hash by array dtype, shape and strides to avoid clashes # between different views of same array h.update(bytes(f'{array.dtype}{array.shape}{array.strides}', 'utf-8')) return h.intdigest() else: # Evaluate built-in hash function on *copy* of data as a byte sequence return hash(array.tobytes())
def log1p_exp(val)
-
Numerically stable implementation of
log(1 + exp(val))
.Expand source code Browse git
def log1p_exp(val): """Numerically stable implementation of `log(1 + exp(val))`.""" if val > 0.: return val + log1p(exp(-val)) else: return log1p(exp(val))
def log1m_exp(val)
-
Numerically stable implementation of
log(1 - exp(val))
.Expand source code Browse git
def log1m_exp(val): """Numerically stable implementation of `log(1 - exp(val))`.""" if val >= 0.: return nan elif val > LOG_2: return log(-expm1(val)) else: return log1p(-exp(val))
def log_sum_exp(val1, val2)
-
Numerically stable implementation of
log(exp(val1) + exp(val2))
.Expand source code Browse git
def log_sum_exp(val1, val2): """Numerically stable implementation of `log(exp(val1) + exp(val2))`.""" if val1 > val2: return val1 + log1p_exp(val2 - val1) else: return val2 + log1p_exp(val1 - val2)
def log_diff_exp(val1, val2)
-
Numerically stable implementation of
log(exp(val1) - exp(val2))
.Expand source code Browse git
def log_diff_exp(val1, val2): """Numerically stable implementation of `log(exp(val1) - exp(val2))`.""" if val1 < val2: return nan elif val1 == val2: return -inf else: return val1 + log1m_exp(val2 - val1)
Classes
class LogRepFloat (val=None, log_val=None)
-
Numerically stable logarithmic representation of positive float values.
Stores logarithm of value and overloads arithmetic operators to use more numerically stable implementations where possible.
Expand source code Browse git
class LogRepFloat(object): """Numerically stable logarithmic representation of positive float values. Stores logarithm of value and overloads arithmetic operators to use more numerically stable implementations where possible. """ def __init__(self, val=None, log_val=None): if log_val is None: if val is None: raise ValueError('One of val or log_val must be specified.') elif val > 0: self.log_val = log(val) elif val == 0.: self.log_val = -inf else: raise ValueError('val must be non-negative.') else: if val is not None: raise ValueError('Specify only one of val and log_val.') else: self.log_val = log_val @property def val(self): try: return exp(self.log_val) except OverflowError: return inf def __add__(self, other): if isinstance(other, LogRepFloat): return LogRepFloat( log_val=log_sum_exp(self.log_val, other.log_val)) else: return self.val + other def __radd__(self, other): return self.__add__(other) def __iadd__(self, other): if isinstance(other, LogRepFloat): self.log_val = log_sum_exp(self.log_val, other.log_val) else: self.log_val = log_sum_exp(self.log_val, log(other)) return self def __sub__(self, other): if isinstance(other, LogRepFloat): if self.log_val >= other.log_val: return LogRepFloat( log_val=log_diff_exp(self.log_val, other.log_val)) else: return self.val - other.val else: return self.val - other def __rsub__(self, other): return (-self).__radd__(other) def __mul__(self, other): if isinstance(other, LogRepFloat): return LogRepFloat(log_val=self.log_val + other.log_val) else: return self.val * other def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other): if isinstance(other, LogRepFloat): return LogRepFloat(log_val=self.log_val - other.log_val) else: return self.val / other def __neg__(self): return -self.val def __eq__(self, other): if isinstance(other, LogRepFloat): return self.log_val == other.log_val else: return self.val == other def __ne__(self, other): if isinstance(other, LogRepFloat): return self.log_val != other.log_val else: return self.val != other def __lt__(self, other): if isinstance(other, LogRepFloat): return self.log_val < other.log_val else: return self.val < other def __gt__(self, other): if isinstance(other, LogRepFloat): return self.log_val > other.log_val else: return self.val > other def __le__(self, other): if isinstance(other, LogRepFloat): return self.log_val <= other.log_val else: return self.val <= other def __ge__(self, other): if isinstance(other, LogRepFloat): return self.log_val >= other.log_val else: return self.val >= other def __str__(self): return str(self.val) def __repr__(self): return 'LogRepFloat(val={0})'.format(self.val)
Instance variables
var val
-
Expand source code Browse git
@property def val(self): try: return exp(self.log_val) except OverflowError: return inf