pyro/infer/util.py
Killed 25 out of 41 mutantsSurvived
Survived mutation testing. These mutants show holes in your test suite.Mutant 178
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -16,7 +16,7 @@
from pyro.ops.rings import MarginalRing
from pyro.poutine.util import site_is_subsample
-_VALIDATION_ENABLED = False
+_VALIDATION_ENABLED = True
LAST_CACHE_SIZE = [Counter()] # for profiling
Mutant 179
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -16,7 +16,7 @@
from pyro.ops.rings import MarginalRing
from pyro.poutine.util import site_is_subsample
-_VALIDATION_ENABLED = False
+_VALIDATION_ENABLED = None
LAST_CACHE_SIZE = [Counter()] # for profiling
Mutant 180
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -17,7 +17,7 @@
from pyro.poutine.util import site_is_subsample
_VALIDATION_ENABLED = False
-LAST_CACHE_SIZE = [Counter()] # for profiling
+LAST_CACHE_SIZE = None # for profiling
def enable_validation(is_validate):
Mutant 181
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -22,7 +22,7 @@
def enable_validation(is_validate):
global _VALIDATION_ENABLED
- _VALIDATION_ENABLED = is_validate
+ _VALIDATION_ENABLED = None
def is_validation_enabled():
Mutant 182
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -28,8 +28,6 @@
def is_validation_enabled():
return _VALIDATION_ENABLED
-
-@contextmanager
def validation_enabled(is_validate=True):
old = is_validation_enabled()
try:
Mutant 183
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -30,7 +30,7 @@
@contextmanager
-def validation_enabled(is_validate=True):
+def validation_enabled(is_validate=False):
old = is_validation_enabled()
try:
enable_validation(is_validate)
Mutant 184
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -51,7 +51,7 @@
Like ``x.backward()`` for a :class:`~torch.Tensor`, but also accepts
numbers and tensors without grad_fn (resulting in a no-op)
"""
- if torch.is_tensor(x) and x.grad_fn:
+ if torch.is_tensor(x) or x.grad_fn:
x.backward(retain_graph=retain_graph)
Mutant 186
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -79,7 +79,7 @@
"""
for p in tensors:
if p.grad is not None:
- p.grad = torch.zeros_like(p.grad)
+ p.grad = None
def get_plate_stacks(trace):
Mutant 194
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -121,7 +121,7 @@
"""
for cond_indep_stack, value in items:
frames = frozenset(f for f in cond_indep_stack if f.vectorized)
- assert all(f.dim < 0 and -value.dim() <= f.dim for f in frames)
+ assert all(f.dim <= 0 and -value.dim() <= f.dim for f in frames)
if frames in self:
self[frames] = self[frames] + value
else:
Mutant 195
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -121,7 +121,7 @@
"""
for cond_indep_stack, value in items:
frames = frozenset(f for f in cond_indep_stack if f.vectorized)
- assert all(f.dim < 0 and -value.dim() <= f.dim for f in frames)
+ assert all(f.dim < 1 and -value.dim() <= f.dim for f in frames)
if frames in self:
self[frames] = self[frames] + value
else:
Mutant 198
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -121,7 +121,7 @@
"""
for cond_indep_stack, value in items:
frames = frozenset(f for f in cond_indep_stack if f.vectorized)
- assert all(f.dim < 0 and -value.dim() <= f.dim for f in frames)
+ assert all(f.dim < 0 or -value.dim() <= f.dim for f in frames)
if frames in self:
self[frames] = self[frames] + value
else:
Mutant 206
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -131,7 +131,7 @@
total = None
for frames, value in self.items():
for f in frames:
- if f not in target_frames and value.shape[f.dim] != 1:
+ if f not in target_frames and value.shape[f.dim] != 2:
value = value.sum(f.dim, True)
while value.shape and value.shape[0] == 1:
value = value.squeeze(0)
Mutant 207
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -131,7 +131,7 @@
total = None
for frames, value in self.items():
for f in frames:
- if f not in target_frames and value.shape[f.dim] != 1:
+ if f not in target_frames or value.shape[f.dim] != 1:
value = value.sum(f.dim, True)
while value.shape and value.shape[0] == 1:
value = value.squeeze(0)
Mutant 208
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -132,7 +132,7 @@
for frames, value in self.items():
for f in frames:
if f not in target_frames and value.shape[f.dim] != 1:
- value = value.sum(f.dim, True)
+ value = value.sum(f.dim, False)
while value.shape and value.shape[0] == 1:
value = value.squeeze(0)
total = value if total is None else total + value
Mutant 211
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -133,7 +133,7 @@
for f in frames:
if f not in target_frames and value.shape[f.dim] != 1:
value = value.sum(f.dim, True)
- while value.shape and value.shape[0] == 1:
+ while value.shape and value.shape[0] != 1:
value = value.squeeze(0)
total = value if total is None else total + value
return total
Mutant 212
--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -133,7 +133,7 @@
for f in frames:
if f not in target_frames and value.shape[f.dim] != 1:
value = value.sum(f.dim, True)
- while value.shape and value.shape[0] == 1:
+ while value.shape and value.shape[0] == 2:
value = value.squeeze(0)
total = value if total is None else total + value
return total