numpyro/infer/mcmc.py

Killed 327 out of 537 mutants

Timeouts

Mutants that made the test suite take a lot longer so the tests were killed.

Mutant 83

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -282,7 +282,7 @@
         vv_state_new = fori_loop(0, num_steps,
                                  lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
                                  vv_state)
-        energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r)
+        energy_old = vv_state.potential_energy - kinetic_fn(inverse_mass_matrix, vv_state.r)
         energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
         delta_energy = energy_new - energy_old
         delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)

Mutant 87

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -284,7 +284,7 @@
                                  vv_state)
         energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r)
         energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
-        delta_energy = energy_new - energy_old
+        delta_energy = energy_new + energy_old
         delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
         accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
         diverging = delta_energy > max_delta_energy

Mutant 135

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -454,7 +454,7 @@
                  adapt_step_size=True,
                  adapt_mass_matrix=True,
                  dense_mass=False,
-                 target_accept_prob=0.8,
+                 target_accept_prob=1.8,
                  trajectory_length=2 * math.pi,
                  init_strategy=init_to_uniform,
                  find_heuristic_step_size=False):

Survived

Survived mutation testing. These mutants show holes in your test suite.

Mutant 1

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -33,7 +33,7 @@
 from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model
 from numpyro.util import cond, copy_docs_from, fori_collect, fori_loop, identity, cached_by
 
-HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob',
+HMCState = namedtuple('XXHMCStateXX', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob',
                                    'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
 """
 A :func:`~collections.namedtuple` consisting of the following fields:

Mutant 6

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -33,7 +33,7 @@
 from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model
 from numpyro.util import cond, copy_docs_from, fori_collect, fori_loop, identity, cached_by
 
-HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob',
+HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'XXenergyXX', 'num_steps', 'accept_prob',
                                    'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
 """
 A :func:`~collections.namedtuple` consisting of the following fields:

Mutant 8

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -33,7 +33,7 @@
 from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model
 from numpyro.util import cond, copy_docs_from, fori_collect, fori_loop, identity, cached_by
 
-HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob',
+HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'XXaccept_probXX',
                                    'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
 """
 A :func:`~collections.namedtuple` consisting of the following fields:

Mutant 14

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -65,7 +65,7 @@
 
 
 def _get_num_steps(step_size, trajectory_length):
-    num_steps = jnp.clip(trajectory_length / step_size, a_min=1)
+    num_steps = jnp.clip(trajectory_length * step_size, a_min=1)
     # NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead)
     # if jax_enable_x64 is False
     return num_steps.astype(canonicalize_dtype(jnp.int64))

Mutant 15

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -65,7 +65,7 @@
 
 
 def _get_num_steps(step_size, trajectory_length):
-    num_steps = jnp.clip(trajectory_length / step_size, a_min=1)
+    num_steps = jnp.clip(trajectory_length / step_size, a_min=2)
     # NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead)
     # if jax_enable_x64 is False
     return num_steps.astype(canonicalize_dtype(jnp.int64))

Mutant 18

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -73,7 +73,7 @@
 
 def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key):
     _, unpack_fn = ravel_pytree(prototype_r)
-    eps = random.normal(rng_key, jnp.shape(mass_matrix_sqrt)[:1])
+    eps = random.normal(rng_key, jnp.shape(mass_matrix_sqrt)[:2])
     if mass_matrix_sqrt.ndim == 1:
         r = jnp.multiply(mass_matrix_sqrt, eps)
         return unpack_fn(r)

Mutant 23

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -86,7 +86,7 @@
 
 def get_diagnostics_str(mcmc_state):
     if isinstance(mcmc_state, HMCState):
-        return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(mcmc_state.num_steps,
+        return 'XX{} steps of size {:.2e}. acc. prob={:.2f}XX'.format(mcmc_state.num_steps,
                                                                   mcmc_state.adapt_state.step_size,
                                                                   mcmc_state.mean_accept_prob)
     else:

Mutant 24

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -94,7 +94,7 @@
 
 
 def get_progbar_desc_str(num_warmup, i):
-    if i < num_warmup:
+    if i <= num_warmup:
         return 'warmup'
     return 'sample'
 

Mutant 25

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -95,7 +95,7 @@
 
 def get_progbar_desc_str(num_warmup, i):
     if i < num_warmup:
-        return 'warmup'
+        return 'XXwarmupXX'
     return 'sample'
 
 

Mutant 26

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -96,7 +96,7 @@
 def get_progbar_desc_str(num_warmup, i):
     if i < num_warmup:
         return 'warmup'
-    return 'sample'
+    return 'XXsampleXX'
 
 
 def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'):

Mutant 27

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -99,7 +99,7 @@
     return 'sample'
 
 
-def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'):
+def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='XXNUTSXX'):
     r"""
     Hamiltonian Monte Carlo inference, using either fixed number of
     steps or the No U-Turn Sampler (NUTS) with adaptive path length.

Mutant 30

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -171,7 +171,7 @@
     """
     if kinetic_fn is None:
         kinetic_fn = euclidean_kinetic_energy
-    vv_update = None
+    vv_update = ""
     trajectory_len = None
     max_treedepth = None
     wa_update = None

Mutant 31

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -172,7 +172,7 @@
     if kinetic_fn is None:
         kinetic_fn = euclidean_kinetic_energy
     vv_update = None
-    trajectory_len = None
+    trajectory_len = ""
     max_treedepth = None
     wa_update = None
     wa_steps = None

Mutant 32

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -173,7 +173,7 @@
         kinetic_fn = euclidean_kinetic_energy
     vv_update = None
     trajectory_len = None
-    max_treedepth = None
+    max_treedepth = ""
     wa_update = None
     wa_steps = None
     max_delta_energy = 1000.

Mutant 33

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -174,7 +174,7 @@
     vv_update = None
     trajectory_len = None
     max_treedepth = None
-    wa_update = None
+    wa_update = ""
     wa_steps = None
     max_delta_energy = 1000.
     if algo not in {'HMC', 'NUTS'}:

Mutant 34

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -175,7 +175,7 @@
     trajectory_len = None
     max_treedepth = None
     wa_update = None
-    wa_steps = None
+    wa_steps = ""
     max_delta_energy = 1000.
     if algo not in {'HMC', 'NUTS'}:
         raise ValueError('`algo` must be one of `HMC` or `NUTS`.')

Mutant 35

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -176,7 +176,7 @@
     max_treedepth = None
     wa_update = None
     wa_steps = None
-    max_delta_energy = 1000.
+    max_delta_energy = 1001.0
     if algo not in {'HMC', 'NUTS'}:
         raise ValueError('`algo` must be one of `HMC` or `NUTS`.')
 

Mutant 40

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -182,7 +182,7 @@
 
     def init_kernel(init_params,
                     num_warmup,
-                    step_size=1.0,
+                    step_size=2.0,
                     inverse_mass_matrix=None,
                     adapt_step_size=True,
                     adapt_mass_matrix=True,

Mutant 42

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -185,7 +185,7 @@
                     step_size=1.0,
                     inverse_mass_matrix=None,
                     adapt_step_size=True,
-                    adapt_mass_matrix=True,
+                    adapt_mass_matrix=False,
                     dense_mass=False,
                     target_accept_prob=0.8,
                     trajectory_length=2*math.pi,

Mutant 43

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -186,7 +186,7 @@
                     inverse_mass_matrix=None,
                     adapt_step_size=True,
                     adapt_mass_matrix=True,
-                    dense_mass=False,
+                    dense_mass=True,
                     target_accept_prob=0.8,
                     trajectory_length=2*math.pi,
                     max_tree_depth=10,

Mutant 45

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -188,7 +188,7 @@
                     adapt_mass_matrix=True,
                     dense_mass=False,
                     target_accept_prob=0.8,
-                    trajectory_length=2*math.pi,
+                    trajectory_length=3*math.pi,
                     max_tree_depth=10,
                     find_heuristic_step_size=False,
                     model_args=(),

Mutant 46

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -188,7 +188,7 @@
                     adapt_mass_matrix=True,
                     dense_mass=False,
                     target_accept_prob=0.8,
-                    trajectory_length=2*math.pi,
+                    trajectory_length=2/math.pi,
                     max_tree_depth=10,
                     find_heuristic_step_size=False,
                     model_args=(),

Mutant 47

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -189,7 +189,7 @@
                     dense_mass=False,
                     target_accept_prob=0.8,
                     trajectory_length=2*math.pi,
-                    max_tree_depth=10,
+                    max_tree_depth=11,
                     find_heuristic_step_size=False,
                     model_args=(),
                     model_kwargs=None,

Mutant 48

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -190,7 +190,7 @@
                     target_accept_prob=0.8,
                     trajectory_length=2*math.pi,
                     max_tree_depth=10,
-                    find_heuristic_step_size=False,
+                    find_heuristic_step_size=True,
                     model_args=(),
                     model_kwargs=None,
                     rng_key=random.PRNGKey(0)):

Mutant 49

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -193,7 +193,7 @@
                     find_heuristic_step_size=False,
                     model_args=(),
                     model_kwargs=None,
-                    rng_key=random.PRNGKey(0)):
+                    rng_key=random.PRNGKey(1)):
         """
         Initializes the HMC sampler.
 

Mutant 57

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -242,7 +242,7 @@
             if pe_fn is not None:
                 raise ValueError('Only one of `potential_fn` or `potential_fn_gen` must be provided.')
             else:
-                kwargs = {} if model_kwargs is None else model_kwargs
+                kwargs = {} if model_kwargs is not None else model_kwargs
                 pe_fn = potential_fn_gen(*model_args, **kwargs)
 
         find_reasonable_ss = None

Mutant 61

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -247,10 +247,7 @@
 
         find_reasonable_ss = None
         if find_heuristic_step_size:
-            find_reasonable_ss = partial(find_reasonable_step_size,
-                                         pe_fn,
-                                         kinetic_fn,
-                                         momentum_generator)
+            find_reasonable_ss = None
 
         wa_init, wa_update = warmup_adapter(num_warmup,
                                             adapt_step_size=adapt_step_size,

Mutant 72

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -268,7 +268,7 @@
         vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
         energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
         hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy,
-                             0, 0., 0., False, wa_state, rng_key_hmc)
+                             1, 0., 0., False, wa_state, rng_key_hmc)
         return device_put(hmc_state)
 
     def _hmc_next(step_size, inverse_mass_matrix, vv_state,

Mutant 73

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -268,7 +268,7 @@
         vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
         energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
         hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy,
-                             0, 0., 0., False, wa_state, rng_key_hmc)
+                             0, 1.0, 0., False, wa_state, rng_key_hmc)
         return device_put(hmc_state)
 
     def _hmc_next(step_size, inverse_mass_matrix, vv_state,

Mutant 74

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -268,7 +268,7 @@
         vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
         energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
         hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy,
-                             0, 0., 0., False, wa_state, rng_key_hmc)
+                             0, 0., 1.0, False, wa_state, rng_key_hmc)
         return device_put(hmc_state)
 
     def _hmc_next(step_size, inverse_mass_matrix, vv_state,

Mutant 75

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -268,7 +268,7 @@
         vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
         energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
         hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy,
-                             0, 0., 0., False, wa_state, rng_key_hmc)
+                             0, 0., 0., True, wa_state, rng_key_hmc)
         return device_put(hmc_state)
 
     def _hmc_next(step_size, inverse_mass_matrix, vv_state,

Mutant 91

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -286,7 +286,7 @@
         energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
         delta_energy = energy_new - energy_old
         delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
-        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
+        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=2.0)
         diverging = delta_energy > max_delta_energy
         transition = random.bernoulli(rng_key, accept_prob)
         vv_state, energy = cond(transition,

Mutant 93

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -287,7 +287,7 @@
         delta_energy = energy_new - energy_old
         delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
         accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
-        diverging = delta_energy > max_delta_energy
+        diverging = delta_energy >= max_delta_energy
         transition = random.bernoulli(rng_key, accept_prob)
         vv_state, energy = cond(transition,
                                 (vv_state_new, energy_new), identity,

Mutant 107

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -327,7 +327,7 @@
             Hamiltonian dynamics given existing state.
 
         """
-        model_kwargs = {} if model_kwargs is None else model_kwargs
+        model_kwargs = {} if model_kwargs is not None else model_kwargs
         rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3)
         r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum)
         vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)

Mutant 120

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -345,7 +345,7 @@
                            identity)
 
         itr = hmc_state.i + 1
-        n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
+        n = jnp.where(hmc_state.i <= wa_steps, itr, itr - wa_steps)
         mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n
 
         return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,

Mutant 121

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -345,7 +345,7 @@
                            identity)
 
         itr = hmc_state.i + 1
-        n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
+        n = jnp.where(hmc_state.i < wa_steps, itr, itr + wa_steps)
         mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n
 
         return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,

Mutant 123

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -346,7 +346,7 @@
 
         itr = hmc_state.i + 1
         n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
-        mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n
+        mean_accept_prob = hmc_state.mean_accept_prob - (accept_prob - hmc_state.mean_accept_prob) / n
 
         return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,
                         accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)

Mutant 124

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -346,7 +346,7 @@
 
         itr = hmc_state.i + 1
         n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
-        mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n
+        mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob + hmc_state.mean_accept_prob) / n
 
         return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,
                         accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)

Mutant 125

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -346,7 +346,7 @@
 
         itr = hmc_state.i + 1
         n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
-        mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n
+        mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) * n
 
         return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,
                         accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)

Mutant 127

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -353,7 +353,7 @@
 
     # Make `init_kernel` and `sample_kernel` visible from the global scope once
     # `hmc` is called for sphinx doc generation.
-    if 'SPHINX_BUILD' in os.environ:
+    if 'XXSPHINX_BUILDXX' in os.environ:
         hmc.init_kernel = init_kernel
         hmc.sample_kernel = sample_kernel
 

Mutant 128

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -353,7 +353,7 @@
 
     # Make `init_kernel` and `sample_kernel` visible from the global scope once
     # `hmc` is called for sphinx doc generation.
-    if 'SPHINX_BUILD' in os.environ:
+    if 'SPHINX_BUILD' not in os.environ:
         hmc.init_kernel = init_kernel
         hmc.sample_kernel = sample_kernel
 

Mutant 129

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -376,7 +376,6 @@
         """
         return identity
 
-    @abstractmethod
     def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
         """
         Initialize the `MCMCKernel` and return an initial state to begin sampling

Mutant 130

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -393,7 +393,6 @@
         """
         raise NotImplementedError
 
-    @abstractmethod
     def sample(self, state, model_args, model_kwargs):
         """
         Given the current `state`, return the next `state` using the given

Mutant 131

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -450,7 +450,7 @@
                  model=None,
                  potential_fn=None,
                  kinetic_fn=None,
-                 step_size=1.0,
+                 step_size=2.0,
                  adapt_step_size=True,
                  adapt_mass_matrix=True,
                  dense_mass=False,

Mutant 133

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -452,7 +452,7 @@
                  kinetic_fn=None,
                  step_size=1.0,
                  adapt_step_size=True,
-                 adapt_mass_matrix=True,
+                 adapt_mass_matrix=False,
                  dense_mass=False,
                  target_accept_prob=0.8,
                  trajectory_length=2 * math.pi,

Mutant 134

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -453,7 +453,7 @@
                  step_size=1.0,
                  adapt_step_size=True,
                  adapt_mass_matrix=True,
-                 dense_mass=False,
+                 dense_mass=True,
                  target_accept_prob=0.8,
                  trajectory_length=2 * math.pi,
                  init_strategy=init_to_uniform,

Mutant 136

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -455,7 +455,7 @@
                  adapt_mass_matrix=True,
                  dense_mass=False,
                  target_accept_prob=0.8,
-                 trajectory_length=2 * math.pi,
+                 trajectory_length=3 * math.pi,
                  init_strategy=init_to_uniform,
                  find_heuristic_step_size=False):
         if not (model is None) ^ (potential_fn is None):

Mutant 137

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -455,7 +455,7 @@
                  adapt_mass_matrix=True,
                  dense_mass=False,
                  target_accept_prob=0.8,
-                 trajectory_length=2 * math.pi,
+                 trajectory_length=2 / math.pi,
                  init_strategy=init_to_uniform,
                  find_heuristic_step_size=False):
         if not (model is None) ^ (potential_fn is None):

Mutant 138

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -457,7 +457,7 @@
                  target_accept_prob=0.8,
                  trajectory_length=2 * math.pi,
                  init_strategy=init_to_uniform,
-                 find_heuristic_step_size=False):
+                 find_heuristic_step_size=True):
         if not (model is None) ^ (potential_fn is None):
             raise ValueError('Only one of `model` or `potential_fn` must be specified.')
         self._model = model

Mutant 144

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -461,7 +461,7 @@
         if not (model is None) ^ (potential_fn is None):
             raise ValueError('Only one of `model` or `potential_fn` must be specified.')
         self._model = model
-        self._potential_fn = potential_fn
+        self._potential_fn = None
         self._kinetic_fn = kinetic_fn if kinetic_fn is not None else euclidean_kinetic_energy
         self._step_size = step_size
         self._adapt_step_size = adapt_step_size

Mutant 145

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -462,7 +462,7 @@
             raise ValueError('Only one of `model` or `potential_fn` must be specified.')
         self._model = model
         self._potential_fn = potential_fn
-        self._kinetic_fn = kinetic_fn if kinetic_fn is not None else euclidean_kinetic_energy
+        self._kinetic_fn = kinetic_fn if kinetic_fn is  None else euclidean_kinetic_energy
         self._step_size = step_size
         self._adapt_step_size = adapt_step_size
         self._adapt_mass_matrix = adapt_mass_matrix

Mutant 146

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -462,7 +462,7 @@
             raise ValueError('Only one of `model` or `potential_fn` must be specified.')
         self._model = model
         self._potential_fn = potential_fn
-        self._kinetic_fn = kinetic_fn if kinetic_fn is not None else euclidean_kinetic_energy
+        self._kinetic_fn = None
         self._step_size = step_size
         self._adapt_step_size = adapt_step_size
         self._adapt_mass_matrix = adapt_mass_matrix

Mutant 150

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -466,7 +466,7 @@
         self._step_size = step_size
         self._adapt_step_size = adapt_step_size
         self._adapt_mass_matrix = adapt_mass_matrix
-        self._dense_mass = dense_mass
+        self._dense_mass = None
         self._target_accept_prob = target_accept_prob
         self._trajectory_length = trajectory_length
         self._algo = 'HMC'

Mutant 155

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -470,7 +470,7 @@
         self._target_accept_prob = target_accept_prob
         self._trajectory_length = trajectory_length
         self._algo = 'HMC'
-        self._max_tree_depth = 10
+        self._max_tree_depth = 11
         self._init_strategy = init_strategy
         self._find_heuristic_step_size = find_heuristic_step_size
         # Set on first call to init

Mutant 156

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -470,7 +470,7 @@
         self._target_accept_prob = target_accept_prob
         self._trajectory_length = trajectory_length
         self._algo = 'HMC'
-        self._max_tree_depth = 10
+        self._max_tree_depth = None
         self._init_strategy = init_strategy
         self._find_heuristic_step_size = find_heuristic_step_size
         # Set on first call to init

Mutant 157

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -471,7 +471,7 @@
         self._trajectory_length = trajectory_length
         self._algo = 'HMC'
         self._max_tree_depth = 10
-        self._init_strategy = init_strategy
+        self._init_strategy = None
         self._find_heuristic_step_size = find_heuristic_step_size
         # Set on first call to init
         self._init_fn = None

Mutant 158

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -472,7 +472,7 @@
         self._algo = 'HMC'
         self._max_tree_depth = 10
         self._init_strategy = init_strategy
-        self._find_heuristic_step_size = find_heuristic_step_size
+        self._find_heuristic_step_size = None
         # Set on first call to init
         self._init_fn = None
         self._postprocess_fn = None

Mutant 160

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -475,7 +475,7 @@
         self._find_heuristic_step_size = find_heuristic_step_size
         # Set on first call to init
         self._init_fn = None
-        self._postprocess_fn = None
+        self._postprocess_fn = ""
         self._sample_fn = None
 
     def _init_state(self, rng_key, model_args, model_kwargs, init_params):

Mutant 161

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -476,7 +476,7 @@
         # Set on first call to init
         self._init_fn = None
         self._postprocess_fn = None
-        self._sample_fn = None
+        self._sample_fn = ""
 
     def _init_state(self, rng_key, model_args, model_kwargs, init_params):
         if self._model is not None:

Mutant 167

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -486,7 +486,7 @@
                 dynamic_args=True,
                 model_args=model_args,
                 model_kwargs=model_kwargs)
-            if any(v['type'] == 'param' for v in model_trace.values()):
+            if any(v['type'] == 'XXparamXX' for v in model_trace.values()):
                 warnings.warn("'param' sites will be treated as constants during inference. To define "
                               "an improper variable, please use a 'sample' site with log probability "
                               "masked out. For example, `sample('x', dist.LogNormal(0, 1).mask(False)` "

Mutant 171

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -503,7 +503,6 @@
 
         return init_params
 
-    @property
     def model(self):
         return self._model
 

Mutant 172

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -507,7 +507,6 @@
     def model(self):
         return self._model
 
-    @copy_docs_from(MCMCKernel.init)
     def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}):
         # non-vectorized
         if rng_key.ndim == 1:

Mutant 177

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -516,7 +516,7 @@
         else:
             rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1)
         init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
-        if self._potential_fn and init_params is None:
+        if self._potential_fn and init_params is not None:
             raise ValueError('Valid value of `init_params` must be provided with'
                              ' `potential_fn`.')
 

Mutant 178

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -516,7 +516,7 @@
         else:
             rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1)
         init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
-        if self._potential_fn and init_params is None:
+        if self._potential_fn or init_params is None:
             raise ValueError('Valid value of `init_params` must be provided with'
                              ' `potential_fn`.')
 

Mutant 182

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -546,7 +546,6 @@
             self._sample_fn = sample_fn
         return init_state
 
-    @copy_docs_from(MCMCKernel.postprocess_fn)
     def postprocess_fn(self, args, kwargs):
         if self._postprocess_fn is None:
             return identity

Mutant 187

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -616,7 +616,7 @@
                  step_size=1.0,
                  adapt_step_size=True,
                  adapt_mass_matrix=True,
-                 dense_mass=False,
+                 dense_mass=True,
                  target_accept_prob=0.8,
                  trajectory_length=None,
                  max_tree_depth=10,

Mutant 190

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -621,7 +621,7 @@
                  trajectory_length=None,
                  max_tree_depth=10,
                  init_strategy=init_to_uniform,
-                 find_heuristic_step_size=False):
+                 find_heuristic_step_size=True):
         super(NUTS, self).__init__(potential_fn=potential_fn, model=model, kinetic_fn=kinetic_fn,
                                    step_size=step_size, adapt_step_size=adapt_step_size,
                                    adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass,

Mutant 198

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -639,7 +639,7 @@
     #   proposal_loc.shape[0] = proposal_loc.shape[0] = N
     # Here, we use the numerical stability procedure in Appendix 6 of [1].
     weight = 1 / samples.shape[0]
-    if scale.ndim > loc.ndim:
+    if scale.ndim >= loc.ndim:
         new_scale = cholesky_update(scale, new_sample - loc, weight)
         proposal_scale = cholesky_update(new_scale, samples - loc, -weight)
         proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 2))

Mutant 199

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -640,7 +640,7 @@
     # Here, we use the numerical stability procedure in Appendix 6 of [1].
     weight = 1 / samples.shape[0]
     if scale.ndim > loc.ndim:
-        new_scale = cholesky_update(scale, new_sample - loc, weight)
+        new_scale = cholesky_update(scale, new_sample + loc, weight)
         proposal_scale = cholesky_update(new_scale, samples - loc, -weight)
         proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 2))
     else:

Mutant 202

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -641,7 +641,7 @@
     weight = 1 / samples.shape[0]
     if scale.ndim > loc.ndim:
         new_scale = cholesky_update(scale, new_sample - loc, weight)
-        proposal_scale = cholesky_update(new_scale, samples - loc, -weight)
+        proposal_scale = cholesky_update(new_scale, samples - loc, +weight)
         proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 2))
     else:
         var = jnp.square(scale) + weight * jnp.square(new_sample - loc)

Mutant 205

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -642,7 +642,7 @@
     if scale.ndim > loc.ndim:
         new_scale = cholesky_update(scale, new_sample - loc, weight)
         proposal_scale = cholesky_update(new_scale, samples - loc, -weight)
-        proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 2))
+        proposal_scale = cholesky_update(proposal_scale, new_sample - samples, + (weight ** 2))
     else:
         var = jnp.square(scale) + weight * jnp.square(new_sample - loc)
         proposal_var = var - weight * jnp.square(samples - loc)

Mutant 207

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -642,7 +642,7 @@
     if scale.ndim > loc.ndim:
         new_scale = cholesky_update(scale, new_sample - loc, weight)
         proposal_scale = cholesky_update(new_scale, samples - loc, -weight)
-        proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 2))
+        proposal_scale = cholesky_update(proposal_scale, new_sample - samples, - (weight ** 3))
     else:
         var = jnp.square(scale) + weight * jnp.square(new_sample - loc)
         proposal_var = var - weight * jnp.square(samples - loc)

Mutant 220

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -658,7 +658,7 @@
     if inv_mass_matrix_sqrt.ndim == 1:
         r = jnp.multiply(inv_mass_matrix_sqrt, eps)
     elif inv_mass_matrix_sqrt.ndim == 2:
-        r = jnp.matmul(inv_mass_matrix_sqrt, eps[..., None])[..., 0]
+        r = jnp.matmul(inv_mass_matrix_sqrt, eps[..., None])[..., 1]
     else:
         raise ValueError("Mass matrix has incorrect number of dims.")
     return r

Mutant 222

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -668,7 +668,7 @@
 # because we might lose precision after many iterations of using _get_proposal_loc_and_scale;
 # If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
 # We may also update those values every 10D iterations...
-SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
+SAAdaptState = namedtuple('XXSAAdaptStateXX', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
 SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
                                  'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
 """

Mutant 223

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -668,7 +668,7 @@
 # because we might lose precision after many iterations of using _get_proposal_loc_and_scale;
 # If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
 # We may also update those values every 10D iterations...
-SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
+SAAdaptState = namedtuple('SAAdaptState', ['XXzsXX', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
 SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
                                  'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
 """

Mutant 224

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -668,7 +668,7 @@
 # because we might lose precision after many iterations of using _get_proposal_loc_and_scale;
 # If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
 # We may also update those values every 10D iterations...
-SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
+SAAdaptState = namedtuple('SAAdaptState', ['zs', 'XXpesXX', 'loc', 'inv_mass_matrix_sqrt'])
 SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
                                  'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
 """

Mutant 225

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -668,7 +668,7 @@
 # because we might lose precision after many iterations of using _get_proposal_loc_and_scale;
 # If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
 # We may also update those values every 10D iterations...
-SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
+SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'XXlocXX', 'inv_mass_matrix_sqrt'])
 SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
                                  'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
 """

Mutant 226

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -668,7 +668,7 @@
 # because we might lose precision after many iterations of using _get_proposal_loc_and_scale;
 # If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
 # We may also update those values every 10D iterations...
-SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
+SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'XXinv_mass_matrix_sqrtXX'])
 SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
                                  'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
 """

Mutant 228

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -669,7 +669,7 @@
 # If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
 # We may also update those values every 10D iterations...
 SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
-SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
+SAState = namedtuple('XXSAStateXX', ['i', 'z', 'potential_energy', 'accept_prob',
                                  'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
 """
 A :func:`~collections.namedtuple` used in Sample Adaptive MCMC.

Mutant 232

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -669,7 +669,7 @@
 # If we recompute, we don't need to store `loc` and `inv_mass_matrix_sqrt` here.
 # We may also update those values every 10D iterations...
 SAAdaptState = namedtuple('SAAdaptState', ['zs', 'pes', 'loc', 'inv_mass_matrix_sqrt'])
-SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'accept_prob',
+SAState = namedtuple('SAState', ['i', 'z', 'potential_energy', 'XXaccept_probXX',
                                  'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
 """
 A :func:`~collections.namedtuple` used in Sample Adaptive MCMC.

Mutant 271

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -732,7 +732,7 @@
         z = init_params
         z_flat, unravel_fn = ravel_pytree(z)
         if inverse_mass_matrix is None:
-            inverse_mass_matrix = jnp.identity(z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-1])
+            inverse_mass_matrix = jnp.identity(z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[+1])
         inv_mass_matrix_sqrt = jnp.linalg.cholesky(inverse_mass_matrix) if dense_mass \
             else jnp.sqrt(inverse_mass_matrix)
         if adapt_state_size is None:

Mutant 272

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -732,7 +732,7 @@
         z = init_params
         z_flat, unravel_fn = ravel_pytree(z)
         if inverse_mass_matrix is None:
-            inverse_mass_matrix = jnp.identity(z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-1])
+            inverse_mass_matrix = jnp.identity(z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-2])
         inv_mass_matrix_sqrt = jnp.linalg.cholesky(inverse_mass_matrix) if dense_mass \
             else jnp.sqrt(inverse_mass_matrix)
         if adapt_state_size is None:

Mutant 275

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -737,7 +737,7 @@
             else jnp.sqrt(inverse_mass_matrix)
         if adapt_state_size is None:
             # XXX: heuristic choice
-            adapt_state_size = 2 * z_flat.shape[-1]
+            adapt_state_size = 3 * z_flat.shape[-1]
         else:
             assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
         # NB: mean is init_params

Mutant 280

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -739,7 +739,7 @@
             # XXX: heuristic choice
             adapt_state_size = 2 * z_flat.shape[-1]
         else:
-            assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
+            assert adapt_state_size >= 1, 'adapt_state_size should be greater than 1.'
         # NB: mean is init_params
         zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size,))
         # compute potential energies

