import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np

import cartopy.crs as ccrs
import cartopy.feature as feature
import cmocean as cm
import matplotlib.ticker as mticker

from scipy import stats
from xarrayMannKendall import Mann_Kendall_test

earth_radius = 6371e3
omega = 7.2921159e-5

def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    """
    Truncate colormap.
    """
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

def area(lat,lon):
    """
    Compute area of a rectilinear grid.
    """
    lat_r = np.radians(lat)
    lon_r = np.radians(lon)
    f=2*omega*np.sin(lat_r)
    grad_lon=lon_r.copy()
    grad_lon.data=np.gradient(lon_r)

    dx=grad_lon*earth_radius*np.cos(lat_r)
    dy=np.gradient(lat_r)*earth_radius

    return dx*dy

import matplotlib.patches as mpatches

def add_patches(axis):
    """
    Add patches over the Black Sea and Caspian Sea.
    """
    axis.add_patch(mpatches.Rectangle(xy=[27, 38], width=20, height=10,
                                    facecolor='darkgrey',
                                    alpha=1,
                                    zorder=3,
                                    transform=ccrs.Geodetic()))

    axis.add_patch(mpatches.Rectangle(xy=[38, 36], width=20, height=11,
                                    facecolor='darkgrey',
                                    alpha=1,
                                    zorder=3,
                                    transform=ccrs.Geodetic()))

    
ccrs_land = feature.NaturalEarthFeature('physical', 'land', '50m',
                                        edgecolor='darkgrey',
                                        facecolor='darkgrey',
                                        linewidth=0.2)


def plot_bars(ax,x,slope,error,width=0.1,**kwargs):
    ax.plot([x,x],[slope-error/2,slope+error/2],**kwargs)
    ax.plot([x-(x/x)*width,x+(x/x)*width],[slope-error/2,slope-error/2],**kwargs)
    ax.plot([x-(x/x)*width,x+(x/x)*width],[slope+error/2,slope+error/2],**kwargs)
    
def plot_barhs(ax,x,slope,error,width=0.1,**kwargs):
    ax.plot([slope-error/2,slope+error/2],[x,x],**kwargs)
    ax.plot([slope-error/2,slope-error/2],[x-(x/x)*width,x+(x/x)*width],**kwargs)
    ax.plot([slope+error/2,slope+error/2],[x-(x/x)*width,x+(x/x)*width],**kwargs)
    
def compute_trends(data):
    slope, intercept, lslope, hslope = stats.mstats.theilslopes(data,range(len(data)),alpha=0.95)
    return slope,intercept

def significance_mk(data):
    h=np.zeros(np.shape(data)[0])
    n=np.zeros(np.shape(data)[0])
    ii=0
    mk_object = Mann_Kendall_test(data,'time',alpha=0.05,MK_modified=True,method="theilslopes")
    for item in data:
        result = mk_object._calc_slope_MK(item.values,effective_n=True)
        h[ii]=result[1]
        n[ii]=result[-1]
        ii+=1
    return h,n 


def plot_Important_Front_regions(ax,regions):
    counter=0
    colorbars=['olive','c','m','g','r','b','#CD6090','#33A1C9']
    for ii in regions.data_vars:
        if 'Important_Front' in ii:
            regions[ii].where(regions[ii]==True).plot.contourf(transform=ccrs.PlateCarree(),
                                                                     cmap='k',alpha=0.2,
                                                                     add_colorbar=False,
                                                                     vmin=0,vmax=1)
            counter+=1
    ax.set_global()

    ax.add_feature(ccrs_land, zorder=1)
    ax.spines['geo'].set_linewidth(1)

    return ax


def plot_Increasing_Front_regions(ax,regions):
    counter=0
    colorbars=['olive','c','m','g','r','b','#CD6090','#33A1C9']
    for ii in regions.data_vars:
        if 'Increasing_Front' in ii:
            regions[ii].where(regions[ii]==True).plot.contourf(transform=ccrs.PlateCarree(),
                                                                     cmap='r',alpha=0.2,
                                                                     add_colorbar=False,
                                                                     vmin=0,vmax=1)
            counter+=1
    ax.set_global()

    ax.add_feature(ccrs_land, zorder=1)
    ax.spines['geo'].set_linewidth(1)

    return ax

def plot_Decreasing_Front_regions(ax,regions):
    counter=0
    colorbars=['olive','c','m','g','r','b','#CD6090','#33A1C9']
    for ii in regions.data_vars:
        if 'Decreasing_Front' in ii:
            regions[ii].where(regions[ii]==True).plot.contourf(transform=ccrs.PlateCarree(),
                                                                     cmap='b',alpha=0.2,
                                                                     add_colorbar=False,
                                                                     vmin=0,vmax=1)
            counter+=1
    ax.set_global()

    ax.add_feature(ccrs_land, zorder=1)
    ax.spines['geo'].set_linewidth(1)

    return ax

ccrs_land_pop = feature.NaturalEarthFeature('physical', 'land', '50m',
                                        edgecolor='black',
                                        facecolor='none',
                                        linewidth=0.2)