from matplotlib import pyplot as plt
import os
import numpy as np
from opt_settings_multistart import *
from opt_DE_multistart import *
from matplotlib.pyplot import cm
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
plt.rcParams['axes.linewidth']=0.4

#choose the animal and parameter estimation type (run_name). Make sure they coincide with opt_settings_multistart
animal = 1
popsize=30
run_name = "2"
number_starts = 1
figsize = (5,3)

run_dir_plotting =f'experiments_differential_evolution/Est_{run_name}/a{animal}'

#some plotting settings

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
   "font.serif": "cm",
    "axes.prop_cycle": cycler('color', ['#1f77b4', '#ff7f0e', '#2ca02c', 
            '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']),
    })

def labelling(key):
    if key == 'gV':
        return'$g_V$'
    elif key == 'gP':
        return '$g_P$'
    elif key== 'kR':
        return '$k_R$'
    elif key == 'nves':
        return '$n_{\mathit{ves}}$'
    elif key == 'N_var':
        return '$N$'
    elif key=='k_base':
        return '$m_1$'
    elif key=='L':
        return '$L$'
    elif key == 't0':
        return '$m_2$'
    else:
        print('No corresponding parameter found')

color = cm.tab10(np.linspace(0, 1, 10))

"""
Plot all animal currents. The colorscheme is specific for the animals throughout the paper. This produces Fig. 3
"""
animals = [1,2,3,4,5]
current_data_strings = ['data/Current_data_animal'+str(i)+'.npy' for i in animals]
fig,ax = plt.subplots(figsize=(5,6),nrows=5,sharex=True, sharey=True, constrained_layout=True)
for i in animals:
    current_data_plot = np.load(current_data_strings[i-1])
    ax[i-1].plot(times,10**9 *current_data_plot, label=f'{i}',c=color[i-1])
fig.supxlabel('$t$ in $s$',fontsize=20)
plt.yticks([-100,0])
fig.legend(title='animal',title_fontsize=15,loc='upper right', ncols=5, fontsize=11,bbox_to_anchor=(1.,1.1))
fig.supylabel('Current $C^{\mathrm{data}}$ in $n$A',fontsize=20)
fig.savefig("analysis/data_plot.png", bbox_inches='tight')

#%%
###################################################################################################################

#load the optimized parameters and the loss of the run
optimized_params = np.load(f'{run_dir_plotting}/result.npy')

obj_values = optimized_params[:,-1]


#sorting
sorted_args = np.argsort(obj_values)
sorted_obj = np.sort(obj_values)




###########################
######## Fit plots ########
###########################



#best fit plot with zoom in for param est 2 and 3
if param_est==1:
    param_Id = list(model.getParameterIds()[i] for i in param_indices)
    model.setParameterById(dict(zip(param_Id, optimized_params[sorted_args[0],1:-1]))) #set the parameters using the optimized ones
    all_params = model.getParameters() #update the model with the parameters
    fig, ax = plt.subplots(figsize=figsize,dpi=200,constrained_layout=True)
    if animal == 1:
        ax.set_ylim(-120,5)
        ax.set_xlim(-0.001,1)
        #the plot of data and simulation
        ax.plot(times[0:2500],current(NF(np.array(all_params)))[0:2500]*10**9,'-',
                markersize = 0.3,linewidth = 1.,label='$C^{\mathrm{sim}}_{\hat{\\theta}}$', c='grey')
        ax.plot(times[0:2500],current_data[0:2500]*10**9,linewidth=1.,label='$C^{\mathrm{data}}$', c='tab:blue')
        
        ax.xaxis.set_tick_params(labelsize=15)
        ax.yaxis.set_tick_params(labelsize=15)
        ax.set_yticks(np.array([-100,-75,-50,-25,0]))

        ax.set_xlabel('$t$ in $s$', fontsize=16)
        ax.set_ylabel('Current in $n$A', fontsize=16)
        ax.legend(loc='lower right',ncols=2, fontsize=14)
        fig.savefig(f'{run_dir_plotting}/best_fit_plot_current_est{run_name}_a{animal}.png')
    elif animal ==4:
        ax.set_ylim(-120,5)
        ax.set_xlim(-0.001,1)
        #the plot of data and simulation
        ax.plot(times[0:2500],current(NF(np.array(all_params)))[0:2500]*10**9,'-',
                markersize = 3,linewidth = 1.,label='$C^{\mathrm{sim}}_{\hat{\\theta}}$', c='grey')
        ax.plot(times[0:2500],current_data[0:2500]*10**9,linewidth=1.,label='$C^{\mathrm{data}}$', c='tab:red')
        
        ax.xaxis.set_tick_params(labelsize=15)
        ax.yaxis.set_tick_params(labelsize=15)
        ax.set_yticks(np.array([-100,-75,-50,-25,0]))
        ax.set_xlabel('$t$ in $s$', fontsize=16)
        ax.set_ylabel('Current in $n$A', fontsize=16)
        ax.legend(loc='lower right',ncols=2, fontsize=14)
        fig.savefig(f'{run_dir_plotting}/best_fit_plot_current_est{run_name}_a{animal}.png')
