from sr2021 import SR2021
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import copy

# This scipt will plot the 2D colorplot
# for the model zonal drifts for each HWM version
# and the Jicamarca climatological zonal drifts
# for a given season and solar flux condition


def get_zonal_drifts(season, solar_activity):
    """
    :param season:
    :param solar_activity:
    :return:
    """
    x = SR2021()
    shape = (len(x.alts), len(x.lts))

    ui_jro = np.empty(shape)
    ui_hwm14 = np.empty(shape)
    ui_hwm07 = np.empty(shape)
    ui_hwm93 = np.empty(shape)

    for i, alt in enumerate(x.alts):
        for j, lt in enumerate(x.lts):
            # Model zonal drift predictions using HWM14
            x.predict(lt, alt, season, solar_activity, hwm_version=2014)
            ui_hwm14[i, j] = x.ui

            # Model zonal drift predictions using HWM07
            x.predict(lt, alt, season, solar_activity, hwm_version=2007)
            ui_hwm07[i, j] = x.ui

            # Model zonal drift predictions using HWM93
            x.predict(lt, alt, season, solar_activity, hwm_version=1993)
            ui_hwm93[i, j] = x.ui

            # JRO climatological zonal drifts
            ui_jro[i, j] = x.ui_jro

    return ui_jro, ui_hwm14, ui_hwm07, ui_hwm93


def drift_plots(drifts, titles):
    x = SR2021()

    fig, axs = plt.subplots(4, 1, figsize=(8, 10))
    axs = axs.flatten()

    fontsize = 16
    #cmap = plt.cm.RdBu_r
    cmap = copy.copy(mpl.cm.get_cmap("RdBu_r"))
    cmap.set_bad(color='k')
    X, Y = np.meshgrid(x.lts, x.alts)

    for i, ax in enumerate(axs):
        # 2D colorplot
        im = ax.pcolormesh(X, Y, drifts[i], vmin=-150, vmax=150, cmap=cmap, shading='nearest')

        # Contour Lines
        cim = ax.contour(X, Y, drifts[i], colors='k', levels=np.arange(-140, 141, 20))
        ax.clabel(cim, inline=True, fontsize=10, fmt='%.2d',)

        # Colorbar
        cbar = fig.colorbar(im, ax=ax, ticks=[-150, 0, 150], aspect=10, fraction=.1, pad=.01)
        cbar.ax.tick_params(labelsize=fontsize-2)

        ax.set_xlim(0, 24)
        ax.set_xticks(np.arange(0, 25, 2))

        ax.set_ylim(240, 560)
        ax.set_yticks(x.alts[::2])

        ax.tick_params(which='major', axis='both', labelsize=fontsize-2)

        if i == 3:
            ax.set_xlabel('Local Time', fontsize=fontsize)
        ax.set_ylabel('Altitude (km)', fontsize=fontsize)
        ax.set_title(titles[i], fontsize=fontsize)

    fig.tight_layout()
    plt.show()


if __name__ == '__main__':
    equ_hsf_drifts = get_zonal_drifts('equ', 'hsf')
    drift_plots(equ_hsf_drifts,
                ['Equinox HSF JRO Zonal Drifts',
                 'Equinox HSF Model Zonal Drifts (HWM14)',
                 'Equinox HSF Model Zonal Drifts (HWM07)',
                 'Equinox HSF Model Zonal Drifts (HWM93)'])
