#!/usr/bin/env python
"""
MPI wrapper for fastphot and fastspec.

"""
import pdb # for debugging

import os, time
import numpy as np

from desiutil.log import get_logger
log = get_logger()

def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None, makeqa=False,
                    samplefile=None, input_redshifts=False, outdir_data='.', templates=None,
                    templateversion=None, fphotodir=None, fphotofile=None):

    import sys
    from desispec.parallel import stdouterr_redirected
    from fastspecfit.mpi import plan

    if comm is None:
        rank, size = 0, 1
    else:
        rank, size = comm.rank, comm.size

    t0 = time.time()
    if rank == 0:
        if args.samplefile is not None:
            import fitsio
            from astropy.table import Table                
            if not os.path.isfile(args.samplefile):
                log.warning(f'{args.samplefile} does not exist.')
                return
            try:
                readcols = ['SURVEY', 'PROGRAM', 'HEALPIX', 'TARGETID']
                if input_redshifts:
                    readcols += ['Z']
                sample = Table(fitsio.read(args.samplefile, columns=readcols))
            except:
                if input_redshifts:
                    errmsg = f'Sample file {args.samplefile} with --input-redshifts set missing required columns {SURVEY,PROGRAM,HEALPIX,TARGETID,Z}'
                else:
                    errmsg = f'Sample file {args.samplefile} missing required columns {SURVEY,PROGRAM,HEALPIX,TARGETID}'
                self.log.critical(errmsg)
                raise ValueError(errmsg)

            _, zbestfiles, outfiles, groups, ntargets = plan(
                comm=comm, specprod=args.specprod, specprod_dir=specprod_dir,
                sample=sample, coadd_type='healpix', makeqa=args.makeqa,
                mp=args.mp, fastphot=args.fastphot,
                outdir_data=outdir_data, overwrite=args.overwrite)
        else:
            sample = None
            _, zbestfiles, outfiles, groups, ntargets = plan(
                comm=comm, specprod=args.specprod, specprod_dir=specprod_dir,
                coadd_type=args.coadd_type, survey=args.survey, program=args.program,
                healpix=args.healpix, tile=args.tile, night=args.night,
                makeqa=args.makeqa, mp=args.mp, fastphot=fastphot, outdir_data=outdir_data, 
                overwrite=args.overwrite)
        log.info('Planning took {:.2f} sec'.format(time.time() - t0))
    else:
        sample = None
        zbestfiles, outfiles, groups, ntargets = [], [], [], []

    if comm:
        zbestfiles = comm.bcast(zbestfiles, root=0)
        outfiles = comm.bcast(outfiles, root=0)
        groups = comm.bcast(groups, root=0)
        ntargets = comm.bcast(ntargets, root=0)
        sample = comm.bcast(sample, root=0)

    sys.stdout.flush()
    
    # all done
    if len(zbestfiles) == 0:
        return
        
    assert(len(groups) == size)
    assert(len(np.concatenate(groups)) == len(zbestfiles))

    for ii in groups[rank]:
        log.debug(f'Rank {rank} started at {time.asctime()}')
        sys.stdout.flush()

        # With --makeqa the desired output directories are in the 'zbestfiles'.
        if args.makeqa:
            from fastspecfit.qa import fastqa as fast
            cmd = 'fastqa'
            cmdargs = f'{outfiles[ii]} -o={zbestfiles[ii]} --mp={args.mp}'
        else:
            if fastphot:
                from fastspecfit.fastspecfit import fastphot as fast
                cmd = 'fastphot'
            else:
                from fastspecfit.fastspecfit import fastspec as fast
                cmd = 'fastspec'
            cmdargs = f'{zbestfiles[ii]} -o={outfiles[ii]} --mp={args.mp}'
                    
            if args.ignore_quasarnet:
                cmdargs += ' --ignore-quasarnet'
                
            if args.no_smooth_continuum:
                cmdargs += ' --no-smooth-continuum'
                
            if args.templates:
                cmdargs += f' --templates={args.templates}'
                
            if args.templateversion:
                cmdargs += f' --templateversion={args.templateversion}'
                
            if args.fphotodir:
                cmdargs += f' --fphotodir={args.fphotodir}'

            if args.fphotofile:
                cmdargs += f' --fphotofile={args.fphotofile}'
                
        if args.ntargets is not None:
            cmdargs += f' --ntargets={args.ntargets}'
                
        if args.firsttarget is not None:
            cmdargs += f' --firsttarget={args.firsttarget}'

        if sample is not None:
            # assume healpix coadds; find the targetids to process
            _, survey, program, healpix = os.path.basename(zbestfiles[ii]).split('-')
            healpix = int(healpix.split('.')[0])
            I = (sample['SURVEY'] == survey) * (sample['PROGRAM'] == program) * (sample['HEALPIX'] == healpix)
            targetids = ','.join(sample[I]['TARGETID'].astype(str))
            cmdargs += f' --targetids={targetids}'
            if input_redshifts:
                inputz = ','.join(sample[I]['Z'].astype(str))
                cmdargs += f' --input-redshifts={inputz}'
        else:
            if args.targetids is not None:
                cmdargs += f' --targetids={args.targetids}'
                
        if args.makeqa:
            logfile = os.path.join(zbestfiles[ii], os.path.basename(outfiles[ii]).replace('.gz', '').replace('.fits', '.log'))
        else:
            logfile = outfiles[ii].replace('.gz', '').replace('.fits', '.log')
            
        if logfile == outfiles[ii]:
            errmsg = f'Log file {logfile} cannot be the same as outfile!'
            self.log.critical(errmsg)
            raise ValueError(errmsg)

        if args.makeqa and args.plan:
            log.info(f'Still need to make {np.sum(ntargets)} QA figure(s) across {len(groups)} input file(s)')
            return

        log.info(f'Rank {rank}, ntargets={ntargets[ii]}: {cmd} {cmdargs}')
        sys.stdout.flush()

        if args.dry_run:
            continue

        try:
            t1 = time.time()
            outdir = os.path.dirname(logfile)
            if not os.path.isdir(outdir):
                os.makedirs(outdir, exist_ok=True)
            if args.nolog:
                fast(args=cmdargs.split())
            else:
                with stdouterr_redirected(to=logfile, overwrite=args.overwrite):
                    fast(args=cmdargs.split())
            dt1 = time.time() - t1
            log.info(f'  rank {rank} done in {dt1:.2f} sec')
            if not os.path.exists(outfiles[ii]):
                log.warning(f'  rank {rank} missing {outfiles[ii]}')
                raise IOError
        except:
            log.warning(f'  rank {rank} raised an exception')
            import traceback
            traceback.print_exc()

    log.debug(f'  rank {rank} is done')
    sys.stdout.flush()

    if comm is not None:
        comm.barrier()

    if rank == 0 and not args.dry_run:
        for outfile in outfiles:
            if not os.path.exists(outfile):
                log.warning(f'Missing {outfile}')

        log.info(f'All done at {time.asctime()}')
        
