Module mici.utils

Utility functions and classes.

Expand source code Browse git
"""Utility functions and classes."""

import numpy as np
from math import log, exp, log1p, expm1, inf, nan
import mici
try:
    import xxhash
    XXHASH_AVAILABLE = True
except ImportError:
    XXHASH_AVAILABLE = False

try:

    import tqdm
    import logging

    class TqdmHandler(logging.StreamHandler):
        """Simple log handler which uses tqdm write method."""

        def __init__(self):
            super().__init__()

        def emit(self, record):
            msg = self.format(record)
            tqdm.tqdm.write(msg)

    def setup_tqdm_logger():
        """Returns a logger which redirects log output to `tqdm.write`."""
        logger = logging.getLogger()
        handler = TqdmHandler()
        logger.addHandler(handler)
        return logger

except ImportError:
    pass


try:

    import arviz

    def convert_to_arviz_inference_data(
            traces, chain_stats, sample_stats_key=None):
        """Wrap chain outputs in an `arviz.InferenceData` container object.

        The `traces` and `chain_stats` arguments should correspond to a
        multiple-chain sampler output i.e. the returned values from a
        `sample_chains` call.

        Args:
            traces (Dict[str, List[array]]): Trace arrays, with one entry per
                function in `trace_funcs` passed to sampler method. Each entry
                consists of a list of arrays, one per chain, with the first
                axes of the arrays corresponding to the sampling (draw) index.
            chain_stats (Dict[str, List[array]]): Chain integration transition
                statistics as a dictionary with string keys describing the
                statistics recorded and values corresponding to a list of
                arrays with one array per chain and the first axis of the
                arrays corresponding to the sampling index.
            sample_stats_key (str): Optional. Key of transition in
                `chain_stats` to use the recorded statistics of to populate the
                `sampling_stats` group in the returned `InferenceData` object.

        Returns:
            arviz.InferenceData:
                An arviz data container with groups `posterior` and
                'sample_stats', both of instances of `xarray.Dataset`. The
                `posterior` group corresponds to the chain variable traces
                provides in the `traces` argument and the `sample_stats`
                group corresponds to the chain transitions statistics passed
                in the `chain_stats` argument (if multiple transition
                statistics dictionaries are present the `sample_stats_key`
                argument should be specified to indicate which to use).
        """
        if (sample_stats_key is not None and
                sample_stats_key not in chain_stats):
            raise ValueError(
                f'Specified `sample_stats_key` ({sample_stats_key}) does '
                f'not match any transition in `chain_stats`.')
        if sample_stats_key is not None:
            return arviz.InferenceData(
                posterior=arviz.dict_to_dataset(traces, library=mici),
                sample_stats=arviz.dict_to_dataset(
                    chain_stats[sample_stats_key], library=mici))
        elif not isinstance(next(iter(chain_stats.values())), dict):
            # chain_stats dictionary value not another dictionary therefore
            # assume corresponds to statistics for a single transition
            return arviz.InferenceData(
                posterior=arviz.dict_to_dataset(traces, library=mici),
                sample_stats=arviz.dict_to_dataset(chain_stats, library=mici))
        elif len(chain_stats) == 1:
            # single transtition statistics dictionary in chain_stats therefore
            # unambiguous to set sample_stats
            return arviz.InferenceData(
                posterior=arviz.dict_to_dataset(traces, library=mici),
                sample_stats=arviz.dict_to_dataset(
                    chain_stats.popitem()[1], library=mici))
        else:
            raise ValueError(
                '`sample_stats_key` must be specified as `chain_stats` '
                'contains multiple transtitiion statistics dictionaries.')

except ImportError:
    pass


def hash_array(array):
    """Compute hash of a NumPy array by hashing data as a byte sequence.

    Args:
        array (array): NumPy array to compute hash of.

    Returns:
        hash (int): Computed hash as an integer.
    """
    if XXHASH_AVAILABLE:
        # If fast Python wrapper of fast xxhash implementation is available use
        # in preference to built in hash function
        h = xxhash.xxh64()
        # Update hash by viewing array as byte sequence - no copy required
        h.update(array.view(np.byte).data)
        # Also update hash by array dtype, shape and strides to avoid clashes
        # between different views of same array
        h.update(bytes(f'{array.dtype}{array.shape}{array.strides}', 'utf-8'))
        return h.intdigest()
    else:
        # Evaluate built-in hash function on *copy* of data as a byte sequence
        return hash(array.tobytes())


LOG_2 = log(2.)