elif param_est==2 or param_est==3:
    param_Id = list(model.getParameterIds()[i] for i in param_indices)
    model.setParameterById(dict(zip(param_Id, optimized_params[sorted_args[0],1:-1]))) #set the parameters using the optimized ones
    all_params = model.getParameters() #update the model with the parameters
    fig, ax = plt.subplots(figsize=figsize,dpi=200,constrained_layout=True)
    if animal == 1:
        ax.set_ylim(-120,30)
        ax.set_xlim(-0.001,1)
        #plot data and simulation
        ax.plot(times[0:2500],current(NF(np.array(all_params)))[0:2500]*10**9,'-',
                markersize = 0.3,linewidth = 1,label='$C^{\mathrm{sim}}_{\hat{\\theta}}$', c='grey')
        ax.plot(times[0:2500],current_data[0:2500]*10**9,
                linewidth=1,label='$C^{\mathrm{data}}$', c='tab:blue')
        #define the zoomed in window
        axin = ax.inset_axes([0.3,0.68,0.5,0.3])
        axin.set_xlim(0.2,0.3)
        axin.set_ylim(-100,-25)
        #plot the data and simulation in the window
        axin.plot(times[0:2500],current(NF(np.array(all_params)))[0:2500]*10**9, linewidth=1,
                  c='tab:blue', label='$C^{\mathrm{sim}}_{\hat{\\theta}}$')
        axin.plot(times[0:2500],current_data[0:2500]*10**9,  linewidth=1,c='grey', label='$C^{\mathrm{data}}$')
        
        ax.indicate_inset_zoom(axin, edgecolor="black", lw=1)
        axin.set_yticklabels([])
        axin.set_xticklabels([])
        ax.xaxis.set_tick_params(labelsize=15)
        ax.yaxis.set_tick_params(labelsize=15)
        ax.set_yticks(np.array([-100,-75,-50,-25,0]))
        ax.set_xlabel('$t$ in $s$', fontsize=16)
        ax.set_ylabel('Current in $n$A', fontsize=16)
        ax.legend(loc='lower right',ncols=2, fontsize=13.7)
        fig.savefig(f'{run_dir_plotting}/best_fit_plot_current_est{run_name}_a{animal}.png')
    elif animal ==4:
        ax.set_ylim(-120,30)
        ax.set_xlim(-0.001,1)
        #plot the data and simulation
        ax.plot(times[0:2500],current(NF(np.array(all_params)))[0:2500]*10**9,'-',
                markersize = 3,linewidth = 1,label='$C^{\mathrm{sim}}_{\hat{\\theta}}$', c='grey')
        ax.plot(times[0:2500],current_data[0:2500]*10**9,linewidth=1,label='$C^{\mathrm{data}}$', c='tab:red')
        #define the zoomed in window
        axin = ax.inset_axes([0.4,0.75,0.5,0.23])
        axin.set_xlim(0.2,0.3)
        axin.set_ylim(-65,-10)
        #plot th data and simulation in the zoomed in window
        axin.plot(times[0:2500],current(NF(np.array(all_params)))[0:2500]*10**9, linewidth=1,c='tab:red', label='$C^{\mathrm{sim}}_{\hat{\\theta}}$')
        axin.plot(times[0:2500],current_data[0:2500]*10**9,  linewidth=1,c='grey', label='$C^{\mathrm{data}}$')
        
        ax.indicate_inset_zoom(axin, edgecolor="black", lw=1)
        axin.set_yticklabels([])
        axin.set_xticklabels([])
        ax.xaxis.set_tick_params(labelsize=15)
        ax.yaxis.set_tick_params(labelsize=15)
        ax.set_yticks(np.array([-100,-75,-50,-25,0]))
        ax.set_xlabel('$t$ in $s$', fontsize=16)
        ax.set_ylabel('Current in $n$A', fontsize=16)
        ax.legend(loc='lower right',ncols=2, fontsize=13.7)
        fig.savefig(f'{run_dir_plotting}/best_fit_plot_current_w_zoom_est{run_name}_a{animal}.png')
