sbi/inference/snre/snre_base.py

Killed 68 out of 102 mutants

Timeouts

Mutants that made the test suite take a lot longer so the tests were killed.

Mutant 132

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -270,7 +270,7 @@
 
         epoch, self._val_log_prob = 0, float("-Inf")
 
-        while epoch <= max_num_epochs and not self._converged(epoch, stop_after_epochs):
+        while epoch <= max_num_epochs or not self._converged(epoch, stop_after_epochs):
 
             # Train for a single epoch.
             self._posterior.net.train()

Mutant 186

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -403,7 +403,7 @@
         log_ratio = self.classifier(torch.cat((theta, x), dim=1).reshape(1, -1))
 
         # Notice opposite sign to pyro potential.
-        return log_ratio + self.prior.log_prob(theta)
+        return log_ratio - self.prior.log_prob(theta)
 
     def pyro_potential(self, theta: Dict[str, Tensor]) -> Tensor:
         r"""Return potential for Pyro sampler.

Survived

Survived mutation testing. These mutants show holes in your test suite.

Mutant 87

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -89,7 +89,7 @@
             self._build_neural_net = utils.classifier_nn(model=classifier)
         else:
             self._build_neural_net = classifier
-        self._posterior = None
+        self._posterior = ""
         self._sample_with_mcmc = True
         self._mcmc_method = mcmc_method
 

Mutant 91

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -130,7 +130,7 @@
             Posterior $p(\theta|x)$ that can be sampled and evaluated.
         """
 
-        max_num_epochs = 2 ** 31 - 1 if max_num_epochs is None else max_num_epochs
+        max_num_epochs = 3 ** 31 - 1 if max_num_epochs is None else max_num_epochs
 
         num_sims_per_round = self._ensure_list(num_simulations_per_round, num_rounds)
 

Mutant 92

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -130,7 +130,7 @@
             Posterior $p(\theta|x)$ that can be sampled and evaluated.
         """
 
-        max_num_epochs = 2 ** 31 - 1 if max_num_epochs is None else max_num_epochs
+        max_num_epochs = 2 * 31 - 1 if max_num_epochs is None else max_num_epochs
 
         num_sims_per_round = self._ensure_list(num_simulations_per_round, num_rounds)
 

Mutant 93

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -130,7 +130,7 @@
             Posterior $p(\theta|x)$ that can be sampled and evaluated.
         """
 
-        max_num_epochs = 2 ** 31 - 1 if max_num_epochs is None else max_num_epochs
+        max_num_epochs = 2 ** 32 - 1 if max_num_epochs is None else max_num_epochs
 
         num_sims_per_round = self._ensure_list(num_simulations_per_round, num_rounds)
 

Mutant 94

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -130,7 +130,7 @@
             Posterior $p(\theta|x)$ that can be sampled and evaluated.
         """
 
-        max_num_epochs = 2 ** 31 - 1 if max_num_epochs is None else max_num_epochs
+        max_num_epochs = 2 ** 31 + 1 if max_num_epochs is None else max_num_epochs
 
         num_sims_per_round = self._ensure_list(num_simulations_per_round, num_rounds)
 

Mutant 95

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -130,7 +130,7 @@
             Posterior $p(\theta|x)$ that can be sampled and evaluated.
         """
 
-        max_num_epochs = 2 ** 31 - 1 if max_num_epochs is None else max_num_epochs
+        max_num_epochs = 2 ** 31 - 2 if max_num_epochs is None else max_num_epochs
 
         num_sims_per_round = self._ensure_list(num_simulations_per_round, num_rounds)
 

Mutant 103

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -146,7 +146,7 @@
                 )
 
             x = self._batched_simulator(theta)
-            x_shape = x_shape_from_simulation(x)
+            x_shape = None
 
             # First round or if retraining from scratch:
             # Call the `self._build_neural_net` with the rounds' thetas and xs as

