import quantities
import pkg_resources
try:
    import cPickle as pickle
except:
    import pickle
import gzip
from hippounit import plottools
import matplotlib.pyplot as plt
import os
import numpy

from quantities import mV, nA, ms, V, s
import json
import collections

# adding the result data to dictionary
#data_dict = collections.OrderedDict()
data_dict = {}


base_dir= '/home/tarluca/HS_OLM_story/hippounit_for_HS_OLM/models/New_models_SURF/PSP_in_HCS_3x_jav_soma/results/PSP_attenuation/'
save_dir = '/home/tarluca/HS_OLM_story/hippounit_for_HS_OLM/models/New_models_SURF/PSP_in_HCS_3x_jav_soma/'

if not os.path.exists(save_dir):
    os.makedirs(save_dir)


for models in os.listdir(base_dir):
    print(models)

    data_dict[models] = {}
    # data_dict[models]['features']=pickle.load(gzip.GzipFile(base_dir +models+'/PSP_attenuation_model_features.p', "rb"))

    # data_dict[models]['EPSP_amps']=pickle.load(gzip.GzipFile(base_dir +models+'/EPSP_amps.p', "rb"))

    with open(base_dir +models+'/PSP_attenuation_model_features.json') as f:
        data_dict[models]['features']=json.load(f, object_pairs_hook=collections.OrderedDict)
    with open(base_dir +models+'/EPSP_amps.json') as f:
        data_dict[models]['EPSP_amps']=json.load(f, object_pairs_hook=collections.OrderedDict)

    with open(base_dir +models+'/PSP_attenuation_mean_model_features.json') as f:
        data_dict[models]['mean_features']=json.load(f, object_pairs_hook=collections.OrderedDict)
    with open(base_dir +models+'/PSP_attenuation_errors.json') as f:
        data_dict[models]['errors']=json.load(f, object_pairs_hook=collections.OrderedDict)

myKeys = list(data_dict.keys())
myKeys.sort()
data_dict = {i: data_dict[i] for i in myKeys}

print(list(data_dict.keys()))


#colors = ['k', 'm', 'c', 'b', '#fdca35', '#b9e751', '#ff2000', '#ff9627']
#colors = ['#fdca35', '#b9e751', '#ff2000', '#ff9627', 'k', 'm', 'b', 'c']

colors = ['seagreen', 'darkolivegreen', 'green', 'limegreen', 'indigo', 'darkmagenta', 'darkviolet', 'mediumvioletred']

# loading observation
with open('/home/tarluca/HS_OLM_story/hippounit_for_HS_OLM/target_features/feat_PSP_attenuation_target_data.json') as f:
    observation = json.load(f, object_pairs_hook=collections.OrderedDict)

# loading stimuli file
#stim_file = pkg_resources.resource_filename("/home/tarluca/hippounit_for_HS_OLM/hippounit/tests/stimuli/PSP_attenuation_stim/stim_PSP_attenuation_test.json")

with open("/home/tarluca/HS_OLM_story/hippounit_for_HS_OLM/hippounit/tests/stimuli/PSP_attenuation_stim/stim_PSP_attenuation_test.json") as f:
    config = json.load(f, object_pairs_hook=collections.OrderedDict)



#plot features

for key, val in observation.items():
    try:
        observation[key] = float(val)
    except ValueError:
        quantity_parts = val.split(" ")
        number = float(quantity_parts[0])
        units = " ".join(quantity_parts[1:])
        observation[key] = quantities.Quantity(number, units)



dists = numpy.array(config['target_distances'])

# Mean absolute feature values plot
"""
plt.figure()


for z, models in enumerate(data_dict):
    exp_mean_attenuations = numpy.array([])
    exp_std_attenuations = numpy.array([])
    model_mean_attenuations = numpy.array([])
    model_std_attenuations = numpy.array([])

    for i in range(len(dists)):
        model_mean_attenuations = numpy.append(model_mean_attenuations, data_dict[models]['mean_features']['mean_attenuation_soma/dend_'+str(dists[i])+'_um']['mean'])
        model_std_attenuations = numpy.append(model_std_attenuations, data_dict[models]['mean_features']['mean_attenuation_soma/dend_'+str(dists[i])+'_um']['std'])
        exp_mean_attenuations = numpy.append(exp_mean_attenuations, observation['mean_attenuation_soma/dend_'+str(dists[i])+'_um'])
        exp_std_attenuations = numpy.append(exp_std_attenuations, observation['std_attenuation_soma/dend_'+str(dists[i])+'_um'])

    plt.errorbar(dists, model_mean_attenuations, yerr = model_std_attenuations, marker ='o', linestyle='none', label = models, color= colors[z])
#plt.errorbar(dists, exp_mean_attenuations, yerr = exp_std_attenuations, marker='x', markersize=10, color = 'black', linestyle='none', label = 'experiment')
plt.xlabel('Distance from soma (um)', fontsize = 16)
plt.ylabel('soma/dendrite attenuation', fontsize = 16)
plt.grid(True)
plt.xticks(fontsize = 16)
plt.yticks(fontsize = 16)
lgd = plt.legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize = "18")
plt.savefig(save_dir + '/PSP_attenuation_feature_means.svg', dpi=600, bbox_extra_artists=(lgd,), bbox_inches='tight')

"""
###Plot features
#plt.figure()


