import matplotlib.pyplot as plt
import numpy
from importlib import reload
import sys
import os
import importlib
import types
# print(sys.modules.keys())
import multiprocessing
import multiprocessing.pool
original_packs = sys.modules.copy()
#import currents_visualization
import math

from neuron import h
import neuron

import json
import collections


#Modellke

def modellke(model_name, num_sim, synapse_number, noise, parameters, distance, tolerance):

        # Models and modfiles can be found in the Models/Detailed_models folder, add the path accordingly

        neuron.load_mechanisms('./modfiles/')
        h.load_file('./HOC/' + model_name +'_skinner.hoc')

        v = parameters

        for sec in h.somatic:
            sec.gbar_Ikdrf = v[0]
            sec.gbar_Ikdrs = v[1]
            sec.gbar_Ika = v[2]
            sec.gkhbar_Ih = v[3]
            sec.g_passsd = v[6]
            sec.gkl_Kleaksd = v[8]
            sec.gna_Nasoma = v[10]   
            sec.gbar_IM = v[13]   
        for sec in h.basal:
            sec.gkhbar_Ih = v[3]
            sec.gbar_Ikdrf = v[0]
            sec.gbar_Ikdrs = v[1]
            sec.gbar_Ika = v[2]
            sec.gbar_cat = v[4]
            sec.gcalbar_cal = 10* v[4]
            sec.gkbar_kca = v[5]
            sec.g_passsd = v[6]
            sec.gkl_Kleaksd = v[8]
            sec.gna_Nadend = v[11]
            sec.gbar_IM = v[13]  
        for sec in h.axonal:  
            sec.gbar_Ikdrfaxon = v[0]
            sec.gbar_Ikdrsaxon = v[1]
            sec.g_passaxon = v[7]
            sec.gkl_Kleakaxon = v[9]
            sec.gna_Naaxon = v[12]

        h.finitialize(-80)
        h.fcurrent()

        """ Setting up synapses """

        stim_file = "stim_PSP_attenuation_test_hs.json" # shape and amplitude of the EPSC originated from the HippoUnit's PSP attenuation test
        with open(stim_file, 'r') as f:
            config = json.load(f, object_pairs_hook=collections.OrderedDict)

        distances = distance
        dist_range = [min(distances) - tolerance, max(distances) + tolerance] 

        tau1 = config['tau_rise']
        tau2 = config['tau_decay']
        EPSC_amp = config['EPSC_amplitude']
        basal= []
        for sec in h.basal:
            basal.append(sec)
        basal = list(basal)

        num = synapse_number
        seed = 1
        AMPA_weight = 0.0003 #an approximate value that originates from HippoUnit's PSP attenuation test

        we_want_all = False

        interval_bw_trains = 300
        interval_bw_stimuli_in_train = 50
        num_trains = 3
        num_stimuli_in_train = 1

        dend = h.dend[0] #if we want to use all possible dendritic locations to recieve a synapse - typically not used
        dendrite = dend
        sect_loc = h.soma(0.5)

        all_dend_loc_0_100 = []
        all_dend_loc_100_200 = []
        all_dend_loc_distal = []

        all_location_distances = {}

        locations = []
        location_distances = {}

        kumm_length_list = []
        kumm_length = 0
        num_of_secs = 0

        for sec in h.basal:
            num_of_secs += sec.nseg
            kumm_length += sec.L
            kumm_length_list.append(kumm_length)

        if we_want_all == True:
            for sec in h.basal:
                for seg in sec:
                    if h.distance(seg.x, sec=sec) <= 100:
                            all_dend_loc_0_100.append([sec.name, seg.x])
                            all_locations_distances[sec.name(), seg.x] = h.distance(seg.x, sec=sec)
                    if h.distance(seg.x, sec=sec) > 100 and h.distance(seg.x, sec=sec) <= 200:
                            all_dend_loc_100_200.append([sec.name, seg.x])
                            all_locations_distances[sec.name(), seg.x] = h.distance(seg.x, sec=sec)
                    if h.distance(seg.x, sec=sec) > 200:
                            all_dend_loc_dist.append([sec.name, seg.x])
                            all_locations_distances[sec.name(), seg.x] = h.distance(seg.x, sec=sec)

        else:     # the chance of a dendritic section to get selected is proportionate to the length of the dendritic section

            norm_kumm_length_list = [i/kumm_length_list[-1] for i in kumm_length_list]

            import random

            _num_ = num  # _num_ will be changed
            num_iterations = 0

            while len(locations) < num and num_iterations < 50 :
                random.seed(seed)
                rand_list = [random.random() for j in range(_num_)]

                for rand in rand_list:

                    for i in range(len(norm_kumm_length_list)):
                        if rand <= norm_kumm_length_list[i] and (rand > norm_kumm_length_list[i-1] or i==0):

                            seg_loc = (rand - norm_kumm_length_list[i-1]) / (norm_kumm_length_list[i] - norm_kumm_length_list[i-1])

                            segs = [seg.x for seg in basal[i]]
                            d_seg = [abs(seg.x - seg_loc) for seg in basal[i]]
                            min_d_seg = numpy.argmin(d_seg)
                            segment = segs[min_d_seg]
                            h.distance(sec=h.soma)
                            h('access ' + basal[i].name())
                            if h.distance(segment) >= dist_range[0] and h.distance(segment) < dist_range[1]:
                                locations.append([basal[i].name(), segment])
                                location_distances[basal[i].name(), segment] = h.distance(segment)
                _num_ = num - len(locations)
                seed += 10
                num_iterations += 1


        locations_weights = [AMPA_weight] * len(locations)
        synapse_lists = {}

        def set_ampa_nmda_multiple_loc_theta(dend_loc, AMPA_weight, interval_bw_trains, interval_bw_stimuli_in_train, num_trains, num_stimuli_in_train):

            start = 200
            for i in range(len(dend_loc)):

                ndend, xloc = dend_loc[i]

                exec("dend=h." + ndend)

                synapse_lists['ampa_list'][i] = h.Exp2Syn(xloc, sec=dend)
                synapse_lists['ampa_list'][i].tau1 = tau1
                synapse_lists['ampa_list'][i].tau2 = tau2

            for j in range(num_trains):
                for i in range(len(dend_loc)):

                    synapse_lists['ns_list'][j][i] = h.NetStim()
                    synapse_lists['ns_list'][j][i].number = num_stimuli_in_train
                    synapse_lists['ns_list'][j][i].interval = interval_bw_stimuli_in_train
                    synapse_lists['ns_list'][j][i].start = start + j * interval_bw_trains 
                    synapse_lists['ns_list'][j][i].noise = noise

                    synapse_lists['ampa_nc_list'][j][i] = h.NetCon(synapse_lists['ns_list'][j][i], synapse_lists['ampa_list'][i], 0, 0, 0)
                    synapse_lists['ampa_nc_list'][j][i].weight[0] = AMPA_weight

            return synapse_lists

        def activate_theta_stimuli(dend_loc, AMPA_weight, interval_bw_trains, interval_bw_stimuli_in_train, num_trains, num_stimuli_in_train):


            synapse_lists.update({'ampa_list' : [None] * len(dend_loc),
                                'ampa_nc_list' : [[None]*len(dend_loc) for i in range(num_trains)],
                                'ns_list' : [[None]*len(dend_loc) for i in range(num_trains)] 
                                })

            set_ampa_nmda_multiple_loc_theta(dend_loc, AMPA_weight, interval_bw_trains, interval_bw_stimuli_in_train, num_trains, num_stimuli_in_train)
            print(synapse_lists)


        def run_simulation(dend_loc, recording_loc):

            (rec_ndend, xloc), distance = recording_loc

            exec("dendrite=h." + rec_ndend)

            exec("sect_loc=h.soma" + "("+str(0.5)+")")

            rec_t = h.Vector()
            rec_t.record(h._ref_t)

            rec_v = h.Vector()
            rec_v.record(sect_loc._ref_v)

            rec_v_dend = h.Vector()
            rec_v_dend.record(dendrite(xloc)._ref_v)
            
            
            v_stim = []
            dend_loc_rec =[]
            
            for i in range(len(dend_loc)):
                exec("dend_loc_rec.append(h." + str(dend_loc[i][0])+"("+str(dend_loc[i][1])+"))")
                v_stim.append(h.Vector())

            for i in range(len(dend_loc_rec)):
                v_stim[i].record(dend_loc_rec[i]._ref_v)


            h.stdinit()
            dt = 0.025
            h.dt = dt
            h.steps_per_ms = 1 / dt
            h.v_init = -60
            h.celsius = 24
            h.init()
            h.tstop = 1600
            h.run()
            t = numpy.array(rec_t)
            v = numpy.array(rec_v)
            v_dend = numpy.array(rec_v_dend)

            
            v_stim_locs = collections.OrderedDict()
            for i in range(len(dend_loc)):
                loc_key = (dend_loc[i][0],dend_loc[i][1]) # list can not be a key, but tuple can
                v_stim_locs[loc_key] = numpy.array(v_stim[i])     # the list that specifies dendritic location will be a key too.

            return t, v, v_dend, v_stim_locs

        dend_loc = locations

        recording_loc = min(location_distances.items(), key=lambda kv : abs(kv[1] - distances[0]))
        activate_theta_stimuli(dend_loc, AMPA_weight, interval_bw_trains, interval_bw_stimuli_in_train, num_trains, num_stimuli_in_train)
        t, v, v_dend, v_stim_locs = run_simulation(dend_loc, recording_loc)

        """ Saving the data for further evalutations """

        somatic_trace = numpy.column_stack([t, v])
        if not os.path.exists('./' + model_name + '/'):
            os.makedirs('./' + model_name + '/')

        numpy.savetxt('./' + model_name + '/'+ model_name + '_' + str(num_sim) + '_' + str(distance) + '_noise_' + str(noise) +'_soma_rec_silent_broad.dat', somatic_trace)

        """ PLOTING """

        plt.plot(t,v)
        #plt.savefig(model_name + '_' + str(num_sim) + '_soma_rec_noise1_150.svg')


