volume_segmantics.model

1from volume_segmantics.model.operations.vol_seg_2d_trainer import \
2    VolSeg2dTrainer
3from volume_segmantics.model.operations.vol_seg_2d_predictor import \
4    VolSeg2dPredictor
5from volume_segmantics.model.operations.vol_seg_prediction_manager import \
6    VolSeg2DPredictionManager
7
8__all__ = [VolSeg2dTrainer, VolSeg2dPredictor, VolSeg2DPredictionManager]
class VolSeg2dTrainer:
 34class VolSeg2dTrainer:
 35    """Class that utlises 2d dataloaders to train a 2d deep learning model.
 36
 37    Args:
 38        sampler
 39        settings
 40    """
 41
 42    def __init__(
 43        self, image_dir_path, label_dir_path, labels: Union[int, dict], settings
 44    ):
 45        self.training_loader, self.validation_loader = get_2d_training_dataloaders(
 46            image_dir_path, label_dir_path, settings
 47        )
 48        self.label_no = labels if isinstance(labels, int) else len(labels)
 49        self.codes = labels if isinstance(labels, dict) else {}
 50        self.settings = settings
 51        # Params for learning rate finder
 52        self.starting_lr = float(settings.starting_lr)
 53        self.end_lr = float(settings.end_lr)
 54        self.log_lr_ratio = self.calculate_log_lr_ratio()
 55        self.lr_find_epochs = settings.lr_find_epochs
 56        self.lr_reduce_factor = settings.lr_reduce_factor
 57        # Params for model training
 58        self.model_device_num = int(settings.cuda_device)
 59        self.patience = settings.patience
 60        self.loss_criterion = self.get_loss_criterion()
 61        self.eval_metric = self.get_eval_metric()
 62        self.model_struc_dict = self.get_model_struc_dict(settings)
 63        self.avg_train_losses = []  # per epoch training loss
 64        self.avg_valid_losses = []  #  per epoch validation loss
 65        self.avg_eval_scores = []  #  per epoch evaluation score
 66
 67    def get_model_struc_dict(self, settings):
 68        model_struc_dict = settings.model
 69        model_type = utils.get_model_type(settings)
 70        model_struc_dict["type"] = model_type
 71        model_struc_dict["in_channels"] = cfg.MODEL_INPUT_CHANNELS
 72        model_struc_dict["classes"] = self.label_no
 73        return model_struc_dict
 74
 75    def calculate_log_lr_ratio(self):
 76        return math.log(self.end_lr / self.starting_lr)
 77
 78    def create_model_and_optimiser(self, learning_rate, frozen=False):
 79        logging.info(f"Setting up the model on device {self.settings.cuda_device}.")
 80        self.model = create_model_on_device(
 81            self.model_device_num, self.model_struc_dict
 82        )
 83        if frozen:
 84            self.freeze_model()
 85        logging.info(
 86            f"Model has {self.count_trainable_parameters()} trainable parameters, {self.count_parameters()} total parameters."
 87        )
 88        self.optimizer = self.create_optimizer(learning_rate)
 89        logging.info("Trainer created.")
 90
 91    def freeze_model(self):
 92        logging.info(
 93            f"Freezing model with {self.count_trainable_parameters()} trainable parameters, {self.count_parameters()} total parameters."
 94        )
 95        for name, param in self.model.named_parameters():
 96            if all(["encoder" in name, "conv" in name]) and param.requires_grad:
 97                param.requires_grad = False
 98
 99    def unfreeze_model(self):
