#!/usr/bin/env python

'''Build a set of templates for use with FastSpecFit.

time python bin/build-templates --version 2.0.0 --templatedir $DESI_ROOT/users/ioannis/fastspecfit/templates

* Chabrier IMF
* MIST isochrones
* C3K_a stellar library
* 5 variable-width age bins between 30 Myr and 13.7 Gyr with constant star
  formation within each age bin

Notes
-----

* Figure 3 in Leja et al. 2017 nicely shows the effect of various free
  parameters on the resulting SED.

* To use different stellar libraries (e.g., MILES vs C3K), in a clean terminal
  make sure you've git-updated a local checkout of the python-fsps and that
  $SPS_HOME/src/sps_vars.f90 has been edited correctly and then:

  conda create -y -n fsps python 'numpy<2'
  conda activate fsps
  cd $HOME/code
  git clone --recursive https://github.com/dfm/python-fsps.git
  export SPS_HOME=$HOME/code/python-fsps/src/fsps/libfsps
  cd python-fsps
  python -m pip install . --no-cache-dir --force-reinstall
  python -c "import fsps; sp = fsps.StellarPopulation(); print(sp.libraries)"

* The resolution of the C3K_a models is R(lambda/FWHM)=3000 or
  R(lambda/dlambda)=7065 (=42 km/s) between 2750-9100 A:
  https://github.com/cconroy20/fsps/tree/master/SPECTRA/C3K#readme

'''
import os, time
import numpy as np
import numpy.ma as ma
import argparse
import fitsio, fsps
from scipy.ndimage import gaussian_filter1d
from scipy.interpolate import interpn

from astropy.io import fits
from astropy.table import Table
import matplotlib.pyplot as plt

from desispec.interpolation import resample_flux
from fastspecfit.util import C_LIGHT, trapz
from fastspecfit.templates import Templates


def _qa_dustem_templates(dustwave, mduste, qpah, umin, gamma, png):

    # check against FSPS using a simple (age-independent) Charlot & Fall
    # dust model

    import fsps
    import seaborn as sns

    def attenuation(tauv, wave, alpha=-0.7):
        return np.exp(-tauv * (wave / 5500.)**alpha)

    tauv = 3.

    sp = fsps.StellarPopulation(imf_type=1, dust_type=0,
                                dust_index=-0.7, sfh=0, # SSP parameters
                                zcontinuous=1)

    # intrinsic (dust-free) spectrum with no dust emission
    sp.params["dust2"] = 0
    sp.params["add_dust_emission"] = False
    wave, flux = sp.get_spectrum(tage=0.01, peraa=True)

    # build the attenuated spectrum + dust emission using FSPS; take this
    # spectrum as "truth"
    sp.params["dust2"] = tauv
    sp.params["add_dust_emission"] = True
    _, fsps_dflux = sp.get_spectrum(tage=0.01, peraa=True)

    # For the purposes of the QA, interpolate the dust emission spectrum
    # onto the FSPS wavelength grid.
    mduste = np.interp(wave, dustwave, mduste, left=0.)

    # now do the dust emission calculation ourselves
    atten = attenuation(tauv, wave)
    dflux = flux * atten

    # normalize the dust emission to the luminosity absorbed by dust, i.e.,
    # keeping Lbol constant
    lbold = trapz(dflux, x=wave) # attenuated energy
    lboln = trapz(flux, x=wave)  # intrinsic energy

    labs = lboln - lbold
    norm = trapz(mduste, x=wave) # should already be 1.0
    duste = mduste * labs / norm

    orig_duste = duste.copy()

    # handle dust self-absorption (algorithm taken from fsps.add_dust)
    iiter = 0
    tduste = 0.
    while (lboln - lbold) > 1e-2 or iiter < 5:
        oduste = duste.copy()
        duste *= atten # attenuate
        tduste += duste

        lbold = trapz(duste, x=wave)  # after self-absorption
        lboln = trapz(oduste, x=wave) # before self-absorption
        duste = mduste * (lboln - lbold) / norm
        #print(lboln - lbold)

        iiter += 1

    plotwave = wave / 1e4

    xlim = (0.1, 500.)
    I = (plotwave > xlim[0]) * (plotwave < xlim[1])

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7), sharex=True)

    ax1.plot(plotwave[I], flux[I], label='Unattenuated spectrum')
    ax1.plot(plotwave[I], (dflux + tduste)[I], alpha=0.7, label='With self-absorption')
    ax1.plot(plotwave[I], (dflux + orig_duste)[I], alpha=0.7, label='No self-absorption')
    ax1.plot(plotwave[I], fsps_dflux[I], alpha=0.7, label='FSPS')
    ax1.set_xscale('log')
    ax1.set_yscale('log')
    ax1.legend(fontsize=10)

    ax2.plot(plotwave[I], (dflux + tduste)[I]/fsps_dflux[I]-1, alpha=0.7, label='No self-absorption')
    ax2.plot(plotwave[I], (dflux + orig_duste)[I]/fsps_dflux[I]-1, alpha=0.7, label='With self-absorption')
    ax2.set_xscale('log')
    ax2.legend(fontsize=10)
    fig.tight_layout()
    fig.savefig(png)


