# Scripts to reproduce the results from
#
# T. Richter, R Ulrich, M. Janczyk:
#    "Diffusion models with time-dependent parameters:
#     Comparing the computation effort and accuracy
#     of different numerical methods"
#
# Thomas Richter
# Otto-von-Guericke University of Magdeburg
# 39106 Magdeburg, Germany
# thomas.richter@ovgu.de
#
# You can use this code under ther terms of the
# Creative Commons Attribution 4.0 License



import numpy as np
from PythonTools import kfe
from PythonTools import randomwalk
import matplotlib.pyplot as plt


### randomwalk without the modification presented in the paper.
###
def randomwalk_variable_nofix(model, disc, params = None):
    T    = disc['T']
    dx   = disc['dx']
    dt   = disc['dt']
    
    nT   = int(T/dt+1.e-8)

    assert abs(T-nT*dt)<1.e-10, 'T not a multiple of the time step'

    # define discrete range of the state space to cover the thresholds

    timesteps = np.linspace(0,T,nT+1)   # discrete points in time
    bmax = np.max(model.b(timesteps,params))
    nX   = int(2*bmax/dx)+6
    if nX%2 == 0:            # number of steps must be odd
        nX = nX + 1          # (nX-1)/2 is zero-line
    
    zeroline = dx*nX/2       # position of the zero-line  
    
    assert T>0,    'max time must be postive'
    assert bmax>0, 'threshold must be positive'
    assert nX>0,   'space steps must be positive'
    assert nX%2==1,'number of space steps must be odd'

    assert model.sigma*model.sigma*dt/dx/dx<=1,'parabolic time step condition!'
    
    upper = zeroline + model.b(timesteps+0.5*dx,params)   # position of upper boundary in the interval midpoints
    lower = zeroline - model.b(timesteps+0.5*dx,params)   # position of lower boundary in the interval midpoints
    
    
    upper = (upper / dx + 1.e-10).astype(int)  # Indices of the boundary in the space-mesh
    lower = (lower / dx - 1.e-10).astype(int)  # Indices of the boundary in the space-mesh

    assert np.min(lower)>0
    assert np.max(upper)<nX-1
    

    assert(np.max(np.abs(upper[:-1]-upper[1:]))<=1), 'time step too large. Margin may not jump more than one element per time step'
    assert(np.max(np.abs(lower[1:]-lower[:-1]))<=1), 'time step too large. Margin may not jump more than one element per time step'
    
    upperdx = zeroline+model.b(timesteps+0.5*dx,params)-upper*dx  # boundary extends last element by this 
    lowerdx = (lower+1)*dx-(zeroline-model.b(timesteps+0.5*dx,params))  # here to the lower, but still, lowerdx is positive

    assert(np.max(upperdx) <= dx)
    assert(np.max(lowerdx) <= dx)

    assert(np.min(upperdx) >= 0)
    assert(np.min(lowerdx) >= 0)


    x = np.zeros(nX+1)                  # initial 
    xx = np.linspace(0,1.,upper[0]-lower[0]+1)   
    assert model.alpha>0
    x[lower[0]:upper[0]+1] = xx**(model.alpha-1)*(1.0-xx)**(model.alpha-1)
    x=x/np.sum(x)

    pdf_u = np.zeros(nT+1)
    pdf_l = np.zeros(nT+1)

    diffusion = model.sigma*model.sigma/2.0

    for n in range(nT):       # time loop
        t = timesteps[n+1]
        
        transport = model.mu(t-dt,params)
        
        pu = diffusion*dt/dx/dx + transport*dt/2./dx
        pl = diffusion*dt/dx/dx - transport*dt/2./dx
        pm = 1.0 - pu - pl

        diff = 1.
            
        # compute probabilities for leaving
        # all from lower / upper
        # and parts of lower+1 and upper-1
        if upper[n+1] == upper[n]:
            pdf_u[n+1] = pu/dt * x[upper[n]] + (1.0-diff) * (pu/dt * x[upper[n]-1] + pm/dt * x[upper[n]])
            pdf_l[n+1] = pl/dt * x[lower[n]] + (1.0-diff) * (pl/dt * x[lower[n]+1] + pm/dt * x[lower[n]])
        
            # compute what will remain in the boundary layer cells
            up = diff * pm * x[upper[n]] + pu * x[upper[n]-1] * diff
            lo = diff * pm * x[lower[n]] + pl * x[lower[n]+1] * diff

            x[lower[n]:upper[n]+1] = pm * x[lower[n]:upper[n]+1] + pu * x[lower[n]-1:upper[n]] + pl * x[lower[n]+1:upper[n]+2]
        
            # set boundary cells
            x[upper[n+1]] = up
            x[lower[n+1]] = lo
        else:
            assert(upper[n+1] == upper[n] - 1 ) # X
            assert(lower[n+1] == lower[n] + 1 ) #  XX
            
            nextdiff = upperdx[n+1]/dx # fraction within the last cell
            assert(nextdiff>=0)
            assert(nextdiff<=1)
            
            
            pdf_u[n+1] = (pm+pu)/dt * x[upper[n]] + pu/dt * x[upper[n]-1] + (1.0-nextdiff) * (pl/dt * x[upper[n]] + pm/dt*x[upper[n]-1] + pu/dt * x[upper[n]-2])
            pdf_l[n+1] = (pm+pl)/dt * x[lower[n]] + pl/dt * x[lower[n]+1] + (1.0-nextdiff) * (pu/dt * x[lower[n]] + pm/dt*x[lower[n]+1] + pl/dt * x[lower[n]+2])
 
            # set new upper / lower element           
            up = nextdiff * (pm * x[upper[n]-1] + pu * x[upper[n]-2] + pl * x[upper[n]])
            lo = nextdiff * (pm * x[lower[n]+1] + pl * x[lower[n]+2] + pu * x[lower[n]])

            x[lower[n]:upper[n]+1] = pm * x[lower[n]:upper[n]+1] + pu * x[lower[n]-1:upper[n]] + pl * x[lower[n]+1:upper[n]+2]

            x[upper[n+1]] = up
            x[lower[n+1]] = lo
 
        # set all outside to zero
        x[:lower[n+1]]   = 0  
        x[upper[n+1]+1:] = 0
        x[:lower[n]]   = 0  
        x[upper[n]+1:] = 0

    # We ignore all trials that do not reach the upper or lower
    # margin and rescale the pobabilities to sum up to one
    scale = disc['dt'] * (np.sum(pdf_u) + np.sum(pdf_l))
    return [pdf_u/scale,pdf_l/scale]



