#!/usr/bin/env python
"""fastspecfit stacking code

fastspecfit-stack --specfitfile /global/cfs/cdirs/desi/users/ioannis/fastspecfit/blanc/tiles/specfit-80607-80613-20201214-20201223.fits --outdir /global/cfs/cdirs/desi/users/ioannis/fastspecfit/blanc/tiles/qa/stacks --tile 80608 --mp 32 --ntargets 5

"""
import pdb # for debugging
import os, sys, time
import numpy as np

from desiutil.log import get_logger
log = get_logger()

# ridiculousness!
import tempfile
os.environ['MPLCONFIGDIR'] = tempfile.mkdtemp()

def parse(options=None):
    """Parse input arguments.

    """
    import sys, argparse

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--tile', default=None, type=str, nargs='*', help='Generate QA for all objects on this tile.')
    parser.add_argument('--night', default=None, type=str, nargs='*', help='Generate QA for all objects observed on this night.')
    parser.add_argument('--targetids', type=str, default=None, help='Comma-separated list of target IDs to process.')
    parser.add_argument('-n', '--ntargets', type=int, help='Number of targets to process in each file.')
    parser.add_argument('--firsttarget', type=int, default=0, help='Index of first object to to process in each file (0-indexed).')
    parser.add_argument('--mp', type=int, default=1, help='Number of multiprocessing processes per MPI rank or node.')

    parser.add_argument('--suffix', default=None, type=str, help='Optional suffix for output filename.')
    parser.add_argument('-o', '--outdir', default=None, type=str, help='Full path to desired output directory.')
    
    parser.add_argument('--specfitfile', default=None, type=str, help='Full path to specfit fitting results.')
    
    if options is None:
        args = parser.parse_args()
        log.info(' '.join(sys.argv))
    else:
        args = parser.parse_args(options)
        log.info('fastspecfit-stack {}'.format(' '.join(options)))

    return args

def main(args=None, comm=None):
    """Main module.

    """
    import fitsio
    from astropy.table import Table
    from astropy.io import fits

    from desispec.interpolation import resample_flux
    from desitarget.sv1.sv1_targetmask import desi_mask

    from fastspecfit.continuum import ContinuumFit
    from fastspecfit.emlines import EMLineFit
    from fastspecfit.external.desi import DESISpectra

    if isinstance(args, (list, tuple, type(None))):
        args = parse(args)

    # Read the fitting results.
    if args.specfitfile is None:
        log.warning('Must provide --specfitfile.')
        return

    def _check_and_read(filename):
        if os.path.isfile(filename):
            fastfit, hdr = fitsio.read(filename, header=True, ext='RESULTS')
            specprod = hdr['SPECPROD']
            fastfit = Table(fastfit)
            log.info('Read {} objects from {}'.format(len(fastfit), filename))
            return fastfit, specprod
        else:
            log.warning('File {} not found.'.format(filename))
            raise IOError

    specfit, specprod = _check_and_read(args.specfitfile)

    # parse the targetids optional input
    if args.targetids:
        targetids = [int(x) for x in args.targetids.split(',')]
    else:
        targetids = args.targetids

    # optionally trim to a particular tile and/or night
    def _select_tiles_nights_targets(fastfit, tiles=None, nights=None, targetids=None):
        if fastfit and (tiles or nights or targetids):
            keep = np.ones(len(fastfit), bool)
            if targetids:
                targetkeep = np.zeros(len(fastfit), bool)
                for targetid in targetids:
                    targetkeep = np.logical_or(targetkeep, fastfit['TARGETID'] == targetid)
                keep = np.logical_and(keep, targetkeep)
            if tiles:
                tilekeep = np.zeros(len(fastfit), bool)
                for tile in tiles:
                    tilekeep = np.logical_or(tilekeep, fastfit['TILEID'].astype(str) == tile)
                keep = np.logical_and(keep, tilekeep)
            if nights:
                nightkeep = np.zeros(len(fastfit), bool)
                for night in nights:
                    nightkeep = np.logical_or(nightkeep, fastfit['NIGHT'].astype(str) == night)
                keep = np.logical_and(keep, nightkeep)
            fastfit = fastfit[keep]
            #log.info('Keeping {} spectra from tile {} on night {}.'.format(len(fastfit), tile, night))
        return fastfit

    specfit = _select_tiles_nights_targets(specfit, tiles=args.tile, nights=args.night, targetids=targetids)

    # Initialize the continuum- and emission-line fitting classes.
    t0 = time.time()
    CFit = ContinuumFit()
    EMFit = EMLineFit()
    log.info('Initializing the classes took: {:.2f} sec'.format(time.time()-t0))

    Spec = DESISpectra()

    # select ELGs
    ielg = np.where((specfit['SV1_DESI_TARGET'] & desi_mask.mask('ELG') != 0) *
                    (specfit['OII_3726_FLUX'] > 0) *
                    (specfit['OII_3729_FLUX'] > 0))[0]
    specfit = specfit[ielg]

    Spec.find_specfiles(fastfit=specfit, specprod=specprod, targetids=targetids,
                        firsttarget=args.firsttarget, ntargets=args.ntargets)
    if len(Spec.specfiles) == 0:
        return
    data = Spec.read_and_unpack(CFit, synthphot=True, remember_coadd=True)

    # sort specfit so it matches 'data', 'Spec.fibermap', and 'Spec.zbest'
    refindx = np.array(['{}-{}-{}'.format(night, tile, tid) for night, tile, tid in zip(Spec.fibermap['FIRST_NIGHT'], Spec.fibermap['FIRST_TILEID'], Spec.fibermap['TARGETID'])])
    sindx = np.array(['{}-{}-{}'.format(night, tile, tid) for night, tile, tid in zip(specfit['NIGHT'], specfit['TILEID'], specfit['TARGETID'])])
    specfit = specfit[np.hstack([np.where(rindx == sindx)[0] for rindx in refindx])]
    
    wave = np.arange(1200.0, 8000.0, 0.5)
    npix = len(wave)

    # Build the stacks!
    ngal = len(specfit)
    flux = np.zeros((ngal, npix))
    ivar = np.zeros((ngal, npix))
    keep = np.ones(ngal, bool)
    t0 = time.time()
    pad = 10.0
    for igal, onegal in enumerate(data):
        cwave = onegal['coadd_wave'] / (1 + onegal['zredrock'])
        #I = np.where((wave > (cwave.min()+0)) * (wave < (cwave.max()-0)))[0]
        I = np.where((wave > (cwave.min()+pad)) * (wave < (cwave.max()-pad)))[0]
        #I = np.where((wave > (cwave.min()+5)) * (wave < (cwave.max()-5)) * (onegal['coadd_ivar'] > 0))[0]
        if len(I) > 5:# and np.count_nonzero(onegal['coadd_ivar'] > 0) > 10:
            try:
                _flux, _ivar = resample_flux(wave[I], cwave, onegal['coadd_flux'], ivar=onegal['coadd_ivar'])
                flux[igal, I] = _flux
                ivar[igal, I] = _ivar
            except:
                log.info('Rejecting object {}'.format(igal))
                pdb.set_trace()
                keep[igal] = False
        else:
            log.info('Rejecting object {}'.format(igal))
            keep[igal] = False
        #_ivar = resample_flux(wave, onegal['coadd_wave'] / (1 + onegal['zredrock']),
        #                      onegal['coadd_flux'], extrapolate=False)
        #norm = np.median(_flux)
        #if norm > 0:
        #    _flux /= norm
        #    _ivar *= norm**2