def build_dustem_templates(qpah=3.5, umin=1., gamma=0.01, png=None):
    """Build the DL07 dust emission templates.

    """
    # Umin and qpah parameter grids
    umin_grid = np.array([0.10, 0.15, 0.20, 0.30, 0.40, 0.50, 0.70, 0.80,
                          1.00, 1.20, 1.50, 2.00, 2.50, 3.00, 4.00, 5.00, 
                          7.00, 8.00, 12.0, 15.0, 20.0, 25.0])
    numin = len(umin_grid)

    # from Table 3 in https://arxiv.org/pdf/astro-ph/0608003
    d_qpah = {
        'MW3.1_00': 0.47,
        'MW3.1_10': 1.12,
        'MW3.1_20': 1.77,
        'MW3.1_30': 2.50,
        'MW3.1_40': 3.19,
        'MW3.1_50': 3.90,
        'MW3.1_60': 4.58,
        }
    models = list(d_qpah.keys())
    qpah_grid = np.array([d_qpah[model] for model in models])
    nqpah = len(qpah_grid)

    nwave = 1001 # fixed number of wavelengths
    dustem = np.zeros((nqpah, numin*2, nwave))

    for imodel, model in enumerate(models):
        dustfile = os.path.join(templatedir, 'original', f'DL07_{model}.dat')
        dustem1 = np.loadtxt(dustfile, skiprows=2)
        if imodel == 0:
            dustwave = dustem1[:, 0]
        dustem[imodel, :, :] = dustem1[:, 1:].T
    dustwave *= 1e4 # [mu --> Angstrom]

    # set up the linear interpolation objects over Umin and qpah
    #umin_interp = RegularGridInterpolator((qpah_grid, umin_grid), dustem[:, :22, :], method='linear')
    #umax_interp = RegularGridInterpolator((qpah_grid, umin_grid), dustem[:, 22:, :], method='linear')
    #dumin = umin_interp((qpah, umin))
    #dumax = umax_interp((qpah, umin))
    dumin = interpn((qpah_grid, umin_grid, dustwave), dustem[:, :22, :], (qpah, umin, dustwave))
    dumax = interpn((qpah_grid, umin_grid, dustwave), dustem[:, 22:, :], (qpah, umin, dustwave))

    # construct P(U)dU as a weighted average of dumin and dumax
    mduste = (1. - gamma) * dumin + gamma * dumax

    # convert to F_lambda (from F_nu) and normalize the spectrum to unity
    mduste /= dustwave**2
    norm = trapz(mduste, x=dustwave)
    if norm > 0.:
        mduste /= norm

    if png:
        _qa_dustem_templates(dustwave, mduste, qpah, umin, gamma, png)

    return dustwave, mduste


