
try:
    import cPickle as pickle
except:
    import pickle
import gzip
from hippounit import plottools
import matplotlib.pyplot as plt
import os
import numpy
import math
# EDIT THIS PART
base_dir= '/home/tarluca/hippounit_for_HS_OLM/models/All_models_for_PSP/'
name = 'Chosen_OLM_somafeatures_optimized_csak_szamolashoz'
#
#
#-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# adding the result data to dictionary
data_dict={}

for models in os.listdir(base_dir + name + '/results/somaticfeat/'):
    print(models)
    if os.path.exists(base_dir + name + '/results/somaticfeat/'+models+'/soma_errors.p') and os.path.exists(base_dir + name + '/results/somaticfeat/'+models+'/soma_features.p'):
        data_dict[models]={}
        data_dict[models]['features']=pickle.load(gzip.GzipFile(base_dir + name + '/results/somaticfeat/'+models+'/soma_features.p', "rb"))
        data_dict[models]['errors']=pickle.load(gzip.GzipFile(base_dir + name + '/results/somaticfeat/'+models+'/soma_errors.p', "rb"))
#
#
#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# PLOT ABSOLUTE FEATURES
#create plot
axs = plottools.tiled_figure("absolute features", figs={}, frames=1, columns=1, orientation='page',
                        height_ratios=None, top=0.97, bottom=0.05, left=0.25, right=0.97, hspace=0.1, wspace=0.2)
try:
    colormap = plt.cm.spectral     
except:
    colormap = plt.cm.nipy_spectral
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())+1)))) # +1 nedded for the experiment
#model features
for key in data_dict.keys():
    x=[]
    y=[]
    xerr=[]
    for i in range (len(data_dict[key]['features']['features_names'])):
        feature_name=data_dict[key]['features']['features_names'][i]
        x.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature mean'])
        y.append(i)
        xerr.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature sd'])
    axs[0].errorbar(x, y, xerr=xerr, marker='.', linestyle='none', label=key,  clip_on=False)

#experimental features
x=[]
y=[]
xerr=[]
for i in range (len(data_dict[key]['features']['features_names'])):
    feature_name=data_dict[key]['features']['features_names'][i]
    y.append(i)
    x.append(float(data_dict[key]['features']['observation'][feature_name]['Mean']))
    xerr.append(float(data_dict[key]['features']['observation'][feature_name]['Std']))
axs[0].errorbar(x, y, xerr=xerr, marker='*', linestyle='none', label="Experiment", clip_on=False)
#plot characteristics
lgd=axs[0].legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize=12)
axs[0].yaxis.set_ticks(range(len(data_dict[key]['features']['features_names'])))
axs[0].set_yticklabels(data_dict[key]['features']['features_names'])
axs[0].set_ylim(-1, len(data_dict[key]['features']['features_names']))
axs[0].set_title('Absolute Features', fontsize=14)

plt.savefig(base_dir + name + '/' + name +'_absolute_features', bbox_extra_artists=(lgd,), bbox_inches='tight')

#
#
#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# SUBPLOTS OF ABSOLUTE FEATURES
#create plot with subplot grid
try:
    colormap = plt.cm.spectral     
except:
    colormap = plt.cm.nipy_spectral

import matplotlib.gridspec as gridspec

fig, axs3 = plt.subplots(2,3) #hanyszor hanyas subplots-ot akarunk?
fig.set_size_inches(18, 16)
t = fig.suptitle('Absolute features', fontsize=12)

gs = gridspec.GridSpec(2, 3, wspace=0.8, bottom = 0.05, top=0.95, left= 0.15, right = 0.95) #itt is meg kell valtoztatni hogy hanyszor hanyas? gondolom igen

ax0=plt.subplot(gs[0])
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())+1)))) # +1 nedded for the experiment
plt.subplot(gs[1])
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())+1)))) # +1 nedded for the experiment
plt.subplot(gs[2])
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())+1)))) # +1 nedded for the experiment
plt.subplot(gs[3])
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())+1)))) # +1 nedded for the experiment
plt.subplot(gs[4])
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())+1)))) # +1 nedded for the experiment


