import os
import re
import ast
import json
import argparse
import nbformat
from nbconvert import PythonExporter
from typing import List, Tuple, Optional, Dict
import glob
import pandas as pd
import numpy as np
import pickle
from .artifact_utils import ExecutionMode, determine_execution_mode, find_datasets, find_models

def remove_comments(script_content):
    """
    Remove comments from the script content, but preserve critical protective comments.
    """
    cleaned_script = []
    for line in script_content.splitlines():
        # Preserve critical protective comments
        if any(keyword in line for keyword in [
            'CRITICAL MODEL ARCHITECTURE RULE',
            'FIXED: DO NOT CHANGE',
            'MODIFIABLE: Can be optimised',
            'CRITICAL: DO NOT MODIFY ARCHITECTURES',
            'MODIFIABLE SECTION: CreateOPEInput parameters',
            'VALID PARAMETERS:',
            'CRITICAL: DO NOT add non-existent parameters'
        ]):
            cleaned_script.append(line)
        elif line.startswith('#') or line=='':
            continue
        else:
            cleaned_script.append(line)
    
    return '\n'.join(cleaned_script)



def get_model_architecture_from_notebook(notebook_path: str, model_name: str) -> List[int]:
    """
    Extract correct model architecture from hardcoded mappings based on notebook name and model name with pattern fallbacks.
    """
    notebook_name = os.path.splitext(os.path.basename(notebook_path))[0]
    
    # Architecture mappings for all notebooks
    architecture_mappings = {
        'basic_synthetic_discrete_advanced': {
            'cql_discrete_b1': [30, 30],
            'cql_discrete_b2': [100],
            'cql_discrete_b3': [50, 10],
            'ddqn': [30, 30]
        },
        'basic_synthetic_continuous_advanced': {
            'cql_continuous_b1': [30, 30],
            'cql_continuous_b2': [100],
            'cql_continuous_b3': [50, 10],
            'iql_continuous_b1': [30, 30],
            'iql_continuous_b2': [100],
            'iql_continuous_b3': [50, 10],
            'sac': [30, 30]
        },
        'rtb_synthetic_discrete_advanced': {
            'cql_discrete_b1': [30, 30],
            'cql_discrete_b2': [100],
            'cql_discrete_b3': [50, 10],
            'ddqn': [30, 30]
        },
        'rtb_synthetic_continuous_advanced': {
            'cql_continuous_b1': [30, 30],
            'cql_continuous_b2': [100],
            'cql_continuous_b3': [50, 10],
            'td3bc_continuous_b1': [30, 30],
            'td3bc_continuous_b2': [100],
            'td3bc_continuous_b3': [50, 10],
            'iql_continuous_b1': [30, 30],
            'iql_continuous_b2': [100],
            'iql_continuous_b3': [50, 10],
            'sac': [30, 30]
        },
        'rec_synthetic_discrete_advanced': {
            'cql_discrete_b1': [30, 30],
            'cql_discrete_b2': [100],
            'cql_discrete_b3': [50, 10],
            'ddqn': [30, 30]
        }
    }
    
    # Get architecture for this specific notebook and model
    if notebook_name in architecture_mappings:
        notebook_architectures = architecture_mappings[notebook_name]
        if model_name in notebook_architectures:
            return notebook_architectures[model_name]
    
    # Pattern-based fallback for unmatched cases
    if '_b1' in model_name:
        return [30, 30]
    elif '_b2' in model_name:
        return [100]
    elif '_b3' in model_name:
        return [50, 10]
    
    # Default fallback
    return [30, 30]

def detect_required_imports(notebook_path: str, anchor_idx: int, notebook_node) -> List[str]:
    """
    Detect which SCOPE-RL policy imports are required by scanning notebook content after CreateOPEInput anchor cell.
    """
    required_imports = set()
    
    # Always include basic imports
    required_imports.add('from scope_rl.policy import EpsilonGreedyHead')
    
    # Scan cells after anchor for class usage
    for cell in notebook_node.cells[anchor_idx:]:
        if cell.cell_type == 'code':
            source = cell.source
            
            # Check for specific class usage
            if 'SoftmaxHead' in source:
                required_imports.add('from scope_rl.policy import SoftmaxHead')
            if 'ContinuousEvalHead' in source:
                required_imports.add('from scope_rl.policy import ContinuousEvalHead')
            if 'TruncatedGaussianHead' in source:
                required_imports.add('from scope_rl.policy import TruncatedGaussianHead')
    
    return list(required_imports)

