import numpy as np


def problem_definitions(problem='2x2'):
    if problem == '2x2':
        a = [[.5, .4], [.4, .3]]
        b = [[.4, .1], [.2, .3]]
        c = [[.7, .0], [.0, .4]]
        d = [[.5, .0], [.0, .3]]
        k = [[.3, .2], [.2, .4]]
        ell = [[.4, .2], [.1, .3]]
        fcffs = [.4, .5]
        gcffs = [.5, .6]
        pastvals = [1., .5, .5, 1., -1., -2., -2., -1.]

        # # Parameters for the Closed Loop Simulation
        # ## Continuous Case
        comgone = [[2., 1.], [0., 2.]]
        comgtwo = [[1.5, 0.], [1., 2.]]
        ctauone = 1.2
        ctautwo = 1.2
        csimpars = [comgone, comgtwo, ctauone, ctautwo]
        # ## Z-Case
        dtauone = 1.
        dtautwo = 1.
        domgone = [[.29, 0.], [.05, .6]]
        domgtwo = [[.5, .03], [0., .7]]
        dsimpars = [domgone, domgtwo, dtauone, dtautwo]
        ddt = 1.
        # ## Other TS
        intlngth = .4
        intdstnc = .8
        atauone = .8
        atautwo = .8
        aomgone = [[1.8, 1.], [.0, 1.5]]
        aomgtwo = [[1.4, 0.], [1., 1.7]]
        asimpars = [aomgone, aomgtwo, atauone, atautwo]

    if problem == '10x10':
        a = [[1.0, 0.5, 0.1, 0.2, 0.5, 0.1, 0.2, 0.4, 0.7, 0.1],
             [0.1, 0.3, 1.0, 0.4, 0.3, 0.2, 0.3, 0.7, 0.4, 0.5],
             [0.5, 0.9, 0.4, 0.8, 0.7, 1.0, 0.4, 0.3, 0.2, 0.3],
             [0.7, 1.0, 0.4, 0.3, 0.0, 0.3, 1.0, 0.4, 0.3, 0.2],
             [0.5, 0.1, 0.7, 0.0, 0.0, 1.0, 0.4, 0.3, 0.0, 0.3],
             [1.0, 0.4, 0.3, 0.2, 0.3, 0.1, 0.3, 1.0, 0.4, 0.3],
             [0.3, 0.2, 0.7, 1.0, 0.4, 0.3, 0.0, 0.3, 0.3, 0.0],
             [0.7, 0.0, 0.1, 1.0, 0.4, 0.8, 0.7, 1.0, 0.4, 0.3],
             [0.9, 0.4, 0.8, 0.7, 1.0, 0.4, 1.0, 0.5, 0.1, 0.2],
             [0.3, 0.0, 0.3, 0.3, 0.5, 0.1, 0.7, 0.0, 0.0, 1.0]]
        b = [[1.0, 0.3, 0.2, 0.3, 0.5, 0.3, 0.7, 1.0, 0.0, 0.5],
             [0.0, 1.0, 0.1, 0.6, 0.2, 0.3, 0.2, 0.3, 0.5, 0.3],
             [0.3, 0.7, 1.0, 0.2, 0.3, 0.2, 0.3, 0.7, 0.4, 0.5],
             [0.4, 0.0, 0.5, 1.0, 0.7, 0.3, 0.5, 0.3, 0.7, 1.0],
             [0.6, 1.0, 0.3, 0.0, 1.0, 0.0, 0.5, 1.0, 0.7, 0.3],
             [0.5, 0.9, 0.4, 0.8, 0.7, 1.0, 0.4, 0.3, 0.2, 0.3],
             [0.7, 1.0, 0.4, 0.3, 0.0, 0.3, 1.0, 0.4, 0.3, 0.2],
             [0.5, 0.1, 0.7, 0.0, 0.0, 1.0, 0.4, 0.3, 0.0, 0.3],
             [1.0, 0.4, 0.3, 0.2, 0.3, 0.1, 0.3, 1.0, 0.4, 0.3],
             [0.3, 0.2, 0.7, 1.0, 0.4, 0.3, 0.0, 0.3, 0.3, 0.]]
        c = np.diag([0.3, 1.0, 0.7, 0.4, 1.3, 1.2, 0.6, 0.4, 1.1, 2.0])
        d = np.diag([0.2, 0.4, 0.3, 0.4, 0.7, 0.6, 0.7, 1.0, 1.7, 0.6])
        k = [[0.4, 0.3, 0.2, 0.3, 0.1, 0.3, 0.7, 0.3, 0.0, 0.1],
             [0.0, 1.0, 0.1, 0.6, 0.2, 0.1, 0.3, 0.3, 0.2, 0.0],
             [0.3, 0.7, 1.0, 0.2, 0.3, 0.0, 1.0, 0.1, 0.3, 0.2],
             [0.4, 0.0, 0.5, 1.0, 0.7, 0.3, 0.2, 0.0, 1.0, 0.1],
             [0.1, 0.3, 0.3, 0.2, 1.0, 0.3, 0.7, 1.0, 0.2, 0.3],
             [0.3, 0.7, 1.0, 0.2, 0.3, 0.2, 1.0, 0.3, 0.7, 1.0],
             [0.4, 1.0, 0.5, 1.0, 0.7, 0.4, 0.0, 0.5, 1.0, 0.7],
             [1.0, 0.3, 0.2, 0.3, 0.5, 0.3, 0.7, 1.0, 0.0, 0.5],
             [0.0, 1.0, 0.1, 0.6, 0.2, 0.3, 0.2, 0.3, 0.5, 0.3],
             [0.3, 0.7, 1.0, 0.2, 0.3, 0.2, 0.3, 0.7, 0.4, 0.5]]
        ell = [[1.0, 0.4, 0.3, 0.2, 0.0, 0.4, 0.0, 0.5, 1.0, 0.7],
               [0.0, 1.0, 0.1, 0.1, 0.2, 0.3, 0.1, 0.3, 0.7, 0.3],
               [0.3, 0.7, 1.0, 0.2, 0.3, 0.2, 1.0, 0.3, 0.7, 1.0],
               [0.4, 1.0, 0.5, 1.0, 0.7, 0.4, 0.0, 0.5, 1.0, 0.7],
               [1.0, 0.4, 0.3, 0.1, 1.0, 0.3, 0.3, 0.2, 1.0, 0.3],
               [0.5, 0.9, 0.4, 0.8, 0.7, 1.0, 0.4, 0.3, 0.2, 0.3],
               [0.7, 1.0, 0.4, 0.3, 0.0, 0.3, 1.0, 0.4, 0.3, 0.2],
               [0.5, 0.1, 0.7, 0.0, 0.0, 1.0, 0.4, 0.3, 0.0, 0.3],
               [1.0, 0.4, 0.3, 0.2, 0.3, 0.1, 0.3, 1.0, 0.4, 0.3],
               [0.3, 0.2, 0.7, 1.0, 0.4, 0.3, 0.0, 0.3, 0.3, 0.0]]

        fcffs = [0.25, 0.1, 0.15, 0.2, 0.22, 0.21, 0.2, 0.25, 0.2, 0.2]
        gcffs = [0.30, 0.15, 0.25, 0.2, 0.23, 0.28, 0.30, 0.25, 0.25, 0.2]
        pastvals = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0,
                    2.0, 1.9, 1.8, 1.7, 1.6, 1.5, 1.4, 1.3, 1.2, 1.1,
                    -0.1, 0.2, -0.3, -0.4, 0.5, -0.6, 0.7, 0.8, 0.9, 1.0,
                    -2.0, -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1]

        # ## R-Case
        comgone = np.diag([3.73, 3.03, 3.33, 3.63, 2.73,
                           2.83, 3.43, 3.63, 2.93, 2.03])
        comgtwo = np.diag([3.53, 3.33, 3.43, 3.33, 3.03,
                           3.13, 3.03, 2.73, 2.03, 3.13])
        ctauone = .8
        ctautwo = .8
        csimpars = [comgone, comgtwo, ctauone, ctautwo]
        # ## Z-Case
        domgone = np.diag([3.73, 3.03, 3.33, 3.63, 2.73,
                           2.83, 3.43, 3.63, 2.93, 2.03])
        domgtwo = np.diag([3.53, 3.33, 3.43, 3.33, 3.03,
                           3.13, 3.03, 2.73, 2.03, 3.13])
        dtauone = .5
        dtautwo = .5
        ddt = .25  # time step
        dsimpars = [domgone, domgtwo, dtauone, dtautwo]
        # ## [k, k+0.8] case
        intlngth = .8
        intdstnc = 1.
        atauone = 1.
        atautwo = 1.
        aomgone = np.diag([4.2, 3.5, 3.8, 4.1, 3.2,
                           3.3, 3.9, 4.1, 3.4, 2.5])
        aomgtwo = np.diag([3.8, 3.6, 3.7, 3.6, 3.3,
                           3.4, 3.3, 3.0, 2.3, 3.4])
        asimpars = [aomgone, aomgtwo, atauone, atautwo]

    coeffs = [a, b, c, d, k, ell]
    coeffnames = ['A', 'B', 'C', 'D', 'K', 'L']
    simparnames = ['Omega_one', 'Omega_two', 'tau_one', 'tau_two']

    prbdct = dict(fcffs=fcffs, gcffs=gcffs, pastvals=pastvals)
    for cnm, coeff in zip(coeffnames, coeffs):
        prbdct.update({cnm: np.array(coeff)})
    tsnames = ['R', 'hZ', 'ots']
    tssimpars = [csimpars, dsimpars, asimpars]
    for tsn, tsp in zip(tsnames, tssimpars):
        try:
            atsdict = {}
            for idx in range(2):
                atsdict.update({simparnames[idx]: np.array(tsp[idx])})
                atsdict.update({simparnames[idx+2]: tsp[idx+2]})
            prbdct.update({tsn: atsdict})
        except IndexError:
            pass
        try:
            prbdct['ots'].update(dict(intlngth=intlngth, intdstnc=intdstnc))
        except KeyError:
            pass
        try:
            prbdct['hZ'].update(dict(ddt=ddt))
        except KeyError:
            pass

    return prbdct


def get_delaysol(trackdelaytrange, pastvals):
    '''helper function that returns the delay value from a dictionary

    direct query may not work since the floating points may not match exactly
    '''
    def delaysol(t_minus_tau, soldict):
        if t_minus_tau < 0:
            return pastvals
        else:
            therealtmt = trackdelaytrange.pop(0)
            if np.allclose(therealtmt, t_minus_tau):
                return soldict[therealtmt]
            else:
                raise UserWarning('somethings wrong with the delay')
    return delaysol
