#!/usr/bin/env python

'''Build a set of FSPS templates for use with FastSpecFit.

* Chabrier, Salpeter, or Kroupa IMF
* MIST isochrones
* C3K_a stellar library
* DL07 dust emission spectra
* 7 variable-width age bins between 30 Myr and 13.7 Gyr with constant star
  formation within each age bin
* 3 stellar metallicity values between -1 and 0.3 (relative to solar)
* 8 dust attenuation values between zero and roughly 3 mag
* Power-law dust law with fixed -0.7 slope

Note that Figure 3 in Leja et al. 2017 nicely shows the effect of various
free parameters on the resulting SED.

conda activate fastspecfit
time python $HOME/code/desihub/fastspecfit/bin/build-fsps-templates --imf salpeter
time python $HOME/code/desihub/fastspecfit/bin/build-fsps-templates --imf chabrier
time python $HOME/code/desihub/fastspecfit/bin/build-fsps-templates --imf kroupa

Note: to use different stellar libraries (e.g., MILES vs C3K), in a clean
terminal make sure you've git-updated a local checkout of the python-fsps and
that $SPS_HOME/src/sps_vars.f90 has been edited correctly and then:

% echo SPS_HOME
  /Users/ioannis/code/python-fsps/src/fsps/libfsps
% python -m pip install . --no-cache-dir --force-reinstall
% python -c "import fsps; sp = fsps.StellarPopulation(); print(sp.libraries)"

'''
import os, time, pdb
import numpy as np
import fitsio, fsps
from scipy.ndimage import gaussian_filter1d
import argparse

from astropy.io import fits
from astropy.table import Table
import matplotlib.pyplot as plt

from desispec.interpolation import resample_flux
from fastspecfit.util import C_LIGHT
from fastspecfit.continuum import PIXKMS_BLU, PIXKMS_RED, PIXKMS_WAVESPLIT

irfactor = int(PIXKMS_RED / PIXKMS_BLU)

def smooth_continuum(wave, flux, medbin=1000, smooth_window=200, 
                     smooth_step=50, png=None):
    """Build a smooth, nonparametric continuum spectrum.

    """
    from numpy.lib.stride_tricks import sliding_window_view
    from scipy.ndimage import median_filter
    from scipy.stats import sigmaclip

    flux /= np.median(flux)

    npix = len(wave)

    # Build the smooth (line-free) continuum by computing statistics in a
    # sliding window, accounting for masked pixels and trying to be smart
    # about broad lines. See:
    #   https://stackoverflow.com/questions/41851044/python-median-filter-for-1d-numpy-array
    #   https://numpy.org/devdocs/reference/generated/numpy.lib.stride_tricks.sliding_window_view.html
    
    wave_win = sliding_window_view(wave, window_shape=smooth_window)
    flux_win = sliding_window_view(flux, window_shape=smooth_window)

    smooth_wave, smooth_flux, smooth_mask = [], [], []
    for swave, sflux in zip(wave_win[::smooth_step], flux_win[::smooth_step]):

        cflux, _, _ = sigmaclip(sflux, low=2.0, high=2.0)
        if len(cflux) < 10:
            smooth_mask.append(True)
            continue

        I = np.isin(sflux, cflux) # fragile?
        smooth_wave.append(np.mean(swave[I]))
        smooth_mask.append(False)
        
        mn = np.max(cflux)
        #mn = np.median(cflux)

        smooth_flux.append(mn)

    smooth_wave = np.array(smooth_wave)
    smooth_flux = np.array(smooth_flux)
    smooth_mask = np.array(smooth_mask)

    # interpolate onto the original wavelength vector
    _smooth_flux = np.interp(wave, smooth_wave, smooth_flux)
    smooth = median_filter(_smooth_flux, medbin, mode='nearest')

    # Optional QA.
    if png:
        import matplotlib.pyplot as plt
        plt.clf()
        fig, ax = plt.subplots(2, 1, figsize=(8, 10), sharex=True)
        ax[0].plot(wave, flux, alpha=0.7)
        ax[0].plot(smooth_wave, smooth_flux, color='green')
        ax[0].scatter(smooth_wave, smooth_flux, color='k', marker='s', s=20)
        ax[0].plot(wave, smooth, color='red')

        ax[1].plot(wave, flux / smooth)
        ax[1].axhline(y=1, color='k', alpha=0.8, ls='--')

        #for xx in ax:
        #    xx.set_xlim(7000, 9000)
        #for xx in ax:
        #    xx.set_ylim(-1, 8)
        fig.savefig(png)
        plt.close()

    return smooth

