pyro/infer/util.py

Killed 25 out of 41 mutants

Survived

Survived mutation testing. These mutants show holes in your test suite.

Mutant 255

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -16,7 +16,7 @@
 from pyro.ops.rings import MarginalRing
 from pyro.poutine.util import site_is_subsample
 
-_VALIDATION_ENABLED = False
+_VALIDATION_ENABLED = True
 LAST_CACHE_SIZE = [Counter()]  # for profiling
 
 

Mutant 256

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -16,7 +16,7 @@
 from pyro.ops.rings import MarginalRing
 from pyro.poutine.util import site_is_subsample
 
-_VALIDATION_ENABLED = False
+_VALIDATION_ENABLED = None
 LAST_CACHE_SIZE = [Counter()]  # for profiling
 
 

Mutant 257

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -17,7 +17,7 @@
 from pyro.poutine.util import site_is_subsample
 
 _VALIDATION_ENABLED = False
-LAST_CACHE_SIZE = [Counter()]  # for profiling
+LAST_CACHE_SIZE = None  # for profiling
 
 
 def enable_validation(is_validate):

Mutant 258

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -22,7 +22,7 @@
 
 def enable_validation(is_validate):
     global _VALIDATION_ENABLED
-    _VALIDATION_ENABLED = is_validate
+    _VALIDATION_ENABLED = None
 
 
 def is_validation_enabled():

Mutant 259

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -28,8 +28,6 @@
 def is_validation_enabled():
     return _VALIDATION_ENABLED
 
-
-@contextmanager
 def validation_enabled(is_validate=True):
     old = is_validation_enabled()
     try:

Mutant 260

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -30,7 +30,7 @@
 
 
 @contextmanager
-def validation_enabled(is_validate=True):
+def validation_enabled(is_validate=False):
     old = is_validation_enabled()
     try:
         enable_validation(is_validate)

Mutant 261

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -51,7 +51,7 @@
     Like ``x.backward()`` for a :class:`~torch.Tensor`, but also accepts
     numbers and tensors without grad_fn (resulting in a no-op)
     """
-    if torch.is_tensor(x) and x.grad_fn:
+    if torch.is_tensor(x) or x.grad_fn:
         x.backward(retain_graph=retain_graph)
 
 

Mutant 263

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -79,7 +79,7 @@
     """
     for p in tensors:
         if p.grad is not None:
-            p.grad = torch.zeros_like(p.grad)
+            p.grad = None
 
 
 def get_plate_stacks(trace):

Mutant 271

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -121,7 +121,7 @@
         """
         for cond_indep_stack, value in items:
             frames = frozenset(f for f in cond_indep_stack if f.vectorized)
-            assert all(f.dim < 0 and -value.dim() <= f.dim for f in frames)
+            assert all(f.dim <= 0 and -value.dim() <= f.dim for f in frames)
             if frames in self:
                 self[frames] = self[frames] + value
             else:

Mutant 272

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -121,7 +121,7 @@
         """
         for cond_indep_stack, value in items:
             frames = frozenset(f for f in cond_indep_stack if f.vectorized)
-            assert all(f.dim < 0 and -value.dim() <= f.dim for f in frames)
+            assert all(f.dim < 1 and -value.dim() <= f.dim for f in frames)
             if frames in self:
                 self[frames] = self[frames] + value
             else:

Mutant 275

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -121,7 +121,7 @@
         """
         for cond_indep_stack, value in items:
             frames = frozenset(f for f in cond_indep_stack if f.vectorized)
-            assert all(f.dim < 0 and -value.dim() <= f.dim for f in frames)
+            assert all(f.dim < 0 or -value.dim() <= f.dim for f in frames)
             if frames in self:
                 self[frames] = self[frames] + value
             else:

Mutant 283

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -131,7 +131,7 @@
         total = None
         for frames, value in self.items():
             for f in frames:
-                if f not in target_frames and value.shape[f.dim] != 1:
+                if f not in target_frames and value.shape[f.dim] != 2:
                     value = value.sum(f.dim, True)
             while value.shape and value.shape[0] == 1:
                 value = value.squeeze(0)

Mutant 284

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -131,7 +131,7 @@
         total = None
         for frames, value in self.items():
             for f in frames:
-                if f not in target_frames and value.shape[f.dim] != 1:
+                if f not in target_frames or value.shape[f.dim] != 1:
                     value = value.sum(f.dim, True)
             while value.shape and value.shape[0] == 1:
                 value = value.squeeze(0)

Mutant 285

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -132,7 +132,7 @@
         for frames, value in self.items():
             for f in frames:
                 if f not in target_frames and value.shape[f.dim] != 1:
-                    value = value.sum(f.dim, True)
+                    value = value.sum(f.dim, False)
             while value.shape and value.shape[0] == 1:
                 value = value.squeeze(0)
             total = value if total is None else total + value

Mutant 288

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -133,7 +133,7 @@
             for f in frames:
                 if f not in target_frames and value.shape[f.dim] != 1:
                     value = value.sum(f.dim, True)
-            while value.shape and value.shape[0] == 1:
+            while value.shape and value.shape[0] != 1:
                 value = value.squeeze(0)
             total = value if total is None else total + value
         return total

Mutant 289

--- pyro/infer/util.py
+++ pyro/infer/util.py
@@ -133,7 +133,7 @@
             for f in frames:
                 if f not in target_frames and value.shape[f.dim] != 1:
                     value = value.sum(f.dim, True)
-            while value.shape and value.shape[0] == 1:
+            while value.shape and value.shape[0] == 2:
                 value = value.squeeze(0)
             total = value if total is None else total + value
         return total