Mutant 281

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -739,7 +739,7 @@
             # XXX: heuristic choice
             adapt_state_size = 2 * z_flat.shape[-1]
         else:
-            assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
+            assert adapt_state_size > 2, 'adapt_state_size should be greater than 1.'
         # NB: mean is init_params
         zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size,))
         # compute potential energies

Mutant 282

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -739,7 +739,7 @@
             # XXX: heuristic choice
             adapt_state_size = 2 * z_flat.shape[-1]
         else:
-            assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
+            assert adapt_state_size > 1, 'XXadapt_state_size should be greater than 1.XX'
         # NB: mean is init_params
         zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size,))
         # compute potential energies

Mutant 283

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -741,7 +741,7 @@
         else:
             assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
         # NB: mean is init_params
-        zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size,))
+        zs = z_flat - _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size,))
         # compute potential energies
         pes = lax.map(lambda z: pe_fn(unravel_fn(z)), zs)
         if dense_mass:

Mutant 288

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -745,7 +745,7 @@
         # compute potential energies
         pes = lax.map(lambda z: pe_fn(unravel_fn(z)), zs)
         if dense_mass:
-            cov = jnp.cov(zs, rowvar=False, bias=True)
+            cov = jnp.cov(zs, rowvar=False, bias=False)
             if cov.shape == ():  # JAX returns scalar for 1D input
                 cov = cov.reshape((1, 1))
             inv_mass_matrix_sqrt = jnp.linalg.cholesky(cov)

