import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import os
import numpy as np

def plot_results(csv_file_path: str):
    """
    Plot line graphs for each combination of metric and estimator showing how the result changes over different iterations.
    Save plots as PNG files in the same directory as the CSV file.
    
    Parameters:
    csv_file_path (str): The path to the CSV file containing the results.
    """
    # Read the CSV file into a DataFrame
    df = pd.read_csv(csv_file_path)

    # Handle the case where 'policy' column exists (SCOPE-RL format)
    if 'policy' in df.columns:
        print("Detected SCOPE-RL format - plotting with policy dimension")
        # Get unique metrics, estimators, and policies
        metrics = df['metric'].unique()
        estimators = df['estimator'].unique()
        policies = df['policy'].unique()
        
        # Create plots directory path (same directory as CSV)
        csv_dir = os.path.dirname(csv_file_path)
        csv_basename = os.path.splitext(os.path.basename(csv_file_path))[0]

        # Plot line graphs for each combination of metric, policy, and estimator
        for metric in metrics:
            plt.figure(figsize=(12, 8))
            
            for policy in policies:
                for estimator in estimators:
                    subset = df[(df['metric'] == metric) & 
                              (df['estimator'] == estimator) & 
                              (df['policy'] == policy)]
                    
                    if len(subset) > 0:
                        # Filter out non-numeric results
                        subset = subset.copy()
                        subset['result'] = pd.to_numeric(subset['result'], errors='coerce')
                        subset = subset.dropna(subset=['result'])
                        
                        if len(subset) > 0:
                            label = f'{policy}_{estimator}'
                            plt.plot(subset['iteration'], subset['result'], 
                                   label=label, marker='o', linewidth=2)
            
            plt.title(f'Results for {metric}')
            plt.xlabel('Iteration')
            plt.ylabel('Result')
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.grid(True, alpha=0.3)
            plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
            
            # Save plot as PNG file
            plot_filename = f'{csv_basename}_{metric}_plot.png'
            plot_path = os.path.join(csv_dir, plot_filename)
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            print(f"Plot saved: {plot_path}")
            
            # Close the figure to free memory
            plt.close()
    
    else:
        # Handle OBP format (backward compatibility)
        print("Detected OBP format - plotting without policy dimension")
        # Get unique metrics and estimators
        metrics = df['metric'].unique()
        estimators = df['estimator'].unique()
        
        # Create plots directory path (same directory as CSV)
        csv_dir = os.path.dirname(csv_file_path)
        csv_basename = os.path.splitext(os.path.basename(csv_file_path))[0]

        # Plot line graphs for each combination of metric and estimator
        for metric in metrics:
            plt.figure(figsize=(10, 6))
            for estimator in estimators:
                subset = df[(df['metric'] == metric) & (df['estimator'] == estimator)]
                
                if len(subset) > 0:
                    # Filter out non-numeric results
                    subset = subset.copy()
                    subset['result'] = pd.to_numeric(subset['result'], errors='coerce')
                    subset = subset.dropna(subset=['result'])
                    
                    if len(subset) > 0:
                        plt.plot(subset['iteration'], subset['result'], label=estimator, marker='o')
            
            plt.title(f'Results for {metric}')
            plt.xlabel('Iteration')
            plt.ylabel('Result')
            plt.legend()
            plt.grid(True)
            plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
            
            # Save plot as PNG file
            plot_filename = f'{csv_basename}_{metric}_plot.png'
            plot_path = os.path.join(csv_dir, plot_filename)
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            print(f"Plot saved: {plot_path}")
            
            # Close the figure to free memory
            plt.close()

    print(f"All plots saved in directory: {csv_dir}")