from dbfread import DBF
from waves_functions import *
from dbfread import DBF
from dbfread.field_parser import FieldParser
import pandas as pd

class CustomFieldParser(FieldParser):
    def parseN(self, field, data):  
        try:
            return super().parseN(field, data)
        except ValueError:
            return None  
        
def build_df_raw_data(filename='../../data/Colombo_datasets/London/london.dbf'):  
    table = DBF(filename, encoding='latin1', parserclass=CustomFieldParser)

    df_raw_data = pd.DataFrame(iter(table))

    temperature_columns = [col for col in df_raw_data.columns if 'TEMP' in col]
    temperature_columns.remove('TIPOTEMP')

    # Ensure ANNO_NAS and DATA are transformed into datetime type
    df_raw_data['DATA'] = pd.to_datetime(df_raw_data['DATA'])
    df_raw_data['ANNO_NAS'] = pd.to_datetime(df_raw_data['ANNO_NAS'] + 1900, format='%Y', errors='coerce')

    # Calculate the age at the time of the menses (in whole years)
    df_raw_data['Age_at_Menses'] = (df_raw_data['DATA'] - df_raw_data['ANNO_NAS']).apply(lambda x: x.days/365.25).round(1).fillna(-1)
    
    #  NOTE: Filtering by QUALIFI = 1 and 4
    
    #Convert to C
    # Lets focus on measures in C for now
    df_in_C = df_raw_data.copy()
    df_in_C = df_in_C[(df_in_C['QUALIFI']==1) | (df_in_C['QUALIFI']==4) ]

    df_in_C.loc[df_in_C['TIPOTEMP'] == 2, temperature_columns] = (df_in_C.loc[df_in_C['TIPOTEMP'] == 2, temperature_columns])/10 + 30
    df_in_C.loc[df_in_C['TIPOTEMP'] == 1, temperature_columns] = ((df_raw_data.loc[df_raw_data['TIPOTEMP'] == 1, temperature_columns])/10 + 90 - 32) * 5/9 
    df_in_C.reset_index(inplace=True)
    
    return df_in_C, df_raw_data

def add_age_group(df_in_C, age_bins=list(range(18, 51, 3))):

    # Define age bins and labels (adjust as needed)
    age_labels = [f"{age_bins[i]}-{age_bins[i+1]}" for i in range(len(age_bins) - 1)]

    # Assign each row to an age group based on Age_at_Menses
    df_in_C['Age_Group'] = pd.cut(df_in_C['Age_at_Menses'], bins=age_bins, labels=age_labels, right=False)
    #multi_cycle_patients = df_in_C['DONNA'].value_counts().index.tolist()
    return df_in_C
    
def build_feature_df(df_in_C, rolling_window=False):
    all_features = []

    for pat_id in df_in_C['DONNA'].unique():

        df_pat = df_in_C[df_in_C['DONNA'] == pat_id].reset_index()

        # Extract temperature and cycle starts
        if (df_pat['Age_at_Menses'].empty) or (df_pat['Age_at_Menses'].iloc[0] == -1): continue

        concat_data = concatenate_temperatures_w_dates(df_pat)

        if rolling_window == True:
            concat_data['Temperature'] = concat_data['Temperature'].dropna().rolling(window=3, min_periods=1).mean()
        else:
            concat_data['Temperature'] = concat_data['Temperature'].dropna()
            
        cycle_starts = df_pat['DATA'].sort_values().dropna()
        differences  = cycle_starts.diff().dt.days
        cycle_starts = cycle_starts[(differences.isna()) | (differences >= 7)]

        # Compute features per cycle
        dynamic_features       = extract_features_per_cycle(concat_data, cycle_starts, detrend=False)
        # If empty, fill with NaNs and keep patient metadata
        if dynamic_features.empty:
            dynamic_features = pd.DataFrame(columns=['Window Start', 'Window End'] + feature_columns_all)
            dynamic_features["DONNA"] = pat_id
            dynamic_features["DATA"]  = cycle_starts.iloc[0]  # Assign the only known cycle start date
        else:
            # Add patient ID and Age_at_Menses
            dynamic_features['DONNA'] = pat_id
            dynamic_features['TemperatureValues'] = pat_id

        dynamic_features = dynamic_features.merge(df_pat[['DATA', 'DONNA', 'Age_at_Menses', 'Age_Group']], left_on=['DONNA', 'Window Start'], right_on=['DONNA', 'DATA'],  how='left')

        all_features.append(dynamic_features)

    # Combine all extracted features
    feature_df = pd.concat(all_features, ignore_index=True)

    # Drop NaNs in Age Group
    feature_df = feature_df.dropna(subset=['Age_Group'])
    feature_df = feature_df.rename(columns=lambda x: x.strip().replace(" ", "_"))
    feature_df = feature_df.rename(columns={'Relative_Amplitude_(RA)': 'Relative_Amplitude'})
    feature_df = feature_df.rename(columns={'Intra-monthly_Variability_(IV)': 'Intra-monthly_Variability'})

    return feature_df


def build_three_way_tensor(feature_df, all_features, col_age='Age_at_Menses', participant_id = 'DONNA'):
    # Ensure each (DONNA, Age_Years) pair is unique by averaging duplicates
    #feature_df_grouped = feature_df.groupby([participant_id, col_age])[all_features].mean().reset_index()
    feature_df_grouped = feature_df.groupby([participant_id, col_age])[all_features].mean()

    # Create a three-way table (Patients x Age_Years x Features)
    three_way_table = feature_df_grouped.pivot(index=participant_id, columns=col_age, values=all_features)

    # Convert to a NumPy tensor (Patients x Age_Years x Features)
    three_way_tensor = np.array([three_way_table[feature].values for feature in all_features])

    # Reshape to (Patients, Age_Years, Features)
    three_way_tensor = np.moveaxis(three_way_tensor, 0, -1)  # Move feature axis to last dimension

    # Display the structure of the resulting 3D tensor
    print(three_way_tensor.shape)
    
    return three_way_tensor, three_way_table

def create_fertility_groups(x, y):
    if (x == True) and (y == True):
        return 'Conceptive'
    elif (x == False) and (y == True):
        return 'Fertile_not_Conceptive'
    elif (x == True) and (y == False):
        return 'Error'
    elif (x == False) and (y == False):
        return 'NoneConceptive'
    else:
        return 'Unknown'