Mutant 294

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -752,7 +752,7 @@
         else:
             inv_mass_matrix_sqrt = jnp.std(zs, 0)
         adapt_state = SAAdaptState(zs, pes, jnp.mean(zs, 0), inv_mass_matrix_sqrt)
-        k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
+        k = random.categorical(rng_key_z, jnp.zeros(zs.shape[1]))
         z = unravel_fn(zs[k])
         pe = pes[k]
         sa_state = SAState(0, z, pe, 0., 0., False, adapt_state, rng_key_sa)

Mutant 298

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -755,7 +755,7 @@
         k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
         z = unravel_fn(zs[k])
         pe = pes[k]
-        sa_state = SAState(0, z, pe, 0., 0., False, adapt_state, rng_key_sa)
+        sa_state = SAState(1, z, pe, 0., 0., False, adapt_state, rng_key_sa)
         return device_put(sa_state)
 
     def sample_kernel(sa_state, model_args=(), model_kwargs=None):

Mutant 299

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -755,7 +755,7 @@
         k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
         z = unravel_fn(zs[k])
         pe = pes[k]
-        sa_state = SAState(0, z, pe, 0., 0., False, adapt_state, rng_key_sa)
+        sa_state = SAState(0, z, pe, 1.0, 0., False, adapt_state, rng_key_sa)
         return device_put(sa_state)
 
     def sample_kernel(sa_state, model_args=(), model_kwargs=None):

