import logging
# from rich.logging import RichHandler

from load_or_comp import load_nsevel_data

from optinf_tools import deriv_approx_data, optinf_quad_svd, optinf_linear
import optinf_tools as oit

from dmd_tools import dmd_model, dmdc_model, sim_dmd, sim_dmdc

import quadmf_optinf_tools as qot

import numpy as np
from scipy.linalg import norm
from scipy.linalg import svd
import matplotlib.pyplot as plt
from scipy.integrate import odeint, solve_ivp


plt.style.use('Solarize_Light2')
plt.style.use('ggplot')
# logging.basicConfig(level=logging.INFO, handlers=[RichHandler()],
#                     format='%(message)s', datefmt="[%X]",
#                     )

# ##########################################################################
# ##### System parameters ##################################################
# ##########################################################################

# Ratio between training and test data
ratio = 1/2
Nprob = 2
Re = 30
t0 = 0.
tE = 6  # 4.
nsnapshots = 2000

# ## PAMM
tE = 12  # 4.
rv = 10

# ## the data

dtfile = f'sc_{Re}_0-{tE}_{nsnapshots}.json'
vpc = 'simulation_nse/'

fullV, trange, cmat, M = load_nsevel_data(vpc+dtfile,
                                          velfile_path_correction=vpc)
logging.info('loaded data from: ' + vpc+dtfile)
vmean = fullV.mean(axis=1).reshape((fullV.shape[0], 1))
nrmvmean = np.sqrt(vmean.T @ M @ vmean)
fullT = trange.tolist()
logging.info(f'norm of v_mean: {nrmvmean.item():.4f}')

plt.figure(10101, figsize=(6, 2))
plt.plot(trange, (cmat @ fullV).T)
plt.xlabel('$t$')
plt.ylabel('$y(t) = Cv(t)$')
plt.tight_layout()
plt.savefig("Figure2.png")

for simucase in ['startup', 'periodic']:
    if simucase == 'startup':
        # ## t: (6, 12)
        rminiratio = 1/2  # which part to consider
        omglstol = 2e-2
        qmfoilsqtol = 1e-2
    elif simucase == 'periodic':
        # ## t: (0, 6)
        rminiratio = -1/2
        omglstol = 8e-3
        qmfoilsqtol = 1e-3

    vzero = np.zeros((fullV.shape[0], 1))
    yzero = np.zeros((cmat.shape[0], 1))

    tol_lstsq = 1e-4

    tol_lstsq_dmdc = 1e-8
