#! /usr/bin/env python
# ==========================================================================
# Shows TS maps
#
# Copyright (C) 2024-2025 Juergen Knoedlseder
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# ==========================================================================
import os
import sys
import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
from matplotlib.patches import Circle
from matplotlib.offsetbox import AuxTransformBox
from matplotlib.offsetbox import AnchoredOffsetbox
import gammalib
import cscripts


# ============ #
# Get filename #
# ============ #
def get_filename(band, brems, ic, prefix='tsmap', ia='ia-g-plaw', dge='plaws', bkg='_bgdlixf',
                 method='', suffix='_prefitted_fixptsrc_200x50', path='tsmaps/'):
    """
    Get TS map filename

    Parameters
    ----------
    band : str
        Energy band ('low', 'high', 'full')
    brems : str
        Bremsstrahlungs component ('none','map','gmap')
    ic : str
        Inverse Compton component ('none','map','gmap',[2D name])
    prefix : str, optional
        Filename prefix
    ia : str, optional
        In-flight annihilation component string
    dge : str, optional
        DGE spectral fitting string
    bkg : str, optional
        Background model suffix
    method : str, optional
        Background method prefix (e.g. '', 'drw-const')
    suffix : str, optional
        Suffix
    path : str, optional
        Optional path in which file may reside

    Returns
    -------
    filename : str
        Result XML filename
    """
    # Set filename
    if band == 'low':
        energies = '00750-03000keV_4bins_wo26Al'
    elif band == 'high':
        energies = '03000-30000keV_4bins'
    elif band == 'full':
        energies = '16bins_wo26Al'
    if brems == 'map' or ic == 'map':
        map = 'map'
    elif brems == 'gmap' or ic == 'gmap':
        map = 'gmap'
    else:
        map = ic
    templates = map
    if brems != 'none':
        templates += '-brems-map'
    if ic != 'none':
        templates += '-ic-map'
    if bkg == '':
        drwnorm = ''
    else:
        drwnorm = '_drwnorm2'
    filename = '%s_com_240x140%s_%s_%sconv-%s_%s-' \
               'g14-3c9-3c3-p18-p05-cr-ve-x1-ca-ls-%s-' \
               'isofix_%s%s_ni11%s.fits' % (prefix, drwnorm, energies, method, map, ia, templates, dge, bkg, suffix)

    # If file does not exist then try prepending path
    if not os.path.isfile(filename):
        alt_filename = '%s%s' % (path, filename)
        if os.path.isfile(alt_filename):
            filename = alt_filename

    # Return filename
    return filename


# ================== #
# Draw filled circle #
# ================== #
def draw_circle(ax, radius=4.0):
    """
    Draw circle in data coordinates
    """
    area = AuxTransformBox(ax.transData)
    area.add_artist(Circle((10, 10), radius, fc='black'))
    box = AnchoredOffsetbox(child=area, loc='lower left', pad=0, frameon=False)
    ax.add_artist(box)

    # Return
    return


# =========== #
# Plot TS map #
# =========== #
def plot_tsmap(ax, filename, smooth=0.0, tsmin=None, tsmax=None, symmetric=False, cmap='coolwarm',
               xlabel=True, ylabel=True, linewidths=1, fontsize=7, labelsize=6):
    """
    Plot map

    Parameters
    ----------
    ax : pyplot
        Frame for map
    filename : str
        Input FITS filename
    smooth : float, optional
        Map smoothing parameter (degrees)
    tmin : float, optional
        Minimum TS value
    tmax : float, optional
        Maximum TS value
    symmetric : boolean, optional
        Symmetric color bar
    cmap : str, optional
        Color map
    xlabel : boolean, optional
        Show x-axis label
    ylabel : boolean, optional
        Show y-axis label
    linewidths : float, optional
        Linewidth for contours
    fontsize : float, optional
        Fontsize for labels
    labelsize : float, optional
        Label size for colorbar
    """
    # Load FITS file
    fits = gammalib.GFits(filename)

    # Extract TS map and prefactor map
    tsmap = gammalib.GSkyMap(fits[0])
    pmap  = gammalib.GSkyMap(fits[1])

    # Plot TS map
    plot_map(ax, tsmap, premap=pmap, tsmin=tsmin, tsmax=tsmax, xlabel=xlabel, ylabel=ylabel,
             fontsize=fontsize, labelsize=labelsize, symmetric=symmetric)

    # Plot contours
    plot_contours(ax, tsmap, premap=pmap, linewidths=linewidths)

    # Return
    return


