import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from collections import defaultdict
import pywt
from scipy.interpolate import interp1d
from scipy.signal import find_peaks, detrend, spectrogram
from scipy.fft import fft, fftfreq
from scipy.stats import skew, kurtosis
from cosinor import cosinor_analysis 
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score


def extract_features_full(signal, t_smooth, sampling_rate=1):
    """
    Extracts features from a given temperature signal, including:
    - Statistical measures (mean, std, skewness, kurtosis)
    - Circadian rhythm features (MESOR, Acrophase, Amplitude)
    - Non-parametric features (IS, IV, M5, L5, RA)
    
    Args:
        signal (array-like): Temperature signal.
        time_stamps (array-like): Corresponding time values.
        sampling_rate (float): Sampling frequency.
        
    Returns:
        dict: Extracted features.
    """

    try:
        period = pd.to_numeric(signal.index[-1].replace('T', ''))
    except:
        period = 28
        
    signal = pd.to_numeric(signal, errors='coerce')
    signal = signal[~np.isnan(signal)]  # Remove NaNs


    # Basic Statistical Features
    mean_value = np.mean(signal)
    std_dev    = np.std(signal)
    skewness   = skew(signal)
    
    kurtosis_val_tmp = kurtosis(signal)
    kurtosis_val     = (kurtosis_val_tmp if kurtosis_val_tmp<2 else np.nan)
    

    detrended_signal = detrend(signal)
    
    # Amplitude
    amplitude = t_smooth.max() - t_smooth.min()

    slope_tmp = np.mean(np.abs(np.gradient(detrended_signal)))
    slope     = slope_tmp


    # Peak-to-peak distance
    peaks, _ = find_peaks(signal)
    if len(peaks) > 1:
        peak_distances = np.diff(peaks) / sampling_rate
        avg_peak_to_peak_tmp = np.mean(peak_distances)
    else:
        avg_peak_to_peak_tmp = np.nan
        
    avg_peak_to_peak = (avg_peak_to_peak_tmp if avg_peak_to_peak_tmp<11 else np.nan)
        
    # 🔹 Non-Parametric Circadian Features (IS, IV, M5, L5, RA)
    sorted_signal = pd.Series(signal).sort_values()
    M5 = sorted_signal.iloc[-5:].mean()  # Mean of highest 5 days
    L5  = sorted_signal.iloc[:5].mean()    # Mean of lowest 5 days
    RA  = (M5 - L5) / (M5 + L5)          # Relative Amplitude
    
    # Circadian Stability Features
    IS = np.var(signal.mean()) / np.var(signal)    # Inter-monthly Stability
    IV = np.var(np.diff(signal)) / np.var(signal)  # Intra-monthly Variability

    # 🔹 Compile Features
    features = {
        "Mean_Temperature": mean_value,
        "Standard_Deviation": std_dev,
        "Skewness": skewness,
        "Kurtosis": kurtosis_val,
        "Average_Slope": slope,
        "Period": period,
        "Amplitude": amplitude,
        "Avg_Peak-to-Peak_Distance": avg_peak_to_peak,
        "M5": M5,
        "L5": L5,
        "Relative_Amplitude": RA,
        "Intra-monthly_Variability": IV,
    }
    
    # ADD fitting functions
    x_vals = np.arange(1,len(signal)+1)
    period_range = range(22, 36)

    for name_func, func in wave_functions.items():
        pred, r2, period, phase = fit_best_wave(x_vals, signal, func, period_range)

        features[f'{name_func}_r2']     = r2 
        features[f'{name_func}_period'] = period
        features[f'{name_func}_phase']  = phase

    return features

def temps_by_phase(temperatures_raw, temperatures_smooth, mucus_peak):
    phase_days_numeric       = temperatures_raw.index.str.extract(r'T(\d+)').astype(int).squeeze()
    phase_days_numeric.index = temperatures_raw.index  # align indexes

    # Ensure the correct split relative to mucus_peak
    temp_fol           = temperatures_raw[phase_days_numeric <= mucus_peak]
    temp_lut           = temperatures_raw[phase_days_numeric > mucus_peak]
    smoothed_temp_fol  = temperatures_smooth[phase_days_numeric <= mucus_peak]
    smoothed_temp_lut  = temperatures_smooth[phase_days_numeric > mucus_peak]
    
    return temp_fol, temp_lut, smoothed_temp_fol, smoothed_temp_lut
    
    
