#! /usr/bin/env python
# ==========================================================================
# Create SEDs figure for paper
#
# Copyright (C) 2024-2025 Juergen Knoedlseder
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# ==========================================================================
import os
import sys
import math
import gammalib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D


# ============ #
# Get filename #
# ============ #
def get_filename(band, brems, ic, prefix='butterfly', ia='ia-g-plaw', dge='logps',
                 prefitted=False, source='ia', filetype='.fits'):
    """
    Get butterfly filename

    Parameters
    ----------
    band : str
        Energy band ('low', 'high', 'full')
    brems : str
        Bremsstrahlungs component ('none','map','gmap')
    ic : str
        Inverse Compton component ('none','map','gmap',[2D name])
    prefix : str, optional
        Filename prefix
    ia : str, optional
        In-flight annihilation component string
    dge : str, optional
        DGE spectral fitting string
    prefitted : boolean, optional
        Use prefitted results?
    source : str, optional
        Resource string
    filetype : str, optional
        File type

    Returns
    -------
    filename : str
        Result XML filename
    """
    # Set filename
    if band == 'low':
        energies = '00750-03000keV_4bins_wo26Al'
    elif band == 'high':
        energies = '03000-30000keV_4bins'
    elif band == 'full':
        energies = '16bins_wo26Al'
    if brems == 'map' or ic == 'map':
        map = 'map'
    elif brems == 'gmap' or ic == 'gmap':
        map = 'gmap'
    else:
        map = ic
    suffix = map
    if brems != 'none':
        suffix += '-brems-map'
    if ic != 'none':
        suffix += '-ic-map'
    if prefitted:
        fitmode = '_prefitted'
    else:
        fitmode = ''
    if source != None:
        fitmode += '_%s' % (source)
    filename = '%s_com_240x140_drwnorm2_%s_conv-%s_%s-' \
               'g14-3c9-3c3-p18-p05-cr-ve-x1-ca-ls-%s-' \
               'isofix_%s_bgdlixf_ni11%s%s' % (prefix, energies, map, ia, suffix, dge, fitmode, filetype)

    # If file does not exist then try prepending results/
    if not os.path.isfile(filename):
        alt_filename = 'results/%s' % (filename)
        if os.path.isfile(alt_filename):
            filename = alt_filename

    # Return filename
    return filename


# ===================================== #
# Extract the spectrum info from a file #
# ===================================== #
def get_spectrum_file(filename, source='IA'):
    """
    Extract the spectrum info from a file for plotting

    Parameters
    ----------
    filename : str
        Name of spectrum FITS file
    source : str, optional
        Source name

    Returns
    -------
    spec : dict
        Python dictionary defining spectral plot parameters
    """
    # Initialise dictionary
    spec = None

    # Get filename
    fname = gammalib.GFilename(filename)

    # If filename is a FITS file then extract spectrum from FITS file
    if fname.is_fits():
        fits = gammalib.GFits(filename)
        if fits.table(1).string('OBJECT') == source:
            spec = get_spectrum_fits(fits)

    # ... otherwise extract spectrum from model
    else:
        models = gammalib.GModels(filename)
        if models.contains(source):
            model  = models[source].spectral()
            spec   = get_spectrum_model(model)

    # Return dictionary
    return spec