def build_templates(models, logages, agebins=None, imf='chabrier',
                    sfh_ssp=False, include_nebular=True):

    nsed = len(models)

    meta = Table()
    meta['age'] = 10**models['logage']
    meta['zzsun'] = models['logmet']
    meta['av'] = models['dust'] * 1.086
    meta['mstar'] = np.zeros(nsed, 'f4')
    meta['sfr'] = np.zeros(nsed, 'f4')

    # https://dfm.io/python-fsps/current/stellarpop_api/
    imfdict = {'salpeter': 0, 'chabrier': 1, 'kroupa': 2}

    print('Instantiating the StellarPopulation object...', end='')
    t0 = time.time()
    if sfh_ssp:
        # SSP SFH
        sp = fsps.StellarPopulation(compute_vega_mags=False, 
                                    add_dust_emission=True,
                                    add_neb_emission=True,
                                    nebemlineinspec=include_nebular,
                                    imf_type=imfdict[imf], # 0=Salpeter, 1=Chabrier, 2=Kroupa
                                    dust_type=0,
                                    dust_index=-0.7,
                                    sfh=0, # SSP parameters
                                    zcontinuous=1,
                                    )
    else:
        # tabular SFH
        sp = fsps.StellarPopulation(compute_vega_mags=False, 
                                    add_dust_emission=True,
                                    add_neb_emission=True,
                                    nebemlineinspec=include_nebular,
                                    imf_type=imfdict[imf], # 0=Salpeter, 1=Chabrier, 2=Kroupa
                                    dust_type=0,
                                    dust_index=-0.7,
                                    sfh=3,  # tabular SFH parameters
                                    zcontinuous=1,
                                    )
    print('...took {:.3f} sec'.format((time.time()-t0)))
    print(sp.libraries)

    if include_nebular:
        print('Creating {} model spectra with nebular emission...'.format(nsed), end='')
    else:
        print('Creating {} model spectra without nebular emission...'.format(nsed), end='')

    t0 = time.time()
    for imodel, model in enumerate(models):
        sp.params['dust1'] = model['dust']
        sp.params['dust2'] = model['dust']
        sp.params['logzsol'] = model['logmet']

        if sfh_ssp:
            # SSP
            wave, flux = sp.get_spectrum(tage=10.**(model['logage'] - 9.), peraa=True)
        else:
            # lookback time of constant SFR
            agebin_indx = np.where(model['logage'] == np.float32(logages))[0]
            agebin = agebins[agebin_indx, :][0] # Gyr
            fspstime = agebin - agebin[0]       # Gyr
            tage = agebin[1] # time of observation [Gyr] 
            #print(tage, model['logage'])
    
            dt = np.diff(agebin) * 1e9          # [yr]
            sfh = np.zeros_like(fspstime) + 1. / dt #/ 2 # [Msun/yr]
    
            # force the SFR to go to zero at the edge
            fspstime = np.hstack((fspstime, fspstime[-1]+1e-8))
            sfh = np.hstack((sfh, 0.))
    
            sp.set_tabular_sfh(fspstime, sfh)
            #print(tage, sp.sfr)
    
            wave, flux = sp.get_spectrum(tage=tage, peraa=True) # tage in Gyr

        #plt.clf()
        #I = (wave > 3500) * (wave < 9000)
        #plt.plot(wave[I], flux[I])
        #plt.savefig('junk2.png')
        #pdb.set_trace()

        lodot = 3.828e33 # [erg/s]
        tenpc2 = (10. * 3.085678e18)**2 # [cm^2]
    
        flux = flux * lodot / (4. * np.pi * tenpc2) # [erg/s/cm2/A/Msun at 10pc]

        # Resample to constant log-lambda / velocity. In the IR (starting at ~1
        # micron), take every fourth sampling, to save space.
        if imodel == 0:
            dlogwave = PIXKMS_BLU / C_LIGHT / np.log(10) # pixel size [log-lambda]
            newwave = 10**np.arange(np.log10(np.min(wave)), np.log10(np.max(wave)), dlogwave)
    
            isplit = np.argmin(np.abs(newwave-PIXKMS_WAVESPLIT)) + 1
            newwave = np.hstack((newwave[:isplit], newwave[isplit:][::irfactor]))
            npix = len(newwave)
    
            fluxes = np.zeros((npix, nsed), dtype=np.float32)

            # emission lines
            linewaves = sp.emline_wavelengths
            linefluxes = np.zeros((len(sp.emline_wavelengths), nsed), dtype=np.float32)

        newflux = resample_flux(newwave, wave, flux)
    
        fluxes[:, imodel] = newflux
        linefluxes[:, imodel] = sp.emline_luminosity * lodot / (4.0 * np.pi * tenpc2)

        meta['mstar'][imodel] = sp.stellar_mass
        meta['sfr'][imodel] = sp.sfr
        #print(tage, sp.formed_mass, sp.stellar_mass)
        if sp.stellar_mass < 0:
            raise ValueError('Stellar mass is negative!')

        #plt.clf()
        #I = np.where((wave > 3500) * (wave < 5600))[0]
        #J = np.where((newwave > 3500) * (newwave < 3600))[0]
        #plt.plot(wave[I], flux[I])
        #plt.plot(newwave[J], fluxes[J, imodel])
        #plt.savefig('junk.png')
        #pdb.set_trace()

    print('...took {:.3f} min'.format((time.time()-t0)/60))

    return meta, newwave, fluxes, linewaves, linefluxes