Mutant 300

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -755,7 +755,7 @@
         k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
         z = unravel_fn(zs[k])
         pe = pes[k]
-        sa_state = SAState(0, z, pe, 0., 0., False, adapt_state, rng_key_sa)
+        sa_state = SAState(0, z, pe, 0., 1.0, False, adapt_state, rng_key_sa)
         return device_put(sa_state)
 
     def sample_kernel(sa_state, model_args=(), model_kwargs=None):

Mutant 301

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -755,7 +755,7 @@
         k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
         z = unravel_fn(zs[k])
         pe = pes[k]
-        sa_state = SAState(0, z, pe, 0., 0., False, adapt_state, rng_key_sa)
+        sa_state = SAState(0, z, pe, 0., 0., True, adapt_state, rng_key_sa)
         return device_put(sa_state)
 
     def sample_kernel(sa_state, model_args=(), model_kwargs=None):

Mutant 309

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -766,7 +766,7 @@
         rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(sa_state.rng_key, 4)
         _, unravel_fn = ravel_pytree(sa_state.z)
 
-        z = loc + _sample_proposal(scale, rng_key_z)
+        z = loc - _sample_proposal(scale, rng_key_z)
         pe = pe_fn(unravel_fn(z))
         pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
         diverging = (pe - sa_state.potential_energy) > max_delta_energy

Mutant 313

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -769,7 +769,7 @@
         z = loc + _sample_proposal(scale, rng_key_z)
         pe = pe_fn(unravel_fn(z))
         pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
-        diverging = (pe - sa_state.potential_energy) > max_delta_energy
+        diverging = (pe + sa_state.potential_energy) > max_delta_energy
 
         # NB: all terms having the pattern *s will have shape N x ...
         # and all terms having the pattern *s_ will have shape (N + 1) x ...

Mutant 314

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -769,7 +769,7 @@
         z = loc + _sample_proposal(scale, rng_key_z)
         pe = pe_fn(unravel_fn(z))
         pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
-        diverging = (pe - sa_state.potential_energy) > max_delta_energy
+        diverging = (pe - sa_state.potential_energy) >= max_delta_energy
 
         # NB: all terms having the pattern *s will have shape N x ...
         # and all terms having the pattern *s_ will have shape (N + 1) x ...

Mutant 333

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -792,7 +792,7 @@
         adapt_state = SAAdaptState(zs, pes, loc, scale)
 
         # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
-        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
+        accept_prob = 2 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
         itr = sa_state.i + 1
         n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
         mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n

Mutant 334

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -792,7 +792,7 @@
         adapt_state = SAAdaptState(zs, pes, loc, scale)
 
         # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
-        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
+        accept_prob = 1 + jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
         itr = sa_state.i + 1
         n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
         mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n

Mutant 335

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -792,7 +792,7 @@
         adapt_state = SAAdaptState(zs, pes, loc, scale)
 
         # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
-        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
+        accept_prob = 1 - jnp.exp(log_weights_[+1] - logsumexp(log_weights_))
         itr = sa_state.i + 1
         n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
         mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n

Mutant 336

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -792,7 +792,7 @@
         adapt_state = SAAdaptState(zs, pes, loc, scale)
 
         # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
-        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
+        accept_prob = 1 - jnp.exp(log_weights_[-2] - logsumexp(log_weights_))
         itr = sa_state.i + 1
         n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
         mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n

Mutant 337

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -792,7 +792,7 @@
         adapt_state = SAAdaptState(zs, pes, loc, scale)
 
         # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
-        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
+        accept_prob = 1 - jnp.exp(log_weights_[-1] + logsumexp(log_weights_))
         itr = sa_state.i + 1
         n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
         mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n

Mutant 339

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -793,7 +793,7 @@
 
         # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
         accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
-        itr = sa_state.i + 1
+        itr = sa_state.i - 1
         n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
         mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
 

Mutant 340

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -793,7 +793,7 @@
 
         # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
         accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
-        itr = sa_state.i + 1
+        itr = sa_state.i + 2
         n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
         mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
 

Mutant 342

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -794,7 +794,7 @@
         # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
         accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
         itr = sa_state.i + 1
-        n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
+        n = jnp.where(sa_state.i <= wa_steps, itr, itr - wa_steps)
         mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
 
         # XXX: we make a modification of SA sampler in [1]

Mutant 343

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -794,7 +794,7 @@
         # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
         accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
         itr = sa_state.i + 1
-        n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
+        n = jnp.where(sa_state.i < wa_steps, itr, itr + wa_steps)
         mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
 
         # XXX: we make a modification of SA sampler in [1]

Mutant 345

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -795,7 +795,7 @@
         accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
         itr = sa_state.i + 1
         n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
-        mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
+        mean_accept_prob = sa_state.mean_accept_prob - (accept_prob - sa_state.mean_accept_prob) / n
 
         # XXX: we make a modification of SA sampler in [1]
         # in [1], each MCMC state contains N points `zs`

Mutant 346

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -795,7 +795,7 @@
         accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
         itr = sa_state.i + 1
         n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
-        mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
+        mean_accept_prob = sa_state.mean_accept_prob + (accept_prob + sa_state.mean_accept_prob) / n
 
         # XXX: we make a modification of SA sampler in [1]
         # in [1], each MCMC state contains N points `zs`

