#!/usr/bin/env python

"""MPI wrapper to get a large number of image cutouts. Only run ahead of a VAC
release.

shifterimg pull dstndstn/cutouts:dvsro
shifter --image dstndstn/cutouts:dvsro cutout --output cutout.jpg --ra 234.2915 --dec 16.7684 --size 256 --layer ls-dr9 --pixscale 0.262 --force

shifter --image dstndstn/cutouts:dvsro bash
$HOME/code/desihub/fastspecfit/bin/get-cutouts --survey cmx --program other --mp 1 --dry-run

"""
import pdb # for debugging

import os, sys, time
import numpy as np
import fitsio
import multiprocessing
from glob import glob

def _cutout_one(args):
    return cutout_one(*args)

def cutout_one(jpegfile, ra, dec, dry_run, rank, iobj):
    """
    pixscale = 0.262
    width = int(30 / pixscale)   # =114
    height = int(width / 1.3) # =87 [3:2 aspect ratio]

    """
    from cutout import cutout

    width = 114
    height = 87

    cmdargs = f'--ra={ra} --dec={dec} --output={jpegfile} --width={width} --height={height} --layer=ls-dr9 --pixscale=0.262 --force'
    if dry_run:
        #print(f'rank {rank} workin on object {iobj}')
        print(f'Rank {rank}, object {iobj}: cutout {cmdargs}')
    else:
        try:
            cutout(ra, dec, jpegfile, width=width, height=height, layer='ls-dr9', pixscale=0.262, force=True)
            print(f'Rank {rank}, object {iobj}: cutout {cmdargs}')
        except:
            print(f'WARNING: Rank {rank}, object {iobj} off the footprint: cutout {cmdargs}')
            #print(f'WARNING: Rank {rank}, object {iobj}: problem writing {jpegfile}')
            import matplotlib.pyplot as plt
            fig, ax = plt.subplots()
            ax.imshow(np.zeros((height, width, 3)), origin='lower')
            ax.axis('off')
            plt.savefig(jpegfile, bbox_inches='tight', pad_inches=0)
            #print(f'Rank {rank}, object {iobj}: off-footprint; wrote empty {jpegfile}')
            
        if not os.path.isfile(jpegfile):
            print(f'WARNING: Missing {jpegfile}')

def _get_jpegfiles_one(args):
    return get_jpegfiles_one(*args)