def normalise_notebook_code(script_content):
    """
    Convert IPython magic commands to standalone Python equivalents.
    """
    # Replace matplotlib inline magic command
    script_content = re.sub(
        r"get_ipython\(\)\.run_line_magic\('matplotlib',\s*'inline'\)",
        "# Matplotlib non-interactive mode\nimport matplotlib\nmatplotlib.use('Agg')  # Non-interactive backend\nimport matplotlib.pyplot as plt\nplt.ioff()  # Turn off interactive mode",
        script_content
    )
    
    # Fix CreateOPEInput parameter issues - remove invalid parameters added by AI agents
    # Valid CreateOPEInput parameters: env, model_args, gamma, bandwidth, state_scaler, action_scaler, device
    # Remove n_trajectories_on_policy_evaluation (this belongs to obtain_whole_inputs method)
    script_content = re.sub(
        r',\s*n_trajectories_on_policy_evaluation=\d+(?:\s*#[^\n]*)?',
        '',
        script_content
    )
    script_content = re.sub(
        r'n_trajectories_on_policy_evaluation=\d+(?:\s*#[^\n]*)?,\s*',
        '',
        script_content
    )
    
    # Remove n_bootstrap_samples (not a valid CreateOPEInput parameter)
    script_content = re.sub(
        r',\s*n_bootstrap_samples=\d+(?:\s*#[^\n]*)?',
        '',
        script_content
    )
    script_content = re.sub(
        r'n_bootstrap_samples=\d+(?:\s*#[^\n]*)?,\s*',
        '',
        script_content
    )
    
    # Fix MeanQFunctionFactory parameter issues - remove invalid parameters added by AI agents
    # Valid MeanQFunctionFactory parameters: share_encoder (bool, default=False)
    # Remove n_critics (this parameter doesn't exist in MeanQFunctionFactory)
    script_content = re.sub(
        r'MeanQFunctionFactory\([^)]*n_critics\s*=\s*\d+[^)]*\)',
        'MeanQFunctionFactory()',
        script_content
    )
    
    # Remove other common invalid parameters AI agents might add to MeanQFunctionFactory
    script_content = re.sub(
        r'MeanQFunctionFactory\([^)]*n_quantiles\s*=\s*\d+[^)]*\)',
        'MeanQFunctionFactory()',
        script_content
    )
    script_content = re.sub(
        r'MeanQFunctionFactory\([^)]*bootstrap\s*=\s*\w+[^)]*\)',
        'MeanQFunctionFactory()',
        script_content
    )
    
    # Add protective comment before CreateOPEInput calls
    script_content = re.sub(
        r'(prep = CreateOPEInput\()',
        r'# MODIFIABLE SECTION: CreateOPEInput parameters can be optimised\n# VALID PARAMETERS: env, model_args, gamma, bandwidth, state_scaler, action_scaler, device\n# CRITICAL: DO NOT add non-existent parameters like n_bootstrap_samples, n_trajectories_on_policy_evaluation\n# Note: n_trajectories_on_policy_evaluation belongs to obtain_whole_inputs() method, not CreateOPEInput constructor\n\1',
        script_content
    )
    
    # Add protective comments before MeanQFunctionFactory calls
    script_content = re.sub(
        r'("q_func_factory":\s*MeanQFunctionFactory\(\))',
        r'# CRITICAL: DO NOT add parameters to MeanQFunctionFactory() - only share_encoder (bool) is valid\n            # FIXED: Parameters like n_critics, n_quantiles do NOT exist in this API\n            \1',
        script_content
    )
    script_content = re.sub(
        r'(q_func_factory=MeanQFunctionFactory\(\))',
        r'# CRITICAL: DO NOT add parameters to MeanQFunctionFactory() - only share_encoder (bool) is valid\n    # FIXED: Parameters like n_critics, n_quantiles do NOT exist in this API\n    \1',
        script_content
    )
    
    # Replace common IPython magic commands with alternatives
    replacements = {
        # Time magic commands
        r"get_ipython\(\)\.run_line_magic\('time',.*?\)": "# Time magic removed",
        r"get_ipython\(\)\.run_line_magic\('timeit',.*?\)": "# Timeit magic removed",
        
        # System commands
        r"get_ipython\(\)\.system\('(.+?)'\)": r'import subprocess; subprocess.run("\1", shell=True)',
        
        # Load magic
        r"get_ipython\(\)\.run_line_magic\('load',.*?\)": "# Load magic removed",
        
        # Other magic commands
        r"get_ipython\(\)\.run_line_magic\('.*?',.*?\)": "# IPython magic removed",
        r"get_ipython\(\)\.run_cell_magic\('.*?',.*?\)": "# IPython cell magic removed",
    }
    
    for pattern, replacement in replacements.items():
        script_content = re.sub(pattern, replacement, script_content)
    
    # Add compatibility fixes for standalone execution
    compatibility_fixes = """# COMPATIBILITY FIXES
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)

# Fix matplotlib to prevent plots from displaying
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend
import matplotlib.pyplot as plt
plt.ioff()  # Turn off interactive mode

# Override plt.show to prevent any plots from displaying
original_show = plt.show
def no_show(*args, **kwargs):
    pass  # Do nothing - plots are saved but not displayed
plt.show = no_show

"""
    
    # Insert compatibility fixes after import statements
    lines = script_content.split('\n')
    import_end_idx = 0
    in_multiline_import = False
    
    for i, line in enumerate(lines):
        stripped = line.strip()
        
        # Handle multiline imports
        if 'import (' in stripped or 'from ' in stripped and '(' in stripped:
            in_multiline_import = True
        elif in_multiline_import and ')' in stripped:
            in_multiline_import = False
            import_end_idx = i + 1 
            continue
        elif in_multiline_import:
            continue
        
        # Handle single line imports
        if stripped and not stripped.startswith('import ') and not stripped.startswith('from ') and not in_multiline_import:
            # This is the first line of code after imports
            import_end_idx = i
            break 
    
    # Insert compatibility fixes
    lines.insert(import_end_idx, compatibility_fixes)
    script_content = '\n'.join(lines)
    
    return script_content