def extract_features_per_phase(temp_follicular, temp_luteal, smoothed_T_lut, smoothed_T_fol, features, mucus_peak, temperatures_raw):

    def _to_day(x):
        """Extract an integer day number from index labels like 'T14' or 14."""
        return int(str(x).lstrip('T'))
    
    smoothed_cycle = pd.concat([smoothed_T_fol, smoothed_T_lut])
        
    acro_idx = smoothed_T_lut.idxmax()           # index label of max in luteal
    acro_day = _to_day(acro_idx)
    
    nadir_idx = smoothed_T_lut.idxmin()           # min of luteal phase
    nadir_day = _to_day(nadir_idx)
    
    start_day  = _to_day(smoothed_cycle.index[0])  
    end_day    = _to_day(temperatures_raw.index[-1])


    #follicular
    features['dur_fol']        = mucus_peak
    features['num_T_vals_fol'] = len(temp_follicular)
    features['mean_fol']       = temp_follicular.mean()
    features['std_fol']        = temp_follicular.std()
    features['amp_raw_fol']    = temp_follicular.max() - temp_follicular.min()
    features['amp_smooth_fol'] = smoothed_T_fol.max() - smoothed_T_fol.min()
    features['smoothness_fol'] = temp_follicular.diff().abs().mean()
    
    #luteal
    features['dur_lut']        = _to_day(temperatures_raw.index[-1]) - mucus_peak
    features['num_T_vals_lut'] = len(temp_luteal)
    features['mean_lut']       = temp_luteal.mean()
    features['std_lut']        = temp_luteal.std()
    features['amp_raw_lut']    = temp_luteal.max() - temp_luteal.min()
    features['amp_smooth_lut'] = smoothed_T_lut.max() - smoothed_T_lut.min()
    features['smoothness_lut'] = temp_luteal.diff().abs().mean()
    
    # NEW luteal / follicular timing features
    
    features['acrophase_day'] = acro_day           # new name (cycle-based)
    features['nadir_day']     = nadir_day
    
    
    acrophase_offset_lut = acro_day - mucus_peak   # days since luteal start
    features['acrophase_angle_lut'] = (np.nan if features['dur_lut'] == 0 else (acrophase_offset_lut / features['dur_lut']) * 180)

    # -- Nadir angle inside follicular
    nadir_offset_fol = nadir_day                   # follicular starts at D0
    features['nadir_angle_fol'] = (np.nan if features['dur_fol'] == 0 else (nadir_offset_fol / features['dur_fol']) * 180)

    
    # NEW slopes
    # ------------------------------------------------------------------
    def _slope(i_start, i_end):
        """Return ΔT/Δday between two absolute day numbers; np.nan if same day."""
        if i_end == i_start:
            return np.nan
        #return (smoothed_cycle.loc[f'T{i_end}' if f'T{i_end}' in smoothed_cycle.index else i_end] - smoothed_cycle.loc[f'T{i_start}' if f'T{i_start}' in smoothed_cycle.index else i_start]) / (i_end - i_start)
        return (smoothed_cycle.loc[f'T{i_end}' if f'T{i_end}' in smoothed_cycle.index else i_end] - smoothed_cycle.loc[f'T{i_start}' if f'T{i_start}' in smoothed_cycle.index else i_start]) / (i_end - i_start)

    features['slope_m_to_nadir']         = _slope(start_day,  nadir_day)
    features['slope_nadir_to_acrophase'] = _slope(nadir_day, acro_day)
    features['slope_acrophase_to_m']     = _slope(acro_day,  end_day)
    
    return features