#%%
#################################################################################################################
#Variability plots of the multistart (Fig A2, A4, A6)
if param_est ==1 or param_est==3:
    fig, ax = plt.subplots(figsize=(4,2.5), dpi=200, constrained_layout=True)
    ax.plot(np.array(param_bounds),
            ['N','$g_V$','$g_P$','$k_R$','$n_{\mathrm{ves}}$'],
              linestyle='dashed', c='black',lw=0.8) # plot the parameter bounds as dashed line
    [ax.plot(optimized_params[sorted_args[i],1:-1].T,
             ['N','$g_V$','$g_P$','$k_R$','$n_{\mathrm{ves}}$'], 
             linestyle='solid', marker='o', c='grey',lw=1,markersize=0.5,alpha=1-i/40) 
             for i in range(6,optimized_params.shape[0])] #plot the 6th to last parameter (sorted according to loss) set as gray line with increasing opacity
    [ax.plot(optimized_params[sorted_args[i],1:-1].T,
             ['N','$g_V$','$g_P$','$k_R$','$n_{\mathrm{ves}}$'], 
             lw=1,c='tab:orange',linestyle='solid', marker='o',markersize=0.5,alpha=1-i/40) 
             for i in range(1,6)]# plot the best first five parameter sets (acc to loss) as orange lines
    ax.plot(optimized_params[sorted_args[0],1:-1].T,
            ['N','$g_V$','$g_P$','$k_R$','$n_{\mathrm{ves}}$'],
              lw=1,c='red',linestyle='solid', marker='o',markersize=0.5) #plot the best parameter set as red line
    ax.set_yticks(['N','$g_V$','$g_P$','$k_R$','$n_{\mathrm{ves}}$'])
    ax.set_xticks([0,2])
    ax.xaxis.set_tick_params(labelsize=15)
    ax.yaxis.set_tick_params(labelsize=15)
    ax.set_xlabel('$log$10 of parameter values', fontsize=15)
    ax.set_ylabel('parameters', fontsize=15)
    fig.savefig(f'{run_dir_plotting}/variability_of_parameters{run_name}_a{animal}.png')
elif param_est ==2:
    fig, ax = plt.subplots(figsize=(4,2.5), dpi=200, constrained_layout=True)

    #swap columns for correct display
    ##swap L and nves column (4 and 5) and t0 and kbase (6 and 7)
    dummy1 = optimized_params[:,1:-1].T[[5,4]]
    dummy2 = optimized_params[:,1:-1].T[[7,6]]
    population_parameters = np.concatenate((optimized_params[:,1:-1].T[0:4,:], dummy1,dummy2))
    dum1 = np.array(param_bounds)[[5,4]]
    dum2 = np.array(param_bounds)[[7,6]]
    param_bounds = np.concatenate((np.array(param_bounds)[0:4,:],dum1,dum2))
    dummy1 = optimized_params[sorted_args[0],1:-1][[5,4]]
    dummy2 = optimized_params[sorted_args[0],1:-1][[7,6]]
    optimized_params = np.concatenate((optimized_params[sorted_args[0],1:-1].T[0:4], dummy1,dummy2))

    #plot the parameter bounds as black dashed line
    ax.plot(np.array(param_bounds),
            ['N','$g_V$','$g_P$','$k_R$','$n_{\mathrm{ves}}$','L','$m_1$','$m_2$'],
              linestyle='dashed', c='black',lw=0.8)
    #plot the best first five parameter sets as orange line
    [ax.plot(population_parameters.T[sorted_args[i],:],
             ['N','$g_V$','$g_P$','$k_R$','$n_{\mathrm{ves}}$','L','$m_1$','$m_2$'], 
             lw=1,c='tab:orange',linestyle='solid', marker='o',markersize=0.5,alpha=1-i/40) for i in range(1,6)]
    # plot the parameter sets as gray line with increasing opacity
    [ax.plot(population_parameters.T[sorted_args[i],:],
             ['N','$g_V$','$g_P$','$k_R$','$n_{\mathrm{ves}}$','L','$m_1$','$m_2$'], 
             linestyle='solid', marker='o', c='grey',lw=1 ,markersize=0.5,alpha=1-i/40)
               for i in range(6,population_parameters.shape[1])]
    #plot the best parameter set as red line
    ax.plot(optimized_params,
        ['N','$g_V$','$g_P$','$k_R$','$n_{\mathrm{ves}}$','L','$m_1$','$m_2$'], 
        lw=1,c='red',linestyle='solid', marker='o',markersize=0.3)
    
    ax.set_yticks(['N','$g_V$','$g_P$','$k_R$','$n_{\mathrm{ves}}$','L','$m_1$','$m_2$'])
    ax.set_xticks([-8,0,2])
    ax.xaxis.set_tick_params(labelsize=15)
    ax.yaxis.set_tick_params(labelsize=15)
    ax.set_xlabel('log10 of parameter values', fontsize=15)
    ax.set_ylabel('parameters', fontsize=15)
    fig.savefig(f'{run_dir_plotting}/variability_of_parameters{run_name}_a{animal}.png')