#        _flux /= CFit.fluxnorm * onegal['synthphot']['flam'][1] # r-band flux
    log.info('Done in {:.2f} sec'.format(time.time()-t0))

    #INFO:fastspecfit-stack:159:main: Rejecting object 7944
    #INFO:fastspecfit-stack:159:main: Rejecting object 14511
    #INFO:fastspecfit-stack:159:main: Rejecting object 17092
    #INFO:fastspecfit-stack:159:main: Rejecting object 18932
    #INFO:fastspecfit-stack:159:main: Rejecting object 19148
    reject = np.where(np.logical_not(keep))[0]
    keep = np.where(keep)[0]
    if len(reject) > 0:
        pdb.set_trace()

    try:
        flux = flux[keep, :]
        ivar = ivar[keep, :]
        specfit = specfit[keep]
        zbest = Spec.zbest[keep]
        fibermap = Spec.fibermap[keep]

        # write out
        t0 = time.time()
        if args.outdir is None:
            args.outdir = '.'
        if not os.path.isdir(args.outdir):
            os.makedirs(args.outdir)

        outfile = os.path.join(args.outdir, 'rest-elgs-{}.fits'.format(args.tile[0]))

        hduflux = fits.PrimaryHDU(flux.astype('f4'))
        hduflux.header['EXTNAME'] = 'FLUX'

        hduivar = fits.ImageHDU(ivar.astype('f4'))
        hduivar.header['EXTNAME'] = 'IVAR'

        hduwave = fits.ImageHDU(wave.astype('f8'))
        hduwave.header['EXTNAME'] = 'WAVE'
        hduwave.header['BUNIT'] = 'Angstrom'
        hduwave.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')

        hduspecfit = fits.convenience.table_to_hdu(specfit)
        hduspecfit.header['EXTNAME'] = 'SPECFIT'

        hduzbest = fits.convenience.table_to_hdu(zbest)
        hduzbest.header['EXTNAME'] = 'ZBEST'

        hdufmap = fits.convenience.table_to_hdu(fibermap)
        hdufmap.header['EXTNAME'] = 'FIBERMAP'

        hx = fits.HDUList([hduflux, hduwave, hduspecfit, hduzbest, hdufmap])

        print('Writing {}'.format(outfile))
        hx.writeto(outfile, overwrite=True)
        log.info('Time to write is {:.2f} sec'.format(time.time()-t0))
    except:
        pdb.set_trace()

if __name__ == '__main__':
    main()