Mutant 109

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -226,7 +226,7 @@
         """
 
         # Starting index for the training set (1 = discard round-0 samples).
-        start_idx = int(discard_prior_samples and round_ > 0)
+        start_idx = int(discard_prior_samples and round_ >= 0)
         # Get total number of training examples.
         num_examples = sum(len(theta) for theta in self._theta_bank)
 

Mutant 110

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -226,7 +226,7 @@
         """
 
         # Starting index for the training set (1 = discard round-0 samples).
-        start_idx = int(discard_prior_samples and round_ > 0)
+        start_idx = int(discard_prior_samples and round_ > 1)
         # Get total number of training examples.
         num_examples = sum(len(theta) for theta in self._theta_bank)
 

Mutant 111

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -226,7 +226,7 @@
         """
 
         # Starting index for the training set (1 = discard round-0 samples).
-        start_idx = int(discard_prior_samples and round_ > 0)
+        start_idx = int(discard_prior_samples or round_ > 0)
         # Get total number of training examples.
         num_examples = sum(len(theta) for theta in self._theta_bank)
 

Mutant 112

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -226,7 +226,7 @@
         """
 
         # Starting index for the training set (1 = discard round-0 samples).
-        start_idx = int(discard_prior_samples and round_ > 0)
+        start_idx = None
         # Get total number of training examples.
         num_examples = sum(len(theta) for theta in self._theta_bank)
 

Mutant 119

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -233,7 +233,7 @@
         # Select random train and validation splits from (theta, x) pairs.
         permuted_indices = torch.randperm(num_examples)
         num_training_examples = int((1 - validation_fraction) * num_examples)
-        num_validation_examples = num_examples - num_training_examples
+        num_validation_examples = num_examples + num_training_examples
         train_indices, val_indices = (
             permuted_indices[:num_training_examples],
             permuted_indices[num_training_examples:],

Mutant 122

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -242,7 +242,7 @@
         clipped_batch_size = min(training_batch_size, num_validation_examples)
 
         # num_atoms = theta.shape[0]
-        clamp_and_warn("num_atoms", num_atoms, min_val=2, max_val=clipped_batch_size)
+        clamp_and_warn("XXnum_atomsXX", num_atoms, min_val=2, max_val=clipped_batch_size)
 
         # Dataset is shared for training and validation loaders.
         dataset = data.TensorDataset(

Mutant 123

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -242,7 +242,7 @@
         clipped_batch_size = min(training_batch_size, num_validation_examples)
 
         # num_atoms = theta.shape[0]
-        clamp_and_warn("num_atoms", num_atoms, min_val=2, max_val=clipped_batch_size)
+        clamp_and_warn("num_atoms", num_atoms, min_val=3, max_val=clipped_batch_size)
 
         # Dataset is shared for training and validation loaders.
         dataset = data.TensorDataset(

Mutant 124

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -253,7 +253,7 @@
         train_loader = data.DataLoader(
             dataset,
             batch_size=clipped_batch_size,
-            drop_last=True,
+            drop_last=False,
             sampler=SubsetRandomSampler(train_indices),
         )
         val_loader = data.DataLoader(

Mutant 126

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -260,7 +260,7 @@
             dataset,
             batch_size=clipped_batch_size,
             shuffle=False,
-            drop_last=False,
+            drop_last=True,
             sampler=SubsetRandomSampler(val_indices),
         )
 

Mutant 130

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -270,7 +270,7 @@
 
         epoch, self._val_log_prob = 0, float("-Inf")
 
-        while epoch <= max_num_epochs and not self._converged(epoch, stop_after_epochs):
+        while epoch < max_num_epochs and not self._converged(epoch, stop_after_epochs):
 
             # Train for a single epoch.
             self._posterior.net.train()

Mutant 137

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -288,7 +288,7 @@
                     )
                 optimizer.step()
 
-            epoch += 1
+            epoch = 1
 
             # Calculate validation performance.
             self._posterior.net.eval()

Mutant 138

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -288,7 +288,7 @@
                     )
                 optimizer.step()
 
-            epoch += 1
+            epoch -= 1
 
             # Calculate validation performance.
             self._posterior.net.eval()

Mutant 139

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -288,7 +288,7 @@
                     )
                 optimizer.step()
 
