bead.src.trainers package

Submodules

bead.src.trainers.inference module

bead.src.trainers.inference.infer(events_bkg, jets_bkg, constituents_bkg, events_sig, jets_sig, constituents_sig, model_path, output_path, config, verbose: bool = False)[source]
Does the entire training loop by calling the fit() and validate(). Appart from this, this is the main function where the data is converted

to the correct type for it to be trained, via torch.Tensor(). Furthermore, the batching is also done here, based on config.batch_size, and it is the torch.utils.data.DataLoader doing the splitting. Applying either EarlyStopping or LR Scheduler is also done here, all based on their respective config arguments. For reproducibility, the seeds can also be fixed in this function.

Parameters:
  • model (modelObject) – The model you wish to train

  • data (Tuple) – Tuple containing the training and validation data

  • project_path (string) – Path to the project directory

  • config (dataClass) – Base class selecting user inputs

Returns:

fully trained model ready to perform compression and decompression

Return type:

modelObject

bead.src.trainers.inference.seed_worker(worker_id)[source]

PyTorch implementation to fix the seeds :param worker_id ():

bead.src.trainers.training module

bead.src.trainers.training.fit(config, model, dataloader, loss_fn, reg_param, optimizer)[source]

This function trains the model on the train set. It computes the losses and does the backwards propagation, and updates the optimizer as well.

Parameters:
  • config (dataClass) – Base class selecting user inputs

  • model (modelObject) – The model you wish to train

  • train_dl (torch.DataLoader) – Defines the batched data which the model is trained on

  • loss (lossObject) – Defines the loss function used to train the model

  • reg_param (float) – Determines proportionality constant to balance different components of the loss.

  • optimizer (torch.optim) – Chooses optimizer for gradient descent.

Returns:

Training losses, Epoch_loss and trained model

Return type:

list, model object

bead.src.trainers.training.seed_worker(worker_id)[source]

PyTorch implementation to fix the seeds

Parameters:

() (worker_id)

bead.src.trainers.training.train(events_train, jets_train, constituents_train, events_val, jets_val, constituents_val, output_path, config, verbose: bool = False)[source]
Does the entire training loop by calling the fit() and validate(). Appart from this, this is the main function where the data is converted

to the correct type for it to be trained, via torch.Tensor(). Furthermore, the batching is also done here, based on config.batch_size, and it is the torch.utils.data.DataLoader doing the splitting. Applying either EarlyStopping or LR Scheduler is also done here, all based on their respective config arguments. For reproducibility, the seeds can also be fixed in this function.

Parameters:
  • model (modelObject) – The model you wish to train

  • data (Tuple) – Tuple containing the training and validation data

  • project_path (string) – Path to the project directory

  • config (dataClass) – Base class selecting user inputs

Returns:

fully trained model ready to perform compression and decompression

Return type:

modelObject

bead.src.trainers.training.validate(config, model, dataloader, loss_fn, reg_param)[source]

Function used to validate the training. Not necessary for doing compression, but gives a good indication of wether the model selected is a good fit or not.

Parameters:
  • model (modelObject) – Defines the model one wants to validate. The model used here is passed directly from fit().

  • test_dl (torch.DataLoader) – Defines the batched data which the model is validated on

  • model_children (list) – List of model parameters

  • reg_param (float) – Determines proportionality constant to balance different components of the loss.

Returns:

Validation loss

Return type:

float

Module contents