100        logging.info(
101            f"Unfreezing model with {self.count_trainable_parameters()} trainable parameters, {self.count_parameters()} total parameters."
102        )
103        for name, param in self.model.named_parameters():
104            if all(["encoder" in name, "conv" in name]) and not param.requires_grad:
105                param.requires_grad = True
106
107    def count_trainable_parameters(self) -> int:
108        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
109
110    def count_parameters(self) -> int:
111        return sum(p.numel() for p in self.model.parameters())
112
113    def get_loss_criterion(self):
114        if self.settings.loss_criterion == "BCEDiceLoss":
115            alpha = self.settings.alpha
116            beta = self.settings.beta
117            logging.info(
118                f"Using combined BCE and Dice loss with weighting of {alpha}*BCE "
119                f"and {beta}*Dice"
120            )
121            loss_criterion = BCEDiceLoss(alpha, beta)
122        elif self.settings.loss_criterion == "DiceLoss":
123            logging.info("Using DiceLoss")
124            loss_criterion = DiceLoss(normalization="none")
125        elif self.settings.loss_criterion == "BCELoss":
126            logging.info("Using BCELoss")
127            loss_criterion = nn.BCEWithLogitsLoss()
128        elif self.settings.loss_criterion == "CrossEntropyLoss":
129            logging.info("Using CrossEntropyLoss")
130            loss_criterion = nn.CrossEntropyLoss()
131        elif self.settings.loss_criterion == "GeneralizedDiceLoss":
132            logging.info("Using GeneralizedDiceLoss")
133            loss_criterion = GeneralizedDiceLoss()
134        else:
135            logging.error("No loss criterion specified, exiting")
136            sys.exit(1)
137        return loss_criterion
138
139    def get_eval_metric(self):
140        # Get evaluation metric
141        if self.settings.eval_metric == "MeanIoU":
142            logging.info("Using MeanIoU")
143            eval_metric = MeanIoU()
144        elif self.settings.eval_metric == "GenericAveragePrecision":
145            logging.info("Using GenericAveragePrecision")
146            eval_metric = GenericAveragePrecision()
147        else:
148            logging.error("No evaluation metric specified, exiting")
149            sys.exit(1)
150        return eval_metric
151
152    def train_model(self, output_path, num_epochs, patience, create=True, frozen=False):
153        """Performs training of model for a number of cycles
154        with a learning rate that is determined automatically.
155        """
156        train_losses = []
157        valid_losses = []
158        eval_scores = []
159
160        if create:
161            self.create_model_and_optimiser(self.starting_lr, frozen=frozen)
162            lr_to_use = self.run_lr_finder()
163            # Recreate model and start training
164            self.create_model_and_optimiser(lr_to_use, frozen=frozen)
165            early_stopping = self.create_early_stopping(output_path, patience)
166        else:
167            # Reduce starting LR, since model alreadiy partiallly trained
168            self.starting_lr /= self.lr_reduce_factor
169            self.end_lr /= self.lr_reduce_factor
170            self.log_lr_ratio = self.calculate_log_lr_ratio()
171            self.load_in_model_and_optimizer(
172                self.starting_lr, output_path, frozen=frozen, optimizer=False
173            )
174            lr_to_use = self.run_lr_finder()
175            min_loss = self.load_in_model_and_optimizer(
176                self.starting_lr, output_path, frozen=frozen, optimizer=False
177            )
178            early_stopping = self.create_early_stopping(
179                output_path, patience, best_score=-min_loss
180            )
181
182        # Initialise the One Cycle learning rate scheduler
183        lr_scheduler = self.create_oc_lr_scheduler(num_epochs, lr_to_use)
184
185        for epoch in range(1, num_epochs + 1):
186            self.model.train()
187            tic = time.perf_counter()
188            logging.info(f"Epoch {epoch} of {num_epochs}")
189            for batch in tqdm(
190                self.training_loader,
191                desc="Training batch",
192                bar_format=cfg.TQDM_BAR_FORMAT,
193            ):
194                loss = self.train_one_batch(lr_scheduler, batch)
195                train_losses.append(loss.item())  # record training loss
196
197            self.model.eval()  # prep model for evaluation
198            with torch.no_grad():
199                for batch in tqdm(
200                    self.validation_loader,
201                    desc="Validation batch",
202                    bar_format=cfg.TQDM_BAR_FORMAT,
203                ):
204                    inputs, targets = utils.prepare_training_batch(
205                        batch, self.model_device_num, self.label_no
206                    )
207                    output = self.model(inputs)  # Forward pass
208                    # calculate the loss
209                    if self.settings.loss_criterion == "CrossEntropyLoss":
210                        loss = self.loss_criterion(output, torch.argmax(targets, dim=1))
211                    else:
212                        loss = self.loss_criterion(output, targets.float())
213                    valid_losses.append(loss.item())  # record validation loss
214                    s_max = nn.Softmax(dim=1)
215                    probs = s_max(output)  # Convert the logits to probs
216                    probs = torch.unsqueeze(probs, 2)
217                    targets = torch.unsqueeze(targets, 2)
218                    eval_score = self.eval_metric(probs, targets)
219                    eval_scores.append(eval_score)  # record eval metric
220
221            toc = time.perf_counter()
222            # calculate average loss/metric over an epoch
223            self.avg_train_losses.append(np.average(train_losses))
224            self.avg_valid_losses.append(np.average(valid_losses))
225            self.avg_eval_scores.append(np.average(eval_scores))
226            logging.info(
227                f"Epoch {epoch}. Training loss: {self.avg_train_losses[-1]}, Validation Loss: "
228                f"{self.avg_valid_losses[-1]}. {self.settings.eval_metric}: {self.avg_eval_scores[-1]}"
229            )
230            logging.info(f"Time taken for epoch {epoch}: {toc - tic:0.2f} seconds")
231            # clear lists to track next epoch
232            train_losses = []
233            valid_losses = []
234            eval_scores = []
235
236            # early_stopping needs the validation loss to check if it has decreased,
237            # and if it has, it will make a checkpoint of the current model
238            early_stopping(
239                self.avg_valid_losses[-1], self.model, self.optimizer, self.codes
240            )
241
242            if early_stopping.early_stop:
243                logging.info("Early stopping")
244                break
245
246        # load the last checkpoint with the best model
247        self.load_in_weights(output_path)
248
249    def load_in_model_and_optimizer(
250        self, learning_rate, output_path, frozen=False, optimizer=False
251    ):
252        self.create_model_and_optimiser(learning_rate, frozen=frozen)
253        logging.info("Loading in weights from saved checkpoint.")
254        loss_val = self.load_in_weights(output_path, optimizer=optimizer)
255        return loss_val
256
257    def load_in_weights(self, output_path, optimizer=False, gpu=True):
258        # load the last checkpoint with the best model
259        if gpu:
260            map_location = f"cuda:{self.model_device_num}"
261        else:
262            map_location = "cpu"
263        model_dict = torch.load(output_path, map_location=map_location)
264        logging.info("Loading model weights.")
265        self.model.load_state_dict(model_dict["model_state_dict"])
266        if optimizer:
267            logging.info("Loading optimizer weights.")
268            self.optimizer.load_state_dict(model_dict["optimizer_state_dict"])
269        return model_dict.get("loss_val", np.inf)
270
271    def run_lr_finder(self):
272        logging.info("Finding learning rate for model.")
273        lr_scheduler = self.create_exponential_lr_scheduler()
274        lr_find_loss, lr_find_lr = self.lr_finder(lr_scheduler)
275        lr_to_use = self.find_lr_from_graph(lr_find_loss, lr_find_lr)
276        logging.info(f"LR to use {lr_to_use}")
277        return lr_to_use
278
279    def lr_finder(self, lr_scheduler, smoothing=0.05, plt_fig=True):
280        lr_find_loss = []
281        lr_find_lr = []
282        iters = 0
283
284        self.model.train()
285        logging.info(
286            f"Training for {self.lr_find_epochs} epochs to create a learning "
287            "rate plot."
288        )
289        for i in range(self.lr_find_epochs):
290            for batch in tqdm(
291                self.training_loader,
292                desc=f"Epoch {i + 1}, batch number",
293                bar_format=cfg.TQDM_BAR_FORMAT,
294            ):
295                loss = self.train_one_batch(lr_scheduler, batch)
296                lr_step = self.optimizer.state_dict()["param_groups"][0]["lr"]
297                lr_find_lr.append(lr_step)
298                if iters == 0:
299                    lr_find_loss.append(loss)
300                else:
301                    loss = smoothing * loss + (1 - smoothing) * lr_find_loss[-1]
302                    lr_find_loss.append(loss)
303                if loss > 1 and iters > len(self.training_loader) // 1.333:
304                    break
305                iters += 1
306
307        if plt_fig:
308            fig = tpl.figure()
309            fig.plot(
310                np.log10(lr_find_lr),
311                lr_find_loss,
312                width=50,
313                height=30,
314                xlabel="Log10 Learning Rate",
315            )
316            fig.show()
317
318        return lr_find_loss, lr_find_lr
319
320    @staticmethod
321    def find_lr_from_graph(
322        lr_find_loss: torch.Tensor, lr_find_lr: torch.Tensor
323    ) -> float:
324        """Calculates learning rate corresponsing to minimum gradient in graph
325        of loss vs learning rate.
326
327        Args:
328            lr_find_loss (torch.Tensor): Loss values accumulated during training
329            lr_find_lr (torch.Tensor): Learning rate used for mini-batch
330
331        Returns:
332            float: The learning rate at the point when loss was falling most steeply
333            divided by a fudge factor.
334        """
335        default_min_lr = cfg.DEFAULT_MIN_LR  # Add as default value to fix bug
336        # Get loss values and their corresponding gradients, and get lr values
337        for i in range(0, len(lr_find_loss)):
338            if lr_find_loss[i].is_cuda:
339                lr_find_loss[i] = lr_find_loss[i].cpu()
340            lr_find_loss[i] = lr_find_loss[i].detach().numpy()
341        losses = np.array(lr_find_loss)
342        try:
343            gradients = np.gradient(losses)
344            min_gradient = gradients.min()
345            if min_gradient < 0:
346                min_loss_grad_idx = gradients.argmin()
347            else:
348                logging.info(
349                    f"Minimum gradient: {min_gradient} was positive, returning default value instead."
350                )
351                return default_min_lr
352        except Exception as e:
353            logging.info(f"Failed to compute gradients, returning default value. {e}")
354            return default_min_lr
355        min_lr = lr_find_lr[min_loss_grad_idx]
356        return min_lr / cfg.LR_DIVISOR
357
358    def lr_exp_stepper(self, x):
359        """Exponentially increase learning rate as part of strategy to find the
360        optimum.
361        Taken from
362        https://towardsdatascience.com/adaptive-and-cyclical-learning-rates-using-pytorch-2bf904d18dee
363        """
364        return math.exp(
365            x * self.log_lr_ratio / (self.lr_find_epochs * len(self.training_loader))
366        )
367
368    def create_optimizer(self, learning_rate):
369        return torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
370
371    def create_exponential_lr_scheduler(self):
372        return torch.optim.lr_scheduler.LambdaLR(self.optimizer, self.lr_exp_stepper)
373
374    def create_oc_lr_scheduler(self, num_epochs, lr_to_use):
375        return torch.optim.lr_scheduler.OneCycleLR(
376            self.optimizer,
377            max_lr=lr_to_use,
378            steps_per_epoch=len(self.training_loader),
379            epochs=num_epochs,
380            pct_start=self.settings.pct_lr_inc,
381        )
382
383    def create_early_stopping(self, output_path, patience, best_score=None):
384        return EarlyStopping(
385            patience=patience,
386            verbose=True,
387            path=output_path,
388            model_dict=self.model_struc_dict,
389            best_score=best_score,
390        )
391
392    def train_one_batch(self, lr_scheduler, batch):
393        inputs, targets = utils.prepare_training_batch(
394            batch, self.model_device_num, self.label_no
395        )
396        self.optimizer.zero_grad()
397        output = self.model(inputs)  # Forward pass
398        if self.settings.loss_criterion == "CrossEntropyLoss":
399            loss = self.loss_criterion(output, torch.argmax(targets, dim=1))
400        else:
401            loss = self.loss_criterion(output, targets.float())
402        loss.backward()  # Backward pass
403        self.optimizer.step()
404        lr_scheduler.step()  # update the learning rate
405        return loss
406
407    def output_loss_fig(self, model_out_path):
408
409        fig = plt.figure(figsize=(10, 8))
410        plt.plot(
411            range(1, len(self.avg_train_losses) + 1),
412            self.avg_train_losses,
413            label="Training Loss",
414        )
415        plt.plot(
416            range(1, len(self.avg_valid_losses) + 1),
417            self.avg_valid_losses,
418            label="Validation Loss",
419        )
420
421        minposs = (
422            self.avg_valid_losses.index(min(self.avg_valid_losses)) + 1
423        )  # find position of lowest validation loss
424        plt.axvline(
425            minposs, linestyle="--", color="r", label="Early Stopping Checkpoint"
426        )
427
428        plt.xlabel("epochs")
429        plt.ylabel("loss")
430        plt.xlim(0, len(self.avg_train_losses) + 1)  # consistent scale
431        plt.grid(True)
432        plt.legend()
433        plt.tight_layout()
434        output_dir = model_out_path.parent
435        fig_out_pth = output_dir / f"{model_out_path.stem}_loss_plot.png"
436        logging.info(f"Saving figure of training/validation losses to {fig_out_pth}")
437        fig.savefig(fig_out_pth, bbox_inches="tight")
438        # Output a list of training stats
439        epoch_lst = range(len(self.avg_train_losses))
440        rows = zip(
441            epoch_lst,
442            self.avg_train_losses,
443            self.avg_valid_losses,
444            self.avg_eval_scores,
445        )
446        with open(output_dir / f"{model_out_path.stem}_train_stats.csv", "w") as f:
447            writer = csv.writer(f)
448            writer.writerow(("Epoch", "Train Loss", "Valid Loss", "Eval Score"))
449            for row in rows:
450                writer.writerow(row)
451
452    def output_prediction_figure(self, model_path):
453        """Saves a figure containing image slice data for three random images
454        fromthe validation dataset along with the corresponding ground truth
455        label image and corresponding prediction output from the model attached
456        to this class instance. The image is saved to the same directory as the
457        model weights.
458
459        Args:
460            model_path (pathlib.Path): Full path to the model weights file,
461            this is used to get the directory and name of the model not to
462            load and predict.
463        """
464        self.model.eval()  # prep model for evaluation
465        batch = next(iter(self.validation_loader))  # Get first batch
466        with torch.no_grad():
467            inputs, targets = utils.prepare_training_batch(
468                batch, self.model_device_num, self.label_no
469            )
470            output = self.model(inputs)  # Forward pass
471            s_max = nn.Softmax(dim=1)
472            probs = s_max(output)  # Convert the logits to probs
473            labels = torch.argmax(probs, dim=1)  # flatten channels
474
475        # Create the plot
476        bs = self.validation_loader.batch_size
477        if bs < 4:
478            rows = bs
479        else:
480            rows = 4
481        fig = plt.figure(figsize=(12, 16))
482        columns = 3
483        j = 0
484        for i in range(columns * rows)[::3]:
485            img = inputs[j].squeeze().cpu()
486            gt = torch.argmax(targets[j], dim=0).cpu()
487            pred = labels[j].cpu()
488            col1 = fig.add_subplot(rows, columns, i + 1)
489            plt.imshow(img, cmap="gray")
490            col2 = fig.add_subplot(rows, columns, i + 2)
491            plt.imshow(gt, cmap="gray")
492            col3 = fig.add_subplot(rows, columns, i + 3)
493            plt.imshow(pred, cmap="gray")
494            j += 1
495            if i == 0:
496                col1.title.set_text("Data")
497                col2.title.set_text("Ground Truth")
498                col3.title.set_text("Prediction")
499        plt.suptitle(f"Predictions for {model_path.name}", fontsize=16)
500        plt_out_pth = model_path.parent / f"{model_path.stem}_prediction_image.png"
501        logging.info(f"Saving example image predictions to {plt_out_pth}")
502        plt.savefig(plt_out_pth, dpi=300)