def build_fe_templates(version='1.0', png=None):
    """Build the Fe template.

    """
    from scipy.ndimage import gaussian_filter1d

    def _qa_fe_templates(fewave, feflux, title):
        fig, ax = plt.subplots()
        ax.plot(fewave, feflux, color='k', lw=1.5, label='Original', alpha=0.5)
        for vdisp in np.logspace(np.log10(150.), np.log10(1e4), 5):
            sigma = vdisp / Templates.AGN_PIXKMS # [pixels]
            sfeflux = gaussian_filter1d(feflux, sigma=sigma)
            ax.plot(fewave, sfeflux, lw=2, label=r'$\sigma=$'+f'{vdisp:.0f} km/s', alpha=0.75)
        ax.set_xlabel(r'Wavelength ($\AA$)')
        ax.set_ylabel('Flux (arbitrary units)')
        #ax.set_xscale('log')
        #ax.set_yscale('log')
        ax.legend(loc='upper left', ncol=2)
        ax.set_title(title)
        fig.tight_layout()
        fig.savefig(png)

    templatefile = os.path.join(templatedir, 'original', f'fetemplates-{version}.fits')
    fe = fitsio.read(templatefile, 'VW01ALT')
    fewave = fe['WAVE'].astype('f8')
    feflux = fe['FLUX'].astype('f8')

    #ma.clump_masked(ma.array(feflux, mask=(fewave > 2750.) * (fewave < 2850.) * (feflux == 0.)))

    # rebin to constant log-lambda (i.e., constant velocity)
    dlogwave = Templates.AGN_PIXKMS / C_LIGHT / np.log(10) # pixel size [log-lambda]
    newwave = 10.**np.arange(np.log10(Templates.AGN_PIXKMS_BOUNDS[0]),
                             np.log10(Templates.AGN_PIXKMS_BOUNDS[1]),
                             dlogwave)
    newflux = resample_flux(newwave, fewave, feflux)

    #norm = trapz(newflux, x=newwave)
    norm = np.median(newflux)
    if norm > 0.:
        newflux /= norm

    if png:
        _qa_fe_templates(newwave, newflux, f'Vestergaard & Wilkes (2001) - v{version}')

    return newwave, newflux


def build_agn_templates(agntau=10., png=None):
    """Build the AGN templates.

    """
    def _qa_agn_templates(agnwave, agnflux, agntau, agngrid=None, taugrid=None):
        plotwave = agnwave / 1e4
        #xlim = (0.1, 100.)
        xlim = (np.min(plotwave), 100.)
        I = (plotwave > xlim[0]) * (plotwave < xlim[1])

        fig, ax = plt.subplots()
        ax.plot(plotwave[I], agnflux[I], label=f'tau={agntau:.1f}', lw=2)
        if agngrid is not None:
            for ii in range(agngrid.shape[0]):
                ax.plot(plotwave[I], agngrid[ii, I] / trapz(agngrid[ii, :], x=agnwave),
                        label=f'{taugrid[ii]:.1f}', lw=0.5)
        ax.set_xlabel('Wavelength (micron)')
        ax.set_ylabel('Flux (arbitrary units)')
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.set_xlim(xlim)
        ax.legend()
        fig.tight_layout()
        fig.savefig(png)

    templatefile = os.path.join(templatedir, 'original', 'Nenkova08_y010_torusg_n10_q2.0.dat')
    data = np.loadtxt(templatefile, skiprows=4)
    agnwave = data[:, 0]    # [nwave, Angstrom]
    agngrid = data[:, 1:].T # [nagn, nwave]
    agngrid /= agnwave[np.newaxis, :]**2 # [F_nu --> F_lam]
    nwave = len(agnwave)

    tau_grid = np.array([5, 10, 20, 30, 40, 60, 80, 100, 150]).astype('f4')
    ntau = len(tau_grid)

    agnflux = interpn((tau_grid, agnwave), agngrid, (agntau, agnwave))
    #I = agnflux > 0.
    #agnflux = agnflux[I]
    #agnwave = agnwave[I]

    norm = trapz(agnflux, x=agnwave)
    if norm > 0.:
        agnflux /= norm

    if png:
        _qa_agn_templates(agnwave, agnflux, agntau, agngrid=agngrid, taugrid=tau_grid)

    return agnwave, agnflux


