"""Agent implementations for two-agent framework.

Provides analyse, pcoder, and summariser agents for ML optimisation.
"""
import sys
import os
import re
from typing import Dict, Any
from pydantic import BaseModel

# Add scripts to path for utilities
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../')))
from scripts.utils.get_api_key import get_api_key
from scripts.utils.load_prompts import load_prompts

# Import different model APIs
try:
    import google.generativeai as genai
except ImportError:
    genai = None

try:
    from openai import OpenAI
    import openai
except ImportError:
    OpenAI = None
    openai = None


class AnalysisOutput(BaseModel):
    """Pydantic model for analysis agent output."""
    analysis: str

class CodeOutput(BaseModel):
    """Pydantic model for coding agent output."""
    code: str

class SummaryOutput(BaseModel):
    """Pydantic model for summary agent output."""
    summary: str

def detect_model_type(model: str) -> str:
    """
    Detect the model type based on the model name.
    
    Args:
        model: The model name (e.g., 'gemini-1.5-flash', 'gpt-4o')
        
    Returns:
        The model type ('gemini', 'openai', 'unknown')
    """
    model_lower = model.lower()
    
    if 'gemini' in model_lower:
        return 'gemini'
    elif any(keyword in model_lower for keyword in ['gpt', 'chatgpt']):
        return 'openai'
    else:
        return 'unknown'

def call_model_api(model: str, system_message: str, user_message: str) -> str:
    """
    Call the appropriate model API based on model type.
    
    Args:
        model: The model name
        system_message: System message for the model
        user_message: User message/prompt for the model
        
    Returns:
        The model's response text
    """
    model_type = detect_model_type(model)
    api_key = get_api_key(model)
    
    if model_type == 'gemini':
        if genai is None:
            raise ImportError("google.generativeai not available. Install with: pip install google-generativeai")
        
        genai.configure(api_key=api_key)
        gemini_model = genai.GenerativeModel(model)
        
        # Combine system and user messages for Gemini
        full_prompt = f"{system_message}\n\n{user_message}"
        response = gemini_model.generate_content(full_prompt)
        
        return response.candidates[0].content.parts[0].text
        
    elif model_type == 'openai':
        if OpenAI is None:
            raise ImportError("openai not available. Install with: pip install openai")
        
        client = OpenAI(api_key=api_key)
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": system_message},
                {"role": "user", "content": user_message}
            ]
        )
        
        return response.choices[0].message.content
        
    else:
        raise ValueError(f"Unsupported model type for model: {model}")

def extract_content_from_response(response_text: str, content_type: str) -> str:
    """
    Extract specific content from model response.
    
    Args:
        response_text: The full response text from the model
        content_type: Type of content to extract ('analysis', 'code', 'summary')
        
    Returns:
        Extracted content
    """
    # Return full response as different models format differently
    return response_text.strip()

