#!/usr/bin/env python

# $HOME/code/desihub/fastspecfit/sandbox/spectra-for-samyak

def main():

    import os
    import numpy as np
    import fitsio, pdb
    from astropy.table import Table
    import matplotlib.pyplot as plt
    
    from desiutil.dust import dust_transmission
    from desispec.io import read_spectra
    from desispec.coaddition import coadd_cameras

    specfile = '/global/cfs/cdirs/desi/spectro/redux/iron/healpix/sv1/bright/70/7020/coadd-sv1-bright-7020.fits'
    fastfile = '/global/cfs/cdirs/desi/spectro/fastspecfit/iron/healpix/sv1/bright/70/7020/fastspec-sv1-bright-7020.fits.gz'

    if True:
        fast = Table(fitsio.read(fastfile, 'FASTSPEC', columns=['TARGETID', 'Z', 'HALPHA_EW', 'HALPHA_EW_IVAR']))

        I = np.where((fast['HALPHA_EW'] > 5.) * (fast['HALPHA_EW'] * np.sqrt(fast['HALPHA_EW_IVAR']) > 5.))[0]
        I = I[:100]
        nobj = len(I)

        meta = Table(fitsio.read(fastfile, 'METADATA', rows=I))
        fast = Table(fitsio.read(fastfile, 'FASTSPEC', rows=I))
        
        models, hdr = fitsio.read(fastfile, 'MODELS', header=True)
        models = np.squeeze(models[I, :, :])
        
        modelwave = hdr['CRVAL1'] + np.arange(hdr['NAXIS1']) * hdr['CDELT1']

        targetids = meta['TARGETID'].data

        spec = read_spectra(specfile).select(targets=targetids)
        coadd_spec = coadd_cameras(spec)
        bands = coadd_spec.bands[0]
        wave = coadd_spec.wave[bands]

        outfiles = []
        for ii, targetid in enumerate(targetids):
            mw_transmission_spec = dust_transmission(coadd_spec.wave[bands], meta[ii]['EBV'])
        
            flux = coadd_spec.flux[bands][ii, :] / mw_transmission_spec
            ivar = coadd_spec.ivar[bands][ii, :] * mw_transmission_spec**2
            
            continuum = np.squeeze(models[ii, 0, :] + models[ii, 1, :])
            lineflux = flux - continuum
            linemodel = np.squeeze(models[ii, 2, :])
        
            out = np.array([modelwave, lineflux, ivar, linemodel]).T
            outfile = os.path.join('/global/u2/i/ioannis/for-samyak/', f'fastspec-sv1-bright-{targetid}.txt')
            outfiles.append(os.path.basename(outfile))
            
            print(f'Writing {outfile}')
            np.savetxt(outfile, out)

        # write out the summary file
        summaryfile = os.path.join('/global/u2/i/ioannis/for-samyak/', f'fastspec-sample.txt')
        print(f'Writing {summaryfile}')
        with open(summaryfile, 'w') as F:
            for targetid, zobj, outfile in zip(targetids, meta['Z'].data, outfiles):
                F.write(f'{targetid} {zobj} {outfile}\n')
    else:
        # https://fastspecfit.desi.lbl.gov/target/sv1-bright-7020-39633324653677602?index=38
        # z=0.1282432
        targetid = 39633324653677602
    
        targetids = fitsio.read(fastfile, 'METADATA', columns='TARGETID')
        I = np.where(targetid == targetids)[0]
        
        meta = Table(fitsio.read(fastfile, 'METADATA', rows=I))
        fast = Table(fitsio.read(fastfile, 'FASTSPEC', rows=I))
        
        models, hdr = fitsio.read(fastfile, 'MODELS', header=True)
        models = np.squeeze(models[I, :, :])
        
        modelwave = hdr['CRVAL1'] + np.arange(hdr['NAXIS1']) * hdr['CDELT1']
        
        spec = read_spectra(specfile).select(targets=targetid)
        coadd_spec = coadd_cameras(spec)
        bands = coadd_spec.bands[0]
    
        mw_transmission_spec = dust_transmission(coadd_spec.wave[bands], meta['EBV'])
    
        wave = coadd_spec.wave[bands]
        flux = coadd_spec.flux[bands].flatten() / mw_transmission_spec
        ivar = coadd_spec.ivar[bands].flatten() * mw_transmission_spec**2
        
        continuum = np.squeeze(models[0, :] + models[1, :])
        lineflux = flux - continuum
        linemodel = np.squeeze(models[2, :])
    
        out = np.array([modelwave, lineflux, ivar, linemodel]).T
        np.savetxt('ioannis/tmp/fastspec-example.txt', out)
    
        # quick plot
        wave, flux, ivar, model = np.loadtxt('ioannis/tmp/fastspec-example.txt', unpack=True)
    
        fig, ax = plt.subplots(figsize=(10, 6))
        ax.plot(wave, flux, color='gray', alpha=0.7, label='Flux')
        ax.plot(wave, model, label='Model', ls='-', color='blue')
        ax.set_ylabel(r'Flux Density ($10^{-17}~{\rm erg}~{\rm s}^{-1}~{\rm cm}^{-2}~\AA^{-1}$)')
        ax.set_xlabel(r'Observed-frame Wavelength ($\AA$)')
        ax.legend(fontsize=8, loc='upper right')
        fig.tight_layout()
        fig.savefig('ioannis/tmp/fastspec-example.png')

if __name__ == '__main__':
    main()


