sbi/inference/snre/snre_b.py

Killed 17 out of 20 mutants

Timeouts

Mutants that made the test suite take a lot longer so the tests were killed.

Mutant 100

--- sbi/inference/snre/snre_b.py
+++ sbi/inference/snre/snre_b.py
@@ -147,7 +147,7 @@
 
         # Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence the
         # "correct" one for the 1-out-of-N classification.
-        log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1)
+        log_prob = logits[:, 0] + torch.logsumexp(logits, dim=-1)
 
         return -torch.mean(log_prob)
 

Survived

Survived mutation testing. These mutants show holes in your test suite.

Mutant 94

--- sbi/inference/snre/snre_b.py
+++ sbi/inference/snre/snre_b.py
@@ -136,7 +136,7 @@
         pairs was sampled from the joint $p(\theta,x)$.
         """
 
-        assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
+        assert theta.shape[0] == x.shape[0], "XXBatch sizes for theta and x must match.XX"
         batch_size = theta.shape[0]
         logits = self._classifier_logits(theta, x, num_atoms)
 

Mutant 101

--- sbi/inference/snre/snre_b.py
+++ sbi/inference/snre/snre_b.py
@@ -147,7 +147,7 @@
 
         # Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence the
         # "correct" one for the 1-out-of-N classification.
-        log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1)
+        log_prob = logits[:, 0] - torch.logsumexp(logits, dim=+1)
 
         return -torch.mean(log_prob)