def main():
    """Main wrapper on fastphot and fastspec.

    Currently only knows about SV1 observations.

    """
    import argparse    
    from fastspecfit.mpi import plan
    
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--coadd-type', type=str, default='healpix', choices=['healpix', 'cumulative', 'pernight', 'perexp'],
                        help='Specify which type of spectra/zbest files to process.')
    parser.add_argument('--specprod', type=str, default='iron', #choices=['fuji', 'guadalupe', 'iron'],
                        help='Spectroscopic production to process.')
    
    parser.add_argument('--healpix', type=str, default=None, help='Comma-separated list of healpixels to process.')
    parser.add_argument('--survey', type=str, default='main,special,cmx,sv1,sv2,sv3', help='Survey to process.')
    parser.add_argument('--program', type=str, default='bright,dark,other,backup', help='Program to process.') # backup not supported
    parser.add_argument('--tile', default=None, type=str, nargs='*', help='Tile(s) to process.')
    parser.add_argument('--night', default=None, type=str, nargs='*', help='Night(s) to process (ignored if coadd-type is cumulative).')

    parser.add_argument('--samplefile', default=None, type=str, help='Full path to sample (FITS) file with {survey,program,healpix,targetid}.')
    parser.add_argument('--input-redshifts', action='store_true', help='Only used with --samplefile; if set, fit with redshift "Z" values.')
    
    parser.add_argument('--mp', type=int, default=1, help='Number of multiprocessing processes per MPI rank or node.')
    parser.add_argument('-n', '--ntargets', type=int, help='Number of targets to process in each file.')
    parser.add_argument('--firsttarget', type=int, default=0, help='Index of first object to to process in each file, zero-indexed.') 
    parser.add_argument('--targetids', type=str, default=None, help='Comma-separated list of TARGETIDs to process.')
    
    parser.add_argument('--fastphot', action='store_true', help='Fit the broadband photometry.')

    parser.add_argument('--templateversion', type=str, default=None, help='Template version number.')
    parser.add_argument('--templates', type=str, default=None, help='Optional full path and filename to the templates.')
    parser.add_argument('--fphotodir', type=str, default=None, help='Top-level location of the source photometry.')
    parser.add_argument('--fphotofile', type=str, default=None, help='Photometric information file.')
    
    parser.add_argument('--merge', action='store_true', help='Merge all individual catalogs (for a given survey and program) into one large file.')
    parser.add_argument('--mergedir', type=str, help='Output directory for merged catalogs.')
    parser.add_argument('--merge-suffix', type=str, help='Filename suffix for merged catalogs.')
    parser.add_argument('--mergeall', action='store_true', help='Merge all the individual merged catalogs into a single merged catalog.')
    parser.add_argument('--mergeall-main', action='store_true', help='Merge all the main catalogs.')
    parser.add_argument('--mergeall-sv', action='store_true', help='Merge all the SV catalogs.')
    parser.add_argument('--mergeall-special', action='store_true', help='Merge all the special catalogs.')
    parser.add_argument('--makeqa', action='store_true', help='Build QA in parallel.')
    parser.add_argument('--ntest', type=int, default=None, help='Select ntest healpixels as a test sample drawn randomly from the full specprod (only for coadd-type==healpix).')
    
    parser.add_argument('--ignore-quasarnet', default=False, action='store_true', help='Do not use QuasarNet to improve QSO redshifts.')
    parser.add_argument('--no-smooth-continuum', default=False, action='store_true', help='Do not fit the smooth continuum.')
    
    parser.add_argument('--overwrite', action='store_true', help='Overwrite any existing output files.')
    parser.add_argument('--plan', action='store_true', help='Plan how many nodes to use and how to distribute the targets.')
    parser.add_argument('--nompi', action='store_true', help='Do not use MPI parallelism.')
    parser.add_argument('--nolog', action='store_true', help='Do not write to the log file.')
    parser.add_argument('--dry-run', action='store_true', help='Generate but do not run commands.')

    parser.add_argument('--outdir-data', default='$PSCRATCH/fastspecfit/data', type=str, help='Base output data directory.')
    
    specprod_dir = None

    args = parser.parse_args()

    outdir_data = os.path.expandvars(args.outdir_data)

    if args.merge or args.mergeall or args.nompi:
        comm = None
    else:
        try:
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
        except ImportError:
            comm = None

    if comm is None:
        rank = 0
    else:
        rank = comm.rank
            
    # https://docs.nersc.gov/development/languages/python/parallel-python/#use-the-spawn-start-method
    if args.mp > 1 and 'NERSC_HOST' in os.environ:
        import multiprocessing
        multiprocessing.set_start_method('spawn')

    # check the input samplefile
    if rank == 0 and args.samplefile is not None:
        import fitsio
        from astropy.table import Table                
        if not os.path.isfile(args.samplefile):
            log.warning(f'{args.samplefile} does not exist.')
            return
        try:
            sample = Table(fitsio.read(args.samplefile, columns=['SURVEY', 'PROGRAM', 'HEALPIX', 'TARGETID']))
        except:
            errmsg = f'Sample file {args.samplefile} missing required columns {SURVEY,PROGRAM,HEALPIX,TARGETID}'
            self.log.critical(errmsg)
            raise ValueError(errmsg)

    if args.samplefile is None and args.coadd_type == 'healpix':
        args.survey = args.survey.split(',')
        args.program = args.program.split(',')

        # For a test sample, use a range of tiles 
        if args.ntest:
            from astropy.table import Table
            rand = np.random.RandomState(seed=1)
            tilepix = Table.read(f'$DESI_ROOT/spectro/redux/{args.specprod}/healpix/tilepix.fits')
            
            # trim to survey/program
            tilepix = tilepix[np.unique(np.hstack([np.where(tilepix['SURVEY'] == survey)[0] for survey in args.survey]))]
            if len(tilepix) == 0:
                log.warning('No matching surveys; nothing to do.')
                return            
            tilepix = tilepix[np.unique(np.hstack([np.where(tilepix['PROGRAM'] == program)[0] for program in args.program]))]
            if len(tilepix) == 0:
                log.warning('No matching programs; nothing to do.')
                return
            
            survprog = np.array([ss+pp for ss, pp in zip(tilepix['SURVEY'], tilepix['PROGRAM'])])
            I = rand.choice(len(tilepix), args.ntest, replace=False)
            args.healpix = ','.join(tilepix['HEALPIX'][I].astype(str))
            log.info(f'Selecting {len(I)} test healpixels: {args.healpix}')
            print(tilepix[I])

    if args.mergeall_main or args.mergeall_sv or args.mergeall_special:
        args.mergeall = True

    if args.merge or args.mergeall:
        from fastspecfit.mpi import merge_fastspecfit

        # convenience code to make the super-merge catalogs, e.g., fastspec-iron-{main,special,sv}.fits
        if args.fastphot:
            fastprefix = 'fastphot'
        else:
            fastprefix = 'fastspec'
        
        if args.mergeall_main or args.mergeall_sv or args.mergeall_special:
            from glob import glob
            if args.mergedir is None:
                mergedir = os.path.join(outdir_data, args.specprod, 'catalogs')
            else:
                mergedir = args.mergedir
                
            if args.mergeall_main:
                args.merge_suffix = f'{args.specprod}-main'
                fastfiles_to_merge = sorted(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-main-*.fits')))
            elif args.mergeall_special:
                args.merge_suffix = f'{args.specprod}-special'
                fastfiles_to_merge = sorted(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-special-*.fits')))
            elif args.mergeall_sv:
                args.merge_suffix = f'{args.specprod}-sv'
                fastfiles_to_merge = sorted(list(set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-*.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-main.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-special.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-special-*.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-main-*.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-sv.fits')))))
            else:
                fastfiles_to_merge = None
        else:
            fastfiles_to_merge = None

        if args.samplefile is not None:
            merge_fastspecfit(specprod=args.specprod, specprod_dir=specprod_dir, coadd_type='healpix',
                              sample=sample, merge_suffix=args.merge_suffix,
                              outdir_data=outdir_data, fastfiles_to_merge=fastfiles_to_merge,
                              outsuffix=args.merge_suffix, mergedir=args.mergedir, overwrite=args.overwrite,
                              fastphot=args.fastphot, supermerge=args.mergeall, mp=args.mp)
        else:
            merge_fastspecfit(specprod=args.specprod, specprod_dir=specprod_dir, coadd_type=args.coadd_type,
                              survey=args.survey, program=args.program, healpix=args.healpix,
                              tile=args.tile, night=args.night, outdir_data=outdir_data,
                              fastfiles_to_merge=fastfiles_to_merge,
                              outsuffix=args.merge_suffix, mergedir=args.mergedir, overwrite=args.overwrite,
                              fastphot=args.fastphot, supermerge=args.mergeall, mp=args.mp)
        return

    if args.plan and args.makeqa is False:
        if rank == 0:
            if args.samplefile is not None:
                plan(comm=comm, specprod=args.specprod, specprod_dir=specprod_dir,
                     sample=sample, coadd_type='healpix', makeqa=args.makeqa,
                     mp=args.mp, fastphot=args.fastphot,
                     outdir_data=outdir_data, overwrite=args.overwrite)
            else:
                plan(comm=comm, specprod=args.specprod, specprod_dir=specprod_dir,
                     coadd_type=args.coadd_type, survey=args.survey, program=args.program,
                     healpix=args.healpix, tile=args.tile, night=args.night,
                     makeqa=args.makeqa, mp=args.mp, fastphot=args.fastphot,
                     outdir_data=outdir_data, overwrite=args.overwrite)
    else:
        run_fastspecfit(args, comm=comm, fastphot=args.fastphot, specprod_dir=specprod_dir,
                        makeqa=args.makeqa, outdir_data=outdir_data, samplefile=args.samplefile,
                        input_redshifts=args.input_redshifts, templates=args.templates,
                        templateversion=args.templateversion, fphotodir=args.fphotodir,
                        fphotofile=args.fphotofile)

if __name__ == '__main__':
    main()
