#!/usr/bin/env python

"""Make the spectrum underlying the FastSpecFit logo.

fastspec /global/cfs/cdirs/desi/spectro/redux/guadalupe/healpix/main/bright/104/10460/redrock-main-bright-10460.fits --targetids 39633005802685491 --outfile fastspec-example.fits
make-logo

"""
import os, argparse, pdb
import numpy as np
from fastspecfit.io import DESISpectra, read_fastspecfit, DESI_ROOT_NERSC
from fastspecfit.fastspecfit import FastFit
from fastspecfit.util import C_LIGHT

def makelogo(data, fastspec, metaspec, Fit, outdir=None):
    """QA plot the emission-line spectrum and best-fitting model."""
    
    from astropy.table import Table, Column
    from scipy.ndimage import median_filter
    import matplotlib.pyplot as plt
    from matplotlib import colors
    import matplotlib.ticker as ticker
    import seaborn as sns

    from fastspecfit.util import ivar2var

    sns.set(context='talk', style='ticks', font_scale=1.5)#, rc=rc)

    col1 = [colors.to_hex(col) for col in ['dodgerblue', 'darkseagreen', 'orangered']]
    col2 = [colors.to_hex(col) for col in ['darkblue', 'darkgreen', 'darkred']]
    col3 = [colors.to_hex(col) for col in ['blue', 'green', 'red']]

    col4 = colors.to_hex('darkblue') # 'darkgreen', 'darkred', 'dodgerblue', 'darkseagreen', 'orangered']]

    if outdir is None:
        outdir = '.'
        
    pngfile = 'desi-users/ioannis/tmp/fastspecfit-spectrum-logo.png'

    apercorr = fastspec['APERCORR']
    redshift = fastspec['Z']

    if metaspec['PHOTSYS'] == 'S':
        filters = FFit.decam
        allfilters = FFit.decamwise
    else:
        filters = FFit.bassmzls
        allfilters = FFit.bassmzlswise

    # Rebuild the best-fitting spectroscopic model; prefix "desi" means
    # "per-camera" and prefix "full" has the cameras h-stacked.
    fullwave = np.hstack(data['wave'])

    desicontinuum, _ = FFit.templates2data(FFit.templateflux_nolines, FFit.templatewave, 
                                           redshift=redshift, synthphot=False,
                                           specwave=data['wave'], specres=data['res'],
                                           specmask=data['mask'], cameras=data['cameras'],
                                           vdisp=fastspec['VDISP'],
                                           coeff=fastspec['COEFF'])

    # remove the aperture correction
    desicontinuum = [_desicontinuum / apercorr for _desicontinuum in desicontinuum]
    fullcontinuum = np.hstack(desicontinuum)

     # Need to be careful we don't pass a large negative residual where
     # there are gaps in the data.
    desiresiduals = []
    for icam in np.arange(len(data['cameras'])):
        resid = data['flux'][icam] - desicontinuum[icam]
        I = (data['flux'][icam] == 0.0) * (data['flux'][icam] == 0.0)
        if np.any(I):
            resid[I] = 0.0
        desiresiduals.append(resid)
    
    if np.all(fastspec['COEFF'] == 0):
        fullsmoothcontinuum = np.zeros_like(fullwave)
    else:
        fullsmoothcontinuum, _ = FFit.smooth_continuum(
            fullwave, np.hstack(desiresiduals), np.hstack(data['ivar']), 
            redshift=redshift, linemask=np.hstack(data['linemask']))

    desismoothcontinuum = []
    for campix in data['camerapix']:
        desismoothcontinuum.append(fullsmoothcontinuum[campix[0]:campix[1]])

    # full model spectrum + individual line-spectra
    desiemlines = FFit.emlinemodel_bestfit(data['wave'], data['res'], fastspec)

    fig, specax = plt.subplots(figsize=(14, 6))

    wavemin, wavemax = 3600, 9800

    for icam in np.arange(len(data['cameras'])): # iterate over cameras
        wave = data['wave'][icam]
        flux = data['flux'][icam]
        modelflux = desiemlines[icam] + desicontinuum[icam] + desismoothcontinuum[icam]

        sigma, camgood = ivar2var(data['ivar'][icam], sigma=True, allmasked_ok=True, clip=0)

        # trim the end of camera 'r'
        if icam == 1:
            camgood = np.where(camgood)[0]
            camgood = camgood[:-35]

        wave = wave[camgood]
        flux = flux[camgood]
        sigma = sigma[camgood]
        modelflux = modelflux[camgood]

        specax.plot(wave/1e4, flux, color=col1[icam], alpha=1.0)
        specax.plot(wave/1e4, modelflux, color=col2[icam], lw=3, alpha=0.8)

    #specax.set_xlim(wavemin, wavemax)
    #specax.set_ylim(-4, 35)
    specax.axis('off')

    fig.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)
    
    print('Writing {}'.format(pngfile))
    fig.savefig(pngfile)
    plt.close()

if __name__ == '__main__':

    p = argparse.ArgumentParser()
    p.add_argument('-o', '--outdir', type=str, default='.', help='output directory')    
    args = p.parse_args()

    fastspec, metaspec, coadd_type, _ = read_fastspecfit('fastspec-example.fits')
    #fastphot, metaphot, coadd_type, _ = read_fastspecfit('fastphot-example.fits')

    specprod = 'guadalupe'
    survey, program, healpix = fastspec['SURVEY'][0], fastspec['PROGRAM'][0], fastspec['HEALPIX'][0]
    
    redux_dir = os.path.join(os.environ.get('DESI_ROOT', DESI_ROOT_NERSC), 'spectro', 'redux')
    redrockfile = os.path.join(redux_dir, specprod, 'healpix', str(survey), str(program), str(healpix // 100), 
                               str(healpix), 'redrock-{}-{}-{}.fits'.format(survey, program, healpix))

    FFit = FastFit()
    Spec = DESISpectra(redux_dir=redux_dir)
    Spec.select(redrockfile, targetids=[fastspec['TARGETID'][0]])
    data = Spec.read_and_unpack(FFit, fastphot=False, synthphot=False)

    makelogo(data[0], fastspec[0], metaspec[0], FFit, outdir=args.outdir)    
