"""
    Author: Julia Maia
    Date: June 2023

    Set of functions used to compute the dynamic loading model
    first developed by Richareds and Hager (1984) and Hager and Clayton 1989

"""

import os
import sys

import numpy as np
import sympy as sym
from sympy.solvers.solveset import linsolve
import scipy.linalg as la


import pyshtools as pysh


# Setting constants
R = 6051877.359  # mean planetary radius
M = 4.867305816025651e24  # mean planetary mass
E = 1e11
nu = 0.25
rhom = 3300  # uppermost mantle density
rhom_cmb = 5350  # mantle density at cmb depth
rhocore = 9400  # core density at cmb
drho_core = rhocore - rhom_cmb  # core-mantle density contrast
Rc = 3250e3  # core radius from Aitta 2012
mu0 = 1e20
G = 6.6743e-11  # gravitational constant


def computemodel(modelparams, OBS_LM, LMAX):
    """
    Computes dynamic loading model of gravity and topography

    Parameters
    ==========
    modelparams: dict
        free parameters of dynamic model (viscosity structure, load depth, elastic thickness)
    OBS_LM: dict
        Spherical harminics of observed gravity and topography
    LMAX: int
        maximum spherical harmonic degree

    returns
    =======
    Dict:
        dictionary with SHCoeffs of topography (m), radial gravity (m/s2) and estimated mantle mass-sheet (kg/m2)
    """

    # Kernels
    kernels = dynamickernels(
        OBS_LM["R"],
        modelparams["Rms"],
        modelparams["mun"],
        modelparams["Rvis"],
        modelparams["Te"],
        LMAX,
    )

    kernels = addradgravkernels(kernels, R)

    DYN_PRED = {}
    # Computng mass-sheer following Pauer+ 2006
    DYN_PRED["mass-sheet"] = singlemspauer(kernels, OBS_LM, LMAX)

    # Compting global predicted gravity and topography
    DYN_PRED["glm_global"] = predgrav1ms(DYN_PRED["mass-sheet"], kernels["G_rad"])
    DYN_PRED["hlm_global"] = predtopo1ms(DYN_PRED["mass-sheet"], kernels["D"])

    return DYN_PRED