def generate_training_bypass_injection(notebook_path: str, datasets: dict, models: dict, notebook_node=None) -> str:
    """
    Generate injection code for TRAINING_BYPASS mode with smart architecture detection to bypass expensive training steps.
    """
    # Determine environment and action type from notebook path
    notebook_name = os.path.splitext(os.path.basename(notebook_path))[0]
    if '/basic/' in notebook_path:
        env_type = 'basic'
    elif '/rec/' in notebook_path:
        env_type = 'rec'
    elif '/rtb/' in notebook_path:
        env_type = 'rtb'
    else:
        env_type = 'basic'  # fallback
    
    action_type = 'continuous' if 'continuous' in notebook_name.lower() else 'discrete'
    
    # Build environment creation code
    if env_type == 'basic':
        env_code = f'env = gym.make("BasicEnv-{action_type}-v0")'
    elif env_type == 'rec':
        env_code = 'env = gym.make("RECEnv-v0")'
    else:  # rtb
        env_code = f'env = gym.make("RTBEnv-{action_type}-v0")'
    
    # Build dataset loading code
    train_path = datasets.get('train', '')
    test_path = datasets.get('test', '')
    
    dataset_code = f"""
# Load pre-computed datasets
print("Loading datasets from artifacts...")
with open(r'{train_path}', 'rb') as f:
    train_logged_dataset = pickle.load(f)
with open(r'{test_path}', 'rb') as f:
    test_logged_dataset = pickle.load(f)
print("Datasets loaded successfully.")
"""
    
    # Detect required imports
    anchor_idx = -1
    if notebook_node:
        for i, cell in enumerate(notebook_node.cells):
            if cell.cell_type == 'code' and 'from scope_rl.ope import CreateOPEInput' in cell.source:
                anchor_idx = i
                break
    
    extra_imports = []
    if notebook_node and anchor_idx != -1:
        extra_imports = detect_required_imports(notebook_path, anchor_idx, notebook_node)
    else:
        # Fallback: include all possible imports
        extra_imports = [
            'from scope_rl.policy import SoftmaxHead',
            'from scope_rl.policy import ContinuousEvalHead',
            'from scope_rl.policy import TruncatedGaussianHead'
        ]
    
    # Build model loading code
    model_loading_code = ""
    
    # Get the basic d3rlpy imports we need
    model_imports = set()
    
    # Track loaded models to avoid duplicates
    loaded_models = set()
    
    for model_name, model_path in models.items():
        # Get architecture using reliable notebook-based detection
        hidden_units = get_model_architecture_from_notebook(notebook_path, model_name)
        
        # Create unique variable name for each model
        clean_name = model_name.replace('_continuous', '').replace('_discrete', '')
        if clean_name in loaded_models:
            clean_name = model_name  # Use full name to make it unique
        loaded_models.add(clean_name)
        
        if 'sac' in model_name.lower():
            if action_type == 'continuous':
                model_imports.add('from d3rlpy.algos import SACConfig')
                model_loading_code += f"""
# Load {model_name} with architecture: {hidden_units}
# CRITICAL: DO NOT MODIFY ARCHITECTURES - Models MUST be loaded with their trained architectures
# Changing these values will cause RuntimeError: size mismatch
{clean_name} = SACConfig(
    actor_encoder_factory=VectorEncoderFactory(hidden_units={hidden_units}),  # FIXED: DO NOT CHANGE
    critic_encoder_factory=VectorEncoderFactory(hidden_units={hidden_units}),  # FIXED: DO NOT CHANGE
    q_func_factory=MeanQFunctionFactory(),
    action_scaler=MinMaxActionScaler(
        minimum=env.action_space.low,
        maximum=env.action_space.high,
    )
).create(device=device)
{clean_name}.build_with_env(env)
{clean_name}.load_model(r'{model_path}')
print("Loaded model: {model_name}")
"""
        
        elif 'cql' in model_name.lower():
            if action_type == 'continuous':
                model_imports.add('from d3rlpy.algos import CQLConfig')
                model_loading_code += f"""
# Load {model_name} with architecture: {hidden_units}
# CRITICAL: DO NOT MODIFY ARCHITECTURES - Models MUST be loaded with their trained architectures
# Changing these values will cause RuntimeError: size mismatch
{clean_name} = CQLConfig(
    actor_encoder_factory=VectorEncoderFactory(hidden_units={hidden_units}),  # FIXED: DO NOT CHANGE
    critic_encoder_factory=VectorEncoderFactory(hidden_units={hidden_units}),  # FIXED: DO NOT CHANGE
    q_func_factory=MeanQFunctionFactory(),
    action_scaler=MinMaxActionScaler(
        minimum=env.action_space.low,
        maximum=env.action_space.high,
    )
).create(device=device)
{clean_name}.build_with_env(env)
{clean_name}.load_model(r'{model_path}')
print("Loaded model: {model_name}")
"""
            else:
                model_imports.add('from d3rlpy.algos import DiscreteCQLConfig')
                model_loading_code += f"""
# Load {model_name} with architecture: {hidden_units}
# CRITICAL: DO NOT MODIFY ARCHITECTURES - Models MUST be loaded with their trained architectures
# Changing these values will cause RuntimeError: size mismatch
{clean_name} = DiscreteCQLConfig(
    encoder_factory=VectorEncoderFactory(hidden_units={hidden_units}),  # FIXED: DO NOT CHANGE
    q_func_factory=MeanQFunctionFactory()
).create(device=device)
{clean_name}.build_with_env(env)
{clean_name}.load_model(r'{model_path}')
print("Loaded model: {model_name}")
"""
        
        elif 'ddqn' in model_name.lower():
            model_imports.add('from d3rlpy.algos import DoubleDQNConfig')
            model_loading_code += f"""
# Load {model_name} with architecture: {hidden_units}
# CRITICAL: DO NOT MODIFY ARCHITECTURES - Models MUST be loaded with their trained architectures
# Changing these values will cause RuntimeError: size mismatch
{clean_name} = DoubleDQNConfig(
    encoder_factory=VectorEncoderFactory(hidden_units={hidden_units}),  # FIXED: DO NOT CHANGE
    q_func_factory=MeanQFunctionFactory()
).create(device=device)
{clean_name}.build_with_env(env)
{clean_name}.load_model(r'{model_path}')
print("Loaded model: {model_name}")
"""
        
        elif 'td3bc' in model_name.lower():
            model_imports.add('from d3rlpy.algos import TD3PlusBCConfig')
            if action_type == 'continuous':
                model_loading_code += f"""
# Load {model_name} with architecture: {hidden_units}
# CRITICAL: DO NOT MODIFY ARCHITECTURES - Models MUST be loaded with their trained architectures
# Changing these values will cause RuntimeError: size mismatch
{clean_name} = TD3PlusBCConfig(
    actor_encoder_factory=VectorEncoderFactory(hidden_units={hidden_units}),  # FIXED: DO NOT CHANGE
    critic_encoder_factory=VectorEncoderFactory(hidden_units={hidden_units}),  # FIXED: DO NOT CHANGE
    action_scaler=MinMaxActionScaler(
        minimum=env.action_space.low,
        maximum=env.action_space.high,
    )
).create(device=device)
{clean_name}.build_with_env(env)
{clean_name}.load_model(r'{model_path}')
print("Loaded model: {model_name}")
"""
        
        elif 'iql' in model_name.lower():
            model_imports.add('from d3rlpy.algos import IQLConfig')
            if action_type == 'continuous':
                model_loading_code += f"""
# Load {model_name} with architecture: {hidden_units}
# CRITICAL: DO NOT MODIFY ARCHITECTURES - Models MUST be loaded with their trained architectures
# Changing these values will cause RuntimeError: size mismatch
{clean_name} = IQLConfig(
    actor_encoder_factory=VectorEncoderFactory(hidden_units={hidden_units}),  # FIXED: DO NOT CHANGE
    critic_encoder_factory=VectorEncoderFactory(hidden_units={hidden_units}),  # FIXED: DO NOT CHANGE
    action_scaler=MinMaxActionScaler(
        minimum=env.action_space.low,
        maximum=env.action_space.high,
    )
).create(device=device)
{clean_name}.build_with_env(env)
{clean_name}.load_model(r'{model_path}')
print("Loaded model: {model_name}")
"""
            else:
                model_loading_code += f"""
# Load {model_name} with architecture: {hidden_units}
# CRITICAL: DO NOT MODIFY ARCHITECTURES - Models MUST be loaded with their trained architectures
# Changing these values will cause RuntimeError: size mismatch
{clean_name} = IQLConfig(
    encoder_factory=VectorEncoderFactory(hidden_units={hidden_units}),  # FIXED: DO NOT CHANGE
    q_func_factory=MeanQFunctionFactory()
).create(device=device)
{clean_name}.build_with_env(env)
{clean_name}.load_model(r'{model_path}')
print("Loaded model: {model_name}")
"""
    
    # Add correct random policy import based on action type
    if action_type == 'continuous':
        model_imports.add('from d3rlpy.algos import RandomPolicyConfig as ContinuousRandomPolicyConfig')
    else:
        model_imports.add('from d3rlpy.algos import DiscreteRandomPolicyConfig')
    
    # Add random policy instantiation based on action type
    if action_type == 'continuous':
        random_policy_code = """
# Create random policy for continuous action space
random = ContinuousRandomPolicyConfig().create(device=device)
random.build_with_env(env)
"""
    else:
        random_policy_code = """
# Create random policy for discrete action space
random = DiscreteRandomPolicyConfig().create(device=device)
random.build_with_env(env)
"""
    
    # Build the complete injection code
    injection_code = f"""
# ==== TRAINING BYPASS MODE: Smart Artifact Loading ====
# Loading essential artifacts (models and datasets) for OPE evaluation
# 
# CRITICAL MODEL ARCHITECTURE RULE:
# - Pre-trained models below MUST use their exact trained architectures (DO NOT MODIFY)
# - Only FQE model and other hyperparameters in CreateOPEInput can be optimised
# - Changing pre-trained model architectures causes RuntimeError: size mismatch
print("TRAINING BYPASS MODE: Loading pre-computed datasets and models...")

import pickle
import gym
import torch
import numpy as np
import pandas as pd
import scope_rl
import warnings
warnings.filterwarnings('ignore')

# Environment imports
from basicgym import BasicEnv
from recgym import RECEnv
from rtbgym import RTBEnv

# Core imports
from d3rlpy.models.encoders import VectorEncoderFactory
from d3rlpy.models.q_functions import MeanQFunctionFactory
from d3rlpy.preprocessing import MinMaxObservationScaler, MinMaxActionScaler

# SCOPE-RL policy imports
from scope_rl.policy import EpsilonGreedyHead
{chr(10).join(extra_imports)}

# Model imports
{chr(10).join(model_imports)}

# Fixed configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
random_state = 12345

# Create logs directory
import os
from pathlib import Path
Path("logs/").mkdir(exist_ok=True)

# Create environment
{env_code}

{dataset_code}

{model_loading_code}

{random_policy_code}

print("TRAINING BYPASS MODE: All artifacts loaded successfully")
print("Ready for OPE evaluation - hyperparameters can be optimised in CreateOPEInput")
"""
    
    return injection_code

