import logging
from logging import info as linfo

import numpy as np
import scipy.linalg as spla

from numpy.linalg import norm

from plotting_tools import plotting_SVD_decay

__all__ = ['solve_leastsquares_redsvd',
           'compressed_vector_kronecker',
           'infer_quad_correction'
           ]


def get_domega_dq(omega):
    def domega_dq(qvec):
        ddqcols = []
        krnprd = compressed_vector_kronecker
        for kkk in range(qvec.shape[0]):
            eveck = np.zeros(qvec.shape)
            eveck[kkk] = 1
            oqv = krnprd(qvec, w=eveck) + krnprd(eveck, w=qvec)
            ddqcols.append((omega@oqv).flatten())
        dodq = np.hstack([ddqcols]).T
        return dodq
    return domega_dq


def solve_leastsquares_redsvd(X=None, F=None, tol_lstsq=1e-8,
                              plotplease=False, subdims=None):
    ''' solve a least square problem by a truncated SVD

    `min_A || F - AX ||^2`

    '''
    Ux, Sx, VxT = spla.svd(X, full_matrices=False)
    if plotplease:
        plotting_SVD_decay(Sx, 'Optinf - lstsquares singular values')
    rx = sum(Sx/Sx[0] > tol_lstsq)

    # computing minimal norm SVD solution for velocity
    Ysvd = F@VxT[:rx, :].T@np.diag(1/Sx[:rx])@Ux[:, :rx].T
    logging.info(f'svd[{rx}]: error |Ysvd@X-F|/|F|: {norm(Ysvd@X-F)/norm(F)}')
    logging.info(f'norm of the least squares solution Ysvd: {norm(Ysvd)}')
    logging.info(f'(relative) SVD cutoff tolerance: {tol_lstsq:.2e}')

    if subdims is None:
        return Ysvd
    else:
        spl = []
        osp = 0
        for csplit in subdims:
            nsp = osp + csplit
            spl.append(Ysvd[:, osp:nsp])
            osp = nsp
        return tuple(spl)


def compressed_vector_kronecker(v, w=None, stage=2, mask=None, ret_mask=False):
    ''' compute kronecker powers and remove the redundant parts

    e.g. from `[a, b] x [a, b] = [aa, ab, ba, bb]`,
    the second occurence of `ab=ba` is removed
    '''
    if stage == 2:
        if mask is None:
            N = v.shape[0]
            ones_upprtri = np.triu(np.ones((N, N), dtype=np.int))
            mask = ones_upprtri.reshape((N*N, ))
            mask = mask == 1
        else:
            pass
        rdkl = []
        rcn = np.int(N*(N+1)/2)  # mask.sum()
        if w is None:
            w = v
        else:
            if not v.shape == w.shape:
                raise RuntimeError('vectors need to be the same size')
            else:
                pass
        try:
            for kkk in range(v.shape[1]):
                cfkrn = np.kron(v[:, kkk], w[:, kkk])
                rdkl.append(cfkrn[mask].reshape((rcn, 1)))
        except IndexError:
            # if v, w is a 1D vector
            cfkrn = np.kron(v, w)
            rdkl.append(cfkrn[mask].reshape((rcn, 1)))
        redkron = np.hstack(rdkl)
    else:
        raise NotImplementedError('only `qxq` covered till now')
    if ret_mask:
        return redkron, mask
    else:
        return redkron


def infer_quad_correction(V=None, v=None, q=None, tol_lstsq=1e-8):
    '''infer the quadratic part of the map from the reduced coordinates

    `v = Vq + Omega qxq`

    to the full model coordinates
    '''

    redqq = compressed_vector_kronecker(q)
    omega = solve_leastsquares_redsvd(X=redqq, F=v-V@q, tol_lstsq=tol_lstsq)
    return omega


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    from rich.logging import RichHandler
    from nse_data_simu.load_data import get_matrices, load_snapshots
    from numpy.random import default_rng

    logging.basicConfig(level=logging.INFO, handlers=[RichHandler()],
                        format='%(message)s',
                        datefmt="[%X]",
                        )
    problem = 'drivencavity'
    ratio = 0.8
    Nprob = 2
    nseodedata = False
    # nseodedata = True
    Re = 500
    t0 = 0.
    tE = 6  # 4.
    # Nts = 2**12
    Nts = 2**9
    nsnapshots = 2**9

    # ## Parameters for the ROM
    qdim = 7

    if problem == 'cylinderwake':
        NVdict = {1: 5812, 2: 9356, 3: 19468}
        NV = NVdict[Nprob]
        Re = 40
    else:
        NVdict = {1: 722, 2: 3042, 3: 6962}
        NV = NVdict[Nprob]
        Re = 500

    # #########################################################################
    # ##### Loading system data ###############################################
    # #########################################################################

    logging.info(f'Loading data for {problem} with NV={NV} and Re={Re}')

    # getting system matrices
    M, A11, A12, H, B1, B2, Cv, Cp = \
        get_matrices(problem, NV, dataprefix='../scripts/')

    vsnsh, Vd, MVd, P, T = \
        load_snapshots(NV=NV, problem='drivencavity',
                       Re=Re, tE=tE, Nts=Nts,
                       dataprefix='../scripts/', nsnapshots=nsnapshots,
                       odesolve=nseodedata)

    rng = default_rng()
    trnprcntg = .8
    # how many snapshots to be used for the training
    # rest is for testing
    nsnsh = vsnsh.shape[1]
    selvec = np.array([False]*nsnsh)
    smpl = rng.choice(np.arange(nsnsh), np.int(trnprcntg*nsnsh), replace=False)
    selvec[smpl] = True
    trnselec = selvec
    tstselec = ~trnselec
    trnvs = vsnsh[:, trnselec]
    tstvs = vsnsh[:, tstselec]

    Uv, Sv, VvT = spla.svd(trnvs, full_matrices=False)
    podbas = Uv[:, :qdim]
    trnqsnapshots = podbas.T @ trnvs

    omega = infer_quad_correction(V=podbas, v=trnvs, q=trnqsnapshots,
                                  tol_lstsq=1e-5)
    cqq = compressed_vector_kronecker(trnqsnapshots)

    poderr = norm(trnvs-podbas@trnqsnapshots)/(trnselec.sum())
    linfo(f'qdim={qdim}: (train, avrg) POD-error={poderr}')

    podomerr = norm(trnvs-podbas@trnqsnapshots-omega@cqq)/(trnselec.sum())
    linfo(f'qdim={qdim}: (train, avrg) POD+Omega-error={podomerr}')

    tstqs = podbas.T @ tstvs
    tstpoderr = norm(tstvs-podbas@tstqs)/(nsnsh-trnselec.sum())
    linfo(f'qdim={qdim}: (test, avrg) POD-error={tstpoderr}')

    tstcqq = compressed_vector_kronecker(tstqs)
    podomerr = norm(tstvs-podbas@tstqs-omega@tstcqq)/(nsnsh-trnselec.sum())
    linfo(f'qdim={qdim}: (test, avrg) POD+Omega-error={podomerr}')

    dodq = get_domega_dq(omega)
    chckq = tstqs[:, 0]

    dodqcq = dodq(chckq)
    chckit = np.allclose((omega@compressed_vector_kronecker(chckq)).flatten(),
                         .5*dodqcq@chckq)
    linfo(f'2*x**2 - 2x*x? {chckit}')

    plt.show()
