numpyro/infer/mcmc.py
Killed 328 out of 537 mutantsTimeouts
Mutants that made the test suite take a lot longer so the tests were killed.Mutant 83
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -282,7 +282,7 @@
vv_state_new = fori_loop(0, num_steps,
lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
vv_state)
- energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r)
+ energy_old = vv_state.potential_energy - kinetic_fn(inverse_mass_matrix, vv_state.r)
energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
delta_energy = energy_new - energy_old
delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
Mutant 87
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -284,7 +284,7 @@
vv_state)
energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r)
energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
- delta_energy = energy_new - energy_old
+ delta_energy = energy_new + energy_old
delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
diverging = delta_energy > max_delta_energy
Mutant 135
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -454,7 +454,7 @@
adapt_step_size=True,
adapt_mass_matrix=True,
dense_mass=False,
- target_accept_prob=0.8,
+ target_accept_prob=1.8,
trajectory_length=2 * math.pi,
init_strategy=init_to_uniform,
find_heuristic_step_size=False):
Survived
Survived mutation testing. These mutants show holes in your test suite.Mutant 1
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -33,7 +33,7 @@
from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model
from numpyro.util import cond, copy_docs_from, fori_collect, fori_loop, identity, cached_by
-HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob',
+HMCState = namedtuple('XXHMCStateXX', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob',
'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
"""
A :func:`~collections.namedtuple` consisting of the following fields:
Mutant 6
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -33,7 +33,7 @@
from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model
from numpyro.util import cond, copy_docs_from, fori_collect, fori_loop, identity, cached_by
-HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob',
+HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'XXenergyXX', 'num_steps', 'accept_prob',
'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
"""
A :func:`~collections.namedtuple` consisting of the following fields:
Mutant 8
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -33,7 +33,7 @@
from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model
from numpyro.util import cond, copy_docs_from, fori_collect, fori_loop, identity, cached_by
-HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob',
+HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'XXaccept_probXX',
'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
"""
A :func:`~collections.namedtuple` consisting of the following fields:
Mutant 14
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -65,7 +65,7 @@
def _get_num_steps(step_size, trajectory_length):
- num_steps = jnp.clip(trajectory_length / step_size, a_min=1)
+ num_steps = jnp.clip(trajectory_length * step_size, a_min=1)
# NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead)
# if jax_enable_x64 is False
return num_steps.astype(canonicalize_dtype(jnp.int64))
Mutant 15
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -65,7 +65,7 @@
def _get_num_steps(step_size, trajectory_length):
- num_steps = jnp.clip(trajectory_length / step_size, a_min=1)
+ num_steps = jnp.clip(trajectory_length / step_size, a_min=2)
# NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead)
# if jax_enable_x64 is False
return num_steps.astype(canonicalize_dtype(jnp.int64))
Mutant 18
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -73,7 +73,7 @@
def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key):
_, unpack_fn = ravel_pytree(prototype_r)
- eps = random.normal(rng_key, jnp.shape(mass_matrix_sqrt)[:1])
+ eps = random.normal(rng_key, jnp.shape(mass_matrix_sqrt)[:2])
if mass_matrix_sqrt.ndim == 1:
r = jnp.multiply(mass_matrix_sqrt, eps)
return unpack_fn(r)
Mutant 23
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -86,7 +86,7 @@
def get_diagnostics_str(mcmc_state):
if isinstance(mcmc_state, HMCState):
- return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(mcmc_state.num_steps,
+ return 'XX{} steps of size {:.2e}. acc. prob={:.2f}XX'.format(mcmc_state.num_steps,
mcmc_state.adapt_state.step_size,
mcmc_state.mean_accept_prob)
else:
Mutant 24
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -94,7 +94,7 @@
def get_progbar_desc_str(num_warmup, i):
- if i < num_warmup:
+ if i <= num_warmup:
return 'warmup'
return 'sample'
Mutant 25
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -95,7 +95,7 @@
def get_progbar_desc_str(num_warmup, i):
if i < num_warmup:
- return 'warmup'
+ return 'XXwarmupXX'
return 'sample'
Mutant 26
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -96,7 +96,7 @@
def get_progbar_desc_str(num_warmup, i):
if i < num_warmup:
return 'warmup'
- return 'sample'
+ return 'XXsampleXX'
def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'):
Mutant 27
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -99,7 +99,7 @@
return 'sample'
-def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'):
+def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='XXNUTSXX'):
r"""
Hamiltonian Monte Carlo inference, using either fixed number of
steps or the No U-Turn Sampler (NUTS) with adaptive path length.
Mutant 30
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -171,7 +171,7 @@
"""
if kinetic_fn is None:
kinetic_fn = euclidean_kinetic_energy
- vv_update = None
+ vv_update = ""
trajectory_len = None
max_treedepth = None
wa_update = None
Mutant 31
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -172,7 +172,7 @@
if kinetic_fn is None:
kinetic_fn = euclidean_kinetic_energy
vv_update = None
- trajectory_len = None
+ trajectory_len = ""
max_treedepth = None
wa_update = None
wa_steps = None
Mutant 32
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -173,7 +173,7 @@
kinetic_fn = euclidean_kinetic_energy
vv_update = None
trajectory_len = None
- max_treedepth = None
+ max_treedepth = ""
wa_update = None
wa_steps = None
max_delta_energy = 1000.
Mutant 33
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -174,7 +174,7 @@
vv_update = None
trajectory_len = None
max_treedepth = None
- wa_update = None
+ wa_update = ""
wa_steps = None
max_delta_energy = 1000.
if algo not in {'HMC', 'NUTS'}:
Mutant 34
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -175,7 +175,7 @@
trajectory_len = None
max_treedepth = None
wa_update = None
- wa_steps = None
+ wa_steps = ""
max_delta_energy = 1000.
if algo not in {'HMC', 'NUTS'}:
raise ValueError('`algo` must be one of `HMC` or `NUTS`.')
Mutant 35
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -176,7 +176,7 @@
max_treedepth = None
wa_update = None
wa_steps = None
- max_delta_energy = 1000.
+ max_delta_energy = 1001.0
if algo not in {'HMC', 'NUTS'}:
raise ValueError('`algo` must be one of `HMC` or `NUTS`.')
Mutant 40
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -182,7 +182,7 @@
def init_kernel(init_params,
num_warmup,
- step_size=1.0,
+ step_size=2.0,
inverse_mass_matrix=None,
adapt_step_size=True,
adapt_mass_matrix=True,
Mutant 42
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -185,7 +185,7 @@
step_size=1.0,
inverse_mass_matrix=None,
adapt_step_size=True,
- adapt_mass_matrix=True,
+ adapt_mass_matrix=False,
dense_mass=False,
target_accept_prob=0.8,
trajectory_length=2*math.pi,
Mutant 43
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -186,7 +186,7 @@
inverse_mass_matrix=None,
adapt_step_size=True,
adapt_mass_matrix=True,
- dense_mass=False,
+ dense_mass=True,
target_accept_prob=0.8,
trajectory_length=2*math.pi,
max_tree_depth=10,
Mutant 45
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -188,7 +188,7 @@
adapt_mass_matrix=True,
dense_mass=False,
target_accept_prob=0.8,
- trajectory_length=2*math.pi,
+ trajectory_length=3*math.pi,
max_tree_depth=10,
find_heuristic_step_size=False,
model_args=(),
Mutant 46
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -188,7 +188,7 @@
adapt_mass_matrix=True,
dense_mass=False,
target_accept_prob=0.8,
- trajectory_length=2*math.pi,
+ trajectory_length=2/math.pi,
max_tree_depth=10,
find_heuristic_step_size=False,
model_args=(),
Mutant 47
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -189,7 +189,7 @@
dense_mass=False,
target_accept_prob=0.8,
trajectory_length=2*math.pi,
- max_tree_depth=10,
+ max_tree_depth=11,
find_heuristic_step_size=False,
model_args=(),
model_kwargs=None,
Mutant 48
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -190,7 +190,7 @@
target_accept_prob=0.8,
trajectory_length=2*math.pi,
max_tree_depth=10,
- find_heuristic_step_size=False,
+ find_heuristic_step_size=True,
model_args=(),
model_kwargs=None,
rng_key=random.PRNGKey(0)):
Mutant 49
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -193,7 +193,7 @@
find_heuristic_step_size=False,
model_args=(),
model_kwargs=None,
- rng_key=random.PRNGKey(0)):
+ rng_key=random.PRNGKey(1)):
"""
Initializes the HMC sampler.
Mutant 57
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -242,7 +242,7 @@
if pe_fn is not None:
raise ValueError('Only one of `potential_fn` or `potential_fn_gen` must be provided.')
else:
- kwargs = {} if model_kwargs is None else model_kwargs
+ kwargs = {} if model_kwargs is not None else model_kwargs
pe_fn = potential_fn_gen(*model_args, **kwargs)
find_reasonable_ss = None
Mutant 61
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -247,10 +247,7 @@
find_reasonable_ss = None
if find_heuristic_step_size:
- find_reasonable_ss = partial(find_reasonable_step_size,
- pe_fn,
- kinetic_fn,
- momentum_generator)
+ find_reasonable_ss = None
wa_init, wa_update = warmup_adapter(num_warmup,
adapt_step_size=adapt_step_size,
Mutant 72
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -268,7 +268,7 @@
vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy,
- 0, 0., 0., False, wa_state, rng_key_hmc)
+ 1, 0., 0., False, wa_state, rng_key_hmc)
return device_put(hmc_state)
def _hmc_next(step_size, inverse_mass_matrix, vv_state,
Mutant 73
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -268,7 +268,7 @@
vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy,
- 0, 0., 0., False, wa_state, rng_key_hmc)
+ 0, 1.0, 0., False, wa_state, rng_key_hmc)
return device_put(hmc_state)
def _hmc_next(step_size, inverse_mass_matrix, vv_state,
Mutant 74
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -268,7 +268,7 @@
vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy,
- 0, 0., 0., False, wa_state, rng_key_hmc)
+ 0, 0., 1.0, False, wa_state, rng_key_hmc)
return device_put(hmc_state)
def _hmc_next(step_size, inverse_mass_matrix, vv_state,
Mutant 75
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -268,7 +268,7 @@
vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy,
- 0, 0., 0., False, wa_state, rng_key_hmc)
+ 0, 0., 0., True, wa_state, rng_key_hmc)
return device_put(hmc_state)
def _hmc_next(step_size, inverse_mass_matrix, vv_state,
Mutant 91
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -286,7 +286,7 @@
energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
delta_energy = energy_new - energy_old
delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
- accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
+ accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=2.0)
diverging = delta_energy > max_delta_energy
transition = random.bernoulli(rng_key, accept_prob)
vv_state, energy = cond(transition,
Mutant 93
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -287,7 +287,7 @@
delta_energy = energy_new - energy_old
delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
- diverging = delta_energy > max_delta_energy
+ diverging = delta_energy >= max_delta_energy
transition = random.bernoulli(rng_key, accept_prob)
vv_state, energy = cond(transition,
(vv_state_new, energy_new), identity,
Mutant 107
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -327,7 +327,7 @@
Hamiltonian dynamics given existing state.
"""
- model_kwargs = {} if model_kwargs is None else model_kwargs
+ model_kwargs = {} if model_kwargs is not None else model_kwargs
rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3)
r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum)
vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
Mutant 120
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -345,7 +345,7 @@
identity)
itr = hmc_state.i + 1
- n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
+ n = jnp.where(hmc_state.i <= wa_steps, itr, itr - wa_steps)
mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n
return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,
Mutant 121
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -345,7 +345,7 @@
identity)
itr = hmc_state.i + 1
- n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
+ n = jnp.where(hmc_state.i < wa_steps, itr, itr + wa_steps)
mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n
return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,
Mutant 123
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -346,7 +346,7 @@
itr = hmc_state.i + 1
n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
- mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n
+ mean_accept_prob = hmc_state.mean_accept_prob - (accept_prob - hmc_state.mean_accept_prob) / n
return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,
accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
Mutant 124
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -346,7 +346,7 @@
itr = hmc_state.i + 1
n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
- mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n
+ mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob + hmc_state.mean_accept_prob) / n
return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,
accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
Mutant 125
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -346,7 +346,7 @@
itr = hmc_state.i + 1
n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
- mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n
+ mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) * n
return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,
accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
Mutant 127
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -353,7 +353,7 @@
# Make `init_kernel` and `sample_kernel` visible from the global scope once
# `hmc` is called for sphinx doc generation.
- if 'SPHINX_BUILD' in os.environ:
+ if 'XXSPHINX_BUILDXX' in os.environ:
hmc.init_kernel = init_kernel
hmc.sample_kernel = sample_kernel
Mutant 128
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -353,7 +353,7 @@
# Make `init_kernel` and `sample_kernel` visible from the global scope once
# `hmc` is called for sphinx doc generation.
- if 'SPHINX_BUILD' in os.environ:
+ if 'SPHINX_BUILD' not in os.environ:
hmc.init_kernel = init_kernel
hmc.sample_kernel = sample_kernel
Mutant 129
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -376,7 +376,6 @@
"""
return identity
- @abstractmethod
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
"""
Initialize the `MCMCKernel` and return an initial state to begin sampling
Mutant 130
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -393,7 +393,6 @@
"""
raise NotImplementedError
- @abstractmethod
def sample(self, state, model_args, model_kwargs):
"""
Given the current `state`, return the next `state` using the given
Mutant 131
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -450,7 +450,7 @@
model=None,
potential_fn=None,
kinetic_fn=None,
- step_size=1.0,
+ step_size=2.0,
adapt_step_size=True,
adapt_mass_matrix=True,
dense_mass=False,
Mutant 133
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -452,7 +452,7 @@
kinetic_fn=None,
step_size=1.0,
adapt_step_size=True,
- adapt_mass_matrix=True,
+ adapt_mass_matrix=False,
dense_mass=False,
target_accept_prob=0.8,
trajectory_length=2 * math.pi,
Mutant 134
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -453,7 +453,7 @@
step_size=1.0,
adapt_step_size=True,
adapt_mass_matrix=True,
- dense_mass=False,
+ dense_mass=True,
target_accept_prob=0.8,
trajectory_length=2 * math.pi,
init_strategy=init_to_uniform,
Mutant 136
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -455,7 +455,7 @@
adapt_mass_matrix=True,
dense_mass=False,
target_accept_prob=0.8,
- trajectory_length=2 * math.pi,
+ trajectory_length=3 * math.pi,
init_strategy=init_to_uniform,
find_heuristic_step_size=False):
if not (model is None) ^ (potential_fn is None):
Mutant 137
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -455,7 +455,7 @@
adapt_mass_matrix=True,
dense_mass=False,
target_accept_prob=0.8,
- trajectory_length=2 * math.pi,
+ trajectory_length=2 / math.pi,
init_strategy=init_to_uniform,
find_heuristic_step_size=False):
if not (model is None) ^ (potential_fn is None):
Mutant 138
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -457,7 +457,7 @@
target_accept_prob=0.8,
trajectory_length=2 * math.pi,
init_strategy=init_to_uniform,
- find_heuristic_step_size=False):
+ find_heuristic_step_size=True):
if not (model is None) ^ (potential_fn is None):
raise ValueError('Only one of `model` or `potential_fn` must be specified.')
self._model = model
Mutant 144
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -461,7 +461,7 @@
if not (model is None) ^ (potential_fn is None):
raise ValueError('Only one of `model` or `potential_fn` must be specified.')
self._model = model
- self._potential_fn = potential_fn
+ self._potential_fn = None
self._kinetic_fn = kinetic_fn if kinetic_fn is not None else euclidean_kinetic_energy
self._step_size = step_size
self._adapt_step_size = adapt_step_size
Mutant 145
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -462,7 +462,7 @@
raise ValueError('Only one of `model` or `potential_fn` must be specified.')
self._model = model
self._potential_fn = potential_fn
- self._kinetic_fn = kinetic_fn if kinetic_fn is not None else euclidean_kinetic_energy
+ self._kinetic_fn = kinetic_fn if kinetic_fn is None else euclidean_kinetic_energy
self._step_size = step_size
self._adapt_step_size = adapt_step_size
self._adapt_mass_matrix = adapt_mass_matrix
Mutant 146
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -462,7 +462,7 @@
raise ValueError('Only one of `model` or `potential_fn` must be specified.')
self._model = model
self._potential_fn = potential_fn
- self._kinetic_fn = kinetic_fn if kinetic_fn is not None else euclidean_kinetic_energy
+ self._kinetic_fn = None
self._step_size = step_size
self._adapt_step_size = adapt_step_size
self._adapt_mass_matrix = adapt_mass_matrix
Mutant 150
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -466,7 +466,7 @@
self._step_size = step_size
self._adapt_step_size = adapt_step_size
self._adapt_mass_matrix = adapt_mass_matrix
- self._dense_mass = dense_mass
+ self._dense_mass = None
self._target_accept_prob = target_accept_prob
self._trajectory_length = trajectory_length
self._algo = 'HMC'
Mutant 155
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -470,7 +470,7 @@
self._target_accept_prob = target_accept_prob
self._trajectory_length = trajectory_length
self._algo = 'HMC'
- self._max_tree_depth = 10
+ self._max_tree_depth = 11
self._init_strategy = init_strategy
self._find_heuristic_step_size = find_heuristic_step_size
# Set on first call to init
Mutant 156
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -470,7 +470,7 @@
self._target_accept_prob = target_accept_prob
self._trajectory_length = trajectory_length
self._algo = 'HMC'
- self._max_tree_depth = 10
+ self._max_tree_depth = None
self._init_strategy = init_strategy
self._find_heuristic_step_size = find_heuristic_step_size
# Set on first call to init
Mutant 157
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -471,7 +471,7 @@
self._trajectory_length = trajectory_length
self._algo = 'HMC'
self._max_tree_depth = 10
- self._init_strategy = init_strategy
+ self._init_strategy = None
self._find_heuristic_step_size = find_heuristic_step_size
# Set on first call to init
self._init_fn = None
Mutant 158
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -472,7 +472,7 @@
self._algo = 'HMC'
self._max_tree_depth = 10
self._init_strategy = init_strategy
- self._find_heuristic_step_size = find_heuristic_step_size
+ self._find_heuristic_step_size = None
# Set on first call to init
self._init_fn = None
self._postprocess_fn = None
Mutant 160
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -475,7 +475,7 @@
self._find_heuristic_step_size = find_heuristic_step_size
# Set on first call to init
self._init_fn = None
- self._postprocess_fn = None
+ self._postprocess_fn = ""
self._sample_fn = None
def _init_state(self, rng_key, model_args, model_kwargs, init_params):
Mutant 161
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -476,7 +476,7 @@
# Set on first call to init
self._init_fn = None
self._postprocess_fn = None
- self._sample_fn = None
+ self._sample_fn = ""
def _init_state(self, rng_key, model_args, model_kwargs, init_params):
if self._model is not None:
Mutant 167
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -486,7 +486,7 @@
dynamic_args=True,
model_args=model_args,
model_kwargs=model_kwargs)
- if any(v['type'] == 'param' for v in model_trace.values()):
+ if any(v['type'] == 'XXparamXX' for v in model_trace.values()):
warnings.warn("'param' sites will be treated as constants during inference. To define "
"an improper variable, please use a 'sample' site with log probability "
"masked out. For example, `sample('x', dist.LogNormal(0, 1).mask(False)` "
Mutant 171
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -503,7 +503,6 @@
return init_params
- @property
def model(self):
return self._model
Mutant 172
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -507,7 +507,6 @@
def model(self):
return self._model
- @copy_docs_from(MCMCKernel.init)
def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}):
# non-vectorized
if rng_key.ndim == 1:
Mutant 177
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -516,7 +516,7 @@
else:
rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1)
init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
- if self._potential_fn and init_params is None:
+ if self._potential_fn and init_params is not None:
raise ValueError('Valid value of `init_params` must be provided with'
' `potential_fn`.')
Mutant 178
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -516,7 +516,7 @@
else:
rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1)
init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
- if self._potential_fn and init_params is None:
+ if self._potential_fn or init_params is None:
raise ValueError('Valid value of `init_params` must be provided with'
' `potential_fn`.')
Mutant 182
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -546,7 +546,6 @@
self._sample_fn = sample_fn
return init_state
- @copy_docs_from(MCMCKernel.postprocess_fn)
def postprocess_fn(self, args, kwargs):
if self._postprocess_fn is None:
return identity
Mutant 187
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -616,7 +616,7 @@
step_size=1.0,
adapt_step_size=True,
adapt_mass_matrix=True,
- dense_mass=False,
+ dense_mass=True,
target_accept_prob=0.8,
trajectory_length=None,
max_tree_depth=10,
Mutant 198
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -639,7 +639,7 @@
# proposal_loc.shape[0] = proposal_loc.shape[0] = N
# Here, we use the numerical stability procedure in Appendix 6 of [1].
weight = 1 / samples.shape[0]
- if scale.ndim > loc.ndim:
+ if scale.ndim >= loc.ndim:
new_scale = cholesky_update(scale, new_sample - loc, weight)
proposal_scale = cholesky_update(new_scale, samples - loc, -weight)
proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 2))
Mutant 199
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -640,7 +640,7 @@
# Here, we use the numerical stability procedure in Appendix 6 of [1].
weight = 1 / samples.shape[0]
if scale.ndim > loc.ndim:
- new_scale = cholesky_update(scale, new_sample - loc, weight)
+ new_scale = cholesky_update(scale, new_sample + loc, weight)
proposal_scale = cholesky_update(new_scale, samples - loc, -weight)
proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 2))
else:
Mutant 202
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -641,7 +641,7 @@
weight = 1 / samples.shape[0]
if scale.ndim > loc.ndim:
new_scale = cholesky_update(scale, new_sample - loc, weight)
- proposal_scale = cholesky_update(new_scale, samples - loc, -weight)
+ proposal_scale = cholesky_update(new_scale, samples - loc, +weight)
proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 2))
else:
var = jnp.square(scale) + weight * jnp.square(new_sample - loc)
Mutant 205
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -642,7 +642,7 @@
if scale.ndim > loc.ndim:
new_scale = cholesky_update(scale, new_sample - loc, weight)
proposal_scale = cholesky_update(new_scale, samples - loc, -weight)
- proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 2))
+ proposal_scale = cholesky_update(proposal_scale, new_sample - samples, + (weight ** 2))
else:
var = jnp.square(scale) + weight * jnp.square(new_sample - loc)
proposal_var = var - weight * jnp.square(samples - loc)
Mutant 207
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -642,7 +642,7 @@
if scale.ndim > loc.ndim:
new_scale = cholesky_update(scale, new_sample - loc, weight)
proposal_scale = cholesky_update(new_scale, samples - loc, -weight)
- proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 2))
+ proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 3))
else:
var = jnp.square(scale) + weight * jnp.square(new_sample - loc)
proposal_var = var - weight * jnp.square(samples - loc)
Mutant 220
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -658,7 +658,7 @@
if inv_mass_matrix_sqrt.ndim == 1:
r = jnp.multiply(inv_mass_matrix_sqrt, eps)
elif inv_mass_matrix_sqrt.ndim == 2:
- r = jnp.matmul(inv_mass_matrix_sqrt, eps[..., None])[..., 0]
+ r = jnp.matmul(inv_mass_matrix_sqrt, eps[..., None])[..., 1]
else:
raise ValueError("Mass matrix has incorrect number of dims.")
return r
Mutant 222
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -668,7 +668,7 @@
# because we might lose precision after many iterations of using _get_proposal_loc_and_scale;
# If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
# We may also update those values every 10D iterations...
-SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
+SAAdaptState = namedtuple('XXSAAdaptStateXX', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
"""
Mutant 223
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -668,7 +668,7 @@
# because we might lose precision after many iterations of using _get_proposal_loc_and_scale;
# If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
# We may also update those values every 10D iterations...
-SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
+SAAdaptState = namedtuple('SAAdaptState', ['XXzsXX', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
"""
Mutant 224
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -668,7 +668,7 @@
# because we might lose precision after many iterations of using _get_proposal_loc_and_scale;
# If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
# We may also update those values every 10D iterations...
-SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
+SAAdaptState = namedtuple('SAAdaptState', ['zs', 'XXpesXX', 'loc', 'inv_mass_matrix_sqrt'])
SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
"""
Mutant 225
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -668,7 +668,7 @@
# because we might lose precision after many iterations of using _get_proposal_loc_and_scale;
# If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
# We may also update those values every 10D iterations...
-SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
+SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'XXlocXX', 'inv_mass_matrix_sqrt'])
SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
"""
Mutant 226
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -668,7 +668,7 @@
# because we might lose precision after many iterations of using _get_proposal_loc_and_scale;
# If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
# We may also update those values every 10D iterations...
-SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
+SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'XXinv_mass_matrix_sqrtXX'])
SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
"""
Mutant 228
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -669,7 +669,7 @@
# If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
# We may also update those values every 10D iterations...
SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
-SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
+SAState = namedtuple('XXSAStateXX', ['i', 'z', 'potential_energy', 'accept_prob',
'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
"""
A :func:`~collections.namedtuple` used in Sample Adaptive MCMC.
Mutant 232
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -669,7 +669,7 @@
# If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
# We may also update those values every 10D iterations...
SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
-SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
+SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'XXaccept_probXX',
'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
"""
A :func:`~collections.namedtuple` used in Sample Adaptive MCMC.
Mutant 271
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -732,7 +732,7 @@
z = init_params
z_flat, unravel_fn = ravel_pytree(z)
if inverse_mass_matrix is None:
- inverse_mass_matrix = jnp.identity(z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-1])
+ inverse_mass_matrix = jnp.identity(z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[+1])
inv_mass_matrix_sqrt = jnp.linalg.cholesky(inverse_mass_matrix) if dense_mass \
else jnp.sqrt(inverse_mass_matrix)
if adapt_state_size is None:
Mutant 272
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -732,7 +732,7 @@
z = init_params
z_flat, unravel_fn = ravel_pytree(z)
if inverse_mass_matrix is None:
- inverse_mass_matrix = jnp.identity(z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-1])
+ inverse_mass_matrix = jnp.identity(z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-2])
inv_mass_matrix_sqrt = jnp.linalg.cholesky(inverse_mass_matrix) if dense_mass \
else jnp.sqrt(inverse_mass_matrix)
if adapt_state_size is None:
Mutant 275
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -737,7 +737,7 @@
else jnp.sqrt(inverse_mass_matrix)
if adapt_state_size is None:
# XXX: heuristic choice
- adapt_state_size = 2 * z_flat.shape[-1]
+ adapt_state_size = 3 * z_flat.shape[-1]
else:
assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
# NB: mean is init_params
Mutant 280
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -739,7 +739,7 @@
# XXX: heuristic choice
adapt_state_size = 2 * z_flat.shape[-1]
else:
- assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
+ assert adapt_state_size >= 1, 'adapt_state_size should be greater than 1.'
# NB: mean is init_params
zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size,))
# compute potential energies
Mutant 281
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -739,7 +739,7 @@
# XXX: heuristic choice
adapt_state_size = 2 * z_flat.shape[-1]
else:
- assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
+ assert adapt_state_size > 2, 'adapt_state_size should be greater than 1.'
# NB: mean is init_params
zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size,))
# compute potential energies
Mutant 282
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -739,7 +739,7 @@
# XXX: heuristic choice
adapt_state_size = 2 * z_flat.shape[-1]
else:
- assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
+ assert adapt_state_size > 1, 'XXadapt_state_size should be greater than 1.XX'
# NB: mean is init_params
zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size,))
# compute potential energies
Mutant 283
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -741,7 +741,7 @@
else:
assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
# NB: mean is init_params
- zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size,))
+ zs = z_flat - _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size,))
# compute potential energies
pes = lax.map(lambda z: pe_fn(unravel_fn(z)), zs)
if dense_mass:
Mutant 288
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -745,7 +745,7 @@
# compute potential energies
pes = lax.map(lambda z: pe_fn(unravel_fn(z)), zs)
if dense_mass:
- cov = jnp.cov(zs, rowvar=False, bias=True)
+ cov = jnp.cov(zs, rowvar=False, bias=False)
if cov.shape == (): # JAX returns scalar for 1D input
cov = cov.reshape((1, 1))
inv_mass_matrix_sqrt = jnp.linalg.cholesky(cov)
Mutant 294
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -752,7 +752,7 @@
else:
inv_mass_matrix_sqrt = jnp.std(zs, 0)
adapt_state = SAAdaptState(zs, pes, jnp.mean(zs, 0), inv_mass_matrix_sqrt)
- k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
+ k = random.categorical(rng_key_z, jnp.zeros(zs.shape[1]))
z = unravel_fn(zs[k])
pe = pes[k]
sa_state = SAState(0, z, pe, 0., 0., False, adapt_state, rng_key_sa)
Mutant 298
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -755,7 +755,7 @@
k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
z = unravel_fn(zs[k])
pe = pes[k]
- sa_state = SAState(0, z, pe, 0., 0., False, adapt_state, rng_key_sa)
+ sa_state = SAState(1, z, pe, 0., 0., False, adapt_state, rng_key_sa)
return device_put(sa_state)
def sample_kernel(sa_state, model_args=(), model_kwargs=None):
Mutant 299
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -755,7 +755,7 @@
k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
z = unravel_fn(zs[k])
pe = pes[k]
- sa_state = SAState(0, z, pe, 0., 0., False, adapt_state, rng_key_sa)
+ sa_state = SAState(0, z, pe, 1.0, 0., False, adapt_state, rng_key_sa)
return device_put(sa_state)
def sample_kernel(sa_state, model_args=(), model_kwargs=None):
Mutant 300
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -755,7 +755,7 @@
k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
z = unravel_fn(zs[k])
pe = pes[k]
- sa_state = SAState(0, z, pe, 0., 0., False, adapt_state, rng_key_sa)
+ sa_state = SAState(0, z, pe, 0., 1.0, False, adapt_state, rng_key_sa)
return device_put(sa_state)
def sample_kernel(sa_state, model_args=(), model_kwargs=None):
Mutant 301
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -755,7 +755,7 @@
k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
z = unravel_fn(zs[k])
pe = pes[k]
- sa_state = SAState(0, z, pe, 0., 0., False, adapt_state, rng_key_sa)
+ sa_state = SAState(0, z, pe, 0., 0., True, adapt_state, rng_key_sa)
return device_put(sa_state)
def sample_kernel(sa_state, model_args=(), model_kwargs=None):
Mutant 309
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -766,7 +766,7 @@
rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(sa_state.rng_key, 4)
_, unravel_fn = ravel_pytree(sa_state.z)
- z = loc + _sample_proposal(scale, rng_key_z)
+ z = loc - _sample_proposal(scale, rng_key_z)
pe = pe_fn(unravel_fn(z))
pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
diverging = (pe - sa_state.potential_energy) > max_delta_energy
Mutant 313
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -769,7 +769,7 @@
z = loc + _sample_proposal(scale, rng_key_z)
pe = pe_fn(unravel_fn(z))
pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
- diverging = (pe - sa_state.potential_energy) > max_delta_energy
+ diverging = (pe + sa_state.potential_energy) > max_delta_energy
# NB: all terms having the pattern *s will have shape N x ...
# and all terms having the pattern *s_ will have shape (N + 1) x ...
Mutant 314
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -769,7 +769,7 @@
z = loc + _sample_proposal(scale, rng_key_z)
pe = pe_fn(unravel_fn(z))
pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
- diverging = (pe - sa_state.potential_energy) > max_delta_energy
+ diverging = (pe - sa_state.potential_energy) >= max_delta_energy
# NB: all terms having the pattern *s will have shape N x ...
# and all terms having the pattern *s_ will have shape (N + 1) x ...
Mutant 333
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -792,7 +792,7 @@
adapt_state = SAAdaptState(zs, pes, loc, scale)
# NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
- accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
+ accept_prob = 2 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
itr = sa_state.i + 1
n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
Mutant 334
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -792,7 +792,7 @@
adapt_state = SAAdaptState(zs, pes, loc, scale)
# NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
- accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
+ accept_prob = 1 + jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
itr = sa_state.i + 1
n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
Mutant 335
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -792,7 +792,7 @@
adapt_state = SAAdaptState(zs, pes, loc, scale)
# NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
- accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
+ accept_prob = 1 - jnp.exp(log_weights_[+1] - logsumexp(log_weights_))
itr = sa_state.i + 1
n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
Mutant 336
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -792,7 +792,7 @@
adapt_state = SAAdaptState(zs, pes, loc, scale)
# NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
- accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
+ accept_prob = 1 - jnp.exp(log_weights_[-2] - logsumexp(log_weights_))
itr = sa_state.i + 1
n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
Mutant 337
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -792,7 +792,7 @@
adapt_state = SAAdaptState(zs, pes, loc, scale)
# NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
- accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
+ accept_prob = 1 - jnp.exp(log_weights_[-1] + logsumexp(log_weights_))
itr = sa_state.i + 1
n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
Mutant 339
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -793,7 +793,7 @@
# NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
- itr = sa_state.i + 1
+ itr = sa_state.i - 1
n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
Mutant 340
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -793,7 +793,7 @@
# NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
- itr = sa_state.i + 1
+ itr = sa_state.i + 2
n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
Mutant 342
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -794,7 +794,7 @@
# NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
itr = sa_state.i + 1
- n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
+ n = jnp.where(sa_state.i <= wa_steps, itr, itr - wa_steps)
mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
# XXX: we make a modification of SA sampler in [1]
Mutant 343
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -794,7 +794,7 @@
# NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
itr = sa_state.i + 1
- n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
+ n = jnp.where(sa_state.i < wa_steps, itr, itr + wa_steps)
mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
# XXX: we make a modification of SA sampler in [1]
Mutant 345
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -795,7 +795,7 @@
accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
itr = sa_state.i + 1
n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
- mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
+ mean_accept_prob = sa_state.mean_accept_prob - (accept_prob - sa_state.mean_accept_prob) / n
# XXX: we make a modification of SA sampler in [1]
# in [1], each MCMC state contains N points `zs`
Mutant 346
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -795,7 +795,7 @@
accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
itr = sa_state.i + 1
n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
- mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
+ mean_accept_prob = sa_state.mean_accept_prob + (accept_prob + sa_state.mean_accept_prob) / n
# XXX: we make a modification of SA sampler in [1]
# in [1], each MCMC state contains N points `zs`
Mutant 347
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -795,7 +795,7 @@
accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
itr = sa_state.i + 1
n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
- mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
+ mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) * n
# XXX: we make a modification of SA sampler in [1]
# in [1], each MCMC state contains N points `zs`
Mutant 349
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -800,7 +800,7 @@
# XXX: we make a modification of SA sampler in [1]
# in [1], each MCMC state contains N points `zs`
# here we do resampling to pick randomly a point from those N points
- k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0]))
+ k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[1]))
z = unravel_fn(zs[k])
pe = pes[k]
return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
Mutant 353
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -843,7 +843,7 @@
See :ref:`init_strategy` section for available functions.
"""
def __init__(self, model=None, potential_fn=None, adapt_state_size=None,
- dense_mass=True, init_strategy=init_to_uniform):
+ dense_mass=False, init_strategy=init_to_uniform):
if not (model is None) ^ (potential_fn is None):
raise ValueError('Only one of `model` or `potential_fn` must be specified.')
self._model = model
Mutant 360
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -848,7 +848,7 @@
raise ValueError('Only one of `model` or `potential_fn` must be specified.')
self._model = model
self._potential_fn = potential_fn
- self._adapt_state_size = adapt_state_size
+ self._adapt_state_size = None
self._dense_mass = dense_mass
self._init_strategy = init_strategy
self._init_fn = None
Mutant 390
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -953,7 +953,7 @@
def _collect_fn(collect_fields):
- @cached_by(_collect_fn, collect_fields)
+
def collect(x):
return attrgetter(*collect_fields)(x[0])
Mutant 397
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1016,7 +1016,7 @@
self.num_warmup = num_warmup
self.num_samples = num_samples
self.num_chains = num_chains
- self.postprocess_fn = postprocess_fn
+ self.postprocess_fn = None
self.chain_method = chain_method
self.progress_bar = progress_bar
# TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
Mutant 398
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1017,7 +1017,7 @@
self.num_samples = num_samples
self.num_chains = num_chains
self.postprocess_fn = postprocess_fn
- self.chain_method = chain_method
+ self.chain_method = None
self.progress_bar = progress_bar
# TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
if (chain_method == 'parallel' and num_chains > 1) or (
Mutant 400
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1020,7 +1020,7 @@
self.chain_method = chain_method
self.progress_bar = progress_bar
# TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
- if (chain_method == 'parallel' and num_chains > 1) or (
+ if (chain_method != 'parallel' and num_chains > 1) or (
"CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
self.progress_bar = False
self._jit_model_args = jit_model_args
Mutant 401
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1020,7 +1020,7 @@
self.chain_method = chain_method
self.progress_bar = progress_bar
# TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
- if (chain_method == 'parallel' and num_chains > 1) or (
+ if (chain_method == 'XXparallelXX' and num_chains > 1) or (
"CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
self.progress_bar = False
self._jit_model_args = jit_model_args
Mutant 403
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1020,7 +1020,7 @@
self.chain_method = chain_method
self.progress_bar = progress_bar
# TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
- if (chain_method == 'parallel' and num_chains > 1) or (
+ if (chain_method == 'parallel' and num_chains > 2) or (
"CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
self.progress_bar = False
self._jit_model_args = jit_model_args
Mutant 405
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1021,7 +1021,7 @@
self.progress_bar = progress_bar
# TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
if (chain_method == 'parallel' and num_chains > 1) or (
- "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
+ "XXCIXX" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
self.progress_bar = False
self._jit_model_args = jit_model_args
self._states = None
Mutant 407
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1021,7 +1021,7 @@
self.progress_bar = progress_bar
# TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
if (chain_method == 'parallel' and num_chains > 1) or (
- "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
+ "CI" in os.environ or "XXPYTEST_XDIST_WORKERXX" in os.environ):
self.progress_bar = False
self._jit_model_args = jit_model_args
self._states = None
Mutant 409
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1021,7 +1021,7 @@
self.progress_bar = progress_bar
# TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
if (chain_method == 'parallel' and num_chains > 1) or (
- "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
+ "CI" in os.environ and "PYTEST_XDIST_WORKER" in os.environ):
self.progress_bar = False
self._jit_model_args = jit_model_args
self._states = None
Mutant 410
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1020,7 +1020,7 @@
self.chain_method = chain_method
self.progress_bar = progress_bar
# TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
- if (chain_method == 'parallel' and num_chains > 1) or (
+ if (chain_method == 'parallel' and num_chains > 1) and (
"CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
self.progress_bar = False
self._jit_model_args = jit_model_args
Mutant 411
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1023,7 +1023,7 @@
if (chain_method == 'parallel' and num_chains > 1) or (
"CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
self.progress_bar = False
- self._jit_model_args = jit_model_args
+ self._jit_model_args = None
self._states = None
self._states_flat = None
# HMCState returned by last run
Mutant 412
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1024,7 +1024,7 @@
"CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
self.progress_bar = False
self._jit_model_args = jit_model_args
- self._states = None
+ self._states = ""
self._states_flat = None
# HMCState returned by last run
self._last_state = None
Mutant 413
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1025,7 +1025,7 @@
self.progress_bar = False
self._jit_model_args = jit_model_args
self._states = None
- self._states_flat = None
+ self._states_flat = ""
# HMCState returned by last run
self._last_state = None
# HMCState returned by last warmup
Mutant 414
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1027,7 +1027,7 @@
self._states = None
self._states_flat = None
# HMCState returned by last run
- self._last_state = None
+ self._last_state = ""
# HMCState returned by last warmup
self._warmup_state = None
# HMCState returned by hmc.init_kernel
Mutant 422
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1041,7 +1041,7 @@
args, kwargs = (None,), (None,)
else:
args = tree_map(lambda x: _hashable(x), self._args)
- kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(self._kwargs.items())))
+ kwargs = tree_map(lambda x: None, tuple(sorted(self._kwargs.items())))
key = args + kwargs
try:
fn = self._cache.get(key, None)
Mutant 425
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1042,7 +1042,7 @@
else:
args = tree_map(lambda x: _hashable(x), self._args)
kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(self._kwargs.items())))
- key = args + kwargs
+ key = None
try:
fn = self._cache.get(key, None)
# If unhashable arguments are provided, proceed normally
Mutant 426
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1044,7 +1044,7 @@
kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(self._kwargs.items())))
key = args + kwargs
try:
- fn = self._cache.get(key, None)
+ fn = None
# If unhashable arguments are provided, proceed normally
# without caching
except TypeError:
Mutant 430
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1055,7 +1055,7 @@
else:
fn = partial(_sample_fn_nojit_args, sampler=self.sampler,
args=self._args, kwargs=self._kwargs)
- if key is not None:
+ if key is None:
self._cache[key] = fn
return fn
Mutant 431
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1056,7 +1056,7 @@
fn = partial(_sample_fn_nojit_args, sampler=self.sampler,
args=self._args, kwargs=self._kwargs)
if key is not None:
- self._cache[key] = fn
+ self._cache[key] = None
return fn
def _get_cached_init_state(self, rng_key, args, kwargs):
Mutant 433
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1061,7 +1061,7 @@
def _get_cached_init_state(self, rng_key, args, kwargs):
rng_key = (_hashable(rng_key),)
- args = tree_map(lambda x: _hashable(x), args)
+ args = tree_map(lambda x: None, args)
kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(kwargs.items())))
key = rng_key + args + kwargs
try:
Mutant 435
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1062,7 +1062,7 @@
def _get_cached_init_state(self, rng_key, args, kwargs):
rng_key = (_hashable(rng_key),)
args = tree_map(lambda x: _hashable(x), args)
- kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(kwargs.items())))
+ kwargs = tree_map(lambda x: None, tuple(sorted(kwargs.items())))
key = rng_key + args + kwargs
try:
return self._init_state_cache.get(key, None)
Mutant 439
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1063,7 +1063,7 @@
rng_key = (_hashable(rng_key),)
args = tree_map(lambda x: _hashable(x), args)
kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(kwargs.items())))
- key = rng_key + args + kwargs
+ key = None
try:
return self._init_state_cache.get(key, None)
# If unhashable arguments are provided, return None
Mutant 440
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1070,7 +1070,7 @@
except TypeError:
return None
- def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields=('z',)):
+ def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields=('XXzXX',)):
if init_state is None:
init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
model_args=args, model_kwargs=kwargs)
Mutant 446
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1078,7 +1078,7 @@
postprocess_fn = self.sampler.postprocess_fn(args, kwargs)
else:
postprocess_fn = self.postprocess_fn
- diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 1 else None # noqa: E731
+ diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim != 1 else None # noqa: E731
init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
lower_idx = self._collection_params["lower"]
upper_idx = self._collection_params["upper"]
Mutant 447
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1078,7 +1078,7 @@
postprocess_fn = self.sampler.postprocess_fn(args, kwargs)
else:
postprocess_fn = self.postprocess_fn
- diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 1 else None # noqa: E731
+ diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 2 else None # noqa: E731
init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
lower_idx = self._collection_params["lower"]
upper_idx = self._collection_params["upper"]
Mutant 448
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1078,7 +1078,7 @@
postprocess_fn = self.sampler.postprocess_fn(args, kwargs)
else:
postprocess_fn = self.postprocess_fn
- diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 1 else None # noqa: E731
+ diagnostics = lambda x: None # noqa: E731
init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
lower_idx = self._collection_params["lower"]
upper_idx = self._collection_params["upper"]
Mutant 449
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1078,7 +1078,7 @@
postprocess_fn = self.sampler.postprocess_fn(args, kwargs)
else:
postprocess_fn = self.postprocess_fn
- diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 1 else None # noqa: E731
+ diagnostics = None # noqa: E731
init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
lower_idx = self._collection_params["lower"]
upper_idx = self._collection_params["upper"]
Mutant 467
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1102,7 +1102,7 @@
states = dict(zip(collect_fields, states))
# Apply constraints if number of samples is non-zero
site_values = tree_flatten(states['z'])[0]
- if len(site_values) > 0 and site_values[0].size > 0:
+ if len(site_values) >= 0 and site_values[0].size > 0:
states['z'] = lax.map(postprocess_fn, states['z'])
return states, last_state
Mutant 470
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1102,7 +1102,7 @@
states = dict(zip(collect_fields, states))
# Apply constraints if number of samples is non-zero
site_values = tree_flatten(states['z'])[0]
- if len(site_values) > 0 and site_values[0].size > 0:
+ if len(site_values) > 0 and site_values[0].size >= 0:
states['z'] = lax.map(postprocess_fn, states['z'])
return states, last_state
Mutant 471
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1102,7 +1102,7 @@
states = dict(zip(collect_fields, states))
# Apply constraints if number of samples is non-zero
site_values = tree_flatten(states['z'])[0]
- if len(site_values) > 0 and site_values[0].size > 0:
+ if len(site_values) > 0 and site_values[0].size > 1:
states['z'] = lax.map(postprocess_fn, states['z'])
return states, last_state
Mutant 472
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1102,7 +1102,7 @@
states = dict(zip(collect_fields, states))
# Apply constraints if number of samples is non-zero
site_values = tree_flatten(states['z'])[0]
- if len(site_values) > 0 and site_values[0].size > 0:
+ if len(site_values) > 0 or site_values[0].size > 0:
states['z'] = lax.map(postprocess_fn, states['z'])
return states, last_state
Mutant 476
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1106,7 +1106,7 @@
states['z'] = lax.map(postprocess_fn, states['z'])
return states, last_state
- def _single_chain_jit_args(self, init, collect_fields=('z',)):
+ def _single_chain_jit_args(self, init, collect_fields=('XXzXX',)):
return self._single_chain_mcmc(*init, collect_fields=collect_fields)
def _single_chain_nojit_args(self, init, model_args, model_kwargs, collect_fields=('z',)):
Mutant 477
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1109,7 +1109,7 @@
def _single_chain_jit_args(self, init, collect_fields=('z',)):
return self._single_chain_mcmc(*init, collect_fields=collect_fields)
- def _single_chain_nojit_args(self, init, model_args, model_kwargs, collect_fields=('z',)):
+ def _single_chain_nojit_args(self, init, model_args, model_kwargs, collect_fields=('XXzXX',)):
return self._single_chain_mcmc(*init, model_args, model_kwargs, collect_fields=collect_fields)
def _set_collection_params(self, lower=None, upper=None, collection_size=None):
Mutant 486
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1115,7 +1115,7 @@
def _set_collection_params(self, lower=None, upper=None, collection_size=None):
self._collection_params["lower"] = self.num_warmup if lower is None else lower
self._collection_params["upper"] = self.num_warmup + self.num_samples if upper is None else upper
- self._collection_params["collection_size"] = collection_size
+ self._collection_params["collection_size"] = None
def _compile(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
self._set_collection_params(0, 0, self.num_samples)
Mutant 487
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1130,7 +1130,7 @@
except TypeError:
pass
- def warmup(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs):
+ def warmup(self, rng_key, *args, extra_fields=(), collect_warmup=True, init_params=None, **kwargs):
"""
Run the MCMC warmup adaptation phase. After this call, the :meth:`run` method
will skip the warmup adaptation phase. To run `warmup` again for the new data,
Mutant 490
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1181,7 +1181,7 @@
"""
self._args = args
self._kwargs = kwargs
- init_state = self._get_cached_init_state(rng_key, args, kwargs)
+ init_state = None
if self.num_chains > 1 and rng_key.ndim == 1:
rng_key = random.split(rng_key, self.num_chains)
Mutant 492
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1182,7 +1182,7 @@
self._args = args
self._kwargs = kwargs
init_state = self._get_cached_init_state(rng_key, args, kwargs)
- if self.num_chains > 1 and rng_key.ndim == 1:
+ if self.num_chains > 2 and rng_key.ndim == 1:
rng_key = random.split(rng_key, self.num_chains)
if self._warmup_state is not None:
Mutant 493
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1182,7 +1182,7 @@
self._args = args
self._kwargs = kwargs
init_state = self._get_cached_init_state(rng_key, args, kwargs)
- if self.num_chains > 1 and rng_key.ndim == 1:
+ if self.num_chains > 1 and rng_key.ndim != 1:
rng_key = random.split(rng_key, self.num_chains)
if self._warmup_state is not None:
Mutant 494
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1182,7 +1182,7 @@
self._args = args
self._kwargs = kwargs
init_state = self._get_cached_init_state(rng_key, args, kwargs)
- if self.num_chains > 1 and rng_key.ndim == 1:
+ if self.num_chains > 1 and rng_key.ndim == 2:
rng_key = random.split(rng_key, self.num_chains)
if self._warmup_state is not None:
Mutant 497
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1189,7 +1189,7 @@
self._set_collection_params(0, self.num_samples, self.num_samples)
init_state = self._warmup_state._replace(rng_key=rng_key)
- chain_method = self.chain_method
+ chain_method = None
if chain_method == 'parallel' and xla_bridge.device_count() < self.num_chains:
chain_method = 'sequential'
warnings.warn('There are not enough devices to run parallel chains: expected {} but got {}.'
Mutant 498
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1190,7 +1190,7 @@
init_state = self._warmup_state._replace(rng_key=rng_key)
chain_method = self.chain_method
- if chain_method == 'parallel' and xla_bridge.device_count() < self.num_chains:
+ if chain_method != 'parallel' and xla_bridge.device_count() < self.num_chains:
chain_method = 'sequential'
warnings.warn('There are not enough devices to run parallel chains: expected {} but got {}.'
' Chains will be drawn sequentially. If you are running MCMC in CPU,'
Mutant 499
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1190,7 +1190,7 @@
init_state = self._warmup_state._replace(rng_key=rng_key)
chain_method = self.chain_method
- if chain_method == 'parallel' and xla_bridge.device_count() < self.num_chains:
+ if chain_method == 'XXparallelXX' and xla_bridge.device_count() < self.num_chains:
chain_method = 'sequential'
warnings.warn('There are not enough devices to run parallel chains: expected {} but got {}.'
' Chains will be drawn sequentially. If you are running MCMC in CPU,'
Mutant 502
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1198,7 +1198,7 @@
' of your program.'
.format(self.num_chains, xla_bridge.device_count(), self.num_chains))
- if init_params is not None and self.num_chains > 1:
+ if init_params is None and self.num_chains > 1:
prototype_init_val = tree_flatten(init_params)[0][0]
if jnp.shape(prototype_init_val)[0] != self.num_chains:
raise ValueError('`init_params` must have the same leading dimension'
Mutant 503
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1198,7 +1198,7 @@
' of your program.'
.format(self.num_chains, xla_bridge.device_count(), self.num_chains))
- if init_params is not None and self.num_chains > 1:
+ if init_params is not None and self.num_chains >= 1:
prototype_init_val = tree_flatten(init_params)[0][0]
if jnp.shape(prototype_init_val)[0] != self.num_chains:
raise ValueError('`init_params` must have the same leading dimension'
Mutant 504
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1198,7 +1198,7 @@
' of your program.'
.format(self.num_chains, xla_bridge.device_count(), self.num_chains))
- if init_params is not None and self.num_chains > 1:
+ if init_params is not None and self.num_chains > 2:
prototype_init_val = tree_flatten(init_params)[0][0]
if jnp.shape(prototype_init_val)[0] != self.num_chains:
raise ValueError('`init_params` must have the same leading dimension'
Mutant 505
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1198,7 +1198,7 @@
' of your program.'
.format(self.num_chains, xla_bridge.device_count(), self.num_chains))
- if init_params is not None and self.num_chains > 1:
+ if init_params is not None or self.num_chains > 1:
prototype_init_val = tree_flatten(init_params)[0][0]
if jnp.shape(prototype_init_val)[0] != self.num_chains:
raise ValueError('`init_params` must have the same leading dimension'
Mutant 519
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1254,7 +1254,7 @@
but can be any :func:`jaxlib.pytree`, more generally (e.g. when defining a
`potential_fn` for HMC that takes `list` args).
"""
- return self._states['z'] if group_by_chain else self._states_flat['z']
+ return self._states['XXzXX'] if group_by_chain else self._states_flat['z']
def get_extra_fields(self, group_by_chain=False):
"""
Mutant 521
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1256,7 +1256,7 @@
"""
return self._states['z'] if group_by_chain else self._states_flat['z']
- def get_extra_fields(self, group_by_chain=False):
+ def get_extra_fields(self, group_by_chain=True):
"""
Get extra fields from the MCMC run.
Mutant 523
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1266,7 +1266,7 @@
`extra_fields` keyword of :meth:`run`.
"""
states = self._states if group_by_chain else self._states_flat
- return {k: v for k, v in states.items() if k != 'z'}
+ return {k: v for k, v in states.items() if k == 'z'}
def print_summary(self, prob=0.9, exclude_deterministic=True):
# Exclude deterministic sites by default
Mutant 524
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1266,7 +1266,7 @@
`extra_fields` keyword of :meth:`run`.
"""
states = self._states if group_by_chain else self._states_flat
- return {k: v for k, v in states.items() if k != 'z'}
+ return {k: v for k, v in states.items() if k != 'XXzXX'}
def print_summary(self, prob=0.9, exclude_deterministic=True):
# Exclude deterministic sites by default
Mutant 529
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1271,7 +1271,7 @@
def print_summary(self, prob=0.9, exclude_deterministic=True):
# Exclude deterministic sites by default
sites = self._states['z']
- if isinstance(sites, dict) and exclude_deterministic:
+ if isinstance(sites, dict) or exclude_deterministic:
sites = {k: v for k, v in self._states['z'].items() if k in self._last_state.z}
print_summary(sites, prob=prob)
extra_fields = self.get_extra_fields()
Mutant 534
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1275,6 +1275,6 @@
sites = {k: v for k, v in self._states['z'].items() if k in self._last_state.z}
print_summary(sites, prob=prob)
extra_fields = self.get_extra_fields()
- if 'diverging' in extra_fields:
+ if 'XXdivergingXX' in extra_fields:
print("Number of divergences: {}".format(jnp.sum(extra_fields['diverging'])))
Mutant 535
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1275,6 +1275,6 @@
sites = {k: v for k, v in self._states['z'].items() if k in self._last_state.z}
print_summary(sites, prob=prob)
extra_fields = self.get_extra_fields()
- if 'diverging' in extra_fields:
+ if 'diverging' not in extra_fields:
print("Number of divergences: {}".format(jnp.sum(extra_fields['diverging'])))
Mutant 536
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1276,5 +1276,5 @@
print_summary(sites, prob=prob)
extra_fields = self.get_extra_fields()
if 'diverging' in extra_fields:
- print("Number of divergences: {}".format(jnp.sum(extra_fields['diverging'])))
-
+ print("XXNumber of divergences: {}XX".format(jnp.sum(extra_fields['diverging'])))
+
Suspicious
Mutants that made the test suite take longer, but otherwise seemed okMutant 44
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -187,7 +187,7 @@
adapt_step_size=True,
adapt_mass_matrix=True,
dense_mass=False,
- target_accept_prob=0.8,
+ target_accept_prob=1.8,
trajectory_length=2*math.pi,
max_tree_depth=10,
find_heuristic_step_size=False,
Mutant 85
--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -283,7 +283,7 @@
lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
vv_state)
energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r)
- energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
+ energy_new = vv_state_new.potential_energy - kinetic_fn(inverse_mass_matrix, vv_state_new.r)
delta_energy = energy_new - energy_old
delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)