"""Two-agent orchestrator.

Main orchestrator class using modular components.
"""
import sys
import os
import argparse
import shutil
from pathlib import Path
from typing import Dict, Any, List, Optional
from datetime import datetime

# Add scripts to path for utilities  
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../')))

# Import project utilities
from scripts.utils.file_editor import read_file
from scripts.utils.parse_agent_output import parse
from scripts.utils.config_loader import load_config
from scripts.utils.save_run_results import save_results

# Import two-agent components
from agents import initialise_agents
from tasks import TwoAgentTasks, TaskResult

# Import modular components
from io_manager import IOManager
from executor import CodeExecutor, pick_interpreter, detect_notebook_framework
from csv_utils import parse_scope_rl_csv, save_csv_results
from metrics import find_best_iteration, detect_project_optimisation_goal, generate_comprehensive_results_table


class TwoAgentCriticalFailure(Exception):
    """Custom exception for critical failures that should cause the framework to exit with non-zero code."""
    pass


class TwoAgentOrchestrator:
    """Main orchestrator for the two-agent optimisation framework.
    
    Manages the complete workflow from initialisation through 7 iterations
    of analysis, coding, and evaluation, culminating in a summary of results.
    """
    
    def __init__(self, input_path: str, model: str, option: str = "whole_code"):
        """Initialise the two-agent orchestrator.
        
        Args:
            input_path: Path to the input notebook or Python file
            model: Language model to use (e.g., 'gemini-1.5-flash')
            option: Output option ('whole_code', 'manual_patch', 'agent_applies')
        """
        self.input_path = input_path
        self.model = model
        self.option = option
        
        # Load configuration
        self.config = load_config()
        
        # Set up modular components
        self.io_manager = IOManager(input_path)
        self.io_manager.setup_working_directories()
        
        # Detect framework type
        self.framework = detect_notebook_framework(input_path)
        print(f"Detected framework: {self.framework}")
        
        # Select appropriate interpreter
        self.interpreter = pick_interpreter(input_path, self.config['settings'])
        print(f"Using interpreter: {self.interpreter}")
        
        # Set up executor
        self.executor = CodeExecutor(self.interpreter, self.io_manager.run_dir)
        
        # Initialise agents
        print(f"Initialising agents with model: {self.model}")
        
        # Import here to avoid circular imports
        from agents import detect_model_type
        model_type = detect_model_type(model)
        print(f"Detected model type: {model_type}")
        
        self.agents = initialise_agents(model)
        self.task_orchestrator = TwoAgentTasks(self.agents)
        
        # Results tracking using structured data approach
        self.all_iteration_metrics: List[Dict[str, Any]] = []
        self.iteration_analyses: List[str] = []
        self.iteration_codes: List[str] = []
        
        # Failure tracking for error handling
        self.failed_iterations: List[int] = []
        self.critical_failures: List[str] = []
        self.successful_iterations: List[int] = []

    def prepare_initial_code(self) -> str:
        """Prepare the initial code file and run it for baseline results.
        
        Returns:
            The initial code content
        """
        print("Preparing initial code and generating baseline results...")
        
        # Handle both .ipynb and .py inputs
        if self.input_path.endswith('.ipynb'):
            # Convert notebook to Python script
            from scripts.utils.notebook_converter import convert_notebook_to_script
            convert_notebook_to_script(self.input_path, str(self.io_manager.code_file))
        else:
            # Copy Python file directly
            shutil.copy(self.input_path, self.io_manager.code_file)
        
        # Read the initial code
        initial_code = read_file(str(self.io_manager.code_file))
        
        # Run initial code for baseline
        baseline_csv_path = self.executor.run_code(str(self.io_manager.code_file))
        
        if baseline_csv_path:
            # Save baseline CSV in correct format
            saved_csv = save_csv_results(baseline_csv_path, 0, self.io_manager.results_dir)
            
            if saved_csv:
                # Parse baseline metrics
                baseline_metrics = parse_scope_rl_csv(saved_csv, 0)
                baseline_metrics['iteration'] = 0  # Baseline is iteration 0
                baseline_metrics['status'] = 'baseline'
                
                # Store baseline metrics
                self.all_iteration_metrics.append(baseline_metrics)
            
            print("Baseline Results:")
            print("=" * 60)
            key_metrics = ['ipw_mean', 'dm_mean', 'dr_mean', 'ipw_relative_ee', 'dm_relative_ee', 'dr_relative_ee']
            for metric in key_metrics:
                if metric in baseline_metrics:
                    print(f"{metric}: {baseline_metrics[metric]}")
            print("=" * 60)
        else:
            print("Warning: Baseline execution failed or no CSV generated")
        
        return initial_code
    
    def run_optimisation_iterations(self, initial_code: str, num_iterations: int = 7):
        """Run the main optimisation loop with analysis, coding, and evaluation.
        
        Each iteration starts from the original baseline code to ensure independent exploration.
        
        Args:
            initial_code: The initial code to optimise
            num_iterations: Number of iterations to run (default: 7)
        """
        print(f"\nStarting {num_iterations} iterations of optimisation...")
        print("=" * 80)
        
        # Store the original code to restart each iteration from baseline
        original_baseline_code = initial_code
        
        for iteration in range(num_iterations):
            print(f"\n{'='*60}")
            print(f"ITERATION {iteration + 1}/{num_iterations}")
            print(f"{'='*60}")
            
            # Each iteration starts from the original baseline
            current_iteration_code = original_baseline_code
            print(f"Starting from original baseline (independent iteration)")
            
            try:
                # Step 1: Analysis
                print(f"\nSTEP 1: ANALYSIS AGENT")
                print(f"Analysing original baseline code")
                print(f"Using model: {self.model}")
                print(f"Calling analysis agent...")
                
                analysis_result = self.task_orchestrator.execute_analysis_task(
                    current_iteration_code, self.framework, iteration + 1
                )
                
                if not analysis_result.success:
                    print(f"Analysis failed: {analysis_result.error_message}")
                    continue
                
                analysis = analysis_result.output
                self.iteration_analyses.append(analysis)
                
                # Save analysis to documentation
                analysis_file = self.io_manager.docs_dir / f"Instructions{iteration}.md"
                self.io_manager.write_file(str(analysis_file), analysis)
                
                # Also save in original location for compatibility
                original_analysis_file = self.io_manager.docs_dir / f"analysis_{iteration + 1}.md"
                self.io_manager.write_file(str(original_analysis_file), analysis)
                
                print(f"Analysis completed successfully!")
                print(f"Results saved to: {analysis_file.name}")
                print(f"Analysis preview:")
                print(f"    {analysis[:200]}..." if len(analysis) > 200 else f"    {analysis}")
                
                # Show key findings from analysis
                lines = analysis.split('\n')
                key_lines = [line for line in lines[:10] if line.strip() and not line.startswith('#')]
                if key_lines:
                    print(f"Key findings:")
                    for line in key_lines[:3]:  # Show first 3 key findings
                        print(f"    • {line.strip()}")
                
                # Step 2: Code Generation
                print(f"\nSTEP 2: PCODER AGENT")
                print(f"Implementing suggested improvements...")
                print(f"Using model: {self.model}")
                print(f"Generating improved code...")
                
                coding_result = self.task_orchestrator.execute_coding_task(
                    current_iteration_code, analysis, self.framework, iteration + 1
                )
                
                if not coding_result.success:
                    print(f"Code generation failed: {coding_result.error_message}")
                    continue
                
                improved_code = coding_result.output
                self.iteration_codes.append(improved_code)
                
                # Save improved code
                iteration_code_file = self.io_manager.results_dir / f"iteration_{iteration + 1}_code.py"
                newcode_file = self.io_manager.docs_dir / f"newcode{iteration}.py"
                
                self.io_manager.write_file(str(iteration_code_file), improved_code)
                self.io_manager.write_file(str(newcode_file), improved_code)
                
                # Show code differences
                print(f"Code generation completed successfully!")
                print(f"Code saved to: {iteration_code_file.name}")
                
                # Generate and save diff (comparing against original baseline)
                diff_text = self.io_manager.generate_code_diff(current_iteration_code, improved_code, iteration + 1)
                diff_file = self.io_manager.results_dir / f"diff_iteration_{iteration + 1}.txt"
                self.io_manager.write_file(str(diff_file), diff_text)
                
                # Basic diff info
                current_lines = current_iteration_code.count('\n')
                new_lines = improved_code.count('\n')
                print(f"Code changes: {current_lines} -> {new_lines} lines ({new_lines - current_lines:+d})")
                print(f"Diff saved to: {diff_file.name}")
                
                # Step 3: Evaluation
                print(f"\nSTEP 3: CODE EXECUTION")
                print(f"Environment: {self.interpreter}")
                print(f"Working directory: {self.io_manager.run_dir}")
                print(f"Executing improved code...")
                
                # Run the improved code
                csv_path = self.executor.run_code(str(iteration_code_file))
                
                if csv_path:
                    # Save CSV in correct format
                    saved_csv = save_csv_results(csv_path, iteration + 1, self.io_manager.results_dir)
                    
                    if saved_csv:
                        # Parse the CSV metrics
                        iteration_metrics = parse_scope_rl_csv(saved_csv, iteration + 1)
                        iteration_metrics['iteration'] = iteration + 1
                        iteration_metrics['status'] = 'success'
                        
                        # Store the structured metrics
                        self.all_iteration_metrics.append(iteration_metrics)
                        self.successful_iterations.append(iteration + 1)
                        
                        print(f"Execution completed successfully!")
                        print(f"Results preview:")
                        
                        # Show key metrics
                        key_metrics = ['ipw_mean', 'dm_mean', 'dr_mean', 'ipw_relative_ee', 'dm_relative_ee', 'dr_relative_ee']
                        for metric in key_metrics:
                            if metric in iteration_metrics:
                                print(f"    {metric}: {iteration_metrics[metric]}")
                        
                        print(f"CSV results saved to: iteration_{iteration + 1}_results.csv")
                    else:
                        error_msg = f"CSV saving failed for iteration {iteration + 1}"
                        print(f"Error: {error_msg}")
                        self.failed_iterations.append(iteration + 1)
                        self.critical_failures.append(error_msg)
                        # Store failure record
                        failure_metrics = {
                            'iteration': iteration + 1,
                            'status': 'failed',
                            'error': 'CSV saving failed'
                        }
                        self.all_iteration_metrics.append(failure_metrics)
                else:
                    error_msg = f"Code execution failed for iteration {iteration + 1}"
                    print(f"Error: {error_msg}")
                    self.failed_iterations.append(iteration + 1)
                    self.critical_failures.append(error_msg)
                    # Store failure record
                    failure_metrics = {
                        'iteration': iteration + 1,
                        'status': 'failed',
                        'error': 'Code execution failed or no CSV generated'
                    }
                    self.all_iteration_metrics.append(failure_metrics)
                
                # Create consolidated iteration summary
                successful_iterations = [m for m in self.all_iteration_metrics if m.get('iteration') == iteration + 1 and m.get('status') == 'success']
                if successful_iterations:
                    status = "SUCCESS"
                    iteration_metrics = successful_iterations[0]
                    
                    # Framework-aware key metrics for summary
                    if self.framework == 'scope_rl':
                        summary_key_metrics = ['on_policy_mean', 'dm_mean', 'sntis_mean', 'snpdis_mean', 'sndr_mean']
                    else:
                        summary_key_metrics = ['ipw_mean', 'dm_mean', 'dr_mean', 'ipw_relative_ee', 'dm_relative_ee', 'dr_relative_ee']
                    
                    key_metrics_text = "\n".join([f"- {k}: {v}" for k, v in iteration_metrics.items() 
                                                if k in summary_key_metrics])
                else:
                    status = "FAILED"
                    key_metrics_text = "Execution failed - no metrics available"
                
                if len(self.iteration_codes) >= iteration + 1:
                    current_improved_code = self.iteration_codes[iteration]
                    line_changes = f"{current_iteration_code.count(chr(10))} -> {current_improved_code.count(chr(10))} ({current_improved_code.count(chr(10)) - current_iteration_code.count(chr(10)):+d})"
                else:
                    line_changes = f"{current_iteration_code.count(chr(10))} -> generation failed"
                
                iteration_summary = f"""# Iteration {iteration + 1} Summary

## Files Generated:
- **Analysis**: R-doc/Instructions{iteration}.md
- **Generated Code**: R-doc/newcode{iteration}.py  
- **Code Diff**: Results/diff_iteration_{iteration + 1}.txt
- **Results**: Results/iteration_{iteration + 1}_results.csv

## Status: {status}

## Code Changes:
- Lines: {line_changes}
- See diff_iteration_{iteration + 1}.txt for detailed changes

## Key Metrics:
{key_metrics_text}
"""
                
                # Save iteration summary
                summary_file = self.io_manager.results_dir / f"iteration_{iteration + 1}_summary.md"
                self.io_manager.write_file(str(summary_file), iteration_summary)
                
                print(f"\n{'='*60}")
                print(f"ITERATION {iteration + 1} COMPLETE")
                print(f"Summary saved to: {summary_file.name}")
                print(f"{'='*60}")
                
            except Exception as e:
                error_msg = f"Critical error in iteration {iteration + 1}: {str(e)}"
                print(f"Error: {error_msg}")
                self.failed_iterations.append(iteration + 1)
                self.critical_failures.append(error_msg)
                continue
        
        # Validation: Check if we have any successful iterations
        print(f"\n{'='*60}")
        print("ITERATION SUMMARY")
        print(f"{'='*60}")
        print(f"Successful iterations: {len(self.successful_iterations)}")
        print(f"Failed iterations: {len(self.failed_iterations)}")
        print(f"Critical failures: {len(self.critical_failures)}")
        
        if len(self.successful_iterations) == 0:
            critical_failure_msg = "CRITICAL FAILURE: No iterations completed successfully"
            print(f"Error: {critical_failure_msg}")
            self.critical_failures.append(critical_failure_msg)
            raise TwoAgentCriticalFailure(critical_failure_msg)
        
        if len(self.failed_iterations) >= len(self.successful_iterations):
            warning_msg = f"WARNING: More failures ({len(self.failed_iterations)}) than successes ({len(self.successful_iterations)})"
            print(f"Error: {warning_msg}")
        
        print(f"Framework completed with {len(self.successful_iterations)} successful iterations")
    
    def generate_final_summary(self):
        """Generate final summary of all iterations using the Summariser agent."""
        print(f"\n{'='*80}")
        print("FINAL SUMMARY GENERATION")
        print(f"{'='*80}")
        
        print("\nSUMMARISER AGENT")
        print(f"Using model: {self.model}")
        print(f"Analysing results from {len([m for m in self.all_iteration_metrics if m.get('status') == 'success'])} iterations...")
        
        # Perform all quantitative analysis FIRST to establish a source of truth.
        baseline_metrics = next((m for m in self.all_iteration_metrics if m.get('status') == 'baseline'), {})
        iteration_metrics = [m for m in self.all_iteration_metrics if m.get('status') == 'success']
        
        detected_goal = detect_project_optimisation_goal(self.input_path, self.io_manager.docs_dir)
        best_iteration_num = find_best_iteration(
            self.all_iteration_metrics,
            self.input_path,
            self.io_manager.docs_dir,
            optimisation_goal_override=detected_goal
        )
        
        comparison_table = generate_comprehensive_results_table(
            self.all_iteration_metrics,
            self.framework,
            self.model
        )
        
        # Save the definitive quantitative results to result.txt
        self.io_manager.write_file(str(self.io_manager.result_file), comparison_table)
        print(f"Comprehensive results table saved to: {self.io_manager.result_file.name}")

        # Create a structured context for the summariser agent.
        # This provides the LLM with the pre-calculated analysis to prevent it from making its own, incorrect conclusions.
        summary_context = f"""
# Quantitative Analysis Results

This section contains the definitive, pre-calculated results of the optimisation run. Use this as the source of truth for the summary.

**Project Optimisation Goal (Primary Metric):** {detected_goal or "Not detected, default to policy_value"}
**Best Performing Iteration (Identified by Orchestrator):** Iteration {best_iteration_num}

**Scoring Methodology:** The framework uses SCOPE-RL Enhanced Robust Scoring, which takes into account:
- Estimator agreement levels (how consistent different estimators are)
- Confidence intervals and outlier handling
- Framework-specific weighting (SCOPE-RL vs OBP considerations)
- Bias-variance trade-offs between estimators

This robust approach may select a different "best" iteration than simple averaging would suggest, as it prioritises statistical reliability and estimator consensus over raw score maximisation.

**Performance Summary Table:**
{comparison_table}

# Raw Iteration Data

This section contains the raw, unprocessed metrics for each iteration for additional context.

**Baseline Metrics:**
{baseline_metrics}

**Successful Iteration Metrics:**
"""
        for metrics in iteration_metrics:
            summary_context += f"\n- Iteration {metrics.get('iteration', 'Unknown')}: {metrics}\n"
        
        # Execute summary task with the new, richer context.
        print(f"Generating comprehensive summary based on quantitative analysis...")
        summary_result = self.task_orchestrator.execute_summary_task(summary_context)
        
        if summary_result.success:
            summary = summary_result.output
            
            # Save summary
            summary_file = self.io_manager.results_dir / "final_summary.md"
            self.io_manager.write_file(str(summary_file), summary)
            
            # The comparison_table has already been saved above.
            
            # Replace input file with best iteration code for start.py
            if best_iteration_num:
                best_code_path = self.io_manager.results_dir / f"iteration_{best_iteration_num}_code.py"
                if best_code_path.exists():
                    best_code = read_file(str(best_code_path))
                    print(f"\nBest iteration identified: Iteration {best_iteration_num}")
                    print(f"Replacing input file with best iteration code...")
                    
                    # Replace the input file with the best iteration code
                    self.io_manager.write_file(self.input_path, best_code)
                    print(f"Input file updated: {self.input_path}")
                    print(f"start.py will now use the optimised version for final CSV generation")
                else:
                    print(f"Could not find code for best iteration ({best_iteration_num}), input file unchanged")
            else:
                print(f"Could not identify best iteration, input file unchanged")
            
            print("Summary generation completed successfully!")
            print(f"Summary saved to: {summary_file.name}")
            
            print(f"\n{'='*80}")
            print("FINAL SUMMARY")
            print(f"{'='*80}")
            print(summary)
            print(f"{'='*80}")
            
            print(f"\nComplete results available in: {self.io_manager.run_dir}")
            
        else:
            print(f"Summary generation failed: {summary_result.error_message}")
    
    def run_complete_optimisation(self, iterations: int = 7):
        """Run the complete optimisation workflow.
        
        Args:
            iterations: Number of optimisation iterations to perform
        """
        try:
            # Prepare initial code and baseline
            initial_code = self.prepare_initial_code()
            
            # Run optimisation iterations
            self.run_optimisation_iterations(initial_code, iterations)
            
            # Generate final summary
            self.generate_final_summary()
            
            print(f"\nTwo-agent optimisation completed!")
            print(f"Results available in: {self.io_manager.run_dir}")
            
        except TwoAgentCriticalFailure as e:
            print(f"\nError: CRITICAL FAILURE: {str(e)}")
            print(f"Framework failed with {len(self.critical_failures)} critical errors:")
            for i, failure in enumerate(self.critical_failures, 1):
                print(f"  {i}. {failure}")
            raise  # Re-raise to trigger sys.exit(1) in main
        except Exception as e:
            print(f"Optimisation failed: {str(e)}")
            raise


