from data_plot_utils import plotit
from problem_definitions import problem_definitions, get_delaysol

import numpy as np

from jitcdde import jitcdde, y, t
from symengine import tanh

continuotime = True #Make it False, if you don't need the plot for this time domain
discretetime = True #Make it False, if you don't need the plot for this time domain
mixedtime = True #Make it False, if you don't need the plot for this time domain


# ## define the problem coefficients
problem = '10x10'
Nsteps = 20
cfiltert = 5

pbdct = problem_definitions(problem=problem)
A = pbdct['A']
B = pbdct['B']
C = pbdct['C']
D = pbdct['D']
K = pbdct['K']
L = pbdct['L']
fcffs = pbdct['fcffs']
gcffs = pbdct['gcffs']
pastvals = pbdct['pastvals']

# ## problem rhs
n = A.shape[0]
nall = 4*n
xidx = range(n)


def get_bamrhs(omgone=None, omgtwo=None,
               taui=None, tauii=None, coupling=False):
    def bamrhs(myy=None):
        if myy is None:
            cy = y
            ffncs = [tanh]*n
            gfncs = [tanh]*n
        else:
            cy = myy
            ffncs = [np.tanh]*n
            gfncs = [np.tanh]*n
        for i in range(4*n):
            if i <= n-1:
                yield sum(- C[i, j]*cy(j)
                          + A[i, j]*fcffs[j]*ffncs[j](cy(j+n))
                          + B[i, j]*fcffs[j]*ffncs[j](cy(j+n, t-taui))
                          for j in xidx)
            elif i <= 2*n-1:
                yield sum(- D[i-n, j]*cy(j+n)
                          + K[i-n, j]*gcffs[j]*gfncs[j](cy(j))
                          + L[i-n, j]*gcffs[j]*gfncs[j](cy(j, t-tauii))
                          for j in xidx)
            elif i <= 3*n-1:
                unfrcdpart = sum(- C[i-2*n, j]*cy(j+2*n)
                                 + A[i-2*n, j]*fcffs[j]*ffncs[j](cy(j+3*n))
                                 + B[i-2*n, j]*fcffs[j]*ffncs[j](cy(j+3*n,
                                                                    t-taui))
                                 for j in xidx)
                if coupling:
                    frcdpart = sum(- omgone[i-2*n, j]*(cy(j+2*n)-cy(j))
                                   for j in xidx)
                    yield unfrcdpart + frcdpart
                else:
                    yield unfrcdpart
            elif i <= 4*n-1:
                unfrcdpart = sum(- D[i-3*n, j]*cy(j+3*n)
                                 + K[i-3*n, j]*gcffs[j]*gfncs[j](cy(j+2*n))
                                 + L[i-3*n, j]*gcffs[j]*gfncs[j](cy(j+2*n,
                                                                 t-tauii))
                                 for j in xidx)
                if coupling:
                    frcdpart = sum(- omgtwo[i-3*n, j]*(cy(j+3*n)-cy(j+n))
                                   for j in xidx)
                    yield unfrcdpart + frcdpart
                else:
                    yield unfrcdpart
    return bamrhs


if continuotime:
    stoptime = 25.0
    numpoints = 250
    Omega_one = pbdct['R']['Omega_one']
    Omega_two = pbdct['R']['Omega_two']
    tau_one = pbdct['R']['tau_one']
    tau_two = pbdct['R']['tau_one']

    fignum = [100, 200]
    for fignum, coupling in zip(fignum, [False, True]):
        bamrhs = get_bamrhs(omgone=Omega_one, omgtwo=Omega_two,
                            coupling=coupling, tauii=tau_two, taui=tau_one)
        DDE = jitcdde(bamrhs)
        DDE.constant_past(pastvals)
        times = DDE.t + np.linspace(0, stoptime, numpoints)

        # short pre-integration to take care of discontinuities
        # DDE.step_on_discontinuities()

        # integrating
        data = []
        for time in times:
            data.append(DDE.integrate(time))
        plotit(data, timerange=times, coupling=coupling, pname=problem,
               filterdata=cfiltert, fignum=fignum, nstates=n)


