bead.src.trainers package
Submodules
bead.src.trainers.inference module
- bead.src.trainers.inference.infer(events_bkg, jets_bkg, constituents_bkg, events_sig, jets_sig, constituents_sig, model_path, output_path, config, verbose: bool = False)[source]
- Does the entire training loop by calling the fit() and validate(). Appart from this, this is the main function where the data is converted
to the correct type for it to be trained, via torch.Tensor(). Furthermore, the batching is also done here, based on config.batch_size, and it is the torch.utils.data.DataLoader doing the splitting. Applying either EarlyStopping or LR Scheduler is also done here, all based on their respective config arguments. For reproducibility, the seeds can also be fixed in this function.
- Parameters:
model (modelObject) – The model you wish to train
data (Tuple) – Tuple containing the training and validation data
project_path (string) – Path to the project directory
config (dataClass) – Base class selecting user inputs
- Returns:
fully trained model ready to perform compression and decompression
- Return type:
modelObject
bead.src.trainers.training module
- bead.src.trainers.training.fit(config, model, dataloader, loss_fn, reg_param, optimizer)[source]
This function trains the model on the train set. It computes the losses and does the backwards propagation, and updates the optimizer as well.
- Parameters:
config (dataClass) – Base class selecting user inputs
model (modelObject) – The model you wish to train
train_dl (torch.DataLoader) – Defines the batched data which the model is trained on
loss (lossObject) – Defines the loss function used to train the model
reg_param (float) – Determines proportionality constant to balance different components of the loss.
optimizer (torch.optim) – Chooses optimizer for gradient descent.
- Returns:
Training losses, Epoch_loss and trained model
- Return type:
list, model object
- bead.src.trainers.training.seed_worker(worker_id)[source]
PyTorch implementation to fix the seeds
- Parameters:
() (worker_id)
- bead.src.trainers.training.train(events_train, jets_train, constituents_train, events_val, jets_val, constituents_val, output_path, config, verbose: bool = False)[source]
- Does the entire training loop by calling the fit() and validate(). Appart from this, this is the main function where the data is converted
to the correct type for it to be trained, via torch.Tensor(). Furthermore, the batching is also done here, based on config.batch_size, and it is the torch.utils.data.DataLoader doing the splitting. Applying either EarlyStopping or LR Scheduler is also done here, all based on their respective config arguments. For reproducibility, the seeds can also be fixed in this function.
- Parameters:
model (modelObject) – The model you wish to train
data (Tuple) – Tuple containing the training and validation data
project_path (string) – Path to the project directory
config (dataClass) – Base class selecting user inputs
- Returns:
fully trained model ready to perform compression and decompression
- Return type:
modelObject
- bead.src.trainers.training.validate(config, model, dataloader, loss_fn, reg_param)[source]
Function used to validate the training. Not necessary for doing compression, but gives a good indication of wether the model selected is a good fit or not.
- Parameters:
model (modelObject) – Defines the model one wants to validate. The model used here is passed directly from fit().
test_dl (torch.DataLoader) – Defines the batched data which the model is validated on
model_children (list) – List of model parameters
reg_param (float) – Determines proportionality constant to balance different components of the loss.
- Returns:
Validation loss
- Return type:
float