import numpy as np
import pandas as pd
import os
import sys
import pickle


#%% file configuration
data_dir = 'data'
cache_dir = 'cache'
data_file_extension = '.csv'
data_files = ['ariel_test_data_' + str(file_idx) + data_file_extension for file_idx in range(1, 11)]
cache_file = 'ariel_test_data_cache.pickle'
model_file = 'model.pickle'

result_dir = './'
result_file = 'results.txt'

n_spots = 10
n_targets = 55
n_wavelengths = 55
n_samples = 300


#%% load data from CSV files (or cache file when it exists)
data_cache = os.path.join(cache_dir, cache_file)
if os.path.isfile(data_cache):
    # load data from cache file
    print('data cache file exists, loading from cache file')
    print('ATTENTION: make sure the cache file is up-to-date! if you change the data loading code, delete the cache file!')

    # load data from cache
    data_avg_df = pickle.load(open(data_cache, 'rb'))
else:
    # load data from CSV files
    print('loading data...')
    data_avg_df = None
    for data_file in data_files:
        print('loading file %s' % data_file)

        file_data_df = pd.read_csv(os.path.join(data_dir, data_file), index_col = 'file_idx')
        print('size of raw data: %.1fGB' % (sys.getsizeof(file_data_df)/1024/1024/1024))

        ##% preprocess input data
        # * calculate average over photon noise instances
        # * observations contain now all spot instances

        # average star/planet features over photon noise instances and spot noise instances (they are the same for all photon/spot noise instances anyway! this is just cleaner code)
        file_data_avg_df = file_data_df.groupby(['planet_idx'])['star_temp', 'star_logg', 'star_rad', 'star_mass', 'star_k_mag', 'period'].mean()

        # average wavelength time-series over photon noise instances
        w_columns = ['w' + str(wavelength_idx) + '-' + str(sample_idx) for wavelength_idx in range(1, n_wavelengths + 1) for sample_idx in range(1, n_samples + 1)]
        file_w_timeseries_df = file_data_df[['planet_idx', 'spot_noise_idx'] + w_columns].groupby(['planet_idx', 'spot_noise_idx']).mean().unstack()
        # new column hierarchy: spot->wavelength->sample
        reorder_columns_indices = [(spot_idx, 'w' + str(wavelength_idx) + '-' + str(sample_idx)) for spot_idx in range(1, n_spots + 1) for wavelength_idx in range(1, n_wavelengths + 1) for sample_idx in range(1, n_samples + 1)]
        reorder_columns_names = [str(spot_idx) + '-w' + str(wavelength_idx) + '-' + str(sample_idx) for spot_idx in range(1, n_spots + 1) for wavelength_idx in range(1, n_wavelengths + 1) for sample_idx in range(1, n_samples + 1)]
        file_w_timeseries_df = file_w_timeseries_df.swaplevel(0, 1, axis = 1)
        file_w_timeseries_df = file_w_timeseries_df[reorder_columns_indices]
        file_w_timeseries_df.columns = reorder_columns_names

        # join data frames
        file_data_avg_df = file_data_avg_df.join(file_w_timeseries_df) # columns: star_temp, star_logg, star_rad, star_mass, star_k_mag, period, 1-w1-1, 1-w1-2, ..., 1-w2-1, 1-w2-2, ..., 2-w1-1, 2-w1-2, ..., 10-w55-300
        # spot/wavelength/sample are one-based indexed but column is a zero-based index!
        # spot: [1, 10], wavelength: [1, 55], sample: [1, 300]
        # column index = (spot-1)*16500 + (wavelength-1)*300 + sample + 60

        if data_avg_df is None:
            data_avg_df = file_data_avg_df
        else:
            data_avg_df = pd.concat((data_avg_df, file_data_avg_df), verify_integrity = True)

    file_data_df = None
    file_data_avg_df = None

    # write cache file
    print('writing cache file...')
    pickle.dump(data_avg_df, open(data_cache, 'wb'))

print('size of total photon noise-averaged data: %.1fGB' % (sys.getsizeof(data_avg_df)/1024/1024/1024))


#%% get input data and define column indices
planet_star_features_idcs = range(0, 6)
lightcurve_idcs = [[range((spot_idx-1)*16500 + (wavelength_idx-1)*300 + 5 + 1, (spot_idx-1)*16500 + (wavelength_idx-1)*300 + 5 + 300 + 1) for wavelength_idx in range(1, n_wavelengths + 1)] for spot_idx in range(1, n_spots + 1)]
assert data_avg_df.columns[planet_star_features_idcs[0]] == 'star_temp'
assert data_avg_df.columns[planet_star_features_idcs[-1]] == 'period'
assert data_avg_df.columns[lightcurve_idcs[0][0][0]] == '1-w1-1'
assert data_avg_df.columns[lightcurve_idcs[4][20][41]] == '5-w21-42'
assert data_avg_df.columns[lightcurve_idcs[-1][-1][-1]] == '10-w55-300'
input_data = data_avg_df.values
data_avg_df = None


#%% ensure that input data are ordered by planet idx
# TODO (however, so far it's the case)


#%% load pipeline from file
print('loading model from file...')
with open(model_file, 'rb') as f:
    reg_models = pickle.load(f)
    target_scores = pickle.load(f)
print('average R2 score: %f' % np.mean(target_scores))

#%% estimate targets
n_targets = 55
n_observations = input_data.shape[0]
estimated_radii = np.zeros((n_observations*100, n_targets))
print('estimating target radii...')
for target_idx in range(0, n_targets):
    print('wavelength %d' % target_idx)
    prediction_per_planet = reg_models[target_idx].predict(input_data)
    prediction_per_planet_spotnoise_photon_noise = np.tile(prediction_per_planet.reshape(1, -1), (100, 1)).reshape(-1, order = 'F')
    estimated_radii[:, target_idx] = prediction_per_planet_spotnoise_photon_noise


#%% write estimated targets to harddisk
print('saving results to CSV file...')
result_df = pd.DataFrame(estimated_radii)
#result_df.to_csv(os.path.join(result_dir, result_file), sep = '\t', header = False, index = False, encoding = 'ascii', line_terminator = '\n', float_format = '%.13f')
result_df.to_csv(os.path.join(result_dir, result_file), sep = '\t', header = False, index = False, encoding = 'ascii', line_terminator = '\n')
print('done')
