import time
from matplotlib import pyplot as plt
from multiprocessing import Pool
from scipy.optimize import differential_evolution
from numpy import exp

import importlib
import os
import sys
import libsbml

import amici
import amici.plotting
import numpy as np

import pypesto
import pypesto.optimize as optimize
import pypesto.visualize as visualize
from scipy.signal import find_peaks
from scipy import signal
from numpy import log10
from cycler import cycler
from mpi4py import MPI

COMM = MPI.COMM_WORLD

from opt_settings_fixedparams import *
from Parallel import *

###############################################################################################################################################################################################
"""
This version uses @parFor. In opt_settings set the number of starts. When submitting the job file, choose an appropriate number of nodes to distribute the tasks.
We generate Figure 9 and 10 from the paper with this code
"""
###############################################################################################################################################################################################




# name of SBML file
sbml_file = "recovery_model_60_peaks_qssa_rates.xml"
# name of the model that will also be the name of the python module
model_name = "recovery_model_60_peaks_qssa_rates" 
#directory to which the generated model code is written
model_output_dir = "recovery_model_60_peaks_qssa_rates"#model_name


"""
If this is the first time running the code and there is no recovery_model_60_peaks_qssa_rates python module, run this to import it with amici 
"""

#sbml_reader = libsbml.SBMLReader() #blub
#sbml_doc = sbml_reader.readSBML(sbml_file)
#sbml_model = sbml_doc.getModel()
#import sbml model, compile and generate amici module
#sbml_importer = amici.SbmlImporter(sbml_file)
#sbml_importer.sbml2amici(model_name, model_output_dir, verbose=True,generate_sensitivity_code = False)



# load amici module (the usual starting point later for the analysis) os.path.abspath(model_output_dir)
#sys.path.insert(0, model_output_dir)
model_module =  amici.import_model_module(model_name,model_output_dir)
model = model_module.getModel()

""" 
Load the times and the solver 
"""


times = np.load('times.npy')
timestep=np.diff(times)[0]
model.setTimepoints(times)
model.setParameterScale(amici.ParameterScaling.log10)
print('tstep=',timestep)
solver = model.getSolver()
solver.setNewtonMaxSteps(100)
solver.setMaxSteps(1599999)

"""
Get the raw data as well as the parameter names
"""
# how to run amici now:
rdata = amici.runAmiciSimulation(model, solver,None)
amici.plotting.plotStateTrajectories(rdata)
#plt.savefig('nominal_plot')

old_param_names = list(model.getParameterIds())
new_param_names = deepcopy(old_param_names)
old_not_ = ['N_var','gV','gP','kR','t0','k_base','fac','L','nves','nsites']
new_not_ = ['$N$', '$g_V$', '$g_P$', '$k_R$', '$t_0$', '$k_0$', '$\gamma$', '$L$', '$n_{\\rm{ves}}$', '$n_{\\rm{sites}}$' ]
for i in range(len(old_not_)):
    id_= old_param_names.index(old_not_[i])
    new_param_names[id_] = new_not_[i]

# initialize the variables
variables = { name:10**np.array(model.getParameters())[i] for i,name in enumerate(model.getParameterIds()) }

t = amici.runAmiciSimulation(model, solver, None).t
R = amici.runAmiciSimulation(model, solver, None).x[:,0]

"""
for n, val in enumerate([20 for i in range(50)]):
    n+=61
    globals()["mu%d"%n] = val
"""

variables['t_wait'] = variables['t_wait'] - 0.0014*1.9
nsites = number_of_sites
model.setParameters(np.log10(np.array([variables[key] for key in variables])))

# print model information
print("Parameter values",model.getParameters())
print("Model name:", model.getName())
print("Model parameters:", model.getParameterIds())
print("Model outputs:   ", model.getObservableIds())
print("Model states:    ", model.getStateIds())



"""
The observed fusions multiplied by factor N. (N*\dot{F}). First amici solves the ODE given the parameters p. Then k_F(t) is calculated given the analytic representation.
Input: Parameters p, np array
Output: N*\dot{F} = N*k_F(t)*R(t)
"""
def NF(p):
	model.setParameters(p)
	t = amici.runAmiciSimulation(model, solver, None).t
	R = amici.runAmiciSimulation(model, solver, None).x[:,0]
	variables = {key:10**p[i] for i,key in enumerate(model.getParameterIds())}
	baseline = 10**(variables['L']*(1-exp(-variables['k_base']*((t*variables['fac']-variables['t_wait'])-variables['t0']))))
	peaks = np.sum([variables[f'a{i}']*exp(-0.5*((t*variables['fac']-variables['t_wait'])-variables[f'mu{i}'])**2/variables['sigma']**2) for i in range(1,61)],axis=0)
	kF = baseline + peaks
	return variables['N_var']*kF*R