def compute_cosine_r_square(temp_signal):

    if len(temp_signal) == 0:
        return np.nan

    # Create a time vector for the measurements (assuming equal spacing)
    t = np.arange(len(temp_signal))
    # Estimate the dominant period via Fourier transform:
    N = len(temp_signal)
    yf = fft(temp_signal)
    xf = fftfreq(N, d=1)[:N//2]
    dominant_freq = xf[np.argmax(np.abs(yf[:N//2]))]
    period_est = 1 / dominant_freq if dominant_freq != 0 else N

    # Define cosine model (using period_est as fixed)
    def cosine_model(t, MESOR, A, phi):
        return MESOR + A * np.cos(2 * np.pi * t / period_est + phi)

    # Fit the cosine model using curve_fit
    try:
        p0 = [np.mean(temp_signal), (np.max(temp_signal)-np.min(temp_signal))/2, 0]
        popt, _ = curve_fit(cosine_model, t, temp_signal, p0=p0)
        fitted_signal = cosine_model(t, *popt)
        ss_res = np.sum((temp_signal - fitted_signal)**2)
        ss_tot = np.sum((temp_signal - np.mean(temp_signal))**2)
        r2_manual = 1 - ss_res/ss_tot if ss_tot != 0 else np.nan
        #print(f"Cosine fit R2 is {r2_manual}")
    except Exception as e:
        print(f"Cosine fit failed")
        r2_manual = np.nan
        fitted_signal = None
        
    return r2_manual, fitted_signal


def cosine_r_square(temp_signal, rolling_window=False):

    if rolling_window:
        temp_signal = temp_signal.dropna().rolling(window=3, min_periods=1).mean().values
    else:
        temp_signal = temp_signal.dropna().values

    if len(temp_signal) == 0:
        return np.nan

    # Create a time vector for the measurements (assuming equal spacing)
    t = np.arange(len(temp_signal))
    # Estimate the dominant period via Fourier transform:
    N = len(temp_signal)
    yf = fft(temp_signal)
    xf = fftfreq(N, d=1)[:N//2]
    dominant_freq = xf[np.argmax(np.abs(yf[:N//2]))]
    period_est = 1 / dominant_freq if dominant_freq != 0 else N

    # Define cosine model (using period_est as fixed)
    def cosine_model(t, MESOR, A, phi):
        return MESOR + A * np.cos(2 * np.pi * t / period_est + phi)

    # Fit the cosine model using curve_fit
    try:
        p0 = [np.mean(temp_signal), (np.max(temp_signal)-np.min(temp_signal))/2, 0]
        popt, _ = curve_fit(cosine_model, t, temp_signal, p0=p0)
        fitted_signal = cosine_model(t, *popt)
        ss_res = np.sum((temp_signal - fitted_signal)**2)
        ss_tot = np.sum((temp_signal - np.mean(temp_signal))**2)
        r2_manual = 1 - ss_res/ss_tot if ss_tot != 0 else np.nan
        #print(f"Cosine fit R2 is {r2_manual}")
    except Exception as e:
        print(f"Cosine fit failed")
        r2_manual = np.nan
        fitted_signal = None
        
    return r2_manual, fitted_signal
        

def extract_dynamic_features(signal, sampling_rate=1, window_size=50, overlap=0.5):
    """
    Extract dynamic features using a moving window.
    
    Args:
        signal (array-like): Input time series signal.
        sampling_rate (int): Sampling rate of the signal.
        window_size (int): Size of the moving window (in samples).
        overlap (float): Overlap between consecutive windows (0 to 1).
    
    Returns:
        pd.DataFrame: Dynamic features for each window.
    """
    step_size = int(window_size * (1 - overlap))  # Compute step size based on overlap
    num_windows = int((len(signal) - window_size) / step_size) + 1
    dynamic_features = []

    for i in range(num_windows):
        start = i * step_size
        end = start + window_size
        window = signal[start:end]

        # Extract features for the current window
        features = extract_features(window, sampling_rate)
        features["Window Start"] = start
        features["Window End"] = end
        dynamic_features.append(features)
    
    # Handle the remainder of the signal
    if end < len(signal):
        remainder = signal[end:]
        if len(remainder) >= window_size / 2:  # Process if remainder is significant
            features = extract_features(remainder, sampling_rate)
            features["Window Start"] = end
            features["Window End"] = len(signal)
            dynamic_features.append(features)

    # Convert results to a DataFrame
    return pd.DataFrame(dynamic_features)

    
def concatenate_temperatures_w_dates(data):
    # Extracting temperature columns and the starting date of each cycle
    temp_columns = [col for col in data.columns if col.startswith('TEMP')]
    date_column  = 'DATA'

    # Ensure the date column is parsed as datetime
    data[date_column] = pd.to_datetime(data[date_column])

    # Create a long-form DataFrame with cycle day as index
    temp_data = pd.DataFrame()
    for idx, row in data.iterrows():
        cycle_start = row[date_column]
        temps = row[temp_columns].dropna().reset_index(drop=True)
        temp_data = pd.concat([temp_data, 
                               pd.DataFrame( {
                        "Date": [cycle_start + pd.Timedelta(days=i) for i in range(len(temps))],
                        "Temperature": temps,} ), ], ignore_index=True, )

    if len(temp_data)==0:
        return []
    
    try:
        temp_data = temp_data.dropna()
        if temp_data['Date'].duplicated().any():
            print("Duplicates found in the Date column. Resolving by aggregating duplicates.")
            temp_data = temp_data.groupby('Date', as_index=False).mean()
     

        # Set the full timeline with NaN for missing dates
        full_timeline = pd.date_range(start=temp_data['Date'].min(), end=temp_data['Date'].max())
        temp_data = temp_data.set_index('Date').reindex(full_timeline).rename_axis('Date').reset_index()
    except:
        import IPython; IPython.embed()
    return temp_data


def extract_features_per_cycle(data, cycle_starts, signal_col='Temperature', date_col='Date', sampling_rate=1, detrend=False):
    """
    Extract dynamic features for each cycle defined by start and end dates.

    Args:
        data (pd.DataFrame): Input data containing the signal and dates.
        cycle_starts (pd.Series): Sorted dates marking the start of each cycle.
        signal_col (str): Column name of the signal to extract features from.
        date_col (str): Column name of the date to use for segmentation.
        sampling_rate (int): Sampling rate of the signal.

    Returns:
        pd.DataFrame: Extracted features for each cycle, aligned with the cycle boundaries.
    """
    # Ensure dates are sorted
    data = data.sort_values(by=date_col)
    
    # List to hold features for each cycle
    all_features = []

    # Iterate through cycles using start and end dates
    for i in range(len(cycle_starts) - 1):
        #import IPython; IPython.embed()
        # Define the start and end of the cycle
        start_date = cycle_starts.iloc[i]
        end_date   = cycle_starts.iloc[i + 1]
        
        # Extract data for the current cycle
        cycle_data = data[(data[date_col] >= start_date) & (data[date_col] < end_date)][signal_col]
        
        # Ensure there is data for the cycle
        if len(cycle_data) > 0:
            # Extract features for the cycle
            #features = extract_features(cycle_data.values, sampling_rate, detrend)
            
            difference       = end_date - start_date
            period           = (difference.days if difference.days< 45 else np.nan)
            features         = extract_features_full(cycle_data.values)
            features['R2_raw'], _  = cosine_r_square(cycle_data, rolling_window=False)
            features['R2_avg'], _  = cosine_r_square(cycle_data, rolling_window=True)
            
            features["Window Start"] = start_date
            features["Window End"]   = end_date
            features["Period"]       = period

                
            all_features.append(features)
    
    # Convert the list of features into a single DataFrame
    return pd.DataFrame(all_features)


def create_proportional_heatmap(dynamic_features, cycle_starts, concat_data):
    """
    Create a proportional heatmap aligned with cycle durations.

    Args:
        dynamic_features (pd.DataFrame): Extracted features for each cycle.
        cycle_starts (pd.Series): Sorted start dates of each cycle.
        concat_data (pd.DataFrame): Data with temperature signal.

    Returns:
        pd.DataFrame: Proportionally stretched heatmap data.
    """
    # Calculate cycle durations (number of days per cycle)
    cycle_durations = (dynamic_features["Window End"] - dynamic_features["Window Start"]).dt.days

    # Normalize features
    normalized_features = dynamic_features.drop(columns=["Window Start", "Window End"]).apply(
        lambda x: (x - x.min()) / (x.max() - x.min()), axis=0 )

    # Stretch each feature proportionally to the cycle duration
    proportional_heatmap = []
    for i, duration in enumerate(cycle_durations):
        row = normalized_features.iloc[i]
        # Repeat each feature row proportionally to the cycle duration
        stretched_row = np.tile(row.values, (duration, 1))
        proportional_heatmap.append(stretched_row)

    # Combine stretched rows
    proportional_heatmap = np.vstack(proportional_heatmap)

    # Use only cycle_starts for xticks
    proportional_heatmap_x = cycle_starts.dt.strftime('%Y-%m-%d').tolist()

    return proportional_heatmap, proportional_heatmap_x

# Function to extract non-NaN values for earliest and latest ages
def get_valid_timepoints_with_ages(patient_series, age_series, num_points=3):
    """
    Extract the earliest and latest `num_points` valid (non-NaN) values 
    from a patient's time series along with corresponding ages.
    """
    valid_data = patient_series.dropna()  # Remove NaNs
    valid_ages = valid_data.index  # Get corresponding ages

    if len(valid_data) < num_points:
        return None, None, None, None  # Not enough data

    #import IPython; IPython.embed()

    # Get the first and last `num_points` time points
    early_values = valid_data.iloc[:num_points].values
    late_values = valid_data.iloc[-num_points:].values

    # Get corresponding ages
    early_ages = valid_ages[:num_points].values
    late_ages = valid_ages[-num_points:].values

    return early_values, late_values, early_ages, late_ages

    
def convert_mm_yy(x):
    s = str(x).strip()
    if not s:
        return pd.NaT
    parts = s.split('-')
    if len(parts) != 2:
        return pd.NaT
    month_str = parts[0].strip()
    year_str = parts[1].strip()
    
    # Convert month; if it's not a valid month (e.g. "00" or non-numeric), default to 1
    try:
        month_int = int(month_str)
    except:
        month_int = 1
    if month_int < 1 or month_int > 12:
        month_int = 1  # default to January if invalid
    month_fixed = str(month_int).zfill(2)
    
    # Convert year (always assume the birth year is in the 1900s)
    try:
        year_int = int(year_str)
    except:
        return pd.NaT
    full_year = 1900 + year_int

    # Build the timestamp string and convert to Timestamp
    date_str = f"{full_year}-{month_fixed}-01"
    try:
        return pd.Timestamp(date_str)
    except Exception as e:
        return pd.NaT
    
# fitting multiple functions

def cosine_wave(t, period, phase):
    return np.sin(2 * np.pi * (t + phase) / period)

def square_wave(t, period, phase, duty=0.5):
    phase_shifted = (t + phase) % period / period
    return np.where(phase_shifted < duty, 1, -1)

def zigzag_wave(t, period, phase):
    phase_shifted = (t + phase) % period / period
    return 4 * np.abs(phase_shifted - 0.5) - 1

def sawtooth_up_wave(t, period, phase):
    return 2 * ((t + phase) / period % 1) - 1  # ramp from -1 to +1

def half_sine_wave(t, period, phase):
    raw = np.sin(2 * np.pi * (t + phase) / period)
    return np.maximum(0, raw)

def gaussian_wave(t, period, phase):
    pos = ((t + phase) % period) / period
    return np.exp(-((pos - 0.5)**2) / 0.01)


wave_functions = {
    "Cosine": cosine_wave,
    "Square": square_wave,
    "Zigzag": zigzag_wave,
    "Sawtooth Up": sawtooth_up_wave,
    "Half-Sine": half_sine_wave,
    "Gaussian": gaussian_wave}


def fit_best_wave(x, y, wave_func, period_range, phase_steps=50):
    best_r2 = -np.inf
    best_period = None
    best_phase = None
    best_pred = None

    for period in period_range:
        for phase in np.linspace(0, period, phase_steps):
            wave = wave_func(x, period, phase).reshape(-1, 1)
            model = LinearRegression().fit(wave, y)
            pred = model.predict(wave)
            r2 = r2_score(y, pred)
            if r2 > best_r2:
                best_r2 = r2
                best_period = period
                best_phase = phase
                best_pred = pred

    return best_pred, best_r2, best_period, best_phase