#!/usr/bin/env python3
import click
import logging
import pandas as pd
import numpy as np
import re
import yaml
from pathlib import Path
from typing import Dict, List, Optional, Union, Set
from datetime import datetime


# Configure logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('data_scrubber')


def load_config(config_path: Path) -> Dict:
    """
    Load configuration from a YAML file.
    
    Parameters
    ----------
    config_path : Path
        Path to the configuration YAML file
        
    Returns
    -------
    Dict
        Configuration dictionary
    """
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)


def find_matching_columns(df: pd.DataFrame, column_names: List[str]) -> Dict[str, str]:
    """
    Find flexible matches for column names in a dataframe.
    
    Parameters
    ----------
    df : pd.DataFrame
        The dataframe to search in
    column_names : List[str]
        List of column names to find matches for
        
    Returns
    -------
    Dict[str, str]
        Dictionary mapping from config column names to actual dataframe column names
    """
    df_cols_set = set(df.columns)
    matches = {}
    
    for config_col in column_names:
        # Try exact match first
        if config_col in df_cols_set:
            matches[config_col] = config_col
            continue
            
        # Try replacing various newline styles in the config
        normalized_config_col = config_col.replace('\\n', '\n')
        if normalized_config_col in df_cols_set:
            matches[config_col] = normalized_config_col
            continue
            
        # Try replacing \n with \r\n (Windows style)
        windows_style = config_col.replace('\\n', '\r\n')
        if windows_style in df_cols_set:
            matches[config_col] = windows_style
            continue
    
    return matches


def drop_incomplete_rows(df: pd.DataFrame, config: Dict) -> pd.DataFrame:
    """
    Drop rows where required non-null columns are missing.
    
    Parameters
    ----------
    df : pd.DataFrame
        The dataframe to process
    config : Dict
        Configuration dictionary
        
    Returns
    -------
    pd.DataFrame
        Dataframe with incomplete rows dropped
    """
    required_columns = config.get('required_non_null_columns', [])
    
    # Create a boolean mask for rows that have any values missing
    null_mask = df[required_columns].isna().any(axis=1)
    
    # Drop rows where the mask is False
    df = df[~null_mask]
    
    logger.info(f"Dropped {sum(null_mask)} rows that had values missing from required columns")
    
    return df
    

def drop_columns(df: pd.DataFrame, config: Dict) -> pd.DataFrame:
    """
    Drop specified columns from the dataframe.
    
    Parameters
    ----------
    df : pd.DataFrame
        The dataframe to process
    config : Dict
        Configuration dictionary
        
    Returns
    -------
    pd.DataFrame
        Dataframe with specified columns dropped
    """
    # Get list of columns to drop from config
    columns_to_drop = config.get('drop_columns', [])
    
    # Find matching columns using the flexible matching function
    column_matches = find_matching_columns(df, columns_to_drop)
    
    # Get columns to drop and missing columns
    found_columns = list(column_matches.values())
    missing_columns = [col for col in columns_to_drop if col not in column_matches]
    
    # Log warnings for missing columns
    for col in missing_columns:
        logger.warning(f"Column '{col}' specified in drop_columns not found in dataframe")
    
    # Drop the columns that were found
    if found_columns:
        df = df.drop(columns=found_columns)
        logger.info(f"Dropped {len(found_columns)} columns")
    else:
        logger.warning("No columns found to drop")
    
    return df


def print_column_diagnostics(df: pd.DataFrame):
    """
    Print column names with their exact string representation for debugging.
    
    Parameters
    ----------
    df : pd.DataFrame
        The dataframe to analyze
    """
    logger.debug("Column name diagnostics:")
    for i, col in enumerate(df.columns):
        # Show column name with literal newlines and special chars visible
        repr_col = repr(col)
        logger.debug(f"Column {i}: {repr_col}")