def build_fsps_templates(models, logages, agebins=None, imf='chabrier',
                         include_nebular=True):

    nsed = len(models)

    meta = Table()
    meta['age'] = 10**models['logage']
    meta['zzsun'] = models['logmet']
    meta['mstar'] = np.zeros(nsed, 'f4')
    meta['sfr'] = np.zeros(nsed, 'f4')

    # https://dfm.io/python-fsps/current/stellarpop_api/
    imfdict = {'salpeter': 0, 'chabrier': 1, 'kroupa': 2}

    print('Instantiating the StellarPopulation object...', end='')
    t0 = time.time()
    # tabular SFH
    sp = fsps.StellarPopulation(
        compute_vega_mags=False,
        add_dust_emission=False, # note!
        add_neb_emission=True,
        nebemlineinspec=include_nebular,
        imf_type=imfdict[imf],
        #dust_type=0,
        #dust_index=-0.7,
        sfh=3,  # tabular SFH parameters
        zcontinuous=1,
    )
    print('...took {:.3f} sec'.format((time.time()-t0)))
    print(sp.libraries)

    if include_nebular:
        print('Creating {} model spectra with nebular emission...'.format(nsed), end='')
    else:
        print('Creating {} model spectra without nebular emission...'.format(nsed), end='')

    t0 = time.time()
    for imodel, model in enumerate(models):
        sp.params['logzsol'] = model['logmet']

        # lookback time of constant SFR
        agebin_indx = np.where(model['logage'] == np.float32(logages))[0]
        agebin = agebins[agebin_indx, :][0] # Gyr
        fspstime = agebin - agebin[0]       # Gyr
        tage = agebin[1] # time of observation [Gyr]
        #print(tage, model['logage'])

        dt = np.diff(agebin) * 1e9          # [yr]
        sfh = np.zeros_like(fspstime) + 1. / dt #/ 2 # [Msun/yr]

        # force the SFR to go to zero at the edge
        fspstime = np.hstack((fspstime, fspstime[-1]+1e-8))
        sfh = np.hstack((sfh, 0.))

        sp.set_tabular_sfh(fspstime, sfh)
        #print(tage, sp.sfr)

        wave, flux = sp.get_spectrum(tage=tage, peraa=True) # tage in Gyr

        lodot = 3.828e33 # [erg/s]
        tenpc2 = (10. * 3.085678e18)**2 # [cm^2]

        flux = flux * lodot / (4. * np.pi * tenpc2) # [erg/s/cm2/A/Msun at 10pc]

        # Resample to constant log-lambda / velocity. In the IR (starting at ~1
        # micron), take every fourth sampling, to save space.
        if imodel == 0:
            lo = np.searchsorted(wave, Templates.PIXKMS_BOUNDS[0], 'left')
            hi = np.searchsorted(wave, Templates.PIXKMS_BOUNDS[1], 'left')

            dlogwave = Templates.PIXKMS / C_LIGHT / np.log(10) # pixel size [log-lambda]
            optwave = 10.**np.arange(np.log10(wave[lo]), np.log10(wave[hi]), dlogwave)
            newwave = np.hstack((wave[:lo], optwave, wave[hi+1:]))
            npix = len(newwave)

            fluxes = np.zeros((npix, nsed), dtype=np.float64)

            # emission lines
            linewaves = sp.emline_wavelengths
            linefluxes = np.zeros((len(sp.emline_wavelengths), nsed), dtype=np.float64)

        newflux = resample_flux(newwave, wave, flux)

        fluxes[:, imodel] = newflux
        linefluxes[:, imodel] = sp.emline_luminosity * lodot / (4.0 * np.pi * tenpc2)

        meta['mstar'][imodel] = sp.stellar_mass
        meta['sfr'][imodel] = sp.sfr
        #print(tage, sp.formed_mass, sp.stellar_mass)
        if sp.stellar_mass < 0:
            raise ValueError('Stellar mass is negative!')

        #plt.clf()
        #I = np.where((wave > 3500) * (wave < 5600))[0]
        #J = np.where((newwave > 3500) * (newwave < 3600))[0]
        #plt.plot(wave[I], flux[I])
        #plt.plot(newwave[J], fluxes[J, imodel])
        #plt.savefig('junk.png')

    print('...took {:.3f} min'.format((time.time()-t0)/60.))

    return meta, newwave, fluxes, linewaves, linefluxes