Mutant 347

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -795,7 +795,7 @@
         accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
         itr = sa_state.i + 1
         n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
-        mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
+        mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) * n
 
         # XXX: we make a modification of SA sampler in [1]
         # in [1], each MCMC state contains N points `zs`

Mutant 349

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -800,7 +800,7 @@
         # XXX: we make a modification of SA sampler in [1]
         # in [1], each MCMC state contains N points `zs`
         # here we do resampling to pick randomly a point from those N points
-        k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0]))
+        k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[1]))
         z = unravel_fn(zs[k])
         pe = pes[k]
         return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)

Mutant 353

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -843,7 +843,7 @@
         See :ref:`init_strategy` section for available functions.
     """
     def __init__(self, model=None, potential_fn=None, adapt_state_size=None,
-                 dense_mass=True, init_strategy=init_to_uniform):
+                 dense_mass=False, init_strategy=init_to_uniform):
         if not (model is None) ^ (potential_fn is None):
             raise ValueError('Only one of `model` or `potential_fn` must be specified.')
         self._model = model

Mutant 360

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -848,7 +848,7 @@
             raise ValueError('Only one of `model` or `potential_fn` must be specified.')
         self._model = model
         self._potential_fn = potential_fn
-        self._adapt_state_size = adapt_state_size
+        self._adapt_state_size = None
         self._dense_mass = dense_mass
         self._init_strategy = init_strategy
         self._init_fn = None

Mutant 390

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -953,7 +953,7 @@
 
 
 def _collect_fn(collect_fields):
-    @cached_by(_collect_fn, collect_fields)
+
     def collect(x):
         return attrgetter(*collect_fields)(x[0])
 

Mutant 397

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1016,7 +1016,7 @@
         self.num_warmup = num_warmup
         self.num_samples = num_samples
         self.num_chains = num_chains
-        self.postprocess_fn = postprocess_fn
+        self.postprocess_fn = None
         self.chain_method = chain_method
         self.progress_bar = progress_bar
         # TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1

Mutant 398

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1017,7 +1017,7 @@
         self.num_samples = num_samples
         self.num_chains = num_chains
         self.postprocess_fn = postprocess_fn
-        self.chain_method = chain_method
+        self.chain_method = None
         self.progress_bar = progress_bar
         # TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
         if (chain_method == 'parallel' and num_chains > 1) or (

Mutant 400

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1020,7 +1020,7 @@
         self.chain_method = chain_method
         self.progress_bar = progress_bar
         # TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
-        if (chain_method == 'parallel' and num_chains > 1) or (
+        if (chain_method != 'parallel' and num_chains > 1) or (
                 "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
             self.progress_bar = False
         self._jit_model_args = jit_model_args

Mutant 401

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1020,7 +1020,7 @@
         self.chain_method = chain_method
         self.progress_bar = progress_bar
         # TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
-        if (chain_method == 'parallel' and num_chains > 1) or (
+        if (chain_method == 'XXparallelXX' and num_chains > 1) or (
                 "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
             self.progress_bar = False
         self._jit_model_args = jit_model_args

Mutant 403

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1020,7 +1020,7 @@
         self.chain_method = chain_method
         self.progress_bar = progress_bar
         # TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
-        if (chain_method == 'parallel' and num_chains > 1) or (
+        if (chain_method == 'parallel' and num_chains > 2) or (
                 "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
             self.progress_bar = False
         self._jit_model_args = jit_model_args

Mutant 405

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1021,7 +1021,7 @@
         self.progress_bar = progress_bar
         # TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
         if (chain_method == 'parallel' and num_chains > 1) or (
-                "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
+                "XXCIXX" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
             self.progress_bar = False
         self._jit_model_args = jit_model_args
         self._states = None

Mutant 407

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1021,7 +1021,7 @@
         self.progress_bar = progress_bar
         # TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
         if (chain_method == 'parallel' and num_chains > 1) or (
-                "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
+                "CI" in os.environ or "XXPYTEST_XDIST_WORKERXX" in os.environ):
             self.progress_bar = False
         self._jit_model_args = jit_model_args
         self._states = None

Mutant 409

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1021,7 +1021,7 @@
         self.progress_bar = progress_bar
         # TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
         if (chain_method == 'parallel' and num_chains > 1) or (
-                "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
+                "CI" in os.environ and "PYTEST_XDIST_WORKER" in os.environ):
             self.progress_bar = False
         self._jit_model_args = jit_model_args
         self._states = None

Mutant 410

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1020,7 +1020,7 @@
         self.chain_method = chain_method
         self.progress_bar = progress_bar
         # TODO: We should have progress bars (maybe without diagnostics) for num_chains > 1
-        if (chain_method == 'parallel' and num_chains > 1) or (
+        if (chain_method == 'parallel' and num_chains > 1) and (
                 "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
             self.progress_bar = False
         self._jit_model_args = jit_model_args

Mutant 411

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1023,7 +1023,7 @@
         if (chain_method == 'parallel' and num_chains > 1) or (
                 "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
             self.progress_bar = False
-        self._jit_model_args = jit_model_args
+        self._jit_model_args = None
         self._states = None
         self._states_flat = None
         # HMCState returned by last run

Mutant 412

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1024,7 +1024,7 @@
                 "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ):
             self.progress_bar = False
         self._jit_model_args = jit_model_args
-        self._states = None
+        self._states = ""
         self._states_flat = None
         # HMCState returned by last run
         self._last_state = None

Mutant 413

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1025,7 +1025,7 @@
             self.progress_bar = False
         self._jit_model_args = jit_model_args
         self._states = None
-        self._states_flat = None
+        self._states_flat = ""
         # HMCState returned by last run
         self._last_state = None
         # HMCState returned by last warmup

Mutant 414

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1027,7 +1027,7 @@
         self._states = None
         self._states_flat = None
         # HMCState returned by last run
-        self._last_state = None
+        self._last_state = ""
         # HMCState returned by last warmup
         self._warmup_state = None
         # HMCState returned by hmc.init_kernel

Mutant 422

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1041,7 +1041,7 @@
             args, kwargs = (None,), (None,)
         else:
             args = tree_map(lambda x: _hashable(x), self._args)
-            kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(self._kwargs.items())))
+            kwargs = tree_map(lambda x: None, tuple(sorted(self._kwargs.items())))
         key = args + kwargs
         try:
             fn = self._cache.get(key, None)

Mutant 425

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1042,7 +1042,7 @@
         else:
             args = tree_map(lambda x: _hashable(x), self._args)
             kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(self._kwargs.items())))
-        key = args + kwargs
+        key = None
         try:
             fn = self._cache.get(key, None)
         # If unhashable arguments are provided, proceed normally

Mutant 426

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1044,7 +1044,7 @@
             kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(self._kwargs.items())))
         key = args + kwargs
         try:
-            fn = self._cache.get(key, None)
+            fn = None
         # If unhashable arguments are provided, proceed normally
         # without caching
         except TypeError:

Mutant 430

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1055,7 +1055,7 @@
             else:
                 fn = partial(_sample_fn_nojit_args, sampler=self.sampler,
                              args=self._args, kwargs=self._kwargs)
-            if key is not None:
+            if key is  None:
                 self._cache[key] = fn
         return fn
 

Mutant 431

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1056,7 +1056,7 @@
                 fn = partial(_sample_fn_nojit_args, sampler=self.sampler,
                              args=self._args, kwargs=self._kwargs)
             if key is not None:
-                self._cache[key] = fn
+                self._cache[key] = None
         return fn
 
     def _get_cached_init_state(self, rng_key, args, kwargs):

Mutant 433

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1061,7 +1061,7 @@
 
     def _get_cached_init_state(self, rng_key, args, kwargs):
         rng_key = (_hashable(rng_key),)
-        args = tree_map(lambda x: _hashable(x), args)
+        args = tree_map(lambda x: None, args)
         kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(kwargs.items())))
         key = rng_key + args + kwargs
         try:

Mutant 435

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1062,7 +1062,7 @@
     def _get_cached_init_state(self, rng_key, args, kwargs):
         rng_key = (_hashable(rng_key),)
         args = tree_map(lambda x: _hashable(x), args)
-        kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(kwargs.items())))
+        kwargs = tree_map(lambda x: None, tuple(sorted(kwargs.items())))
         key = rng_key + args + kwargs
         try:
             return self._init_state_cache.get(key, None)

Mutant 439

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1063,7 +1063,7 @@
         rng_key = (_hashable(rng_key),)
         args = tree_map(lambda x: _hashable(x), args)
         kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(kwargs.items())))
-        key = rng_key + args + kwargs
+        key = None
         try:
             return self._init_state_cache.get(key, None)
         # If unhashable arguments are provided, return None

Mutant 440

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1070,7 +1070,7 @@
         except TypeError:
             return None
 
-    def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields=('z',)):
+    def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields=('XXzXX',)):
         if init_state is None:
             init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
                                            model_args=args, model_kwargs=kwargs)

Mutant 446

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1078,7 +1078,7 @@
             postprocess_fn = self.sampler.postprocess_fn(args, kwargs)
         else:
             postprocess_fn = self.postprocess_fn
-        diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 1 else None   # noqa: E731
+        diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim != 1 else None   # noqa: E731
         init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
         lower_idx = self._collection_params["lower"]
         upper_idx = self._collection_params["upper"]

Mutant 447

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1078,7 +1078,7 @@
             postprocess_fn = self.sampler.postprocess_fn(args, kwargs)
         else:
             postprocess_fn = self.postprocess_fn
-        diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 1 else None   # noqa: E731
+        diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 2 else None   # noqa: E731
         init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
         lower_idx = self._collection_params["lower"]
         upper_idx = self._collection_params["upper"]

Mutant 448

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1078,7 +1078,7 @@
             postprocess_fn = self.sampler.postprocess_fn(args, kwargs)
         else:
             postprocess_fn = self.postprocess_fn
-        diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 1 else None   # noqa: E731
+        diagnostics = lambda x: None   # noqa: E731
         init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
         lower_idx = self._collection_params["lower"]
         upper_idx = self._collection_params["upper"]

Mutant 449

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1078,7 +1078,7 @@
             postprocess_fn = self.sampler.postprocess_fn(args, kwargs)
         else:
             postprocess_fn = self.postprocess_fn
-        diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 1 else None   # noqa: E731
+        diagnostics = None   # noqa: E731
         init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
         lower_idx = self._collection_params["lower"]
         upper_idx = self._collection_params["upper"]

Mutant 467

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1102,7 +1102,7 @@
         states = dict(zip(collect_fields, states))
         # Apply constraints if number of samples is non-zero
         site_values = tree_flatten(states['z'])[0]
-        if len(site_values) > 0 and site_values[0].size > 0:
+        if len(site_values) >= 0 and site_values[0].size > 0:
             states['z'] = lax.map(postprocess_fn, states['z'])
         return states, last_state
 

Mutant 470

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1102,7 +1102,7 @@
         states = dict(zip(collect_fields, states))
         # Apply constraints if number of samples is non-zero
         site_values = tree_flatten(states['z'])[0]
-        if len(site_values) > 0 and site_values[0].size > 0:
+        if len(site_values) > 0 and site_values[0].size >= 0:
             states['z'] = lax.map(postprocess_fn, states['z'])
         return states, last_state
 

Mutant 471

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1102,7 +1102,7 @@
         states = dict(zip(collect_fields, states))
         # Apply constraints if number of samples is non-zero
         site_values = tree_flatten(states['z'])[0]
-        if len(site_values) > 0 and site_values[0].size > 0:
+        if len(site_values) > 0 and site_values[0].size > 1:
             states['z'] = lax.map(postprocess_fn, states['z'])
         return states, last_state
 

Mutant 472

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1102,7 +1102,7 @@
         states = dict(zip(collect_fields, states))
         # Apply constraints if number of samples is non-zero
         site_values = tree_flatten(states['z'])[0]
-        if len(site_values) > 0 and site_values[0].size > 0:
+        if len(site_values) > 0 or site_values[0].size > 0:
             states['z'] = lax.map(postprocess_fn, states['z'])
         return states, last_state
 

Mutant 476

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1106,7 +1106,7 @@
             states['z'] = lax.map(postprocess_fn, states['z'])
         return states, last_state
 
-    def _single_chain_jit_args(self, init, collect_fields=('z',)):
+    def _single_chain_jit_args(self, init, collect_fields=('XXzXX',)):
         return self._single_chain_mcmc(*init, collect_fields=collect_fields)
 
     def _single_chain_nojit_args(self, init, model_args, model_kwargs, collect_fields=('z',)):

Mutant 477

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1109,7 +1109,7 @@
     def _single_chain_jit_args(self, init, collect_fields=('z',)):
         return self._single_chain_mcmc(*init, collect_fields=collect_fields)
 
-    def _single_chain_nojit_args(self, init, model_args, model_kwargs, collect_fields=('z',)):
+    def _single_chain_nojit_args(self, init, model_args, model_kwargs, collect_fields=('XXzXX',)):
         return self._single_chain_mcmc(*init, model_args, model_kwargs, collect_fields=collect_fields)
 
     def _set_collection_params(self, lower=None, upper=None, collection_size=None):

Mutant 486

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1115,7 +1115,7 @@
     def _set_collection_params(self, lower=None, upper=None, collection_size=None):
         self._collection_params["lower"] = self.num_warmup if lower is None else lower
         self._collection_params["upper"] = self.num_warmup + self.num_samples if upper is None else upper
-        self._collection_params["collection_size"] = collection_size
+        self._collection_params["collection_size"] = None
 
     def _compile(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
         self._set_collection_params(0, 0, self.num_samples)

Mutant 487

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1130,7 +1130,7 @@
         except TypeError:
             pass
 
-    def warmup(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs):
+    def warmup(self, rng_key, *args, extra_fields=(), collect_warmup=True, init_params=None, **kwargs):
         """
         Run the MCMC warmup adaptation phase. After this call, the :meth:`run` method
         will skip the warmup adaptation phase. To run `warmup` again for the new data,

