#!/usr/bin/env python
# __BEGIN_LICENSE__
#  Copyright (c) 2009-2013, United States Government as represented by the
#  Administrator of the National Aeronautics and Space Administration. All
#  rights reserved.
#
#  The NGT platform is licensed under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance with the
#  License. You may obtain a copy of the License at
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
# __END_LICENSE__

import sys, argparse, subprocess, re, os, math, time, glob, shutil, math, copy
import os.path as P

# The path to the ASP python files
basepath    = os.path.abspath(sys.path[0])
pythonpath  = os.path.abspath(basepath + '/../Python')  # for dev ASP
libexecpath = os.path.abspath(basepath + '/../libexec') # for packaged ASP
sys.path.insert(0, basepath) # prepend to Python path
sys.path.insert(0, pythonpath)
sys.path.insert(0, libexecpath)

import asp_system_utils, asp_cmd_utils
asp_system_utils.verify_python_version_is_supported()

# Prepend to system PATH
os.environ["PATH"] = libexecpath + os.pathsep + os.environ["PATH"]

# This is explained below
if 'ASP_LIBRARY_PATH' in os.environ:
    os.environ['LD_LIBRARY_PATH'] = os.environ['ASP_LIBRARY_PATH']

def get_output_prefix(args):
    '''Parse the output prefix from the argument list'''
    prefix = asp_cmd_utils.get_option(args, '-o', 1)
    if not prefix:
        prefix = asp_cmd_utils.get_option(args, '--output-prefix', 1)
    if len(prefix) < 2:
        raise Exception('Failed to parse the output prefix.')
    return prefix[1]

def get_num_nodes(nodes_list):

    if nodes_list is None:
        return 1 # local machine

    # Count the number of nodes without repetition (need this for
    # Pleiades).
    nodes = {}
    num_nodes = 0
    try:
        fh = open(nodes_list, "r")
        for line in fh:
            if re.match(r'^\s*$', line): continue # skip empty lines
            matches = re.match(r'^\s*([^\s]*)', line)
            if matches:
                nodes[matches.group(1)] = 1

        num_nodes = len(nodes)
    except Exception as e:
        asp_system_utils.die(e)
    if num_nodes == 0:
        raise Exception('The list of computing nodes is empty')

    return num_nodes

def get_best_procs_threads():
    # Decide the best number of processes to use on a node, and how
    # many threads to use for each process. Too many processes
    # may result in too much memory usage.

    # We assume all machines have the same number of CPUs (cores)
    num_cpus = asp_system_utils.get_num_cpus()
    
    # Auto-compute the number of processes unless set.
    if opt.processes is not None:
        num_procs = opt.processes
    else:
        num_procs = int(round(num_cpus/4.0))
    if num_procs < 1:
        num_procs = 1
    
    # Same for the number of threads.
    if opt.threads is not None:
        num_threads = opt.threads
    else:
        num_threads = int(math.ceil(float(num_cpus)/num_procs))
    if num_threads < 1:
        num_threads = 1
    
    print("Using %d threads and %d processes." % (num_threads, num_procs))

    return (num_procs, num_threads)

def get_num_images(args):
    '''Return the number of input images. There are several cases to consider.'''

    count = 0

    # TODO(oalexan1): This is very fragile logic.
    IMAGE_EXTENSIONS = ['.tif', '.tiff', '.ntf', '.nitf', '.png', '.jpeg', '.jpg',
                        '.jp2', '.img', '.cub', '.bip', '.bil', '.bsq']
    for a in args:
        lc = a.lower()
        for e in IMAGE_EXTENSIONS:
            if lc.endswith(e):
                count += 1

    # Handle --image-list
    if count == 0:
        if '--image-list' not in args:
            raise Exception("No input images found, and neither was " + \
                            "--image-list specified. Supported image extensions: " + \
                            " ".join(IMAGE_EXTENSIONS))

    # Parse --image-list
    if '--image-list' in args:
        count = 0
        for v in range(len(args)):
            if args[v] == '--image-list' and v + 1 < len(args):
                image_list = args[v + 1]
                
                # If image_list has commas, that means it is a list of lists.
                # Then split by comma and collect in an array.
                if ',' in image_list:
                    image_lists = image_list.split(',')
                else:
                    image_lists = [image_list]
                    
                for image_list in image_lists:    
                    # Check if image_list exists
                    if not os.path.isfile(image_list):
                        raise Exception("Could not find image list: " + image_list)
                    with open(image_list, "r") as fh:
                        lines = fh.readlines()
                        lines = (" ".join(lines)).split()
                        
                        for line in lines:
                            line = line.strip()
                            if len(line) == 0:
                                continue
                            count = count + 1

    return count

