import numpy as np
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin


#%% feature extraction class
class FeatureExtraction(BaseEstimator, TransformerMixin):

    # lightcurve_idcs: nested array of column index ranges of the light curve data
    # planet_star_features_idcs: index range of planet/star features
    def __init__(self, lightcurve_idcs, planet_star_features_idcs):
        self.lightcurve_idcs = lightcurve_idcs
        self.planet_star_features_idcs = planet_star_features_idcs

    # fit feature extraction method to data (just a stub in our implementation)
    def fit(self, X, y = None):
        return self

    # this method does the actual feature extraction
    #
    # X: NumPy array [observations x raw-features]
    #
    # self.lightcurve_idcs: nested array of column index ranges of the light curve data
    #   to access the column indices for, e.g. spot index 5 (of 10), wavelength index 20 (of 55), and sample index 100 (of 300), use: self.lightcurve_idcs[4][19][99]
    #   to access the corresponding light curve features use: X[:, self.lightcurve_idcs[4][19][99]]
    #
    # self.planet_star_features_idcs: index range of planet/star features
    #   to access planet/star features use: X[:, self.planet_star_features_idcs]
    def transform(self, X):

        features = X[:, self.planet_star_features_idcs]
        for spot_idx in range(0, 10):
            for wavelength_idx in range(0, 55):
                upper_level = np.tile(1.0, X.shape[0])
                lower_level_1 = np.percentile(X[:, self.lightcurve_idcs[spot_idx][wavelength_idx]], q = 1, axis = 1)
                lower_level_5 = np.percentile(X[:, self.lightcurve_idcs[spot_idx][wavelength_idx]], q = 5, axis = 1)
                lower_level_10 = np.percentile(X[:, self.lightcurve_idcs[spot_idx][wavelength_idx]], q = 10, axis = 1)

                level_diff_2_1 = np.sqrt(upper_level - lower_level_1)
                level_diff_2_5 = np.sqrt(upper_level - lower_level_5)
                level_diff_2_10 = np.sqrt(upper_level - lower_level_10)

                features = np.concatenate((features, level_diff_2_1[:, None], level_diff_2_5[:, None], level_diff_2_10[:, None]), axis = 1)
        return features
