volume_segmantics.model
1from volume_segmantics.model.operations.vol_seg_2d_trainer import \ 2 VolSeg2dTrainer 3from volume_segmantics.model.operations.vol_seg_2d_predictor import \ 4 VolSeg2dPredictor 5from volume_segmantics.model.operations.vol_seg_prediction_manager import \ 6 VolSeg2DPredictionManager 7 8__all__ = [VolSeg2dTrainer, VolSeg2dPredictor, VolSeg2DPredictionManager]
34class VolSeg2dTrainer: 35 """Class that utlises 2d dataloaders to train a 2d deep learning model. 36 37 Args: 38 sampler 39 settings 40 """ 41 42 def __init__( 43 self, image_dir_path, label_dir_path, labels: Union[int, dict], settings 44 ): 45 self.training_loader, self.validation_loader = get_2d_training_dataloaders( 46 image_dir_path, label_dir_path, settings 47 ) 48 self.label_no = labels if isinstance(labels, int) else len(labels) 49 self.codes = labels if isinstance(labels, dict) else {} 50 self.settings = settings 51 # Params for learning rate finder 52 self.starting_lr = float(settings.starting_lr) 53 self.end_lr = float(settings.end_lr) 54 self.log_lr_ratio = self.calculate_log_lr_ratio() 55 self.lr_find_epochs = settings.lr_find_epochs 56 self.lr_reduce_factor = settings.lr_reduce_factor 57 # Params for model training 58 self.model_device_num = int(settings.cuda_device) 59 self.patience = settings.patience 60 self.loss_criterion = self.get_loss_criterion() 61 self.eval_metric = self.get_eval_metric() 62 self.model_struc_dict = self.get_model_struc_dict(settings) 63 self.avg_train_losses = [] # per epoch training loss 64 self.avg_valid_losses = [] # per epoch validation loss 65 self.avg_eval_scores = [] # per epoch evaluation score 66 67 def get_model_struc_dict(self, settings): 68 model_struc_dict = settings.model 69 model_type = utils.get_model_type(settings) 70 model_struc_dict["type"] = model_type 71 model_struc_dict["in_channels"] = cfg.MODEL_INPUT_CHANNELS 72 model_struc_dict["classes"] = self.label_no 73 return model_struc_dict 74 75 def calculate_log_lr_ratio(self): 76 return math.log(self.end_lr / self.starting_lr) 77 78 def create_model_and_optimiser(self, learning_rate, frozen=False): 79 logging.info(f"Setting up the model on device {self.settings.cuda_device}.") 80 self.model = create_model_on_device( 81 self.model_device_num, self.model_struc_dict 82 ) 83 if frozen: 84 self.freeze_model() 85 logging.info( 86 f"Model has {self.count_trainable_parameters()} trainable parameters, {self.count_parameters()} total parameters." 87 ) 88 self.optimizer = self.create_optimizer(learning_rate) 89 logging.info("Trainer created.") 90 91 def freeze_model(self): 92 logging.info( 93 f"Freezing model with {self.count_trainable_parameters()} trainable parameters, {self.count_parameters()} total parameters." 94 ) 95 for name, param in self.model.named_parameters(): 96 if all(["encoder" in name, "conv" in name]) and param.requires_grad: 97 param.requires_grad = False 98 99 def unfreeze_model(self): 100 logging.info( 101 f"Unfreezing model with {self.count_trainable_parameters()} trainable parameters, {self.count_parameters()} total parameters." 102 ) 103 for name, param in self.model.named_parameters(): 104 if all(["encoder" in name, "conv" in name]) and not param.requires_grad: 105 param.requires_grad = True 106 107 def count_trainable_parameters(self) -> int: 108 return sum(p.numel() for p in self.model.parameters() if p.requires_grad) 109 110 def count_parameters(self) -> int: 111 return sum(p.numel() for p in self.model.parameters()) 112 113 def get_loss_criterion(self): 114 if self.settings.loss_criterion == "BCEDiceLoss": 115 alpha = self.settings.alpha 116 beta = self.settings.beta 117 logging.info( 118 f"Using combined BCE and Dice loss with weighting of {alpha}*BCE " 119 f"and {beta}*Dice" 120 ) 121 loss_criterion = BCEDiceLoss(alpha, beta) 122 elif self.settings.loss_criterion == "DiceLoss": 123 logging.info("Using DiceLoss") 124 loss_criterion = DiceLoss(normalization="none") 125 elif self.settings.loss_criterion == "BCELoss": 126 logging.info("Using BCELoss") 127 loss_criterion = nn.BCEWithLogitsLoss() 128 elif self.settings.loss_criterion == "CrossEntropyLoss": 129 logging.info("Using CrossEntropyLoss") 130 loss_criterion = nn.CrossEntropyLoss() 131 elif self.settings.loss_criterion == "GeneralizedDiceLoss": 132 logging.info("Using GeneralizedDiceLoss") 133 loss_criterion = GeneralizedDiceLoss() 134 else: 135 logging.error("No loss criterion specified, exiting") 136 sys.exit(1) 137 return loss_criterion 138 139 def get_eval_metric(self): 140 # Get evaluation metric 141 if self.settings.eval_metric == "MeanIoU": 142 logging.info("Using MeanIoU") 143 eval_metric = MeanIoU() 144 elif self.settings.eval_metric == "GenericAveragePrecision": 145 logging.info("Using GenericAveragePrecision") 146 eval_metric = GenericAveragePrecision() 147 else: 148 logging.error("No evaluation metric specified, exiting") 149 sys.exit(1) 150 return eval_metric 151 152 def train_model(self, output_path, num_epochs, patience, create=True, frozen=False): 153 """Performs training of model for a number of cycles 154 with a learning rate that is determined automatically. 155 """ 156 train_losses = [] 157 valid_losses = [] 158 eval_scores = [] 159 160 if create: 161 self.create_model_and_optimiser(self.starting_lr, frozen=frozen) 162 lr_to_use = self.run_lr_finder() 163 # Recreate model and start training 164 self.create_model_and_optimiser(lr_to_use, frozen=frozen) 165 early_stopping = self.create_early_stopping(output_path, patience) 166 else: 167 # Reduce starting LR, since model alreadiy partiallly trained 168 self.starting_lr /= self.lr_reduce_factor 169 self.end_lr /= self.lr_reduce_factor 170 self.log_lr_ratio = self.calculate_log_lr_ratio() 171 self.load_in_model_and_optimizer( 172 self.starting_lr, output_path, frozen=frozen, optimizer=False 173 ) 174 lr_to_use = self.run_lr_finder() 175 min_loss = self.load_in_model_and_optimizer( 176 self.starting_lr, output_path, frozen=frozen, optimizer=False 177 ) 178 early_stopping = self.create_early_stopping( 179 output_path, patience, best_score=-min_loss 180 ) 181 182 # Initialise the One Cycle learning rate scheduler 183 lr_scheduler = self.create_oc_lr_scheduler(num_epochs, lr_to_use) 184 185 for epoch in range(1, num_epochs + 1): 186 self.model.train() 187 tic = time.perf_counter() 188 logging.info(f"Epoch {epoch} of {num_epochs}") 189 for batch in tqdm( 190 self.training_loader, 191 desc="Training batch", 192 bar_format=cfg.TQDM_BAR_FORMAT, 193 ): 194 loss = self.train_one_batch(lr_scheduler, batch) 195 train_losses.append(loss.item()) # record training loss 196 197 self.model.eval() # prep model for evaluation 198 with torch.no_grad(): 199 for batch in tqdm( 200 self.validation_loader, 201 desc="Validation batch", 202 bar_format=cfg.TQDM_BAR_FORMAT, 203 ): 204 inputs, targets = utils.prepare_training_batch( 205 batch, self.model_device_num, self.label_no 206 ) 207 output = self.model(inputs) # Forward pass 208 # calculate the loss 209 if self.settings.loss_criterion == "CrossEntropyLoss": 210 loss = self.loss_criterion(output, torch.argmax(targets, dim=1)) 211 else: 212 loss = self.loss_criterion(output, targets.float()) 213 valid_losses.append(loss.item()) # record validation loss 214 s_max = nn.Softmax(dim=1) 215 probs = s_max(output) # Convert the logits to probs 216 probs = torch.unsqueeze(probs, 2) 217 targets = torch.unsqueeze(targets, 2) 218 eval_score = self.eval_metric(probs, targets) 219 eval_scores.append(eval_score) # record eval metric 220 221 toc = time.perf_counter() 222 # calculate average loss/metric over an epoch 223 self.avg_train_losses.append(np.average(train_losses)) 224 self.avg_valid_losses.append(np.average(valid_losses)) 225 self.avg_eval_scores.append(np.average(eval_scores)) 226 logging.info( 227 f"Epoch {epoch}. Training loss: {self.avg_train_losses[-1]}, Validation Loss: " 228 f"{self.avg_valid_losses[-1]}. {self.settings.eval_metric}: {self.avg_eval_scores[-1]}" 229 ) 230 logging.info(f"Time taken for epoch {epoch}: {toc - tic:0.2f} seconds") 231 # clear lists to track next epoch 232 train_losses = [] 233 valid_losses = [] 234 eval_scores = [] 235 236 # early_stopping needs the validation loss to check if it has decreased, 237 # and if it has, it will make a checkpoint of the current model 238 early_stopping( 239 self.avg_valid_losses[-1], self.model, self.optimizer, self.codes 240 ) 241 242 if early_stopping.early_stop: 243 logging.info("Early stopping") 244 break 245 246 # load the last checkpoint with the best model 247 self.load_in_weights(output_path) 248 249 def load_in_model_and_optimizer( 250 self, learning_rate, output_path, frozen=False, optimizer=False 251 ): 252 self.create_model_and_optimiser(learning_rate, frozen=frozen) 253 logging.info("Loading in weights from saved checkpoint.") 254 loss_val = self.load_in_weights(output_path, optimizer=optimizer) 255 return loss_val 256 257 def load_in_weights(self, output_path, optimizer=False, gpu=True): 258 # load the last checkpoint with the best model 259 if gpu: 260 map_location = f"cuda:{self.model_device_num}" 261 else: 262 map_location = "cpu" 263 model_dict = torch.load(output_path, map_location=map_location) 264 logging.info("Loading model weights.") 265 self.model.load_state_dict(model_dict["model_state_dict"]) 266 if optimizer: 267 logging.info("Loading optimizer weights.") 268 self.optimizer.load_state_dict(model_dict["optimizer_state_dict"]) 269 return model_dict.get("loss_val", np.inf) 270 271 def run_lr_finder(self): 272 logging.info("Finding learning rate for model.") 273 lr_scheduler = self.create_exponential_lr_scheduler() 274 lr_find_loss, lr_find_lr = self.lr_finder(lr_scheduler) 275 lr_to_use = self.find_lr_from_graph(lr_find_loss, lr_find_lr) 276 logging.info(f"LR to use {lr_to_use}") 277 return lr_to_use 278 279 def lr_finder(self, lr_scheduler, smoothing=0.05, plt_fig=True): 280 lr_find_loss = [] 281 lr_find_lr = [] 282 iters = 0 283 284 self.model.train() 285 logging.info( 286 f"Training for {self.lr_find_epochs} epochs to create a learning " 287 "rate plot." 288 ) 289 for i in range(self.lr_find_epochs): 290 for batch in tqdm( 291 self.training_loader, 292 desc=f"Epoch {i + 1}, batch number", 293 bar_format=cfg.TQDM_BAR_FORMAT, 294 ): 295 loss = self.train_one_batch(lr_scheduler, batch) 296 lr_step = self.optimizer.state_dict()["param_groups"][0]["lr"] 297 lr_find_lr.append(lr_step) 298 if iters == 0: 299 lr_find_loss.append(loss) 300 else: 301 loss = smoothing * loss + (1 - smoothing) * lr_find_loss[-1] 302 lr_find_loss.append(loss) 303 if loss > 1 and iters > len(self.training_loader) // 1.333: 304 break 305 iters += 1 306 307 if plt_fig: 308 fig = tpl.figure() 309 fig.plot( 310 np.log10(lr_find_lr), 311 lr_find_loss, 312 width=50, 313 height=30, 314 xlabel="Log10 Learning Rate", 315 ) 316 fig.show() 317 318 return lr_find_loss, lr_find_lr 319 320 @staticmethod 321 def find_lr_from_graph( 322 lr_find_loss: torch.Tensor, lr_find_lr: torch.Tensor 323 ) -> float: 324 """Calculates learning rate corresponsing to minimum gradient in graph 325 of loss vs learning rate. 326 327 Args: 328 lr_find_loss (torch.Tensor): Loss values accumulated during training 329 lr_find_lr (torch.Tensor): Learning rate used for mini-batch 330 331 Returns: 332 float: The learning rate at the point when loss was falling most steeply 333 divided by a fudge factor. 334 """ 335 default_min_lr = cfg.DEFAULT_MIN_LR # Add as default value to fix bug 336 # Get loss values and their corresponding gradients, and get lr values 337 for i in range(0, len(lr_find_loss)): 338 if lr_find_loss[i].is_cuda: 339 lr_find_loss[i] = lr_find_loss[i].cpu() 340 lr_find_loss[i] = lr_find_loss[i].detach().numpy() 341 losses = np.array(lr_find_loss) 342 try: 343 gradients = np.gradient(losses) 344 min_gradient = gradients.min() 345 if min_gradient < 0: 346 min_loss_grad_idx = gradients.argmin() 347 else: 348 logging.info( 349 f"Minimum gradient: {min_gradient} was positive, returning default value instead." 350 ) 351 return default_min_lr 352 except Exception as e: 353 logging.info(f"Failed to compute gradients, returning default value. {e}") 354 return default_min_lr 355 min_lr = lr_find_lr[min_loss_grad_idx] 356 return min_lr / cfg.LR_DIVISOR 357 358 def lr_exp_stepper(self, x): 359 """Exponentially increase learning rate as part of strategy to find the 360 optimum. 361 Taken from 362 https://towardsdatascience.com/adaptive-and-cyclical-learning-rates-using-pytorch-2bf904d18dee 363 """ 364 return math.exp( 365 x * self.log_lr_ratio / (self.lr_find_epochs * len(self.training_loader)) 366 ) 367 368 def create_optimizer(self, learning_rate): 369 return torch.optim.AdamW(self.model.parameters(), lr=learning_rate) 370 371 def create_exponential_lr_scheduler(self): 372 return torch.optim.lr_scheduler.LambdaLR(self.optimizer, self.lr_exp_stepper) 373 374 def create_oc_lr_scheduler(self, num_epochs, lr_to_use): 375 return torch.optim.lr_scheduler.OneCycleLR( 376 self.optimizer, 377 max_lr=lr_to_use, 378 steps_per_epoch=len(self.training_loader), 379 epochs=num_epochs, 380 pct_start=self.settings.pct_lr_inc, 381 ) 382 383 def create_early_stopping(self, output_path, patience, best_score=None): 384 return EarlyStopping( 385 patience=patience, 386 verbose=True, 387 path=output_path, 388 model_dict=self.model_struc_dict, 389 best_score=best_score, 390 ) 391 392 def train_one_batch(self, lr_scheduler, batch): 393 inputs, targets = utils.prepare_training_batch( 394 batch, self.model_device_num, self.label_no 395 ) 396 self.optimizer.zero_grad() 397 output = self.model(inputs) # Forward pass 398 if self.settings.loss_criterion == "CrossEntropyLoss": 399 loss = self.loss_criterion(output, torch.argmax(targets, dim=1)) 400 else: 401 loss = self.loss_criterion(output, targets.float()) 402 loss.backward() # Backward pass 403 self.optimizer.step() 404 lr_scheduler.step() # update the learning rate 405 return loss 406 407 def output_loss_fig(self, model_out_path): 408 409 fig = plt.figure(figsize=(10, 8)) 410 plt.plot( 411 range(1, len(self.avg_train_losses) + 1), 412 self.avg_train_losses, 413 label="Training Loss", 414 ) 415 plt.plot( 416 range(1, len(self.avg_valid_losses) + 1), 417 self.avg_valid_losses, 418 label="Validation Loss", 419 ) 420 421 minposs = ( 422 self.avg_valid_losses.index(min(self.avg_valid_losses)) + 1 423 ) # find position of lowest validation loss 424 plt.axvline( 425 minposs, linestyle="--", color="r", label="Early Stopping Checkpoint" 426 ) 427 428 plt.xlabel("epochs") 429 plt.ylabel("loss") 430 plt.xlim(0, len(self.avg_train_losses) + 1) # consistent scale 431 plt.grid(True) 432 plt.legend() 433 plt.tight_layout() 434 output_dir = model_out_path.parent 435 fig_out_pth = output_dir / f"{model_out_path.stem}_loss_plot.png" 436 logging.info(f"Saving figure of training/validation losses to {fig_out_pth}") 437 fig.savefig(fig_out_pth, bbox_inches="tight") 438 # Output a list of training stats 439 epoch_lst = range(len(self.avg_train_losses)) 440 rows = zip( 441 epoch_lst, 442 self.avg_train_losses, 443 self.avg_valid_losses, 444 self.avg_eval_scores, 445 ) 446 with open(output_dir / f"{model_out_path.stem}_train_stats.csv", "w") as f: 447 writer = csv.writer(f) 448 writer.writerow(("Epoch", "Train Loss", "Valid Loss", "Eval Score")) 449 for row in rows: 450 writer.writerow(row) 451 452 def output_prediction_figure(self, model_path): 453 """Saves a figure containing image slice data for three random images 454 fromthe validation dataset along with the corresponding ground truth 455 label image and corresponding prediction output from the model attached 456 to this class instance. The image is saved to the same directory as the 457 model weights. 458 459 Args: 460 model_path (pathlib.Path): Full path to the model weights file, 461 this is used to get the directory and name of the model not to 462 load and predict. 463 """ 464 self.model.eval() # prep model for evaluation 465 batch = next(iter(self.validation_loader)) # Get first batch 466 with torch.no_grad(): 467 inputs, targets = utils.prepare_training_batch( 468 batch, self.model_device_num, self.label_no 469 ) 470 output = self.model(inputs) # Forward pass 471 s_max = nn.Softmax(dim=1) 472 probs = s_max(output) # Convert the logits to probs 473 labels = torch.argmax(probs, dim=1) # flatten channels 474 475 # Create the plot 476 bs = self.validation_loader.batch_size 477 if bs < 4: 478 rows = bs 479 else: 480 rows = 4 481 fig = plt.figure(figsize=(12, 16)) 482 columns = 3 483 j = 0 484 for i in range(columns * rows)[::3]: 485 img = inputs[j].squeeze().cpu() 486 gt = torch.argmax(targets[j], dim=0).cpu() 487 pred = labels[j].cpu() 488 col1 = fig.add_subplot(rows, columns, i + 1) 489 plt.imshow(img, cmap="gray") 490 col2 = fig.add_subplot(rows, columns, i + 2) 491 plt.imshow(gt, cmap="gray") 492 col3 = fig.add_subplot(rows, columns, i + 3) 493 plt.imshow(pred, cmap="gray") 494 j += 1 495 if i == 0: 496 col1.title.set_text("Data") 497 col2.title.set_text("Ground Truth") 498 col3.title.set_text("Prediction") 499 plt.suptitle(f"Predictions for {model_path.name}", fontsize=16) 500 plt_out_pth = model_path.parent / f"{model_path.stem}_prediction_image.png" 501 logging.info(f"Saving example image predictions to {plt_out_pth}") 502 plt.savefig(plt_out_pth, dpi=300)
Class that utlises 2d dataloaders to train a 2d deep learning model.
Args
- sampler
- settings
42 def __init__( 43 self, image_dir_path, label_dir_path, labels: Union[int, dict], settings 44 ): 45 self.training_loader, self.validation_loader = get_2d_training_dataloaders( 46 image_dir_path, label_dir_path, settings 47 ) 48 self.label_no = labels if isinstance(labels, int) else len(labels) 49 self.codes = labels if isinstance(labels, dict) else {} 50 self.settings = settings 51 # Params for learning rate finder 52 self.starting_lr = float(settings.starting_lr) 53 self.end_lr = float(settings.end_lr) 54 self.log_lr_ratio = self.calculate_log_lr_ratio() 55 self.lr_find_epochs = settings.lr_find_epochs 56 self.lr_reduce_factor = settings.lr_reduce_factor 57 # Params for model training 58 self.model_device_num = int(settings.cuda_device) 59 self.patience = settings.patience 60 self.loss_criterion = self.get_loss_criterion() 61 self.eval_metric = self.get_eval_metric() 62 self.model_struc_dict = self.get_model_struc_dict(settings) 63 self.avg_train_losses = [] # per epoch training loss 64 self.avg_valid_losses = [] # per epoch validation loss 65 self.avg_eval_scores = [] # per epoch evaluation score
67 def get_model_struc_dict(self, settings): 68 model_struc_dict = settings.model 69 model_type = utils.get_model_type(settings) 70 model_struc_dict["type"] = model_type 71 model_struc_dict["in_channels"] = cfg.MODEL_INPUT_CHANNELS 72 model_struc_dict["classes"] = self.label_no 73 return model_struc_dict
78 def create_model_and_optimiser(self, learning_rate, frozen=False): 79 logging.info(f"Setting up the model on device {self.settings.cuda_device}.") 80 self.model = create_model_on_device( 81 self.model_device_num, self.model_struc_dict 82 ) 83 if frozen: 84 self.freeze_model() 85 logging.info( 86 f"Model has {self.count_trainable_parameters()} trainable parameters, {self.count_parameters()} total parameters." 87 ) 88 self.optimizer = self.create_optimizer(learning_rate) 89 logging.info("Trainer created.")
91 def freeze_model(self): 92 logging.info( 93 f"Freezing model with {self.count_trainable_parameters()} trainable parameters, {self.count_parameters()} total parameters." 94 ) 95 for name, param in self.model.named_parameters(): 96 if all(["encoder" in name, "conv" in name]) and param.requires_grad: 97 param.requires_grad = False
99 def unfreeze_model(self): 100 logging.info( 101 f"Unfreezing model with {self.count_trainable_parameters()} trainable parameters, {self.count_parameters()} total parameters." 102 ) 103 for name, param in self.model.named_parameters(): 104 if all(["encoder" in name, "conv" in name]) and not param.requires_grad: 105 param.requires_grad = True
113 def get_loss_criterion(self): 114 if self.settings.loss_criterion == "BCEDiceLoss": 115 alpha = self.settings.alpha 116 beta = self.settings.beta 117 logging.info( 118 f"Using combined BCE and Dice loss with weighting of {alpha}*BCE " 119 f"and {beta}*Dice" 120 ) 121 loss_criterion = BCEDiceLoss(alpha, beta) 122 elif self.settings.loss_criterion == "DiceLoss": 123 logging.info("Using DiceLoss") 124 loss_criterion = DiceLoss(normalization="none") 125 elif self.settings.loss_criterion == "BCELoss": 126 logging.info("Using BCELoss") 127 loss_criterion = nn.BCEWithLogitsLoss() 128 elif self.settings.loss_criterion == "CrossEntropyLoss": 129 logging.info("Using CrossEntropyLoss") 130 loss_criterion = nn.CrossEntropyLoss() 131 elif self.settings.loss_criterion == "GeneralizedDiceLoss": 132 logging.info("Using GeneralizedDiceLoss") 133 loss_criterion = GeneralizedDiceLoss() 134 else: 135 logging.error("No loss criterion specified, exiting") 136 sys.exit(1) 137 return loss_criterion
139 def get_eval_metric(self): 140 # Get evaluation metric 141 if self.settings.eval_metric == "MeanIoU": 142 logging.info("Using MeanIoU") 143 eval_metric = MeanIoU() 144 elif self.settings.eval_metric == "GenericAveragePrecision": 145 logging.info("Using GenericAveragePrecision") 146 eval_metric = GenericAveragePrecision() 147 else: 148 logging.error("No evaluation metric specified, exiting") 149 sys.exit(1) 150 return eval_metric
152 def train_model(self, output_path, num_epochs, patience, create=True, frozen=False): 153 """Performs training of model for a number of cycles 154 with a learning rate that is determined automatically. 155 """ 156 train_losses = [] 157 valid_losses = [] 158 eval_scores = [] 159 160 if create: 161 self.create_model_and_optimiser(self.starting_lr, frozen=frozen) 162 lr_to_use = self.run_lr_finder() 163 # Recreate model and start training 164 self.create_model_and_optimiser(lr_to_use, frozen=frozen) 165 early_stopping = self.create_early_stopping(output_path, patience) 166 else: 167 # Reduce starting LR, since model alreadiy partiallly trained 168 self.starting_lr /= self.lr_reduce_factor 169 self.end_lr /= self.lr_reduce_factor 170 self.log_lr_ratio = self.calculate_log_lr_ratio() 171 self.load_in_model_and_optimizer( 172 self.starting_lr, output_path, frozen=frozen, optimizer=False 173 ) 174 lr_to_use = self.run_lr_finder() 175 min_loss = self.load_in_model_and_optimizer( 176 self.starting_lr, output_path, frozen=frozen, optimizer=False 177 ) 178 early_stopping = self.create_early_stopping( 179 output_path, patience, best_score=-min_loss 180 ) 181 182 # Initialise the One Cycle learning rate scheduler 183 lr_scheduler = self.create_oc_lr_scheduler(num_epochs, lr_to_use) 184 185 for epoch in range(1, num_epochs + 1): 186 self.model.train() 187 tic = time.perf_counter() 188 logging.info(f"Epoch {epoch} of {num_epochs}") 189 for batch in tqdm( 190 self.training_loader, 191 desc="Training batch", 192 bar_format=cfg.TQDM_BAR_FORMAT, 193 ): 194 loss = self.train_one_batch(lr_scheduler, batch) 195 train_losses.append(loss.item()) # record training loss 196 197 self.model.eval() # prep model for evaluation 198 with torch.no_grad(): 199 for batch in tqdm( 200 self.validation_loader, 201 desc="Validation batch", 202 bar_format=cfg.TQDM_BAR_FORMAT, 203 ): 204 inputs, targets = utils.prepare_training_batch( 205 batch, self.model_device_num, self.label_no 206 ) 207 output = self.model(inputs) # Forward pass 208 # calculate the loss 209 if self.settings.loss_criterion == "CrossEntropyLoss": 210 loss = self.loss_criterion(output, torch.argmax(targets, dim=1)) 211 else: 212 loss = self.loss_criterion(output, targets.float()) 213 valid_losses.append(loss.item()) # record validation loss 214 s_max = nn.Softmax(dim=1) 215 probs = s_max(output) # Convert the logits to probs 216 probs = torch.unsqueeze(probs, 2) 217 targets = torch.unsqueeze(targets, 2) 218 eval_score = self.eval_metric(probs, targets) 219 eval_scores.append(eval_score) # record eval metric 220 221 toc = time.perf_counter() 222 # calculate average loss/metric over an epoch 223 self.avg_train_losses.append(np.average(train_losses)) 224 self.avg_valid_losses.append(np.average(valid_losses)) 225 self.avg_eval_scores.append(np.average(eval_scores)) 226 logging.info( 227 f"Epoch {epoch}. Training loss: {self.avg_train_losses[-1]}, Validation Loss: " 228 f"{self.avg_valid_losses[-1]}. {self.settings.eval_metric}: {self.avg_eval_scores[-1]}" 229 ) 230 logging.info(f"Time taken for epoch {epoch}: {toc - tic:0.2f} seconds") 231 # clear lists to track next epoch 232 train_losses = [] 233 valid_losses = [] 234 eval_scores = [] 235 236 # early_stopping needs the validation loss to check if it has decreased, 237 # and if it has, it will make a checkpoint of the current model 238 early_stopping( 239 self.avg_valid_losses[-1], self.model, self.optimizer, self.codes 240 ) 241 242 if early_stopping.early_stop: 243 logging.info("Early stopping") 244 break 245 246 # load the last checkpoint with the best model 247 self.load_in_weights(output_path)
Performs training of model for a number of cycles with a learning rate that is determined automatically.
249 def load_in_model_and_optimizer( 250 self, learning_rate, output_path, frozen=False, optimizer=False 251 ): 252 self.create_model_and_optimiser(learning_rate, frozen=frozen) 253 logging.info("Loading in weights from saved checkpoint.") 254 loss_val = self.load_in_weights(output_path, optimizer=optimizer) 255 return loss_val
257 def load_in_weights(self, output_path, optimizer=False, gpu=True): 258 # load the last checkpoint with the best model 259 if gpu: 260 map_location = f"cuda:{self.model_device_num}" 261 else: 262 map_location = "cpu" 263 model_dict = torch.load(output_path, map_location=map_location) 264 logging.info("Loading model weights.") 265 self.model.load_state_dict(model_dict["model_state_dict"]) 266 if optimizer: 267 logging.info("Loading optimizer weights.") 268 self.optimizer.load_state_dict(model_dict["optimizer_state_dict"]) 269 return model_dict.get("loss_val", np.inf)
271 def run_lr_finder(self): 272 logging.info("Finding learning rate for model.") 273 lr_scheduler = self.create_exponential_lr_scheduler() 274 lr_find_loss, lr_find_lr = self.lr_finder(lr_scheduler) 275 lr_to_use = self.find_lr_from_graph(lr_find_loss, lr_find_lr) 276 logging.info(f"LR to use {lr_to_use}") 277 return lr_to_use
279 def lr_finder(self, lr_scheduler, smoothing=0.05, plt_fig=True): 280 lr_find_loss = [] 281 lr_find_lr = [] 282 iters = 0 283 284 self.model.train() 285 logging.info( 286 f"Training for {self.lr_find_epochs} epochs to create a learning " 287 "rate plot." 288 ) 289 for i in range(self.lr_find_epochs): 290 for batch in tqdm( 291 self.training_loader, 292 desc=f"Epoch {i + 1}, batch number", 293 bar_format=cfg.TQDM_BAR_FORMAT, 294 ): 295 loss = self.train_one_batch(lr_scheduler, batch) 296 lr_step = self.optimizer.state_dict()["param_groups"][0]["lr"] 297 lr_find_lr.append(lr_step) 298 if iters == 0: 299 lr_find_loss.append(loss) 300 else: 301 loss = smoothing * loss + (1 - smoothing) * lr_find_loss[-1] 302 lr_find_loss.append(loss) 303 if loss > 1 and iters > len(self.training_loader) // 1.333: 304 break 305 iters += 1 306 307 if plt_fig: 308 fig = tpl.figure() 309 fig.plot( 310 np.log10(lr_find_lr), 311 lr_find_loss, 312 width=50, 313 height=30, 314 xlabel="Log10 Learning Rate", 315 ) 316 fig.show() 317 318 return lr_find_loss, lr_find_lr
320 @staticmethod 321 def find_lr_from_graph( 322 lr_find_loss: torch.Tensor, lr_find_lr: torch.Tensor 323 ) -> float: 324 """Calculates learning rate corresponsing to minimum gradient in graph 325 of loss vs learning rate. 326 327 Args: 328 lr_find_loss (torch.Tensor): Loss values accumulated during training 329 lr_find_lr (torch.Tensor): Learning rate used for mini-batch 330 331 Returns: 332 float: The learning rate at the point when loss was falling most steeply 333 divided by a fudge factor. 334 """ 335 default_min_lr = cfg.DEFAULT_MIN_LR # Add as default value to fix bug 336 # Get loss values and their corresponding gradients, and get lr values 337 for i in range(0, len(lr_find_loss)): 338 if lr_find_loss[i].is_cuda: 339 lr_find_loss[i] = lr_find_loss[i].cpu() 340 lr_find_loss[i] = lr_find_loss[i].detach().numpy() 341 losses = np.array(lr_find_loss) 342 try: 343 gradients = np.gradient(losses) 344 min_gradient = gradients.min() 345 if min_gradient < 0: 346 min_loss_grad_idx = gradients.argmin() 347 else: 348 logging.info( 349 f"Minimum gradient: {min_gradient} was positive, returning default value instead." 350 ) 351 return default_min_lr 352 except Exception as e: 353 logging.info(f"Failed to compute gradients, returning default value. {e}") 354 return default_min_lr 355 min_lr = lr_find_lr[min_loss_grad_idx] 356 return min_lr / cfg.LR_DIVISOR
Calculates learning rate corresponsing to minimum gradient in graph of loss vs learning rate.
Args
- lr_find_loss (torch.Tensor): Loss values accumulated during training
- lr_find_lr (torch.Tensor): Learning rate used for mini-batch
Returns
float: The learning rate at the point when loss was falling most steeply divided by a fudge factor.
358 def lr_exp_stepper(self, x): 359 """Exponentially increase learning rate as part of strategy to find the 360 optimum. 361 Taken from 362 https://towardsdatascience.com/adaptive-and-cyclical-learning-rates-using-pytorch-2bf904d18dee 363 """ 364 return math.exp( 365 x * self.log_lr_ratio / (self.lr_find_epochs * len(self.training_loader)) 366 )
Exponentially increase learning rate as part of strategy to find the optimum. Taken from https://towardsdatascience.com/adaptive-and-cyclical-learning-rates-using-pytorch-2bf904d18dee
392 def train_one_batch(self, lr_scheduler, batch): 393 inputs, targets = utils.prepare_training_batch( 394 batch, self.model_device_num, self.label_no 395 ) 396 self.optimizer.zero_grad() 397 output = self.model(inputs) # Forward pass 398 if self.settings.loss_criterion == "CrossEntropyLoss": 399 loss = self.loss_criterion(output, torch.argmax(targets, dim=1)) 400 else: 401 loss = self.loss_criterion(output, targets.float()) 402 loss.backward() # Backward pass 403 self.optimizer.step() 404 lr_scheduler.step() # update the learning rate 405 return loss
407 def output_loss_fig(self, model_out_path): 408 409 fig = plt.figure(figsize=(10, 8)) 410 plt.plot( 411 range(1, len(self.avg_train_losses) + 1), 412 self.avg_train_losses, 413 label="Training Loss", 414 ) 415 plt.plot( 416 range(1, len(self.avg_valid_losses) + 1), 417 self.avg_valid_losses, 418 label="Validation Loss", 419 ) 420 421 minposs = ( 422 self.avg_valid_losses.index(min(self.avg_valid_losses)) + 1 423 ) # find position of lowest validation loss 424 plt.axvline( 425 minposs, linestyle="--", color="r", label="Early Stopping Checkpoint" 426 ) 427 428 plt.xlabel("epochs") 429 plt.ylabel("loss") 430 plt.xlim(0, len(self.avg_train_losses) + 1) # consistent scale 431 plt.grid(True) 432 plt.legend() 433 plt.tight_layout() 434 output_dir = model_out_path.parent 435 fig_out_pth = output_dir / f"{model_out_path.stem}_loss_plot.png" 436 logging.info(f"Saving figure of training/validation losses to {fig_out_pth}") 437 fig.savefig(fig_out_pth, bbox_inches="tight") 438 # Output a list of training stats 439 epoch_lst = range(len(self.avg_train_losses)) 440 rows = zip( 441 epoch_lst, 442 self.avg_train_losses, 443 self.avg_valid_losses, 444 self.avg_eval_scores, 445 ) 446 with open(output_dir / f"{model_out_path.stem}_train_stats.csv", "w") as f: 447 writer = csv.writer(f) 448 writer.writerow(("Epoch", "Train Loss", "Valid Loss", "Eval Score")) 449 for row in rows: 450 writer.writerow(row)
452 def output_prediction_figure(self, model_path): 453 """Saves a figure containing image slice data for three random images 454 fromthe validation dataset along with the corresponding ground truth 455 label image and corresponding prediction output from the model attached 456 to this class instance. The image is saved to the same directory as the 457 model weights. 458 459 Args: 460 model_path (pathlib.Path): Full path to the model weights file, 461 this is used to get the directory and name of the model not to 462 load and predict. 463 """ 464 self.model.eval() # prep model for evaluation 465 batch = next(iter(self.validation_loader)) # Get first batch 466 with torch.no_grad(): 467 inputs, targets = utils.prepare_training_batch( 468 batch, self.model_device_num, self.label_no 469 ) 470 output = self.model(inputs) # Forward pass 471 s_max = nn.Softmax(dim=1) 472 probs = s_max(output) # Convert the logits to probs 473 labels = torch.argmax(probs, dim=1) # flatten channels 474 475 # Create the plot 476 bs = self.validation_loader.batch_size 477 if bs < 4: 478 rows = bs 479 else: 480 rows = 4 481 fig = plt.figure(figsize=(12, 16)) 482 columns = 3 483 j = 0 484 for i in range(columns * rows)[::3]: 485 img = inputs[j].squeeze().cpu() 486 gt = torch.argmax(targets[j], dim=0).cpu() 487 pred = labels[j].cpu() 488 col1 = fig.add_subplot(rows, columns, i + 1) 489 plt.imshow(img, cmap="gray") 490 col2 = fig.add_subplot(rows, columns, i + 2) 491 plt.imshow(gt, cmap="gray") 492 col3 = fig.add_subplot(rows, columns, i + 3) 493 plt.imshow(pred, cmap="gray") 494 j += 1 495 if i == 0: 496 col1.title.set_text("Data") 497 col2.title.set_text("Ground Truth") 498 col3.title.set_text("Prediction") 499 plt.suptitle(f"Predictions for {model_path.name}", fontsize=16) 500 plt_out_pth = model_path.parent / f"{model_path.stem}_prediction_image.png" 501 logging.info(f"Saving example image predictions to {plt_out_pth}") 502 plt.savefig(plt_out_pth, dpi=300)
Saves a figure containing image slice data for three random images fromthe validation dataset along with the corresponding ground truth label image and corresponding prediction output from the model attached to this class instance. The image is saved to the same directory as the model weights.
Args
- model_path (pathlib.Path): Full path to the model weights file,
- this is used to get the directory and name of the model not to
- load and predict.
17class VolSeg2dPredictor: 18 """Class that performs U-Net prediction operations. Does not interact with disk.""" 19 20 def __init__(self, model_file_path: str, settings: SimpleNamespace) -> None: 21 self.model_file_path = Path(model_file_path) 22 self.settings = settings 23 self.model_device_num = int(settings.cuda_device) 24 model_tuple = create_model_from_file( 25 self.model_file_path, self.model_device_num 26 ) 27 self.model, self.num_labels, self.label_codes = model_tuple 28 29 def get_model_from_trainer(self, trainer): 30 self.model = trainer.model 31 32 def predict_single_axis(self, data_vol, output_probs=False, axis=Axis.Z): 33 output_vol_list = [] 34 output_prob_list = [] 35 data_vol = utils.rotate_array_to_axis(data_vol, axis) 36 yx_dims = list(data_vol.shape[1:]) 37 s_max = nn.Softmax(dim=1) 38 data_loader = get_2d_prediction_dataloader(data_vol, self.settings) 39 self.model.eval() 40 logging.info(f"Predicting segmentation for volume of shape {data_vol.shape}.") 41 with torch.no_grad(): 42 for batch in tqdm( 43 data_loader, desc="Prediction batch", bar_format=cfg.TQDM_BAR_FORMAT 44 ): 45 output = self.model(batch.to(self.model_device_num)) # Forward pass 46 probs = s_max(output) # Convert the logits to probs 47 # TODO: Don't flatten channels if one-hot output is needed 48 labels = torch.argmax(probs, dim=1) # flatten channels 49 labels = utils.crop_tensor_to_array(labels, yx_dims) 50 output_vol_list.append(labels.astype(np.uint8)) 51 if output_probs: 52 # Get indices of max probs 53 max_prob_idx = torch.argmax(probs, dim=1, keepdim=True) 54 # Extract along axis from outputs 55 probs = torch.gather(probs, 1, max_prob_idx) 56 # Remove the label dimension 57 probs = torch.squeeze(probs, dim=1) 58 probs = utils.crop_tensor_to_array(probs, yx_dims) 59 output_prob_list.append(probs.astype(np.float16)) 60 61 labels = np.concatenate(output_vol_list) 62 labels = utils.rotate_array_to_axis(labels, axis) 63 probs = np.concatenate(output_prob_list) if output_prob_list else None 64 if probs is not None: 65 probs = utils.rotate_array_to_axis(probs, axis) 66 return labels, probs 67 68 def predict_3_ways_max_probs(self, data_vol): 69 shape_tup = data_vol.shape 70 logging.info("Creating empty data volumes in RAM to combine 3 axis prediction.") 71 label_container = np.empty((2, *shape_tup), dtype=np.uint8) 72 prob_container = np.empty((2, *shape_tup), dtype=np.float16) 73 logging.info("Predicting YX slices:") 74 label_container[0], prob_container[0] = self.predict_single_axis( 75 data_vol, output_probs=True 76 ) 77 logging.info("Predicting ZX slices:") 78 label_container[1], prob_container[1] = self.predict_single_axis( 79 data_vol, output_probs=True, axis=Axis.Y 80 ) 81 logging.info("Merging XY and ZX volumes.") 82 self.merge_vols_in_mem(prob_container, label_container) 83 logging.info("Predicting ZY slices:") 84 label_container[1], prob_container[1] = self.predict_single_axis( 85 data_vol, output_probs=True, axis=Axis.X 86 ) 87 logging.info("Merging max of XY and ZX volumes with ZY volume.") 88 self.merge_vols_in_mem(prob_container, label_container) 89 return label_container[0], prob_container[0] 90 91 def merge_vols_in_mem(self, prob_container, label_container): 92 max_prob_idx = np.argmax(prob_container, axis=0) 93 max_prob_idx = max_prob_idx[np.newaxis, :, :, :] 94 prob_container[0] = np.squeeze( 95 np.take_along_axis(prob_container, max_prob_idx, axis=0) 96 ) 97 label_container[0] = np.squeeze( 98 np.take_along_axis(label_container, max_prob_idx, axis=0) 99 ) 100 101 def predict_12_ways_max_probs(self, data_vol): 102 shape_tup = data_vol.shape 103 logging.info("Creating empty data volumes in RAM to combine 12 way prediction.") 104 label_container = np.empty((2, *shape_tup), dtype=np.uint8) 105 prob_container = np.empty((2, *shape_tup), dtype=np.float16) 106 label_container[0], prob_container[0] = self.predict_3_ways_max_probs(data_vol) 107 for k in range(1, 4): 108 logging.info(f"Rotating volume {k * 90} degrees") 109 data_vol = np.rot90(data_vol) 110 labels, probs = self.predict_3_ways_max_probs(data_vol) 111 label_container[1] = np.rot90(labels, -k) 112 prob_container[1] = np.rot90(probs, -k) 113 logging.info( 114 f"Merging rot {k * 90} deg volume with rot {(k-1) * 90} deg volume." 115 ) 116 self.merge_vols_in_mem(prob_container, label_container) 117 return label_container[0], prob_container[0] 118 119 def predict_single_axis_to_one_hot(self, data_vol, axis=Axis.Z): 120 prediction, _ = self.predict_single_axis(data_vol, axis=axis) 121 return utils.one_hot_encode_array(prediction, self.num_labels) 122 123 def predict_3_ways_one_hot(self, data_vol): 124 one_hot_out = self.predict_single_axis_to_one_hot(data_vol) 125 one_hot_out += self.predict_single_axis_to_one_hot(data_vol, Axis.Y) 126 one_hot_out += self.predict_single_axis_to_one_hot(data_vol, Axis.X) 127 return one_hot_out 128 129 def predict_12_ways_one_hot(self, data_vol): 130 one_hot_out = self.predict_3_ways_one_hot(data_vol) 131 for k in range(1, 4): 132 logging.info(f"Rotating volume {k * 90} degrees") 133 data_vol = np.rot90(data_vol) 134 one_hot_out += np.rot90( 135 self.predict_3_ways_one_hot(data_vol), -k, axes=(-3, -2) 136 ) 137 return one_hot_out
Class that performs U-Net prediction operations. Does not interact with disk.
20 def __init__(self, model_file_path: str, settings: SimpleNamespace) -> None: 21 self.model_file_path = Path(model_file_path) 22 self.settings = settings 23 self.model_device_num = int(settings.cuda_device) 24 model_tuple = create_model_from_file( 25 self.model_file_path, self.model_device_num 26 ) 27 self.model, self.num_labels, self.label_codes = model_tuple
32 def predict_single_axis(self, data_vol, output_probs=False, axis=Axis.Z): 33 output_vol_list = [] 34 output_prob_list = [] 35 data_vol = utils.rotate_array_to_axis(data_vol, axis) 36 yx_dims = list(data_vol.shape[1:]) 37 s_max = nn.Softmax(dim=1) 38 data_loader = get_2d_prediction_dataloader(data_vol, self.settings) 39 self.model.eval() 40 logging.info(f"Predicting segmentation for volume of shape {data_vol.shape}.") 41 with torch.no_grad(): 42 for batch in tqdm( 43 data_loader, desc="Prediction batch", bar_format=cfg.TQDM_BAR_FORMAT 44 ): 45 output = self.model(batch.to(self.model_device_num)) # Forward pass 46 probs = s_max(output) # Convert the logits to probs 47 # TODO: Don't flatten channels if one-hot output is needed 48 labels = torch.argmax(probs, dim=1) # flatten channels 49 labels = utils.crop_tensor_to_array(labels, yx_dims) 50 output_vol_list.append(labels.astype(np.uint8)) 51 if output_probs: 52 # Get indices of max probs 53 max_prob_idx = torch.argmax(probs, dim=1, keepdim=True) 54 # Extract along axis from outputs 55 probs = torch.gather(probs, 1, max_prob_idx) 56 # Remove the label dimension 57 probs = torch.squeeze(probs, dim=1) 58 probs = utils.crop_tensor_to_array(probs, yx_dims) 59 output_prob_list.append(probs.astype(np.float16)) 60 61 labels = np.concatenate(output_vol_list) 62 labels = utils.rotate_array_to_axis(labels, axis) 63 probs = np.concatenate(output_prob_list) if output_prob_list else None 64 if probs is not None: 65 probs = utils.rotate_array_to_axis(probs, axis) 66 return labels, probs
68 def predict_3_ways_max_probs(self, data_vol): 69 shape_tup = data_vol.shape 70 logging.info("Creating empty data volumes in RAM to combine 3 axis prediction.") 71 label_container = np.empty((2, *shape_tup), dtype=np.uint8) 72 prob_container = np.empty((2, *shape_tup), dtype=np.float16) 73 logging.info("Predicting YX slices:") 74 label_container[0], prob_container[0] = self.predict_single_axis( 75 data_vol, output_probs=True 76 ) 77 logging.info("Predicting ZX slices:") 78 label_container[1], prob_container[1] = self.predict_single_axis( 79 data_vol, output_probs=True, axis=Axis.Y 80 ) 81 logging.info("Merging XY and ZX volumes.") 82 self.merge_vols_in_mem(prob_container, label_container) 83 logging.info("Predicting ZY slices:") 84 label_container[1], prob_container[1] = self.predict_single_axis( 85 data_vol, output_probs=True, axis=Axis.X 86 ) 87 logging.info("Merging max of XY and ZX volumes with ZY volume.") 88 self.merge_vols_in_mem(prob_container, label_container) 89 return label_container[0], prob_container[0]
91 def merge_vols_in_mem(self, prob_container, label_container): 92 max_prob_idx = np.argmax(prob_container, axis=0) 93 max_prob_idx = max_prob_idx[np.newaxis, :, :, :] 94 prob_container[0] = np.squeeze( 95 np.take_along_axis(prob_container, max_prob_idx, axis=0) 96 ) 97 label_container[0] = np.squeeze( 98 np.take_along_axis(label_container, max_prob_idx, axis=0) 99 )
101 def predict_12_ways_max_probs(self, data_vol): 102 shape_tup = data_vol.shape 103 logging.info("Creating empty data volumes in RAM to combine 12 way prediction.") 104 label_container = np.empty((2, *shape_tup), dtype=np.uint8) 105 prob_container = np.empty((2, *shape_tup), dtype=np.float16) 106 label_container[0], prob_container[0] = self.predict_3_ways_max_probs(data_vol) 107 for k in range(1, 4): 108 logging.info(f"Rotating volume {k * 90} degrees") 109 data_vol = np.rot90(data_vol) 110 labels, probs = self.predict_3_ways_max_probs(data_vol) 111 label_container[1] = np.rot90(labels, -k) 112 prob_container[1] = np.rot90(probs, -k) 113 logging.info( 114 f"Merging rot {k * 90} deg volume with rot {(k-1) * 90} deg volume." 115 ) 116 self.merge_vols_in_mem(prob_container, label_container) 117 return label_container[0], prob_container[0]
129 def predict_12_ways_one_hot(self, data_vol): 130 one_hot_out = self.predict_3_ways_one_hot(data_vol) 131 for k in range(1, 4): 132 logging.info(f"Rotating volume {k * 90} degrees") 133 data_vol = np.rot90(data_vol) 134 one_hot_out += np.rot90( 135 self.predict_3_ways_one_hot(data_vol), -k, axes=(-3, -2) 136 ) 137 return one_hot_out
13class VolSeg2DPredictionManager(BaseDataManager): 14 def __init__( 15 self, 16 predictor: VolSeg2dPredictor, 17 data_vol: Union[str, np.ndarray], 18 settings: SimpleNamespace, 19 ) -> None: 20 super().__init__(data_vol, settings) 21 self.predictor = predictor 22 self.settings = settings 23 24 def predict_volume_to_path( 25 self, output_path: Union[Path, None], quality: Union[utils.Quality, None] = None 26 ) -> np.ndarray: 27 probs = None 28 one_hot = self.settings.one_hot 29 if quality is None: 30 quality = utils.get_prediction_quality(self.settings) 31 if quality == utils.Quality.LOW: 32 if one_hot: 33 prediction = self.predictor.predict_single_axis_to_one_hot( 34 self.data_vol 35 ) 36 else: 37 prediction, probs = self.predictor.predict_single_axis(self.data_vol) 38 if quality == utils.Quality.MEDIUM: 39 if one_hot: 40 prediction = self.predictor.predict_3_ways_one_hot(self.data_vol) 41 else: 42 prediction, probs = self.predictor.predict_3_ways_max_probs( 43 self.data_vol 44 ) 45 if quality == utils.Quality.HIGH: 46 if one_hot: 47 prediction = self.predictor.predict_12_ways_one_hot(self.data_vol) 48 else: 49 prediction, probs = self.predictor.predict_12_ways_max_probs( 50 self.data_vol 51 ) 52 if output_path is not None: 53 utils.save_data_to_hdf5( 54 prediction, output_path, chunking=self.input_data_chunking 55 ) 56 if probs is not None and self.settings.output_probs: 57 utils.save_data_to_hdf5( 58 probs, 59 f"{output_path.parent / output_path.stem}_probs.h5", 60 chunking=self.input_data_chunking, 61 ) 62 return prediction
24 def predict_volume_to_path( 25 self, output_path: Union[Path, None], quality: Union[utils.Quality, None] = None 26 ) -> np.ndarray: 27 probs = None 28 one_hot = self.settings.one_hot 29 if quality is None: 30 quality = utils.get_prediction_quality(self.settings) 31 if quality == utils.Quality.LOW: 32 if one_hot: 33 prediction = self.predictor.predict_single_axis_to_one_hot( 34 self.data_vol 35 ) 36 else: 37 prediction, probs = self.predictor.predict_single_axis(self.data_vol) 38 if quality == utils.Quality.MEDIUM: 39 if one_hot: 40 prediction = self.predictor.predict_3_ways_one_hot(self.data_vol) 41 else: 42 prediction, probs = self.predictor.predict_3_ways_max_probs( 43 self.data_vol 44 ) 45 if quality == utils.Quality.HIGH: 46 if one_hot: 47 prediction = self.predictor.predict_12_ways_one_hot(self.data_vol) 48 else: 49 prediction, probs = self.predictor.predict_12_ways_max_probs( 50 self.data_vol 51 ) 52 if output_path is not None: 53 utils.save_data_to_hdf5( 54 prediction, output_path, chunking=self.input_data_chunking 55 ) 56 if probs is not None and self.settings.output_probs: 57 utils.save_data_to_hdf5( 58 probs, 59 f"{output_path.parent / output_path.stem}_probs.h5", 60 chunking=self.input_data_chunking, 61 ) 62 return prediction
Inherited Members
- volume_segmantics.data.base_data_manager.BaseDataManager
- preprocess_data