Class that utlises 2d dataloaders to train a 2d deep learning model.

Args
  • sampler
  • settings
VolSeg2dTrainer(image_dir_path, label_dir_path, labels: Union[int, dict], settings)
42    def __init__(
43        self, image_dir_path, label_dir_path, labels: Union[int, dict], settings
44    ):
45        self.training_loader, self.validation_loader = get_2d_training_dataloaders(
46            image_dir_path, label_dir_path, settings
47        )
48        self.label_no = labels if isinstance(labels, int) else len(labels)
49        self.codes = labels if isinstance(labels, dict) else {}
50        self.settings = settings
51        # Params for learning rate finder
52        self.starting_lr = float(settings.starting_lr)
53        self.end_lr = float(settings.end_lr)
54        self.log_lr_ratio = self.calculate_log_lr_ratio()
55        self.lr_find_epochs = settings.lr_find_epochs
56        self.lr_reduce_factor = settings.lr_reduce_factor
57        # Params for model training
58        self.model_device_num = int(settings.cuda_device)
59        self.patience = settings.patience
60        self.loss_criterion = self.get_loss_criterion()
61        self.eval_metric = self.get_eval_metric()
62        self.model_struc_dict = self.get_model_struc_dict(settings)
63        self.avg_train_losses = []  # per epoch training loss
64        self.avg_valid_losses = []  #  per epoch validation loss
65        self.avg_eval_scores = []  #  per epoch evaluation score
def get_model_struc_dict(self, settings)
67    def get_model_struc_dict(self, settings):
68        model_struc_dict = settings.model
69        model_type = utils.get_model_type(settings)
70        model_struc_dict["type"] = model_type
71        model_struc_dict["in_channels"] = cfg.MODEL_INPUT_CHANNELS
72        model_struc_dict["classes"] = self.label_no
73        return model_struc_dict
def calculate_log_lr_ratio(self)
75    def calculate_log_lr_ratio(self):
76        return math.log(self.end_lr / self.starting_lr)
def create_model_and_optimiser(self, learning_rate, frozen=False)
78    def create_model_and_optimiser(self, learning_rate, frozen=False):
79        logging.info(f"Setting up the model on device {self.settings.cuda_device}.")
80        self.model = create_model_on_device(
81            self.model_device_num, self.model_struc_dict
82        )
83        if frozen:
84            self.freeze_model()
85        logging.info(
86            f"Model has {self.count_trainable_parameters()} trainable parameters, {self.count_parameters()} total parameters."
87        )
88        self.optimizer = self.create_optimizer(learning_rate)
89        logging.info("Trainer created.")
def freeze_model(self)
91    def freeze_model(self):
92        logging.info(
93            f"Freezing model with {self.count_trainable_parameters()} trainable parameters, {self.count_parameters()} total parameters."
94        )
95        for name, param in self.model.named_parameters():
96            if all(["encoder" in name, "conv" in name]) and param.requires_grad:
97                param.requires_grad = False
def unfreeze_model(self)
 99    def unfreeze_model(self):