#exp_mean_attenuations = numpy.array([])
#exp_std_attenuations = numpy.array([])
"""

for i in range(len(dists)):

    exp_mean_attenuations = numpy.append(exp_mean_attenuations, observation['mean_attenuation_soma/dend_'+str(dists[i])+'_um'])
    exp_std_attenuations = numpy.append(exp_std_attenuations, observation['std_attenuation_soma/dend_'+str(dists[i])+'_um'])

for z, models in enumerate(data_dict):
    model_attenuations = numpy.array([])

    distances = []
    #location_labels = []

    for k, v in data_dict[models]['features'].items() :
        distances.append(data_dict[models]['features'][k]['distance'])
        model_attenuations = numpy.append(model_attenuations, data_dict[models]['features'][k]['attenuation_soma/dendrite'])
        #location_labels.append(k[0]+'('+str(k[1])+')')

    #for i in range(len(distances)):
    plt.plot(distances, model_attenuations, marker ='o', linestyle='none', label = models, color=colors[z])
#plt.errorbar(dists, exp_mean_attenuations, yerr = exp_std_attenuations, marker='x', markersize=10, color = 'black', linestyle='none', label = 'experiment')
plt.xlabel('Distance from soma (um)', fontsize = 16)
plt.ylabel('soma/dendrite attenuation', fontsize = 16)
plt.grid(True)
plt.xticks(fontsize = 16)
plt.yticks(fontsize = 16)
lgd = plt.legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize = "18")
plt.savefig(save_dir + '/PSP_attenuation_features.svg', dpi=600, bbox_extra_artists=(lgd,), bbox_inches='tight')
"""


"""plot EPSPs"""

"""

plt.figure()


for z, models in enumerate(data_dict):
    #print(data_dict[models]['EPSP_amps'])
    EPSPs_dend = numpy.array([])
    EPSPs_soma = numpy.array([])

    distances = []
    #location_labels = []
    i=0  #for label  

    for k, v in data_dict[models]['EPSP_amps'].items() :
        distances.append(data_dict[models]['EPSP_amps'][k]['distance'])
        EPSPs_dend = numpy.append(EPSPs_dend, data_dict[models]['EPSP_amps'][k]['EPSP_amp_dendrite'])
        EPSPs_soma = numpy.append(EPSPs_soma, data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'])
 
        #location_labels.append(k[0]+'('+str(k[1])+')')

        if i ==0 :  # label only for the first one 
            plt.plot(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_dendrite'], marker ='^', color= colors[z], linestyle='none', label = models)
            plt.plot(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'], marker ='o', color= colors[z], linestyle='none')
       
        else:
            plt.plot(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_dendrite'], marker ='^', color= colors[z], linestyle='none')
            plt.plot(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'], marker ='o', color= colors[z], linestyle='none')

        i+=1
    # break

plt.xlabel('Synapse distance from soma (um)', fontsize = 16)
plt.ylabel('Peak amplitude (mV)', fontsize = 16)
plt.title('EPSPs')
plt.grid(True)
plt.xticks(fontsize = 16)
plt.yticks(fontsize = 16)
lgd = plt.legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize = "18")
plt.savefig(save_dir + '/PSP_attenuation_EPSP_amps.svg', dpi=600, bbox_extra_artists=(lgd,), bbox_inches='tight')
"""

"""plot EPSPs soma only"""

fig = plt.figure()

ax=plt.axes()

lines = []