#model features
for key in data_dict.keys():

    x_v=[]
    xerr_v=[]
    labels_v = []

    x_relV=[]
    xerr_relV=[]
    labels_relV = []

    x_sag_ratio=[]
    xerr_sag_ratio=[]
    labels_sag_ratio = []

    x_sag_amp=[]
    xerr_sag_amp=[]
    labels_sag_amp = []

    x_sag_time_constant=[]
    xerr_sag_time_constant=[]
    labels_sag_time_constant = []

    # key is the name of the model
    for i in range (len(data_dict[key]['features']['features_names'])):
        feature_name=data_dict[key]['features']['features_names'][i]
        if 'steady_state_voltage' in feature_name or 'voltage_base' in feature_name or 'AP_begin_voltage' in feature_name or 'AHP_depth_abs' in feature_name or 'AHP_depth_abs_slow' in feature_name:
            x_v.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature mean'])
            xerr_v.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature sd'])
            labels_v.append(feature_name)
        if 'voltage_deflection' in feature_name or 'voltage_deflection_begin' in feature_name:
            x_relV.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature mean'])
            xerr_relV.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature sd'])
            labels_relV.append(feature_name)
        if 'sag_amplitude' in feature_name:
            x_sag_amp.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature mean'])
            xerr_sag_amp.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature sd'])
            labels_sag_amp.append(feature_name)   
        if 'sag_ratio' in feature_name:
            x_sag_ratio.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature mean'])
            xerr_sag_ratio.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature sd'])
            labels_sag_ratio.append(feature_name)
        if 'sag_time_constant' in feature_name:
            x_sag_time_constant.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature mean'])
            xerr_sag_time_constant.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature sd'])
            labels_sag_time_constant.append(feature_name)

    y_v=range(len(x_v))
    plt.subplot(gs[0])
    plt.errorbar(x_v, y_v, xerr=xerr_v, marker='.', linestyle='none', label=key,  clip_on=False)

    y_relV=range(len(x_relV))
    plt.subplot(gs[1])
    plt.errorbar(x_relV, y_relV, xerr=xerr_relV, marker='.', linestyle='none', label=key,  clip_on=False)
    plt.yticks(range(len(labels_relV)), labels_relV, fontsize = 8)

    y_sag_ratio=range(len(x_sag_ratio))
    plt.subplot(gs[2])
    plt.errorbar(x_sag_ratio, y_sag_ratio, xerr=xerr_sag_ratio, marker='.', linestyle='none', label=key,  clip_on=False)

    y_sag_amp=range(len(x_sag_amp))
    plt.subplot(gs[3])
    plt.errorbar(x_sag_amp, y_sag_amp, xerr=xerr_sag_amp, marker='.', linestyle='none', label=key,  clip_on=False)

    y_sag_time_constant=range(len(x_sag_time_constant))
    plt.subplot(gs[4])
    plt.errorbar(x_sag_time_constant, y_sag_time_constant, xerr=xerr_sag_time_constant, marker='.', linestyle='none', label=key,  clip_on=False)

#...........................................................................................................................................
#experimental features

x_v=[]
xerr_v=[]
labels_v = []

x_relV=[]
xerr_relV=[]
labels_relV = []

x_sag_ratio=[]
xerr_sag_ratio=[]
labels_sag_ratio = []

x_sag_amp=[]
xerr_sag_amp=[]
labels_sag_amp = []

x_sag_time_constant=[]
xerr_sag_time_constant=[]
labels_sag_time_constant = []