# ============================================= #
# Extract the spectrum info from a GFits object #
# ============================================= #
def get_spectrum_fits(fits):
    """
    Extract the spectrum info from a GFits object

    Parameters
    ----------
    fits : `~gammalib.GFits`
        Spectral GFits object
    
    Returns
    -------
    spec : dict
        Python dictionary defining spectral plot parameters
    """
    # Read spectrum objects
    table    = fits.table(1)
    c_energy = table['e_ref']
    c_ed     = table['e_min']
    c_eu     = table['e_max']
    c_flux   = table['ref_e2dnde']
    c_norm   = table['norm']
    c_eflux  = table['norm_err']
    c_upper  = table['norm_ul']
    c_ts     = table['ts']

    # Initialise arrays to be filled
    spec = {
        'energies'    : [],
        'flux'        : [],
        'ed_engs'     : [],
        'eu_engs'     : [],
        'e_flux'      : [],
        'ul_energies' : [],
        'ul_ed_engs'  : [],
        'ul_eu_engs'  : [],
        'ul_flux'     : [],
        'yerr'        : [],
    }

    # Determine if we can load the delta-log-likelihood profiles
    has_sedtype = table.has_card('SED_TYPE')
    load_dll    = False
    if has_sedtype:
        seds = table.card('SED_TYPE').string().split(',')

    # Loop over rows of the file
    nrows = table.nrows()
    for row in range(nrows):

        # Get Test Statistic, flux and flux error
        ts    = c_ts.real(row)
        norm  = c_norm.real(row)
        flx   = norm * c_flux.real(row)
        e_flx = flx * c_eflux.real(row)

        # If Test Statistic is larger than 1 and flux error is smaller than
        # flux then append flux plots ...
        if ts > 1.0 and e_flx < flx:
            spec['energies'].append(c_energy.real(row)*1.0e6)
            spec['flux'].append(flx)
            spec['ed_engs'].append(c_ed.real(row)*1.0e6)
            spec['eu_engs'].append(c_eu.real(row)*1.0e6)
            spec['e_flux'].append(e_flx)

        # ... otherwise append upper limit
        else:
            spec['ul_energies'].append(c_energy.real(row)*1.0e6)
            spec['ul_flux'].append(flx*c_upper.real(row))
            spec['ul_ed_engs'].append(c_ed.real(row)*1.0e6)
            spec['ul_eu_engs'].append(c_eu.real(row)*1.0e6)

    # Set upper limit errors
    spec['yerr'] = [0.6 * x for x in spec['ul_flux']]

    # Return dictionary
    return spec


# ============================================ #
# Extract the spectrum from GModelSpectralBins #
# ============================================ #
def get_spectrum_model(model):
    """
    Extract the spectrum info from a GFits object

    Parameters
    ----------
    model : `~gammalib.GModelSpectralBins`
        Spectral bins model
    
    Returns
    -------
    spec : dict
        Python dictionary defining spectral plot parameters
    """
    # Initialise arrays to be filled
    spec = {
        'energies'    : [],
        'flux'        : [],
        'ed_engs'     : [],
        'eu_engs'     : [],
        'e_flux'      : [],
        'ul_energies' : [],
        'ul_ed_engs'  : [],
        'ul_eu_engs'  : [],
        'ul_flux'     : [],
        'yerr'        : [],
    }

    # Loop over bins
    nbins = model.bins()
    for bin in range(nbins):
        emin   = model.emin(bin).MeV()
        emax   = model.emax(bin).MeV()
        flux   = model.intensity(bin)
        error  = model.error(bin)
        emean  = math.sqrt(emin*emax)
        ed_eng = emean - emin
        eu_eng = emax - emean
        norm   = emean * emean * gammalib.MeV2erg
        flux  *= norm
        error *= norm
        spec['energies'].append(emean)
        spec['flux'].append(flux)
        spec['ed_engs'].append(ed_eng)
        spec['eu_engs'].append(eu_eng)
        spec['e_flux'].append(error)

    # Return dictionary
    return spec


# ================================== #
# Read butterfly data from FITS file #
# ================================== #
def read_butterfly_fits(filename):
    """
    Read butterfly data from FITS file

    Parameters
    ----------
    filename : str
        Name of FITS file

    Returns
    -------
    butterfly : dict
        Python dictionary defining butterfly plot and best fit spectrum
    """
    # Initialise arrays to be filled
    butterfly = {'butterfly_x' : [],
                 'butterfly_y' : [],
                 'line_x'      : [],
                 'line_y'      : []}

    # Open FITS file
    fits = gammalib.GFits(filename)

    # Get sensitivity table
    table = fits.table('BUTTERFLY')

    # Get relevant columns
    c_energy        = table['ENERGY']
    c_intensity     = table['INTENSITY']
    c_intensity_min = table['INTENSITY_MIN']
    c_intensity_max = table['INTENSITY_MAX']

    # Fill vectors
    nrows = table.nrows()
    for row in range(nrows):
    
        # Get conversion coefficient TeV -> erg
        conv = c_energy[row] * c_energy[row] * 1.0e6 * gammalib.MeV2erg

        # Compute upper edge of confidence band
        butterfly['butterfly_x'].append(c_energy[row]*1.0e6)
        butterfly['butterfly_y'].append(c_intensity_max[row] * conv)

        # Set line values
        butterfly['line_x'].append(c_energy[row]*1.0e6)
        butterfly['line_y'].append(c_intensity[row] * conv)

    # Loop over the rows backwards to compute the lower edge of the
    # confidence band
    for row in range(nrows-1,-1,-1):
        conv      = c_energy[row] * c_energy[row] * 1.0e6 * gammalib.MeV2erg
        low_error = max(c_intensity_min[row] * conv, 1e-26)
        butterfly['butterfly_x'].append(c_energy[row]*1.0e6)
        butterfly['butterfly_y'].append(low_error)

    # Return butterfly dictionary
    return butterfly