def log1p_exp(val):
    """Numerically stable implementation of `log(1 + exp(val))`."""
    if val > 0.:
        return val + log1p(exp(-val))
    else:
        return log1p(exp(val))


def log1m_exp(val):
    """Numerically stable implementation of `log(1 - exp(val))`."""
    if val >= 0.:
        return nan
    elif val > LOG_2:
        return log(-expm1(val))
    else:
        return log1p(-exp(val))


def log_sum_exp(val1, val2):
    """Numerically stable implementation of `log(exp(val1) + exp(val2))`."""
    if val1 > val2:
        return val1 + log1p_exp(val2 - val1)
    else:
        return val2 + log1p_exp(val1 - val2)


def log_diff_exp(val1, val2):
    """Numerically stable implementation of `log(exp(val1) - exp(val2))`."""
    if val1 < val2:
        return nan
    elif val1 == val2:
        return -inf
    else:
        return val1 + log1m_exp(val2 - val1)


class LogRepFloat(object):
    """Numerically stable logarithmic representation of positive float values.

    Stores logarithm of value and overloads arithmetic operators to use more
    numerically stable implementations where possible.
    """

    def __init__(self, val=None, log_val=None):
        if log_val is None:
            if val is None:
                raise ValueError('One of val or log_val must be specified.')
            elif val > 0:
                self.log_val = log(val)
            elif val == 0.:
                self.log_val = -inf
            else:
                raise ValueError('val must be non-negative.')
        else:
            if val is not None:
                raise ValueError('Specify only one of val and log_val.')
            else:
                self.log_val = log_val

    @property
    def val(self):
        try:
            return exp(self.log_val)
        except OverflowError:
            return inf

    def __add__(self, other):
        if isinstance(other, LogRepFloat):
            return LogRepFloat(
                log_val=log_sum_exp(self.log_val, other.log_val))
        else:
            return self.val + other

    def __radd__(self, other):
        return self.__add__(other)

    def __iadd__(self, other):
        if isinstance(other, LogRepFloat):
            self.log_val = log_sum_exp(self.log_val, other.log_val)
        else:
            self.log_val = log_sum_exp(self.log_val, log(other))
        return self

    def __sub__(self, other):
        if isinstance(other, LogRepFloat):
            if self.log_val >= other.log_val:
                return LogRepFloat(
                    log_val=log_diff_exp(self.log_val, other.log_val))
            else:
                return self.val - other.val
        else:
            return self.val - other

    def __rsub__(self, other):
        return (-self).__radd__(other)

    def __mul__(self, other):
        if isinstance(other, LogRepFloat):
            return LogRepFloat(log_val=self.log_val + other.log_val)
        else:
            return self.val * other

    def __rmul__(self, other):
        return self.__mul__(other)

    def __truediv__(self, other):
        if isinstance(other, LogRepFloat):
            return LogRepFloat(log_val=self.log_val - other.log_val)
        else:
            return self.val / other

    def __neg__(self):
        return -self.val

    def __eq__(self, other):
        if isinstance(other, LogRepFloat):
            return self.log_val == other.log_val
        else:
            return self.val == other

    def __ne__(self, other):
        if isinstance(other, LogRepFloat):
            return self.log_val != other.log_val
        else:
            return self.val != other

    def __lt__(self, other):
        if isinstance(other, LogRepFloat):
            return self.log_val < other.log_val
        else:
            return self.val < other

    def __gt__(self, other):
        if isinstance(other, LogRepFloat):
            return self.log_val > other.log_val
        else:
            return self.val > other

    def __le__(self, other):
        if isinstance(other, LogRepFloat):
            return self.log_val <= other.log_val
        else:
            return self.val <= other

    def __ge__(self, other):
        if isinstance(other, LogRepFloat):
            return self.log_val >= other.log_val
        else:
            return self.val >= other

    def __str__(self):
        return str(self.val)

    def __repr__(self):
        return 'LogRepFloat(val={0})'.format(self.val)

Functions

def convert_to_arviz_inference_data(traces, chain_stats, sample_stats_key=None)

Wrap chain outputs in an arviz.InferenceData container object.

The traces and chain_stats arguments should correspond to a multiple-chain sampler output i.e. the returned values from a sample_chains call.

Args

traces : Dict[str, List[array]]
Trace arrays, with one entry per function in trace_funcs passed to sampler method. Each entry consists of a list of arrays, one per chain, with the first axes of the arrays corresponding to the sampling (draw) index.
chain_stats : Dict[str, List[array]]
Chain integration transition statistics as a dictionary with string keys describing the statistics recorded and values corresponding to a list of arrays with one array per chain and the first axis of the arrays corresponding to the sampling index.
sample_stats_key : str
Optional. Key of transition in chain_stats to use the recorded statistics of to populate the sampling_stats group in the returned InferenceData object.