# Launch the jobs with GNU Parallel. It will take care of distributing the jobs
# across the nodes and load balancing. The way we accomplish this is by calling
# this same script but with --job-id <num>.
def spawn_to_nodes(step, num_nodes, output_prefix, argsIn):

    args = copy.copy(argsIn)

    if opt.processes is None or opt.threads is None:
        # The user did not specify these. We will find the best
        # for their system.
        (num_procs, num_threads) = get_best_procs_threads()
    else:
        num_procs   = opt.processes
        num_threads = opt.threads

    # Number of jobs to run in parallel to ensure all nodes are busy.
    num_parallel_jobs = num_nodes * num_procs
    
    # Need not have more processes than jobs to run on them.
    # TODO(oalexan1): Wipe this
    #if num_procs > num_parallel_jobs:
    #     num_procs = num_parallel_jobs

    asp_cmd_utils.wipe_option(args, '--processes', 1)
    asp_cmd_utils.wipe_option(args, '--threads', 1)
    args.extend(['--processes', str(num_procs)])
    args.extend(['--threads', str(num_threads)])
    args.extend(['--num-parallel-jobs', str(num_parallel_jobs)])

    # For convenience store the list of job ids in a file that
    # will be passed to GNU parallel. Keep this in the run directory.
    asp_system_utils.mkdir_p(os.path.dirname(output_prefix))
    job_list = output_prefix + '-job-list.txt'
    f = open(job_list, 'w')
    for i in range(num_parallel_jobs):
        f.write("%d\n" % i)
    f.close()

    # Use GNU parallel with given number of processes.
    # TODO(oalexan1): Run 'parallel' using the runInGnuParallel() function call,
    # when the ASP_LIBRARY_PATH trick can be fully encapsulated in the
    # asp_system_utils.py code rather than being needed for each tool.
    cmd = ['parallel', '--will-cite', '--env', 'PATH', '--env', 'LD_LIBRARY_PATH', '--env', 'ASP_LIBRARY_PATH', '--env', 'ASP_DEPS_DIR', '--env', 'ISISROOT', '-u', '--max-procs', str(num_procs), '-a', job_list]
    if asp_system_utils.which(cmd[0]) is None:
        raise Exception('Need GNU Parallel to distribute the jobs.')

    if opt.nodes_list is not None:
        cmd += ['--sshloginfile', opt.nodes_list]

    if opt.parallel_options is not None:
        cmd += opt.parallel_options.split(' ')

    # Add the options which we want GNU parallel to not mess
    # with. Put them into a single string. Before that, put in quotes
    # any quantities having spaces, to avoid issues later.
    # Don't quote quantities already quoted.
    # TODO(oalexan1): Unify this across all tools.
    # Improve and use the function argListToString.
    args_copy = args[:] # deep copy
    args_copy += ["--work-dir", opt.work_dir]
    if opt.isisroot  is not None: 
        args_copy += ["--isisroot",  opt.isisroot]
    if opt.isisdata is not None: 
        args_copy += ["--isisdata", opt.isisdata]
    
    for index, arg in enumerate(args_copy):
        if re.search(r"[ \t]", arg) and arg[0] != '\'':
            args_copy[index] = '\'' + arg + '\''
    python_path = sys.executable # children must use same Python as parent
    start    = step; stop = start + 1
    args_str = python_path + " "              + \
               " ".join(args_copy)            + \
               " --entry-point " + str(start) + \
               " --stop-point " + str(stop) 
    args_str += " --job-id {}"
    cmd += [args_str]

    # This is a bugfix for RHEL 8. The 'parallel' program fails to start with ASP's
    # libs, so temporarily hide them.
    if 'LD_LIBRARY_PATH' in os.environ:
        os.environ['ASP_LIBRARY_PATH'] = os.environ['LD_LIBRARY_PATH']
        os.environ['LD_LIBRARY_PATH'] = ''

    try:
        asp_system_utils.generic_run(cmd, opt.verbose)
    except Exception as e:
        exception_len = len(str(e))
        # This may fail because the command line is too long. It does not seem
        # possible to catch this error, which comes from parallel. So we
        # have to guess.
        if exception_len > 50000:
            print("If the command line is too long, consider using --image-list, etc.")

        raise Exception(str(e))

    # Undo the above
    if 'ASP_LIBRARY_PATH' in os.environ:
        os.environ['LD_LIBRARY_PATH'] = os.environ['ASP_LIBRARY_PATH']

