import numpy as np
import pandas as pd
import h5py


class SR2021:
    def __init__(self):
        # File handle
        self.f = h5py.File('sr2021.hdf5', 'r')

        # Altitude bins [km]
        self.alts = np.array([240, 280, 320, 360, 400, 440, 480, 520, 560])

        # Local time bins
        self.lts = np.linspace(0, 24, 49, endpoint=True)

        # Seasonal Bin
        # December Solstice ('dec') [Jan, Feb, Nov, Dec]
        # Equinox ('equ') [Mar, Apr, Sep, Oct]
        # June Solstice ('jun') [May, Jun, Jul, Aug]
        self.season = None

        # Solar Activity Bin
        # High Solar Flux - HSF
        # Low Solar Flux - LSF
        self.solar_activity = None

        # Average Kp of Jicamarca Observations
        self.Kp = 2

        # Average F10.7 of Jicamarca Observations
        self.f107 = np.nan

        # HWM Version used to drive neutral wind dynamo
        self.hwm_version = np.nan

        # SR2021 Zonal Drift Model Prediction
        self.ui = np.nan

        # Jicamarca Climatological Drifts
        self.ui_jro = np.nan
        self.wi_jro = np.nan

        # Field line integrated quantities
        self.pedersen = np.nan
        self.hall = np.nan
        self.u_phi_pedersen = np.nan
        self.u_p_hall = np.nan

        # Field line data (pandas dataframe)
        self.fl_data = None

    def predict(self, lt, alt, season, solar_activity, hwm_version=2014):
        # Check that input variables are allowable
        self._check_inputs(lt, alt, season, solar_activity, hwm_version)
        self.season = season
        self.solar_activity = solar_activity
        self.f107 = self._average_f107(season, solar_activity)
        self.hwm_version = hwm_version

        i = np.where(self.alts == alt)[0][0]
        j = np.where(self.lts == np.mod(lt, 24))[0][0]
        gset = '%s_%s' % (season.lower(), solar_activity.lower())
        self.ui_jro = self.f[gset]['Ui_jro'][i, j]
        self.wi_jro = self.f[gset]['Wi_jro'][i, j]
        self.pedersen = self.f[gset]['EP'][i, j]
        self.hall = self.f[gset]['EH'][i, j]

        if hwm_version == 1993:
            self.ui = self.f[gset]['Ui_model_93'][i, j]
            self.u_phi_pedersen = self.f[gset]['EPUPS93'][i, j] / self.pedersen
            self.u_p_hall = self.f[gset]['EHUHL93'][i, j] / self.hall

        elif hwm_version == 2007:
            self.ui = self.f[gset]['Ui_model_07'][i, j]
            self.u_phi_pedersen = self.f[gset]['EPUPS07'][i, j] / self.pedersen
            self.u_p_hall = self.f[gset]['EHUHL07'][i, j] / self.hall

        elif hwm_version == 2014:
            self.ui = self.f[gset]['Ui_model_14'][i, j]
            self.u_phi_pedersen = self.f[gset]['EPUPS14'][i, j] / self.pedersen
            self.u_p_hall = self.f[gset]['EHUHL14'][i, j] / self.hall

        self.fl_data = pd.DataFrame(self.f[gset]['fl_data']['%d_%d' % (i, j)][:],
                                    columns=self.f.attrs['column_names'])

    def _average_f107(self, season, solar_activity):
        if solar_activity.lower() == 'lsf':
            return 85
        elif season in ['dec', 'equ'] and solar_activity == 'hsf':
            return 150
        elif season == 'jun' and solar_activity == 'hsf':
            return 130

    def _check_inputs(self, lt, alt, season, solar_activity, hwm_version):
        if lt not in self.lts:
            raise ValueError('lt is type float in '
                             '[0.0, 0.5, 1.0, ..., 23.0, 23.5]')

        if alt not in self.alts:
            raise ValueError('alt is type int in '
                             '[240, 280, 320, 360, 400, 440, 480, 520, 560]')

        if season.lower() not in ['dec', 'equ', 'jun']:
            raise ValueError('season is type str \n'
                             '\'dec\' - [Jan, Feb, Nov, Dec]\n'
                             '\'equ\' - [Mar, Apr, Sep, Oct]\n'
                             '\'jun\' - [May, Jun, Jul, Aug]')

        if solar_activity.lower() not in ['lsf', 'hsf']:
            raise ValueError('solar_activity is type str\n'
                             '\'lsf\' - Low Solar Flux\n'
                             '\'hsf\' - High Solar Flux')

        if hwm_version not in [1993, 2007, 2014]:
            raise ValueError('hwm_version is type int in [1993, 2007, 2014]')

    def print_fl_key_description(self):
        print(self.f.attrs['field_line_key_descriptions'])

    def close(self):
        self.f.close()
        self.f = None


if __name__ == '__main__':
    pass
