# This script creates and submits job files to train the YOLOv5x models starting from the pre-trained weights
import os

# Create set to track models that finished training already or are currently training, so we don't train them again 
trained_models = {'unaug_backbone_fold_2', 'unaug_last_layer_fold_2', 'aug_backbone_fold_1', 'unaug_backbone_fold_1', 'unaug_last_layer_fold_1', 'unaug_full_fold_3', 'unaug_backbone_fold_4', 'unaug_backbone_fold_5', 'unaug_full_fold_6', 'unaug_backbone_fold_7', 'unaug_last_layer_fold_7', 'unaug_full_fold_8', 'unaug_backbone_fold_9', 'aug_backbone_fold_2', 'unaug_full_fold_10', 'unaug_backbone_fold_10', 'unaug_backbone_fold_3', 'unaug_full_fold_4', 'unaug_last_layer_fold_5', 'aug_last_layer_fold_1', 'unaug_full_fold_7', 'unaug_full_fold_9', 'aug_backbone_fold_5', 'aug_backbone_fold_9', 'aug_full_fold_2', 'unaug_last_layer_fold_9', 'aug_backbone_fold_7', 'unaug_full_fold_1', 'aug_last_layer_fold_2', 'aug_last_layer_fold_5', 'aug_full_fold_6', 'aug_last_layer_fold_7', 'aug_last_layer_fold_9', 'aug_full_fold_1', 'unaug_full_fold_2', 'unaug_last_layer_fold_3', 'aug_last_layer_fold_3', 'aug_full_fold_4', 'aug_full_fold_9', 'aug_full_fold_3', 'aug_backbone_fold_3', 'unaug_last_layer_fold_4', 'aug_backbone_fold_4', 'aug_last_layer_fold_4', 'unaug_full_fold_5', 'unaug_backbone_fold_6', 'unaug_last_layer_fold_6', 'aug_full_fold_5', 'aug_backbone_fold_6', 'aug_full_fold_7', 'unaug_last_layer_fold_8', 'unaug_last_layer_fold_10', 'aug_backbone_fold_10', 'unaug_backbone_fold_8', 'aug_full_fold_10', 'aug_last_layer_fold_6', 'aug_backbone_fold_8', 'aug_last_layer_fold_8', 'aug_last_layer_fold_10'}

# For each of the 10 folds
for fold in range(1, 11):
    # Determine device (i.e., CPU) to run on. Each node has two CPUs, so we alternate between 0 and 1 to get the most resources possible
    device = 0 if fold % 2 == 0 else 1

    # Freeze none of the model, backbone, and all but last layer
    for freeze in [0, 10, 24]:
        # Train model on un-augmented and augmented data
        for train_mode in ['unaug', 'aug']:
            # Create filename and freeze argument
            freeze_arg = ''
            if freeze == 0:
                file_name = f"{train_mode}_full_fold_{fold}.job"
            elif freeze == 10:
                freeze_arg = '--freeze 10'
                file_name = f"{train_mode}_backbone_fold_{fold}.job"
            elif freeze == 24:
                freeze_arg = '--freeze 24'
                file_name = f"{train_mode}_last_layer_fold_{fold}.job"

            # Skip already trained models
            if file_name[:-4] in trained_models:
                continue

            # Determine wall time and dataset to run on. Augmented requires more time
            if train_mode == 'unaug':
                wall_time = 8
                dataset = f'fold_{fold}.yaml'
            else:
                wall_time = 48
                dataset = f'aug_fold_{fold}.yaml'

            # Create job file
            with open(file_name, 'w') as output:
                # Specify resources and other arguments for cluster
                output.write(f"#!/bin/bash\n#SBATCH --time={wall_time}:00:00\n#SBATCH --nodes=1 --ntasks-per-node=1\n#SBATCH --partition=gpu\n#SBATCH --mail-type=BEGIN,END\n#SBATCH --job-name=train_models\n\n")

                # Load python
                output.write("module load anaconda-python3\n")
                # Switch to model directory
                output.write("cd $HOME/yolov5\n") 
                # Activate environment
                output.write('source /software/python/anaconda3/etc/profile.d/conda.sh\n')

                # Train model
                output.write(f'python3 train.py {freeze_arg} --img 640 --batch 8 --exist-ok --epochs 300 --data {dataset} --worker 1 --device {device} --weights yolov5x.pt --name {file_name[:-4]}')

            # Submit job file
            os.system(f"sbatch {file_name}")