###########################################################################
# ##### Loading system data ################################################
###########################################################################

    # Training and test data
    if rminiratio > 0:
        V = fullV[:, int(len(fullT)*rminiratio):]
        T = fullT[int(len(fullT)*rminiratio):]
    elif rminiratio < 0:
        V = fullV[:, :int(len(fullT)*rminiratio)]
        T = fullT[:int(len(fullT)*rminiratio)]
    else:
        V = fullV
        T = fullT

    Vf = V                      # Vf correponds to the test velocity data
    Tf = T                      # Tf correponds to the time interval for Tf
    V = Vf[:, :int(len(T)*ratio)]  # V correponds to the training velocity data
    T = T[:int(len(T)*ratio)]    # T correponds to the time interval for T

    ###########################################################################
    # ##### Computing reduced basis ###########################################
    ###########################################################################

    Uv, Sv, VvT = svd(V, full_matrices=False)

    # plotting decay of singular values
    # plotting_SVD_decay(Sv)

    # order of reduced models
    Uvr = Uv[:, :rv]

    ###########################################################################
    # ##### Computing reduced trajectories#####################################
    ###########################################################################

    dt = T[1]-T[0]
    V_red = Uvr.T@V
    N_sims = 0

    # Computing reduced derivatives
    Vd_red = deriv_approx_data(V_red, dt, N_sims)

    ###########################################################################
    # ##### Operator inference quadratic model ################################
    ###########################################################################

    print('Computing operator inference model... \n')

    Aoptinf, Hoptinf, Boptinf = optinf_quad_svd(V_red, Vd_red, tol_lstsq)

    ###########################################################################
    # ##### Operator inference linear model ###################################
    ###########################################################################

    Aoptinf_lin, Boptinf_lin = optinf_linear(V_red, Vd_red)

    ###########################################################################
    # ##### DMD  model ########################################################
    ###########################################################################

    print('Computing DMD models... \n')
    Admd = dmd_model(Uvr, V, rv)

    ###########################################################################
    # ##### DMD model with control ############################################
    ###########################################################################

    Admdc, Bdmdc = dmdc_model(Uvr, V, rv, tol_lstsq_dmdc)

    ###########################################################################
    # ##### DMD quadratic model with control ##################################
    ###########################################################################

    # Admd_quad, Hdmd_quad, Bdmd_quad = dmdquad_model(Uvr, V, rv)

    ###########################################################################
    # ##### Quadratic MF operator inference ###################################
    ###########################################################################
    logging.info('Inferring the QMF model...')
    podbas, prjbas = Uvr, Uvr
    qsnsh = prjbas.T.dot(V)
    omega = qot.infer_quad_correction(V=podbas, v=V, q=qsnsh,
                                      tol_lstsq=omglstol)
    dodq = qot.get_domega_dq(omega)

    gmgdql = []
    # dqsnsh = prjbas.T @ Vd
    dqsnsh = Vd_red
    Nsnsh = V.shape[1]

    def gqtMgqqd(qstate, qdot=None):
        '''XXX: here with M=I for the moment'''
        gq = podbas + dodq(qstate)
        if qdot is not None:
            return gq.T.dot(gq.dot(qdot))
        else:
            return gq.T.dot(gq)

    for kkk in range(Nsnsh):
        gmgdql.append(gqtMgqqd(qsnsh[:, kkk:kkk+1],
                               qdot=dqsnsh[:, kkk:kkk+1]))

    dqarray = np.hstack(gmgdql)
    cqq = qot.compressed_vector_kronecker(qsnsh)
    bigX = np.vstack([np.ones((1, Nsnsh)), qsnsh, cqq])

    (Az, Ao, At) = qot.\
        solve_leastsquares_redsvd(X=bigX, F=dqarray, tol_lstsq=qmfoilsqtol,
                                  subdims=(1, rv, np.int(rv*(rv+1)/2)))

    def quadmf_opinf_rhs(t, qstate):
        gtmg = gqtMgqqd(qstate)
        cqq = qot.compressed_vector_kronecker(qstate)
        rhs = Az.flatten() + Ao.dot(qstate).flatten() + At.dot(cqq).flatten()
        return np.linalg.solve(gtmg, rhs)

    ###########################################################################
    # ##### Simulating systems ################################################
    ###########################################################################

    logging.info('Simulating reduced order systems...')

    # projected initial condition
    x0 = Uvr.T@V[:, 0]

    # simulating OptInf linear model
    logging.info('integrating lin opinf')
    xsol_optinf_lin = odeint(oit.lin_model, x0, Tf, (Aoptinf_lin, Boptinf_lin))
    Voptinf_lin = Uvr @ xsol_optinf_lin.T

    # simulating quad MF OptInf model
    method = 'BDF'
    method = 'LSODA'
    logging.info('integrating qmf opinf')
    qmfopinfq = solve_ivp(quadmf_opinf_rhs, (Tf[0], tE), x0, method=method,
                          t_eval=Tf)
    Voptinf_qmf = Uvr @ qmfopinfq.y

    # simulating DMD model
    Vrdmd = sim_dmd(Admd, x0, len(Tf))
    Vdmd = Uvr@Vrdmd

    # Simulating DMD model with control
    Vrdmdc = sim_dmdc(Admdc, Bdmdc, x0, len(Tf))
    Vdmdc = Uvr@Vrdmdc

    # Simulating DMD quadratic model with control
    # Vrdmd_quad = sim_dmdquad(Admd_quad, Hdmd_quad, Bdmd_quad, x0, len(T))
    # Vdmd_quad  = Uvr@Vrdmd_quad

    ###########################################################################
    # ##### Plotting results ##################################################
    ###########################################################################

    fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, tight_layout=True,
                                   figsize=(9, 6))
    flngth = 24
    tfilter = np.arange(0, len(Tf), flngth)
    fskip = 4
    trange = np.array(Tf)

    def incrmntfilter(ctf):
        ctf = np.r_[0, ctf+fskip]
        try:
            ctr = trange[ctf]
        except IndexError:
            ctf = ctf[:-1]
            ctr = trange[ctf]
        return ctf, ctr

    markerlst = ['v:', '^:', '<:', '>:', 'd:']
    markerlst = ['o-', 's-', 'd-', 'D-', 'p-']
    msize = 3
    lw = .5
    # ax = plt.subplot(212)

    ctf = tfilter
    logging.info(f'filter for plotting the errors: ctf={flngth}')
    gtr = trange
    ctr = gtr[tfilter]
    prop_cycle = plt.rcParams['axes.prop_cycle']
    colors = prop_cycle.by_key()['color']
    # datalist = [Voptinf, Voptinf_lin, Vpod, Vdmd, Vdmdc]
    # datalist = [Voptinf, Voptinf_qmf, Vpod, Vdmd, Vdmdc]
    # labellist = ['OpInf', 'OpInf qmf', 'POD', 'DMD', 'DMDc']
    datalist = [Vdmdc, Voptinf_qmf]
    labellist = ['DMDc', 'OpInf qmf']

    yf = cmat@Vf + yzero
    ax1.plot(Tf, (yf).T[:, 0], 'k', linewidth=lw, label='FOM')
    ax1.plot(Tf, (yf).T[:, 1:], 'k', linewidth=lw)

    for kkk in range(len(datalist)):
        cmkr, ccl = markerlst[kkk], colors[kkk]
        try:
            cdf = Vf - datalist[kkk]
            cmnrmlst = [np.sqrt(cdf[:, ckc].T @ M @ cdf[:, ckc])
                        for ckc in ctf]
            ax2.semilogy(ctr, cmnrmlst,
                         cmkr, color=ccl, label=labellist[kkk],
                         linewidth=lw, markersize=msize)
            ax1.plot(ctr, (cmat@datalist[kkk]+yzero).T[ctf, 0],
                     cmkr, color=ccl, label=labellist[kkk],
                     linewidth=lw, markersize=msize)
            ax1.plot(ctr, (cmat@datalist[kkk]+yzero).T[ctf, 1:],
                     cmkr, color=ccl,
                     linewidth=lw, markersize=msize)
            if kkk == 0:
                ctf, ctr = incrmntfilter(ctf)
            else:
                ctf, ctr = incrmntfilter(ctf[1:])
        except ValueError:
            pass

    # ax1.plot(T, (Cv@Voptinf).T, '--b')
    ax1.axvline(x=T[-1], color='k', linestyle='--')
    ax2.axvline(x=T[-1], color='k', linestyle='--')
    ax2.set_xlabel('time $t$')
    ax2.set_ylabel('$\\|v(t)-v_r(t)\\|_L^2$')
    # ax2.set_ylabel('$L_{\\infty}$ error of $v(t)$')
    # ax2.legend(loc='upper right')
    ax2.set_title(f"Approximation error: $r={rv}$")
    # ax2.subplots_adjust(wspace=0.5)
    ax1.set_ylabel('$y(t)=Cv(t)$')
    ax1.legend(loc='upper right')
    ax1.set_title(f"Time-domain simulation: $r={rv}$")

    fig.savefig(f"sc_{Re}_{Tf[0]}-{Tf[-1]}_{nsnapshots}_rv{rv}.pdf")
    fig.savefig(f"sc_{Re}_{Tf[0]}-{Tf[-1]}_{nsnapshots}_rv{rv}.png")

    # print('Optinf error: ', norm(Voptinf-Vf))
    print(simucase + ': Optinf lin error: ', norm(Voptinf_lin-Vf))
    print(simucase + ': DMD error: ', norm(Vdmd-Vf))
    print(simucase + ': DMDc error: ', norm(Vdmdc-Vf))
    print(simucase + ': Optinf qmf error: ', norm(Voptinf_qmf-Vf))

plt.show(block=True)
