#!/usr/bin/env python3
# -*- coding: utf-8 -*-


# Scripts to reproduce the results from
#
# T. Richter, R Ulrich, M. Janczyk:
#    "Diffusion models with time-dependent parameters:
#     Comparing the computation effort and accuracy
#     of different numerical methods"
#
# Thomas Richter
# Otto-von-Guericke University of Magdeburg
# 39106 Magdeburg, Germany
# thomas.richter@ovgu.de
#
# You can use this code under ther terms of the
# Creative Commons Attribution 4.0 License

import numpy as np


def f(model,params,ai,t,aj,ta, muINT_t, muINT_ta, SQRT):
    return 1./SQRT * np.exp(
        -(ai-aj-(muINT_t - muINT_ta))**2.0/(
            2.0*model.sigma*model.sigma*(t-ta))
        )
def psi(model,params, ai,t, aj, ta, dai, mut, muINT_t, muINT_ta, SQRT):
    return f(model,params,ai,t,aj,ta, muINT_t, muINT_ta, SQRT)/2.0*(dai-mut-(ai-aj-(muINT_t - muINT_ta))/(t-ta))
        

def ff(TT,B1,B2,dt_B1,muINT,SQRT,SIGMA,dak,muk,k,jj):
    return 1./SQRT[k-jj] * np.exp(
        -(B1[k]-B2[jj]-(muINT[k]-muINT[jj]))**2.0 / (SIGMA*(TT[k]-TT[jj])))/2.0*(
            dak-muk-(B1[k]-B2[jj]-(muINT[k]-muINT[jj]))/(TT[k]-TT[jj])
            )



