import numpy as np
import pandas as pd
import os
import sys
import copy
import pickle

from sklearn.pipeline import Pipeline
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler

from shrinkage_linear_regression import ShrinkageLinearRegression
from extract_features_percentile import FeatureExtraction


#%% configuration
data_dir = 'data'
cache_dir = 'cache'
data_file_extension = '.csv'
data_files = ['ariel_train_data_' + str(file_idx) + data_file_extension for file_idx in range(1, 11)]
cache_file = 'ariel_train_data_cache.pickle'
model_file = 'model.pickle'

n_spots = 10
n_targets = 55
n_wavelengths = 55
n_samples = 300


#%% load data from CSV files (or cache file when it exists)
data_cache = os.path.join(cache_dir, cache_file)
if os.path.isfile(data_cache):
    # load data from cache file
    print('cache file exists, loading from cache file')
    print('ATTENTION: make sure the cache file is up-to-date! if you change the data loading code, delete the cache file so that it is recreated!')

    # load data from cache
    data_avg_df = pickle.load(open(data_cache, 'rb'))
else:
    # load data from CSV files
    print('loading data...')
    data_avg_df = None
    for data_file in data_files:
        print('loading file %s' % data_file)

        file_data_df = pd.read_csv(os.path.join(data_dir, data_file), index_col = 'file_idx')
        print('size of raw data: %.1fGB' % (sys.getsizeof(file_data_df)/1024/1024/1024))

        ##% preprocess input data
        # * calculate average over photon noise instances
        # * observations contain now all spot instances

        # average star/planet features and targets over photon noise instances and spot noise instances (they are the same for all photon/spot noise instances anyway! this is just cleaner code)
        r_columns = ['r' + str(r_idx) for r_idx in range(1, n_targets + 1)]
        file_data_avg_df = file_data_df.groupby(['planet_idx'])[['star_temp', 'star_logg', 'star_rad', 'star_mass', 'star_k_mag', 'period'] + r_columns].mean()

        # average wavelength time-series over photon noise instances
        w_columns = ['w' + str(wavelength_idx) + '-' + str(sample_idx) for wavelength_idx in range(1, n_wavelengths + 1) for sample_idx in range(1, n_samples + 1)]
        file_w_timeseries_df = file_data_df[['planet_idx', 'spot_noise_idx'] + w_columns].groupby(['planet_idx', 'spot_noise_idx']).mean().unstack()
        # new column hierarchy: spot->wavelength->sample
        reorder_columns_indices = [(spot_idx, 'w' + str(wavelength_idx) + '-' + str(sample_idx)) for spot_idx in range(1, n_spots + 1) for wavelength_idx in range(1, n_wavelengths + 1) for sample_idx in range(1, n_samples + 1)]
        reorder_columns_names = [str(spot_idx) + '-w' + str(wavelength_idx) + '-' + str(sample_idx) for spot_idx in range(1, n_spots + 1) for wavelength_idx in range(1, n_wavelengths + 1) for sample_idx in range(1, n_samples + 1)]
        file_w_timeseries_df = file_w_timeseries_df.swaplevel(0, 1, axis = 1)
        file_w_timeseries_df = file_w_timeseries_df[reorder_columns_indices]
        file_w_timeseries_df.columns = reorder_columns_names

        # join data frames
        file_data_avg_df = file_data_avg_df.join(file_w_timeseries_df) # columns: star_temp, star_logg, star_rad, star_mass, star_k_mag, period, r1, ..., r55, 1-w1-1, 1-w1-2, ..., 1-w2-1, 1-w2-2, ..., 2-w1-1, 2-w1-2, ..., 10-w55-300
        # spot/wavelength/sample are one-based indexed but column is a zero-based index!
        # spot: [1, 10], wavelength: [1, 55], sample: [1, 300]
        # column index = (spot-1)*16500 + (wavelength-1)*300 + sample + 60

        if data_avg_df is None:
            data_avg_df = file_data_avg_df
        else:
            data_avg_df = pd.concat((data_avg_df, file_data_avg_df), verify_integrity = True)

    file_data_df = None
    file_data_avg_df = None

    # write cache file
    print('writing cache file...')
    pickle.dump(data_avg_df, open(data_cache, 'wb'))