def get_jpegfiles_one(fastspecfile, coadd_type, htmldir_root, overwrite, verbose):
    meta = fitsio.read(fastspecfile, 'METADATA', columns=['SURVEY', 'PROGRAM', 'HEALPIX', 'TARGETID', 'RA', 'DEC'])
    htmldir = os.path.join(htmldir_root, meta['SURVEY'][0], meta['PROGRAM'][0], str(meta['HEALPIX'][0]//100), str(meta['HEALPIX'][0]))
    if not os.path.isdir(htmldir):
        os.makedirs(htmldir, exist_ok=True)
        
    jpegfiles = get_cutout_filename(meta, coadd_type, outprefix='fastspec',
                                    htmldir=htmldir)
    allra, alldec = meta['RA'], meta['DEC']
    if overwrite is False:
        I = np.array([not os.path.isfile(jpegfile) for jpegfile in jpegfiles])
        J = ~I
        if np.sum(J) > 0:
            if verbose:
                print(f'Skipping {np.sum(J)} existing QA file(s) from fastspecfile {fastspecfile}.')
            jpegfiles = jpegfiles[I]
            allra = allra[I]
            alldec = alldec[I]

    return jpegfiles, allra, alldec, len(jpegfiles)

def get_cutout_filename(metadata, coadd_type, outprefix=None, htmldir=None):
    """Build the cutout filename.

    """
    import astropy.table

    if htmldir is None:
        htmldir = '.'
        
    if outprefix is None:
        outprefix = 'fastspec'

    def _one_filename(_metadata):
        if coadd_type == 'healpix':
            jpegfile = os.path.join(htmldir, 'tmp.{}-{}-{}-{}-{}.jpeg'.format(
                outprefix, _metadata['SURVEY'], _metadata['PROGRAM'],
                _metadata['HEALPIX'], _metadata['TARGETID']))
        elif coadd_type == 'cumulative':
            jpegfile = os.path.join(htmldir, 'tmp.{}-{}-{}-{}.jpeg'.format(
                outprefix, _metadata['TILEID'], coadd_type, _metadata['TARGETID']))
        elif coadd_type == 'pernight':
            jpegfile = os.path.join(htmldir, 'tmp.{}-{}-{}-{}.jpeg'.format(
                outprefix, _metadata['TILEID'], _metadata['NIGHT'], _metadata['TARGETID']))
        elif coadd_type == 'perexp':
            jpegfile = os.path.join(htmldir, 'tmp.{}-{}-{}-{}-{}.jpeg'.format(
                outprefix, _metadata['TILEID'], _metadata['NIGHT'],
                _metadata['EXPID'], _metadata['TARGETID']))
        elif coadd_type == 'custom':
            jpegfile = os.path.join(htmldir, 'tmp.{}-{}-{}-{}-{}.jpeg'.format(
                outprefix, _metadata['SURVEY'], _metadata['PROGRAM'],
                _metadata['HEALPIX'], _metadata['TARGETID']))
        else:
            errmsg = 'Unrecognized coadd_type {}!'.format(coadd_type)
            print(errmsg)
            raise ValueError(errmsg)
        return jpegfile

    if type(metadata) is astropy.table.row.Row:
        jpegfile = _one_filename(metadata)
    else:
        jpegfile = [_one_filename(_metadata) for _metadata in metadata]
    
    return np.array(jpegfile)

def weighted_partition(weights, n, groups_per_node=None):
    '''
    Partition `weights` into `n` groups with approximately same sum(weights)

    Args:
        weights: array-like weights
        n: number of groups

    Returns list of lists of indices of weights for each group

    Notes:
        compared to `dist_discrete_all`, this function allows non-contiguous
        items to be grouped together which allows better balancing.
    '''
    #- sumweights will track the sum of the weights that have been assigned
    #- to each group so far
    sumweights = np.zeros(n, dtype=float)

    #- Initialize list of lists of indices for each group
    groups = list()
    for i in range(n):
        groups.append(list())

    #- Assign items from highest weight to lowest weight, always assigning
    #- to whichever group currently has the fewest weights
    weights = np.asarray(weights)
    for i in np.argsort(-weights):
        j = np.argmin(sumweights)
        groups[j].append(i)
        sumweights[j] += weights[i]

    assert len(groups) == n

    #- Reorder groups to spread out large items across different nodes
    #- NOTE: this isn't perfect, e.g. study
    #-   weighted_partition(np.arange(12), 6, groups_per_node=2)
    #- even better would be to zigzag back and forth across the nodes instead
    #- of loop across the nodes.
    if groups_per_node is None:
        return groups
    else:
        distributed_groups = [None,] * len(groups)
        num_nodes = (n + groups_per_node - 1) // groups_per_node
        i = 0
        for noderank in range(groups_per_node):
            for inode in range(num_nodes):
                j = inode*groups_per_node + noderank
                if i < n and j < n:
                    distributed_groups[j] = groups[i]
                    i += 1

        #- do a final check that all groups were assigned
        for i in range(len(distributed_groups)):
            assert distributed_groups[i] is not None, 'group {} not set'.format(i)

        return distributed_groups


def plan(comm=None, specprod=None, coadd_type='healpix',
         survey=None, program=None, healpix=None, tile=None, night=None, 
         outdir_data='.', outdir_html='.', mp=1, overwrite=False):

    from astropy.table import Table
                                
    t0 = time.time()
    if comm is None:
        rank, size = 0, 1
    else:
        rank, size = comm.rank, comm.size

    desi_root = os.environ.get('DESI_ROOT', '/dvs_ro/cfs/cdirs/desi')

    # look for data in the standard location
    if coadd_type == 'healpix':
        subdir = 'healpix'
        if healpix is not None:
            healpixels = healpix.split(',')
    else:
        subdir = 'tiles'

        # figure out which tiles belong to the SV programs
        if tile is None:
            tilefile = os.path.join(desi_root, 'spectro', 'redux', specprod, 'tiles-{specprod}.csv')
            alltileinfo = Table.read(tilefile)
            tileinfo = alltileinfo[['sv' in survey for survey in alltileinfo['SURVEY']]]

            print('Retrieved a list of {} {} tiles from {}'.format(
                len(tileinfo), ','.join(sorted(set(tileinfo['SURVEY']))), tilefile))

            tile = np.array(list(set(tileinfo['TILEID'])))

    outdir = os.path.join(outdir_data, specprod, subdir)
    htmldir = os.path.join(outdir_data, specprod, 'html', subdir)

    def _findfiles(filedir, prefix='fastspec', survey=None, program=None, healpix=None,
                   tile=None, night=None, gzip=False):
        if gzip:
            fitssuffix = 'fits.gz'
        else:
            fitssuffix = 'fits'
            
        if coadd_type == 'healpix':
            thesefiles = []
            for onesurvey in np.atleast_1d(survey):
                for oneprogram in np.atleast_1d(program):
                    print(f'Building file list for survey={onesurvey} and program={oneprogram}')
                    if healpix is not None:
                        for onepix in healpixels:
                            _thesefiles = glob(os.path.join(filedir, onesurvey, oneprogram, str(int(onepix)//100), onepix,
                                                            '{}-{}-{}-*.{}'.format(prefix, onesurvey, oneprogram, fitssuffix)))
                            thesefiles.append(_thesefiles)
                    else:
                        allpix = glob(os.path.join(filedir, onesurvey, oneprogram, '*'))
                        for onepix in allpix:
                            _thesefiles = glob(os.path.join(onepix, '*', '{}-{}-{}-*.{}'.format(prefix, onesurvey, oneprogram, fitssuffix)))
                            thesefiles.append(_thesefiles)
            if len(thesefiles) > 0:
                thesefiles = np.array(sorted(np.unique(np.hstack(thesefiles))))
        elif coadd_type == 'cumulative':
            # Scrape the disk to get the tiles, but since we read the csv file I don't think this ever happens.
            if tile is None:
                tiledirs = np.array(sorted(set(glob(os.path.join(filedir, 'cumulative', '?????')))))
                if len(tiledirs) > 0:
                    tile = [int(os.path.basename(tiledir)) for tiledir in tiledirs]
            if tile is not None:
                thesefiles = []
                for onetile in tile:
                    nightdirs = np.array(sorted(set(glob(os.path.join(filedir, 'cumulative', str(onetile), '????????')))))
                    if len(nightdirs) > 0:
                        # for a given tile, take just the most recent night
                        thisnightdir = nightdirs[-1]
                        thesefiles.append(glob(os.path.join(thisnightdir, '{}-[0-9]-{}-thru????????.{}'.format(
                            prefix, onetile, fitssuffix))))
                if len(thesefiles) > 0:
                    thesefiles = np.array(sorted(set(np.hstack(thesefiles))))
        elif coadd_type == 'pernight':
            if tile is not None and night is not None:
                thesefiles = []
                for onetile in tile:
                    for onenight in night:
                        thesefiles.append(glob(os.path.join(
                            filedir, 'pernight', str(onetile), str(onenight), '{}-[0-9]-{}-{}.{}'.format(
                                prefix, onetile, onenight, fitssuffix))))
                if len(thesefiles) > 0:
                    thesefiles = np.array(sorted(set(np.hstack(thesefiles))))
            elif tile is not None and night is None:
                thesefiles = np.array(sorted(set(np.hstack([glob(os.path.join(
                    filedir, 'pernight', str(onetile), '????????', '{}-[0-9]-{}-????????.{}'.format(
                    prefix, onetile, fitssuffix))) for onetile in tile]))))
            elif tile is None and night is not None:
                thesefiles = np.array(sorted(set(np.hstack([glob(os.path.join(
                    filedir, 'pernight', '?????', str(onenight), '{}-[0-9]-?????-{}.{}'.format(
                    prefix, onenight, fitssuffix))) for onenight in night]))))
            else:
                thesefiles = np.array(sorted(set(glob(os.path.join(
                    filedir, '?????', '????????', '{}-[0-9]-?????-????????.{}'.format(
                        prefix, fitssuffix))))))
        elif coadd_type == 'perexp':
            if tile is not None:
                thesefiles = np.array(sorted(set(np.hstack([glob(os.path.join(
                    filedir, 'perexp', str(onetile), '????????', '{}-[0-9]-{}-exp????????.{}'.format(
                    prefix, onetile, fitssuffix))) for onetile in tile]))))
            else:
                thesefiles = np.array(sorted(set(glob(os.path.join(
                    filedir, 'perexp', '?????', '????????', '{}-[0-9]-?????-exp????????.{}'.format(
                        prefix, fitssuffix))))))
        else:
            pass
        return thesefiles

    fastspecfiles = _findfiles(outdir, prefix='fastspec', survey=survey, program=program,
                               healpix=healpix, tile=tile, night=night, gzip=True)
    nfile = len(fastspecfiles)

    if nfile == 0:
        if rank == 0:
            print('No fastspecfiles found!')
        return list(), list(), list(), list

    print(f'Found {nfile} fastspecfiles.')

    verbose = False
    mpargs = [(fastspecfile, coadd_type, htmldir, overwrite, verbose)
              for fastspecfile in fastspecfiles]
    if mp > 1:
        with multiprocessing.Pool(mp) as P:
            out = P.map(_get_jpegfiles_one, mpargs)
    else:
        out = [get_jpegfiles_one(*mparg) for mparg in mpargs]
    out = list(zip(*out))
    
    jpegfiles = np.array(out[0], dtype=object)
    allra = np.array(out[1], dtype=object)
    alldec = np.array(out[2], dtype=object)
    ntargets = np.array(out[3])
        
    iempty = np.where(ntargets == 0)[0]
    if len(iempty) > 0:
        print(f'Skipping {len(iempty)} fastspecfiles with no jpeg files to make.')

    itodo = np.where(ntargets > 0)[0]
    if len(itodo) > 0:
        jpegfiles = jpegfiles[itodo]
        allra = allra[itodo]
        alldec = alldec[itodo]
        ntargets = ntargets[itodo]

        print(f'Missing cutouts for {np.sum(ntargets)} targets.')
        #print('Working on {} files with a total of {} targets.'.format(len(itodo), np.sum(ntargets)))

        groups = weighted_partition(ntargets, size)
            
        #indices = np.arange(len(jpegfiles))
        #
        ## Assign the sample to ranks to make the jpegfiles distribution per rank ~flat.
        ## https://stackoverflow.com/questions/33555496/split-array-into-equally-weighted-chunks-based-on-order
        #cumuweight = ntargets.cumsum() / ntargets.sum()
        #idx = np.searchsorted(cumuweight, np.linspace(0, 1, size, endpoint=False)[1:])
        #if len(idx) < size: # can happen in corner cases or with 1 rank
        #    groups = np.array_split(indices, size) # unweighted
        #else:
        #    groups = np.array_split(indices, idx) # weighted
        #for ii in range(size): # sort by weight
        #    srt = np.argsort(ntargets[groups[ii]])
        #    groups[ii] = groups[ii][srt]
    else:
        groups = [np.array([])]

    return jpegfiles, allra, alldec, groups

def do_cutouts(args, comm=None, outdir_data='.', outdir_html='.'):

    if comm is None:
        rank, size = 0, 1
    else:
        rank, size = comm.rank, comm.size

    t0 = time.time()
    if rank == 0:
        jpegfiles, allra, alldec, groups = plan(
                comm=comm, specprod=args.specprod,
                coadd_type=args.coadd_type,
                survey=args.survey, program=args.program,
                healpix=args.healpix, tile=args.tile, night=args.night,
                outdir_data=outdir_data, outdir_html=outdir_html,
                overwrite=args.overwrite, mp=args.mp)
        print(f'Planning took {(time.time() - t0):.2f} sec')
    else:
        jpegfiles, allra, alldec, groups = [], [], [], []

    if comm:
        jpegfiles = comm.bcast(jpegfiles, root=0)
        allra = comm.bcast(allra, root=0)
        alldec = comm.bcast(alldec, root=0)
        groups = comm.bcast(groups, root=0)
    sys.stdout.flush()
    
    # all done
    if len(jpegfiles) == 0 or len(np.hstack(jpegfiles)) == 0:
        return
        
    assert(len(groups) == size)

    for ii in groups[rank]:
        print(f'Rank {rank} started at {time.asctime()}')
        sys.stdout.flush()
    
        mpargs = [(jpegfile, ra, dec, args.dry_run, rank, iobj) for iobj, (jpegfile, ra, dec) in
                  enumerate(zip(jpegfiles[ii], allra[ii], alldec[ii]))]
        if args.mp > 1:
            with multiprocessing.Pool(args.mp) as P:
                P.map(_cutout_one, mpargs)
        else:
            [cutout_one(*mparg) for mparg in mpargs]

    print(f'  rank {rank} is done')
    sys.stdout.flush()

    if comm is not None:
        comm.barrier()

    if rank == 0 and not args.dry_run:
        print(f'All done at {time.asctime()}')
        
def main():
    """Main wrapper.

    """
    import argparse    
    
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--coadd-type', type=str, default='healpix', choices=['healpix', 'cumulative', 'pernight', 'perexp'],
                        help='Specify which type of spectra/zbest files to process.')
    parser.add_argument('--specprod', type=str, default='iron', #choices=['fuji', 'guadalupe', 'iron'],
                        help='Spectroscopic production to process.')
    
    parser.add_argument('--healpix', type=str, default=None, help='Comma-separated list of healpixels to process.')
    parser.add_argument('--survey', type=str, default='main,special,cmx,sv1,sv2,sv3', help='Survey to process.')
    parser.add_argument('--program', type=str, default='bright,dark,other,backup', help='Program to process.') # backup not supported
    parser.add_argument('--tile', default=None, type=str, nargs='*', help='Tile(s) to process.')
    parser.add_argument('--night', default=None, type=str, nargs='*', help='Night(s) to process (ignored if coadd-type is cumulative).')
    parser.add_argument('--mp', type=int, default=1, help='Number of multiprocessing processes per MPI rank or node.')

    parser.add_argument('--overwrite', action='store_true', help='Overwrite any existing output files.')
    parser.add_argument('--plan', action='store_true', help='Plan how many nodes to use and how to distribute the targets.')
    parser.add_argument('--nompi', action='store_true', help='Do not use MPI parallelism.')
    parser.add_argument('--dry-run', action='store_true', help='Generate but do not run commands.')

    parser.add_argument('--outdir-html', default='$PSCRATCH/fastspecfit/html', type=str, help='Base output HTML directory.')
    parser.add_argument('--outdir-data', default='$PSCRATCH/fastspecfit/data', type=str, help='Base output data directory.')
    
    args = parser.parse_args()

    outdir_data = os.path.expandvars(args.outdir_data)
    outdir_html = os.path.expandvars(args.outdir_html)

    if args.nompi:
        comm = None
    else:
        try:
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
        except ImportError:
            comm = None

    # https://docs.nersc.gov/development/languages/python/parallel-python/#use-the-spawn-start-method
    if args.mp > 1 and 'NERSC_HOST' in os.environ:
        import multiprocessing
        multiprocessing.set_start_method('spawn')
        
    if args.coadd_type == 'healpix':
        args.survey = args.survey.split(',')
        args.program = args.program.split(',')

    if args.plan:
        if comm is None:
            rank = 0
        else:
            rank = comm.rank
            
        if rank == 0:
            plan(comm=comm, specprod=args.specprod, coadd_type=args.coadd_type,
                 survey=args.survey, program=args.program,
                 healpix=args.healpix, tile=args.tile, night=args.night,
                 outdir_data=outdir_data, outdir_html=outdir_html,
                 overwrite=args.overwrite, mp=args.mp)
    else:
        do_cutouts(args, comm=comm, outdir_data=outdir_data, outdir_html=outdir_html)

if __name__ == '__main__':
    main()