# ======== #
# Plot map #
# ======== #
def plot_map(ax, inmap, premap=None, smooth=0.0, tsmin=None, tsmax=None, symmetric=False,
             cmap='coolwarm', xlabel=True, ylabel=True, fontsize=7, labelsize=6):
    """
    Plot map

    Parameters
    ----------
    ax : pyplot
        Frame for map
    inmap : `~gammalib.GSkyMap()`
        Input sky map
    premap : `~gammalib.GSkyMap()`, optional
        Prefactor map
    smooth : float, optional
        Map smoothing parameter (degrees)
    tmin : float, optional
        Minimum TS value
    tmax : float, optional
        Maximum TS value
    symmetric : boolean, optional
        Symmetric color bar
    cmap : str, optional
        Color map
    xlabel : boolean, optional
        Show x-axis label
    ylabel : boolean, optional
        Show y-axis label
    fontsize : float, optional
        Font size for label
    labelsize : float, optional
        Label size for colorbar
    """
    # Optionally smooth map
    if smooth > 0.0:
        map = inmap.copy()
        map.smooth('DISK', smooth)
    else:
        map = inmap

    # Create array from skymap
    array   = []
    val_max = 0.0
    val_min = 0.0
    for iy in range(map.ny()-1,-1,-1):
        row = []
        for ix in range(map.nx()):
            index = ix+iy*map.nx()
            value = map[index]
            if premap != None:
                if premap[index] < 0.0:
                    value = -value
            if value > val_max:
                val_max = value
            if value < val_min:
                val_min = value
            row.append(value)
        array.append(row)

    # Set minimum and maximum
    if tsmin != None and tsmax != None:
        vmin = tsmin
        vmax = tsmax
    elif tsmin != None:
        vmin = tsmin
        vmax = None
    elif tsmax != None:
        vmin = None
        vmax = tsmax
    elif symmetric:
        vmin = -max([-val_min,val_max])
        vmax =  max([-val_min,val_max])
    else:
        vmin = None
        vmax = None

    # Get skymap boundaries
    nx_max = map.nx() - 0.5
    ny_max = map.ny() - 0.5
    if inmap.projection().coordsys() == 'GAL':
        l_min  = map.pix2dir(gammalib.GSkyPixel(-0.5,ny_max/2)).l_deg()
        l_max  = map.pix2dir(gammalib.GSkyPixel(nx_max,ny_max/2)).l_deg()
        b_min  = map.pix2dir(gammalib.GSkyPixel(nx_max/2,-0.5)).b_deg()
        b_max  = map.pix2dir(gammalib.GSkyPixel(nx_max/2,ny_max)).b_deg()
        xtext  = 'Galactic longitude (deg)'
        ytext  = 'Galactic latitude (deg)'
    else:
        l_min  = map.pix2dir(gammalib.GSkyPixel(-0.5,ny_max/2)).ra_deg()
        l_max  = map.pix2dir(gammalib.GSkyPixel(nx_max,ny_max/2)).ra_deg()
        b_min  = map.pix2dir(gammalib.GSkyPixel(nx_max/2,-0.5)).dec_deg()
        b_max  = map.pix2dir(gammalib.GSkyPixel(nx_max/2,ny_max)).dec_deg()
        xtext  = 'Right Ascension (deg)'
        ytext  = 'Declination (deg)'
    if l_max > 180.0:
        l_min -= 360.0
        l_max -= 360.0
    if l_min < 0 and l_max < 0:
        l_min += 360.0
        l_max += 360.0
    if l_max >= l_min:
        l_max -= 360.0
    aspect = abs((l_max-l_min)/(b_max-b_min))

    # Normalise to midpoint of zero
    norm = MidpointNormalize(midpoint=0)

    # Show Aitoff projection
    c = ax.imshow(array, extent=(l_min,l_max,b_min,b_max),
                  cmap=plt.get_cmap(cmap), vmin=vmin, vmax=vmax, norm=norm)
    cbar = plt.colorbar(c, orientation='vertical', shrink=0.7, ax=ax, pad=0.02)
    cbar.ax.tick_params(labelsize=labelsize)
    ax.set_xlim([l_min,l_max])
    ax.set_ylim([b_min,b_max])
    if xlabel:
        ax.set_xlabel(xtext, fontsize=fontsize)
    if ylabel:
        ax.set_ylabel(ytext, fontsize=fontsize)

    # Return
    return


