from sr2021 import SR2021
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import copy


# This scipt will plot the 2D colorplot
# for each term in the equation for
# zonal drifts.

def get_terms(season, solar_activity, hwm_version):
    x = SR2021()
    shape = (len(x.alts), len(x.lts))

    ui = np.empty(shape)
    term1 = np.empty(shape)
    term2 = np.empty(shape)
    term3 = np.empty(shape)

    for i, alt in enumerate(x.alts):
        for j, lt in enumerate(x.lts):
            x.predict(lt, alt, season, solar_activity, hwm_version)

            # Terms where the JRO zonal or vertical drifts
            # are NaN are also NaN
            if np.isnan(x.ui_jro) or np.isnan(x.wi_jro):
                ui[i, j] = np.nan
                term1[i, j] = np.nan
                term2[i, j] = np.nan
                term3[i, j] = np.nan

            else:
                # Model zonal drifts
                ui[i, j] = x.ui

                # Pedersen weighted zonal neutral winds
                term1[i, j] = x.u_phi_pedersen

                # Hall weighted meridional winds mutliplied by
                # ratio of Hall-to-Pedersen conductance
                term2[i, j] = (x.hall / x.pedersen) * x.u_p_hall

                # JRO Vertical drifts multiplied by ratio of
                # Hall-to-Pedersen conductance
                term3[i, j] = -(x.hall / x.pedersen) * x.wi_jro

    return np.ma.array(ui, mask=np.isnan(ui)), term1, term2, term3


def term_comparison_plots(terms, titles):
    x = SR2021()

    fig, axs = plt.subplots(4, 1, figsize=(8, 10))
    axs = axs.flatten()

    fontsize = 16
    cmap = copy.copy(mpl.cm.get_cmap("RdBu_r"))
    cmap.set_bad(color='k')
    X, Y = np.meshgrid(x.lts, x.alts)

    # Set approprimate maximum for term
    vmins = [-120, -120, -10, -20]
    vmaxs = [120, 120, 10, 20]

    # Level spacing for contour lines
    levels = [np.arange(-140, 141, 20),
              np.arange(-140, 141, 20),
              np.arange(-10, 10, 2),
              np.arange(-20, 21, 5)]

    for i, ax in enumerate(axs):
        # 2D Color plot
        im = ax.pcolormesh(X, Y, terms[i], vmin=vmins[i], vmax=vmaxs[i], cmap=cmap, shading='nearest')

        # Contour lines
        cim = ax.contour(X, Y, terms[i], colors='k', levels=levels[i])
        ax.clabel(cim, inline=True, fontsize=10, fmt='%.2d',)

        # Colorbar
        cbar = fig.colorbar(im, ax=ax, ticks=[vmins[i], 0, vmaxs[i]], aspect=10, fraction=.1, pad=.01)
        cbar.ax.tick_params(labelsize=fontsize - 2)

        ax.set_xlim(0, 24)
        ax.set_xticks(np.arange(0, 25, 2))

        ax.set_ylim(240, 560)
        ax.set_yticks(x.alts[::2])

        ax.tick_params(which='major', axis='both', labelsize=fontsize - 2)

        if i == 3:
            ax.set_xlabel('Local Time', fontsize=fontsize)
        ax.set_ylabel('Altitude (km)', fontsize=fontsize)
        ax.set_title(titles[i], fontsize=fontsize)

    fig.tight_layout()
    plt.show()


if __name__ == '__main__':
    jun_hsf_terms = get_terms('jun', 'hsf', 2007)
    term_comparison_plots(jun_hsf_terms,
                          [r'$U_i$ Model Drifts for June Solstice HSF (HWM14)',
                           r'$U_{\phi}^P$ (HWM14)',
                           r'$\frac{\Sigma_H}{\Sigma_P}U_L^H$ (HWM14)',
                           r'-$\frac{\Sigma_H}{\Sigma_P}W_i$'])
