import numpy as np
import matplotlib.pyplot as plt

# define some gray levels
black = (0.0, 0.0, 0.0)
darkgray = (0.25, 0.25, 0.25)
midgray = (0.5, 0.5, 0.5)
lightgray = (0.75, 0.75, 0.75)
white = (1.0, 1.0, 1.0)

# define some red levels
darkred = (0.3, 0., 0.)
midred = (0.5, 0., 0.)
lightred = (0.7, 0., 0.)
red = (1., 0., 0.)

# INM colors 
myblue = (0., 64./255., 192./255.)
myred = (192./255., 64./255., 0.)
mygreen = (0., 192./255., 64./255.)
myorange = (0.5, 0.25, 0.25)
mypink = (0.75, 0.25, 0.75)
myblue2 = (0., 128./255., 192./255.)
myred2 = (245./255., 157./255., 115./255.)

# coolers colors
myred_hex = '#931621'
myyellow_hex = '#B67431'
myblue1_hex = '#2B4162'
myblue2_hex = '#2C8C99'
mygreen_hex = '#0B6E4F'

# define custom colors
blue_light = '#6d9eeb'
blue_dark = '#3c78d8'
pink_light = '#c27ba0'
pink_dark = '#a64d79'
gray_light = '#E8E8E8'
red = '#aa2222'

col_blues = np.array(['#d4e2f7', '#bfd3f2', '#aac4ee', '#95b5ea', '#7fa7e6', '#6a98e2', '#5589dd', '#3f7ad9',
             '#3c78d8', '#2a6cd5', '#2661c0', '#2256aa', '#1d4b95', '#194180', '#15366a', '#112b55', '#0d2040'])[::-1]

# slides colors
lightblue_slides = '#6d9eeb'

panel_wh_ratio = (1. + np.sqrt(5)) / 2. # golden ratio

class visualization():

    def __init__(self):
        '''

        '''
        self.SCIwidth1Col = 3.25  # in inches
        self.SCIwidth2Col = 6.75
        self.inchpercm = 2.54
    
        width = self.SCIwidth1Col
        height = width / panel_wh_ratio

        scale = 0.8

        plt.rcParams['figure.figsize'] = (width, height)

        # resolution of figures in dpi
        # does not influence eps output
        plt.rcParams['figure.dpi'] = 600

        # font
        plt.rcParams['font.size'] = scale*9
        plt.rcParams['axes.titlesize'] = scale*9
        plt.rcParams['axes.labelsize'] = scale*9
        plt.rcParams['legend.fontsize'] = scale*9
        plt.rcParams['font.family'] = ['serif']
        plt.rcParams['pdf.fonttype'] = 42
#         plt.rcParams['font.serif'] = ['Times']
        # plt.rcParams['pdf.use14corefonts'] = True

        plt.rcParams['lines.linewidth'] = scale*1.0

        # size of markers (points in point plots)
        plt.rcParams['lines.markersize'] = scale * 2.5
        plt.rcParams['patch.linewidth'] = scale * 1.0
        plt.rcParams['axes.linewidth'] = scale * 0.5     # edge linewidth

        # ticks distances
        plt.rcParams['xtick.major.size'] = scale * 1.5      # major tick size in points
        plt.rcParams['xtick.minor.size'] = scale * 1.5      # minor tick size in points
        plt.rcParams['lines.markeredgewidth'] = scale * 0.5  # line width of ticks
        plt.rcParams['grid.linewidth'] = scale * 0.5
        plt.rcParams['xtick.major.pad'] = scale * 2      # distance to major tick label in points
        plt.rcParams['xtick.minor.pad'] = scale * 2      # distance to the minor tick label in points
        plt.rcParams['ytick.major.size'] = scale * 1.5      # major tick size in points
        plt.rcParams['ytick.minor.size'] = scale * 1.5      # minor tick size in points
        plt.rcParams['ytick.major.width'] = scale * 0.5      # major tick size in points
        plt.rcParams['ytick.minor.width'] = scale * 0.2      # minor tick size in points
        plt.rcParams['xtick.major.width'] = scale * 0.5      # major tick size in points
        plt.rcParams['xtick.minor.width'] = scale * 0.2      # minor tick size in points
        plt.rcParams['ytick.major.pad'] = scale * 2      # distance to major tick label in points
        plt.rcParams['ytick.minor.pad'] = scale * 2      # distance to the minor tick label in points

        # ticks textsize
        plt.rcParams['xtick.labelsize'] = scale * 7
        plt.rcParams['ytick.labelsize'] = scale * 7

        # use latex to generate the labels in plots
        # not needed anymore in newer versions
        # using this, font detection fails on adobe illustrator 2010-07-20 
        # plt.rcParams['text.usetex'] = True
        plt.rcParams['ps.useafm'] = False   # use of afm fonts, results in small files
        plt.rcParams['ps.fonttype'] = 3    # Output Type 3 (Type3) or Type 42 (TrueType)
        plt.rcParams['mathtext.fontset'] = 'cm'    # set computer modern as font used in tex text
  

    ##################
    ### DIMENSIONS ###
    ##################

    def set_SCI_1column_fig_style(self, ratio=panel_wh_ratio):
        '''figure size corresponding to Plos 1 column'''
        plt.rcParams.update({
            'figure.figsize' : [self.SCIwidth1Col,self.SCIwidth1Col/ratio],
        })


    def set_SCI_2column_fig_style(self, ratio=panel_wh_ratio ):
        '''figure size corresponding to Plos 2 columns'''
        plt.rcParams.update({
            'figure.figsize' : [self.SCIwidth2Col, self.SCIwidth2Col/ratio],
        })
        
        


    ############
    ### MISC ###
    ############

    def remove_axis_junk(self, ax, which=['right', 'top']):
        '''remove upper and right axis'''
        # for loc, spine in ax.spines.iteritems():
        #     if loc in which:
        #         spine.set_color('none')
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')
    
    def make_axis_cross(self,ax):
        # Move left y-axis and bottim x-axis to centre, passing through (0,0)
        ax.spines['left'].set_position('center')
        #ax.spines['bottom'].set_position('center')
        
        # Eliminate upper and right axes
        ax.spines['right'].set_color('none')
        ax.spines['top'].set_color('none')

        # Show ticks in the left and lower axes only
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')
        return ax

    def legend(self, ax, on=True, loc=1):
        plt.sca(ax)
        if on:
            plt.legend(loc=loc)
        return ax

    def title(self, ax, title=''):
        plt.sca(ax)
        plt.suptitle(title)
        return ax