def dynamickernels(R, Rms, mun, Rvis, Te, LMAX):
    """
    Compute dynamic kernels allowing for lithospheric flexure and viscosity variations in the mantle, following equation C6 from James et al. (2013).
    Free-slip at core-mantle boundary and no-slip at the surface


    Parameters
    ==========
    R: float
        mean planetary radius or local average radius of investigated region (m)
    Rms: float
        mass-sheet radius (m)
    mun: array like
        normalized viscosities [layers from surface to core] (mu/mu0)
    Rvis: array like
        Radii os viscosity variations [interfaces from surface to core] (m)
    Te: int, float
        Elastic thickness of the lithosphere (m)
    LMAX: int
        maximum spherical harmonic degree

    Return
    ======
    Dynamic kernels
    """
    G = pysh.constants.G.value  # Gravitational constant
    g0 = G * M / (R ** 2)  # grav accelarion at MPR (local or global)
    gc = g0_x(Rc, Rc)  # gravitational accelarion at CMB
    gms = g0_x(Rms, Rc)  # gravitational accelarion at mass load radius

    # Defining variables that the linear equations will solved for
    CMlm = sym.Symbol("CMlm")  # Core-mantle boundary relief
    Hlm = sym.Symbol("Hlm")  # Surface relief produced by dynamic flow
    # Nlm = sym.Symbol('Nlm') # Geoid height produced by dynamic flow
    Tlm = sym.Symbol("Tlm")  # Surface poloidal shear stress
    Vlm = sym.Symbol("Vlm")  # liquid core poloidal velocity

    # factors used to compute geoid kernel
    larr = np.arange(0, LMAX + 1)
    geoid_prefactor = 4 * np.pi * G * R / (g0 * (2 * larr + 1))
    geoid_corefactor = drho_core * (Rc / R) ** (larr + 2)
    geoid_msfactor = (Rms / R) ** (larr + 2)

    # Factors lin system
    fact1 = (4 * np.pi * G) / (2 * larr + 1) * R * rhom

    # Unitary mass-sheet anomaly
    MS = np.ones(LMAX + 1)

    # # Static geoid (used to compute potential kernel)
    # Nstatic = np.zeros(np.shape(MS))
    # for l, MSlm in enumerate(MS):
    #     Nstatic[l] = geoid_prefactor[l] * geoid_msfactor[l] * MSlm

    # Arrays where linear eqs. results will be stored
    Nlm_arr = np.zeros(np.shape(MS))
    Hlm_arr = np.zeros(np.shape(MS))
    Clm_arr = np.zeros(np.shape(MS))

    # Defining array witl all relevant interfaces
    # for the core->surface and mass-sheet->surface propagator matrices
    Rvis = np.insert(Rvis, 0, R)
    Rvis_c = np.append(Rvis, Rc)
    Rvis_ms = Rvis[Rvis > Rms]
    Rvis_ms = np.append(Rvis_ms, Rms)

    # Lithospheric flexure factors (thin elastic shell model)
    D = (E * Te ** 3) / (12 * (1 - nu ** 2))  # Shell flexural
    Re = R - 1 / 2 * Te  # Radius of elastic shell
    l = 1
    for _, MSlm in enumerate(MS[1:]):
        L = l * (l + 1)
        # n = l*(l+1) - 2
        cl1 = (-(l ** 3) * (l + 1) ** 3 + 4 * l ** 2 * (l + 1) ** 2) / (
            -l * (l + 1) + 1 - nu
        )
        cl2 = (-l * (l + 1) + 2) / (-l * (l + 1) + 1 - nu)
        el = (D / Re ** 4) * cl1 + (E * Te / Re ** 2) * cl2
        # print(el)
        # Computing propagator matrices
        Prc = np.identity(4)
        for i in range(len(Rvis_c) - 1):
            A = np.array(
                [
                    [-2, L, 0, 0],
                    [-1, 1, 0, 1 / mun[i]],
                    [12 * mun[i], -6 * L * mun[i], 1, L],
                    [-6 * mun[i], (4 * L - 2) * mun[i], -1, -2],
                ]
            )

            Pr = la.expm(A * (np.log(Rvis_c[i] / R) - np.log(Rvis_c[i + 1] / R)))
            Prc = np.matmul(Prc, Pr)

        Prm = np.identity(4)
        for i in range(len(Rvis_ms) - 1):
            A = np.array(
                [
                    [-2, L, 0, 0],
                    [-1, 1, 0, 1 / mun[i]],
                    [12 * mun[i], -6 * L * mun[i], 1, L],
                    [-6 * mun[i], (4 * L - 2) * mun[i], -1, -2],
                ]
            )
            # Pr = np.zeros(np.shape(A)).astype("float128")
            Pr = la.expm(
                A * (np.log(Rvis_ms[i] / R) - np.log(Rvis_ms[i + 1] / R))
            ).astype("float128")
            Prm = np.matmul(Prm, Pr)

        eqsys = []

        for j in range(4):

            termCM_1 = Prc[j, 2] * gc * (Rc / R)
            termmass_1 = Prm[j, 2] * (Rms * gms) / R
            if j == 2:
                termCM_2 = -fact1[l] * (Rc / R) ** (l + 2)
                termH = (rhom * g0 - fact1[l] * rhom + el) * Hlm
                termmass_2 = -fact1[l] * (Rms / R) ** (l + 2)
            else:
                termCM_2 = 0
                termH = 0
                termmass_2 = 0
            if j == 3:
                termT = -Tlm
            else:
                termT = 0
            termCM = (termCM_1 + termCM_2) * drho_core * CMlm
            termmass = termmass_1 + termmass_2
            termv = Prc[j, 1] * (mu0 / R) * Vlm

            eqsys.append(termv + termCM + termH + termT + termmass)

        eqsys = sym.Matrix(eqsys)
        solve = linsolve(eqsys, (CMlm, Vlm, Tlm, Hlm))
        sol_dict = {
            "CMlm": solve.args[0][0],
            "Vlm": solve.args[0][1],
            "Tlm": solve.args[0][2],
            "Hlm": solve.args[0][3],
        }

        Nlm_arr[l] = geoid_prefactor[l] * (
            rhom * sol_dict["Hlm"]
            + geoid_msfactor[l] * MSlm
            + geoid_corefactor[l] * sol_dict["CMlm"]
        )
        Hlm_arr[l] = sol_dict["Hlm"]
        Clm_arr[l] = sol_dict["CMlm"]

        l += 1

    Gk = Nlm_arr / MS  # Geoid kernel
    Hk = Hlm_arr / MS  # Displacement kernel
    kernels = {"D": Hk, "G": Gk}

    return kernels