100        logging.info(
101            f"Unfreezing model with {self.count_trainable_parameters()} trainable parameters, {self.count_parameters()} total parameters."
102        )
103        for name, param in self.model.named_parameters():
104            if all(["encoder" in name, "conv" in name]) and not param.requires_grad:
105                param.requires_grad = True
def count_trainable_parameters(self) -> int:
107    def count_trainable_parameters(self) -> int:
108        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
def count_parameters(self) -> int:
110    def count_parameters(self) -> int:
111        return sum(p.numel() for p in self.model.parameters())
def get_loss_criterion(self)
113    def get_loss_criterion(self):
114        if self.settings.loss_criterion == "BCEDiceLoss":
115            alpha = self.settings.alpha
116            beta = self.settings.beta
117            logging.info(
118                f"Using combined BCE and Dice loss with weighting of {alpha}*BCE "
119                f"and {beta}*Dice"
120            )
121            loss_criterion = BCEDiceLoss(alpha, beta)
122        elif self.settings.loss_criterion == "DiceLoss":
123            logging.info("Using DiceLoss")
124            loss_criterion = DiceLoss(normalization="none")
125        elif self.settings.loss_criterion == "BCELoss":
126            logging.info("Using BCELoss")
127            loss_criterion = nn.BCEWithLogitsLoss()
128        elif self.settings.loss_criterion == "CrossEntropyLoss":
129            logging.info("Using CrossEntropyLoss")
130            loss_criterion = nn.CrossEntropyLoss()
131        elif self.settings.loss_criterion == "GeneralizedDiceLoss":
132            logging.info("Using GeneralizedDiceLoss")
133            loss_criterion = GeneralizedDiceLoss()
134        else:
135            logging.error("No loss criterion specified, exiting")
136            sys.exit(1)
137        return loss_criterion
def get_eval_metric(self)
139    def get_eval_metric(self):
140        # Get evaluation metric
141        if self.settings.eval_metric == "MeanIoU":
142            logging.info("Using MeanIoU")
143            eval_metric = MeanIoU()
144        elif self.settings.eval_metric == "GenericAveragePrecision":
145            logging.info("Using GenericAveragePrecision")
146            eval_metric = GenericAveragePrecision()
147        else:
148            logging.error("No evaluation metric specified, exiting")
149            sys.exit(1)
150        return eval_metric
def train_model(self, output_path, num_epochs, patience, create=True, frozen=False)
152    def train_model(self, output_path, num_epochs, patience, create=True, frozen=False):
153        """Performs training of model for a number of cycles
154        with a learning rate that is determined automatically.
155        """
156        train_losses = []
157        valid_losses = []
158        eval_scores = []
159
160        if create:
161            self.create_model_and_optimiser(self.starting_lr, frozen=frozen)
162            lr_to_use = self.run_lr_finder()
163            # Recreate model and start training
164            self.create_model_and_optimiser(lr_to_use, frozen=frozen)
165            early_stopping = self.create_early_stopping(output_path, patience)
166        else:
167            # Reduce starting LR, since model alreadiy partiallly trained
168            self.starting_lr /= self.lr_reduce_factor
169            self.end_lr /= self.lr_reduce_factor
170            self.log_lr_ratio = self.calculate_log_lr_ratio()
171            self.load_in_model_and_optimizer(
172                self.starting_lr, output_path, frozen=frozen, optimizer=False
173            )
174            lr_to_use = self.run_lr_finder()
175            min_loss = self.load_in_model_and_optimizer(
176                self.starting_lr, output_path, frozen=frozen, optimizer=False
177            )
178            early_stopping = self.create_early_stopping(
179                output_path, patience, best_score=-min_loss
180            )
181
182        # Initialise the One Cycle learning rate scheduler
183        lr_scheduler = self.create_oc_lr_scheduler(num_epochs, lr_to_use)
184
185        for epoch in range(1, num_epochs + 1):
186            self.model.train()
187            tic = time.perf_counter()
188            logging.info(f"Epoch {epoch} of {num_epochs}")
189            for batch in tqdm(
190                self.training_loader,
191                desc="Training batch",
192                bar_format=cfg.TQDM_BAR_FORMAT,
193            ):
194                loss = self.train_one_batch(lr_scheduler, batch)
195                train_losses.append(loss.item())  # record training loss
196
197            self.model.eval()  # prep model for evaluation
198            with torch.no_grad():
199                for batch in tqdm(
200                    self.validation_loader,
201                    desc="Validation batch",
202                    bar_format=cfg.TQDM_BAR_FORMAT,
203                ):
204                    inputs, targets = utils.prepare_training_batch(
205                        batch, self.model_device_num, self.label_no
206                    )
207                    output = self.model(inputs)  # Forward pass
208                    # calculate the loss
209                    if self.settings.loss_criterion == "CrossEntropyLoss":
210                        loss = self.loss_criterion(output, torch.argmax(targets, dim=1))
211                    else:
212                        loss = self.loss_criterion(output, targets.float())
213                    valid_losses.append(loss.item())  # record validation loss
214                    s_max = nn.Softmax(dim=1)
215                    probs = s_max(output)  # Convert the logits to probs
216                    probs = torch.unsqueeze(probs, 2)
217                    targets = torch.unsqueeze(targets, 2)
218                    eval_score = self.eval_metric(probs, targets)
219                    eval_scores.append(eval_score)  # record eval metric
220
221            toc = time.perf_counter()
222            # calculate average loss/metric over an epoch
223            self.avg_train_losses.append(np.average(train_losses))
224            self.avg_valid_losses.append(np.average(valid_losses))
225            self.avg_eval_scores.append(np.average(eval_scores))
226            logging.info(
227                f"Epoch {epoch}. Training loss: {self.avg_train_losses[-1]}, Validation Loss: "
228                f"{self.avg_valid_losses[-1]}. {self.settings.eval_metric}: {self.avg_eval_scores[-1]}"
229            )
230            logging.info(f"Time taken for epoch {epoch}: {toc - tic:0.2f} seconds")
231            # clear lists to track next epoch
232            train_losses = []
233            valid_losses = []
234            eval_scores = []
235
236            # early_stopping needs the validation loss to check if it has decreased,
237            # and if it has, it will make a checkpoint of the current model
238            early_stopping(
239                self.avg_valid_losses[-1], self.model, self.optimizer, self.codes
240            )
241
242            if early_stopping.early_stop:
243                logging.info("Early stopping")
244                break
245
246        # load the last checkpoint with the best model
247        self.load_in_weights(output_path)