# ================================= #
# Read butterfly data from XML file #
# ================================= #
def read_butterfly_xml(filename, emin=0.75, emax=30.0, n=100, source='IA'):
    """
    Read butterfly data from XML file

    Parametes
    ---------
    ax : pyplot
        Plotting frame
    model : `~gammalib.GSpectralModel()`
        Spectral model
    emin : float
        Minimum energy (MeV)
    emax : float
        Maximum energy (MeV)
    n : int, optional
        Number of energies used for plotting
    source : str, optional
        Source name
    """
    # Set flux conversion factor
    conv = gammalib.MeV2erg

    # Set energies
    energies = gammalib.GEnergies(n, gammalib.GEnergy(emin, 'MeV'),
                                     gammalib.GEnergy(emax, 'MeV'))

    # Initialise arrays to be filled
    butterfly = {'butterfly_x' : [],
                 'butterfly_y' : [],
                 'line_x'      : [],
                 'line_y'      : []}

    # Get spectral model
    models = gammalib.GModels(filename)
    model  = models[source].spectral().copy()

    # Set x vector
    butterfly['line_x'] = [energy.MeV() for energy in energies]

    # Set y vector
    butterfly['line_y'] = [model.eval(energy) * energy.MeV() * energy.MeV() * conv for energy in energies]

    # Extract names of all free paramters in model
    pars = []
    for par in model:
        if par.is_free():
            pars.append(par.name())

    # Compute number of parameter change combinations
    ncomb = int(math.pow(2,len(pars))+0.5)

    # Loop over all combinations and append a model with changed parameters
    models = []
    for i in range(ncomb):
        mod = model.copy()
        k   = 1
        for par in pars:
            mod[par].remove_range()
            pos = (i & k == k)
            if pos:
                mod[par].value(model[par].value()+model[par].error())
            else:
                mod[par].value(model[par].value()-model[par].error())
            k *= 2
        models.append(mod)

    # Build butterfly vectors
    y_min    = []
    for i, energy in enumerate(energies):
        ymin = 1.0e30
        ymax = 0.0
        for m in models:
            y_val = m.eval(energy) * energy.MeV() * energy.MeV() * conv
            if y_val < ymin:
                ymin = y_val
            if y_val > ymax:
                ymax = y_val
        y_min.append(ymin)
        butterfly['butterfly_x'].append(butterfly['line_x'][i])
        butterfly['butterfly_y'].append(ymax)
    for i in range(n-1,-1,-1):
        butterfly['butterfly_x'].append(butterfly['line_x'][i])
        butterfly['butterfly_y'].append(y_min[i])

    # Return butterfly dictionary
    return butterfly


# ============== #
# Plot butterfly #
# ============== #
def plot_butterfly(ax, linecolor='red', color='red', hatch=None, source='IA', linestyle='-', alpha=0.5):
    """
    Plot butterfly diagram

    Parameters
    ----------
    ax : pyplot
        Plotting frame
    linecolor : str, optional
        Color of line
    color : str, optional
        Color of butterfly
    hatch : str, optional
        Hatching of butterfly
    source : str, optional
        Source name
    linestyle : str, optional
        Linestyle
    alpha : float, optional
        Alpha for butterfly
    """
    # Set model string
    model = '%s-%s-%s-%s' % ('gc-l-0.12+b0.66', 'plaw', 'ia-g-l1.4-b0.2r3.3', 'logp')

    # Get butterfly filename
    filename = get_filename('full', 'none', 'icmaps358003-507006', prefix='butterfly',
                            ia=model, prefitted=False, source=gammalib.tolower(source))

    # Continue only if file exists
    if os.path.isfile(filename):

        # Read butterfly data
        if '.fits' in filename:
            butterfly = read_butterfly_fits(filename)
        else:
            butterfly = read_butterfly_xml(filename, source=source)

        # Plot butterfly
        ax.fill(butterfly['butterfly_x'], butterfly['butterfly_y'], color=color, alpha=alpha, hatch=hatch)

        # Plot spectral line
        ax.plot(butterfly['line_x'], butterfly['line_y'], linestyle=linestyle, color=linecolor, zorder=-3)

    # ... otherwise indicate that file is missing
    else:
        print('--- %s' % filename)

    # Return
    return


