#!/usr/bin/env python3
"""
Metadata Preprocessing Script for ALS Clinical Data

This script preprocesses ALS clinical metadata using Multiple Imputation by Chained Equations (MICE).
It handles missing data, performs data cleaning, and prepares the dataset for downstream analysis.

Key features:
- Processes multiple clinical data sources
- Handles missing data using MICE imputation
- Performs sensitivity analysis
- One-hot encodes categorical variables
- Provides detailed logging and quality control
"""

import logging
import pandas as pd
import numpy as np
import os
import sklearn
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from sklearn.experimental import enable_iterative_imputer  # noqa
from sklearn.impute import IterativeImputer, SimpleImputer
from packaging import version
from datetime import datetime
from sklearn.feature_selection import VarianceThreshold
import argparse

# Check scikit-learn version for OneHotEncoder compatibility
sklearn_version = version.parse(sklearn.__version__)
if sklearn_version >= version.parse("1.2"):
    encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
else:
    encoder = OneHotEncoder(sparse=False, handle_unknown='ignore')

def setup_logging():
    """Configure logging with timestamp, level, and message."""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    return logging.getLogger(__name__)

def read_csv_strip_columns(file_path, dtype=None, **kwargs):
    """
    Reads a CSV file and strips whitespace from column names.
    
    Args:
        file_path (str): Path to the CSV file
        dtype (dict, optional): Data types for columns
        **kwargs: Additional arguments for pd.read_csv
        
    Returns:
        pd.DataFrame: Processed dataframe with stripped column names
    """
    try:
        df = pd.read_csv(file_path, dtype=dtype, **kwargs)
        df.columns = df.columns.str.strip()
        df['Participant_ID'] = df['Participant_ID'].str.strip().str.upper()
        return df
    except Exception as e:
        logging.error(f"Error reading {file_path}: {e}")
        return None

def get_imputable_columns(df):
    """
    Identifies columns that are eligible for MICE imputation.
    
    Args:
        df (pd.DataFrame): Input dataframe
        
    Returns:
        list: Columns that can be imputed
    """
    # Define columns that should never be imputed
    exclude_from_imputation = [
        'Participant_ID',
        'ALSFRS_R_Baseline_Value',
        'ALSFRS_R_Latest_Value',
        'ALSFRS_R_PROGRESSION_SLOPE',
        'Site_of_Onset',
        'El_Escorial_Criteria_Combined',
        'REVISED_EL_ESCORIAL_CRITERIA',
        'elescrlr',
        'elescrlr_mapped',
        'dieddt',
        'Visit_Date'
    ]
    
    # Get numeric columns that can be imputed
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    imputable_cols = [col for col in numeric_cols if col not in exclude_from_imputation]
    
    logging.info(f"Columns selected for MICE imputation: {imputable_cols}")
    return imputable_cols

def mice_imputation_sensitivity(df, num_imputations=5, impute_cols=None):
    """
    Perform MICE imputation multiple times for sensitivity analysis.
    
    Args:
        df (pd.DataFrame): Input dataframe
        num_imputations (int): Number of imputation iterations
        impute_cols (list): Columns to impute
        
    Returns:
        list: List of imputed dataframes
    """
    if impute_cols is None or len(impute_cols) == 0:
        logging.warning("No columns specified for imputation")
        return [df] * num_imputations

    logging.info(f"Starting MICE imputation on {len(impute_cols)} columns")
    imputed_dfs = []
    
    try:
        for seed in range(num_imputations):
            imputer = IterativeImputer(
                random_state=seed,
                max_iter=10,
                initial_strategy='mean',
                skip_complete=True
            )
            df_imputed = df.copy()
            df_imputed[impute_cols] = imputer.fit_transform(df[impute_cols])
            imputed_dfs.append(df_imputed)
            logging.info(f"Completed imputation iteration {seed + 1}/{num_imputations}")
    except Exception as e:
        logging.error(f"Error during MICE imputation: {e}")
        return [df] * num_imputations

    return imputed_dfs

def encode_categorical_features(df):
    """
    One-hot encodes specific categorical features.
    
    Args:
        df (pd.DataFrame): Input dataframe
        
    Returns:
        pd.DataFrame: Dataframe with encoded categorical features
    """
    if 'Site_of_Onset' in df.columns:
        site_encoded = encoder.fit_transform(df[['Site_of_Onset']])
        site_encoded_columns = encoder.get_feature_names_out(['Site_of_Onset'])
        site_encoded_df = pd.DataFrame(site_encoded, columns=site_encoded_columns, index=df.index)
        df = pd.concat([df.drop(['Site_of_Onset'], axis=1), site_encoded_df], axis=1)

    if 'El_Escorial_Criteria_Combined' in df.columns:
        escorial_encoded = encoder.fit_transform(df[['El_Escorial_Criteria_Combined']])
        escorial_encoded_columns = encoder.get_feature_names_out(['El_Escorial_Criteria_Combined'])
        escorial_encoded_df = pd.DataFrame(escorial_encoded, columns=escorial_encoded_columns, index=df.index)
        df = pd.concat([df.drop(['El_Escorial_Criteria_Combined'], axis=1), escorial_encoded_df], axis=1)

    return df