def create_analyse_agent(model: str) -> Dict[str, Any]:
    """
    Create the Analyse agent that identifies improvements and optimisations.
    
    Args:
        model: The language model to use (e.g., 'gpt-4o', 'gemini-1.5-flash')
        
    Returns:
        Dict containing agent configuration and callable function
    """
    api_key = get_api_key(model)
    
    def analyse_code(code: str, framework: str = "unknown") -> AnalysisOutput:
        """
        Analyse code and suggest improvements for hyperparameter optimisation.
        
        Args:
            code: The Python code to analyse
            framework: The detected framework ('obp', 'scope_rl', or 'unknown')
            
        Returns:
            AnalysisOutput with improvement suggestions
        """
        # Load framework-specific prompts from config
        prompts = load_prompts()
        
        if 'two_agent' in prompts and 'analysis' in prompts['two_agent']:
            framework_prompts = prompts['two_agent']['analysis']
            if framework in framework_prompts:
                context = framework_prompts[framework]
            elif framework == 'unknown' and 'general' in framework_prompts:
                context = framework_prompts['general']
            else:
                # Fallback to general if specific framework not found
                context = framework_prompts.get('general', 'Analyse the code for potential improvements.')
        else:
            # Fallback context if config prompts not found
            context = "Analyse the code for potential improvements."
        
        prompt = f"""
        Analyse the following code and provide improvement suggestions in Markdown format.
        
        {context}
        
        Guidelines:
        - **CRITICAL: Only suggest parameter VALUE changes, NOT structural code changes**
        - **DO NOT add new frameworks like RandomizedSearchCV, StratifiedKFold, or hyperparameter grids**
        - **DO NOT add new imports or complex optimisation frameworks**
        
        **Allowed Changes:**
        - Change parameter values: `random_state=12345` → `random_state=42`
        - Adjust hyperparameters: `max_iter=1000` → `max_iter=2000`
        - Add simple parameters to existing functions: `LogisticRegression()` → `LogisticRegression(C=10)`
        - Change fold counts: `n_folds=3` → `n_folds=5`
        - Modify existing numeric values for better performance
        
        **FORBIDDEN Changes:**
        - Adding RandomizedSearchCV or GridSearchCV
        - Adding StratifiedKFold or complex cross-validation
        - Adding new imports or libraries
        - Restructuring the code architecture
        - Adding multi-metric scoring arrays
        
        **Example Good Suggestions:**
        - "Change random_state from 12345 to 42 for potentially better results"
        - "Increase max_iter from 1000 to 2000 to improve convergence"
        - "Add C=10 parameter to LogisticRegression for stronger regularisation"
        - "Change n_folds from 3 to 5 for better cross-validation"
        
        Focus on 1-2 simple parameter adjustments per iteration.
        
        **Code to analyse:**
        ```python
        {code}
        ```
        
        Provide your analysis as a clear, structured markdown document suggesting ONLY parameter value changes.
        """
        
        system_message = "You are an expert at analysing machine learning code for optimisation opportunities. Use British English throughout."
        
        response_text = call_model_api(model, system_message, prompt)
        analysis_content = extract_content_from_response(response_text, 'analysis')
        
        return AnalysisOutput(analysis=analysis_content)
    
    return {
        'agent_type': 'analyse',
        'model': model,
        'api_key': api_key,
        'analyse': analyse_code
    }

def create_pcoder_agent(model: str) -> Dict[str, Any]:
    """Create PCoder agent for code improvement implementation.
    
    Args:
        model: Language model to use
        
    Returns:
        Dict containing agent configuration and callable function
    """
    api_key = get_api_key(model)
    
    def generate_improved_code(original_code: str, analysis: str, framework: str = "unknown") -> CodeOutput:
        """Generate improved code based on analysis suggestions.
        
        Args:
            original_code: Original Python code
            analysis: Analysis from the Analyse agent
            framework: Detected framework type for context
            
        Returns:
            CodeOutput with improved code
        """
        # Load framework-specific PCoder prompts from config
        prompts = load_prompts()
        
        if 'two_agent' in prompts and 'pcoder' in prompts['two_agent']:
            framework_prompts = prompts['two_agent']['pcoder']
            if framework in framework_prompts:
                context = framework_prompts[framework]
            else:
                # Fallback to general guidance if specific framework not found
                context = """Apply the suggested parameter VALUE explorations for maximum performance.
                
                **CRITICAL CONSTRAINTS:**
                - Change ONLY parameter values, NO structural code changes
                - NO new frameworks or complex optimisation code
                - Keep all existing imports and code structure
                - EXPLORE different values each iteration (avoid repeating same values)"""
        else:
            # Fallback context if config prompts not found
            context = "Apply the suggested improvements precisely, focusing on parameter value changes only."
        
        prompt = f"""
        Given the code below and its analysis, apply the suggested improvements and return the complete, functional code.
        
        {context}
        
        **Original code:**
        ```python
        {original_code}
        ```
        
        **Analysis and improvement instructions:**
        {analysis}
        
        Return the complete improved code as a single Python code block implementing the suggested parameter explorations.
        """
        
        system_message = "You are an expert programmer who implements precise code improvements. Use British English in comments."
        
        response_text = call_model_api(model, system_message, prompt)
        
        # Extract code from response (handle different model formats)
        code_content = extract_content_from_response(response_text, 'code')
        
        # Try to extract code from markdown code blocks if present
        import re
        code_blocks = re.findall(r'```(?:python)?\s*(.*?)```', code_content, re.DOTALL)
        if code_blocks:
            code_content = code_blocks[0].strip()
        
        return CodeOutput(code=code_content)
    
    return {
        'agent_type': 'pcoder',
        'model': model,
        'api_key': api_key,
        'generate_code': generate_improved_code
    }

