#!/usr/bin/env python
"""
MPI wrapper for fastphot and fastspec.

"""
import os, time
import numpy as np
from astropy.table import Table

from fastspecfit.logger import log


def build_cmdargs(args, redrockfile, outfile, sample=None, fastphot=False,
                  input_redshifts=False):
    """Build the set of command-line arguments.

    """
    # With --makeqa the desired output directories are in the 'redrockfiles'.
    if args.makeqa:
        cmd = 'fastqa'
        cmdargs = f'{outfile} -o={redrockfile} --mp={args.mp}'
    else:
        if fastphot:
            cmd = 'fastphot'
        else:
            cmd = 'fastspec'
        cmdargs = f'{redrockfile} -o={outfile} --mp={args.mp}'

        if args.ignore_quasarnet:
            cmdargs += ' --ignore-quasarnet'
        if args.ignore_quasarnet:
            cmdargs += ' --ignore-photometry'
        if args.no_smooth_continuum:
            cmdargs += ' --no-smooth-continuum'
        if args.templates:
            cmdargs += f' --templates={args.templates}'
        if args.templateversion:
            cmdargs += f' --templateversion={args.templateversion}'
        if args.fphotodir:
            cmdargs += f' --fphotodir={args.fphotodir}'
        if args.fphotofile:
            cmdargs += f' --fphotofile={args.fphotofile}'
        if args.nmonte:
            cmdargs += f' --nmonte={args.nmonte}'
        if args.seed:
            cmdargs += f' --seed={args.seed}'

    if args.verbose:
        cmdargs += ' --verbose'
    if args.ntargets is not None:
        cmdargs += f' --ntargets={args.ntargets}'
    if args.firsttarget is not None:
        cmdargs += f' --firsttarget={args.firsttarget}'

    if sample is not None:
        # assume healpix coadds; find the targetids to process -- fragile
        _, survey, program, healpix = os.path.basename(redrockfile).split('-')
        healpix = int(healpix.split('.')[0])
        I = (sample['SURVEY'] == survey) * (sample['PROGRAM'] == program) * (sample['HEALPIX'] == healpix)
        targetids = ','.join(sample[I]['TARGETID'].astype(str))
        cmdargs += f' --targetids={targetids}'
        if input_redshifts:
            inputz = ','.join(sample[I]['Z'].astype(str))
            cmdargs += f' --input-redshifts={inputz}'
    else:
        if args.targetids is not None:
            cmdargs += f' --targetids={args.targetids}'

    if args.makeqa:
        logfile = os.path.join(redrockfile, os.path.basename(outfile).replace('.gz', '').replace('.fits', '.log'))
    else:
        logfile = outfile.replace('.gz', '').replace('.fits', '.log')

    #if args.makeqa and args.plan:
    #    log.info(f'Still need to make {np.sum(ntargets)} QA figure(s) across {len(groups)} input file(s)')
    #    return

    return cmd, cmdargs, logfile


