# Scripts to reproduce the results from
#
# T. Richter, R Ulrich, M. Janczyk:
#    "Diffusion models with time-dependent parameters:
#     Comparing the computation effort and accuracy
#     of different numerical methods"
#
# Thomas Richter
# Otto-von-Guericke University of Magdeburg
# 39106 Magdeburg, Germany
# thomas.richter@ovgu.de
#
# You can use this code under ther terms of the
# Creative Commons Attribution 4.0 License


import numpy as np
from scipy import optimize
import matplotlib.pyplot as plt
import time

from PythonTools import kfe
from PythonTools import tools

plt.rcParams['text.usetex'] = True


### Define the testcase. fixed boundary and variable drift
class TestCase4:
    sigma = 4
    
    
    def mu(self,t,p):
        assert p['a'] >= 2,'Parameter a must be larger than or equal to 2'
        return p['muc'] + p['A'] * np.exp(-t/p['tau']) * (np.exp(1.0)/(p['a']-1)/p['tau'])**(p['a']-1)*t**(p['a']-2)*( (p['a']-1)-t/p['tau']) 
 
    def b(self,t,p):
        # we must add the term '0*t' such that the function
        # returns a vector of size len(t) if t is a vector
        return p['b'] + 0*t

    def dt_b(self,t,params = None):
        return 0 * t


### Initialize the model
model = TestCase4()

### Define the discretization parameters
disc = {
        'T'  : 1000,
        'dt' : 1000.0/200,
        'dx' : 1.0/10
        }


############################ 
params = {
    # Variable drift
    'muc'    : 0.5,      # constant part of drift
    'A'      : 20,     
    'tau'    : 50,
    'a'      : 2,
    # Variability of the diffusion process
    'sigma'  : 4.0, 
    # Location of upper boundary
    'b'      : 75,     
    # Variability of offset
    'sigmaR' : 30,
    'muR'    : 300,
    # Initial condition: alpha>0 Beta-Distribution B(alpha,alpha)
    #    and alpha = 0, initial in center
    'alpha'  : 2,    
    } 


### sort densities into bins 
def cdfbin(cdf,nbins,disc):
    histo = np.histogram(cdf,nbins)                # generate histogramm
    return np.cumsum(histo[0])[:-1] * disc['dt']   # sum up & remove final
    

## Solves the KFE for a given set of model parameters
# 1st: the congruent case, 2nd: switch sign of A and solve the incongruent case
def solve_kfe(model, disc, params):
    [pdf_u_cong,pdf_l_cong,fs]   = kfe.kfe_ale(model, disc, params)     
    p_u_cong   = tools.add_residual(pdf_u_cong,   disc, params)
    p_l_cong   = tools.add_residual(pdf_l_cong,   disc, params)
    params['A'] = - params['A']                                  # switch to incongruent

    [pdf_u_incong,pdf_l_incong,fs] = kfe.kfe_ale(model, disc, params)
    p_u_incong = tools.add_residual(pdf_u_incong, disc, params)
    p_l_incong = tools.add_residual(pdf_l_incong, disc, params)
    params['A'] = - params['A']                                  # switch to congruent
    return p_u_cong,p_l_cong,p_u_incong,p_l_incong


### Defines the goal functional for data fitting
# p is the parameter vector
# args = (freeparameters,cdfdata)
# freeparameters is a list of all parameters used for fitting
def goal(p, *args):
    freeparameters = args[0]   # umpack arguments
    cdfdata        = args[1]
   
    for i in range(len(freeparameters)):         # set the parameters
        params[freeparameters[i]] = p[i]
 
    # solve the Kolmogorov Forward Equation 
    [p_u_cong,p_l_cong,p_u_incong,p_l_incong]   = solve_kfe(model, disc, params)     

    # compute cdf's 
    c_u_cong   = tools.pdf2cdf(p_u_cong,disc)    # cdf upper
    c_u_incong = tools.pdf2cdf(p_u_incong,disc)  # cdf lower
 
    # compute error in 19er-bins of the cdf's 
    ecdfcong   = cdfbin(c_u_cong,20,disc)   - cdfdata[:,0]  # compute errors
    ecdfincong = cdfbin(c_u_incong,20,disc) - cdfdata[:,1]
    
    # return the l2-error
    return (np.linalg.norm(ecdfcong)**2.0 + np.linalg.norm(ecdfincong)**2.0)



### MAIN


### Load data from Ulrich et al. 2015, Flanker Task
cdf_data = np.loadtxt('Data/ulrich-cdf-flanker.txt',skiprows=1)[:,2:4]

### Define the parameters that are free for the fitting process
freeparameters = ['A','tau','muc','alpha','b','muR','sigmaR']

### sigma is no free paramter as it can't be identified simultaneously
### with the other paramters. We take the value from the paper.
params['sigma'] = 3.98 # 


### Copy the free parameters to an array and set initial values
### p will be the variable used in optimization
p = np.zeros(len(freeparameters))
for i in range(len(freeparameters)):
    p[i] = params[freeparameters[i]]


### Run the parameter identification task for a sequence of
### discretizations to analyze the impact of discretization
### accuracy on the estimated parameters

for di in [1,2,3,4,5]:

    ### Start the Minimization
    t1 = time.perf_counter()
    optres = optimize.minimize(goal, p, args=(freeparameters, cdf_data), 
                               method='Nelder-Mead',options={'disp' : True, 'adaptive' : True, 'maxiter' : 5000},tol=0.00000001)
    t2 = time.perf_counter()
 

    ### Recompute the model with the identified parameters
    [p_u_cong,p_l_cong,p_u_incong,p_l_incong]   = solve_kfe(model, disc, params)     

    c_u_cong   = tools.pdf2cdf(p_u_cong,disc)
    c_l_cong   = tools.pdf2cdf(p_l_cong,disc)   
    c_u_incong = tools.pdf2cdf(p_u_incong,disc) 
    c_l_incong = tools.pdf2cdf(p_l_incong,disc) 

    bin_cong   = cdfbin(c_u_cong,  20,disc)
    bin_incong = cdfbin(c_u_incong,20,disc)

    yr = np.linspace(0.05,0.95,19)


    plt.figure(figsize=(5, 4))

    plt.title('Eriksen-Flanker task $\displaystyle \Delta t={0}$'.format(disc['dt']))
    plt.xlabel('Reacting time (ms)')
    plt.ylabel('Probability')
    plt.plot(bin_cong,      yr, '-o', markersize=3, color='blue',label='simulated congruent')
    plt.plot(bin_incong,    yr, '-o', markersize=3,color='orange',label='simulated incongruent')
    plt.plot(cdf_data[:,0], yr, 'o',color='darkblue' ,label='observed congruent')
    plt.plot(cdf_data[:,1], yr, 'o',color='darkorange',label='observed incongruent')
    plt.legend()
    plt.savefig('pics/eriksen-flanker_{0}.png'.format(disc['dt']), dpi=300)

    plt.show()


    print(disc['dt'],end='&')
    for n in ['sigma','b','muR','sigmaR','A','muc','tau','alpha']:
        print('{0:4.2f}'.format(params[n]),end='&')
    print('{0:4.2f}'.format(t2-t1),'\\\\')

    disc['dt'] = disc['dt']/2
    disc['dx'] = disc['dx']/2


print(params)