# ================================================ #
# Mid point normalisation for asymmetric TS values #
# ================================================ #
class MidpointNormalize(mcolors.Normalize):
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        mcolors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        v_ext = np.max( [ np.abs(self.vmin), np.abs(self.vmax) ] )
        x, y = [-v_ext, self.midpoint, v_ext], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y))


# ============= #
# Plot contours #
# ============= #
def plot_contours(ax, inmap, premap=None, smooth=0.0, linestyle='solid', linewidths=1, color='black',
                  levels=[9.0, 16.0, 25.0, 36.0, 49.0, 64.0, 81.0, 100.0, 121.0, 144.0, 169.0, 196.0, 225.0, 256.0, 289.0, 324.0, 361.0, 400.0, 441.0, 484.0, 529.0, 576.0, 625.0, 676.0, 729.0, 784.0, 841.0, 900.0, 961.0, 1024.0, 1089.0]):
    """
    Plot contours

    Parameters
    ----------
    ax : pyplot
        Frame for map
    inmap : `~gammalib.GkyMap'
        Input sky map
    premap : `~gammalib.GSkyMap()`, optional
        Prefactor map
    smooth : float, optional
        Map smoothing parameter (degrees)
    linestyle : str, optional
        Linestyle for contours
    linewidths : str, optional
        Linewidth for contours
    color : str, optional
        Colour for contours
    levels : list, optional
        Contour levels
    """
    # Optionally smooth maps
    if smooth > 0.0:
        map = inmap.copy()
        map.smooth('GAUSSIAN', smooth)
    else:
        map = inmap

    # Create array from skymap
    array   = []
    val_max = 0.0
    val_min = 0.0
    for iy in range(map.ny()):
        row = []
        for ix in range(map.nx()):
            index = ix+iy*map.nx()
            value = map[index]
            if premap != None:
                if premap[index] < 0.0:
                    value = -value
            if value > val_max:
                val_max = value
            if value < val_min:
                val_min = value
            row.append(value)
        array.append(row)

    # Get skymap boundaries
    nx_max = map.nx() - 0.5
    ny_max = map.ny() - 0.5
    if inmap.projection().coordsys() == 'GAL':
        l_min  = map.pix2dir(gammalib.GSkyPixel(-0.5,ny_max/2)).l_deg()
        l_max  = map.pix2dir(gammalib.GSkyPixel(nx_max,ny_max/2)).l_deg()
        b_min  = map.pix2dir(gammalib.GSkyPixel(nx_max/2,-0.5)).b_deg()
        b_max  = map.pix2dir(gammalib.GSkyPixel(nx_max/2,ny_max)).b_deg()
        xlabel = 'Galactic longitude (deg)'
        ylabel = 'Galactic latitude (deg)'
    else:
        l_min  = map.pix2dir(gammalib.GSkyPixel(-0.5,ny_max/2)).ra_deg()
        l_max  = map.pix2dir(gammalib.GSkyPixel(nx_max,ny_max/2)).ra_deg()
        b_min  = map.pix2dir(gammalib.GSkyPixel(nx_max/2,-0.5)).dec_deg()
        b_max  = map.pix2dir(gammalib.GSkyPixel(nx_max/2,ny_max)).dec_deg()
        xlabel = 'Right Ascension (deg)'
        ylabel = 'Declination (deg)'
    if l_max > 180.0:
        l_min -= 360.0
        l_max -= 360.0
    if l_min < 0 and l_max < 0:
        l_min += 360.0
        l_max += 360.0
    if l_max >= l_min:
        l_max -= 360.0
    aspect = abs((l_max-l_min)/(b_max-b_min))

    # Show contours in sigma
    if val_max >= levels[0]:
        ax.contour(array, levels, extent=(l_min,l_max,b_min,b_max),
                   linewidths=linewidths, linestyles=linestyle, colors=color)

    # Return
    return


