#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 23 10:51:54 2018

@author: shihweif
"""

import numpy as np 
from matplotlib import pyplot as plt
from mpl_toolkits.basemap import Basemap, shiftgrid
#from MEOF import MEOF
#from EOF import EOF
#from ModelResults import ModelResults

class DrawFutureProjections:
    __lat_bnd = [-20, 20]
    __lon_bnd = [122, 290]
    __lat  = np.arange(__lat_bnd[1]-(__lat_bnd[0])+1)+__lat_bnd[0]
    __lon  = np.arange((__lon_bnd[1]-__lon_bnd[0])/1.5+1)*1.5+__lon_bnd[0]

    def __init__(self, *args, **kwargs):
        self.allmodelPlots = args[0]
        self.experiments = args[1]
        self.numOfExperiments = len(args[1])
        self.Moderate = kwargs['Moderate'] if 'Moderate' in kwargs else False
	plt.switch_backend('agg')

    def plotNaturePaperFigure3(self):
        warmSST_all = []
        coldSST_all = []
        #ensoEvolution_all = []
        ensoToSFCorr_all = []
        for i in range(self.numOfExperiments):
            allModelP = self.allmodelPlots[i]
            warmSST, warmSST_obs = allModelP.getSSTTimeseries(box='warm-box')
            coldSST, coldSST_obs = allModelP.getSSTTimeseries(box='cold-box')
            ensoToSFCorr, ensoToSFCorr_obs = allModelP.getENSOToSFRelations()
            ensoToSFCorr_all.append(ensoToSFCorr)
            warmSST_all.append(warmSST)
            coldSST_all.append(coldSST)

        colors = ['C0','C1','C2','C3']
        from matplotlib import gridspec
        import os
        from matplotlib import font_manager as fm, rcParams
        rcParams.update({'errorbar.capsize': 2})
        #fpath = os.path.join(rcParams["datapath"], "/export/home/shihweif/ensoComplexity/fonts/Georgia.ttf")
        prop = fm.FontProperties(fname="fonts/Georgia.ttf")
        plt.close(0)
        plt.figure(0, figsize=(8,4))
        gs = gridspec.GridSpec(2,5)
        axInds = plt.subplot(gs[:,3:])
        axWarm = plt.subplot(gs[0,:3])
        axCold = plt.subplot(gs[1,:3])
        width = 0.5
        X_obs   = np.arange(warmSST_obs.shape[0])/12.+100+97
        for i in range(self.numOfExperiments):
            asymmetry = ensoToSFCorr_all[i]
            meanAsySF = np.nanmean(asymmetry)
            stdAsySF = np.nanstd(asymmetry)
            axInds.bar(0.5+i,meanAsySF, width)
            axInds.errorbar(0.5+i, meanAsySF, yerr=stdAsySF, color='k', elinewidth=0.8)
            warmSST      = warmSST_all[i]
            coldSST      = coldSST_all[i]
            X = np.arange(self.allmodelPlots[i].years*12.)/12.
            if self.allmodelPlots[i].experiment == 'historical':
                X = X + 100
            elif self.allmodelPlots[i].experiment == 'RCP45':
                X = X + 100 + 145
            elif self.allmodelPlots[i].experiment == 'RCP85':
                X = X + 100 + 145
            yMeansWarm = np.nanmean(warmSST, axis=0)
            yStdsWarm = np.nanstd(warmSST, axis=0)
            axWarm.plot(X,yMeansWarm, label=self.allmodelPlots[i].experiment, color=colors[i], linewidth=0.6)
            axWarm.fill_between(X, yMeansWarm-yStdsWarm, yMeansWarm+yStdsWarm, alpha=0.5, facecolor=colors[i])
            yMeansCold = np.nanmean(coldSST, axis=0)
            yStdsCold = np.nanstd(coldSST, axis=0)
            axCold.plot(X,yMeansCold, label=self.allmodelPlots[i].experiment, color=colors[i], linewidth=0.6)
            axCold.fill_between(X, yMeansCold-yStdsCold, yMeansCold+yStdsCold, alpha=0.5, facecolor=colors[i])
        axInds.bar(4.5,ensoToSFCorr_obs, width, color='k')
        axInds.set_title('Tendency to Multi-year or Cyclic ENSO', fontsize=10, fontproperties=prop)
        axInds.set_ylabel('Correlations', fontsize=10, fontproperties=prop)
        axInds.axhline(y=0, linestyle='-',color='k')
        axInds.set_xticks([0.5, 1.5, 2.5, 3.5, 4.5])
        axInds.set_xticklabels(self.experiments+['Reanalysis'], fontsize=7, fontproperties=prop)
        yticks_c = [-0.6, -0.4, -0.2, 0, 0.2, 0.4, 0.6]
        axInds.set_yticks(yticks_c)
        axInds.set_yticklabels(["{:.1f}".format(ytick) for ytick in yticks_c], fontsize=8, fontproperties=prop)
        axInds.set_xlim(0,5)
        axInds.set_ylim(-0.6,0.6)
        text_prop = fm.FontProperties(fname="fonts/Georgia.ttf", size=10)
        axInds.text(.25, 0.85, 'Multi-year ENSO',horizontalalignment='center',verticalalignment='top', fontproperties=text_prop, transform=axInds.transAxes)
        axInds.text(.25, 0.15, 'ENSO cycle',horizontalalignment='center',verticalalignment='bottom', fontproperties=text_prop,transform=axInds.transAxes)
        axInds.text(-0.1, 1.06, '(c)',horizontalalignment='left',verticalalignment='top', fontsize=8,transform=axInds.transAxes)
        xticks = [0,50, 100, (100+39), (100+89), (100+145-5), (100+145+50-5),  (100+145+94)]
        yticks = [24, 26, 28, 30]
        xLabelss = ['-100','-50','1861','1900','1950','2000','2050','2100']
        axWarm.text(.01, 1.12, '(a)',horizontalalignment='left',verticalalignment='top', fontsize=8,transform=axWarm.transAxes)
        axWarm.set_title('SST in the Equatorial Eastern Pacific', fontsize=12, fontproperties=prop)
        axWarm.set_xlabel('Year', fontsize=10, fontproperties=prop)
        axWarm.set_ylabel('SST ($^\circ$C)', fontsize=10, fontproperties=prop)
        axWarm.plot(X_obs,warmSST_obs, label='Reanalysis', color='k', linewidth=0.4)
        axWarm.axhline(y=28, linestyle=':', color='k', linewidth=1)
        axWarm.set_xticks(xticks)
        axWarm.set_xticklabels(xLabelss, fontproperties=prop)
        axWarm.set_yticks(yticks)
        axWarm.set_yticklabels(["{:.1f}".format(ytick) for ytick in yticks], fontproperties=prop)
        axWarm.tick_params(labelsize=8)
        warmLegend = axWarm.legend(loc=2, fontsize=6, ncol=3)
        for legobj in warmLegend.legendHandles:
            legobj.set_linewidth(2.0)
        axCold.text(.01, 1.12, '(b)',horizontalalignment='left',verticalalignment='top', fontsize=8, transform=axCold.transAxes)
        axCold.set_title('SST in the Equatorial Central Pacific', fontsize=12, fontproperties=prop)
        axCold.set_xlabel('Year', fontsize=10, fontproperties=prop)
        axCold.set_ylabel('SST ($^\circ$C)', fontsize=10, fontproperties=prop)
        axCold.plot(X_obs,coldSST_obs, label='Reanalysis', color='k', linewidth=0.4)
        axCold.axhline(y=28, linestyle=':', color='k', linewidth=1)
        axCold.set_xticks(xticks)
        axCold.set_xticklabels(xLabelss, fontproperties=prop)
        axCold.set_yticklabels(axCold.get_yticks(), fontproperties=prop)
        axCold.tick_params(labelsize=8)
        coldLegend = axCold.legend(loc=2, fontsize=6, ncol=3)
        for legobj in coldLegend.legendHandles:
            legobj.set_linewidth(2.0)
        plt.tight_layout()
        plt.savefig('Figure/Figure2.png', bbox_inches='tight', dpi=800)

    def plotNaturePaperFigure4_SSTPropagation(self):
        from UnclassifiedComputations import UnclassifiedComputations
        unCom       = UnclassifiedComputations()
        var_means_all = []
        for i in range(self.numOfExperiments):
            allModelP = self.allmodelPlots[i]
            var_means, var_means_obs = allModelP.getSSTPropagationsFromSF()
            var_means_all.append(var_means)

        import os
        from matplotlib import font_manager as fm, rcParams
        #fpath = os.path.join(rcParams["datapath"], "/export/home/shihweif/ensoComplexity/fonts/Georgia.ttf")
        prop = fm.FontProperties(fname="fonts/Georgia.ttf")

        levels = np.arange(6)/5.+0.1
        levels = (np.concatenate((-1*levels[::-1],levels)))/2.
        f,axes = plt.subplots(2,5,figsize=(10,6))
        title_prop = fm.FontProperties(fname="fonts/Georgia.ttf", size=14)
        f.suptitle('Equtorial SST Evolutions for Strong SF Months', fontproperties=title_prop)
        for i in range(self.numOfExperiments):
            var_means = var_means_all[i]
            plot = unCom.plotVarSeasonalCompositeWithHaveHullerENLNOnsetOnly(np.nanmean(var_means, axis=0), axes[:,i+1], levels=levels, ifSST=True)
            axes[0,i+1].set_title(self.allmodelPlots[i].experiment, fontsize=10, fontproperties=prop)
        unCom.plotVarSeasonalCompositeWithHaveHullerENLNOnsetOnly(var_means_obs, axes[:,0], levels=levels*2, ifSST=True)
        axes[0,0].set_title('Reanalysis', fontsize=10, fontproperties=prop)
        plt.tight_layout()
        f.subplots_adjust(top=0.9)
        f.subplots_adjust(right=0.925)
        cbar_ax = f.add_axes([0.95, 0.05, 0.03, 0.83])
        f.colorbar(plot, cax=cbar_ax, spacing='proportional', ticks=[-0.5,-0.4,-0.3,-0.2,-0.1,0,0.1,0.2,0.3,0.4,0.5])
        cbar_ax.tick_params(labelsize=8)
        plt.savefig('Figure/Figure3.png', bbox_inches='tight', dpi=800)
        plt.close(f)
        