Performs training of model for a number of cycles with a learning rate that is determined automatically.

def load_in_model_and_optimizer(self, learning_rate, output_path, frozen=False, optimizer=False)
249    def load_in_model_and_optimizer(
250        self, learning_rate, output_path, frozen=False, optimizer=False
251    ):
252        self.create_model_and_optimiser(learning_rate, frozen=frozen)
253        logging.info("Loading in weights from saved checkpoint.")
254        loss_val = self.load_in_weights(output_path, optimizer=optimizer)
255        return loss_val
def load_in_weights(self, output_path, optimizer=False, gpu=True)
257    def load_in_weights(self, output_path, optimizer=False, gpu=True):
258        # load the last checkpoint with the best model
259        if gpu:
260            map_location = f"cuda:{self.model_device_num}"
261        else:
262            map_location = "cpu"
263        model_dict = torch.load(output_path, map_location=map_location)
264        logging.info("Loading model weights.")
265        self.model.load_state_dict(model_dict["model_state_dict"])
266        if optimizer:
267            logging.info("Loading optimizer weights.")
268            self.optimizer.load_state_dict(model_dict["optimizer_state_dict"])
269        return model_dict.get("loss_val", np.inf)
def run_lr_finder(self)
271    def run_lr_finder(self):
272        logging.info("Finding learning rate for model.")
273        lr_scheduler = self.create_exponential_lr_scheduler()
274        lr_find_loss, lr_find_lr = self.lr_finder(lr_scheduler)
275        lr_to_use = self.find_lr_from_graph(lr_find_loss, lr_find_lr)
276        logging.info(f"LR to use {lr_to_use}")
277        return lr_to_use
def lr_finder(self, lr_scheduler, smoothing=0.05, plt_fig=True)
279    def lr_finder(self, lr_scheduler, smoothing=0.05, plt_fig=True):
280        lr_find_loss = []
281        lr_find_lr = []
282        iters = 0
283
284        self.model.train()
285        logging.info(
286            f"Training for {self.lr_find_epochs} epochs to create a learning "
287            "rate plot."
288        )
289        for i in range(self.lr_find_epochs):
290            for batch in tqdm(
291                self.training_loader,
292                desc=f"Epoch {i + 1}, batch number",
293                bar_format=cfg.TQDM_BAR_FORMAT,
294            ):
295                loss = self.train_one_batch(lr_scheduler, batch)
296                lr_step = self.optimizer.state_dict()["param_groups"][0]["lr"]
297                lr_find_lr.append(lr_step)
298                if iters == 0:
299                    lr_find_loss.append(loss)
300                else:
301                    loss = smoothing * loss + (1 - smoothing) * lr_find_loss[-1]
302                    lr_find_loss.append(loss)
303                if loss > 1 and iters > len(self.training_loader) // 1.333:
304                    break
305                iters += 1
306
307        if plt_fig:
308            fig = tpl.figure()
309            fig.plot(
310                np.log10(lr_find_lr),
311                lr_find_loss,
312                width=50,
313                height=30,
314                xlabel="Log10 Learning Rate",
315            )
316            fig.show()
317
318        return lr_find_loss, lr_find_lr
@staticmethod
def find_lr_from_graph(lr_find_loss: torch.Tensor, lr_find_lr: torch.Tensor) -> float:
320    @staticmethod
321    def find_lr_from_graph(
322        lr_find_loss: torch.Tensor, lr_find_lr: torch.Tensor
323    ) -> float:
324        """Calculates learning rate corresponsing to minimum gradient in graph
325        of loss vs learning rate.
326
327        Args:
328            lr_find_loss (torch.Tensor): Loss values accumulated during training
329            lr_find_lr (torch.Tensor): Learning rate used for mini-batch
330
331        Returns:
332            float: The learning rate at the point when loss was falling most steeply
333            divided by a fudge factor.
334        """
335        default_min_lr = cfg.DEFAULT_MIN_LR  # Add as default value to fix bug
336        # Get loss values and their corresponding gradients, and get lr values
337        for i in range(0, len(lr_find_loss)):
338            if lr_find_loss[i].is_cuda:
339                lr_find_loss[i] = lr_find_loss[i].cpu()
340            lr_find_loss[i] = lr_find_loss[i].detach().numpy()
341        losses = np.array(lr_find_loss)
342        try:
343            gradients = np.gradient(losses)
344            min_gradient = gradients.min()
345            if min_gradient < 0:
346                min_loss_grad_idx = gradients.argmin()
347            else:
348                logging.info(
349                    f"Minimum gradient: {min_gradient} was positive, returning default value instead."
350                )
351                return default_min_lr
352        except Exception as e:
353            logging.info(f"Failed to compute gradients, returning default value. {e}")
354            return default_min_lr
355        min_lr = lr_find_lr[min_loss_grad_idx]
356        return min_lr / cfg.LR_DIVISOR

Calculates learning rate corresponsing to minimum gradient in graph of loss vs learning rate.

Args
  • lr_find_loss (torch.Tensor): Loss values accumulated during training
  • lr_find_lr (torch.Tensor): Learning rate used for mini-batch
Returns

float: The learning rate at the point when loss was falling most steeply divided by a fudge factor.

def lr_exp_stepper(self, x)
358    def lr_exp_stepper(self, x):
359        """Exponentially increase learning rate as part of strategy to find the
360        optimum.
361        Taken from
362        https://towardsdatascience.com/adaptive-and-cyclical-learning-rates-using-pytorch-2bf904d18dee
363        """
364        return math.exp(
365            x * self.log_lr_ratio / (self.lr_find_epochs * len(self.training_loader))
366        )

Exponentially increase learning rate as part of strategy to find the optimum. Taken from https://towardsdatascience.com/adaptive-and-cyclical-learning-rates-using-pytorch-2bf904d18dee