def addradgravkernels(kernels, R):
    """
    Adds two new kernels to dictionary associated with radia gravity

    Parameters
    ==========
    kernels: dict
        dynamic kernels
    R: float
        mean planetary radius (local or global)
    """
    larr = np.arange(len(kernels["G"]))
    GM = G * M
    geoid2pot = kernels["G"] * (GM / (R ** 2))

    kernels["G_rad"] = geoid2pot * (larr + 1) / R
    # kernels["Z_rad"] = kernels["G_rad"] / kernels["D"]

    return kernels


def singlemspauer(kernels, OBS_LM, lmax):
    """
    Computes the mass-sheet coefficients that minimizes the misfit between
    predictions and observations for both gravity and topography.
    Based on eq. 15 from Pauer et al. 2006

    Parameters
    ==========
    kernels: dict
        dynamic kernels for a single mass-sheet model
    OBS_LM: dict
        gravity and topogrpahy spherical harmonics observations
    lmax: int
        maximum spherical harmonic coefficient

    Return
    ======
    pyshtools SHCoeffs
        Spherical harmonic coeffcients of mass anomalies in the mantle
        Degrees 0 and 1 are set to zero.

    """

    # for i, k in enumerate(dictkeys):
    Hl = kernels["D"]
    Gl = kernels["G_rad"]
    hobs = OBS_LM["topo"]["clm"]
    gobs = OBS_LM["gravr"]["clm"]

    weightl = OBS_LM["gravr"]["spec_loc"] / OBS_LM["topo"]["spec_loc"]

    Gl2 = Gl ** 2
    Hl2 = weightl * Hl ** 2

    psilm = np.zeros(np.shape(hobs.to_array(lmax=lmax)))

    for l in range(2, lmax + 1):
        topoft = hobs.coeffs[:, l, : l + 1] * Hl[l] * weightl[l]

        geoidft = gobs.coeffs[:, l, : l + 1] * Gl[l]

        psilm[:, l, : l + 1] = (topoft + geoidft) / (Hl2[l] + Gl2[l])

    return pysh.SHCoeffs.from_array(psilm)


def predgrav1ms(psilm, Gl):
    """
    Computes predicted geoid sh coefficients from dynamic model

    Parameters
    ==========
    psilm: SHCoeffs
        Spherical harmonic coeffcients of mass anomalies in the mantle
    Gl: numpy array
        geoid dynamic kernel

    Return
    ======
    pyshtools SHCoeffs
        Spherical harmonic coeffcients predicted geoid
    """

    geopred_lm = np.zeros(np.shape(psilm.to_array()))
    for l in range(len(Gl)):
        geopred_lm[:, l, : l + 1] = psilm.coeffs[:, l, : l + 1] * Gl[l]

    return pysh.SHCoeffs.from_array(geopred_lm)


def predtopo1ms(psilm, Hl):
    """
    Computes predicted topography sh coefficients from dynamic model

    Parameters
    ==========
    Hl: numpy array
        displacement dynamic kernel
    psilm: SHCoeffs
        Spherical harmonic coeffcients of mass anomalies in the mantle

    Return
    ======
    pyshtools SHCoeffs
        Spherical harmonic coeffcients predicted topography
    """

    topopred_lm = np.zeros(np.shape(psilm.to_array()))

    for l in range(len(Hl)):
        topopred_lm[:, l, : l + 1] = psilm.coeffs[:, l, : l + 1] * Hl[l]

    return pysh.SHCoeffs.from_array(topopred_lm)


def g0_x(Rx, Rc):
    """
    Computes the gravitational accelaration at an radius Rx (Rcore <= Rx < Rplanet).

    Parameters
    ==========
    Rx: float
        radius of interest (m)
    Rc: float
        Core radius (m)
    Return
    ======
    float
        gravitational accelarion (m/s2)
    """

    # Core mass
    Mcore = np.pi * 4 / 3 * (Rc ** 3) * rhocore

    if Rx == Rc:
        g0 = G * Mcore / (Rx ** 2)
    else:
        # Mantle mass a radius Rx
        # with linear density profile following values used in Herrick and Phillips (1992)
        a = (rhom - rhom_cmb) / (R - Rc)
        b = rhom_cmb - Rc * a
        Mmantle = np.pi * a * (Rx ** 4 - Rc ** 4) + np.pi * 4 / 3 * b * (
            Rx ** 3 - Rc ** 3
        )

        # Total mass
        Mx = Mmantle + Mcore

        # gravitational accelarion for mass Mx at radius Rx
        g0 = G * Mx / (Rx ** 2)

    return g0