def main(args):

    # velocity dispersion grid
    vdisp_nominal = 125.
    vdispmin = 50.
    vdispmax = 500.
    dvdisp = 25.
    nvdisp = int(np.ceil((vdispmax - vdispmin) / dvdisp)) + 1
    vdisp = np.linspace(vdispmin, vdispmax, nvdisp)

    if args.sfh_ssp:
        # SSP ages
        ages = np.array([0.05, 0.02, 0.065, 0.215, 0.715, 2.4, 7.8, 13.5]) # [Gyr]
        logages = np.log10(1e9 * ages) # [yr]
        nages = len(ages)
        #nages = 8
        #minlogage = 7.0    # =10 Myr
        #maxlogage = 10.146 # =14 Gyr
        #logages = np.linspace(minlogage, maxlogage, nages)
        agebins = None
    else:
        # Choose lookback time bins. 
        #agelims = np.array([0., 0.03, 0.1, 0.6, 3.6, 13.]) # [Gyr]
        #agelims = np.array([0., 0.01, 0.03, 0.1, 0.33, 1.1, 3.6, 12., 14.]) # [Gyr]

        # from prospect.templates.adjust_continuity_agebins
        nbins = 5
        tuniv = 13.7
        tbinmax = (tuniv * 0.85) * 1e9
        lim1, lim2 = 7.4772, 8.0
        agelims = np.array([0, lim1] + np.linspace(lim2, np.log10(tbinmax), nbins-2).tolist() + [np.log10(tuniv*1e9)]) # log10(yr)
        agelims = 10.**agelims / 1e9 # [Gyr]

        agebins = np.array([agelims[:-1], agelims[1:]]).T # [Gyr]
        logages = np.log10(1e9*np.sum(agebins, axis=1) / 2) # mean age [log10(yr)] in each bin
        
        print('<Ages>: '+' '.join(['{:.4f} Gyr'.format(10.**logage/1e9) for logage in logages]))

        nages = len(logages)

    logmets = np.array([0.0]) # [-1.0, -0.3, 0.0, 0.3]
    #logmets = np.array([-1.0, 0.0, 0.3]) # [-1.0, -0.3, 0.0, 0.3]
    nmets = len(logmets)
    zsolar = 0.019

    ndusts = 8
    mindust = 0.0
    minlogdust = -1.7
    maxlogdust = 0.477
    dusts = np.hstack((mindust, np.logspace(minlogdust, maxlogdust, ndusts-1)))

    # for testing
    if args.test:
        dusts = np.array([0.0])
        logmets = [0.0]
        ndusts = 1
        nmets = 1
    
    dims = (nages, nmets, ndusts)

    models_dtype = np.dtype(
        [('logmet', np.float32),
         ('dust', np.float32),
         ('logage', np.float32)])

    # Let's be pedantic about the procedure so we don't mess up the indexing...
    models = np.zeros(dims, dtype=models_dtype)

    for iage, logage in enumerate(logages):
        for imet, logmet in enumerate(logmets):
            for idust, dust in enumerate(dusts):
                models[iage, imet, idust]['logmet'] = logmet
                models[iage, imet, idust]['dust'] = dust
                models[iage, imet, idust]['logage'] = logage

    models = models.flatten()

    # Build models with and without line-emission.
    meta, wave, flux, linewaves, linefluxes = build_templates(
        models, logages, agebins=agebins, include_nebular=True, 
        imf=args.imf, sfh_ssp=args.sfh_ssp)
    _, _, fluxnolines, _, _ = build_templates(
        models, logages, agebins=agebins, include_nebular=False, 
        imf=args.imf, sfh_ssp=args.sfh_ssp)

    lineflux = flux - fluxnolines

    #I = (wave > 3500) * (wave < 9000)
    #plt.clf()
    #plt.plot(wave[I], flux[I, 0])
    #plt.plot(wave[I], fluxnolines[I, 0])
    #plt.ylim(0, 0.1e-7)
    #plt.savefig('junk2.png')
    #pdb.set_trace()

    # Build the velocity dispersion templates.

    # Select just the line-free models trimmed to the 1200-10000 A wavelength
    # range.
    I = np.where((wave > 1200) * (wave < PIXKMS_WAVESPLIT))[0]
    vdispwave = wave[I]
    #nvdispmodel = len(J)

    # Deprecated - stellar continuum normalized out.
    #normflux = np.zeros_like(fluxnolines[I, :][:, J])
    #for imodel in np.arange(nvdispmodel):
    #    smooth = smooth_continuum(vdispwave, fluxnolines[I, J[imodel]], smooth_window=200, 
    #                              smooth_step=50, medbin=1000)#, png='smooth-{}.png'.format(imodel))
    #    normflux[:, imodel] = fluxnolines[I, J[imodel]] / smooth

    vdispflux = []
    for sigma in vdisp / PIXKMS_BLU:
        vdispflux.append(gaussian_filter1d(fluxnolines[I, :], sigma=sigma, axis=0))
    vdispflux = np.stack(vdispflux, axis=-1) # [npix,nvdispmodel,nvdisp]

    #K = np.where((vdispwave > 3500) * (vdispwave < 4300))[0]
    #plt.clf()
    #plt.plot(vdispwave[K], fluxnolines[I, 6][J])
    #plt.plot(vdispwave[K], vdispflux[J, 6, nvdisp-1])
    #plt.savefig('junk.png')
    #pdb.set_trace()

    # Write out.
    outdir = os.path.join(os.environ.get('DESI_ROOT'), 'science', 'gqp', 'templates', 'fastspecfit', args.version)
    if not os.path.isdir(outdir):
        os.makedirs(outdir, exist_ok=True)
    outfile = os.path.join(outdir, 'ftemplates-{}-{}.fits'.format(args.imf, args.version))

    hduflux1 = fits.PrimaryHDU(flux)
    hduflux1.header['EXTNAME'] = 'FLUX'
    hduflux1.header['VERSION'] = args.version
    hduflux1.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    hduflux2 = fits.ImageHDU(lineflux)
    hduflux2.header['EXTNAME'] = 'LINEFLUX'
    hduflux2.header['VERSION'] = args.version
    hduflux2.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    hduflux3 = fits.ImageHDU(vdispflux)
    hduflux3.header['EXTNAME'] = 'VDISPFLUX'
    hduflux3.header['VERSION'] = args.version
    hduflux3.header['VDISPMIN'] = (vdispmin, 'minimum velocity dispersion [km/s]')
    hduflux3.header['VDISPMAX'] = (vdispmax, 'maximum velocity dispersion [km/s]')
    hduflux3.header['VDISPRES'] = (dvdisp, 'velocity dispersion spacing [km/s]')
    hduflux3.header['VDISPNOM'] = (vdisp_nominal, 'nominal (default) velocity dispersion [km/s]')
    hduflux3.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    isplit = np.argmin(np.abs(wave-PIXKMS_WAVESPLIT)) + 1

    hduwave1 = fits.ImageHDU(wave)
    hduwave1.header['EXTNAME'] = 'WAVE'
    hduwave1.header['BUNIT'] = 'Angstrom'
    hduwave1.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')
    hduwave1.header['PIXSZBLU'] = (PIXKMS_BLU, 'pixel size blueward of PIXSZSPT [km/s]')
    hduwave1.header['PIXSZRED'] = (PIXKMS_RED, 'pixel size redward of PIXSZSPT [km/s]')
    hduwave1.header['PIXSZSPT'] = (wave[isplit], 'wavelength where pixel size changes [Angstrom]')

    hduwave2 = fits.ImageHDU(vdispwave)
    hduwave2.header['EXTNAME'] = 'VDISPWAVE'
    hduwave2.header['BUNIT'] = 'Angstrom'
    hduwave2.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')
    hduwave2.header['PIXSZ'] = (PIXKMS_BLU, 'pixel size [km/s]')

    hdutable = fits.convenience.table_to_hdu(meta)
    hdutable.header['EXTNAME'] = 'METADATA'
    hdutable.header['imf'] = args.imf

    # emission lines
    hduflux4 = fits.ImageHDU(linefluxes)
    hduflux4.header['EXTNAME'] = 'LINEFLUXES'
    hduflux4.header['VERSION'] = args.version
    hduflux4.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    hduwave3 = fits.ImageHDU(linewaves)
    hduwave3.header['EXTNAME'] = 'LINEWAVES'
    hduwave3.header['BUNIT'] = 'Angstrom'
    hduwave3.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')

    hx = fits.HDUList([hduflux1, hduflux2, hduwave1, hduflux3, hduwave2, hdutable, hduflux4, hduwave3])

    print('Writing {} model spectra to {}'.format(len(models), outfile))
    hx.writeto(outfile, overwrite=True)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--version', required=True, help='Version number (e.g., 1.3.0)')
    parser.add_argument('--sfh-ssp', action='store_true', help='Use a SSP star formation history (SFH); otherwise a tabular SFH.')
    parser.add_argument('--imf', type=str, default='chabrier', choices=['chabrier', 'salpeter', 'kroupa'],
                        help='Initial mass function')
    parser.add_argument('--test', action='store_true', help='Generate a test set of SPS models.')
    args = parser.parse_args()

    main(args)