def create_optimizer(self, learning_rate)
368    def create_optimizer(self, learning_rate):
369        return torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
def create_exponential_lr_scheduler(self)
371    def create_exponential_lr_scheduler(self):
372        return torch.optim.lr_scheduler.LambdaLR(self.optimizer, self.lr_exp_stepper)
def create_oc_lr_scheduler(self, num_epochs, lr_to_use)
374    def create_oc_lr_scheduler(self, num_epochs, lr_to_use):
375        return torch.optim.lr_scheduler.OneCycleLR(
376            self.optimizer,
377            max_lr=lr_to_use,
378            steps_per_epoch=len(self.training_loader),
379            epochs=num_epochs,
380            pct_start=self.settings.pct_lr_inc,
381        )
def create_early_stopping(self, output_path, patience, best_score=None)
383    def create_early_stopping(self, output_path, patience, best_score=None):
384        return EarlyStopping(
385            patience=patience,
386            verbose=True,
387            path=output_path,
388            model_dict=self.model_struc_dict,
389            best_score=best_score,
390        )
def train_one_batch(self, lr_scheduler, batch)
392    def train_one_batch(self, lr_scheduler, batch):
393        inputs, targets = utils.prepare_training_batch(
394            batch, self.model_device_num, self.label_no
395        )
396        self.optimizer.zero_grad()
397        output = self.model(inputs)  # Forward pass
398        if self.settings.loss_criterion == "CrossEntropyLoss":
399            loss = self.loss_criterion(output, torch.argmax(targets, dim=1))
400        else:
401            loss = self.loss_criterion(output, targets.float())
402        loss.backward()  # Backward pass
403        self.optimizer.step()
404        lr_scheduler.step()  # update the learning rate
405        return loss
def output_loss_fig(self, model_out_path)
407    def output_loss_fig(self, model_out_path):
408
409        fig = plt.figure(figsize=(10, 8))
410        plt.plot(
411            range(1, len(self.avg_train_losses) + 1),
412            self.avg_train_losses,
413            label="Training Loss",
414        )
415        plt.plot(
416            range(1, len(self.avg_valid_losses) + 1),
417            self.avg_valid_losses,
418            label="Validation Loss",
419        )
420
421        minposs = (
422            self.avg_valid_losses.index(min(self.avg_valid_losses)) + 1
423        )  # find position of lowest validation loss
424        plt.axvline(
425            minposs, linestyle="--", color="r", label="Early Stopping Checkpoint"
426        )
427
428        plt.xlabel("epochs")
429        plt.ylabel("loss")
430        plt.xlim(0, len(self.avg_train_losses) + 1)  # consistent scale
431        plt.grid(True)
432        plt.legend()
433        plt.tight_layout()
434        output_dir = model_out_path.parent
435        fig_out_pth = output_dir / f"{model_out_path.stem}_loss_plot.png"
436        logging.info(f"Saving figure of training/validation losses to {fig_out_pth}")
437        fig.savefig(fig_out_pth, bbox_inches="tight")
438        # Output a list of training stats
439        epoch_lst = range(len(self.avg_train_losses))
440        rows = zip(
441            epoch_lst,
442            self.avg_train_losses,
443            self.avg_valid_losses,
444            self.avg_eval_scores,
445        )
446        with open(output_dir / f"{model_out_path.stem}_train_stats.csv", "w") as f:
447            writer = csv.writer(f)
448            writer.writerow(("Epoch", "Train Loss", "Valid Loss", "Eval Score"))
449            for row in rows:
450                writer.writerow(row)
def output_prediction_figure(self, model_path)
452    def output_prediction_figure(self, model_path):
453        """Saves a figure containing image slice data for three random images
454        fromthe validation dataset along with the corresponding ground truth
455        label image and corresponding prediction output from the model attached
456        to this class instance. The image is saved to the same directory as the
457        model weights.
458
459        Args:
460            model_path (pathlib.Path): Full path to the model weights file,
461            this is used to get the directory and name of the model not to
462            load and predict.
463        """
464        self.model.eval()  # prep model for evaluation
465        batch = next(iter(self.validation_loader))  # Get first batch
466        with torch.no_grad():
467            inputs, targets = utils.prepare_training_batch(
468                batch, self.model_device_num, self.label_no
469            )
470            output = self.model(inputs)  # Forward pass
471            s_max = nn.Softmax(dim=1)
472            probs = s_max(output)  # Convert the logits to probs
473            labels = torch.argmax(probs, dim=1)  # flatten channels
474
475        # Create the plot
476        bs = self.validation_loader.batch_size
477        if bs < 4:
478            rows = bs
479        else:
480            rows = 4
481        fig = plt.figure(figsize=(12, 16))
482        columns = 3
483        j = 0
484        for i in range(columns * rows)[::3]:
485            img = inputs[j].squeeze().cpu()
486            gt = torch.argmax(targets[j], dim=0).cpu()
487            pred = labels[j].cpu()
488            col1 = fig.add_subplot(rows, columns, i + 1)
489            plt.imshow(img, cmap="gray")
490            col2 = fig.add_subplot(rows, columns, i + 2)
491            plt.imshow(gt, cmap="gray")
492            col3 = fig.add_subplot(rows, columns, i + 3)
493            plt.imshow(pred, cmap="gray")
494            j += 1
495            if i == 0:
496                col1.title.set_text("Data")
497                col2.title.set_text("Ground Truth")
498                col3.title.set_text("Prediction")
499        plt.suptitle(f"Predictions for {model_path.name}", fontsize=16)
500        plt_out_pth = model_path.parent / f"{model_path.stem}_prediction_image.png"
501        logging.info(f"Saving example image predictions to {plt_out_pth}")
502        plt.savefig(plt_out_pth, dpi=300)

Saves a figure containing image slice data for three random images fromthe validation dataset along with the corresponding ground truth label image and corresponding prediction output from the model attached to this class instance. The image is saved to the same directory as the model weights.

Args
  • model_path (pathlib.Path): Full path to the model weights file,
  • this is used to get the directory and name of the model not to
  • load and predict.
class VolSeg2dPredictor:
 17class VolSeg2dPredictor:
 18    """Class that performs U-Net prediction operations. Does not interact with disk."""
 19
 20    def __init__(self, model_file_path: str, settings: SimpleNamespace) -> None:
 21        self.model_file_path = Path(model_file_path)
 22        self.settings = settings
 23        self.model_device_num = int(settings.cuda_device)
 24        model_tuple = create_model_from_file(
 25            self.model_file_path, self.model_device_num
 26        )
 27        self.model, self.num_labels, self.label_codes = model_tuple
 28
 29    def get_model_from_trainer(self, trainer):
 30        self.model = trainer.model
 31
 32    def predict_single_axis(self, data_vol, output_probs=False, axis=Axis.Z):
 33        output_vol_list = []
 34        output_prob_list = []
 35        data_vol = utils.rotate_array_to_axis(data_vol, axis)
 36        yx_dims = list(data_vol.shape[1:])
 37        s_max = nn.Softmax(dim=1)
 38        data_loader = get_2d_prediction_dataloader(data_vol, self.settings)
 39        self.model.eval()
 40        logging.info(f"Predicting segmentation for volume of shape {data_vol.shape}.")
 41        with torch.no_grad():
 42            for batch in tqdm(
 43                data_loader, desc="Prediction batch", bar_format=cfg.TQDM_BAR_FORMAT
 44            ):
 45                output = self.model(batch.to(self.model_device_num))  # Forward pass
 46                probs = s_max(output)  # Convert the logits to probs
 47                # TODO: Don't flatten channels if one-hot output is needed
 48                labels = torch.argmax(probs, dim=1)  # flatten channels
 49                labels = utils.crop_tensor_to_array(labels, yx_dims)
 50                output_vol_list.append(labels.astype(np.uint8))
 51                if output_probs:
 52                    # Get indices of max probs
 53                    max_prob_idx = torch.argmax(probs, dim=1, keepdim=True)
 54                    # Extract along axis from outputs
 55                    probs = torch.gather(probs, 1, max_prob_idx)
 56                    # Remove the label dimension
 57                    probs = torch.squeeze(probs, dim=1)
 58                    probs = utils.crop_tensor_to_array(probs, yx_dims)
 59                    output_prob_list.append(probs.astype(np.float16))
 60
 61        labels = np.concatenate(output_vol_list)
 62        labels = utils.rotate_array_to_axis(labels, axis)
 63        probs = np.concatenate(output_prob_list) if output_prob_list else None
 64        if probs is not None:
 65            probs = utils.rotate_array_to_axis(probs, axis)
 66        return labels, probs
 67
 68    def predict_3_ways_max_probs(self, data_vol):
 69        shape_tup = data_vol.shape
 70        logging.info("Creating empty data volumes in RAM to combine 3 axis prediction.")
 71        label_container = np.empty((2, *shape_tup), dtype=np.uint8)
 72        prob_container = np.empty((2, *shape_tup), dtype=np.float16)
 73        logging.info("Predicting YX slices:")
 74        label_container[0], prob_container[0] = self.predict_single_axis(
 75            data_vol, output_probs=True
 76        )
 77        logging.info("Predicting ZX slices:")
 78        label_container[1], prob_container[1] = self.predict_single_axis(
 79            data_vol, output_probs=True, axis=Axis.Y
 80        )
 81        logging.info("Merging XY and ZX volumes.")
 82        self.merge_vols_in_mem(prob_container, label_container)
 83        logging.info("Predicting ZY slices:")
 84        label_container[1], prob_container[1] = self.predict_single_axis(
 85            data_vol, output_probs=True, axis=Axis.X
 86        )
 87        logging.info("Merging max of XY and ZX volumes with ZY volume.")
 88        self.merge_vols_in_mem(prob_container, label_container)
 89        return label_container[0], prob_container[0]
 90
 91    def merge_vols_in_mem(self, prob_container, label_container):
 92        max_prob_idx = np.argmax(prob_container, axis=0)
 93        max_prob_idx = max_prob_idx[np.newaxis, :, :, :]
 94        prob_container[0] = np.squeeze(
 95            np.take_along_axis(prob_container, max_prob_idx, axis=0)
 96        )
 97        label_container[0] = np.squeeze(
 98            np.take_along_axis(label_container, max_prob_idx, axis=0)
 99        )
