from sr2021 import SR2021
from scipy.interpolate import griddata
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import copy

# This script will plot the field line values for
# Pedersen conductivity, Hall conductivity,
# geographic northward winds and geographic
# zonal winds from HWM14 for a given LT


def get_points_values(key, lt, season, solar_activity):
    x = SR2021()

    mag_lats = np.array([])
    alts = np.array([])
    values = np.array([])

    for alt in x.alts:
        x.predict(lt, alt, season, solar_activity)

        mag_lats = np.append(mag_lats, x.fl_data['MAG_LAT'])
        alts = np.append(alts, x.fl_data['ALT'])
        values = np.append(values, x.fl_data[key])

    points = np.vstack((mag_lats, alts)).T

    return points, values


def get_data(keys, lt, season, solar_activity, x_interp, y_interp):
    data = []
    for key in keys:
        points, values = get_points_values(key, lt, season, solar_activity)
        f = griddata(points, values,
                    (x_interp, y_interp),
                    method='cubic')

        # Apex height
        ha = ((6371 + y_interp) / np.cos(np.radians(x_interp))**2) - 6371

        # Remove interpolated values "inside" lowest fieldline
        f[ha < 240] = np.nan

        data.append(f)
    return data


def fl_plots(fl_data, titles, lt, x_interp, y_interp):
    x = SR2021()

    fig, axs = plt.subplots(2, 2, figsize=(16, 10))
    axs = axs.flatten()

    fontsize = 16
    cmap = copy.copy(mpl.cm.get_cmap("RdBu_r"))
    cmap.set_bad(color='k')

    # Set approprimate maximum for term
    vmins = [-6, -6, -40, -120]
    vmaxs = [-3, -3, 40, 120]

    # Level spacing for contour lines
    levels = [np.arange(-6, -3.1, .5),
              np.arange(-6, -3.1, .5),
              np.arange(-40, 41, 10),
              np.arange(-120, 121, 20)]

    for i, ax in enumerate(axs):
        # 2D Color plot
        im = ax.pcolormesh(x_interp, y_interp, fl_data[i], vmin=vmins[i], vmax=vmaxs[i], cmap=cmap, shading='nearest')

        # Contour lines
        cim = ax.contour(x_interp, y_interp, fl_data[i], colors='k', levels=levels[i])

        if i in [0, 1]:
            ax.clabel(cim, inline=True, fontsize=8, fmt='%.2f', )
        else:
            ax.clabel(cim, inline=True, fontsize=8, fmt='%.d', )

        # Colorbar
        cbar = fig.colorbar(im, ax=ax, ticks=[vmins[i], 0, vmaxs[i]], aspect=10, fraction=.1, pad=.01)
        cbar.ax.tick_params(labelsize=fontsize - 2)

        # ax.set_xlim(0, 24)
        # ax.set_xticks(np.arange(0, 25, 2))

        ax.set_ylim(100, 500)
        ax.set_yticks(np.arange(100, 501, 50))

        ax.tick_params(which='major', axis='both', labelsize=fontsize - 2)

        ax.set_xlabel('Magnetic Latitude', fontsize=fontsize)
        ax.set_ylabel('Altitude (km)', fontsize=fontsize)
        ax.set_title(titles[i], fontsize=fontsize)

    fig.suptitle('Equinox HSF at %.1f LT' % lt, fontsize=fontsize, y=.99)

    fig.tight_layout()
    plt.show()


if __name__ == '__main__':
    LT = 18
    mag_lat_interp, alt_interp = np.meshgrid(np.linspace(-15, 15, 200, endpoint=True), np.arange(100, 501, 2))

    equ_hsf_fl_data = get_data(['SIG_PED', 'SIG_HAL', 'VN14', 'UN14'],
                               lt=LT,
                               season='equ',
                               solar_activity='hsf',
                               x_interp=mag_lat_interp,
                               y_interp=alt_interp)

    # Take log10 of pedersen and hall values
    equ_hsf_fl_data[0] = np.log10(equ_hsf_fl_data[0])
    equ_hsf_fl_data[1] = np.log10(equ_hsf_fl_data[1])

    fl_plots(equ_hsf_fl_data,
             [r'$\log_{10}$($\sigma_P$ [cm$^{-3}$])',
              r'$\log_{10}$($\sigma_H$ [cm$^{-3}$])',
              r'Geographic Northward Winds [m/s] (HWM14)',
              r'Geographic Zonal Winds [m/s] (HWM14)]'],
             LT,
             mag_lat_interp,
             alt_interp)