for i in range (len(data_dict[key]['features']['features_names'])):
    feature_name=data_dict[key]['features']['features_names'][i]
    if 'steady_state_voltage' in feature_name or 'voltage_base' in feature_name or 'AP_begin_voltage' in feature_name or 'AHP_depth_abs' in feature_name or 'AHP_depth_abs_slow' in feature_name:
        x_v.append(float(data_dict[key]['features']['observation'][feature_name]['Mean']))
        xerr_v.append(float(data_dict[key]['features']['observation'][feature_name]['Std']))
        labels_v.append(feature_name)
    if 'voltage_deflection.' in feature_name or 'voltage_deflection_begin' in feature_name:
        x_relV.append(float(data_dict[key]['features']['observation'][feature_name]['Mean']))
        xerr_relV.append(float(data_dict[key]['features']['observation'][feature_name]['Std']))
        labels_relV.append(feature_name)
    if 'sag_ratio' in feature_name:
        x_sag_ratio.append(float(data_dict[key]['features']['observation'][feature_name]['Mean']))
        xerr_sag_ratio.append(float(data_dict[key]['features']['observation'][feature_name]['Std']))
        labels_sag_ratio.append(feature_name)
    if 'sag_amplitude' in feature_name:
        x_sag_amp.append(float(data_dict[key]['features']['observation'][feature_name]['Mean']))
        xerr_sag_amp.append(float(data_dict[key]['features']['observation'][feature_name]['Std']))
        labels_sag_amp.append(feature_name)
    if 'sag_time_constant' in feature_name:
        x_sag_time_constant.append(float(data_dict[key]['features']['observation'][feature_name]['Mean']))
        xerr_sag_time_constant.append(float(data_dict[key]['features']['observation'][feature_name]['Std']))
        labels_sag_time_constant.append(feature_name)


y_v=range(len(x_v))
plt.subplot(gs[0])
plt.errorbar(x_v, y_v, xerr=xerr_v, marker='*', linestyle='none', label='Experiment',  clip_on=False)
plt.yticks(range(len(labels_v)), labels_v, fontsize = 8)
plt.xticks(fontsize = 8)
plt.xlabel('mV', fontsize=8)

y_relV=range(len(x_relV))
plt.subplot(gs[1])
plt.errorbar(x_relV, y_relV, xerr=xerr_relV, marker='*', linestyle='none', label='Experiment',  clip_on=False)
#plt.yticks(range(len(labels_relV)), labels_relV, fontsize = 8)
plt.xticks(fontsize = 8)
plt.xlabel('mV', fontsize=8)

y_sag_ratio=range(len(x_sag_ratio))
plt.subplot(gs[2])
plt.errorbar(x_sag_ratio, y_sag_ratio, xerr=xerr_sag_ratio, marker='*', linestyle='none', label='Experiment',  clip_on=False)
plt.yticks(range(len(labels_sag_ratio)), labels_sag_ratio, fontsize = 8)
plt.xticks(fontsize = 8)

y_sag_amp=range(len(x_sag_amp))
plt.subplot(gs[3])
plt.errorbar(x_sag_amp, y_sag_amp, xerr=xerr_sag_amp, marker='*', linestyle='none', label='Experiment',  clip_on=False)
plt.yticks(range(len(labels_sag_amp)), labels_sag_amp, fontsize = 8)
plt.xticks(fontsize = 8)
plt.xlabel('mV', fontsize=8)

y_sag_time_constant=range(len(x_sag_time_constant))
plt.subplot(gs[4])
plt.errorbar(x_sag_time_constant, y_sag_time_constant, xerr=xerr_sag_time_constant, marker='*', linestyle='none', label='Experiment',  clip_on=False)
plt.yticks(range(len(labels_sag_time_constant)), labels_sag_time_constant, fontsize = 8)
plt.xticks(fontsize = 8)
plt.xlabel('ms', fontsize=8)

handles, labels = ax0.get_legend_handles_labels()
#fig.legend(bbox_to_anchor=(1.1, 1.2), loc = 'upper left', ncol = 4, fontsize = 8)
lgd = fig.legend(handles, labels, loc = "upper right", bbox_to_anchor=(1.1, 0.95) )# bbox_extra_artists does not work with it


#plt.legend(bbox_to_anchor=(1.1, 1.05), ncol = 1, fontsize = 8)
#plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=2, mode="expand", borderaxespad=0, prop={'size':8})
plt.savefig(base_dir + name + '/' + name + '_absolute_features_subplots.svg', bbox_extra_artists=(lgd,t,), bbox_inches='tight')


#
#
#---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# plot subplots separatelly