def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None, makeqa=False,
                    sample=None, input_redshifts=False, outdir_data='.', templates=None,
                    templateversion=None, fphotodir=None, fphotofile=None):
    """Main wrapper to run fastspec, fastphot, or fastqa.

    """
    import sys
    from desispec.parallel import stdouterr_redirected, weighted_partition
    from fastspecfit.mpi import plan

    if comm:
        rank = comm.rank
        size = comm.size
    else:
        rank, size = 0, 1

    if rank == 0:
        t0 = time.time()
    _, all_redrockfiles, all_outfiles, all_ntargets = plan(
        comm=comm, specprod=args.specprod, specprod_dir=specprod_dir,
        coadd_type=args.coadd_type, survey=args.survey, program=args.program,
        healpix=args.healpix, tile=args.tile, night=args.night,
        makeqa=args.makeqa, mp=args.mp, fastphot=fastphot,
        outdir_data=outdir_data, overwrite=args.overwrite, sample=sample)
    if rank == 0:
        log.info(f'Planning took {time.time() - t0:.2f} sec')
        # If no work is left to do, let the other ranks know so they can return
        # politely.
        if len(all_redrockfiles) == 0:
            alldone = True
        else:
            alldone = False
    else:
        alldone = False

    if comm:
        alldone = comm.bcast(alldone, root=0)

    if alldone:
        return

    if comm:
        if rank == 0:
            groups = weighted_partition(all_ntargets, size)
            for irank in range(1, size):
                log.debug(f'Rank {rank} sending work to rank {irank}')
                comm.send(all_redrockfiles[groups[irank]], dest=irank, tag=1)
                comm.send(all_outfiles[groups[irank]], dest=irank, tag=2)
                comm.send(all_ntargets[groups[irank]], dest=irank, tag=3)
            # rank 0 gets work, too
            redrockfiles = all_redrockfiles[groups[0]]
            outfiles = all_outfiles[groups[0]]
            ntargets = all_ntargets[groups[0]]
        else:
            log.debug(f'Rank {rank}: received work from rank 0')
            redrockfiles = comm.recv(source=0, tag=1)
            outfiles = comm.recv(source=0, tag=2)
            ntargets = comm.recv(source=0, tag=3)
    else:
        redrockfiles = all_redrockfiles
        outfiles = all_outfiles
        ntargets = all_ntargets


    # loop on each file
    for redrockfile, outfile, ntarget in zip(redrockfiles, outfiles, ntargets):
        if rank == 0:
            log.debug(f'Rank {rank} started at {time.asctime()}')

        if args.makeqa:
            from fastspecfit.qa import fastqa as fast
        else:
            if fastphot:
                from fastspecfit.fastspecfit import fastphot as fast
            else:
                from fastspecfit.fastspecfit import fastspec as fast

        cmd, cmdargs, logfile = build_cmdargs(args, redrockfile, outfile, sample=sample,
                                              fastphot=fastphot, input_redshifts=input_redshifts)

        if rank == 0:
            log.info(f'Rank {rank}: ntargets={ntarget}: {cmd} {cmdargs}')

        if args.dry_run:
            continue

        try:
            t1 = time.time()
            outdir = os.path.dirname(logfile)
            if not os.path.isdir(outdir):
                os.makedirs(outdir, exist_ok=True)

            if args.nolog:
                err = fast(args=cmdargs.split(), comm=None)
            else:
                with stdouterr_redirected(to=logfile, overwrite=args.overwrite, comm=None):
                    err = fast(args=cmdargs.split(), comm=None)

            log.info(f'Rank {rank} done in {(time.time() - t1)/60.:.2f} min')
            if err != 0:
                if not os.path.exists(outfile):
                    log.warning(f'Rank {rank} missing {outfile}')
                    raise IOError
        except:
            log.warning(f'Rank {rank} raised an exception')
            import traceback
            traceback.print_exc()

    if rank == 0:
        log.debug(f'Rank {rank} is done')

    if comm:
        comm.barrier()

    if rank == 0 and not args.dry_run:
        for outfile in outfiles:
            if not os.path.exists(outfile):
                log.warning(f'Missing {outfile}')

        log.info(f'All done at {time.asctime()}')