for z, models in enumerate(data_dict):
    #print(data_dict[models]['EPSP_amps'])
    #EPSPs_dend = numpy.array([])
    EPSPs_soma = numpy.array([])

    distances = []
    #location_labels = []
    i=0  #for label  

    for k, v in data_dict[models]['EPSP_amps'].items() :
        distances.append(data_dict[models]['EPSP_amps'][k]['distance'])
        #EPSPs_dend = numpy.append(EPSPs_dend, data_dict[models]['EPSP_amps'][k]['EPSP_amp_dendrite'])
        EPSPs_soma = numpy.append(EPSPs_soma, data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'])
 
        #location_labels.append(k[0]+'('+str(k[1])+')')

        if i ==0 :  # label only for the first one 
            #plt.plot(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_dendrite'], marker ='^', color= colors[j], linestyle='none', label = models)
            lines.append(ax.plot(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'], marker ='o', color= colors[z], linestyle='none', label = models))
            #plt.scatter(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'], marker ='.')
       
        else:
            #plt.plot(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_dendrite'], marker ='^', color= colors[j], linestyle='none')
            lines.append(ax.plot(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'], marker ='o', color= colors[z], linestyle='none'))
            #plt.scatter(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'], marker ='.')

        i+=1

ax.set_ylim(0)
ax.set_xlabel('Synapse distance from soma (um)', fontsize = 16)
ax.set_ylabel('Peak amplitude (mV)', fontsize = 16)
ax.set_title('EPSPs on soma')
ax.grid(True)
ax.tick_params(labelsize=12)

#ax.lines.clear()

#lgd = ax.legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize = "18")
fig.set_figwidth(5)

plt.savefig(save_dir + '/PSP_attenuation_EPSP_amps_on_soma_only_grid_HCS_3x.svg', dpi=300, bbox_inches='tight')

"""plot EPSPs soma only"""

fig = plt.figure()

ax2=plt.axes()

lines2 = []

for z, models in enumerate(data_dict):
    #print(data_dict[models]['EPSP_amps'])
    #EPSPs_dend = numpy.array([])
    EPSPs_soma = numpy.array([])

    distances = []
    #location_labels = []
    i=0  #for label  

    for k, v in data_dict[models]['EPSP_amps'].items() :
        distances.append(data_dict[models]['EPSP_amps'][k]['distance'])
        #EPSPs_dend = numpy.append(EPSPs_dend, data_dict[models]['EPSP_amps'][k]['EPSP_amp_dendrite'])
        EPSPs_soma = numpy.append(EPSPs_soma, data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'])
 
        #location_labels.append(k[0]+'('+str(k[1])+')')

        if i ==0 :  # label only for the first one 
            #plt.plot(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_dendrite'], marker ='^', color= colors[j], linestyle='none', label = models)
            lines2.append(ax2.plot(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'], marker ='o', markeredgewidth = 0.0, color= colors[z], linestyle='none', label = models, alpha = 0.5))
            #ax2.scatter(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'], marker ='o', alpha = 0.5)
       
        else:
            #plt.plot(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_dendrite'], marker ='^', color= colors[j], linestyle='none')
            lines2.append(ax2.plot(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'], marker ='o', markeredgewidth = 0.0, color= colors[z], linestyle='none', alpha = 0.5))
            #ax2.scatter(data_dict[models]['EPSP_amps'][k]['distance'], data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'], marker ='o', alpha = 0.5)

        i+=1

ax2.set_ylim(0)
#ax.set_xlabel('Synapse distance from soma (um)', fontsize = 16)
#ax.set_ylabel('Peak amplitude (mV)', fontsize = 16)
#ax.set_title('EPSPs on soma')
#ax.grid(True)

ax.tick_params(labelsize=16)

#ax2.set_xticklabels(" ")
#ax2.set_yticklabels(" ")
#ax.lines.clear()

#lgd = ax2.legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize = "18")
fig.set_figwidth(5)

plt.savefig(save_dir + '/PSP_attenuation_EPSP_amps_on_soma_only_dot_50_HCS_3x.png', transparent=True, dpi=300, bbox_inches='tight')


"""Plot EPSP soma means """
"""
plt.figure()
tolerance = 25
mean_locations = [25, 75, 125, 175, 225, 275, 325, 375, 425, 475, 525, 575]

for z, models in enumerate(data_dict):
    #print(data_dict[models]['EPSP_amps'])
    #EPSPs_dend = numpy.array([])

    #location_labels = []
    EPSPs_soma = numpy.array([])
    distances = numpy.array([])
    j = 0
    for k, v in data_dict[models]['EPSP_amps'].items() :

        distances = numpy.append(distances, data_dict[models]['EPSP_amps'][k]['distance'])
        distances = numpy.asarray(distances, dtype = 'float64')
        EPSPs_soma = numpy.append(EPSPs_soma, data_dict[models]['EPSP_amps'][k]['EPSP_amp_soma'])
        EPSPs_soma = numpy.asarray(EPSPs_soma, dtype = 'float64')

    for loc in mean_locations:

        EPSP_means = numpy.array([])
        for i, dist in enumerate(distances):
            if dist >= loc - tolerance and dist < loc + tolerance:
                EPSP_means = numpy.append(EPSP_means, EPSPs_soma[i])
        mean_EPSP = numpy.mean(EPSP_means)
        std_EPSP = numpy.std(EPSP_means)

        if j ==0 :  # label only for the first one 
            plt.errorbar(loc, mean_EPSP, yerr = std_EPSP, marker='o', linestyle='none', label = models, color=colors[z])

        else:
            plt.errorbar(loc, mean_EPSP, yerr = std_EPSP, marker='o', linestyle='none', color=colors[z])

        j+=1

plt.xlabel('Synapse distance from soma (um)', fontsize = 16)
plt.ylabel('Peak amplitude on soma (mV)', fontsize = 16)
plt.title('EPSPs on soma means')
plt.grid(True)
plt.xticks(fontsize = 16)
plt.yticks(fontsize = 16)
lgd = plt.legend(bbox_to_anchor=(1.0, 1.0), loc = 'upper left', fontsize = "18")
plt.savefig(save_dir + '/PSP_attenuation_EPSP_amps_on_soma_means.svg', dpi=600, bbox_extra_artists=(lgd,), bbox_inches='tight')
"""
plt.show()