# vol***************************************************************************************************************************************
sub1 = plottools.tiled_figure("voltage", figs={}, frames=1, columns=1, orientation='page',
                        height_ratios=None, top=0.97, bottom=0.05, left=0.25, right=0.97, hspace=0.1, wspace=0.2)
try:
    colormap = plt.cm.spectral     
except:
    colormap = plt.cm.nipy_spectral
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())+1)))) # +1 nedded for the experiment


for key in data_dict.keys():
    x=[]
    y=[]
    xerr=[]
    for i in range (len(data_dict[key]['features']['features_names'])):
        feature_name=data_dict[key]['features']['features_names'][i]
        if 'steady_state_voltage' in feature_name or 'voltage_base' in feature_name or 'AP_begin_voltage' in feature_name or 'AHP_depth_abs' in feature_name or 'AHP_depth_abs_slow' in feature_name:
            x.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature mean'])
            xerr.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature sd'])
            y.append(feature_name)
            #sb labels_v.append(feature_name)
    sub1[0].errorbar(x, y, xerr=xerr, marker='o', linestyle='none', label=key,  clip_on=False)

# adding experimental features
x=[]
y=[]
xerr=[]
for i in range (len(data_dict[key]['features']['features_names'])):
    feature_name=data_dict[key]['features']['features_names'][i]
    if 'steady_state_voltage' in feature_name or 'voltage_base' in feature_name or 'AP_begin_voltage' in feature_name or 'AHP_depth_abs' in feature_name or 'AHP_depth_abs_slow' in feature_name:
        x.append(float(data_dict[key]['features']['observation'][feature_name]['Mean']))
        xerr.append(float(data_dict[key]['features']['observation'][feature_name]['Std']))
        y.append(feature_name)
sub1[0].errorbar(x, y, xerr=xerr, marker='*', linestyle='none', label="Experiment", clip_on=False)

lgd=sub1[0].legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize=12)
sub1[0].yaxis.set_ticks(range(len(y)))
sub1[0].set_yticklabels(y)
sub1[0].set_ylim(-1, len(y))
sub1[0].set_title('Voltage subplots', fontsize=14)
sub1[0].set_xlabel("mV")

plt.savefig(base_dir + name + '/' + name +' voltage', bbox_extra_artists=(lgd,), bbox_inches='tight')

# rel vol***************************************************************************************************************************************
sub2 = plottools.tiled_figure("relative voltage", figs={}, frames=1, columns=1, orientation='page',
                        height_ratios=None, top=0.97, bottom=0.05, left=0.25, right=0.97, hspace=0.1, wspace=0.2)
try:
    colormap = plt.cm.spectral     
except:
    colormap = plt.cm.nipy_spectral
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())+1)))) # +1 nedded for the experiment


for key in data_dict.keys():
    x=[]
    y=[]
    xerr=[]
    for i in range (len(data_dict[key]['features']['features_names'])):
        feature_name=data_dict[key]['features']['features_names'][i]
        if 'voltage_deflection' in feature_name or 'voltage_deflection_begin' in feature_name:
            x.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature mean'])
            xerr.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature sd'])
            y.append(feature_name)
    sub2[0].errorbar(x, y, xerr=xerr, marker='o', linestyle='none', label=key,  clip_on=False)

# adding experimental features
x=[]
y=[]
xerr=[]
for i in range (len(data_dict[key]['features']['features_names'])):
    feature_name=data_dict[key]['features']['features_names'][i]
    if 'voltage_deflection' in feature_name or 'voltage_deflection_begin' in feature_name:
        x.append(float(data_dict[key]['features']['observation'][feature_name]['Mean']))
        xerr.append(float(data_dict[key]['features']['observation'][feature_name]['Std']))
        y.append(feature_name)
sub2[0].errorbar(x, y, xerr=xerr, marker='*', linestyle='none', label="Experiment", clip_on=False)

lgd=sub2[0].legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize=12)
sub2[0].yaxis.set_ticks(range(len(y)))
sub2[0].set_yticklabels(y)
sub2[0].set_ylim(-1, len(y))
sub2[0].set_title('Relative voltage subplots', fontsize=14)
sub2[0].set_xlabel("mV")