def process_complete_data_with_MICE(input_dir, output_dir):
    """
    Main function to process and merge all data sources.
    
    Args:
        input_dir (str): Directory containing input data files
        output_dir (str): Directory for output files
        
    Returns:
        pd.DataFrame: Processed and imputed dataframe
    """
    logger = setup_logging()
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    try:
        logger.info("Starting complete data processing")

        # Process AALS Data Portal and get eligible participants
        dataportal_df, eligible_participants = process_aals_dataportal(input_dir)
        if dataportal_df is None or not eligible_participants:
            logger.error("No eligible participants found in AALS Data Portal")
            return None

        # Process other datasets and filter eligible participants
        datasets = [
            dataportal_df,
            process_aals_dx_hx(input_dir, eligible_participants),
            process_demographics(input_dir, eligible_participants),
            process_niv_log(input_dir, eligible_participants),
            process_mortality(input_dir, eligible_participants),
            process_reflexes(input_dir, eligible_participants),
            process_cns_lability(input_dir, eligible_participants),
            process_family_history(input_dir, eligible_participants),
            process_auxiliary_chemistry(input_dir, eligible_participants),
            process_ashworth_spasticity(input_dir, eligible_participants),
            process_grip_strength(input_dir, eligible_participants),
            process_vital_capacity(input_dir, eligible_participants),
            process_vital_signs(input_dir, eligible_participants),
        ]

        # Remove None datasets
        datasets = [df for df in datasets if df is not None and not df.empty]

        # Ensure one row per participant in each dataset
        datasets = [df.groupby('Participant_ID').first().reset_index() for df in datasets]

        # Merge datasets on Participant_ID
        final_df = datasets[0]
        for df in datasets[1:]:
            final_df = pd.merge(final_df, df, on='Participant_ID', how='left')

        # Remove any remaining controls
        final_df = final_df[final_df['Participant_ID'].str.startswith('CASE-')]

        # Drop duplicates if any
        final_df = final_df.drop_duplicates(subset='Participant_ID')

        # Remove 'Visit_Date' columns
        visit_date_columns = [col for col in final_df.columns if col.startswith('Visit_Date')]
        if visit_date_columns:
            logger.info(f"Removing {len(visit_date_columns)} 'Visit_Date' columns")
            final_df = final_df.drop(columns=visit_date_columns)

        # Combine El Escorial Criteria columns
        final_df['REVISED_EL_ESCORIAL_CRITERIA'] = final_df['REVISED_EL_ESCORIAL_CRITERIA'].replace(['', 'None'], np.nan)
        final_df['El_Escorial_Criteria_Combined'] = final_df['REVISED_EL_ESCORIAL_CRITERIA']
        if 'elescrlr_mapped' in final_df.columns:
            final_df.loc[final_df['El_Escorial_Criteria_Combined'].isna(), 'El_Escorial_Criteria_Combined'] = \
                final_df.loc[final_df['El_Escorial_Criteria_Combined'].isna(), 'elescrlr_mapped']
        else:
            logger.warning("'elescrlr_mapped' column not found")

        # Handle participants with missing 'El_Escorial_Criteria_Combined'
        most_frequent_escorial = final_df['El_Escorial_Criteria_Combined'].mode()[0]
        final_df['El_Escorial_Criteria_Combined'] = final_df['El_Escorial_Criteria_Combined'].fillna(most_frequent_escorial)

        # Drop the original columns
        final_df = final_df.drop(columns=['REVISED_EL_ESCORIAL_CRITERIA', 'elescrlr', 'elescrlr_mapped'], errors='ignore')

        # Calculate missingness per participant and remove participants exceeding threshold
        final_df['missing_percentage'] = final_df.isnull().mean(axis=1)
        missingness_threshold = 0.3
        before_missingness = len(final_df)
        final_df = final_df[final_df['missing_percentage'] <= missingness_threshold]
        after_missingness = len(final_df)
        logger.info(f"Removed {before_missingness - after_missingness} participants due to >30% missing data")

        # Get columns eligible for imputation
        impute_cols = get_imputable_columns(final_df)
        
        # Perform MICE Imputation with Sensitivity Analysis
        imputed_datasets = mice_imputation_sensitivity(final_df, num_imputations=5, impute_cols=impute_cols)

        # Choose first imputed dataset as final
        final_df = imputed_datasets[0]

        # Save imputation summary
        imputation_summary = pd.DataFrame({
            'Column': impute_cols,
            'Missing_Before': [final_df[col].isnull().sum() for col in impute_cols],
            'Total_Records': len(final_df)
        })
        imputation_summary['Missing_Percentage'] = (
            imputation_summary['Missing_Before'] / imputation_summary['Total_Records'] * 100
        )
        imputation_summary_file = os.path.join(output_dir, f'imputation_summary_{timestamp}.csv')
        imputation_summary.to_csv(imputation_summary_file, index=False)
        logger.info(f"Imputation summary saved to: {imputation_summary_file}")

        # Continue with categorical encoding and final processing
        final_df = encode_categorical_features(final_df)

        # Save final processed data
        output_file = os.path.join(output_dir, f'processed_als_metadata_cases_only_{timestamp}.csv')
        final_df.to_csv(output_file, index=False)
        logger.info(f"Final processed data saved to: {output_file}")

        return final_df
    except Exception as e:
        logger.error(f"Error in process_complete_data: {e}")
        logger.exception("Detailed exception information:")
        return None

def main():
    """Main function to run the metadata preprocessing pipeline."""
    parser = argparse.ArgumentParser(description='Process ALS clinical metadata using MICE imputation')
    parser.add_argument('--input-dir', required=True, help='Directory containing input data files')
    parser.add_argument('--output-dir', required=True, help='Directory for output files')
    args = parser.parse_args()

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    logger = setup_logging()
    logger.info("Starting ALS metadata preprocessing with MICE Imputation")
    
    final_df = process_complete_data_with_MICE(args.input_dir, args.output_dir)

    if final_df is not None and not final_df.empty:
        logger.info(f"Successfully processed {len(final_df)} participants")
        logger.info(f"Final dataset shape: {final_df.shape}")
    else:
        logger.error("Processing failed or no data available")

if __name__ == "__main__":
    main() 