#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 10 10:50:07 2018

@author: Maria Cruz Varona
"""

import time
import numpy as np
import scipy as sp
import amfe
import matplotlib.pyplot as plt

#%%
input_file = amfe.amfe_dir('meshes/gmsh/bar.msh')
output_file = amfe.amfe_dir('cruz/results/beam/NLMM_results')

#%% define system
my_material = amfe.KirchhoffMaterial(E=210E9, nu=0.3, rho=1E4, plane_stress=True)
sys2nd = amfe.MechanicalSystem()
sys2nd.load_mesh_from_gmsh(input_file, 7, my_material)
sys2nd.apply_dirichlet_boundaries(8, 'xy') # fixature of the left side
sys2nd.apply_neumann_boundaries(key=9, val=1E8, direct=(0,-1),
                                   time_func=lambda t: np.sin(31*t))


#%% NLMM 2nd order function
def nlmm2nd(MechanicalSystem, B_2nd, qr, dqr, ddqr, r, v0ik, **options):

    #options and their defaults
    if 'dispIter' in options: # display iterations ?
        dispIter = options['dispIter']
    else:
        dispIter = False
    if 'dispCond' in options: # display conditional number ?
        dispCond = options['dispCond']
    else:
        dispCond = False
    if 'orth' in options: #perform QR-decomposition ?
        orth = options['orth']
    else:
        orth = False
    if 'defl' in options: #perform deflation with singular value decomposition ?
        defl = options['defl']
    else:
        defl = False
    if 'r_defl' in options: #apply desired reduced order? default: r_defl = rank(Vraw)
        r_defl = options['r_defl']
    else:
        r_defl = False
    if 'AbsTol' in options: #apply desired absolute tolerance  ?
        AbsTol = options['AbsTol']
    else:
        AbsTol = 10E-2
    if 'RelTol' in options: #apply desired relative tolerance  ?
        RelTol = options['RelTol']
    else:
        RelTol = 10E-2

    #preallocations
    n = v0ik.shape[0] # order of the FOM
    NumSG = v0ik.shape[2] # create limiter for outer for loop: number of signal generators
    Ktotal = v0ik.shape[1] # create limiter for inner for loop: number of snapshots
    Vraw = np.zeros((n,NumSG*Ktotal)) # projection matrix V
    t = np.linspace(0,1,Ktotal) # build vector with time snapshots
    iSG = 0
    kSnap = 0

    #SO-NLMM-Algorithm
    for iSG in range(NumSG):
        print('Signal-Gerator: ', iSG+1)
        for kSnap in range(Ktotal):
            print('Time-Snapshot: ', kSnap+1)

            # get actual qrik and v0ik
            qrik = qr[kSnap,iSG]
            vik = v0ik[:,kSnap,iSG]

            # Compute required matrices
            M = MechanicalSystem.M()
            D = MechanicalSystem.D()
            K, f_int = MechanicalSystem.K_and_f(u = vik*qrik, t = t[kSnap])
#            f_ext = MechanicalSystem.f_ext(u = vik*qrik, t = t[kSnap])
#            print(np.linalg.norm(f_ext - B_2nd * r[:,kSnap,iSG]))

            # calculate Rhs and JacRhs
            Rhs = M @ vik * ddqr[kSnap,iSG]  + D @ vik * dqr[kSnap,iSG] + f_int - B_2nd * r[:,kSnap,iSG] #f_ext
            JacRhs = M * ddqr[kSnap,iSG] + D * dqr[kSnap,iSG] + K * qrik

            # norms
            abs_Rhs = np.linalg.norm(Rhs, 2)

            # Solve nonlinear systems of equations via Newton-Raphson method
            iteration = 0
            while abs_Rhs > RelTol * abs_Rhs + AbsTol:

                iteration += 1

                # solve
                delta_vik = np.linalg.solve(JacRhs.toarray(), Rhs)

                # update
                vik -= delta_vik

                # update A, F_int and F_ext
                M = MechanicalSystem.M()
                D = MechanicalSystem.D()
                K, f_int = MechanicalSystem.K_and_f(u = vik*qrik, t = t[kSnap])
#                f_ext = MechanicalSystem.f_ext(u = vik*qrik, t = t[kSnap])
#                print(np.linalg.norm(f_ext - B_2nd * r[:,kSnap,iSG]))

                # update Rhs, JacRhs
                Rhs = M @ vik * ddqr[kSnap,iSG] + D @ vik * dqr[kSnap,iSG] + f_int - B_2nd * r[:,kSnap,iSG] #f_ext
                JacRhs = M * ddqr[kSnap,iSG] + D * dqr[kSnap,iSG] + K * qrik

                # update norms
                abs_Rhs = np.linalg.norm(Rhs, 2)

                if dispIter == True:
                    print('Newton-Iteration: ', iteration, 'Residual:', np.linalg.norm(Rhs))

                if dispCond == True:
                    print('Condition of Jacobian: ',np.linalg.cond(JacRhs.toarray()))

                if iteration > 1000:
                    print('Maximum number of iterations reached without converging')
                    break

            # end of while-loop (Newton iteration)
            if iteration <= 1000:
                print('Newton-Raphson converged!')

            Vraw[:,iSG*Ktotal + kSnap] = vik
            if kSnap < (Ktotal-1):
                v0ik[:,kSnap+1,iSG] = Vraw[:,kSnap]

        # Real subspace #TODO
        # end of for-loop kSnap

    # end of for-loop iSG

    # Orthogonalization
    if orth == True:
        V_orth, sigma, __ = sp.linalg.svd(Vraw, full_matrices=False)
        return V_orth

    # Deflation
    if r_defl == False:
        r_defl = np.linalg.matrix_rank(Vraw)
    if defl == True:
        U, sigma, __ = sp.linalg.svd(Vraw, full_matrices=False)
        V_defl = U[:,:r_defl]
        return V_defl

    return Vraw

#%% Signal Generator TODO: Outsource into a different function (maybe even class)

# Preallocations
Ktotal = 20
SG = 1
m = 1
t_snap = np.linspace(0,1,Ktotal)
dqr = np.zeros((Ktotal,SG))
qr = np.zeros((Ktotal,SG))
r = np.zeros((m,Ktotal,SG))
ddqr = np.zeros((Ktotal,SG))

amplitude = 1
omega = 10
t_end = 1.0

#---------qr
for i in range(Ktotal):
    qr[i,0] = amplitude*np.sin(omega*t_snap[i])
#    print(qr[i,0])

#---------dqr = sv(qr)
for i in range(Ktotal):
    dqr[i,0] = np.cos(omega*t_snap[i])*amplitude*omega
#    print(dqr[i,0])

#---------ddqr
for i in range(Ktotal):
    ddqr[i,0] = -1*np.sin(omega*t_snap[i])*amplitude*omega*omega
#    print(ddqr[i,0])

#---------r
def u1(qr, A):
    u = A*qr
    return u

A = 1E8
r[0,:,:] = u1(qr, A)


# for calculation of B2nd
my_force_y = -1E8*np.sin(31*t_end)
B_2nd = sys2nd.f_ext(u = None, t = 1.0) / my_force_y


n = B_2nd.shape[0]
#v0ik = np.ma.ones((n,Ktotal,SG))
v0ik = np.zeros((n,Ktotal,SG))

# Training system for POD
my_material = amfe.KirchhoffMaterial(E=210E9, nu=0.3, rho=1E4, plane_stress=True)
sys2nd_train = amfe.MechanicalSystem()
sys2nd_train.load_mesh_from_gmsh(input_file, 7, my_material)
sys2nd_train.apply_dirichlet_boundaries(8, 'xy') # fixature of the left side
sys2nd_train.apply_neumann_boundaries(key=9, val=1E8, direct=(0,-1),
                                   time_func=lambda t: np.sin(10*t))

#%% Execute the NLMM algorithm
start_time = time.time()
V_nlmm = nlmm2nd(sys2nd, B_2nd, qr, dqr, ddqr, r, v0ik, dispIter = True, orth = True)
time_V_nlmm = time.time() - start_time

#%% Test, if V_nlmm is orthogonal
redOrd_nlmm = V_nlmm.shape[1]
V_nlmmT = np.matrix.transpose(V_nlmm)
Test_orth = np.linalg.norm(np.eye(redOrd_nlmm) - V_nlmmT@V_nlmm)

print(np.linalg.cond(V_nlmm))

#%% solve FOM
#---------------using linear solver (M, D, K)
sys2nd.clear_timesteps()
solverlin = amfe.GeneralizedAlphaLinearDynamicsSolver(sys2nd, dt = 1e-3)
start_time = time.time()
solverlin.solve()
time_lin_FOM = time.time() - start_time
sys2nd.export_paraview(output_file + '_nonlin_bar_FOM_linear')

##---------------using nonlinear solver (generalized-alpha)
#sys2nd.clear_timesteps()
#solvernl = amfe.GeneralizedAlphaNonlinearDynamicsSolver(sys2nd, dt = 1e-3)
#start_time = time.time()
#solvernl.solve()
##solvernl.solve_with_adaptive_time_stepping()
#time_nonlin_FOM = time.time() - start_time
#
## Export for Postprocessing with Paraview
#q_export_nonlin = np.array(sys2nd.u_output).T
#t_qr = np.array(sys2nd.T_output)
#q_export_nonlin = sys2nd.constrain_vec(q_export_nonlin)
#y_q_2nd = q_export_nonlin[1][:]
#sys2nd.export_paraview(output_file + '_nonlin_bar_FOM_nonlinear')
#
#np.save(output_file + '_q_export_nonlin', q_export_nonlin)
#np.save(output_file + '_t_qr', t_qr)
#np.save(output_file + '_y_q_2nd', y_q_2nd)
#np.save(output_file + '_time_nonlin_FOM', time_nonlin_FOM)

#%% Reduction (POD) and Simulation ROM-POD
redOrd_pod = V_nlmm.shape[1]

start_time = time.time()
sys2nd_train.clear_timesteps()
solvernl = amfe.GeneralizedAlphaNonlinearDynamicsSolver(sys2nd_train, dt = 1e-3)
start_time = time.time()
solvernl.solve()
#solvernl.solve_with_adaptive_time_stepping()
time_nonlin_FOM_train_POD = time.time() - start_time

q_export_nonlin_train = np.array(sys2nd_train.u_output).T
t_qr_train = np.array(sys2nd_train.T_output)
q_export_nonlin_train = sys2nd_train.constrain_vec(q_export_nonlin_train)
y_q_2nd_train = q_export_nonlin_train[1][:]

start_time = time.time()
__, V_pod = amfe.pod(sys2nd_train, redOrd_pod)
time_pod = time.time() - start_time
time_V_pod = time_nonlin_FOM_train_POD + time_pod
sysr2nd_V_pod = amfe.reduce_mechanical_system(sys2nd, V_pod)

initial_conditions_pod = {'q0': np.zeros(redOrd_pod), 'dq0': np.zeros(redOrd_pod)}
solver = amfe.GeneralizedAlphaNonlinearDynamicsSolver(sysr2nd_V_pod, dt = 1e-3, initial_conditions=initial_conditions_pod)
sysr2nd_V_pod.clear_timesteps()

start_time = time.time()
solver.solve()
time_nonlin_ROM_pod = time.time() - start_time

# Export for Postprocessing with Paraview
qr_export_V_pod = sysr2nd_V_pod.constrain_vec(np.array(sysr2nd_V_pod.u_output).T)
t_qr_V_pod = np.array(sysr2nd_V_pod.T_output)
yr_V_pod = qr_export_V_pod[1][:]
sysr2nd_V_pod.export_paraview(output_file + '_nonlin_bar_ROM_pod')

#%% Simulation ROM-NLMM
sysr2nd_V_nlmm = amfe.reduce_mechanical_system(sys2nd, V_nlmm)
redOrd_nlmm = V_nlmm.shape[1]

initial_conditions_nlmm = {'q0': np.zeros(redOrd_nlmm), 'dq0': np.zeros(redOrd_nlmm)}
nonlin_solver_ROM = amfe.GeneralizedAlphaNonlinearDynamicsSolver(sysr2nd_V_nlmm, dt = 1e-3, initial_conditions = initial_conditions_nlmm)
sysr2nd_V_nlmm.clear_timesteps()

start_time = time.time()
nonlin_solver_ROM.solve()
time_nonlin_ROM_nlmm = time.time() - start_time

# Export for Postprocessing with Paraview
qr_export_V_nlmm = np.array(sysr2nd_V_nlmm.u_red_output).T
t_qr_V_nlmm = np.array(sysr2nd_V_nlmm.T_output)
qr_V_nlmm = V_nlmm @ qr_export_V_nlmm
yr_V_nlmm = qr_V_nlmm[1][:]
sysr2nd_V_nlmm.export_paraview(output_file + '_nonlin_bar_ROM_nlmm')

#%%
#---------------Linear basis
r_Phi = 5

start_time = time.time()
omega_, Phi = amfe.reduced_basis.vibration_modes(sys2nd, r_Phi)

#---------------Static modal derivatives
Theta = amfe.reduced_basis.static_derivatives(Phi, sys2nd.K)

#---------------Extended basis V_extended = [Phi Theta]
V_extended = amfe.augment_with_derivatives(Phi, Theta)

time_V_extended = time.time() - start_time

#---------------Write linear basis
sys2nd.clear_timesteps()
for i in np.arange(Phi.shape[1]):
    sys2nd.write_timestep(i, Phi[:,i])
sys2nd.export_paraview(output_file + '_lin_basis')

#---------------Write static modal derivatives
sys2nd.clear_timesteps()
counter = 0
for i in np.arange(Theta.shape[1]):
    for j in np.arange(Theta.shape[1]):
        if i > j:
            sys2nd.write_timestep(counter, Theta[:,i,j])
            counter = counter + 1
sys2nd.export_paraview(output_file + '_SMDs')

#---------------Write extended basis
sys2nd.clear_timesteps()
for i in np.arange(V_extended.shape[1]):
    sys2nd.write_timestep(i, V_extended[:,i])
sys2nd.export_paraview(output_file + '_extended_basis')

#---------------Test, if V_extended is orthogonal
redOrd_V_extended = V_extended.shape[1]
Test_orth = np.linalg.norm(np.eye(redOrd_V_extended) - np.matrix.transpose(V_extended) @ V_extended)

#%%
#---------------Simulation of ROMs
### V_extended
sysr2nd_V_extended = amfe.reduce_mechanical_system(sys2nd, V_extended)

redOrd_V_extended = V_extended.shape[1]
initial_conditions_V_extended = {'q0': np.zeros(redOrd_V_extended), 'dq0': np.zeros(redOrd_V_extended)}
nonlin_solver_ROM_V_extended = amfe.GeneralizedAlphaNonlinearDynamicsSolver(sysr2nd_V_extended, dt = 1e-3, initial_conditions = initial_conditions_V_extended)
sysr2nd_V_extended.clear_timesteps()

start_time = time.time()
nonlin_solver_ROM_V_extended.solve()
time_nonlin_ROM_V_extended = time.time() - start_time

# Export for Postprocessing with Paraview
qr_export_V_extended = np.array(sysr2nd_V_extended.u_red_output).T
t_qr_V_extended = np.array(sysr2nd_V_extended.T_output)
qr_V_extended = V_extended @ qr_export_V_extended
yr_V_extended = qr_V_extended[1][:]
sysr2nd_V_extended.export_paraview(output_file + '_nonlin_bar_ROM_V_extended')

### Phi
sysr2nd_Phi = amfe.reduce_mechanical_system(sys2nd, Phi)

redOrd_Phi = Phi.shape[1]
initial_conditions_Phi = {'q0': np.zeros(redOrd_Phi), 'dq0': np.zeros(redOrd_Phi)}
nonlin_solver_ROM_Phi = amfe.GeneralizedAlphaNonlinearDynamicsSolver(sysr2nd_Phi, dt = 1e-3, initial_conditions = initial_conditions_Phi)
sysr2nd_Phi.clear_timesteps()

start_time = time.time()
nonlin_solver_ROM_Phi.solve()
time_nonlin_ROM_Phi = time.time() - start_time

# Export for Postprocessing with Paraview
qr_export_V_Phi = np.array(sysr2nd_Phi.u_red_output).T
t_qr_V_Phi = np.array(sysr2nd_Phi.T_output)
qr_V_Phi = Phi @ qr_export_V_Phi
yr_Phi = qr_V_Phi[1][:]
sysr2nd_Phi.export_paraview(output_file + '_nonlin_bar_ROM_Phi')


#%% Error measure
#-------Load simulation data from FOM for test signal
q_export_nonlin = np.load(output_file + '_q_export_nonlin.npy')
t_qr = np.load(output_file + '_t_qr.npy')
y_q_2nd = np.load(output_file + '_y_q_2nd.npy')
time_nonlin_FOM = np.load(output_file + '_time_nonlin_FOM.npy')


#------- q-errors for different reduction methods
error_q_rel_Phi = np.linalg.norm(q_export_nonlin - qr_V_Phi, 2)/np.linalg.norm(q_export_nonlin, 2)
error_q_rel_V_ext = np.linalg.norm(q_export_nonlin - qr_V_extended, 2)/np.linalg.norm(q_export_nonlin, 2)
error_q_rel_V_pod = np.linalg.norm(q_export_nonlin - qr_export_V_pod, 2)/np.linalg.norm(q_export_nonlin, 2)
error_q_rel_V_nlmm = np.linalg.norm(q_export_nonlin - qr_V_nlmm, 2)/np.linalg.norm(q_export_nonlin, 2)

#------- y-errors for different reduction methods
error_y_rel_Phi = np.linalg.norm(y_q_2nd - yr_Phi, 2)/np.linalg.norm(y_q_2nd, 2)
error_y_rel_V_ext = np.linalg.norm(y_q_2nd - yr_V_extended, 2)/np.linalg.norm(y_q_2nd, 2)
error_y_rel_V_pod = np.linalg.norm(y_q_2nd - yr_V_pod, 2)/np.linalg.norm(y_q_2nd, 2)
error_y_rel_V_nlmm = np.linalg.norm(y_q_2nd - yr_V_nlmm, 2)/np.linalg.norm(y_q_2nd, 2)

#------- Print
print('error_y_rel_Phi = %s' %error_y_rel_Phi)
print('error_y_rel_V_ext = %s' %error_y_rel_V_ext)
print('error_y_rel_V_pod = %s' %error_y_rel_V_pod)
print('error_y_rel_V_nlmm = %s' %error_y_rel_V_nlmm)

#%% Plot y-displacement for different reduction methods
fig = plt.figure(1)
ax = plt.subplot(111)
plt.title('')
plt.xlabel('time (s)')
plt.ylabel('$y$-displacement (m)')

grey = (0.5, 0.5, 0.5)
brown = (0.67, 0.0, 0.0)
blue = (0.0, 0.67, 1.0)
red = (1.0, 0.0, 0.0)
green = (0.0, 0.67, 0.0)

plotFOM, = plt.plot(t_qr, y_q_2nd, color=grey, linestyle='-', linewidth=4.0, label = 'FOM')
plotPHI, = plt.plot(t_qr_V_Phi, yr_Phi, color=brown, linestyle='-', linewidth=2.0, label = 'Phi')
plotEXT, = plt.plot(t_qr_V_extended, yr_V_extended, color=blue, linestyle='-.', linewidth=2.0, label = 'Vaug')
plotPOD, = plt.plot(t_qr_V_pod, yr_V_pod, color=red, linestyle=':', linewidth=2.0, label = 'POD')
plotNLMM, = plt.plot(t_qr_V_nlmm, yr_V_nlmm, color=green, linestyle ='--', linewidth=2.0, label = 'NLMM')

plt.axis([0,1,-3,3])
plt.grid(True)

ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=5, fancybox=True, shadow=True)
plt.savefig('y_displacements.png', format = 'png')

from matplotlib2tikz import save
save("y_displacements.tex")

#%% Plot y-errors for different reduction methods
fig = plt.figure(2)
ax = plt.subplot(111)
plt.title('')
plt.xlabel('time (s)')
plt.ylabel('rel. $\mathcal{L}_2$ output-error')

error_output_rel_Phi = np.abs(y_q_2nd - yr_Phi)/np.linalg.norm(y_q_2nd, 2)
error_output_rel_V_ext = np.abs(y_q_2nd - yr_V_extended)/np.linalg.norm(y_q_2nd, 2)
error_output_rel_V_pod = np.abs(y_q_2nd - yr_V_pod)/np.linalg.norm(y_q_2nd, 2)
error_output_rel_V_nlmm = np.abs(y_q_2nd - yr_V_nlmm)/np.linalg.norm(y_q_2nd, 2)

plt.axis([0,1,1e-9,1e-1])
plt.grid(True)

plotPHI, = plt.semilogy(t_qr_V_Phi, error_output_rel_Phi, color=brown, linestyle='-', linewidth=2.0, label = 'Phi')
plotEXT, = plt.semilogy(t_qr_V_extended, error_output_rel_V_ext, color=blue, linestyle='-.', linewidth=2.0, label = 'Vaug')
plotPOD, = plt.semilogy(t_qr_V_pod, error_output_rel_V_pod, color=red, linestyle=':', linewidth=2.0, label = 'POD')
plotNLMM, = plt.semilogy(t_qr_V_nlmm, error_output_rel_V_nlmm, color=green, linestyle ='--', linewidth=2.0, label = 'NLMM')

ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=5, fancybox=True, shadow=True)
plt.savefig('error_output_rel.png', format = 'png')
save("error_output_rel.tex")

# https://matplotlib.org/users/usetex.html
# https://matplotlib.org/users/customizing.html#customizing-matplotlib

#%% Printing reduction and simulation times
#----- reduction times
print('time_V_extended = %s seconds' %time_V_extended)
print('time_V_pod = %s seconds' %time_V_pod)
print('time_V_nlmm = %s seconds' %time_V_nlmm)

#----- simulation times
print('time_lin_FOM = %s seconds' %time_lin_FOM)
print('time_nonlin_FOM = %s seconds' %time_nonlin_FOM)

print('time_nonlin_ROM_Phi = %s seconds' %time_nonlin_ROM_Phi)
print('time_nonlin_ROM_V_extended = %s seconds' %time_nonlin_ROM_V_extended)
print('time_nonlin_ROM_pod = %s seconds' %time_nonlin_ROM_pod)
print('time_nonlin_ROM_nlmm = %s seconds' %time_nonlin_ROM_nlmm)

#%%
fig = plt.figure(3)
ax = plt.subplot(111)
plt.title('')
plt.xlabel('time (s)')
plt.ylabel('$y$-displacement (m)')

grey = (0.5, 0.5, 0.5)
brown = (0.67, 0.0, 0.0)
blue = (0.0, 0.67, 1.0)
red = (1.0, 0.0, 0.0)
green = (0.0, 0.67, 0.0)

plotFOM_test, = plt.plot(t_qr, y_q_2nd, color=grey, linestyle='-', linewidth=4.0, label = 'FOM_test')
plotFOM_train, = plt.plot(t_qr, y_q_2nd_train, color='k', linestyle='-', linewidth=4.0, label = 'FOM_train')
plotSG_qr_train, = plt.plot(t_snap, qr, color=green, linestyle='-', linewidth=4.0, label = 'SG_qr_train')

ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3, fancybox=True, shadow=True)