plt.savefig(base_dir + name + '/' + name +' relative voltage', bbox_extra_artists=(lgd,), bbox_inches='tight')

# sag ratio***************************************************************************************************************************************
sub3 = plottools.tiled_figure("sag ratio", figs={}, frames=1, columns=1, orientation='page',
                        height_ratios=None, top=0.97, bottom=0.05, left=0.25, right=0.97, hspace=0.1, wspace=0.2)
try:
    colormap = plt.cm.spectral     
except:
    colormap = plt.cm.nipy_spectral
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())+1)))) # +1 nedded for the experiment


for key in data_dict.keys():
    x=[]
    y=[]
    xerr=[]
    for i in range (len(data_dict[key]['features']['features_names'])):
        feature_name=data_dict[key]['features']['features_names'][i]
        if 'sag_ratio' in feature_name:
            x.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature mean'])
            xerr.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature sd'])
            y.append(feature_name)
    sub3[0].errorbar(x, y, xerr=xerr, marker='o', linestyle='none', label=key,  clip_on=False)

# adding experimental features
x=[]
y=[]
xerr=[]
for i in range (len(data_dict[key]['features']['features_names'])):
    feature_name=data_dict[key]['features']['features_names'][i]
    if 'sag_ratio' in feature_name:
        x.append(float(data_dict[key]['features']['observation'][feature_name]['Mean']))
        xerr.append(float(data_dict[key]['features']['observation'][feature_name]['Std']))
        y.append(feature_name)
sub3[0].errorbar(x, y, xerr=xerr, marker='*', linestyle='none', label="Experiment", clip_on=False)

lgd=sub3[0].legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize=12)
sub3[0].yaxis.set_ticks(range(len(y)))
sub3[0].set_yticklabels(y)
sub3[0].set_ylim(-1, len(y))
sub3[0].set_title('Sag ratio subplots', fontsize=14)
sub3[0].set_xlabel("Ration")

plt.savefig(base_dir + name + '/' + name +' sag ratio', bbox_extra_artists=(lgd,), bbox_inches='tight')

# sag amp***************************************************************************************************************************************
sub4 = plottools.tiled_figure("sag amp", figs={}, frames=1, columns=1, orientation='page',
                        height_ratios=None, top=0.97, bottom=0.05, left=0.25, right=0.97, hspace=0.1, wspace=0.2)
try:
    colormap = plt.cm.spectral     
except:
    colormap = plt.cm.nipy_spectral
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())+1)))) # +1 nedded for the experiment


for key in data_dict.keys():
    x=[]
    y=[]
    xerr=[]
    for i in range (len(data_dict[key]['features']['features_names'])):
        feature_name=data_dict[key]['features']['features_names'][i]
        if 'sag_amplitude' in feature_name:
            x.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature mean'])
            xerr.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature sd'])
            y.append(feature_name)
    sub4[0].errorbar(x, y, xerr=xerr, marker='o', linestyle='none', label=key,  clip_on=False)

# adding experimental features
x=[]
y=[]
xerr=[]
for i in range (len(data_dict[key]['features']['features_names'])):
    feature_name=data_dict[key]['features']['features_names'][i]
    if 'sag_amplitude' in feature_name:
        x.append(float(data_dict[key]['features']['observation'][feature_name]['Mean']))
        xerr.append(float(data_dict[key]['features']['observation'][feature_name]['Std']))
        y.append(feature_name)
sub4[0].errorbar(x, y, xerr=xerr, marker='*', linestyle='none', label="Experiment", clip_on=False)

lgd=sub4[0].legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize=12)
sub4[0].yaxis.set_ticks(range(len(y)))
sub4[0].set_yticklabels(y)
sub4[0].set_ylim(-1, len(y))
sub4[0].set_title('Sag amplitudes subplots', fontsize=14)
sub4[0].set_xlabel("mV")

plt.savefig(base_dir + name + '/' + name +' sag amplitudes', bbox_extra_artists=(lgd,), bbox_inches='tight')