# ======== #
# Plot SED #
# ======== #
def plot_sed(ax, color='red', source='IA', marker='None', markersize=None, linewidth=None, label=None):
    """
    Plot SED

    Parameters
    ----------
    ax : pyplot
        Plotting frame
    color : str, optional
        Color of SED
    source : str, optional
        Source name
    markersize : float, optional
        Marker size
    linewidth : float, optional
        Line width
    label : str, optional
        Label
    """
    # Set SED string
    sed = '%s-%s-%s-%s' % ('gc-l-0.12+b0.66', 'bins3', 'ia-g-l1.4-b0.2r3.3', 'bins8')

    # Get SED filename
    filename = get_filename('full', 'none', 'icmaps358003-507006', prefix='spectrum_bins',
                            ia=sed, prefitted=False, source=gammalib.tolower(source))

    # Continue only if file exists
    if os.path.isfile(filename):

        # Read SED
        if 'IA' in source:
            source = 'IA'
        else:
            source = 'GC'
        spec = get_spectrum_file(filename, source=source)

        # Plot error bars
        ax.errorbar(spec['energies'], spec['flux'],
                    yerr=spec['e_flux'], xerr=[spec['ed_engs'], spec['eu_engs']],
                    marker=marker, color=color, linestyle='None',
                    markersize=markersize, linewidth=linewidth, label=label)

        # Plot upper limits
        if len(spec['ul_energies']) > 0:
            ax.errorbar(spec['ul_energies'], spec['ul_flux'], yerr=spec['yerr'],
                        xerr=[spec['ul_ed_engs'], spec['ul_eu_engs']],
                        uplims=True, marker=marker, color=color, linestyle='None',
                        markersize=markersize, linewidth=linewidth)

    # ... otherwise indicate that file is missing
    else:
        print('--- %s' % filename)

    # Return
    return


# =============================== #
# Plot systematic uncertainty box #
# =============================== #
def plot_systematics(ax, source='IA', color='red', alpha=0.4, width=0.025, display='bars', spectrum=True):
    """
    Plot systematic uncertainty box

    Parameters
    ----------
    ax : pyplot
        Plotting frame
    source : string, optional
        Source
    color : str, optional
        Colour
    alpha : float, optional
        Alpha
    width : float, optional
        Uncertainty bar width
    display : str, optional
        Display style ('box' or 'bars')
    spectrum : boolean, optional
        Use spectrum instead of results
    """
    # Set SED string
    sed = '%s-%s-%s-%s' % ('gc-l-0.12+b0.66', 'bins3', 'ia-g-l1.4-b0.2r3.3', 'bins8')

    # Set different DGEs
    dges = [{'brems': 'none', 'ic': 'icmaps358003-507006'},
            {'brems': 'none', 'ic': 'icmaps60-06_60-02'},
            {'brems': 'none', 'ic': 'map'},
            {'brems': 'none', 'ic': 'gmap'},
            {'brems': 'map',  'ic': 'map'},
            {'brems': 'gmap', 'ic': 'gmap'}]

    # Initialise energies and minimum/maximum fluxes
    energies    = None
    flux_min    = None
    flux_max    = None
    ul_energies = []
    ul_ed_engs  = []
    ul_eu_engs  = []

    # Loop over DGEs
    for dge in dges:

        # Get result filename
        if spectrum:
            filename = get_filename('full', dge['brems'], dge['ic'], prefix='spectrum_bins', ia=sed,
                                    source=gammalib.tolower(source), filetype='.fits')
        else:
            filename = get_filename('full', dge['brems'], dge['ic'], prefix='results', ia=sed,
                                    source=None, filetype='.xml')

        # Continue only if file exists
        if os.path.isfile(filename):

            # Get spectrum
            spec = get_spectrum_file(filename, source=gammalib.toupper(source).lstrip('68_'))

            # Continue only if spectrum was found
            if spec != None:

                # Build vectors from results
                if spectrum:
                    spec_energies = []
                    spec_flux_max = []
                    spec_flux_min = []
                    for i in range(len(spec['energies'])):
                        spec_energies.append(spec['energies'][i])
                        if display == 'bars':
                            spec_flux_min.append(spec['flux'][i]-spec['e_flux'][i])
                            spec_flux_max.append(spec['flux'][i]+spec['e_flux'][i])
                        else:
                            spec_flux_min.append(spec['flux'][i])
                            spec_flux_max.append(spec['flux'][i])
                    for i in range(len(spec['ul_energies'])):
                        ul_energies.append(spec['ul_energies'][i])
                        ul_ed_engs.append(spec['ul_ed_engs'][i])
                        ul_eu_engs.append(spec['ul_eu_engs'][i])
                        spec_energies.append(spec['ul_energies'][i])
                        spec_flux_min.append(spec['ul_flux'][i])
                        spec_flux_max.append(spec['ul_flux'][i])
                else:
                    spec_energies = spec['energies']
                    spec_flux_min = spec['flux']
                    spec_flux_max = spec['flux']

                # Determine minimum and maximum flux values
                if flux_min == None:
                    energies = spec_energies
                    flux_min = list(spec_flux_min)
                    flux_max = list(spec_flux_max)
                else:
                    for i, flux in enumerate(spec_flux_min):
                        if flux < flux_min[i]:
                            flux_min[i] = flux
                    for i, flux in enumerate(spec_flux_max):
                        if flux > flux_max[i]:
                            flux_max[i] = flux

            # ... otherwise indicate that no spectrum was found
            else:
                print('--- No spectrum found for source "%s" in file "%s"' % (source,filename))

        # ... otherwise indicate that file is missing
        else:
            print('--- %s' % filename)

    # Plot uncertainty box
    if energies != None:
        if display == 'bars':
            for i, energy in enumerate(energies):
                if 'IA' in source and i > 2:
                    k          = ul_energies.index(energy)
                    energy_min = ul_ed_engs[k]
                    energy_max = ul_eu_engs[k]
                    yerr_min   = 0.6 * flux_min[i]
                    yerr_max   = 0.6 * flux_max[i]
                    eb1        = ax.errorbar([energy], [flux_min[i]], xerr=[[energy_min], [energy_max]],
                                             yerr=[yerr_min], uplims=True, color=color, alpha=alpha)
                    eb2        = ax.errorbar([energy], [flux_max[i]], xerr=[[energy_min], [energy_max]],
                                             yerr=[yerr_max], uplims=True, color=color, alpha=alpha)
                    eb1[-1][0].set_linestyle('--')
                    eb1[-1][1].set_linestyle('--')
                    eb2[-1][0].set_linestyle('--')
                    eb2[-1][1].set_linestyle('--')
                else:
                    ax.plot([energy, energy], [flux_min[i], flux_max[i]], linestyle='dashed', color=color, alpha=0.5)
        else:
            for i, energy in enumerate(energies):
                x = [(1.0-width)*energy, (1.0+width)*energy, (1.0+width)*energy, (1.0-width)*energy, (1.0-width)*energy]
                y = [flux_min[i], flux_min[i], flux_max[i], flux_max[i], flux_min[i]]
                if alpha > 0.0:
                    ax.fill(x, y, color=color, alpha=alpha, linewidth=0)
                else:
                    ax.plot(x, y, color=color, linewidth=1)

    # Return
    return