def main(input_path: str, model: str, option: str = "whole_code"):
    """Main entry point for the two-agent framework.
    
    This function is called by start.py when -fw two_agent is specified.
    
    Args:
        input_path: Path to the input notebook or Python file
        model: Language model to use
        option: Output option (currently supports 'whole_code', 'manual_patch', 'agent_applies')
    """
    print("=" * 80)
    print("TWO-AGENT OPTIMISATION FRAMEWORK")
    print("=" * 80)
    print(f"Input: {input_path}")
    print(f"Model: {model}")
    print(f"Option: {option}")
    print("=" * 80)
    
    try:
        # Create and run orchestrator
        orchestrator = TwoAgentOrchestrator(input_path, model, option)
        
        # Run with 7 iterations
        orchestrator.run_complete_optimisation(iterations=7)
        
        print("\nTwo-agent optimisation completed successfully!")
        sys.exit(0)  # Explicit success exit
        
    except TwoAgentCriticalFailure as e:
        print(f"\nTWO-AGENT FRAMEWORK FAILED")
        print(f"Reason: {str(e)}")
        print("Check the logs above for detailed error information.")
        sys.exit(1)  # Non-zero exit code for critical failures
        
    except Exception as e:
        print(f"\nUNEXPECTED ERROR: {str(e)}")
        print("This indicates a bug in the two-agent framework.")
        sys.exit(1)  # Non-zero exit code for unexpected errors


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Two-agent framework for iterative hyperparameter optimisation."
    )
    parser.add_argument("input_path", type=str, help="Path to the input notebook or Python file")
    parser.add_argument("model", type=str, help="Language model to use (e.g., gemini-1.5-flash)")
    parser.add_argument(
        "-opt", "--option", 
        type=str, 
        default="whole_code", 
        choices=['manual_patch', 'whole_code', 'agent_applies'], 
        help="Output option for code modifications"
    )
    
    args = parser.parse_args()
    main(args.input_path, args.model, args.option) 