def main(imf='chabrier', qpah=3.5, umin=1., gamma=0.01, agntau=10.,
         test=False, version='1.0.0', templatedir=None):
    """Build all the templates.

    """
    # AGN + Fe template(s)
    fewave, feflux = build_fe_templates(version='1.0')#, png=os.path.join(templatedir, 'original', 'qa-fe.png'))
    agnwave, agnflux = build_agn_templates(agntau=agntau, png=os.path.join(templatedir, 'original', 'qa-agn.png'))

    # dust emission template(s)
    dustwave, dustflux = build_dustem_templates(qpah=qpah, umin=umin, gamma=gamma)#,
                                              #png=os.path.join(templatedir, 'original', 'qa-dustem.png'))

    # Choose lookback time bins.

    # from prospect.templates.adjust_continuity_agebins
    nbins = 5
    tuniv = 13.7
    tbinmax = (tuniv * 0.85) * 1e9
    lim1, lim2 = 7.4772, 8.0
    agelims = np.array([0, lim1] + np.linspace(lim2, np.log10(tbinmax), nbins-2).tolist() + [np.log10(tuniv*1e9)]) # log10(yr)
    agelims = 10.**agelims / 1e9 # [Gyr]

    agebins = np.array([agelims[:-1], agelims[1:]]).T # [Gyr]
    logages = np.log10(1e9*np.sum(agebins, axis=1) / 2) # mean age [log10(yr)] in each bin
    print('<Ages>: '+' '.join(['{:.4f} Gyr'.format(10.**logage/1e9) for logage in logages]))

    nages = len(logages)

    logmets = np.array([0.0]) # [-1.0, -0.3, 0.0, 0.3]
    #logmets = np.array([-1.0, 0.0, 0.3]) # [-1.0, -0.3, 0.0, 0.3]
    nmets = len(logmets)
    zsolar = 0.019

    # for testing
    if test:
        logmets = [0.0]
        nmets = 1

    dims = (nages, nmets)

    models_dtype = np.dtype(
        [('logmet', np.float32),
         ('logage', np.float32)])

    # Let's be pedantic about the procedure so we don't mess up the indexing...
    models = np.zeros(dims, dtype=models_dtype)

    for iage, logage in enumerate(logages):
        for imet, logmet in enumerate(logmets):
            models[iage, imet]['logmet'] = logmet
            models[iage, imet]['logage'] = logage

    models = models.flatten()

    # Build models with and without line-emission.
    meta, wave, flux, linewaves, linefluxes = build_fsps_templates(
        models, logages, agebins=agebins, include_nebular=True, imf=imf)

    _, _, fluxnolines, _, _ = build_fsps_templates(
        models, logages, agebins=agebins, include_nebular=False, imf=imf)

    lineflux = flux - fluxnolines

    ## Convolve the line-free models (trimmed to the 1200-10000 A wavelength
    ## range) to the nominal velocity dispersion
    #I = np.where(wave < PIXKMS_WAVESPLIT)[0]
    ##I = np.where((wave > 1200) * (wave < PIXKMS_WAVESPLIT))[0]
    #vdispwave = wave[I]
    #vdispflux = gaussian_filter1d(fluxnolines[I, :], sigma=vdisp_nominal / PIXKMS_BLU, axis=1) # [npix,nmodel]

    # Interpolate the dust emission model(s) to the nominal wavelength array.
    dustflux = np.interp(wave, dustwave, dustflux, left=0.)

    #agnflux = 10.**np.interp(np.log10(wave), np.log10(agnwave), np.log10(agnflux+1e-30))-1e-30
    #I = wave <= agnwave[0] # do not extrapolate blueward
    #agnflux[I] = 0.

    # Write out.
    outdir = os.path.join(templatedir, version)
    if not os.path.isdir(outdir):
        os.makedirs(outdir, exist_ok=True)
    outfile = os.path.join(outdir, f'ftemplates-{imf}-{version}.fits')

    #isplit = np.argmin(np.abs(wave-PIXKMS_WAVESPLIT)) + 1

    hduflux1 = fits.PrimaryHDU(flux)
    hduflux1.header['EXTNAME'] = 'FLUX'
    hduflux1.header['VERSION'] = version
    hduflux1.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    hduflux2 = fits.ImageHDU(lineflux)
    hduflux2.header['EXTNAME'] = 'LINEFLUX'
    hduflux2.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    # dust emission and AGN spectra
    hduflux3 = fits.ImageHDU(dustflux)
    hduflux3.header['EXTNAME'] = 'DUSTFLUX'
    hduflux3.header['QPAH'] = (qpah, 'PAH fraction')
    hduflux3.header['UMIN'] = (umin, 'minimum radiation field strength')
    hduflux3.header['GAMMA'] = (gamma, 'gamma parameter')
    hduflux3.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    hduflux4 = fits.ImageHDU(agnflux)
    hduflux4.header['EXTNAME'] = 'AGNFLUX'
    hduflux4.header['AGNTAU'] = (agntau, 'AGN optical depth')
    hduflux4.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    hduflux5 = fits.ImageHDU(feflux)
    hduflux5.header['EXTNAME'] = 'FEFLUX'
    hduflux5.header['BUNIT'] = 'erg/(s cm2 Angstrom)'


    hduwave1 = fits.ImageHDU(wave)
    hduwave1.header['EXTNAME'] = 'WAVE'
    hduwave1.header['BUNIT'] = 'Angstrom'
    hduwave1.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')
    hduwave1.header['PIXWAVLO'] = (Templates.PIXKMS_BOUNDS[0], 'min(wave) where pixel size is PIXKMS [Angstrom]')
    hduwave1.header['PIXWAVHI'] = (Templates.PIXKMS_BOUNDS[1], 'max(wave) where pixel size is PIXKMS [Angstrom]')
    hduwave1.header['PIXKMS'] = (Templates.PIXKMS, 'pixel size blueward of PIXSZSPT [km/s]')

    hduwave2 = fits.ImageHDU(agnwave)
    hduwave2.header['EXTNAME'] = 'AGNWAVE'
    hduwave2.header['BUNIT'] = 'Angstrom'
    hduwave2.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')

    hduwave3 = fits.ImageHDU(fewave)
    hduwave3.header['EXTNAME'] = 'FEWAVE'
    hduwave3.header['BUNIT'] = 'Angstrom'
    hduwave3.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')
    hduwave3.header['PIXSZ'] = (Templates.AGN_PIXKMS, 'pixel size [km/s]')


    hdutable = fits.convenience.table_to_hdu(meta)
    hdutable.header['EXTNAME'] = 'METADATA'
    hdutable.header['imf'] = imf

    # emission lines
    hduflux6 = fits.ImageHDU(linefluxes)
    hduflux6.header['EXTNAME'] = 'LINEFLUXES'
    hduflux6.header['BUNIT'] = 'erg/(s cm2 Angstrom)'

    hduwave4 = fits.ImageHDU(linewaves)
    hduwave4.header['EXTNAME'] = 'LINEWAVES'
    hduwave4.header['BUNIT'] = 'Angstrom'
    hduwave4.header['AIRORVAC'] = ('vac', 'vacuum wavelengths')

    # metadata table
    hx = fits.HDUList([hduflux1, hduflux2, hduflux3, hduwave1, hdutable,
                       hduflux4, hduwave2,
                       hduflux5, hduwave3,
                       hduflux6, hduwave4])

    print(f'Writing {len(models)} model spectra to {outfile}')
    hx.writeto(outfile, overwrite=True)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--templatedir', required=True, help='Location of original and output templates')
    parser.add_argument('--version', required=True, help='Version number (e.g., 2.0.0)')

    # DL07 parameters
    # https://python-fsps.readthedocs.io/en/latest/stellarpop_api
    parser.add_argument('--qpah', type=float, default=3.5, help='DL07 parameter')
    parser.add_argument('--umin', type=float, default=1., help='DL07 parameter')
    parser.add_argument('--gamma', type=float, default=0.01, help='DL07 parameter')
    # Nenkova+08 parameters
    # https://python-fsps.readthedocs.io/en/latest/stellarpop_api
    parser.add_argument('--agntau', type=float, default=10., help='Nenkova+08 parameter')
    # FSPS parameters
    parser.add_argument('--imf', type=str, default='chabrier', choices=['chabrier', 'salpeter', 'kroupa'],
                        help='Initial mass function')
    parser.add_argument('--test', action='store_true', help='Generate a test set of SPS models.')
    args = parser.parse_args()

    templatedir = os.path.expandvars(args.templatedir)
    origdir = os.path.join(templatedir, 'original')
    if not os.path.isdir(origdir):
        raise IOError(f'Missing directory containing original templates {origdir}')

    main(imf=args.imf, qpah=args.qpah, umin=args.umin, gamma=args.gamma,
         agntau=args.agntau, test=args.test, version=args.version,
         templatedir=templatedir)
