import os
import pandas as pd
import numpy as np
from scipy.special import erfcinv
import argparse

import sys
sys.argv = [
    "sequential_threshold_processing.py",
    "results/constrained_search_results_step1",
    "results/constrained_search_results_step2",
    "results/constrained_search_results_step3",
    "results/constrained_search_results_step4",
    "--output",
    "results/constrained_search_results_combined",
]

def gaussian_noise_zscore_cutoff(num_ccg: int, false_positives: float = 0.005) -> float:
    """Determines the z-score cutoff based on Gaussian noise model and number of pixels.

    NOTE: This procedure assumes that the z-scores (normalized maximum intensity
    projections) are distributed according to a standard normal distribution. Here,
    this model is used to find the cutoff value such that there is at most
    'false_positives' number of false positives in all of the pixels.

    Parameters
    ----------
    num_ccg : int
        Total number of cross-correlograms calculated during template matching. Product
        of the number of pixels, number of defocus values, and number of orientations.
    false_positives : float, optional
        Number of false positives to allow in the image (over all pixels). Default is
        0.005 which corresponds to 0.5% false-positives.

    Returns
    -------
    float
        Z-score cutoff.
    """
    tmp = erfcinv(2.0 * false_positives / num_ccg)
    tmp *= np.sqrt(2.0)
    
    return float(tmp)

def process_files_sequentially(input_files, output_basename, false_positive_rate=0.005):
    """
    Process files sequentially, calculating thresholds based on cumulative correlations
    and updating particle parameters.
    
    Parameters
    ----------
    input_files : list
        Ordered list of input CSV files to process
    output_basename : str
        Base name for output files (without extension)
    false_positive_rate : float
        False positive rate to use for threshold calculation
    """
    # Dictionary to store particles from all steps
    all_particles = pd.DataFrame()
    
    # Dictionary to track total correlations
    total_correlations = 0
    
    # Process each file in order
    for step_idx, input_file in enumerate(input_files):
        step_num = step_idx + 1
        print(f"\nProcessing Step {step_num}: {input_file}")
        
        try:
            # Read the results file
            results_df = pd.read_csv(input_file)
            
            if results_df.empty:
                print(f"  Warning: Empty results file {input_file}")
                continue
            
            # Read parameters file if it exists
            params_file = input_file.replace(".csv", "_parameters.csv")
            if os.path.exists(params_file):
                try:
                    params_df = pd.read_csv(params_file)
                    if not params_df.empty and 'num_correlations' in params_df.columns:
                        correlations = int(params_df.iloc[0]['num_correlations'])
                        total_correlations += correlations
                        print(f"  Added {correlations} correlations (total: {total_correlations})")
                except Exception as e:
                    print(f"  Error reading parameters file {params_file}: {e}")
            
            # Calculate threshold based on total correlations
            threshold = gaussian_noise_zscore_cutoff(total_correlations, false_positive_rate)
            print(f"  Threshold for step {step_num}: {threshold:.4f} (based on {total_correlations} total correlations)")
            
            # Check if refined_scaled_mip column exists
            if 'refined_scaled_mip' not in results_df.columns:
                print(f"  Warning: refined_scaled_mip column not found in {input_file}, using mip instead")
                compare_col = 'scaled_mip'
            else:
                compare_col = 'refined_scaled_mip'
            
            # Filter particles above threshold
            above_threshold_df = results_df[results_df[compare_col] > threshold].copy()
            
            if above_threshold_df.empty:
                print(f"  No particles above threshold in {input_file}")
                continue
            
            # Print stats
            print(f"  {len(above_threshold_df)} of {len(results_df)} particles above threshold (using {compare_col})")
            
            # Add a step column to track which step this is from
            above_threshold_df['step'] = step_num
            
            # If this is the first step, just add all particles above threshold
            if step_num == 1:
                all_particles = above_threshold_df
            else:
                # For each particle in the new results
                for _, particle in above_threshold_df.iterrows():
                    particle_idx = particle['particle_index']
                    
                    # Check if this particle exists in our previous results
                    existing_particle = all_particles[all_particles['particle_index'] == particle_idx]
                    
                    if len(existing_particle) > 0:
                        # Particle exists, update parameters
                        idx_to_update = all_particles.index[all_particles['particle_index'] == particle_idx].tolist()[0]
                        
                        # Check if original offset columns exist
                        offset_cols = ['original_offset_phi', 'original_offset_theta', 'original_offset_psi']
                        
                        # Add original offset columns from step 1 if they don't exist yet
                        for col in offset_cols:
                            if col not in all_particles.columns:
                                all_particles[col] = 0.0
                        
                        # Add offset values from current step to existing values
                        for col in offset_cols:
                            if col in particle and pd.notna(particle[col]):
                                all_particles.at[idx_to_update, col] += particle[col]
                        
                        # Update all other parameters
                        for col in particle.index:
                            if col not in offset_cols and pd.notna(particle[col]):
                                all_particles.at[idx_to_update, col] = particle[col]
                        
                        # Update step
                        all_particles.at[idx_to_update, 'step'] = step_num
                    else:
                        # New particle, add it to the DataFrame
                        all_particles = pd.concat([all_particles, pd.DataFrame([particle])], ignore_index=True)
            
            # Save intermediate results for this step
            step_particles = all_particles[all_particles['step'] == step_num]
            if not step_particles.empty:
                output_file = f"{output_basename}_step{step_num}.csv"
                step_particles.to_csv(output_file, index=False)
                print(f"  Saved {len(step_particles)} particles for step {step_num}")
            
        except Exception as e:
            print(f"  Error processing file {input_file}: {e}")
    
    # Save final results
    if not all_particles.empty:
        output_file = f"{output_basename}_final.csv"
        all_particles.to_csv(output_file, index=False)
        
        # Save summary
        summary_data = {
            'total_particles': len(all_particles),
            'total_correlations': total_correlations,
            'final_threshold': threshold
        }
        summary_df = pd.DataFrame([summary_data])
        summary_df.to_csv(f"{output_basename}_summary.csv", index=False)
        
        print(f"\nProcessing complete. Final results saved to {output_file}")
        print(f"Total particles: {len(all_particles)}")
    
    return all_particles

def main():
    parser = argparse.ArgumentParser(description='Process results files sequentially')
    parser.add_argument('input_files', nargs='+', help='Ordered list of input CSV files to process')
    parser.add_argument('--output', '-o', required=True, help='Output file basename (without extension)')
    parser.add_argument('--false-positive-rate', '-f', type=float, default=0.005, 
                        help='False positive rate for threshold calculation (default: 0.005)')
    
    args = parser.parse_args()
    
    # Check if all files exist
    for input_file in args.input_files:
        if not os.path.exists(input_file):
            print(f"Error: File {input_file} does not exist!")
            return
    
    # Process files
    process_files_sequentially(args.input_files, args.output, args.false_positive_rate)

if __name__ == "__main__":
    main() 