def rename_columns(df: pd.DataFrame, config: Dict) -> pd.DataFrame:
    """
    Rename columns based on the configuration.
    
    Parameters
    ----------
    df : pd.DataFrame
        The dataframe to process
    config : Dict
        Configuration dictionary
        
    Returns
    -------
    pd.DataFrame
        Dataframe with renamed columns
    """
    # Diagnostic print to check column names
    print_column_diagnostics(df)
    
    # Find matching columns using the flexible matching function
    column_matches = find_matching_columns(df, list(config['rename_columns'].keys()))
    
    # Create rename map from matches to new names
    rename_map = {actual_col: config['rename_columns'][config_col] 
                  for config_col, actual_col in column_matches.items()}
    
    # Find missing columns that are in config but not mapped
    missing_columns = [col for col in config['rename_columns'] if col not in column_matches]
    
    if missing_columns:
        logger.warning(f"Some columns specified in config were not found in the dataframe: {', '.join(missing_columns)}")
    
    if rename_map:
        logger.info(f"Renaming columns with map from config")
        df = df.rename(columns=rename_map)
        logger.info(f"Renamed {len(rename_map)} columns")
    else:
        logger.warning("No columns found to rename")
    
    return df


def anonymize_labelers(df: pd.DataFrame, config: Dict) -> pd.DataFrame:
    """
    Anonymize the Assignments/labeler columns by mapping names to codes.
    
    Parameters
    ----------
    df : pd.DataFrame
        The dataframe to process
    config : Dict
        Configuration dictionary
        
    Returns
    -------
    pd.DataFrame
        Dataframe with anonymized labeler information
    """
    mapping = config.get('labeler_codes')
    assignment_columns = ['Test Entry Assignment', 'Test Review Assignment']

    for col_name in assignment_columns:
        if col_name in df.columns:
            
            # Strip trailing whitespace from the column values
            df[col_name] = df[col_name].astype(str).str.strip()
            
            # Count occurrences of each assignee name before anonymization
            assignee_counts = df[col_name].dropna().value_counts()
            logger.info(f"Work distribution for {col_name}:")
            for assignee, count in assignee_counts.items():
                logger.info(f"  {assignee}: {count} assignments")
            
            # Check if all unique values in the column have a mapping
            unique_assignees = df[col_name].dropna().unique()
            unmapped_assignees = [name for name in unique_assignees if name not in mapping]
            
            if unmapped_assignees:
                logger.warning(f"Found assignees without mapping in config for {col_name}: {unmapped_assignees}")
            else:
                logger.info(f"All {len(unique_assignees)} unique assignees have mappings in config for {col_name}")
            
            # Apply the mapping
            df[col_name] = df[col_name].map(mapping)
            df[col_name] = pd.to_numeric(df[col_name], errors='coerce').astype("Int64")
            logger.info(f"Anonymized {col_name} using provided mapping")

        else:
            logger.warning(f"`{col_name}` column not found")
    
    return df


def clean_lead_level(df: pd.DataFrame, config: Dict) -> pd.DataFrame:
    """
    Clean the lead level column and add a numeric version.
    
    Parameters
    ----------
    df : pd.DataFrame
        The dataframe to process
    config : Dict
        Configuration dictionary
        
    Returns
    -------
    pd.DataFrame
        Dataframe with cleaned lead level values and numeric column
    """
    lead_col = config['renamed_lead_column']
    
    # Find matching column using flexible matching
    lead_matches = find_matching_columns(df, [lead_col])
    if not lead_matches:
        logger.warning(f"{lead_col} column not found")
        return df
    
    actual_lead_col = lead_matches[lead_col]
    
    # Create new column for numeric values and insert it right after the lead column
    numeric_col = f"{actual_lead_col}_numeric"
    lead_col_idx = df.columns.get_loc(actual_lead_col)
    df.insert(lead_col_idx + 1, numeric_col, np.nan)
    
    processed_count = 0
    
    for idx, orig_id, value in zip(df.index, df['Original ID'], df[actual_lead_col]):
        if pd.isna(value):
            continue
            
        value = str(value).strip()

        if value == "Not Tested":
            processed_count += 1
            continue
        elif value == "Not Detected":
            df.at[idx, numeric_col] = 0
            processed_count += 1
        elif value.endswith('µg/ft²') or value.endswith('ug/ft²'):
            try:
                # Handle both micro symbol variants
                clean_value = value.replace('µg/ft²', '').replace('ug/ft²', '').replace(',', '').strip()
                numeric_value = float(clean_value)
                df.at[idx, numeric_col] = numeric_value
                processed_count += 1
            except ValueError:
                logger.warning(f"ID # {int(orig_id)}: Unable to convert lead level '{value}' to numeric value")
        else:
            logger.error(f"ID # {int(orig_id)}: Unable to process lead level value: {value}")
    
    logger.info(f"Processed {processed_count} lead level values to numeric format")
    valid_lead_count = df[numeric_col].notna().sum()
    logger.info(f"Processed {processed_count} values in {actual_lead_col}. {valid_lead_count} valid values.")

    return df