class TestCase2:
   
    sigma = 4
    alpha = 5
    
    def mu(self,t,params):
        return 0.5+params['Adrift']/params['tau']*np.exp(1.0-t/params['tau'])*(1-t/params['tau'])

    def b(self,t,parameters = None):
        return 75.0 * (1.0-0.6*(t)/(t+150.0))

    def dt_b(self,t,parameters = None):
        return 75.0 * (-0.6) * (1.0/(t+150.0) - t/(t+150.0)/(t+150.0))

model = TestCase2()


### Parameters to define the time-depending drift
# For TestCase 2 they are fixed, no parameter identification
params = {
    'mu0'    : 0.5,
    'Adrift' : 20,
    'tau'    : 150.0,
    
    # alpha=0 for initial value at x=0
    'alpha'  : 5.0,
    }

disc = {
        'T' : 1000,
        'dt' : 0,
        'dx' : 0
        }


# space Discretization parameters used in Table 1.
NX_kfe = 160
# step size for kfe
DX_kfe = 2.0/NX_kfe                           ## step sizes dx (based on ALE-Inverval [-1,1])


## TT_kfe/DT_kfe are the number of time steps / step sizes used for the KFE simulations. 
NT_kfe     = 10*NX_kfe
DT_kfe     = disc['T'] / NT_kfe


## Time steps for Random Walks
NT_rw      = NX_kfe*NX_kfe//4*3
DT_rw      = 1000.0/NT_rw

DX_rw = np.sqrt(DT_rw) * model.sigma
NX_rw = (150.0/DX_rw).astype(int)
DX_rw = 150.0/NX_rw


print('Random Walk')
print('dx\t\tdt\t\ttime\t\tErr')

disc['bmax'] = 75
disc['dt']   = DT_rw
disc['dx']   = DX_rw

[rw_u,rw_l]=randomwalk.randomwalk_variable(model,disc,params)
Nvisu = len(rw_u)//5


[rw_u_nf,rw_l_nw]=randomwalk_variable_nofix(model,disc,params)
Nvisu = len(rw_u_nf)//5


disc['bmax'] = 0
disc['dt']   = DT_kfe
disc['dx']   = DX_kfe
disc['theta'] = 0.5+0.05*disc['dt']

[kfe_u,kfe_l,fs]   = kfe.kfe_ale(model, disc, params)
Nvisu = len(kfe_u)//5

plt.figure(figsize=(12, 3))


plt.subplot(121)
plt.xlabel('time')
plt.ylabel('probability (upper threshold)')
plt.plot(np.linspace(0,200,len(kfe_u)//5),kfe_u[:len(kfe_u)//5],label='KFE',linewidth=1)
plt.plot(np.linspace(0,200,len(rw_u)//5),rw_u[:len(rw_u)//5],label='Random walks (modified)',linewidth=2)
plt.plot(np.linspace(0,200,len(rw_u_nf)//5),rw_u_nf[:len(rw_u_nf)//5],label='Random walks')
plt.legend()

plt.subplot(122)
plt.xlabel('time')

plt.plot(np.linspace(25,50,len(kfe_u)//40),kfe_u[len(kfe_u)//40:len(kfe_u)//20],label='KFE')
plt.plot(np.linspace(25,50,len(rw_u)//40),rw_u[len(rw_u)//40:len(rw_u)//20],label='Random walks (modified)',linewidth=2)
plt.ylim(0.006,0.014)
plt.plot(np.linspace(25,50,len(rw_u_nf)//40),rw_u_nf[len(rw_u_nf)//40:len(rw_u_nf)//20],label='Random walks')
plt.legend()

plt.savefig('pics/testcase2-adjustrandomwalk.png', dpi=300, bbox_inches = 'tight')
plt.show()