-            epoch += 1
+            epoch += 2
 
             # Calculate validation performance.
             self._posterior.net.eval()

Mutant 140

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -292,7 +292,7 @@
 
             # Calculate validation performance.
             self._posterior.net.eval()
-            log_prob_sum = 0
+            log_prob_sum = 1
             with torch.no_grad():
                 for batch in val_loader:
                     theta_batch, x_batch = (

Mutant 147

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -301,7 +301,7 @@
                     )
                     log_prob = self._loss(theta_batch, x_batch, num_atoms)
                     log_prob_sum -= log_prob.sum().item()
-                self._val_log_prob = log_prob_sum / num_validation_examples
+                self._val_log_prob = log_prob_sum * num_validation_examples
 
             self._maybe_show_progress(self._show_progress_bars, epoch)
 

Mutant 153

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -320,7 +320,7 @@
         repeated_x = utils.repeat_rows(x, num_atoms)
 
         # Choose `1` or `num_atoms - 1` thetas from the rest of the batch for each x.
-        probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1)
+        probs = ones(batch_size, batch_size) * (2 - eye(batch_size)) / (batch_size - 1)
 
         choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False)
 

Mutant 154

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -320,7 +320,7 @@
         repeated_x = utils.repeat_rows(x, num_atoms)
 
         # Choose `1` or `num_atoms - 1` thetas from the rest of the batch for each x.
-        probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1)
+        probs = ones(batch_size, batch_size) * (1 + eye(batch_size)) / (batch_size - 1)
 
         choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False)
 

Mutant 155

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -320,7 +320,7 @@
         repeated_x = utils.repeat_rows(x, num_atoms)
 
         # Choose `1` or `num_atoms - 1` thetas from the rest of the batch for each x.
-        probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1)
+        probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) * (batch_size - 1)
 
         choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False)
 

Mutant 156

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -320,7 +320,7 @@
         repeated_x = utils.repeat_rows(x, num_atoms)
 
         # Choose `1` or `num_atoms - 1` thetas from the rest of the batch for each x.
-        probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1)
+        probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size + 1)
 
         choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False)
 

Mutant 157

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -320,7 +320,7 @@
         repeated_x = utils.repeat_rows(x, num_atoms)
 
         # Choose `1` or `num_atoms - 1` thetas from the rest of the batch for each x.
-        probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1)
+        probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 2)
 
         choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False)
 

Mutant 161

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -322,7 +322,7 @@
         # Choose `1` or `num_atoms - 1` thetas from the rest of the batch for each x.
         probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1)
 
-        choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False)
+        choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=True)
 
         contrasting_theta = theta[choices]
 

Mutant 170

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -334,7 +334,6 @@
 
         return self._posterior.net(theta_and_x)
 
-    @abstractmethod
     def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
         raise NotImplementedError
 

Mutant 175

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -380,7 +380,7 @@
         self.prior = prior
         self.x = x
 
-        if mcmc_method in ("slice", "hmc", "nuts"):
+        if mcmc_method in ("XXsliceXX", "hmc", "nuts"):
             return self.pyro_potential
         else:
             return self.np_potential

Mutant 176

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -380,7 +380,7 @@
         self.prior = prior
         self.x = x
 
-        if mcmc_method in ("slice", "hmc", "nuts"):
+        if mcmc_method in ("slice", "XXhmcXX", "nuts"):
             return self.pyro_potential
         else:
             return self.np_potential

Mutant 177

--- sbi/inference/snre/snre_base.py
+++ sbi/inference/snre/snre_base.py
@@ -380,7 +380,7 @@
         self.prior = prior
         self.x = x
 
-        if mcmc_method in ("slice", "hmc", "nuts"):
+        if mcmc_method in ("slice", "hmc", "XXnutsXX"):
             return self.pyro_potential
         else:
             return self.np_potential