def filter_redundant_cells(cells: list, anchor_idx: int) -> list:
    """
    Filter out cells that contain redundant model loading/training code to prevent double injection.
    Keep only OPE evaluation logic after injection.
    """
    # Patterns that indicate redundant model loading/training code
    redundant_patterns = [
        '.create(device=device)',  # More specific to avoid breaking OPE setup
        '.build_with_env(env)',
        '.load_model(',
        'CQLConfig(',  # More specific to avoid breaking imports
        'SACConfig(',
        'DoubleDQNConfig(',
        'TD3PlusBCConfig(',
        'IQLConfig(',
        'RandomPolicyConfig(',
        'DiscreteRandomPolicyConfig(',
        'VectorEncoderFactory(',  # More specific
        'MeanQFunctionFactory()',  # More specific
        'fit(',
        'fit_online(',
        'train(',
        # Dataset generation patterns
        'collect_episodes',
        'ConstantEpsilonGreedy',
        'LinearDecayEpsilonGreedy',
    ]
    
    # OPE evaluation patterns that we want to KEEP
    ope_patterns = [
        'CreateOPEInput',
        'prep.obtain_whole_inputs',
        'ope.evaluate_performance_of_ope_estimators',
        'ope.summarize_off_policy_estimates',
        'ope.visualize',
        'plt.',
        'print(',
        'display(',
        'policy_value_df_dict',
        'policy_value_interval_df_dict',
    ]
    
    filtered_cells = []
    
    for i, cell in enumerate(cells[anchor_idx:], start=anchor_idx):
        if cell.cell_type != 'code':
            # Keep all non-code cells (markdown, etc.)
            filtered_cells.append(cell)
            continue
        
        source = cell.source.strip()
        if not source:
            # Keep empty cells
            filtered_cells.append(cell)
            continue
        
        # Check if cell contains OPE evaluation logic
        has_ope_logic = any(pattern in source for pattern in ope_patterns)
        
        # Check if cell contains redundant model loading code
        has_redundant_code = any(pattern in source for pattern in redundant_patterns)
        
        if has_ope_logic:
            # If cell has OPE logic, always keep it (even if it has some redundant code)
            # This prevents breaking legitimate OPE setup code
            filtered_cells.append(cell)
        elif not has_redundant_code:
            # No redundant code detected - keep the cell
            filtered_cells.append(cell)
        # else: Pure redundant code - skip this cell
    
    return filtered_cells