# ========================= #
# Show SED figure for paper #
# ========================= #
def show_seds():
    """
    Show SED figure for paper
    """
    # Create figure
    fig = plt.figure(figsize=(6,4))
    fig.subplots_adjust(left=0.13, bottom=0.12, right=0.99, top=0.98)

    # Create plotting frame
    ax = fig.add_subplot(111)

    # Plot SEDs
    plot_sed(ax, source='68_IA', color='red')
    plot_sed(ax, source='68_GC', color='blue')

    # Plot systematic uncertainty
    plot_systematics(ax, source='68_IA', color='red')
    plot_systematics(ax, source='68_GC', color='blue')

    # Plot butterflies of fitted models
    plot_butterfly(ax, source='IA', linecolor='red',  color='tomato',     linestyle='-')
    plot_butterfly(ax, source='GC', linecolor='blue', color='dodgerblue', linestyle='--', hatch='xxxxx')

    # Set legend
    handles = []
    handles.append(Line2D([0], [0], marker='None', color='red',  linestyle='-',  label='Bulge'))
    handles.append(Line2D([0], [0], marker='None', color='blue', linestyle='--', label='Core'))
    ax.legend(handles=handles, loc='lower right', fontsize=9)

    # Set attributes
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlim([0.75,30.0])
    ax.set_ylim([1e-11,1e-9])
    ax.set_xlabel('Energy (MeV)')
    ax.set_ylabel(r'E$^2$ $\times$ dN/dE (erg cm$^{-2}$ s$^{-1}$)')

    # Show plot
    plt.show()

    # Save figure
    fig.savefig('fig3.pdf', dpi=300)

    # Return
    return


# ======================== #
# Main routine entry point #
# ======================== #
if __name__ == '__main__':

    # Show SED figure for paper
    show_seds()
