#!/usr/bin/env python

import os, pdb
import numpy as np

def writeout(outfile, wave, flux, ivar, modelflux,
             emlineflux, fastspec, fastmeta):
    from astropy.io import fits

    nobj, _ = flux.shape
    
    hduflux = fits.PrimaryHDU(flux.astype('f4'))
    hduflux.header['EXTNAME'] = 'FLUX'

    hduivar = fits.ImageHDU(ivar.astype('f4'))
    hduivar.header['EXTNAME'] = 'IVAR'

    hduwave = fits.ImageHDU(wave.astype('f8'))
    hduwave.header['EXTNAME'] = 'WAVE'
    hduwave.header['BUNIT'] = 'Angstrom'
    hduwave.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')

    hdumflux = fits.ImageHDU(modelflux.astype('f4'))
    hdumflux.header['EXTNAME'] = 'MODELFLUX'

    hdueflux = fits.ImageHDU(emlineflux.astype('f4'))
    hdueflux.header['EXTNAME'] = 'EMLINEFLUX'

    hdufastmeta = fits.convenience.table_to_hdu(fastmeta)
    hdufastmeta.header['EXTNAME'] = 'METADATA'

    hdufast = fits.convenience.table_to_hdu(fastspec)
    hdufast.header['EXTNAME'] = 'FASTSPEC'

    hdulist = [hduflux, hduivar, hduwave, hdumflux,
               hdueflux, hdufastmeta, hdufast]
    
    hx = fits.HDUList(hdulist)

    hx.writeto(outfile, overwrite=True, checksum=True)
    print('Writing {} spectra to {}'.format(nobj, outfile))

def main():

    from fastspecfit.templates.templates import rebuild_fastspec_spectrum, read_stacked_fastspec
    from fastspecfit.continuum import ContinuumFit
    from fastspecfit.emlines import EMLineFit    

    targetclass = 'elg'
    templatedir = os.path.join(os.getenv('DESI_ROOT'), 'users', 'ioannis', 'desi-templates')
    fastspecfile = os.path.join(templatedir, '{}-fastspec.fits'.format(targetclass))

    CFit = ContinuumFit()
    EMFit = EMLineFit()    

    fastwave, fastflux, fastivar, fastmeta, fastspec = read_stacked_fastspec(fastspecfile)

    modelflux = np.zeros_like(fastflux, dtype='f4')
    emlineflux = np.zeros_like(fastflux, dtype='f4')
    nobj = len(fastmeta)

    # rebuild the best-fitting spectrum
    for ii in np.arange(nobj):
        modelwave, continuum, smooth_continuum, emlinemodel, data = rebuild_fastspec_spectrum(
            fastspec[ii], fastwave, fastflux[ii, :], fastivar[ii, :], CFit, EMFit)
        
        good = np.where(fastivar[ii, :] > 0)[0]
        emlineflux[ii, good] = emlinemodel
        modelflux[ii, good] = continuum + emlinemodel

    outfile = '/global/cscratch1/sd/ioannis/{}-fastspec-models.fits'.format(targetclass)
    writeout(outfile, fastwave, fastflux, fastivar, 
             modelflux, emlineflux, fastspec, fastmeta)

if __name__ == '__main__':
    main()
    
