import torch
from datasets import load_dataset
from transformers import (
    AutoProcessor,
    GitForCausalLM,
    TrainingArguments,
    Trainer
)

# --- 1. CONFIGURAZIONE DELL'ADDESTRAMENTO ---
DATASET_DIR = "master_dataset_v6"
MODEL_CHECKPOINT = "microsoft/git-base-coco"
OUTPUT_DIR = "tripix-ai-model" # Dove verrà salvato il nostro modello addestrato

# Parametri di addestramento
NUM_EPOCHS = 3 # Quante volte il modello vedrà l'intero dataset
BATCH_SIZE = 4 # Quante immagini processare alla volta (abbassare se si esaurisce la memoria)
LEARNING_RATE = 5e-5

# --- 2. PREPARAZIONE DATI ---
# Carica il dataset dal nostro file metadata.csv
# 'imagefolder' è un caricatore speciale di Hugging Face
dataset = load_dataset("imagefolder", data_dir=DATASET_DIR, split="train")

# Carica il processore del modello (gestisce l'input di immagini e testo)
processor = AutoProcessor.from_pretrained(MODEL_CHECKPOINT)

def transform(example_batch):
    """Pre-processa un batch di dati nel formato richiesto dal modello."""
    images = [img.convert("RGB") for img in example_batch["image"]]
    texts = example_batch["text"]
    
    inputs = processor(images=images, text=texts, padding="max_length", return_tensors="pt")
    
    # L'input del modello si aspetta 'pixel_values' e 'input_ids'
    # L'output (ciò che deve imparare) sono gli 'input_ids' stessi
    inputs.update({"labels": inputs["input_ids"]})
    
    return inputs

# Applica la trasformazione a tutto il dataset
prepared_dataset = dataset.map(function=transform, batched=True, remove_columns=dataset.column_names)

# --- 3. ADDESTRAMENTO ---
# Controlla se è disponibile una GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Dispositivo di addestramento: {device}")

# Carica il modello pre-addestrato
model = GitForCausalLM.from_pretrained(MODEL_CHECKPOINT).to(device)

# Definisci gli argomenti per il training
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    logging_steps=50,
    save_steps=500,
    fp16=torch.cuda.is_available(), # Usa calcoli più veloci se c'è una GPU
    push_to_hub=False, # Non pubblicare il modello online
)

# Crea l'oggetto Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=prepared_dataset,
)

# Lancia l'addestramento!
print("Inizio del processo di fine-tuning...")
trainer.train()

# Salva il modello finale e il processore
print("Addestramento completato. Salvataggio del modello finale.")
model.save_pretrained(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)

print(f"\nModello affinato e salvato in '{OUTPUT_DIR}'. Test di inferenza pronto.")