sbi/inference/snre/snre_b.py
Killed 17 out of 20 mutantsTimeouts
Mutants that made the test suite take a lot longer so the tests were killed.Mutant 204
--- sbi/inference/snre/snre_b.py
+++ sbi/inference/snre/snre_b.py
@@ -147,7 +147,7 @@
# Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence the
# "correct" one for the 1-out-of-N classification.
- log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1)
+ log_prob = logits[:, 0] + torch.logsumexp(logits, dim=-1)
return -torch.mean(log_prob)
Survived
Survived mutation testing. These mutants show holes in your test suite.Mutant 198
--- sbi/inference/snre/snre_b.py
+++ sbi/inference/snre/snre_b.py
@@ -136,7 +136,7 @@
pairs was sampled from the joint $p(\theta,x)$.
"""
- assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
+ assert theta.shape[0] == x.shape[0], "XXBatch sizes for theta and x must match.XX"
batch_size = theta.shape[0]
logits = self._classifier_logits(theta, x, num_atoms)
Mutant 205
--- sbi/inference/snre/snre_b.py
+++ sbi/inference/snre/snre_b.py
@@ -147,7 +147,7 @@
# Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence the
# "correct" one for the 1-out-of-N classification.
- log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1)
+ log_prob = logits[:, 0] - torch.logsumexp(logits, dim=+1)
return -torch.mean(log_prob)