def validate_python_syntax(script_content: str, notebook_path: str) -> bool:
    """
    Validate that the generated Python script has correct syntax and report any syntax errors found.
    """
    try:
        ast.parse(script_content)
        return True
    except SyntaxError as e:
        print(f"Syntax error in converted script from {notebook_path}:")
        print(f"  Line {e.lineno}: {e.msg}")
        print(f"  Text: {e.text}")
        return False

def convert_notebook_to_script(notebook_path, output_script_path=None, inject_csv_export=True):
    """
    Convert Jupyter Notebook to Python script with enhanced bypass workflow and automatic CSV export injection.
    """
    # File validation
    if not os.path.isfile(notebook_path):
        raise FileNotFoundError(f"The file {notebook_path} does not exist.")
    
    if not notebook_path.endswith('.ipynb'):
        raise ValueError("The file is not a Jupyter Notebook.")
    
    # Load notebook
    with open(notebook_path, 'r', encoding='utf-8') as f:
        notebook_content = f.read()
    
    # Parse the notebook content into a notebook node object
    notebook_node = nbformat.reads(notebook_content, as_version=4)
    
    # Determine execution mode based on available artifacts
    execution_mode = determine_execution_mode(notebook_path)
    
    if execution_mode == ExecutionMode.TRAINING_BYPASS:
        print("TRAINING BYPASS MODE: Detected datasets and models, bypassing expensive training...")
        
        # Find the CreateOPEInput anchor cell
        anchor_idx = -1
        for i, cell in enumerate(notebook_node.cells):
            if cell.cell_type == 'code' and 'from scope_rl.ope import CreateOPEInput' in cell.source:
                anchor_idx = i
                break
        
        if anchor_idx == -1:
            print("Warning: CreateOPEInput import not found. Falling back to full execution.")
            execution_mode = ExecutionMode.FULL_RUN
        else:
            # Load artifacts
            datasets = find_datasets(notebook_path)
            models = find_models(notebook_path)
            
            # Generate injection code with notebook_node for import detection
            injection_code = generate_training_bypass_injection(notebook_path, datasets, models, notebook_node)
            
            # Create injection cell
            injection_cell = nbformat.v4.new_code_cell(source=injection_code)
            
            # Reconstruct notebook: injection cell + OPE cells from anchor onwards
            # notebook_node.cells = [injection_cell] + notebook_node.cells[anchor_idx:]
            
            # Temporarily disable cell filtering to fix indentation issues
            # filtered_cells = filter_redundant_cells(notebook_node.cells, anchor_idx)
            notebook_node.cells = [injection_cell] + notebook_node.cells[anchor_idx:]
            
            print(f"TRAINING BYPASS MODE: Injected smart bypass code, starting from CreateOPEInput cell")
    
    elif execution_mode == ExecutionMode.FULL_RUN:
        print("FULL RUN MODE: No artifacts detected, executing complete notebook...")
    
    # Convert to Python script
    exporter = PythonExporter()
    script_content, _ = exporter.from_notebook_node(notebook_node)
    
    # Apply post-processing
    script_content = remove_comments(script_content)
    script_content = normalise_notebook_code(script_content)
    
    # Validate syntax before CSV injection
    if not validate_python_syntax(script_content, notebook_path):
        print("Warning: Syntax errors detected before CSV injection. Attempting to proceed...")
    
    # CSV Export Injection (if enabled)
    if inject_csv_export:
        # Check if CSV export already exists (using regex to handle quote variations)
        if not re.search(r"\.to_csv\(['\"]out\.csv['\"]", script_content) and "# AUTO-CSV-EXPORT" not in script_content:
            
            # First, detect notebook type regardless of execution mode
            # Detect variables that need to be exported (OBP pattern)
            obp_pattern = r"(\w*relative_ee\w*)\s*=\s*|(\w+)\s*=\s*ope\.summarize_estimators_comparison"
            obp_matches = re.findall(obp_pattern, script_content)
            obp_variables = {match[0] or match[1] for match in obp_matches if match[0] or match[1]}
            
            # Also detect OBP confidence interval variables (estimated_interval)
            obp_interval_pattern = r"(\w*estimated_interval\w*)\s*=\s*ope\.summarize_off_policy_estimates"
            obp_interval_matches = re.findall(obp_interval_pattern, script_content)
            obp_interval_variables = {match for match in obp_interval_matches if match}
            
            # Combine OBP variables
            all_obp_variables = obp_variables | obp_interval_variables
            
            # Detect SCOPE-RL patterns
            scope_patterns = [
                r"(\w+)\s*=\s*ope\.summarize_off_policy_estimates",
                r"(\w+)\s*=\s*ope\.evaluate_performance_of_ope_estimators"
            ]
            scope_variables = set()
            for pattern in scope_patterns:
                scope_matches = re.findall(pattern, script_content)
                scope_variables.update(scope_matches)
            
            # Also check for tuple unpacking patterns (multiline)
            if "policy_value_df_dict" in script_content:
                scope_variables.add("policy_value_df_dict")
            if "policy_value_interval_df_dict" in script_content:
                scope_variables.add("policy_value_interval_df_dict")
            
            # Determine notebook type
            is_obp_notebook = bool(all_obp_variables)
            is_scope_rl_notebook = bool(scope_variables)
            
            # SCOPE-RL notebooks (with multi-policy support)
            if is_scope_rl_notebook and execution_mode == ExecutionMode.TRAINING_BYPASS:
                print("TRAINING BYPASS MODE: Forcing consistent multi-policy OPE CSV export")
                # Simplified export logic to handle ALL policies in policy_value_df_dict
                export_logic = """# Simplified SCOPE-RL OPE CSV export
exported = False
if 'policy_value_df_dict' in locals() and policy_value_df_dict:
    print("Exporting ALL policies from policy_value_df_dict")
    all_policy_data = []
    
    for policy_name, policy_df in policy_value_df_dict.items():
        # Reset index to get estimator names as a column
        policy_df_reset = policy_df.reset_index()
        policy_df_reset.rename(columns={'index': 'estimator'}, inplace=True)
        
        # Keep all metrics (including relative_policy_value)
        metric_columns = [col for col in policy_df_reset.columns if col != 'estimator']
        
        # Create clean data for each estimator and metric
        for _, row in policy_df_reset.iterrows():
            estimator = row['estimator']
            for metric in metric_columns:
                result_value = row[metric]
                if pd.notna(result_value):  # Only include non-null values
                    all_policy_data.append({
                        'iteration': 0,
                        'policy': policy_name,
                        'estimator': estimator,
                        'metric': metric,
                        'result': result_value,
                        'percentage_change': 0
                    })
    
    # Create DataFrame and remove any duplicates
    combined_df = pd.DataFrame(all_policy_data)
    combined_df = combined_df.drop_duplicates(subset=['policy', 'estimator', 'metric'], keep='first')
    
    # Sort by metric first to group all policy_value together, then all relative_policy_value together
    metric_order = ['policy_value', 'relative_policy_value']  # Custom order: policy_value first, then relative_policy_value
    metric_order_dict = {metric: i for i, metric in enumerate(metric_order)}
    combined_df['metric_order'] = combined_df['metric'].map(metric_order_dict).fillna(999)
    combined_df = combined_df.sort_values(['metric_order', 'policy', 'estimator'], ascending=[True, True, True])
    combined_df = combined_df.drop('metric_order', axis=1)  # Remove temporary sorting column
    
    combined_df = combined_df[['iteration', 'policy', 'estimator', 'metric', 'result', 'percentage_change']]
    combined_df.to_csv('out.csv', index=False)
    print(f"Exported {len(combined_df)} unique policy/estimator/metric combinations from {len(policy_value_df_dict)} policies")
    exported = True

else:
    print("Warning: No policy_value_df_dict found")
    pd.DataFrame({'iteration': [0], 'policy': ['baseline'], 'estimator': ['dm'], 'metric': ['policy_value'], 'result': [1.0], 'percentage_change': [0]}).to_csv('out.csv', index=False)"""
                
                # Indent the export_logic for proper try block
                indented_logic = '\n'.join('    ' + line for line in export_logic.strip().split('\n'))
                
                export_code = f"""
# AUTO-CSV-EXPORT: Consistent OPE Export for Training Bypass Mode
import pandas as pd
try:
{indented_logic}
    print("CSV export completed successfully")
except Exception as e:
    print(f"CSV export failed: {{e}}")
    # Create minimal output to prevent downstream errors
    pd.DataFrame({{'estimator': ['baseline'], 'result': [1.0]}}).set_index('estimator').to_csv('out.csv')
    print("Created fallback CSV output")
"""
            
            elif is_obp_notebook:
                # OBP pattern export logic (legacy format without policy column)
                print("Detected OBP notebook - using legacy CSV format")
                variables = sorted(list(all_obp_variables))
                
                if len(variables) == 1:
                    var_name = variables[0]
                    export_logic = f"""    # Convert OBP dataframe to standard format
    all_data = []
    
    # Handle relative_ee type variables (single dataframe)
    if 'relative_ee' in '{var_name}' and hasattr({var_name}, 'reset_index'):
        df_converted = {var_name}.reset_index()
        estimator_col = 'index' if 'index' in df_converted.columns else df_converted.columns[0]
        
        for _, row in df_converted.iterrows():
            estimator = row[estimator_col]
            for col in df_converted.columns:
                if col != estimator_col:
                    metric = col
                    result = row[col]
                    if pd.notna(result):
                        all_data.append({{
                            'iteration': 0,
                            'metric': metric,
                            'estimator': estimator,
                            'result': result,
                            'percentage change': 0
                        }})
    
    # Handle estimated_interval type variables (confidence intervals)
    elif 'estimated_interval' in '{var_name}' and hasattr({var_name}, 'reset_index'):
        df_converted = {var_name}.reset_index()
        estimator_col = 'index' if 'index' in df_converted.columns else df_converted.columns[0]
        
        for _, row in df_converted.iterrows():
            estimator = row[estimator_col]
            for col in df_converted.columns:
                if col != estimator_col:
                    metric = col
                    result = row[col]
                    if pd.notna(result):
                        all_data.append({{
                            'iteration': 0,
                            'metric': metric,
                            'estimator': estimator,
                            'result': result,
                            'percentage change': 0
                        }})
    
    # Fallback for other variable types
    elif hasattr({var_name}, 'reset_index'):
        df_converted = {var_name}.reset_index()
        estimator_col = 'index' if 'index' in df_converted.columns else df_converted.columns[0]
        
        for _, row in df_converted.iterrows():
            estimator = row[estimator_col]
            for col in df_converted.columns:
                if col != estimator_col:
                    metric = col
                    result = row[col]
                    if pd.notna(result):
                        all_data.append({{
                            'iteration': 0,
                            'metric': metric,
                            'estimator': estimator,
                            'result': result,
                            'percentage change': 0
                        }})
    
    if all_data:
        result_df = pd.DataFrame(all_data)
        # Sort by metric first, then by estimator to group metrics together
        # Custom order: mean first, CIs in middle, relative-ee last
        metric_order = ['mean', '95.0% CI (lower)', '95.0% CI (upper)', 'relative-ee']
        metric_order_dict = {metric: i for i, metric in enumerate(metric_order)}
        result_df['metric_order'] = result_df['metric'].map(metric_order_dict).fillna(999)
        result_df = result_df.sort_values(['metric_order', 'estimator'], ascending=[True, True])
        result_df = result_df.drop('metric_order', axis=1)
        result_df.to_csv('out.csv', index=False)
    else:
        # Fallback for empty data
        pd.DataFrame({{'iteration': [0], 'metric': ['relative-ee'], 'estimator': ['baseline'], 'result': [1.0], 'percentage change': [0]}}).to_csv('out.csv', index=False)"""
                else:
                    # Multiple variables - need to merge them
                    export_logic = """    # Convert multiple OBP dataframes to standard format
    all_data = []
    dataframes_to_merge = []"""
                    
                    for i, var in enumerate(variables):
                        export_logic += f"""
    
    # Process {var}
    if hasattr({var}, 'reset_index'):
        df_{i} = {var}.reset_index()
        estimator_col = 'index' if 'index' in df_{i}.columns else df_{i}.columns[0]
        
        for _, row in df_{i}.iterrows():
            estimator = row[estimator_col]
            for col in df_{i}.columns:
                if col != estimator_col:
                    metric = col
                    result = row[col]
                    if pd.notna(result):
                        all_data.append({{
                            'iteration': 0,
                            'metric': metric,
                            'estimator': estimator,
                            'result': result,
                            'percentage change': 0
                        }})"""
                    
                    export_logic += """
    
    if all_data:
        result_df = pd.DataFrame(all_data)
        # Remove duplicates if any
        result_df = result_df.drop_duplicates(subset=['estimator', 'metric'], keep='first')
        # Sort by metric first, then by estimator to group metrics together
        # Custom order: mean first, CIs in middle, relative-ee last
        metric_order = ['mean', '95.0% CI (lower)', '95.0% CI (upper)', 'relative-ee']
        metric_order_dict = {metric: i for i, metric in enumerate(metric_order)}
        result_df['metric_order'] = result_df['metric'].map(metric_order_dict).fillna(999)
        result_df = result_df.sort_values(['metric_order', 'estimator'], ascending=[True, True])
        result_df = result_df.drop('metric_order', axis=1)
        result_df.to_csv('out.csv', index=False)
    else:
        # Fallback for empty data
        pd.DataFrame({{'iteration': [0], 'metric': ['relative-ee'], 'estimator': ['baseline'], 'result': [1.0], 'percentage change': [0]}}).to_csv('out.csv', index=False)"""
                
                export_code = f"""
# AUTO-CSV-EXPORT: Export results for analysis (OBP format)
import pandas as pd
try:
{export_logic}
    print("CSV export completed successfully")
except Exception as e:
    print(f"CSV export failed: {{e}}")
    # Create minimal output to prevent downstream errors
    pd.DataFrame({{'estimator': ['baseline'], 'result': [1.0]}}).set_index('estimator').to_csv('out.csv')
    print("Created fallback CSV output")
"""
            
            elif is_scope_rl_notebook:
                # SCOPE-RL pattern export logic (full run mode)
                print("Detected SCOPE-RL notebook - using new CSV format")
                var_list = sorted(list(scope_variables))
                
                if len(var_list) == 1:
                    export_logic = f"""
    # Try to export SCOPE-RL results
    if isinstance({var_list[0]}, dict):
        # Handle dictionary output from summarize_off_policy_estimates
        if {var_list[0]}:
            first_key = list({var_list[0]}.keys())[0]
            {var_list[0]}[first_key].rename_axis('estimator').to_csv('out.csv')
        else:
            print("No SCOPE-RL results to export")
    elif hasattr({var_list[0]}, 'to_csv'):
        # Handle DataFrame output
        {var_list[0]}.rename_axis('estimator').to_csv('out.csv')
    else:
        print(f"Cannot export {var_list[0]} - unsupported type: {{type({var_list[0]})}}")"""
                else:
                    # Multiple SCOPE-RL variables - try the first one
                    export_logic = f"""
    # Try to export SCOPE-RL results (first available)
    for var_name, var_obj in [('{ var_list[0]}', {var_list[0]}), ('{ var_list[1] if len(var_list) > 1 else var_list[0]}', {var_list[1] if len(var_list) > 1 else var_list[0]})]:
        if isinstance(var_obj, dict) and var_obj:
            first_key = list(var_obj.keys())[0]
            var_obj[first_key].rename_axis('estimator').to_csv('out.csv')
            print(f"Exported {{var_name}} results")
            break
        elif hasattr(var_obj, 'to_csv'):
            var_obj.rename_axis('estimator').to_csv('out.csv')
            print(f"Exported {{var_name}} results")
            break
    else:
        print("No exportable SCOPE-RL results found")"""
                
                export_code = f"""
# AUTO-CSV-EXPORT: Export results for analysis (SCOPE-RL format)
import pandas as pd
try:
{export_logic}
    print("CSV export completed successfully")
except Exception as e:
    print(f"CSV export failed: {{e}}")
    # Create minimal output to prevent downstream errors
    pd.DataFrame({{'estimator': ['baseline'], 'result': [1.0]}}).set_index('estimator').to_csv('out.csv')
    print("Created fallback CSV output")
"""
            else:
                # No variables detected - skip CSV injection
                print("No OBP or SCOPE-RL patterns detected - skipping CSV export")
                export_code = ""
            
            if export_code:
                # Add to script
                script_content = script_content.rstrip() + '\n' + export_code
                
                # Validate final syntax
                if validate_python_syntax(script_content, notebook_path):
                    print(f"CSV export injection successful for {notebook_path}")
                else:
                    print(f"Warning: CSV export injection may have introduced syntax errors")
    
    # Save script
    if output_script_path is None:
        output_script_path = notebook_path.replace('.ipynb', '.py')
    
    with open(output_script_path, 'w', encoding='utf-8') as f:
        f.write(script_content)
    
    print(f"Notebook converted to script: {output_script_path}")
    print(f"Execution mode: {execution_mode.value}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert Jupyter Notebook to Python script.")
    parser.add_argument("notebook_path", help="Path to the Jupyter Notebook file.")
    parser.add_argument("output_script_path", nargs='?', help="Path to save the converted Python script. If not provided, the script will be saved in the same directory as the input file with the same name.")
    parser.add_argument("--no-csv-injection", action="store_true", help="Disable automatic CSV export injection.")
    
    args = parser.parse_args()
    
    inject_csv = not args.no_csv_injection
    convert_notebook_to_script(args.notebook_path, args.output_script_path, inject_csv_export=inject_csv)