def clean_metal_columns(df: pd.DataFrame, config: Dict) -> pd.DataFrame:
    """
    Clean metal contaminant columns and extract numeric values and locations.
    
    Parameters
    ----------
    df : pd.DataFrame
        The dataframe to process
    config : Dict
        Configuration dictionary
        
    Returns
    -------
    pd.DataFrame
        Dataframe with additional numeric and location columns for metals
    """
    detected_pattern = r"DETECTED: (\d+(?:\.\d+)?) ug/ft² \((.*?)\)"
    processed_count = 0
    
    # Find matching columns using flexible matching
    metal_matches = find_matching_columns(df, config['metal_columns_to_analyze'])
    
    # Log missing columns
    missing_columns = [col for col in config['metal_columns_to_analyze'] if col not in metal_matches]
    for col in missing_columns:
        logger.warning(f"Metal column '{col}' not found in dataframe")
    
    for config_col, actual_col in metal_matches.items():
        numeric_col = f"{actual_col}_numeric"
        location_col = f"{actual_col}_location"
        df[numeric_col] = np.nan
        df[location_col] = ""
        
        column_processed = 0
        
        for idx, orig_id, value in zip(df.index, df['Original ID'], df[actual_col]):
            if pd.isna(value):
                continue
                
            value = str(value).strip()
            
            if value.startswith("DETECTED"):
                match = re.search(detected_pattern, value)
                if match:
                    amount, location = match.groups()
                    df.at[idx, numeric_col] = float(amount)
                    df.at[idx, location_col] = location
                    column_processed += 1
            elif value == 'Not Detected':
                df.at[idx, numeric_col] = 0
                df.at[idx, location_col] = "Not Recorded"
                column_processed += 1
            elif value in ['Not Tested', 'Not tested']:
                # Simply leave the columns as nans
                column_processed += 1
            else:
                logger.error(f"ID # {int(orig_id)} for {actual_col}: Unable to process metal value: {value}")
        
        df.drop(columns=actual_col, inplace=True)
        valid_metal_count = df[numeric_col].notna().sum()
        processed_count += column_processed
        logger.info(f"Processed {column_processed} values in {actual_col}. {valid_metal_count} valid values.")
    
    logger.info(f"Total metal values processed: {processed_count}")
    
    
    return df