# ========================= #
# Show main figure of paper #
# ========================= #
def show_tsmaps_main():
    """
    Show main figure of paper
    """
    # Set figure size
    figwidth  = 183.0 / 25.4;
    figheight =  95.0 / 25.4;

    # Create figure
    fig = plt.figure(figsize=(figwidth,figheight))
    fig.subplots_adjust(left=0.07, right=1.1, top=1.02, bottom=0.07, wspace=0.0, hspace=0.01)

    # Create frames
    ax1 = fig.add_subplot(211)
    ax2 = fig.add_subplot(212)

    # Get filenames
    lowname  = get_filename('low',  'none', 'icmap35.0-80.0-03.0')
    highname = get_filename('high', 'none', 'icmap50.0-70.0-06.0')

    # Show 0.75-3 MeV TS map
    if os.path.isfile(lowname):
        plot_tsmap(ax1, lowname, xlabel=False)

    # Show 3-30 MeV TS map
    if os.path.isfile(highname):
        plot_tsmap(ax2, highname)

    # Add labels
    ax1.text(0.01, 0.97, 'a', color='black', fontsize=8, weight='bold',
            horizontalalignment='left', verticalalignment='top',
            transform=ax1.transAxes)
    ax2.text(0.01, 0.97, 'b', color='black', fontsize=8, weight='bold',
             horizontalalignment='left', verticalalignment='top',
             transform=ax2.transAxes)

    # Add PSF circles (68% containment radius)
    draw_circle(ax1, radius=2.0)
    draw_circle(ax2, radius=1.4)

    # Set attributes
    ax1.tick_params(axis='both', which='major', labelsize=6)
    ax1.tick_params(axis='both', which='minor', labelsize=5)
    ax2.tick_params(axis='both', which='major', labelsize=6)
    ax2.tick_params(axis='both', which='minor', labelsize=5)

    # Show plot
    plt.show()

    # Save figure
    fig.savefig('fig1.pdf', dpi=300)

    # Return
    return


# ========================================================= #
# Show extended data figure for different background models #
# ========================================================= #
def show_tsmaps_bkg():
    """
    Show extended data figure for different background models
    """
    # Set figure size
    figwidth  = 183.0 / 25.4;
    figheight =  78.0 / 25.4;

    # Create figure
    fig = plt.figure(figsize=(figwidth,figheight))
    fig.subplots_adjust(left=0.055, right=1.03, top=1.02, bottom=0.06, wspace=0.01, hspace=0.01)

    # Set TS maps
    maps = [{'band': 'low',  'ic': 'icmap35.0-80.0-03.0', 'bkg': '_bgdlixf', 'method': '',
             'label': 'a'},
            {'band': 'high', 'ic': 'icmap50.0-70.0-06.0', 'bkg': '_bgdlixf', 'method': '',
             'label': 'b'},
            {'band': 'low',  'ic': 'icmap35.0-80.0-03.0', 'bkg': '_bgdlixf', 'method': 'drw-const_',
             'label': 'c'},
            {'band': 'high', 'ic': 'icmap50.0-70.0-06.0', 'bkg': '_bgdlixf', 'method': 'drw-const_',
             'label': 'd'},
            {'band': 'low',  'ic': 'icmap35.0-80.0-03.0', 'bkg': '', 'method': '',
             'label': 'e'},
            {'band': 'high', 'ic': 'icmap50.0-70.0-06.0', 'bkg': '', 'method': '',
             'label': 'f'}]

    # Loop over TS maps
    for i, map in enumerate(maps):

        # Get attributes
        band    = map['band']
        ic      = map['ic']
        bkg     = map['bkg']
        method  = map['method']
        label   = map['label']

        # Set x and y labels flags
        xlabel = (i > len(maps)-3)
        ylabel = (i % 2 == 0)

        # Get filenames
        filename = get_filename(band, 'none', ic, bkg=bkg, method=method)

        # Only continue if TS map exists
        if os.path.isfile(filename):

            # Create frame
            ax = fig.add_subplot(3,2,i+1)

            # Show TS map
            plot_tsmap(ax, filename, xlabel=xlabel, ylabel=ylabel, linewidths=0.5, fontsize=5, labelsize=5)

            # Add labels
            ax.text(0.01, 0.97, label, color='black', fontsize=8, weight='bold',
                     horizontalalignment='left', verticalalignment='top',
                     transform=ax.transAxes)

            # Add PSF circles (68% containment radius)
            if ylabel:
                draw_circle(ax, radius=2.0)
            else:
                draw_circle(ax, radius=1.4)

            # Set attributes
            ax.tick_params(axis='both', which='major', labelsize=5)
            ax.tick_params(axis='both', which='minor', labelsize=5)

        # ... otherwise print filename
        else:
            print(filename)

    # Show plot
    plt.show()

    # Save figure
    fig.savefig('extended-fig2.pdf', dpi=300)

    # Return
    return