def run_job(prog, args, job_id, **kw):
    '''Wrapper to run a command. Set job_id=-1 if only one job will be called.'''

    binpath = asp_system_utils.bin_path(prog)
    call    = [binpath]
    call.extend(args)

    if job_id < 0: # Only one job will process everything
        if opt.threads is not None:
            asp_cmd_utils.wipe_option(call, '--threads', 1)
            call.extend(['--threads', str(opt.threads)])

    else: # One of the parallel jobs

        if opt.threads is not None:
            asp_cmd_utils.wipe_option(call, '--threads', 1)
            call.extend(['--threads', str(opt.threads)])

    if opt.dryrun:
        print('%s' % ' '.join(call))
        return
    if opt.verbose:
        print('%s' % ' '.join(call))
    try:
        code = subprocess.call(call)
    except OSError as e:
        raise Exception('%s: %s' % (binpath, e))
    if code != 0:
        raise Exception('Bundle adjust step ' + kw['msg'] + ' failed')

class ParallelBaStep:
    # The ids of individual parallel_bundle_adjust steps. What the user
    # sets as the statistics step, internally has 3 sub-steps:
    # A. Compute image statistics
    # B. Compute normalization bounds.
    # C. Compute interest points per image.
    statistics     = 0
    bounds         = 0.1
    interest_point = 0.2
    matching       = 1
    optimization   = 2
    all            = 3