def process_sample_dates(df: pd.DataFrame) -> pd.DataFrame:
    """
    Process the Sample Date column to handle multiple dates separated by newlines.
    
    Parameters
    ----------
    df : pd.DataFrame
        The dataframe to process
        
    Returns
    -------
    pd.DataFrame
        Dataframe with processed sample dates
    """
    date_col = 'Sample Date'
    
    # Find matching column using flexible matching
    date_matches = find_matching_columns(df, [date_col])
    if not date_matches:
        logger.warning(f"{date_col} column not found")
        return df
    
    actual_date_col = date_matches[date_col]
    
    # Create new columns
    df.insert(3, 'First_Sample_Date', pd.NaT)
    df.insert(4, 'Last_Sample_Date', pd.NaT)
    #df.insert(5, 'Sample_Date_Count', 0)
    
    processed_count = 0
    
    for idx, orig_id, value in zip(df.index, df['Original ID'], df[actual_date_col]):
        if pd.isna(value) or value in ['Not Provided', 'Unclear']:
            continue
            
        value = str(value).strip()
        
        # Split by newline characters to get multiple dates if present
        date_strings = re.split(r'\r?\n', value)
        date_strings = [d.strip() for d in date_strings if d.strip()]
        
        if not date_strings:
            continue
            
        # Try to parse dates
        parsed_dates = []
        for date_str in date_strings:
            try:
                # Try ISO format (YYYY-MM-DD HH:MM:SS)
                if re.match(r'\d{4}-\d{2}-\d{2}', date_str):
                    date_obj = datetime.strptime(date_str.split(' ')[0], '%Y-%m-%d')
                    parsed_dates.append(date_obj)
                # Try MM/DD/YYYY format
                else:
                    date_obj = datetime.strptime(date_str, '%m/%d/%Y')
                    parsed_dates.append(date_obj)
            except ValueError:
                logger.error(f"ID # {int(orig_id)}: Could not parse date: {date_str}")
                
        
        if parsed_dates:
            #df.at[idx, 'Sample_Date_Count'] = len(parsed_dates)
            df.at[idx, 'First_Sample_Date'] = min(parsed_dates)
            df.at[idx, 'Last_Sample_Date'] = max(parsed_dates)
            processed_count += 1
    
    logger.info(f"Processed {processed_count} sample date entries")
    
    df.drop(columns=[actual_date_col], inplace=True)

    return df


def process_geospatial_data(df: pd.DataFrame, config: Dict) -> pd.DataFrame:
    """
    Process geospatial location column to split latitude and longitude.
    
    Parameters
    ----------
    df : pd.DataFrame
        The dataframe to process
    config : Dict
        Configuration dictionary
        
    Returns
    -------
    pd.DataFrame
        Dataframe with processed geospatial data
    """
    geo_col = config.get('geospatial_location_column')
    if not geo_col:
        logger.warning("No geospatial location column specified in config")
        return df
    
    # Find matching column using flexible matching
    geo_matches = find_matching_columns(df, [geo_col])
    
    # If direct match not found, try to use a renamed column
    if not geo_matches:
        # Check if column was renamed in a previous step
        for old_col, new_col in config.get('rename_columns', {}).items():
            if old_col == geo_col and new_col in df.columns:
                actual_geo_col = new_col
                break
        else:
            logger.warning(f"Geospatial location column '{geo_col}' not found in dataframe")
            return df
    else:
        actual_geo_col = geo_matches[geo_col]
    
    # Create new columns for latitude and longitude
    df.insert(1, 'Latitude', np.nan)
    df.insert(2, 'Longitude', np.nan)
    
    processed_count = 0
    
    for idx, value in zip(df.index, df[actual_geo_col]):
        if pd.isna(value):
            logger.error(f"Could not parse coordinates: {value} at row {idx}")
            continue
            
        value = str(value).strip()
        
        # Split by comma to separate lat and long
        coords = value.split(',')
        if len(coords) == 2:
            lat = float(coords[0].strip())
            lon = float(coords[1].strip())
            df.at[idx, 'Latitude'] = lat
            df.at[idx, 'Longitude'] = lon
            processed_count += 1
        else:
            logger.error(f"Could not parse coordinates: {value} at row {idx}")
    
    df.drop(columns=[actual_geo_col], inplace=True)

    logger.info(f"Processed {processed_count} geospatial coordinates")
    return df