100
101    def predict_12_ways_max_probs(self, data_vol):
102        shape_tup = data_vol.shape
103        logging.info("Creating empty data volumes in RAM to combine 12 way prediction.")
104        label_container = np.empty((2, *shape_tup), dtype=np.uint8)
105        prob_container = np.empty((2, *shape_tup), dtype=np.float16)
106        label_container[0], prob_container[0] = self.predict_3_ways_max_probs(data_vol)
107        for k in range(1, 4):
108            logging.info(f"Rotating volume {k * 90} degrees")
109            data_vol = np.rot90(data_vol)
110            labels, probs = self.predict_3_ways_max_probs(data_vol)
111            label_container[1] = np.rot90(labels, -k)
112            prob_container[1] = np.rot90(probs, -k)
113            logging.info(
114                f"Merging rot {k * 90} deg volume with rot {(k-1) * 90} deg volume."
115            )
116            self.merge_vols_in_mem(prob_container, label_container)
117        return label_container[0], prob_container[0]
118
119    def predict_single_axis_to_one_hot(self, data_vol, axis=Axis.Z):
120        prediction, _ = self.predict_single_axis(data_vol, axis=axis)
121        return utils.one_hot_encode_array(prediction, self.num_labels)
122
123    def predict_3_ways_one_hot(self, data_vol):
124        one_hot_out = self.predict_single_axis_to_one_hot(data_vol)
125        one_hot_out += self.predict_single_axis_to_one_hot(data_vol, Axis.Y)
126        one_hot_out += self.predict_single_axis_to_one_hot(data_vol, Axis.X)
127        return one_hot_out
128
129    def predict_12_ways_one_hot(self, data_vol):
130        one_hot_out = self.predict_3_ways_one_hot(data_vol)
131        for k in range(1, 4):
132            logging.info(f"Rotating volume {k * 90} degrees")
133            data_vol = np.rot90(data_vol)
134            one_hot_out += np.rot90(
135                self.predict_3_ways_one_hot(data_vol), -k, axes=(-3, -2)
136            )
137        return one_hot_out

Class that performs U-Net prediction operations. Does not interact with disk.