# ====================================================================== #
# Show extended data figure for different Galactic ridge emission models #
# ====================================================================== #
def show_tsmaps_gre():
    """
    Show extended data figure for different Galactic ridge emission models
    """
    # Set figure size
    figwidth  = 183.0 / 25.4;
    figheight = 104.0 / 25.4;

    # Create figure
    fig = plt.figure(figsize=(figwidth,figheight))
    fig.subplots_adjust(left=0.055, right=1.03, top=1.015, bottom=0.04, wspace=0.01, hspace=0.01)

    # Set TS maps
    maps = [{'band': 'low',  'brems': 'none', 'ic': 'map',  'label': 'COMPASS (IC)'},
            {'band': 'high', 'brems': 'none', 'ic': 'map',  'label': 'COMPASS (IC)'},
            {'band': 'low',  'brems': 'none', 'ic': 'gmap', 'label': 'GALPROP (IC)'},
            {'band': 'high', 'brems': 'none', 'ic': 'gmap', 'label': 'GALPROP (IC)'},
            {'band': 'low',  'brems': 'map',  'ic': 'map',  'label': 'COMPASS (IC & Bremsstrahlung)'},
            {'band': 'high', 'brems': 'map',  'ic': 'map',  'label': 'COMPASS (IC & Bremsstrahlung)'},
            {'band': 'low',  'brems': 'gmap', 'ic': 'gmap', 'label': 'GALPROP (IC & Bremsstrahlung)'},
            {'band': 'high', 'brems': 'gmap', 'ic': 'gmap', 'label': 'GALPROP (IC & Bremsstrahlung)'}]

    # Loop over TS maps
    for i, map in enumerate(maps):

        # Get attributes
        band    = map['band']
        ic      = map['ic']
        brems   = map['brems']
        label   = map['label']

        # Set x and y labels flags
        xlabel = (i > len(maps)-3)
        ylabel = (i % 2 == 0)

        # Get filenames
        filename = get_filename(band, brems, ic)

        # Only continue if TS map exists
        if os.path.isfile(filename):

            # Create frame
            ax  = fig.add_subplot(4,2,i+1)

            # Plot TS map
            plot_tsmap(ax, filename, xlabel=xlabel, ylabel=ylabel, linewidths=0.5, fontsize=5, labelsize=5)

            # Add labels
            ax.text(0.01, 0.97, label, color='black', fontsize=6, weight='bold',
                     horizontalalignment='left', verticalalignment='top',
                     transform=ax.transAxes)

            # Add PSF circles (68% containment radius)
            if ylabel:
                draw_circle(ax, radius=2.0)
            else:
                draw_circle(ax, radius=1.4)

            # Set attributes
            ax.tick_params(axis='both', which='major', labelsize=5)
            ax.tick_params(axis='both', which='minor', labelsize=5)

        # ... otherwise print missing filename
        else:
            print(filename)

    # Show plot
    plt.show()

    # Save figure
    fig.savefig('fig2.pdf', dpi=300)

    # Return
    return


