import torch
import numpy as np
import matplotlib.pyplot as plt
import copy
import math


def train_model(device, model, train_loader, test_loader, criterion, optimizer, 
                fig_dir, fig_pred_dir, spatial_weights, num_epochs=10, 
                early_stopping_patience=20, min_delta=0.0, ckpt_path=None):
    model.to(device)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=3, factor=0.1
    )

    train_loss_values, val_loss_values, epoch_lr_values = [], [], []
    best_val_loss = math.inf
    best_epoch = -1
    best_state = copy.deepcopy(model.state_dict())
    early_stopping_counter = 0

    spatial_weights = spatial_weights.to(device).unsqueeze(0).unsqueeze(0)

    for epoch in range(num_epochs):
        # ---------- TRAIN ----------
        model.train()
        running_train_loss = 0.0
        for inputs, reflectivity in train_loader:
            inputs, reflectivity = inputs.to(device), reflectivity.to(device)
            optimizer.zero_grad()
            outputs = model(inputs, epoch)
            reflectivity = reflectivity.unsqueeze(1)

            batch_weights = spatial_weights.expand_as(outputs)
            loss = criterion(outputs, reflectivity, batch_weights)
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item()

        average_train_loss = running_train_loss / len(train_loader)
        train_loss_values.append(average_train_loss)

        # ---------- VALID ----------
        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for inputs, reflectivity in test_loader:
                inputs, reflectivity = inputs.to(device), reflectivity.to(device)
                outputs = model(inputs, epoch)
                reflectivity = reflectivity.unsqueeze(1)
                batch_weights = spatial_weights.expand_as(outputs)
                loss = criterion(outputs, reflectivity, batch_weights)
                running_val_loss += loss.item()

        average_val_loss = running_val_loss / len(test_loader)
        val_loss_values.append(average_val_loss)

        # ---------- LR + LOGGING ----------
        scheduler.step(average_val_loss)

        # record a single LR per epoch (take first param group)
        epoch_lr_values.append(optimizer.param_groups[0]['lr'])
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"LR: {epoch_lr_values[-1]:.3e} "
              f"Train: {average_train_loss:.4f} "
              f"Val: {average_val_loss:.4f}",
              flush=True)

        # ---------- EARLY STOPPING ----------
        if average_val_loss < best_val_loss - float(min_delta):
            best_val_loss = average_val_loss
            best_epoch = epoch
            best_state = copy.deepcopy(model.state_dict())
            early_stopping_counter = 0
            if ckpt_path:
                torch.save({'epoch': epoch,
                            'model_state_dict': best_state,
                            'optimizer_state_dict': optimizer.state_dict(),
                            'val_loss': best_val_loss},
                           ckpt_path)
        else:
            early_stopping_counter += 1

        if early_stopping_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch+1} "
                  f"(best epoch was {best_epoch+1}, val_loss={best_val_loss:.4f})")
            break

    # restore best weights
    model.load_state_dict(best_state)

    # ---------- PLOT ----------
    fig, ax1 = plt.subplots()
    ax1.plot(train_loss_values, 'r', label='Training Loss')
    ax1.plot(val_loss_values, 'b', label='Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss & Learning Rate Over Epochs')
    ax1.grid(True)

    ax2 = ax1.twinx()
    ax2.plot(epoch_lr_values, 'g', label='Learning Rate')
    ax2.set_ylabel('Learning Rate')

    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
    plt.savefig(f'{fig_dir}/{num_epochs}_train_val_lr.png')
    plt.close()
    print('Training DONE!!!')