def drop_rows_with_missing_required_fields(df: pd.DataFrame, config: Dict) -> pd.DataFrame:
    """
    Drop rows where required non-null columns have missing values.
    
    Parameters
    ----------
    df : pd.DataFrame
        The dataframe to process
    config : Dict
        Configuration dictionary containing required_non_null_columns
        
    Returns
    -------
    pd.DataFrame
        Dataframe with rows containing missing required fields dropped
    """
    required_columns = config.get('required_non_null_columns', [])
    if not required_columns:
        logger.warning("No required non-null columns specified in config")
        return df
        
    # Create a boolean mask for rows that have any values missing
    null_mask = df[required_columns].isna().any(axis=1)
    
    # Log details about rows being dropped
    if null_mask.any():
        dropped_rows = df[null_mask]
        for idx, row in dropped_rows.iterrows():
            missing_cols = [col for col in required_columns if pd.isna(row[col])]
            logger.warning(f"Row {idx} dropped due to missing values in columns: {', '.join(missing_cols)}")
    
    # Drop rows where the mask is True
    df = df[~null_mask]
    
    logger.info(f"Dropped {sum(null_mask)} rows that had values missing from required columns")
    
    return df


def process_asbestos_test_types(df: pd.DataFrame, config: Dict) -> pd.DataFrame:
    """
    Process the asbestos column to extract test types used.
    
    Creates binary columns for each asbestos test type indicating whether that 
    test was mentioned in the asbestos field.
    
    Parameters
    ----------
    df : pd.DataFrame
        The dataframe to process
    config : Dict
        Configuration dictionary containing test types
        
    Returns
    -------
    pd.DataFrame
        Dataframe with added binary test type columns
    """
    asbestos_col = config.get('renamed_asbestos_column')
    test_types = config.get('asbestos_test_types', [])
    
    logger.info(f"Processing asbestos test types from column: {asbestos_col}")
    
    # Find the position of the asbestos column
    asbestos_col_idx = df.columns.get_loc(asbestos_col)
    
    # Create binary columns for each test type and insert them after the asbestos column
    for i, test_type in enumerate(test_types):
        col_name = f'asbestos_{test_type}_used'
        # Insert each new column right after the asbestos column (accounting for previously inserted columns)
        insert_position = asbestos_col_idx + 1 + i
        df.insert(insert_position, col_name, 0)
    
    rows_with_asbestos_test = 0
    
    # Process each row to identify test types used
    for idx, value in df[asbestos_col].items():
        if pd.isna(value):
            continue
            
        value = str(value).strip()
        
        # Skip standard non-test values
        if value == 'Not Tested':
            continue
        
        # Check for each test type in the text (case-sensitive since test types are always uppercase)
        for test_type in test_types:
            if test_type in value:
                col_name = f'asbestos_{test_type}_used'
                df.at[idx, col_name] = 1
        
        rows_with_asbestos_test += 1
    
    logger.info(f"Processed {rows_with_asbestos_test} entries with asbestos testing")
    
    # Log summary statistics
    for test_type in test_types:
        col_name = f'asbestos_{test_type}_used'
        count = df[col_name].sum()
        logger.info(f"Found {count} samples tested with {test_type}")
    
    return df


