#!/usr/bin/env python3
# -*- coding: utf-8 -*-


# Scripts to reproduce the results from
#
# T. Richter, R Ulrich, M. Janczyk:
#    "Diffusion models with time-dependent parameters:
#     Comparing the computation effort and accuracy
#     of different numerical methods"
#
# Thomas Richter
# Otto-von-Guericke University of Magdeburg
# 39106 Magdeburg, Germany
# thomas.richter@ovgu.de
#
# You can use this code under ther terms of the
# Creative Commons Attribution 4.0 License



import numpy as np
import matplotlib.pyplot as plt


def randomwalk(model, disc):
    T = disc['T']
    b = model.b
    
    dx = disc['dx']
    assert dx>0, 'Spatial step size must be positive'
    
    dt = disc['dt']
    assert dt>0, 'Time step size must be positive'

    nX = int(2*b/dx+1.e-10)

    nT = int(disc['T'] /dt + 1.e-10)
    
    assert T>0,    'max time must be postive'
    assert b>0,    'threshold must be positive'
    assert nX>0,   'space steps must be positive'
    assert nX%2==0,'number of space steps must be even'

    
    assert model.sigma*model.sigma*dt/dx/dx<=1,'parabolic time step condition!'
    assert 0.5*model.mu*dt/dx<=1           ,'hyperbolic time step condition!'

    x = np.zeros( (nX+1) )    # solution vector
    
    xx = np.linspace(0,1.,nX+1)   # initial distribution
    if model.alpha>0:
        xx = np.linspace(0,1.,nX+1) # initial value Beta-Distribution
        x = xx**(model.alpha-1)*(1.0-xx)**(model.alpha-1)
        x=x/np.sum(x)               # normalize
    else:
        x[nX//2] = 1

#    x = xx**4.0*(1.0-xx)**4.0
    
#    x=x/np.sum(x)

    pdf_u = np.zeros(nT+1)
    pdf_l = np.zeros(nT+1)

    for n in range(nT):       # time loop
        t = (n+1)*dt
        
       
        pu = model.sigma*model.sigma*dt/2./dx/dx+model.mu*dt/2./dx
        pl = model.sigma*model.sigma*dt/2./dx/dx-model.mu*dt/2./dx
        pm = 1.0-pu-pl

        pdf_u[n+1]=pu * x[-2] / dt
        pdf_l[n+1]=pl * x[1]  / dt
        x[1:-1] = pm * x[1:-1] + pu * x[0:-2] + pl * x[2:]
  
    # We ignore all trials that do not reach the upper or lower
    # margin and rescale the pobabilities to sum up to one
    scale = np.sum(pdf_u)*dt + np.sum(pdf_l) * dt

    return [pdf_u/scale,pdf_l/scale]


## Allows for variable drift rate
## Thresholds are fixed
def randomwalk_variabledrift(model, disc, params):
    T = disc['T']
    b = disc['bmax']
    
    dx = disc['dx']
    assert dx>0, 'Spatial step size must be positive'
    
    dt = disc['dt']
    assert dt>0, 'Time step size must be positive'

    nX = int(2*b/dx+1.e-10)

    nT = int(disc['T'] /dt + 1.e-10)
    
    assert T>0,    'max time must be postive'
    assert b>0,    'threshold must be positive'
    assert nX>0,   'space steps must be positive'
    assert nX%2==0,'number of space steps must be even'

    assert model.sigma*model.sigma*dt/dx/dx<=1,'parabolic time step condition!'


#    for t in np.linspace(0,T,nT):
#        if 0.5*model.mu(t,params)*dt/dx>1:
#            print(params['tau'],t,model.mu(t,params))
#        assert 0.5*model.mu(t,params)*dt/dx<=1           ,'hyperbolic time step condition!'
    
    x = np.zeros( (nX+1) )    # solution vector
   
    xx = np.linspace(0,1.,nX+1)   # initial distribution
    if model.alpha>0:
        xx = np.linspace(0,1.,nX+1) # initial value Beta-Distribution
        x = xx**(model.alpha-1)*(1.0-xx)**(model.alpha-1)
        x=x/np.sum(x)               # normalize
    else:
        x[nX//2] = 1

#    xx = np.linspace(0,1.,nX+1)   # initial distribution
#    x = xx**4.0*(1.0-xx)**4.0
    x=x/np.sum(x)

    pdf_u = np.zeros(nT+1)
    pdf_l = np.zeros(nT+1)

    diffusion = model.sigma*model.sigma/2.0


    for n in range(nT):       # time loop
        t = (n+1)*dt

        transport = model.mu(t-dt,params)
        
        pu = diffusion*dt/dx/dx + transport*dt/2./dx
        pl = diffusion*dt/dx/dx - transport*dt/2./dx
        pm = 1.0 - pu - pl


        pdf_u[n+1]=pu * x[-2] / dt
        pdf_l[n+1]=pl * x[1]  / dt
        x[1:-1] = pm * x[1:-1] + pu * x[0:-2] + pl * x[2:]
  
    # We ignore all trials that do not reach the upper or lower
    # margin and rescale the pobabilities to sum up to one
    scale = np.sum(pdf_u)*dt + np.sum(pdf_l) * dt

    return [pdf_u/scale,pdf_l/scale]



def randomwalk_variable(model, disc, params = None):
    T    = disc['T']
    dx   = disc['dx']
    dt   = disc['dt']
    
    nT   = int(T/dt+1.e-8)

    assert abs(T-nT*dt)<1.e-10, 'T not a multiple of the time step'

    # define discrete range of the state space to cover the thresholds

    timesteps = np.linspace(0,T,nT+1)   # discrete points in time
    bmax = np.max(model.b(timesteps,params))
    nX   = int(2*bmax/dx)+6
    if nX%2 == 0:            # number of steps must be odd
        nX = nX + 1          # (nX-1)/2 is zero-line
    
    zeroline = dx*nX/2       # position of the zero-line  
    
    assert T>0,    'max time must be postive'
    assert bmax>0, 'threshold must be positive'
    assert nX>0,   'space steps must be positive'
    assert nX%2==1,'number of space steps must be odd'

    assert model.sigma*model.sigma*dt/dx/dx<=1,'parabolic time step condition!'
#    assert 0.5*model.mu(0,params)*dt/dx<=1    ,'hyperbolic time step condition!'
    
    upper = zeroline + model.b(timesteps+0.5*dx,params)   # position of upper boundary in the interval midpoints
    lower = zeroline - model.b(timesteps+0.5*dx,params)   # position of lower boundary in the interval midpoints
    
    
    upper = (upper / dx + 1.e-10).astype(int)  # Indices of the boundary in the space-mesh
    lower = (lower / dx - 1.e-10).astype(int)  # Indices of the boundary in the space-mesh

    assert np.min(lower)>0
    assert np.max(upper)<nX-1
    

    assert(np.max(np.abs(upper[:-1]-upper[1:]))<=1), 'time step too large. Margin may not jump more than one element per time step'
    assert(np.max(np.abs(lower[1:]-lower[:-1]))<=1), 'time step too large. Margin may not jump more than one element per time step'
    
    upperdx = zeroline+model.b(timesteps+0.5*dx,params)-upper*dx  # boundary extends last element by this 
    lowerdx = (lower+1)*dx-(zeroline-model.b(timesteps+0.5*dx,params))  # here to the lower, but still, lowerdx is positive


    assert(np.max(upperdx) <= dx)
    assert(np.max(lowerdx) <= dx)

    assert(np.min(upperdx) >= 0)
    assert(np.min(lowerdx) >= 0)


    x = np.zeros( (nX+1) )    # solution vector
    
    xx = np.linspace(0,1.,nX+1)   # initial distribution
    if params['alpha']>0:
        xx = np.linspace(0,1.,nX+1) # initial value Beta-Distribution
        x = xx**(model.alpha-1)*(1.0-xx)**(model.alpha-1)
        x=x/np.sum(x)               # normalize
    else:
        x[nX//2] = 1

    cdf_u = np.zeros(nT+1)
    cdf_l = np.zeros(nT+1)
    pdf_u = np.zeros(nT+1)
    pdf_l = np.zeros(nT+1)

    diffusion = model.sigma*model.sigma/2.0

    for n in range(nT):       # time loop
        t = timesteps[n+1]
        
        transport = model.mu(t-dt,params)
        
        pu = diffusion*dt/dx/dx + transport*dt/2./dx
        pl = diffusion*dt/dx/dx - transport*dt/2./dx
        pm = 1.0 - pu - pl


        diff = upperdx[n]/dx # fraction within the last cell
        assert(diff>=0)
        assert(diff<=1)

            
        # compute probabilities for leaving
        # all from lower / upper
        # and parts of lower+1 and upper-1
        if upper[n+1] == upper[n]:
            pdf_u[n+1] = pu/dt * x[upper[n]] + (1.0-diff) * (pu/dt * x[upper[n]-1] + pm/dt * x[upper[n]])
            pdf_l[n+1] = pl/dt * x[lower[n]] + (1.0-diff) * (pl/dt * x[lower[n]+1] + pm/dt * x[lower[n]])
        
            # compute what will remain in the boundary layer cells
            up = diff * pm * x[upper[n]] + pu * x[upper[n]-1] * diff
            lo = diff * pm * x[lower[n]] + pl * x[lower[n]+1] * diff
            
            # transport.. In domain + one layer
#            x[lower[n]-1:upper[n]+2] = pm * x[lower[n]-1:upper[n]+2] + pu * x[lower[n]-2:upper[n]+1] + pl * x[lower[n]:upper[n]+3]

            x[lower[n]:upper[n]+1] = pm * x[lower[n]:upper[n]+1] + pu * x[lower[n]-1:upper[n]] + pl * x[lower[n]+1:upper[n]+2]
        
            # set boundary cells
            x[upper[n+1]] = up
            x[lower[n+1]] = lo
        else:
            assert(upper[n+1] == upper[n] - 1 ) # X
            assert(lower[n+1] == lower[n] + 1 ) #  XX
            
            nextdiff = upperdx[n+1]/dx # fraction within the last cell
            assert(nextdiff>=0)
            assert(nextdiff<=1)
            
            
            pdf_u[n+1] = (pm+pu)/dt * x[upper[n]] + pu/dt * x[upper[n]-1] + (1.0-nextdiff) * (pl/dt * x[upper[n]] + pm/dt*x[upper[n]-1] + pu/dt * x[upper[n]-2])
            pdf_l[n+1] = (pm+pl)/dt * x[lower[n]] + pl/dt * x[lower[n]+1] + (1.0-nextdiff) * (pu/dt * x[lower[n]] + pm/dt*x[lower[n]+1] + pl/dt * x[lower[n]+2])

#            pdf_u[n+1] = (pm+pu)/dt * x[upper[n]] + pu/dt * x[upper[n]-1] + pl/dt * x[upper[n]] * diff + pm/dt * x[upper[n]-1] * diff + pu/dt * diff * x[upper[n]-2]
#            pdf_l[n+1] = (pm+pl)/dt * x[lower[n]] + pl/dt * x[lower[n]+1] + pu/dt * x[lower[n]] * diff + pm/dt * x[lower[n]+1] * diff + pl/dt * diff * x[lower[n]+2]
 
            # set new upper / lower element           
            up = nextdiff * (pm * x[upper[n]-1] + pu * x[upper[n]-2] + pl * x[upper[n]])
            lo = nextdiff * (pm * x[lower[n]+1] + pl * x[lower[n]+2] + pu * x[lower[n]])

            x[lower[n]:upper[n]+1] = pm * x[lower[n]:upper[n]+1] + pu * x[lower[n]-1:upper[n]] + pl * x[lower[n]+1:upper[n]+2]

            x[upper[n+1]] = up
            x[lower[n+1]] = lo
 
        # set all outside to zero
        x[:lower[n+1]]   = 0  
        x[upper[n+1]+1:] = 0
        x[:lower[n]]   = 0  
        x[upper[n]+1:] = 0

    # We ignore all trials that do not reach the upper or lower
    # margin and rescale the pobabilities to sum up to one
    scale = disc['dt'] * (np.sum(pdf_u) + np.sum(pdf_l))
    return [pdf_u/scale,pdf_l/scale]