def create_summariser_agent(model: str) -> Dict[str, Any]:
    """Create summariser agent for results evaluation.
    
    Args:
        model: Language model to use
        
    Returns:
        Dict containing agent configuration and callable function
    """
    api_key = get_api_key(model)
    
    def summarise_results(results_text: str) -> SummaryOutput:
        """Analyse results from multiple iterations and determine best performers.
        
        Args:
            results_text: Combined results from all iterations, including a pre-computed analysis
            
        Returns:
            SummaryOutput with analysis of best iterations
        """
        prompt = f"""
        You are an expert AI analysis agent. Your task is to write a final summary report for a multi-iteration hyperparameter optimisation run.

        You have been provided with a "source of truth" quantitative analysis performed by the orchestrator. Do NOT perform your own analysis or attempt to reinterpret the raw data. Your summary MUST reflect the findings from the 'Quantitative Analysis Results' section.

        **Source of Truth & Data:**
        {results_text}

        **Your Task:**
        Based *only* on the 'Quantitative Analysis Results' provided above, write a comprehensive summary in British English that includes the following sections:

        1.  **Executive Summary:** Start with a brief, high-level overview. State the primary optimisation goal and clearly identify the best-performing iteration number as determined by the orchestrator.
        2.  **Performance Analysis:**
            *   Explain *why* the winning iteration was chosen, referencing the primary metric (e.g., `relative_policy_value`).
            *   Use the 'Performance Summary Table' to compare the winning iteration against the baseline and other top performers.
            *   Highlight specific numerical improvements for key estimators.
        3.  **Conclusion:** Briefly conclude the report, reiterating the success of the optimisation and confirming the best iteration found.

        **CRITICAL INSTRUCTIONS:**
        - The "Best Performing Iteration" identified in the quantitative analysis is definitive. Do not contradict it.
        - Structure your response clearly with markdown headings.
        """
        
        system_message = "You are an expert at evaluating machine learning results and creating clear, data-driven summary reports. Your response must be strictly adhere to the provided quantitative analysis."
        
        response_text = call_model_api(model, system_message, prompt)
        summary_content = extract_content_from_response(response_text, 'summary')
        
        return SummaryOutput(summary=summary_content)
    
    return {
        'agent_type': 'summariser',
        'model': model,
        'api_key': api_key,
        'summarise': summarise_results
    }

def initialise_agents(model: str) -> Dict[str, Dict[str, Any]]:
    """Initialise all three agents for the two-agent framework.
    
    Args:
        model: Language model to use for all agents
        
    Returns:
        Dict containing all agent instances
    """
    # Validate that we can get the API key for this model
    try:
        get_api_key(model)
    except Exception as e:
        raise ValueError(f"Failed to get API key for model {model}: {e}")
    
    # Validate that the model type is supported
    model_type = detect_model_type(model)
    if model_type == 'unknown':
        print(f"Warning: Unknown model type for {model}, will attempt to use as OpenAI-compatible")
    
    return {
        'analyse': create_analyse_agent(model),
        'pcoder': create_pcoder_agent(model),
        'summariser': create_summariser_agent(model)
    } 