import argparse
import logging
import os
from typing import Any, Dict, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class TrainingConfig:
    """Configuration class for training hyperparameters and settings."""
    def __init__(self, epochs: int = 100, batch_size: int = 32, learning_rate: float = 0.001,
                 weight_decay: float = 1e-4, momentum: float = 0.9, checkpoint_dir: str = './checkpoints'):
        self.epochs = epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.checkpoint_dir = checkpoint_dir

class SimpleDataset(Dataset):
    """A simple dataset class for demonstration purposes."""
    def __init__(self, data: Any, labels: Any):
        self.data = data
        self.labels = labels

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.data[idx], self.labels[idx]

class SimpleModel(nn.Module):
    """A simple neural network model for demonstration purposes."""
    def __init__(self, input_size: int, num_classes: int):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)

def train_one_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module,
                    optimizer: optim.Optimizer, device: torch.device) -> float:
    """Train the model for one epoch."""
    model.train()
    total_loss = 0.0
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def validate(model: nn.Module, dataloader: DataLoader, criterion: nn.Module,
             device: torch.device) -> float:
    """Validate the model on the validation set."""
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    return total_loss / len(dataloader)

def save_checkpoint(state: Dict[str, Any], filename: str) -> None:
    """Save the model checkpoint."""
    torch.save(state, filename)

def load_checkpoint(filename: str, model: nn.Module, optimizer: optim.Optimizer) -> Tuple[nn.Module, optim.Optimizer]:
    """Load the model checkpoint."""
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return model, optimizer

def main(args: argparse.Namespace) -> None:
    """Main function for training the model."""
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize dataset and dataloaders
    train_dataset = SimpleDataset(data=torch.randn(1000, 10), labels=torch.randint(0, 2, (1000,)))
    val_dataset = SimpleDataset(data=torch.randn(200, 10), labels=torch.randint(0, 2, (200,)))
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    # Initialize model, criterion, optimizer, and scheduler
    model = SimpleModel(input_size=10, num_classes=2).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)

    # Training loop
    best_val_loss = float('inf')
    for epoch in range(args.epochs):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss = validate(model, val_loader, criterion, device)
        scheduler.step()

        logging.info(f'Epoch [{epoch+1}/{args.epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

        # Save checkpoint
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch
            }, os.path.join(args.checkpoint_dir, 'best_checkpoint.pth'))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train a simple model.')
    parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs.')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training.')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate for optimizer.')
    parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay for optimizer.')
    parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for SGD optimizer.')
    parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='Directory to save checkpoints.')
    args = parser.parse_args()

    os.makedirs(args.checkpoint_dir, exist_ok=True)
    main(args)