"""
Define the mEPSC function. 
Input: tstep
Output: mEPSC mini current used for the convolution with \dot{F}k_F(t) to create the measured current
"""

def mEPSC_fun(tstep):
    ###Parameters, don't change!
    size_of_mini = 0.6e-9 #A, Amplitude of mEJC, Estimated from variance-mean of data (see Fig 2F)
    A = -7.209251536449789e-06
    B = 2.709256850482493e-09
    t_0 = 0
    tau_rf = 10.692783377261414
    tau_df =0.001500129264510
    tau_ds = 0.002823055510748#*0.6
    length_of_mini =34*1e-3
    
    """Return one mEPSC."""
    t = np.arange(0,length_of_mini,tstep)
    mEPSC = (t >= t_0)*(A*(1-np.exp(-(t-t_0)/tau_rf))*(B*np.exp(-(t-t_0)/tau_df) + (1-B)*np.exp(-(t-t_0)/tau_ds)))
    mEPSC = -(mEPSC/min(mEPSC) *size_of_mini)
    
    return mEPSC

"""
Calculate the resulting current by convolving with the impulse function mEPSC.
Input: Fusions NF =N*\dot{F}
Output: Current = N\dot{F}*mESPC (convolution)
"""
def current(NF):
    return signal.convolve(NF*timestep, mEPSC_fun(timestep))

"""
Objective function for the optimisation process. The peaks are used as reference. 
"""

t = amici.runAmiciSimulation(model, solver, None).t

#path strings for loading data
current_data = 'data/Current_data_animal'+str(animal)+'.npy'


stop_ind=[-9,-4,-5,-7,-9] # last peak / where to stop finding the peaks depending on the animal

#find the peaks in the real data to compare with the simulated ones

current_data=np.load(current_data)
#h=20000
#
top_peaks_data= find_peaks(current_data,distance=10,height=(-3.9e-8,-0.4e-8))[0][:stop_ind[animal-1]]
bottom_peaks_data = find_peaks(-current_data,3e-8)[0]


#definition of the objective function
                              

def f_obj(param):

    #get the current parameters
    param_Id = list(model.getParameterIds()[i] for i in param_indices)
    model.setParameterById(dict(zip(param_Id, param)))
    all_params = np.array(model.getParameters())

    #calculate the current
    R = amici.runAmiciSimulation(model, solver).x[:,0]
    t = amici.runAmiciSimulation(model, solver).t
    NF_model = NF(all_params)
    #NF_peaks_model = find_peaks(NF_model)[0]
    current_model = current(NF_model)
    
    #find the peaks in the simulated data
    top_peaks_model = find_peaks(current_model)[0]
    
    bottom_peaks_model = find_peaks(-current_model)[0]

    if (59 <= len(current_model[top_peaks_model])) & (60 <= len(current_model[bottom_peaks_model])):
        #return the difference at the peaks (top and bottom) multiplied by factor 10⁸
        err = ((current_data[top_peaks_data]*10**9 - current_model[top_peaks_model][:59]*10**9)**2).sum() + ((current_data[bottom_peaks_data]*10**9 - current_model[bottom_peaks_model][:60]*10**9)**2).sum()
        return err
    else: 
        #print("Inf")
        return np.inf


"""
the optimization function: input the parameter value that is fixed (due to parallelizing it comes from fixed_params_grid)
return in form:
[fixed_parameter_value, correspoinding multistart number, optimized parameter, loss value]
"""
def optimize_params_single_run(fixed_param_value):

    #initialize the fixed parameter with the value from the grid
    variables[c_fixed[0]]=10**fixed_param_value[0]
    model.setParameterById(dict(zip(c_fixed, [fixed_param_value[0]])))


    print("SETTING",list(model.getParameterIds()[i] for i in param_indices))
    print("with bounds",param_bounds)
    
    # Run differential evolution with parallel evaluation
    opt_ = differential_evolution(f_obj,popsize=10, bounds=param_bounds, tol=1e-3, workers=1, polish=False,**opt_options)
    print(opt_)
    optimized_params =np.array(opt_.x)
    res =np.array(np.concatenate(([fixed_param_value[0]],[fixed_param_value[1]],optimized_params,[f_obj(optimized_params)])))
    return res

#%%
@parFor(lst,COMM)
def optimize_wrapper(fixed_param_value,COMM):
    return optimize_params_single_run(fixed_param_value)

#parallel optimization for all initial values.
if __name__ == "__main__":

    #embarrassingly parallel for-loop

    start_time = time.perf_counter()
    result = optimize_wrapper(None,COMM)
    np.save(f"{run_dir}/result.npy",result)
    finish_time = time.perf_counter()

    print("Program finished in {} seconds".format(finish_time-start_time))




# %%