Mutant 490

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1181,7 +1181,7 @@
         """
         self._args = args
         self._kwargs = kwargs
-        init_state = self._get_cached_init_state(rng_key, args, kwargs)
+        init_state = None
         if self.num_chains > 1 and rng_key.ndim == 1:
             rng_key = random.split(rng_key, self.num_chains)
 

Mutant 492

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1182,7 +1182,7 @@
         self._args = args
         self._kwargs = kwargs
         init_state = self._get_cached_init_state(rng_key, args, kwargs)
-        if self.num_chains > 1 and rng_key.ndim == 1:
+        if self.num_chains > 2 and rng_key.ndim == 1:
             rng_key = random.split(rng_key, self.num_chains)
 
         if self._warmup_state is not None:

Mutant 493

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1182,7 +1182,7 @@
         self._args = args
         self._kwargs = kwargs
         init_state = self._get_cached_init_state(rng_key, args, kwargs)
-        if self.num_chains > 1 and rng_key.ndim == 1:
+        if self.num_chains > 1 and rng_key.ndim != 1:
             rng_key = random.split(rng_key, self.num_chains)
 
         if self._warmup_state is not None:

Mutant 494

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1182,7 +1182,7 @@
         self._args = args
         self._kwargs = kwargs
         init_state = self._get_cached_init_state(rng_key, args, kwargs)
-        if self.num_chains > 1 and rng_key.ndim == 1:
+        if self.num_chains > 1 and rng_key.ndim == 2:
             rng_key = random.split(rng_key, self.num_chains)
 
         if self._warmup_state is not None:

Mutant 497

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1189,7 +1189,7 @@
             self._set_collection_params(0, self.num_samples, self.num_samples)
             init_state = self._warmup_state._replace(rng_key=rng_key)
 
-        chain_method = self.chain_method
+        chain_method = None
         if chain_method == 'parallel' and xla_bridge.device_count() < self.num_chains:
             chain_method = 'sequential'
             warnings.warn('There are not enough devices to run parallel chains: expected {} but got {}.'

Mutant 498

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1190,7 +1190,7 @@
             init_state = self._warmup_state._replace(rng_key=rng_key)
 
         chain_method = self.chain_method
-        if chain_method == 'parallel' and xla_bridge.device_count() < self.num_chains:
+        if chain_method != 'parallel' and xla_bridge.device_count() < self.num_chains:
             chain_method = 'sequential'
             warnings.warn('There are not enough devices to run parallel chains: expected {} but got {}.'
                           ' Chains will be drawn sequentially. If you are running MCMC in CPU,'

Mutant 499

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1190,7 +1190,7 @@
             init_state = self._warmup_state._replace(rng_key=rng_key)
 
         chain_method = self.chain_method
-        if chain_method == 'parallel' and xla_bridge.device_count() < self.num_chains:
+        if chain_method == 'XXparallelXX' and xla_bridge.device_count() < self.num_chains:
             chain_method = 'sequential'
             warnings.warn('There are not enough devices to run parallel chains: expected {} but got {}.'
                           ' Chains will be drawn sequentially. If you are running MCMC in CPU,'

Mutant 502

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1198,7 +1198,7 @@
                           ' of your program.'
                           .format(self.num_chains, xla_bridge.device_count(), self.num_chains))
 
-        if init_params is not None and self.num_chains > 1:
+        if init_params is  None and self.num_chains > 1:
             prototype_init_val = tree_flatten(init_params)[0][0]
             if jnp.shape(prototype_init_val)[0] != self.num_chains:
                 raise ValueError('`init_params` must have the same leading dimension'

Mutant 503

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1198,7 +1198,7 @@
                           ' of your program.'
                           .format(self.num_chains, xla_bridge.device_count(), self.num_chains))
 
-        if init_params is not None and self.num_chains > 1:
+        if init_params is not None and self.num_chains >= 1:
             prototype_init_val = tree_flatten(init_params)[0][0]
             if jnp.shape(prototype_init_val)[0] != self.num_chains:
                 raise ValueError('`init_params` must have the same leading dimension'

Mutant 504

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1198,7 +1198,7 @@
                           ' of your program.'
                           .format(self.num_chains, xla_bridge.device_count(), self.num_chains))
 
-        if init_params is not None and self.num_chains > 1:
+        if init_params is not None and self.num_chains > 2:
             prototype_init_val = tree_flatten(init_params)[0][0]
             if jnp.shape(prototype_init_val)[0] != self.num_chains:
                 raise ValueError('`init_params` must have the same leading dimension'

Mutant 505

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1198,7 +1198,7 @@
                           ' of your program.'
                           .format(self.num_chains, xla_bridge.device_count(), self.num_chains))
 
-        if init_params is not None and self.num_chains > 1:
+        if init_params is not None or self.num_chains > 1:
             prototype_init_val = tree_flatten(init_params)[0][0]
             if jnp.shape(prototype_init_val)[0] != self.num_chains:
                 raise ValueError('`init_params` must have the same leading dimension'

Mutant 519

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1254,7 +1254,7 @@
             but can be any :func:`jaxlib.pytree`, more generally (e.g. when defining a
             `potential_fn` for HMC that takes `list` args).
         """