model_name = 'HS_5091'
#parameters = [0.132949097283729, 0.000240160632407, 0.00447006700677, 1.81474244381339E-05, 0.007654001148095, 0.000231459226207, 7.25680265357357E-06, 0.00033549468495, 1.81310425780781E-05, 0.000131238332447, 0.021268570396291, 0.034795036625969, 0.041175850157676, 2.0907442977788E-05]
parameters = [0.164282138440483, 0.000259431698957, 0.003902611944949, 1.90338429505871E-05, 0.008694201572264, 0.000236222653756, 5.85922818126927E-06, 0.000458756101385, 2.08272108697804E-05, 8.95830420293865E-05, 0.017177093294903, 0.043859873006251, 0.095349775221684, 4.38128611779177E-05]

noise = [0, 1]

start_num_syn = 100
end_num_syn = 1000
increase_syn_by = 100

stim_distances = [[200]] #here you can specify the location distances. If you want to stimulate the cell in the first 100 um AND the from 300 to 400 um you have to type [[50], [350]] with width = 50. Currently [[200]] with width = 200 setup the script will select dendritic location in the first 400 um of the dendritic arbor
width = 200

for z in noise:
    for j in stim_distances:
        for i in range(start_num_syn, end_num_syn, increase_syn_by):

            modellke(model_name, i, i, z, parameters, j, width)

            plt.savefig('./' + model_name + '/'+ model_name + '_' + str(i) + '_' + str(j) + '_noise_' + str(z) +'_soma_rec_silent_broad.svg')
            
            plt.close()