if __name__ == '__main__':

    # Fix for this failing in a subprocess. There, need not know the ASP version.
    asp_version = ''
    if 'ISISROOT' in os.environ:
       asp_version = asp_system_utils.get_asp_version()
       
    usage = '''parallel_bundle_adjust <images> <cameras> <optional ground control points> -o <output prefix> [options]
        Camera model arguments may be optional for some stereo
        session types (e.g. isis). See bundle_adjust for all options.\n''' + asp_version
    
    # A wrapper for bundle_adjust which computes image statistics and IP matches
    #  in parallel across multiple machines.  The final bundle_adjust step is
    #  performed on a single machine.

    # Algorithm: When the script is started, it starts one copy of
    # itself on each node during steps before optimization. 
    # Those scripts in turn start actual jobs on those nodes.
    # For the optimization step, the script does the work itself.

    p = argparse.ArgumentParser(usage=usage)
    p.add_argument('--nodes-list', dest='nodes_list', default=None,
                 help='The list of computing nodes, one per line. ' + \
                 'If not provided, run on the local machine.')
    p.add_argument('--processes', dest='processes', default=None, type=int, 
                   help='The number of processes to use per node. The default is ' + \
                   'a quarter of the number of cores on the head node.')
    p.add_argument('--threads', dest='threads', default=None,
                 type=int, help='The number of threads  per process. The default is ' + \
                 'the number of cores on the head node over the number of processes.')
    p.add_argument('-e', '--entry-point', 
                   dest='entry_point', default=0, type = float,
                   help = "Bundle adjustment entry point (start at this stage). " + \
                   "Options: statistics and interest points per image = 0, " + \
                   "interest point matching = 1, optimization = 2.")
    p.add_argument('--stop-point', dest='stop_point',  default=3, type = float,
                   help = "Bundle adjustment stop point (stop *before* this stage). " + \
                   "Options: statistics = 0, matching = 1, optimization = 2, all = 3.")
    p.add_argument('--parallel-options', dest='parallel_options', default='--sshdelay 0.2',
                   help='Options to pass directly to GNU Parallel.')
    p.add_argument('-v', '--version', dest='version', default=False,
                 action='store_true', help='Display the version of software.')
    p.add_argument('--verbose', dest='verbose', default=False, action='store_true',
                 help='Display the commands being executed.')

    # Internal variables below.
    # The index of the spawned process, 0 <= job_id < processes.
    p.add_argument('--job-id', dest='job_id', default=None, type=int,
                 help=argparse.SUPPRESS)
    # Directory where the job is running
    p.add_argument('--work-dir', dest='work_dir', default=None,
                 help=argparse.SUPPRESS)
    # ISIS settings
    p.add_argument('--isisroot', dest='isisroot', default=None,
                 help=argparse.SUPPRESS)
    p.add_argument('--isisdata', dest='isisdata', default=None,
                 help=argparse.SUPPRESS)
    # Debug options
    p.add_argument('--dry-run', dest='dryrun', default=False, action='store_true',
                 help=argparse.SUPPRESS)

    global opt
    (opt, args) = p.parse_known_args()

    if opt.version:
        asp_system_utils.print_version_and_exit()

    if not args and not opt.version:
        p.print_help()
        sys.exit(1)

    try:
        args = asp_cmd_utils.clean_args(args)
        output_prefix = get_output_prefix(args)
    except Exception as e:
        # Print the error message and exit.
        print('\nERROR: ' + str(e) + '\n')
        p.print_help()
        sys.exit(1)
    
    # Setting ISISROOT and ISISDATA must happen in both parent and 
    # child processes, and before checking parallel version.
    if opt.isisroot is not None: os.environ['ISISROOT'] = opt.isisroot
    if opt.isisdata is not None: os.environ['ISISDATA'] = opt.isisdata
    
    # Ensure our 'parallel' is not out of date
    asp_system_utils.check_parallel_version()

    if opt.job_id is None:
        # When the script is started, set some options from the
        # environment which we will pass to the scripts we spawn
        # 1. Set the work directory
        opt.work_dir = os.getcwd()
        # 2. Set the ISIS settings if any
        if 'ISISROOT' in os.environ: opt.isisroot  = os.environ['ISISROOT']
        if 'ISISDATA' in os.environ: opt.isisdata = os.environ['ISISDATA']
        # 3. Fix for Pleiades, copy the nodes_list to the run directory
        # to ensure it can be seen on any node.
        if opt.nodes_list is not None:
            if not os.path.isfile(opt.nodes_list):
                asp_system_utils.die('\nERROR: No such nodes-list file: ' + opt.nodes_list, code=2)
            local_nodes_list = output_prefix + "-nodes-list.txt"
            if opt.nodes_list != local_nodes_list:
                asp_system_utils.mkdir_p(os.path.dirname(output_prefix))
                shutil.copy2(opt.nodes_list, local_nodes_list)
                opt.nodes_list = local_nodes_list
            asp_cmd_utils.wipe_option(sys.argv, '--nodes-list', 1)
            sys.argv.extend(['--nodes-list', opt.nodes_list])
            print("Nodes list: " + opt.nodes_list)
            
    else:
        # After the script spawns itself to nodes, it starts in the
        # home dir. Make it go to the right place.
        os.chdir(opt.work_dir)

    if opt.version:
        args.append('-v')

    if opt.job_id is None:

        # We get here when the script is started. The current running
        # process has become the management process that spawns other
        # copies of itself on other machines. This block will only do
        # actual work during the optimization step.

        # Wipe options which we will override.
        self_args = sys.argv # shallow copy
        asp_cmd_utils.wipe_option(self_args, '-e', 1)
        asp_cmd_utils.wipe_option(self_args, '--entry-point', 1)
        asp_cmd_utils.wipe_option(self_args, '--stop-point',  1)

        output_folder = os.path.dirname(output_prefix)

        # Create the main output folder
        if (not os.path.exists(output_folder)) and (output_folder != ""):
            os.makedirs(output_folder)

        if '--isis-cnet' in args or '--match-files-prefix' in args or          \
          '--clean-match-files-prefix' in args or '--skip-matching' in args or \
          '--nvm' in args:
           print("Skipping statistics and matching based on input options.")
           if opt.entry_point < ParallelBaStep.optimization:
              opt.entry_point = ParallelBaStep.optimization
        
        # Sanity check
        if (opt.entry_point < ParallelBaStep.statistics or 
            opt.stop_point  < ParallelBaStep.statistics):
          raise Exception("Invalid entry or stop point.")

        num_nodes = get_num_nodes(opt.nodes_list)
        
        # Statistics, bounds per image, interest points per image
        step = ParallelBaStep.interest_point
        if (opt.entry_point <= step):
            if (opt.stop_point <= step):
                sys.exit()

            # Spawn statistics processes to nodes
            # num_images = get_num_images(self_args)
            spawn_to_nodes(ParallelBaStep.statistics, num_nodes, output_prefix, self_args)
            
            # Run the bounds step as a single process
            if '--skip-matching' not in args:
                args.extend(['--skip-matching']) # this will also skip stats
            run_job('bundle_adjust', args + ['--calc-normalization-bounds'],
                    job_id=-1, 
                    msg='%d: Normalization bounds' % step)
            
            # Spawn interest point processes to nodes
            spawn_to_nodes(ParallelBaStep.interest_point, num_nodes, output_prefix,
                           self_args)
            
        # Matching
        step = ParallelBaStep.matching
        if (opt.entry_point <= step):
            if (opt.stop_point <= step):
                sys.exit()

            # Will run as many processes as we have nodes times processes per node.
            # This seems to be the best approach at load-balancing.
            
            # The commented-out logic below will be deleted at some point.
            # To do load-balancing properly, need to know how many pairs of
            # images we expect to match. This depends on --auto-overlap-params,
            # etc.
            #sep = ","
            #settings = run_and_parse_output("bundle_adjust", 
            #                                args + ['--query-num-image-pairs'], 
            #                                sep, opt.verbose)
            #num_image_pairs = int(settings["num_image_pairs"][0])
       
            # Spawn matching to nodes
            spawn_to_nodes(step, num_nodes, output_prefix, self_args)

        # Optimization
        step = ParallelBaStep.optimization
        if (opt.entry_point <= step):
            if (opt.stop_point <= step):
                sys.exit()
            if '--skip-matching' not in args:
                args.extend(['--skip-matching'])
            run_job('bundle_adjust', args, job_id=-1, msg='%d: Optimizing' % step)

            # End main process case
    else:

        # This process was spawned by GNU Parallel with a given
        # value of job_id. It will do the actual work.
        if opt.verbose:
            print("Running on machine: ", os.uname())

        try:
            args.extend(['--job-id', str(opt.job_id)])

            if (opt.entry_point == ParallelBaStep.statistics):
                args.extend(['--stop-after-statistics'])
                run_job('bundle_adjust', args, opt.job_id,
                        msg='%d: Statistics' % opt.entry_point)
                
            # The bounds step will not be done in parallel so is not here

            if (opt.entry_point == ParallelBaStep.interest_point):
                args.extend(['--calc-ip'])
                run_job('bundle_adjust', args, opt.job_id,
                        msg='%d: Interest point' % opt.entry_point)

            if (opt.entry_point == ParallelBaStep.matching):
                args.extend(['--stop-after-matching'])
                run_job('bundle_adjust', args, opt.job_id,
                        msg='%d: Matching' % opt.entry_point)
            
            # The optimization step will not be done in parallel so is not here

        except Exception as e:
            asp_system_utils.die(e)
            raise