# sag time constant***************************************************************************************************************************************
sub5 = plottools.tiled_figure("sag time constant", figs={}, frames=1, columns=1, orientation='page',
                        height_ratios=None, top=0.97, bottom=0.05, left=0.25, right=0.97, hspace=0.1, wspace=0.2)
try:
    colormap = plt.cm.spectral     
except:
    colormap = plt.cm.nipy_spectral
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())+1)))) # +1 nedded for the experiment


for key in data_dict.keys():
    x=[]
    y=[]
    xerr=[]
    for i in range (len(data_dict[key]['features']['features_names'])):
        feature_name=data_dict[key]['features']['features_names'][i]
        if 'sag_time_constant' in feature_name:
            x.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature mean'])
            xerr.append(data_dict[key]['features']['feature_results_dict'][feature_name]['feature sd'])
            y.append(feature_name)
    sub5[0].errorbar(x, y, xerr=xerr, marker='o', linestyle='none', label=key,  clip_on=False)

# adding experimental features
x=[]
y=[]
xerr=[]
for i in range (len(data_dict[key]['features']['features_names'])):
    feature_name=data_dict[key]['features']['features_names'][i]
    if 'sag_time_constant' in feature_name:
        x.append(float(data_dict[key]['features']['observation'][feature_name]['Mean']))
        xerr.append(float(data_dict[key]['features']['observation'][feature_name]['Std']))
        y.append(feature_name)
sub5[0].errorbar(x, y, xerr=xerr, marker='*', linestyle='none', label="Experiment", clip_on=False)

lgd=sub5[0].legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize=12)
sub5[0].yaxis.set_ticks(range(len(y)))
sub5[0].set_yticklabels(y)
sub5[0].set_ylim(-1, len(y))
sub5[0].set_title('Sag time constants subplots', fontsize=14)
sub5[0].set_xlabel("ms")

plt.savefig(base_dir + name + '/' + name +' sag time constants', bbox_extra_artists=(lgd,), bbox_inches='tight')

#
#
#-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
#Errors

axs2 = plottools.tiled_figure("features", figs={}, frames=1, columns=1, orientation='page',
                              height_ratios=None, top=0.97, bottom=0.05, left=0.25, right=0.97, hspace=0.1, wspace=0.2)
try:
    colormap = plt.cm.spectral     
except:
    colormap = plt.cm.nipy_spectral
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())))))

#plt.figure()
for key in data_dict.keys():
    x=[]
    y=[]
    xerr=[]
    for i in range (len(data_dict[key]['errors']['features_names'])):
        if "Apic" not in data_dict[key]['errors']['features_names'][i]:
            feature_name=data_dict[key]['errors']['features_names'][i]
            x.append(data_dict[key]['errors']['feature_results_dict'][feature_name])
            y.append(i)
           # xerr.append(data_dict[key]['errors']['feature_results_dict'][feature_name]['feature error sd'])
    axs2[0].plot(x, y, marker='o', linestyle='none', label=key,  clip_on=False)
    #plt.errorbar(x, y, xerr=xerr, marker='o', linestyle='none', label=key,  clip_on=False)

lgd=axs2[0].legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize=12)
axs2[0].yaxis.set_ticks(range(len(data_dict[key]['errors']['features_names'])))
axs2[0].set_yticklabels(data_dict[key]['errors']['features_names'])
axs2[0].set_ylim(-1, len(data_dict[key]['errors']['features_names']))
axs2[0].set_title('Feature errors', fontsize=14)
#plt.tick_params(labelsize=11)
plt.savefig(base_dir + name + '/' + name + '_feature_errors', bbox_extra_artists=(lgd,), bbox_inches='tight')

