pyro/infer/mcmc/util.py
Killed 0 out of 15 mutantsSurvived
Survived mutation testing. These mutants show holes in your test suite.Mutant 429
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -106,7 +106,7 @@
for ordinal, log_prob in self._log_probs.items():
self._log_prob_shapes[ordinal] = broadcast_shape(*(t.shape for t in self._log_probs[ordinal]))
- def _reduce(self, ordinal, agg_log_prob=torch.tensor(0.)):
+ def _reduce(self, ordinal, agg_log_prob=torch.tensor(1.0)):
"""
Reduce the log prob terms for the given ordinal:
- taking log_sum_exp of factors in enum dims (i.e.
Mutant 430
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -283,7 +283,7 @@
self._compiled_fn = torch.jit.trace(_pe_jit, vals, **jit_options)
return self._compiled_fn(*vals)
- def get_potential_fn(self, jit_compile=False, skip_jit_warnings=True, jit_options=None):
+ def get_potential_fn(self, jit_compile=True, skip_jit_warnings=True, jit_options=None):
if jit_compile:
jit_options = {"check_trace": False} if jit_options is None else jit_options
return partial(self._potential_fn_jit, skip_jit_warnings, jit_options)
Mutant 431
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -283,7 +283,7 @@
self._compiled_fn = torch.jit.trace(_pe_jit, vals, **jit_options)
return self._compiled_fn(*vals)
- def get_potential_fn(self, jit_compile=False, skip_jit_warnings=True, jit_options=None):
+ def get_potential_fn(self, jit_compile=False, skip_jit_warnings=False, jit_options=None):
if jit_compile:
jit_options = {"check_trace": False} if jit_options is None else jit_options
return partial(self._potential_fn_jit, skip_jit_warnings, jit_options)
Mutant 432
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -292,7 +292,7 @@
# TODO: expose init_strategy using separate functions.
def _get_init_params(model, model_args, model_kwargs, transforms, potential_fn, prototype_params,
- max_tries_initial_params=100, num_chains=1, strategy="uniform"):
+ max_tries_initial_params=101, num_chains=1, strategy="uniform"):
params = prototype_params
params_per_chain = defaultdict(list)
n = 0
Mutant 433
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -292,7 +292,7 @@
# TODO: expose init_strategy using separate functions.
def _get_init_params(model, model_args, model_kwargs, transforms, potential_fn, prototype_params,
- max_tries_initial_params=100, num_chains=1, strategy="uniform"):
+ max_tries_initial_params=100, num_chains=2, strategy="uniform"):
params = prototype_params
params_per_chain = defaultdict(list)
n = 0
Mutant 434
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -292,7 +292,7 @@
# TODO: expose init_strategy using separate functions.
def _get_init_params(model, model_args, model_kwargs, transforms, potential_fn, prototype_params,
- max_tries_initial_params=100, num_chains=1, strategy="uniform"):
+ max_tries_initial_params=100, num_chains=1, strategy="XXuniformXX"):
params = prototype_params
params_per_chain = defaultdict(list)
n = 0
Mutant 435
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -324,7 +324,7 @@
def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max_plate_nesting=None,
- jit_compile=False, jit_options=None, skip_jit_warnings=False, num_chains=1):
+ jit_compile=True, jit_options=None, skip_jit_warnings=False, num_chains=1):
"""
Given a Python callable with Pyro primitives, generates the following model-specific
properties needed for inference using HMC/NUTS kernels:
Mutant 436
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -324,7 +324,7 @@
def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max_plate_nesting=None,
- jit_compile=False, jit_options=None, skip_jit_warnings=False, num_chains=1):
+ jit_compile=False, jit_options=None, skip_jit_warnings=True, num_chains=1):
"""
Given a Python callable with Pyro primitives, generates the following model-specific
properties needed for inference using HMC/NUTS kernels:
Mutant 437
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -324,7 +324,7 @@
def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max_plate_nesting=None,
- jit_compile=False, jit_options=None, skip_jit_warnings=False, num_chains=1):
+ jit_compile=False, jit_options=None, skip_jit_warnings=False, num_chains=2):
"""
Given a Python callable with Pyro primitives, generates the following model-specific
properties needed for inference using HMC/NUTS kernels:
Mutant 438
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -421,7 +421,7 @@
return wrapped
-def diagnostics(samples, group_by_chain=True):
+def diagnostics(samples, group_by_chain=False):
"""
Gets diagnostics statistics such as effective sample size and
split Gelman-Rubin using the samples drawn from the posterior
Mutant 439
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -445,7 +445,7 @@
return diagnostics
-def summary(samples, prob=0.9, group_by_chain=True):
+def summary(samples, prob=1.9, group_by_chain=True):
"""
Returns a summary table displaying diagnostics of ``samples`` from the
posterior. The diagnostics displayed are mean, standard deviation, median,
Mutant 440
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -445,7 +445,7 @@
return diagnostics
-def summary(samples, prob=0.9, group_by_chain=True):
+def summary(samples, prob=0.9, group_by_chain=False):
"""
Returns a summary table displaying diagnostics of ``samples`` from the
posterior. The diagnostics displayed are mean, standard deviation, median,
Mutant 441
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -479,7 +479,7 @@
return summary_dict
-def print_summary(samples, prob=0.9, group_by_chain=True):
+def print_summary(samples, prob=1.9, group_by_chain=True):
"""
Prints a summary table displaying diagnostics of ``samples`` from the
posterior. The diagnostics displayed are mean, standard deviation, median,
Mutant 442
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -479,7 +479,7 @@
return summary_dict
-def print_summary(samples, prob=0.9, group_by_chain=True):
+def print_summary(samples, prob=0.9, group_by_chain=False):
"""
Prints a summary table displaying diagnostics of ``samples`` from the
posterior. The diagnostics displayed are mean, standard deviation, median,
Mutant 443
--- pyro/infer/mcmc/util.py
+++ pyro/infer/mcmc/util.py
@@ -518,7 +518,7 @@
def _predictive_sequential(model, posterior_samples, model_args, model_kwargs,
- num_samples, sample_sites, return_trace=False):
+ num_samples, sample_sites, return_trace=True):
collected = []
samples = [{k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples)]
for i in range(num_samples):