#!/usr/bin/env python
"""Convert the FSPS SSPs from .spec to .FITS format.

source /global/cfs/cdirs/desi/software/desi_environment.sh master
./CKC2fits

Note that the SSPs are resampled to 25 km/s (roughly the smallest pixel size in
DESI at ~1 micron) but to save space every just fourth pixel longward of 1
micron (so an effective pixel spacing of ~100 km/s) is written out.  Each SSP
has ~83,000 pixels.

version 1.0: 25 km/s pixels
version 1.1: 10 km/s pixels

"""
import os, pdb
import numpy as np

from scipy import constants
C_LIGHT = constants.c / 1000.0 # [km/s]

ssppath = '/global/cfs/cdirs/desi/external/templates/SSP-CKC14z'
version = 'v1.1'

def _CKC2fits(metallicity='Z0.0190'):
    from astropy.io import fits
    import astropy.units as u
    from astropy.table import QTable, Column
    from desispec.interpolation import resample_flux

    library = 'CKC14z'
    isochrone = 'Padova' # would be nice to get MIST in here
    imf = 'Kroupa'

    sspfile = os.path.join(ssppath, 'spec', 'SSP_{}_{}_{}_{}.out.spec'.format(
        isochrone, library, imf, metallicity))
    outfile = os.path.join(ssppath, version, 'SSP_{}_{}_{}_{}.fits'.format(
        isochrone, library, imf, metallicity))

    skiprow = 8 # skip the informational header
    nage, _npix = np.loadtxt(sspfile, skiprows=skiprow, max_rows=1, dtype=int, unpack=True)
    skiprow += 1

    wave = np.loadtxt(sspfile, skiprows=skiprow, max_rows=1) # vacuum wavelengths
    assert(np.all(np.diff(wave) >= 0)) # make sure wavelengths are sorted
    skiprow += 1

    # Resample to constant log-lambda / velocity. In the IR (starting at ~1
    # micron), take every fourth sampling, to save space.
    pixkms = 10.0                            # pixel size [km/s]
    irfactor = 4
    wavesplit = 1e4
    dlogwave = pixkms / C_LIGHT / np.log(10) # pixel size [log-lambda]
    newwave = 10**np.arange(np.log10(np.min(wave)), np.log10(np.max(wave)), dlogwave)

    isplit = np.argmin(np.abs(newwave-wavesplit)) + 1
    newwave = np.hstack((newwave[:isplit], newwave[isplit:][::irfactor]))
    npix = len(newwave)

    ssp = QTable()
    ssp.add_column(Column(name='age', dtype='f4', length=nage))
    ssp.add_column(Column(name='mstar', dtype='f4', length=nage))
    ssp.add_column(Column(name='lbol', dtype='f4', length=nage))
    #ssp.add_column(Column(name='flux', dtype='f4', length=nage, shape=(npix,)))
    flux = np.zeros((npix, nage)).astype('f4')

    # now read the additional properties
    for iage in np.arange(nage):
        t, m, l, s = np.loadtxt(sspfile, skiprows=skiprow, max_rows=1, unpack=True)
        skiprow += 1
        _flux = np.loadtxt(sspfile, skiprows=skiprow, max_rows=1)
        #print(_flux.shape, wave.shape, flux[iage, :].shape)

        _flux *= 3.826e33 * C_LIGHT*1e13 / wave**2     # [Lsun/Hz] --> [erg/s/A]
        _flux /= 4.0 * np.pi * (10.0 * 3.085678e18)**2 # [erg/s/cm2/A] at 10 pc
        _newflux = resample_flux(newwave, wave, _flux)
        flux[:, iage] = _newflux

        ssp['age'][iage] = 10**t # [yr]
        ssp['mstar'][iage] = 10**m # [Msun]
        ssp['lbol'][iage] = l # Lsun]
        
        skiprow += 1
        
    ssp['age'].unit = u.yr
    ssp['mstar'].unit = u.solMass
    ssp['lbol'].unit = u.solLum
    #ssp['flux'].unit = unit=u.erg / (u.s * u.cm**2 * u.Angstrom)

    # add units
    newwave *= u.Angstrom
    flux *= u.erg / (u.s * u.cm**2 * u.Angstrom)
    
    hduflux = fits.PrimaryHDU(flux.value)
    hduflux.header['EXTNAME'] = 'FLUX'
    hduflux.header['VERSION'] = version
    hduflux.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    hduwave = fits.ImageHDU(newwave.value)
    hduwave.header['EXTNAME'] = 'WAVE'
    hduwave.header['BUNIT'] = 'Angstrom'
    hduwave.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')
    hduwave.header['PIXSZBLU'] = (pixkms, 'pixel size blueward of PIXSZSPT [km/s]')
    hduwave.header['PIXSZRED'] = (irfactor*pixkms, 'pixel size redward of PIXSZSPT [km/s]')
    hduwave.header['PIXSZSPT'] = (newwave[isplit].value, 'wavelength where pixel size changes [Angstrom]')

    hdutable = fits.convenience.table_to_hdu(ssp)
    hdutable.header['EXTNAME'] = 'METADATA'

    hx = fits.HDUList([hduflux, hduwave, hdutable])

    print('Writing {}'.format(outfile))
    hx.writeto(outfile, overwrite=True)
    
if __name__ == '__main__':
    allZ = ['Z0.0003', 'Z0.0006', 'Z0.0012', 'Z0.0025',
            'Z0.0049', 'Z0.0096', 'Z0.0190', 'Z0.0300']
    #for metallicity in [allZ[6]]:
    for metallicity in allZ:
        _CKC2fits(metallicity)