if discretetime:
    Omega_one = pbdct['hZ']['Omega_one']
    Omega_two = pbdct['hZ']['Omega_two']
    tau_one = pbdct['hZ']['tau_one']
    tau_two = pbdct['hZ']['tau_two']
    ddt = pbdct['hZ']['ddt']
    fignum1 = [1000, 2000]
    for fignum, coupling in zip(fignum1, [False, True]):
        bamrhs = get_bamrhs(omgone=Omega_one, omgtwo=Omega_two,
                            coupling=coupling, taui=tau_one, tauii=tau_two)

        def get_dsol(dsoldct=None):
            def dsol(idx, time=None):
                if time is None:
                    return dsoldct['csol'][idx]
                else:
                    tdct = time.as_coefficients_dict()
                    if (tdct[1] == -dsoldct['tau'] or
                            tdct[1] == -dsoldct['tau']):
                        return dsoldct['dsol'][idx]
                    else:
                        raise UserWarning('Wrong delay value')
            return dsol

        csol = pastvals.copy()  # solution at current time instance
        dsol = pastvals.copy()  # solution at delay time instance
        dsoldct = dict(t=0., tau=tau_one, csol=csol, dsol=dsol)
        sol = get_dsol(dsoldct)
        rhsinc = list(bamrhs(myy=sol))

        timerange = [k*ddt for k in range(Nsteps+1)]
        trackdelaytrange = timerange.copy()
        delaysol = get_delaysol(trackdelaytrange, pastvals)

        datadict = {0.: csol}
        for ctime in timerange[1:]:
            dsol = delaysol(ctime-tau_one, datadict)
            csol = [csol[k] + ddt*rhsinc[k] for k in range(nall)]
            dsoldct.update(dict(dsol=dsol, csol=csol, t=ctime))
            rhsinc = list(bamrhs(myy=sol))
            datadict.update({ctime: csol})

        data = [datadict[ctime] for ctime in timerange]
        plotit(data, timerange=np.array(timerange),
               fignum=fignum, pname=problem,
               timescale='discrete', coupling=coupling, nstates=n)

if mixedtime:
    Omega_one = pbdct['ots']['Omega_one']
    Omega_two = pbdct['ots']['Omega_two']
    tau_one = pbdct['ots']['tau_one']
    tau_two = pbdct['ots']['tau_one']
    intlngth = pbdct['ots']['intlngth']
    intdstnc = pbdct['ots']['intdstnc']

    intvals = 10
    locnts = 20
    flocnts = 8  # nts in the reduced timerange for plotting
    flctds = np.int(np.floor(locnts/flocnts))
    lctrng = np.linspace(0., intlngth, locnts+1)
    fltrt = np.arange(0, locnts+1, flctds)
    fltrt[-1] = locnts  # hard set the last val to the last val
    flctrng = lctrng[fltrt]  # reduced time for plotting
    timerange = []
    trl = []
    for kint in range(intvals):
        ctr = kint*intdstnc+flctrng
        trl.append(ctr)
        timerange.extend((kint*intdstnc+lctrng).tolist())

    fignum1 = [3000, 4000]
    for fignum, coupling in zip(fignum1, [False, True]):
        trackdelaytrange = timerange.copy()
        delaysol = get_delaysol(trackdelaytrange, pastvals)

        bamrhs = get_bamrhs(omgone=Omega_one, omgtwo=Omega_two,
                            coupling=coupling, taui=tau_one, tauii=tau_two)

        def get_dsol(dsoldct=None):
            def dsol(idx, time=None):
                if time is None:
                    return dsoldct['csol'][idx]
                else:
                    tdct = time.as_coefficients_dict()
                    if (tdct[1] == -dsoldct['tau'] or
                            tdct[1] == -dsoldct['tau']):
                        return dsoldct['dsol'][idx]
                    else:
                        raise UserWarning('Wrong delay value')
            return dsol

        csol = pastvals.copy()  # solution at current time instance
        ctime = 0.
        dsol = delaysol(ctime-tau_one, None)
        dsoldct = dict(t=ctime, tau=tau_one, csol=csol, dsol=dsol)
        sol = get_dsol(dsoldct)
        rhsinc = list(bamrhs(myy=sol))
        datadict = {0.: csol}

        for kkk, ctime in enumerate(timerange[1:]):
            # print(ctime, ctime-tau_one)
            dsol = delaysol(ctime-tau_one, datadict)
            cdt = ctime - timerange[kkk]
            csol = [csol[k] + cdt*rhsinc[k] for k in range(nall)]
            dsoldct.update(dict(dsol=dsol, csol=csol, t=ctime))
            rhsinc = list(bamrhs(myy=sol))
            datadict.update({ctime: csol})

        data = []
        for ctr in trl:
            cdt = [datadict[ctime] for ctime in ctr]
            data.append(np.array(cdt))
        plotit(data, timerange=trl, fignum=fignum,
               pname=problem, filterdata=None,
               timescale='timescale', coupling=coupling, nstates=n)