def count_per_company_tests(df: pd.DataFrame, config: Dict) -> None:
    """
    Count and log occurrences of testing laboratories and main/testing companies.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe containing the raw columns prior to any renaming/dropping
    config : Dict
        Configuration dictionary containing the original column names under
        `original_testing_laboratory_column` and `original_testing_company_column`
    """
    lab_config_col = config.get('original_testing_laboratory_column')
    comp_config_col = config.get('original_testing_company_column')

    columns_to_match: List[str] = [c for c in [lab_config_col, comp_config_col] if c]
    if not columns_to_match:
        logger.warning("No testing laboratory/company columns specified in config")
        return

    column_matches = find_matching_columns(df, columns_to_match)

    def split_to_series(series: pd.Series) -> pd.Series:
        items: List[str] = []
        for value in series.dropna():
            text = str(value).strip()
            if not text:
                continue
            parts = text.split('&')
            for part in parts:
                name = re.sub(r'\s+', ' ', part).strip()
                if not name or name.lower() == 'nan':
                    continue
                items.append(name)
        return pd.Series(items, dtype='string') if items else pd.Series(dtype='string')

    # Laboratories
    if lab_config_col in column_matches:
        lab_col = column_matches[lab_config_col]
        lab_series = split_to_series(df[lab_col])
        if not lab_series.empty:
            counts = lab_series.value_counts()
            counts_df = counts.rename('count').to_frame().reset_index(names='name')
            counts_df.sort_values(by=['count', 'name'], ascending=[False, True], inplace=True)
            logger.info("Testing Laboratories counts (descending):")
            for _, row in counts_df.iterrows():
                logger.info(f"  {int(row['count']):4d}  {row['name']}")
        else:
            logger.info("No entries found for Testing Laboratories")
    else:
        logger.warning(f"Testing laboratory column not found: {lab_config_col}")

    # Main/Testing Companies
    if comp_config_col in column_matches:
        comp_col = column_matches[comp_config_col]
        comp_series = split_to_series(df[comp_col])
        if not comp_series.empty:
            counts = comp_series.value_counts()
            counts_df = counts.rename('count').to_frame().reset_index(names='name')
            counts_df.sort_values(by=['count', 'name'], ascending=[False, True], inplace=True)
            logger.info("Main/Testing Companies counts (descending):")
            for _, row in counts_df.iterrows():
                logger.info(f"  {int(row['count']):4d}  {row['name']}")
        else:
            logger.info("No entries found for Main/Testing Companies")
    else:
        logger.warning(f"Testing company column not found: {comp_config_col}")


@click.command()
@click.option('--input_file', 
              type=click.Path(exists=True),
              required=True,
              help='Path to the input XLSX file')
@click.option('--output_file',
              type=click.Path(),
              required=True,
              help='Path for the output CSV file')
@click.option('--config_file',
              type=click.Path(exists=True),
              default='config.yaml',
              help='Path to the configuration YAML file')
@click.option('--sheet_name',
              default='DATA_PRE_INT',
              help='Sheet name or index to process')
@click.option('--last_row',
              type=int,
              default=None,
              help='Last row to ingest if the whole sheet should not be processed')
@click.option('--log_level',
              type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'ERROR']),
              default='INFO',
              help='Logging level')
def main(input_file: str, output_file: str, config_file: str, sheet_name: Union[str, int], last_row: int, log_level: str):
    """
    Process an XLSX file to clean and anonymize data for public release.
    """
    # Set logging level
    logger.setLevel(getattr(logging, log_level))
    
    logger.info(f"Processing {input_file}, sheet {sheet_name}")
    
    # Read the input file
    try:
        df = pd.read_excel(input_file, sheet_name=sheet_name, parse_dates=False, nrows=last_row)
        logger.info(f"Loaded data with {len(df)} rows and {len(df.columns)} columns")
    except Exception as e:
        logger.error(f"Error reading input file: {e}")
        return
    

    # Create output directory if it doesn't exist
    output_path = Path(output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Load config and process data
    config = load_config(Path(config_file))

    # Apply each processing step in sequence
    count_per_company_tests(df, config)

    processed_df = drop_columns(df, config)
    processed_df = rename_columns(processed_df, config)
    processed_df = process_geospatial_data(processed_df, config)
    processed_df = anonymize_labelers(processed_df, config)
    processed_df = drop_incomplete_rows(processed_df, config)
    processed_df = clean_lead_level(processed_df, config)
    processed_df = clean_metal_columns(processed_df, config)
    processed_df = process_sample_dates(processed_df)
    processed_df = process_asbestos_test_types(processed_df, config)
    processed_df = drop_rows_with_missing_required_fields(processed_df, config)
    
    # Convert the 'Original ID' column to an integer
    processed_df['Original ID'] = pd.to_numeric(processed_df['Original ID'], errors='coerce').astype("Int64")
    
    # Save the processed data
    processed_df.to_csv(output_path, index=False)
    logger.info(f"Saved processed data to {output_path}")
    logger.info(f"Original shape: {df.shape}, Processed shape: {processed_df.shape}")


if __name__ == "__main__":
    main() 