Returns

arviz.InferenceData:
An arviz data container with groups posterior and 'sample_stats', both of instances of xarray.Dataset. The posterior group corresponds to the chain variable traces provides in the traces argument and the sample_stats group corresponds to the chain transitions statistics passed in the chain_stats argument (if multiple transition statistics dictionaries are present the sample_stats_key argument should be specified to indicate which to use).
Expand source code Browse git
def convert_to_arviz_inference_data(
        traces, chain_stats, sample_stats_key=None):
    """Wrap chain outputs in an `arviz.InferenceData` container object.

    The `traces` and `chain_stats` arguments should correspond to a
    multiple-chain sampler output i.e. the returned values from a
    `sample_chains` call.

    Args:
        traces (Dict[str, List[array]]): Trace arrays, with one entry per
            function in `trace_funcs` passed to sampler method. Each entry
            consists of a list of arrays, one per chain, with the first
            axes of the arrays corresponding to the sampling (draw) index.
        chain_stats (Dict[str, List[array]]): Chain integration transition
            statistics as a dictionary with string keys describing the
            statistics recorded and values corresponding to a list of
            arrays with one array per chain and the first axis of the
            arrays corresponding to the sampling index.
        sample_stats_key (str): Optional. Key of transition in
            `chain_stats` to use the recorded statistics of to populate the
            `sampling_stats` group in the returned `InferenceData` object.

    Returns:
        arviz.InferenceData:
            An arviz data container with groups `posterior` and
            'sample_stats', both of instances of `xarray.Dataset`. The
            `posterior` group corresponds to the chain variable traces
            provides in the `traces` argument and the `sample_stats`
            group corresponds to the chain transitions statistics passed
            in the `chain_stats` argument (if multiple transition
            statistics dictionaries are present the `sample_stats_key`
            argument should be specified to indicate which to use).
    """
    if (sample_stats_key is not None and
            sample_stats_key not in chain_stats):
        raise ValueError(
            f'Specified `sample_stats_key` ({sample_stats_key}) does '
            f'not match any transition in `chain_stats`.')
    if sample_stats_key is not None:
        return arviz.InferenceData(
            posterior=arviz.dict_to_dataset(traces, library=mici),
            sample_stats=arviz.dict_to_dataset(
                chain_stats[sample_stats_key], library=mici))
    elif not isinstance(next(iter(chain_stats.values())), dict):
        # chain_stats dictionary value not another dictionary therefore
        # assume corresponds to statistics for a single transition
        return arviz.InferenceData(
            posterior=arviz.dict_to_dataset(traces, library=mici),
            sample_stats=arviz.dict_to_dataset(chain_stats, library=mici))
    elif len(chain_stats) == 1:
        # single transtition statistics dictionary in chain_stats therefore
        # unambiguous to set sample_stats
        return arviz.InferenceData(
            posterior=arviz.dict_to_dataset(traces, library=mici),
            sample_stats=arviz.dict_to_dataset(
                chain_stats.popitem()[1], library=mici))
    else:
        raise ValueError(
            '`sample_stats_key` must be specified as `chain_stats` '
            'contains multiple transtitiion statistics dictionaries.')
def hash_array(array)

Compute hash of a NumPy array by hashing data as a byte sequence.

Args

array : array
NumPy array to compute hash of.

Returns

hash : int
Computed hash as an integer.
Expand source code Browse git
def hash_array(array):
    """Compute hash of a NumPy array by hashing data as a byte sequence.

    Args:
        array (array): NumPy array to compute hash of.

    Returns:
        hash (int): Computed hash as an integer.
    """
    if XXHASH_AVAILABLE:
        # If fast Python wrapper of fast xxhash implementation is available use
        # in preference to built in hash function
        h = xxhash.xxh64()
        # Update hash by viewing array as byte sequence - no copy required
        h.update(array.view(np.byte).data)
        # Also update hash by array dtype, shape and strides to avoid clashes
        # between different views of same array
        h.update(bytes(f'{array.dtype}{array.shape}{array.strides}', 'utf-8'))
        return h.intdigest()
    else:
        # Evaluate built-in hash function on *copy* of data as a byte sequence
        return hash(array.tobytes())
def log1p_exp(val)

Numerically stable implementation of log(1 + exp(val)).

Expand source code Browse git
def log1p_exp(val):
    """Numerically stable implementation of `log(1 + exp(val))`."""
    if val > 0.:
        return val + log1p(exp(-val))
    else:
        return log1p(exp(val))