def main():
    """Main wrapper on fastphot and fastspec.

    Currently only knows about SV1 observations.

    """
    import argparse
    from fastspecfit.mpi import plan

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--coadd-type', type=str, default='healpix', choices=['healpix', 'cumulative', 'pernight', 'perexp'],
                        help='Specify which type of spectra/zbest files to process.')
    parser.add_argument('--specprod', type=str, default='loa', help='Spectroscopic production to process.')

    parser.add_argument('--healpix', type=str, default=None, help='Comma-separated list of healpixels to process.')
    parser.add_argument('--survey', type=str, default='main,special,cmx,sv1,sv2,sv3', help='Survey to process.')
    parser.add_argument('--program', type=str, default='bright,dark,other,backup', help='Program to process.') # backup not supported
    parser.add_argument('--tile', default=None, type=str, nargs='*', help='Tile(s) to process.')
    parser.add_argument('--night', default=None, type=str, nargs='*', help='Night(s) to process (ignored if coadd-type is cumulative).')

    parser.add_argument('--samplefile', default=None, type=str, help='Full path to sample (FITS) file with {survey,program,healpix,targetid}.')
    parser.add_argument('--input-redshifts', action='store_true', help='Only used with --samplefile; if set, fit with redshift "Z" values.')
    parser.add_argument('--input-seeds', type=str, default=None, help='Comma-separated list of input random-number seeds corresponding to the (required) --targetids input.')
    parser.add_argument('--nmonte', type=int, default=10, help='Number of Monte Carlo realizations.')
    parser.add_argument('--seed', type=int, default=1, help='Random seed for Monte Carlo reproducibility; ignored if --input-seeds is passed.')

    parser.add_argument('--mp', type=int, default=1, help='Number of multiprocessing processes per MPI rank or node.')
    parser.add_argument('-n', '--ntargets', type=int, help='Number of targets to process in each file.')
    parser.add_argument('--firsttarget', type=int, default=0, help='Index of first object to to process in each file, zero-indexed.') 
    parser.add_argument('--targetids', type=str, default=None, help='Comma-separated list of TARGETIDs to process.')

    parser.add_argument('--fastphot', action='store_true', help='Fit the broadband photometry.')

    parser.add_argument('--templateversion', type=str, default=None, help='Template version number.')
    parser.add_argument('--templates', type=str, default=None, help='Optional full path and filename to the templates.')
    parser.add_argument('--fphotodir', type=str, default=None, help='Top-level location of the source photometry.')
    parser.add_argument('--fphotofile', type=str, default=None, help='Photometric information file.')

    parser.add_argument('--merge', action='store_true', help='Merge all individual catalogs (for a given survey and program) into one large file.')
    parser.add_argument('--mergedir', type=str, help='Output directory for merged catalogs.')
    parser.add_argument('--merge-suffix', type=str, help='Filename suffix for merged catalogs.')
    parser.add_argument('--mergeall', action='store_true', help='Merge all the individual merged catalogs into a single merged catalog.')
    parser.add_argument('--mergeall-main', action='store_true', help='Merge all the main catalogs.')
    parser.add_argument('--mergeall-sv', action='store_true', help='Merge all the SV catalogs.')
    parser.add_argument('--mergeall-special', action='store_true', help='Merge all the special catalogs.')
    parser.add_argument('--makeqa', action='store_true', help='Build QA in parallel.')

    parser.add_argument('--ignore-quasarnet', default=False, action='store_true', help='Do not use QuasarNet to improve QSO redshifts.')
    parser.add_argument('--ignore-photometry', default=False, action='store_true', help='Ignore the broadband photometry during model fitting.')
    parser.add_argument('--no-smooth-continuum', default=False, action='store_true', help='Do not fit the smooth continuum.')

    parser.add_argument('--verbose', action='store_true', help='More verbose output.')
    parser.add_argument('--overwrite', action='store_true', help='Overwrite any existing output files.')
    parser.add_argument('--plan', action='store_true', help='Plan how many nodes to use and how to distribute the targets.')
    parser.add_argument('--profile', action='store_true', help='Write out profiling / timing files..')
    parser.add_argument('--nompi', action='store_true', help='Do not use MPI parallelism.')
    parser.add_argument('--nolog', action='store_true', help='Do not write to the log file.')
    parser.add_argument('--dry-run', action='store_true', help='Generate but do not run commands.')

    parser.add_argument('--outdir-data', default='$PSCRATCH/fastspecfit/data', type=str, help='Base output data directory.')

    args = parser.parse_args()

    specprod_dir = None
    outdir_data = os.path.expanduser(os.path.expandvars(args.outdir_data))

    if args.nompi:
        comm = None
    else:
        try:
            from mpi4py import MPI
            # needed when profiling; no effect otherwise
            # https://docs.linaroforge.com/24.0.6/html/forge/map/python_profiling/profile_python_script.html
            #MPI.Init_thread(MPI.THREAD_SINGLE)
            comm = MPI.COMM_WORLD
        except ImportError:
            comm = None

    if comm:
        rank = comm.rank
    else:
        rank = 0
        # https://docs.nersc.gov/development/languages/python/parallel-python/#use-the-spawn-start-method
        if args.mp > 1 and 'NERSC_HOST' in os.environ:
            import multiprocessing
            multiprocessing.set_start_method('spawn')

    # If an input samplefile is provided, read and broadcast it.
    if args.samplefile is None:
        sample = None
    else:
        if args.coadd_type != 'healpix':
            errmsg = 'Input --samplefile is only currently compatible with --coadd-type="healpix"'
            log.critical(errmsg)
            raise NotImplementedError(errmsg)

        if rank == 0:
            import fitsio
            if not os.path.isfile(args.samplefile):
                log.warning(f'{args.samplefile} does not exist.')
                return
            try:
                sample = Table(fitsio.read(args.samplefile, columns=['SURVEY', 'PROGRAM', 'HEALPIX', 'TARGETID']))
                log.info(f'Read {len(sample)} rows from {args.samplefile}')
            except:
                if args.input_redshifts:
                    errmsg = f'Sample file {args.samplefile} with --input-redshifts missing required columns ' + \
                        '{SURVEY,PROGRAM,HEALPIX,TARGETID,Z}'
                else:
                    errmsg = f'Sample file {args.samplefile} missing required columns ' + \
                        '{SURVEY,PROGRAM,HEALPIX,TARGETID}'
                log.critical(errmsg)
                raise ValueError(errmsg)
        else:
            sample = Table()

        sample = comm.bcast(sample, root=0)

    # Parse some of the inputs.
    if args.samplefile is None and args.coadd_type == 'healpix':
        args.survey = args.survey.split(',')
        args.program = args.program.split(',')
        if args.healpix is not None:
            args.healpix = args.healpix.split(',')

    if args.mergeall_main or args.mergeall_sv or args.mergeall_special:
        args.mergeall = True

    if args.merge or args.mergeall:
        from fastspecfit.mpi import merge_fastspecfit

        # convenience code to make the super-merge catalogs, e.g., fastspec-iron-{main,special,sv}.fits
        if args.fastphot:
            fastprefix = 'fastphot'
        else:
            fastprefix = 'fastspec'

        if args.mergeall_main or args.mergeall_sv or args.mergeall_special:
            from glob import glob
            if args.mergedir is None:
                mergedir = os.path.join(outdir_data, args.specprod, 'catalogs')
            else:
                mergedir = args.mergedir

            if args.mergeall_main:
                args.merge_suffix = f'{args.specprod}-main'
                fastfiles_to_merge = sorted(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-main-*.fits')))
            elif args.mergeall_special:
                args.merge_suffix = f'{args.specprod}-special'
                fastfiles_to_merge = sorted(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-special-*.fits')))
            elif args.mergeall_sv:
                args.merge_suffix = f'{args.specprod}-sv'
                fastfiles_to_merge = sorted(list(set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-*.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-main.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-special.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-special-*.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-main-*.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-sv.fits')))))
            else:
                fastfiles_to_merge = None
        else:
            fastfiles_to_merge = None

        if args.samplefile is not None:
            merge_fastspecfit(specprod=args.specprod, specprod_dir=specprod_dir, coadd_type='healpix',
                              sample=sample, merge_suffix=args.merge_suffix,
                              outdir_data=outdir_data, fastfiles_to_merge=fastfiles_to_merge,
                              outsuffix=args.merge_suffix, mergedir=args.mergedir, overwrite=args.overwrite,
                              fastphot=args.fastphot, supermerge=args.mergeall, mp=args.mp)
        else:
            merge_fastspecfit(specprod=args.specprod, specprod_dir=specprod_dir, coadd_type=args.coadd_type,
                              survey=args.survey, program=args.program, healpix=args.healpix,
                              tile=args.tile, night=args.night, outdir_data=outdir_data,
                              fastfiles_to_merge=fastfiles_to_merge, outsuffix=args.merge_suffix,
                              mergedir=args.mergedir, overwrite=args.overwrite,
                              fastphot=args.fastphot, supermerge=args.mergeall, mp=args.mp,
                              split_hdu=True)
        return


    if args.plan:
        plan(comm=comm, specprod=args.specprod, specprod_dir=specprod_dir,
             coadd_type=args.coadd_type, survey=args.survey, program=args.program,
             healpix=args.healpix, tile=args.tile, night=args.night,
             makeqa=args.makeqa, mp=args.mp, fastphot=args.fastphot,
             outdir_data=outdir_data, overwrite=args.overwrite,
             sample=sample)
    else:
        if args.profile:
            import cProfile
            import pstats

            profiler = cProfile.Profile()
            profiler.enable()

        run_fastspecfit(args, comm=comm, fastphot=args.fastphot, specprod_dir=specprod_dir,
                        makeqa=args.makeqa, outdir_data=outdir_data, sample=sample,
                        input_redshifts=args.input_redshifts, templates=args.templates,
                        templateversion=args.templateversion, fphotodir=args.fphotodir,
                        fphotofile=args.fphotofile)

        if args.profile:
            profiler.disable()

            outfile = os.path.join(outdir_data, f'profile_rank{rank}.prof')
            log.info(f'Writing {outfile}')
            profiler.dump_stats(outfile)
            with open(os.path.join(outdir_data, f'profile_rank{rank}.txt'), 'w') as F:
                stats = pstats.Stats(profiler, stream=F)
                stats.strip_dirs()
                stats.sort_stats('cumtime')
                stats.print_stats()
                stats.print_callees()

    if comm:
        MPI.Finalize()


if __name__ == '__main__':
    main()