-        return self._states['z'] if group_by_chain else self._states_flat['z']
+        return self._states['XXzXX'] if group_by_chain else self._states_flat['z']
 
     def get_extra_fields(self, group_by_chain=False):
         """

Mutant 521

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1256,7 +1256,7 @@
         """
         return self._states['z'] if group_by_chain else self._states_flat['z']
 
-    def get_extra_fields(self, group_by_chain=False):
+    def get_extra_fields(self, group_by_chain=True):
         """
         Get extra fields from the MCMC run.
 

Mutant 523

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1266,7 +1266,7 @@
             `extra_fields` keyword of :meth:`run`.
         """
         states = self._states if group_by_chain else self._states_flat
-        return {k: v for k, v in states.items() if k != 'z'}
+        return {k: v for k, v in states.items() if k == 'z'}
 
     def print_summary(self, prob=0.9, exclude_deterministic=True):
         # Exclude deterministic sites by default

Mutant 524

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1266,7 +1266,7 @@
             `extra_fields` keyword of :meth:`run`.
         """
         states = self._states if group_by_chain else self._states_flat
-        return {k: v for k, v in states.items() if k != 'z'}
+        return {k: v for k, v in states.items() if k != 'XXzXX'}
 
     def print_summary(self, prob=0.9, exclude_deterministic=True):
         # Exclude deterministic sites by default

Mutant 529

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1271,7 +1271,7 @@
     def print_summary(self, prob=0.9, exclude_deterministic=True):
         # Exclude deterministic sites by default
         sites = self._states['z']
-        if isinstance(sites, dict) and exclude_deterministic:
+        if isinstance(sites, dict) or exclude_deterministic:
             sites = {k: v for k, v in self._states['z'].items() if k in self._last_state.z}
         print_summary(sites, prob=prob)
         extra_fields = self.get_extra_fields()

Mutant 534

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1275,6 +1275,6 @@
             sites = {k: v for k, v in self._states['z'].items() if k in self._last_state.z}
         print_summary(sites, prob=prob)
         extra_fields = self.get_extra_fields()
-        if 'diverging' in extra_fields:
+        if 'XXdivergingXX' in extra_fields:
             print("Number of divergences: {}".format(jnp.sum(extra_fields['diverging'])))
 

Mutant 535

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1275,6 +1275,6 @@
             sites = {k: v for k, v in self._states['z'].items() if k in self._last_state.z}
         print_summary(sites, prob=prob)
         extra_fields = self.get_extra_fields()
-        if 'diverging' in extra_fields:
+        if 'diverging' not in extra_fields:
             print("Number of divergences: {}".format(jnp.sum(extra_fields['diverging'])))
 

Mutant 536

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -1276,5 +1276,5 @@
         print_summary(sites, prob=prob)
         extra_fields = self.get_extra_fields()
         if 'diverging' in extra_fields:
-            print("Number of divergences: {}".format(jnp.sum(extra_fields['diverging'])))
-
+            print("XXNumber of divergences: {}XX".format(jnp.sum(extra_fields['diverging'])))
+

Suspicious

Mutants that made the test suite take longer, but otherwise seemed ok

Mutant 44

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -187,7 +187,7 @@
                     adapt_step_size=True,
                     adapt_mass_matrix=True,
                     dense_mass=False,
-                    target_accept_prob=0.8,
+                    target_accept_prob=1.8,
                     trajectory_length=2*math.pi,
                     max_tree_depth=10,
                     find_heuristic_step_size=False,

Mutant 85

--- numpyro/infer/mcmc.py
+++ numpyro/infer/mcmc.py
@@ -283,7 +283,7 @@
                                  lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
                                  vv_state)
         energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r)
-        energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
+        energy_new = vv_state_new.potential_energy - kinetic_fn(inverse_mass_matrix, vv_state_new.r)
         delta_energy = energy_new - energy_old
         delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
         accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)