print('size of total photon noise-averaged data: %.1fGB' % (sys.getsizeof(data_avg_df)/1024/1024/1024))

# data_avg_df:
# * rows corresponds to planets, columns correspond to features
# * lightcures are averaged over photon noise instances
# * columns are: star_temp, star_logg, star_rad, star_mass, star_k_mag, period, r1, ..., r55, 1-w1-1, 1-w1-2, ..., 1-w2-1, 1-w2-2, ..., 2-w1-1, 2-w1-2, ..., 10-w55-300
# * to access a specific spot noise instance, wavelength and sample, calculate the respective column index as:
#     column index = (spot-1)*16500 + (wavelength-1)*300 + sample + 60
#     where spot ranges from 1 to 10, wavelength from 1 to 55, sample from 1 to 300 (spot/wavelength/sample are one-based indexed but column is a zero-based index!)


#%% get targets
r_columns = ['r' + str(r_idx) for r_idx in range(1, n_targets + 1)]
targets = data_avg_df.loc[:, r_columns].values
assert n_targets == targets.shape[1]


#%% get input data and define column indices
data_avg_df = data_avg_df.drop(r_columns, axis = 1)
planet_star_features_idcs = range(0, 6)
lightcurve_idcs = [[range((spot_idx-1)*16500 + (wavelength_idx-1)*300 + 5 + 1, (spot_idx-1)*16500 + (wavelength_idx-1)*300 + 5 + 300 + 1) for wavelength_idx in range(1, n_wavelengths + 1)] for spot_idx in range(1, n_spots + 1)]
assert data_avg_df.columns[planet_star_features_idcs[0]] == 'star_temp'
assert data_avg_df.columns[planet_star_features_idcs[-1]] == 'period'
assert data_avg_df.columns[lightcurve_idcs[0][0][0]] == '1-w1-1'
assert data_avg_df.columns[lightcurve_idcs[4][20][41]] == '5-w21-42'
assert data_avg_df.columns[lightcurve_idcs[-1][-1][-1]] == '10-w55-300'
input_data = data_avg_df.values
data_avg_df = None

# input_data:
# * like 'data_avg_df' but without target columns r1-r55
# * columns are: star_temp, star_logg, star_rad, star_mass, star_k_mag, period, 1-w1-1, 1-w1-2, ..., 1-w2-1, 1-w2-2, ..., 2-w1-1, 2-w1-2, ..., 10-w55-300
# * to access a specific spot noise instance, wavelength and sample, calculate the respective column index as:
#   column index = (spot-1)*16500 + (wavelength-1)*300 + sample + 5
#   where spot ranges from 1 to 10, wavelength from 1 to 55, sample from 1 to 300 (spot/wavelength/sample are one-based indexed but column index is a zero-based index!)


#%% build signal processing pipeline
feature_extractor = FeatureExtraction(lightcurve_idcs, planet_star_features_idcs)
preprocessor = StandardScaler()
reg_model = ShrinkageLinearRegression()

reg_pipe = Pipeline([('fextraction', feature_extractor),
                 ('preprocessor', preprocessor),
                 ('reg', reg_model)])


#%% train and evaluate pipeline
cv = KFold(n_splits = 5)
reg_models = [0] * n_targets
target_scores = [0] * n_targets

# iterate over wavelengths
for target_idx in range(0, n_targets):
    print('training model for wavelength %d' % target_idx)

    # train model
    reg_models[target_idx] = copy.deepcopy(reg_pipe.fit(input_data, targets[:, target_idx]))

    # calucate R2 score
    xv_fold_scores = cross_val_score(reg_pipe, input_data, targets[:, target_idx], cv = cv, scoring = 'r2', n_jobs = 1)
    target_scores[target_idx] = xv_fold_scores.mean()

print('average R2 score: %.2f +/- %.2f' % (np.mean(target_scores), np.std(target_scores)))


#%% save model to disk
print('saving model to file...')
with open(model_file, 'wb') as f:
    pickle.dump(reg_models, f)
    pickle.dump(target_scores, f)