VolSeg2dPredictor(model_file_path: str, settings: types.SimpleNamespace)
20    def __init__(self, model_file_path: str, settings: SimpleNamespace) -> None:
21        self.model_file_path = Path(model_file_path)
22        self.settings = settings
23        self.model_device_num = int(settings.cuda_device)
24        model_tuple = create_model_from_file(
25            self.model_file_path, self.model_device_num
26        )
27        self.model, self.num_labels, self.label_codes = model_tuple
def get_model_from_trainer(self, trainer)
29    def get_model_from_trainer(self, trainer):
30        self.model = trainer.model
def predict_single_axis(self, data_vol, output_probs=False, axis=<Axis.Z: 0>)
32    def predict_single_axis(self, data_vol, output_probs=False, axis=Axis.Z):
33        output_vol_list = []
34        output_prob_list = []
35        data_vol = utils.rotate_array_to_axis(data_vol, axis)
36        yx_dims = list(data_vol.shape[1:])
37        s_max = nn.Softmax(dim=1)
38        data_loader = get_2d_prediction_dataloader(data_vol, self.settings)
39        self.model.eval()
40        logging.info(f"Predicting segmentation for volume of shape {data_vol.shape}.")
41        with torch.no_grad():
42            for batch in tqdm(
43                data_loader, desc="Prediction batch", bar_format=cfg.TQDM_BAR_FORMAT
44            ):
45                output = self.model(batch.to(self.model_device_num))  # Forward pass
46                probs = s_max(output)  # Convert the logits to probs
47                # TODO: Don't flatten channels if one-hot output is needed
48                labels = torch.argmax(probs, dim=1)  # flatten channels
49                labels = utils.crop_tensor_to_array(labels, yx_dims)
50                output_vol_list.append(labels.astype(np.uint8))
51                if output_probs:
52                    # Get indices of max probs
53                    max_prob_idx = torch.argmax(probs, dim=1, keepdim=True)
54                    # Extract along axis from outputs
55                    probs = torch.gather(probs, 1, max_prob_idx)
56                    # Remove the label dimension
57                    probs = torch.squeeze(probs, dim=1)
58                    probs = utils.crop_tensor_to_array(probs, yx_dims)
59                    output_prob_list.append(probs.astype(np.float16))
60
61        labels = np.concatenate(output_vol_list)
62        labels = utils.rotate_array_to_axis(labels, axis)
63        probs = np.concatenate(output_prob_list) if output_prob_list else None
64        if probs is not None:
65            probs = utils.rotate_array_to_axis(probs, axis)
66        return labels, probs
def predict_3_ways_max_probs(self, data_vol)
68    def predict_3_ways_max_probs(self, data_vol):
69        shape_tup = data_vol.shape
70        logging.info("Creating empty data volumes in RAM to combine 3 axis prediction.")
71        label_container = np.empty((2, *shape_tup), dtype=np.uint8)
72        prob_container = np.empty((2, *shape_tup), dtype=np.float16)
73        logging.info("Predicting YX slices:")
74        label_container[0], prob_container[0] = self.predict_single_axis(
75            data_vol, output_probs=True
76        )
77        logging.info("Predicting ZX slices:")
78        label_container[1], prob_container[1] = self.predict_single_axis(
79            data_vol, output_probs=True, axis=Axis.Y
80        )
81        logging.info("Merging XY and ZX volumes.")
82        self.merge_vols_in_mem(prob_container, label_container)
83        logging.info("Predicting ZY slices:")
84        label_container[1], prob_container[1] = self.predict_single_axis(
85            data_vol, output_probs=True, axis=Axis.X
86        )
87        logging.info("Merging max of XY and ZX volumes with ZY volume.")
88        self.merge_vols_in_mem(prob_container, label_container)
89        return label_container[0], prob_container[0]
def merge_vols_in_mem(self, prob_container, label_container)
91    def merge_vols_in_mem(self, prob_container, label_container):
92        max_prob_idx = np.argmax(prob_container, axis=0)
93        max_prob_idx = max_prob_idx[np.newaxis, :, :, :]
94        prob_container[0] = np.squeeze(
95            np.take_along_axis(prob_container, max_prob_idx, axis=0)
96        )
97        label_container[0] = np.squeeze(
98            np.take_along_axis(label_container, max_prob_idx, axis=0)
99        )
def predict_12_ways_max_probs(self, data_vol)
101    def predict_12_ways_max_probs(self, data_vol):
102        shape_tup = data_vol.shape
103        logging.info("Creating empty data volumes in RAM to combine 12 way prediction.")
104        label_container = np.empty((2, *shape_tup), dtype=np.uint8)
105        prob_container = np.empty((2, *shape_tup), dtype=np.float16)
106        label_container[0], prob_container[0] = self.predict_3_ways_max_probs(data_vol)
107        for k in range(1, 4):
108            logging.info(f"Rotating volume {k * 90} degrees")
109            data_vol = np.rot90(data_vol)
110            labels, probs = self.predict_3_ways_max_probs(data_vol)
111            label_container[1] = np.rot90(labels, -k)
112            prob_container[1] = np.rot90(probs, -k)
113            logging.info(
114                f"Merging rot {k * 90} deg volume with rot {(k-1) * 90} deg volume."
115            )
116            self.merge_vols_in_mem(prob_container, label_container)
117        return label_container[0], prob_container[0]
def predict_single_axis_to_one_hot(self, data_vol, axis=<Axis.Z: 0>)
119    def predict_single_axis_to_one_hot(self, data_vol, axis=Axis.Z):
120        prediction, _ = self.predict_single_axis(data_vol, axis=axis)
121        return utils.one_hot_encode_array(prediction, self.num_labels)
def predict_3_ways_one_hot(self, data_vol)
123    def predict_3_ways_one_hot(self, data_vol):
124        one_hot_out = self.predict_single_axis_to_one_hot(data_vol)
125        one_hot_out += self.predict_single_axis_to_one_hot(data_vol, Axis.Y)
126        one_hot_out += self.predict_single_axis_to_one_hot(data_vol, Axis.X)
127        return one_hot_out
def predict_12_ways_one_hot(self, data_vol)
129    def predict_12_ways_one_hot(self, data_vol):
130        one_hot_out = self.predict_3_ways_one_hot(data_vol)
131        for k in range(1, 4):
132            logging.info(f"Rotating volume {k * 90} degrees")
133            data_vol = np.rot90(data_vol)
134            one_hot_out += np.rot90(
135                self.predict_3_ways_one_hot(data_vol), -k, axes=(-3, -2)
136            )
137        return one_hot_out
class VolSeg2DPredictionManager(volume_segmantics.data.base_data_manager.BaseDataManager):
13class VolSeg2DPredictionManager(BaseDataManager):
14    def __init__(
15        self,
16        predictor: VolSeg2dPredictor,
17        data_vol: Union[str, np.ndarray],
18        settings: SimpleNamespace,
19    ) -> None:
20        super().__init__(data_vol, settings)
21        self.predictor = predictor
22        self.settings = settings
23
24    def predict_volume_to_path(
25        self, output_path: Union[Path, None], quality: Union[utils.Quality, None] = None
26    ) -> np.ndarray:
27        probs = None
28        one_hot = self.settings.one_hot
29        if quality is None:
30            quality = utils.get_prediction_quality(self.settings)
31        if quality == utils.Quality.LOW:
32            if one_hot:
33                prediction = self.predictor.predict_single_axis_to_one_hot(
34                    self.data_vol
35                )
36            else:
37                prediction, probs = self.predictor.predict_single_axis(self.data_vol)
38        if quality == utils.Quality.MEDIUM:
39            if one_hot:
40                prediction = self.predictor.predict_3_ways_one_hot(self.data_vol)
41            else:
42                prediction, probs = self.predictor.predict_3_ways_max_probs(
43                    self.data_vol
44                )
45        if quality == utils.Quality.HIGH:
46            if one_hot:
47                prediction = self.predictor.predict_12_ways_one_hot(self.data_vol)
48            else:
49                prediction, probs = self.predictor.predict_12_ways_max_probs(
50                    self.data_vol
51                )
52        if output_path is not None:
53            utils.save_data_to_hdf5(
54                prediction, output_path, chunking=self.input_data_chunking
55            )
56            if probs is not None and self.settings.output_probs:
57                utils.save_data_to_hdf5(
58                    probs,
59                    f"{output_path.parent / output_path.stem}_probs.h5",
60                    chunking=self.input_data_chunking,
61                )
62        return prediction
VolSeg2DPredictionManager( predictor: volume_segmantics.model.VolSeg2dPredictor, data_vol: Union[str, numpy.ndarray], settings: types.SimpleNamespace)
14    def __init__(
15        self,
16        predictor: VolSeg2dPredictor,
17        data_vol: Union[str, np.ndarray],
18        settings: SimpleNamespace,
19    ) -> None:
20        super().__init__(data_vol, settings)
21        self.predictor = predictor
22        self.settings = settings
def predict_volume_to_path( self, output_path: Optional[pathlib.Path], quality: Optional[volume_segmantics.utilities.base_data_utils.Quality] = None) -> numpy.ndarray:
24    def predict_volume_to_path(
25        self, output_path: Union[Path, None], quality: Union[utils.Quality, None] = None
26    ) -> np.ndarray:
27        probs = None
28        one_hot = self.settings.one_hot
29        if quality is None:
30            quality = utils.get_prediction_quality(self.settings)
31        if quality == utils.Quality.LOW:
32            if one_hot:
33                prediction = self.predictor.predict_single_axis_to_one_hot(
34                    self.data_vol
35                )
36            else:
37                prediction, probs = self.predictor.predict_single_axis(self.data_vol)
38        if quality == utils.Quality.MEDIUM:
39            if one_hot:
40                prediction = self.predictor.predict_3_ways_one_hot(self.data_vol)
41            else:
42                prediction, probs = self.predictor.predict_3_ways_max_probs(
43                    self.data_vol
44                )
45        if quality == utils.Quality.HIGH:
46            if one_hot:
47                prediction = self.predictor.predict_12_ways_one_hot(self.data_vol)
48            else:
49                prediction, probs = self.predictor.predict_12_ways_max_probs(
50                    self.data_vol
51                )
52        if output_path is not None:
53            utils.save_data_to_hdf5(
54                prediction, output_path, chunking=self.input_data_chunking
55            )
56            if probs is not None and self.settings.output_probs:
57                utils.save_data_to_hdf5(
58                    probs,
59                    f"{output_path.parent / output_path.stem}_probs.h5",
60                    chunking=self.input_data_chunking,
61                )
62        return prediction
Inherited Members
volume_segmantics.data.base_data_manager.BaseDataManager
preprocess_data