'''
# plot of traces
plt.figure(3)
j=0
for key in data_dict.keys():

    plt.subplot(round(len(data_dict.keys())/2.0),2,j+1)
    for i in range (0, len(data_dict[key]['features']['traces_results'])):
        for key_trace, value in data_dict[key]['features']['traces_results'][i].iteritems():
            plt.plot(data_dict[key]['features']['traces_results'][i][key_trace][0], data_dict[key]['features']['traces_results'][i][key_trace][1], label=key_trace)
    plt.legend(loc=2)
    plt.title(key)
    plt.xlabel("ms")
    plt.ylabel("mV")
    j+=1


# plot of traces separately
j=0
for key in data_dict.keys():
    plt.figure(j+4)
    for i in range (0, len(data_dict[key]['features']['traces_results'])):
        for key_trace, value in data_dict[key]['features']['traces_results'][i].iteritems():
            plt.subplot(round(len(data_dict[key]['features']['traces_results'])/2.0),2,i+1)
            plt.tight_layout()
            plt.plot(data_dict[key]['features']['traces_results'][i][key_trace][0], data_dict[key]['features']['traces_results'][i][key_trace][1])
            plt.title(key_trace)
            plt.xlabel("ms")
            plt.ylabel("mV")
    plt.suptitle(key, y=1.0)
    j+=1
'''
'''
plt.figure('1')
plt.figure('2')
plt.figure('3')
plt.figure('4')
'''

'''
# scatter plot - 'correllation'
fig, axes = plt.subplots(nrows=2, ncols=1)
fig.tight_layout(pad=1.08, h_pad=10, w_pad=14)

colormap = plt.cm.spectral      #http://matplotlib.org/1.2.1/examples/pylab_examples/show_colormaps.html
plt.gca().set_prop_cycle(plt.cycler('color', colormap(numpy.linspace(0, 0.9, len(data_dict.keys())))))

#data_dict_copy = dict(data_dict)
import copy
data_dict_copy = copy.deepcopy(data_dict)

for key in data_dict_copy.keys():

    del data_dict_copy[key]['features']['feature_results_dict']['AHP_depth.Step0.8']    #We don't want to plot the same feature against itself, so I remove it from the copied dictionary
    data_dict_copy[key]['features']['features_names'].remove('AHP_depth.Step0.8')

    num_of_subplots = len(data_dict_copy[key]['features']['features_names'])

    for i in range (len(data_dict_copy[key]['features']['features_names'])):
        feature_name=data_dict_copy[key]['features']['features_names'][i]
        x = data_dict_copy[key]['features']['feature_results_dict'][feature_name]['feature mean']
        y = data_dict[key]['features']['feature_results_dict']['AHP_depth.Step0.8']['feature mean'] # the original dictionary is used, which contains the feature of interest
        xerr = data_dict_copy[key]['features']['feature_results_dict'][feature_name]['feature sd']
        yerr = data_dict[key]['features']['feature_results_dict']['AHP_depth.Step0.8']['feature sd']

        x_exp = float(data_dict[key]['features']['observation'][feature_name]['Mean'])
        y_exp = float(data_dict[key]['features']['observation']['AHP_depth.Step0.8']['Mean'])
        xerr_exp = float(data_dict[key]['features']['observation'][feature_name]['Std'])
        yerr_exp = float(data_dict[key]['features']['observation']['AHP_depth.Step0.8']['Std'])

        plt.subplot(int(numpy.ceil(numpy.sqrt(num_of_subplots))), int(numpy.ceil(numpy.sqrt(num_of_subplots))), i+1)

        plt.errorbar(x, y, xerr=xerr, yerr=yerr, marker='o', linestyle='none', label=key,  clip_on=False)
        plt.errorbar(x_exp, y_exp, xerr=xerr_exp, yerr=yerr_exp, marker='D', linestyle='none', color = 'black', label='experiment',  clip_on=False)
        plt.xlabel(feature_name)
        plt.ylabel('AHP_depth.Step0.8')
fig.set_size_inches(30, 30)
plt.legend(bbox_to_anchor=(1.05, 0), loc='lower left', borderaxespad=0., ncol=4)
#plt.savefig('/mnt/extra31/Modellezo_csapat/CA1_pyramidal/multi_opt_validation_results/last_ca1_pc_morefeatures_slowerNa/last_ca1_pc_morefeatures_slowerNa_scatter_plot_features_with_experiment')
plt.savefig('/mnt/extra31/Modellezo_csapat/CA1_pyramidal/multi_opt_validation_results/ca1_pc_morefeatures_ih/ca1_pc_morefeatures_ih_scatter_plot_features.pdf')
'''
plt.show()