# ============================================== #
# Show extended data figure for temporal TS maps #
# ============================================== #
def show_tsmaps_temporal():
    """
    Show extended data figure for temporal TS maps
    """
    # Set figure size
    figwidth  = 183.0 / 25.4;
    figheight = 130.0 / 25.4;

    # Create figure
    fig = plt.figure(figsize=(figwidth,figheight))
    fig.subplots_adjust(left=0.055, right=1.03, top=1.02, bottom=0.04, wspace=0.01, hspace=0.01)

    # Set TS maps
    maps = [{'band': 'low',  'start': 48393, 'stop': 48926, 'ia': 'ia-g-l1.4-b0.2r3.3-plaw',
             'ic': 'icmap35.0-80.0-03.0', 'label': 'a'},
            {'band': 'high', 'start': 48393, 'stop': 48926, 'ia': 'ia-l-0.12+b0.66-plaw',
             'ic': 'icmap50.0-70.0-06.0', 'label': 'b'},
            {'band': 'low',  'start': 48926, 'stop': 49459, 'ia': 'ia-g-l1.4-b0.2r3.3-plaw',
             'ic': 'icmap35.0-80.0-03.0', 'label': 'c'},
            {'band': 'high', 'start': 48926, 'stop': 49459, 'ia': 'ia-l-0.12+b0.66-plaw',
             'ic': 'icmap50.0-70.0-06.0', 'label': 'd'},
            {'band': 'low',  'start': 49459, 'stop': 49992, 'ia': 'ia-g-l1.4-b0.2r3.3-plaw',
             'ic': 'icmap35.0-80.0-03.0', 'label': 'e'},
            {'band': 'high', 'start': 49459, 'stop': 49992, 'ia': 'ia-l-0.12+b0.66-plaw',
             'ic': 'icmap50.0-70.0-06.0', 'label': 'f'},
            {'band': 'low',  'start': 49992, 'stop': 50526, 'ia': 'ia-g-l1.4-b0.2r3.3-plaw',
             'ic': 'icmap35.0-80.0-03.0', 'label': 'g'},
            {'band': 'high', 'start': 49992, 'stop': 50526, 'ia': 'ia-l-0.12+b0.66-plaw',
             'ic': 'icmap50.0-70.0-06.0', 'label': 'h'}]

    # Loop over TS maps
    for i, map in enumerate(maps):

        # Get attributes
        band  = map['band']
        ia    = map['ia']
        ic    = map['ic']
        label = map['label']
        path  = 'temporal/mjd%d-%d' % (map['start'], map['stop'])

        # Set x and y labels flags
        xlabel = (i > len(maps)-3)
        ylabel = (i % 2 == 0)

        # Get TS map filename
        filename = '%s/%s' % (path, get_filename(band, 'none', ic, ia=ia, path=''))

        # Continue only if TS map exists
        if os.path.isfile(filename):

            # Create frame
            ax = fig.add_subplot(5,2,i+1)

            # Plot TS map
            plot_tsmap(ax, filename, xlabel=xlabel, ylabel=ylabel, linewidths=0.5, fontsize=5, labelsize=5)

            # Add labels
            ax.text(0.01, 0.97, label, color='black', fontsize=8, weight='bold',
                     horizontalalignment='left', verticalalignment='top',
                     transform=ax.transAxes)

            # Add PSF circles (1-sigma radius)
            if ylabel:
                draw_circle(ax, radius=2.0)
            else:
                draw_circle(ax, radius=1.4)

            # Set attributes
            ax.tick_params(axis='both', which='major', labelsize=5)
            ax.tick_params(axis='both', which='minor', labelsize=5)

        # ... otherwise print missing filename
        else:
            print(filename)

        # Get results filename
        filename = '%s/%s' % (path, get_filename(band, 'none', ic, ia=ia, path='', prefix='results',
                                                 suffix='_prefitted'))
        filename = filename.replace('.fits', '.xml')

        # Continue only if results exist
        if os.path.isfile(filename):

            # Load models
            models = gammalib.GModels(filename)

            # Continue only if IA component exists
            if models.contains('IA'):
                model = models['IA']
                if band == 'low':
                    emin = gammalib.GEnergy(0.75, 'MeV')
                    emax = gammalib.GEnergy(3.0,'MeV')
                else:
                    emin = gammalib.GEnergy(3.0, 'MeV')
                    emax = gammalib.GEnergy(30.0,'MeV')
                map['pflux']       = model.flux(emin,emax)
                map['pflux_error'] = model.flux_error(emin,emax)
                map['eflux']       = model.eflux(emin,emax)
                map['eflux_error'] = model.eflux_error(emin,emax)

                #print(band, map['pflux']*1.0e5, map['pflux_error']*1.0e5, map['eflux']*1.0e10, map['eflux_error']*1.0e10)

        # ... otherwise print missing filename
        else:
            print('*** File "%s" not found.' % (filename))

    # Create subplots
    gs = gridspec.GridSpec(6,2,hspace=0.4,wspace=0.1)
    gs.update(left=0.06, right=0.95, top=0.965, bottom=0.06, wspace=0.2, hspace=0.2)
    ax1 = fig.add_subplot(gs[5,0])
    ax2 = fig.add_subplot(gs[5,1])

    # Initialise arrays for flux plots
    mjd_low   = []
    emjd_low  = []
    mjd_high  = []
    emjd_high = []
    f_low     = []
    ef_low    = []
    f_high    = []
    ef_high   = []
    for map in maps:
        if 'pflux' in map:
            if map['band'] == 'low':
                mjd_low.append(0.5 * (map['start'] + map['stop']))
                emjd_low.append(0.5 * (map['stop']  - map['start']))
                f_low.append(map['pflux']*1.0e4)
                ef_low.append(map['pflux_error']*1.0e4)
            else:
                mjd_high.append(0.5 * (map['start'] + map['stop']))
                emjd_high.append(0.5 * (map['stop']  - map['start']))
                f_high.append(map['pflux']*1.0e5)
                ef_high.append(map['pflux_error']*1.0e5)

    # Show flux
    ax1.errorbar(mjd_low, f_low, xerr=emjd_low, yerr=ef_low, marker='o', color='red', linestyle='None',
                 markersize=1.0, linewidth=1.0)
    ax2.errorbar(mjd_high, f_high, xerr=emjd_high, yerr=ef_high, marker='o', color='blue', linestyle='None',
                 markersize=1.0, linewidth=1.0)

    # Plot zero flux lines
    ax1.plot([48393,50526], [0.0,0.0], '-', color='black', linewidth=0.5)
    ax2.plot([48393,50526], [0.0,0.0], '-', color='black', linewidth=0.5)

    # Add labels
    ax1.text(0.01, 0.97, 'i', color='black', fontsize=8, weight='bold',
             horizontalalignment='left', verticalalignment='top',
             transform=ax1.transAxes)
    ax2.text(0.01, 0.97, 'j', color='black', fontsize=8, weight='bold',
             horizontalalignment='left', verticalalignment='top',
             transform=ax2.transAxes)

    # Set attributes
    ax1.set_xlabel('MJD (days)', fontsize=5)
    ax2.set_xlabel('MJD (days)', fontsize=5)
    ax1.set_ylabel(r'Flux (10$^{-4}$ cm$^{-2}$ s$^{-1}$)', fontsize=5)
    ax2.set_ylabel(r'Flux (10$^{-5}$ cm$^{-2}$ s$^{-1}$)', fontsize=5)
    ax1.set_xlim([48393,50526])
    ax1.set_ylim([0.0,4.0])
    ax2.set_xlim([48393,50526])
    ax2.set_ylim([-1.6,7.0])
    ax1.tick_params(axis='both', which='major', labelsize=5)
    ax1.tick_params(axis='both', which='minor', labelsize=5)
    ax2.tick_params(axis='both', which='major', labelsize=5)
    ax2.tick_params(axis='both', which='minor', labelsize=5)

    # Show plot
    plt.show()

    # Save figure
    fig.savefig('extended-fig4.pdf', dpi=300)

    # Return
    return


# ======================== #
# Main routine entry point #
# ======================== #
if __name__ == '__main__':

    # Set usage string
    usage = 'show_paper_tmaps.py [-figure main/bkg/dge/temporal]'

    # Set default options
    options = [{'option': '-figure', 'value': 'main'}]

    # Get arguments and options from command line arguments
    args, options = cscripts.ioutils.get_args_options(options, usage)

    # Extract script parameters from options
    figure = options[0]['value']

    # Branch on option
    if figure == 'main':
        show_tsmaps_main()
    elif figure == 'bkg':
        show_tsmaps_bkg()
    elif figure == 'gre':
        show_tsmaps_gre()
    elif figure == 'temporal':
        show_tsmaps_temporal()
    else:
        print('*** ERROR: Unknown figure option "%s". Specify one of main/bkg/dge/temporal.' % figure)