def im(model,params,disc):
    T = disc['T']
    assert T>0,    'max time must be postive'
    
    b = model.b(0,params)
    assert b>0,    'threshold must be positive'
    
    dt = disc['dt']
    assert dt>0,   'time steps must be postive'
    
    nX = disc['nX']
    Z = np.linspace(-b,b,nX+1)    
    
    nT = int(T/dt+1.e-10)
    g1 = np.zeros( (nT+1,nX+1) )
    g2 = np.zeros( (nT+1,nX+1) )
    
    TT = np.linspace(0,T,nT+1)
    
    # precompute drift rate and integral of drift rate for efficiency
    MU    = model.mu(TT, params) 
    MUINT = model.muINT(TT, params) 
    SQRT  = np.sqrt(2.0*np.pi*model.sigma*model.sigma*TT)
    B     = model.b(TT, params) 
    dtB   = model.dt_b(TT, params) 
    
    for k in range(1,nT):
        g1[k,:] = - 2. * psi(model,params, model.b(dt*k),dt*k,Z,0, model.dt_b(dt*k,params), MU[k], MUINT[k], MUINT[0], SQRT[k])
        g2[k,:] = + 2. * psi(model,params,-model.b(dt*k),dt*k,Z,0,-model.dt_b(dt*k,params), MU[k], MUINT[k], MUINT[0], SQRT[k])

        F11 = ff(TT, B, B, dtB,MUINT, SQRT, 2.0 * model.sigma*model.sigma, model.dt_b(dt*k,params), MU[k], k, np.arange(1,k))
        F12 = ff(TT, B,-B, dtB,MUINT, SQRT, 2.0 * model.sigma*model.sigma, model.dt_b(dt*k,params), MU[k], k, np.arange(1,k))
        F21 = ff(TT,-B, B,-dtB,MUINT, SQRT, 2.0 * model.sigma*model.sigma,-model.dt_b(dt*k,params), MU[k], k, np.arange(1,k))
        F22 = ff(TT,-B,-B,-dtB,MUINT, SQRT, 2.0 * model.sigma*model.sigma,-model.dt_b(dt*k,params), MU[k], k, np.arange(1,k))

        g1[k,:] = g1[k,:] + 2.*dt*(np.dot(np.transpose(g1[1:k,:]),F11)+np.dot(np.transpose(g2[1:k,:]),F12))   
        g2[k,:] = g2[k,:] - 2.*dt*(np.dot(np.transpose(g1[1:k,:]),F21)+np.dot(np.transpose(g2[1:k,:]),F22))   

    x = np.zeros(nX+1)
    if params['alpha']>0:
        xx = np.linspace(0,1.,nX+1) # initial value Beta-Distribution
        x = xx**(params['alpha']-1)*(1.0-xx)**(params['alpha']-1)
        x=x/np.sum(x)               # normalize
    else:
        x[nX//2] = 1

    return np.dot(g1,x),np.dot(g2,x)


## Starting at x=0, no initial distribution
def imzero(model,params,disc):
    T = disc['T']
    assert T>0,    'max time must be postive'
    
    b = model.b(0,params)
    assert b>0,    'threshold must be positive'
    
    dt = disc['dt']
    assert dt>0,   'time steps must be postive'
    
    Z = [0]
    
    nT = int(T/dt+1.e-10)
    g1 = np.zeros( (nT+1) )
    g2 = np.zeros( (nT+1) )
    
    TT = np.linspace(0,T,nT+1)
    
    # precompute drift rate and integral of drift rate for efficiency
    MU    = model.mu(TT, params) 
    MUINT = model.muINT(TT, params) 
    SQRT  = np.sqrt(2.0*np.pi*model.sigma*model.sigma*TT)
    B     = model.b(TT, params) 
    dtB   = model.dt_b(TT, params) 
    
    for k in range(1,nT):
        g1[k] = - 2. * psi(model,params, model.b(dt*k),dt*k,Z,0, model.dt_b(dt*k,params), MU[k], MUINT[k], MUINT[0], SQRT[k])
        g2[k] = + 2. * psi(model,params,-model.b(dt*k),dt*k,Z,0,-model.dt_b(dt*k,params), MU[k], MUINT[k], MUINT[0], SQRT[k])

        F11 = ff(TT, B, B, dtB,MUINT, SQRT, 2.0 * model.sigma*model.sigma, model.dt_b(dt*k,params), MU[k], k, np.arange(1,k))
        F12 = ff(TT, B,-B, dtB,MUINT, SQRT, 2.0 * model.sigma*model.sigma, model.dt_b(dt*k,params), MU[k], k, np.arange(1,k))
        F21 = ff(TT,-B, B,-dtB,MUINT, SQRT, 2.0 * model.sigma*model.sigma,-model.dt_b(dt*k,params), MU[k], k, np.arange(1,k))
        F22 = ff(TT,-B,-B,-dtB,MUINT, SQRT, 2.0 * model.sigma*model.sigma,-model.dt_b(dt*k,params), MU[k], k, np.arange(1,k))

        g1[k] = g1[k] + 2.*dt*(np.dot(np.transpose(g1[1:k]),F11)+np.dot(np.transpose(g2[1:k]),F12))   
        g2[k] = g2[k] - 2.*dt*(np.dot(np.transpose(g1[1:k]),F21)+np.dot(np.transpose(g2[1:k]),F22))   

    return g1,g2



### Simplified version with fixed threshold and drift
### This is not needed, just for simpler realization of
### test case 1. 

def ffixed(model,ai,t,aj,ta, muINT_t, muINT_ta, SQRT):
    return 1./SQRT * np.exp(
        -(ai-aj-(muINT_t - muINT_ta))**2.0/(
            2.0*model.sigma*model.sigma*(t-ta))
        )
def psifixed(model, ai,t, aj, ta, mut, muINT_t, muINT_ta, SQRT):
    return ffixed(model,ai,t,aj,ta, muINT_t, muINT_ta, SQRT)/2.0*(-mut-(ai-aj-(muINT_t - muINT_ta))/(t-ta))
        

def fffixed(TT,B1,B2,muINT,SQRT,SIGMA,muk,k,jj):
    return 1./SQRT[k-jj] * np.exp(
        -(B1-B2-(muINT[k]-muINT[jj]))**2.0 / (SIGMA*(TT[k]-TT[jj])))/2.0*(
            -muk-(B1-B2-(muINT[k]-muINT[jj]))/(TT[k]-TT[jj])
            )



def imfixed(model,disc):
    T = disc['T']
    assert T>0,    'max time must be postive'
    
    b = model.b
    assert b>0,    'threshold must be positive'
    
    dt = disc['dt']
    assert dt>0,   'time steps must be postive'
    
    nX = disc['nX']
    Z = np.linspace(-b,b,nX+1)    
    
    nT = int(T/dt+1.e-10)
    g1 = np.zeros( (nT+1,nX+1) )
    g2 = np.zeros( (nT+1,nX+1) )
    
    TT = np.linspace(0,T,nT+1)
    
    # precompute drift rate and integral of drift rate for efficiency
    MU    = model.mu
    MUINT = MU*TT
    SQRT  = np.sqrt(2.0*np.pi*model.sigma*model.sigma*TT)
    
    for k in range(1,nT):
        g1[k,:] = - 2. * psifixed(model, b,dt*k,Z,0, MU, MU*k*dt, 0, SQRT[k])
        g2[k,:] = + 2. * psifixed(model,-b,dt*k,Z,0, MU, MU*k*dt, 0, SQRT[k])

        F11 = fffixed(TT, b, b, MUINT, SQRT, 2.0 * model.sigma*model.sigma, MU, k, np.arange(1,k))
        F12 = fffixed(TT, b,-b, MUINT, SQRT, 2.0 * model.sigma*model.sigma, MU, k, np.arange(1,k))
        F21 = fffixed(TT,-b, b, MUINT, SQRT, 2.0 * model.sigma*model.sigma, MU, k, np.arange(1,k))
        F22 = fffixed(TT,-b,-b, MUINT, SQRT, 2.0 * model.sigma*model.sigma, MU, k, np.arange(1,k))

        g1[k,:] = g1[k,:] + 2.*dt*(np.dot(np.transpose(g1[1:k,:]),F11)+np.dot(np.transpose(g2[1:k,:]),F12))   
        g2[k,:] = g2[k,:] - 2.*dt*(np.dot(np.transpose(g1[1:k,:]),F21)+np.dot(np.transpose(g2[1:k,:]),F22))   

    x = np.zeros(nX+1)
    if model.alpha>0:
        xx = np.linspace(0,1.,nX+1) # initial value Beta-Distribution
        x = xx**(model.alpha-1)*(1.0-xx)**(model.alpha-1)
        x=x/np.sum(x)               # normalize
    else:
        x[nX//2] = 1

    return np.dot(g1,x),np.dot(g2,x)