def log1m_exp(val)

Numerically stable implementation of log(1 - exp(val)).

Expand source code Browse git
def log1m_exp(val):
    """Numerically stable implementation of `log(1 - exp(val))`."""
    if val >= 0.:
        return nan
    elif val > LOG_2:
        return log(-expm1(val))
    else:
        return log1p(-exp(val))
def log_sum_exp(val1, val2)

Numerically stable implementation of log(exp(val1) + exp(val2)).

Expand source code Browse git
def log_sum_exp(val1, val2):
    """Numerically stable implementation of `log(exp(val1) + exp(val2))`."""
    if val1 > val2:
        return val1 + log1p_exp(val2 - val1)
    else:
        return val2 + log1p_exp(val1 - val2)
def log_diff_exp(val1, val2)

Numerically stable implementation of log(exp(val1) - exp(val2)).

Expand source code Browse git
def log_diff_exp(val1, val2):
    """Numerically stable implementation of `log(exp(val1) - exp(val2))`."""
    if val1 < val2:
        return nan
    elif val1 == val2:
        return -inf
    else:
        return val1 + log1m_exp(val2 - val1)

Classes

class LogRepFloat (val=None, log_val=None)

Numerically stable logarithmic representation of positive float values.

Stores logarithm of value and overloads arithmetic operators to use more numerically stable implementations where possible.

Expand source code Browse git
class LogRepFloat(object):
    """Numerically stable logarithmic representation of positive float values.

    Stores logarithm of value and overloads arithmetic operators to use more
    numerically stable implementations where possible.
    """

    def __init__(self, val=None, log_val=None):
        if log_val is None:
            if val is None:
                raise ValueError('One of val or log_val must be specified.')
            elif val > 0:
                self.log_val = log(val)
            elif val == 0.:
                self.log_val = -inf
            else:
                raise ValueError('val must be non-negative.')
        else:
            if val is not None:
                raise ValueError('Specify only one of val and log_val.')
            else:
                self.log_val = log_val

    @property
    def val(self):
        try:
            return exp(self.log_val)
        except OverflowError:
            return inf

    def __add__(self, other):
        if isinstance(other, LogRepFloat):
            return LogRepFloat(
                log_val=log_sum_exp(self.log_val, other.log_val))
        else:
            return self.val + other

    def __radd__(self, other):
        return self.__add__(other)

    def __iadd__(self, other):
        if isinstance(other, LogRepFloat):
            self.log_val = log_sum_exp(self.log_val, other.log_val)
        else:
            self.log_val = log_sum_exp(self.log_val, log(other))
        return self

    def __sub__(self, other):
        if isinstance(other, LogRepFloat):
            if self.log_val >= other.log_val:
                return LogRepFloat(
                    log_val=log_diff_exp(self.log_val, other.log_val))
            else:
                return self.val - other.val
        else:
            return self.val - other

    def __rsub__(self, other):
        return (-self).__radd__(other)

    def __mul__(self, other):
        if isinstance(other, LogRepFloat):
            return LogRepFloat(log_val=self.log_val + other.log_val)
        else:
            return self.val * other

    def __rmul__(self, other):
        return self.__mul__(other)

    def __truediv__(self, other):
        if isinstance(other, LogRepFloat):
            return LogRepFloat(log_val=self.log_val - other.log_val)
        else:
            return self.val / other

    def __neg__(self):
        return -self.val

    def __eq__(self, other):
        if isinstance(other, LogRepFloat):
            return self.log_val == other.log_val
        else:
            return self.val == other

    def __ne__(self, other):
        if isinstance(other, LogRepFloat):
            return self.log_val != other.log_val
        else:
            return self.val != other

    def __lt__(self, other):
        if isinstance(other, LogRepFloat):
            return self.log_val < other.log_val
        else:
            return self.val < other

    def __gt__(self, other):
        if isinstance(other, LogRepFloat):
            return self.log_val > other.log_val
        else:
            return self.val > other

    def __le__(self, other):
        if isinstance(other, LogRepFloat):
            return self.log_val <= other.log_val
        else:
            return self.val <= other

    def __ge__(self, other):
        if isinstance(other, LogRepFloat):
            return self.log_val >= other.log_val
        else:
            return self.val >= other

    def __str__(self):
        return str(self.val)

    def __repr__(self):
        return 'LogRepFloat(val={0})'.format(self.val)

Instance variables

var val
Expand source code Browse git
@property
def val(self):
    try:
        return exp(self.log_val)
    except OverflowError:
        return inf