"""
    Author: Julia Maia
    Date: June 2023

    Set of functions used to perform the multitaper localization procdegure
    First developed by Wieczorek and Simons (2005,2007)

"""

import numpy as np

import pyshtools as pysh


def localized_observations(nwin, lwin, LMAX=40):
    """
    Computes localized (masking ishtar and ovda) and global spectra using multitapers

    Parameters
    ==========
    nwin: int
        number of tapers used for the localization
    lwin: int
        spectral bandwith of localization windows
    LMAX: int (optional)
        maximum spherical harmonic degree
        default=40
    """
    OBS = {"topo": {}, "gravr": {}}
    # Loading topography data
    topo_lm = pysh.datasets.Venus.VenusTopo719(lmax=LMAX)
    OBS["R"] = topo_lm.coeffs[0, 0, 0]  # Mean planetary radius
    topo_lm.coeffs[0, 0, 0] = 0

    # GENRATING MULTITAPERS
    # Creating bynary mask that excludes ishtar Terra, Ovda and Thetis
    gridtopo = topo_lm.expand()
    binmask = ishtaraphro_mask(gridtopo.lats(), gridtopo.lons(), gridtopo)
    # Creating localization window grids
    dict_mt = grid_tapers_from_mask(binmask, nwin, lwin, LMAX)

    # Multitaper with uniform weights
    weights = np.ones(nwin) / nwin

    dictlocclm = localized_spectrum_pertaper(topo_lm, dict_mt["clmcap"])
    OBS["topo"]["spec_loc"] = localized_spectrum(dictlocclm, weights)
    OBS["topo"]["spec"] = topo_lm.spectrum()
    OBS["topo"]["clm"] = topo_lm

    # Loading gravity data
    glm = pysh.datasets.Venus.MGNP180U(lmax=LMAX)
    # gravr_lm = loadsh.loadgravradial(LMAX, OBS["R"])
    glm.coeffs[0, 0, 0] = 0.0
    glm.set_omega(0)
    GRID_GRAV = glm.expand(f=0, a=OBS["R"])
    # Creating SHCoeffs of radial gravity from SHGravCoeff
    GRID_RAD = GRID_GRAV.rad * (-1)
    gravr_lm = GRID_RAD.expand()
    # Computing spectra
    dictlocclm = localized_spectrum_pertaper(gravr_lm, dict_mt["clmcap"])
    OBS["gravr"]["spec_loc"] = localized_spectrum(dictlocclm, weights)

    OBS["gravr"]["spec"] = gravr_lm.spectrum()
    OBS["gravr"]["clm"] = gravr_lm

    return OBS


def grid_tapers_from_mask(binmask, nwins, lwin, lmax):
    """
    Created SHGrids containing localization windows from a binary mask
    (=gridwindow_frommask)
    Parameters
    ==========
    binmask: numpy array
        2-dimensional array with binary mask
    nwins: int
        number of windows (to make a multitaper localization)
    lwin:int
        spherical harmonic bandwidth
    lmax:int
        maximum spherical harmonic degree used to compute grid

    Returns
    =======
    dict:
        SHgrids and eigenvalues of each localization window
    """

    windows = pysh.SHWindow.from_mask(binmask, lwin, nwin=nwins)

    dict_mt = {"clmcap": {}, "eigenv": {}}
    for i in range(nwins):
        clm_cap = windows.to_shcoeffs(itaper=i)
        dict_mt["clmcap"][f"w{i+1}"] = clm_cap.expand(lmax=lmax)
        dict_mt["eigenv"][f"w{i+1}"] = windows.eigenvalues[i]

    return dict_mt


def localized_spectrum_pertaper(clm, grid_caps):
    """
    Computes multitaper localized spherical harmonics coefficients
    for each localization taper

    Parameters
    ==========
    clm: SHCoeffs
        Spherical harmonic coefficients
    grid_caps: dict
        dictionary containing grid of each localization taper


    Returns
    =======
    dict:
        dictionary with localized SHCoeffs for each localization taper
    """
    grid = clm.expand()
    # grid_caps = gridwindow_frommask(binmask, nwins, lwin, lmax, verbose)
    locclm = {}
    for _, k in enumerate(grid_caps):
        grid_loc = grid * grid_caps[k]
        locclm[k] = grid_loc.expand()

    return locclm


def localized_spectrum(mtclm, weights):
    """
    Computes the final localized spectra using multitapers.

    Parameters
    ==========
    mtclm: dict
        dictionary containing localized SHCoeffs for each localization taper
    weights: array like
        weight applied to each localized spectrum.
    """
    locspec = np.zeros((len(mtclm), mtclm["w1"].lmax + 1))

    for i, k in enumerate(mtclm):
        locspec[i] = pysh.SHCoeffs.spectrum(mtclm[k])

    mtspec = np.dot(weights, locspec)

    return mtspec


def ishtaraphro_mask(lats, lons, topogrid):
    """
    Creates a binary mask masking out ishtar terra and western aphrodite terra.
    This mask if useful if one wants to explore dynamic support on Venus in a global scale.

    Parameters
    ==========
    lats: numpy array
        grid latitudes
    lons: numpy array
        grid longitudes

    Returns
    =======
    numpy array:
        2d array binary mask
    """
    latgrid, longrid = np.meshgrid(lats, lons, indexing="ij")

    # Ishtar terra coordinates
    window = (50 < latgrid) & (latgrid < 90) & (280 < longrid) & (longrid <= 360)
    window += (50 < latgrid) & (latgrid < 90) & (0 <= longrid) & (longrid < 90)

    # aphrodite coordinates
    window += (-20 < latgrid) & (latgrid < 8) & (50 < longrid) & (longrid < 140)

    # window = np.invert(window)
    grid_window = pysh.SHGrid.from_array(window.astype(np.float64))

    grid_window = grid_window.data * (topogrid.data > 0)
    grid_window = np.invert(np.array(grid_window, dtype=bool))

    return grid_window
