#!/usr/bin/env python
# __BEGIN_LICENSE__
#  Copyright (c) 2009-2013, United States Government as represented by the
#  Administrator of the National Aeronautics and Space Administration. All
#  rights reserved.
#
#  The NGT platform is licensed under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance with the
#  License. You may obtain a copy of the License at
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
# __END_LICENSE__

# TODO: The same behavior as bundle adjust should be used when dealing
# with pre-existing match files. This tool recreates them if cannot
# find them in the subdirectory for each subtask even though they
# exist in the final location.

import sys, argparse, subprocess, re, os, math, time, tempfile, glob,\
       shutil, math, copy
import os.path as P

# The path to the ASP python files
basepath    = os.path.abspath(sys.path[0])
pythonpath  = os.path.abspath(basepath + '/../Python')  # for dev ASP
libexecpath = os.path.abspath(basepath + '/../libexec') # for packaged ASP
sys.path.insert(0, basepath) # prepend to Python path
sys.path.insert(0, pythonpath)
sys.path.insert(0, libexecpath)

import asp_system_utils
asp_system_utils.verify_python_version_is_supported()

from stereo_utils import * # must be after the path is altered above

# Prepend to system PATH
os.environ["PATH"] = libexecpath + os.pathsep + os.environ["PATH"]

# This is explained below
if 'ASP_LIBRARY_PATH' in os.environ:
    os.environ['LD_LIBRARY_PATH'] = os.environ['ASP_LIBRARY_PATH']

def get_output_prefix(args):
    '''Parse the output prefix from the argument list'''
    temp = get_option(args, '-o', 1)
    if not temp:
        temp = get_option(args, '--output-prefix', 1)
    if len(temp) < 2:
        raise Exception('Failed to parse the output prefix!')
    return temp[1]


def get_num_nodes(nodes_list):

    if nodes_list is None:
        return 1 # local machine

    # Count the number of nodes without repetition (need this for
    # Pleiades).
    nodes = {}
    num_nodes = 0
    try:
        fh = open(nodes_list, "r")
        for line in fh:
            if re.match('^\s*$', line): continue # skip empty lines
            matches = re.match('^\s*([^\s]*)', line)
            if matches:
                nodes[matches.group(1)] = 1

        num_nodes = len(nodes)
    except Exception as e:
        die(e)
    if num_nodes == 0:
        raise Exception('The list of computing nodes is empty')

    return num_nodes

def get_best_procs_threads(step):
    # Decide the best number of processes to use on a node, and how
    # many threads to use for each process.  There used to be some
    # fancy logic, see below, but it does not work well. ASP mostly
    # uses 100% CPU per process the vast majority of the time, even
    # when invoked with a lot of threads.  Not sure why. The file
    # system could be the bottleneck.  As such, it is faster to just
    # use many processes and one thread per process.

    # We assume all machines have the same number of CPUs (cores)
    num_cpus = get_num_cpus()

    # Respect user's choice for the number of processes
    num_procs = num_cpus
    if opt.processes is not None:
        num_procs = opt.processes

    # Same for the number of threads.
    num_threads = 1
    if opt.threads is not None:
        num_threads = opt.threads

    if opt.verbose:
        print("For stage %d, using %d threads and %d processes." %
              (step, num_threads, num_procs))

    return (num_procs, num_threads)


def get_num_instances(args):
    '''Determine the number of total instances that the work will be split over.'''

    # For now we use a number of instances equal to the number of images.
    count = 0
    IMAGE_EXTENSIONS = ['.tif', '.tiff', '.ntf', '.png', '.jpeg', '.jpg',
                        '.jp2', '.img', '.cub', '.bip', '.bil', '.bsq']
    for a in args:
        lc = a.lower()
        for e in IMAGE_EXTENSIONS:
            if lc.endswith(e):
                count += 1
    return count

def get_subfolder_prefix(output_folder, instance_index):
    return os.path.join(output_folder, 'sub_idx_'+str(instance_index), 'run')

# Launch GNU Parallel for all tiles, it will take care of distributing
# the jobs across the nodes and load balancing. The way we accomplish
# this is by calling this same script but with --instance_index <num>.
def spawn_to_nodes(step, argsIn):

    args = copy.copy(argsIn)

    num_instances = get_num_instances(args)

    if opt.processes is None or opt.threads is None:
        # The user did not specify these. We will find the best
        # for their system.
        (procs, threads) = get_best_procs_threads(step)
    else:
        procs   = opt.processes
        threads = opt.threads

    wipe_option(args, '--processes', 1)
    wipe_option(args, '--threads', 1)
    args.extend(['--processes', str(procs)])
    args.extend(['--threads', str(threads)])
    args.extend(['--instance-count', str(num_instances)])

    # For convenience store the instance index list in a file that
    #  will be passed to GNU parallel.
    tmpFile = tempfile.NamedTemporaryFile(delete=True, dir='.')
    f = open(tmpFile.name, 'w')
    for i in range(num_instances):
        f.write("%d\n" % i)
    f.close()

    # Use GNU parallel with given number of processes.
    # TODO(oalexan1): Run 'parallel' using the runInGnuParallel() function call,
    # when the ASP_LIBRARY_PATH trick can be fully encapsulated in the
    # asp_system_utils.py code rather than being needed for each tool.
    cmd = ['parallel', '--will-cite', '--env', 'PATH', '--env', 'LD_LIBRARY_PATH', '--env', 'ASP_LIBRARY_PATH', '--env', 'ASP_DEPS_DIR', '-u', '--max-procs', str(procs), '-a', tmpFile.name]
    if which(cmd[0]) is None:
        raise Exception('Need GNU Parallel to distribute the jobs.')

    if opt.nodes_list is not None:
        cmd += ['--sshloginfile', opt.nodes_list]

    # Add the options which we want GNU parallel to not mess
    # with. Put them into a single string. Before that, put in quotes
    # any quantities having spaces, to avoid issues later.
    # Don't quote quantities already quoted.
    args_copy = args[:] # deep copy
    for index, arg in enumerate(args_copy):
        if re.search(" ", arg) and arg[0] != '\'':
            args_copy[index] = '\'' + arg + '\''
    python_path = sys.executable # children must use same Python as parent
    start    = step; stop = start + 1
    args_str = python_path + " " + \
               " ".join(args_copy) + " --entry-point " + str(start) + \
               " --stop-point " + str(stop) + " --work-dir "  + opt.work_dir
    if opt.isisroot  is not None: args_str += " --isisroot "  + opt.isisroot
    if opt.isisdata is not None: args_str += " --isisdata " + opt.isisdata
    args_str += " --instance-index {}"
    cmd += [args_str]

    # This is a bugfix for RHEL 8. The 'parallel' program fails to start with ASP's
    # libs, so temporarily hide them.
    if 'LD_LIBRARY_PATH' in os.environ:
        os.environ['ASP_LIBRARY_PATH'] = os.environ['LD_LIBRARY_PATH']
        os.environ['LD_LIBRARY_PATH'] = ''
    
    asp_system_utils.generic_run(cmd, opt.verbose)

    # Undo the above
    if 'ASP_LIBRARY_PATH' in os.environ:
        os.environ['LD_LIBRARY_PATH'] = os.environ['ASP_LIBRARY_PATH']

def run_job(prog, args, instance_index, **kw):
    '''Wrapper to run a command.
       Set instance_index=-1 if only one instance will be called.'''

    binpath = bin_path(prog)
    call    = [binpath]
    call.extend(args)

    if instance_index < 0: # Only one instance will be used
        if opt.threads is not None:
            wipe_option(call, '--threads', 1)
            call.extend(['--threads', str(opt.threads)])

    else: # One of the parallel jobs

        if opt.threads is not None:
            wipe_option(call, '--threads', 1)
            call.extend(['--threads', str(opt.threads)])

        if '--stop-after-statistics' not in args:
            # Update the output prefix for this instance
            output_prefix = get_output_prefix(args)
            output_folder = os.path.dirname(output_prefix)
            sub_prefix    = get_subfolder_prefix(output_folder, instance_index)

            # Replace the output prefix in the call
            set_option(call, '-o', [sub_prefix])
            wipe_option(call, '--output-prefix', 1)

    if opt.dryrun:
        print('%s' % ' '.join(call))
        return
    if opt.verbose:
        print('%s' % ' '.join(call))
    try:
        code = subprocess.call(call)
    except OSError as e:
        raise Exception('%s: %s' % (binpath, e))
    if code != 0:
        raise Exception('Bundle adjust step ' + kw['msg'] + ' failed')

def distribute_files(output_prefix, ext, num_instances):
    '''.stats (and camera footprint bbox files) files get distributed to all folders.
       .match files get symlinked to the main folder.'''

    output_folder = os.path.dirname(output_prefix)
    sub_prefixes = [get_subfolder_prefix(output_folder, i) for i in range(num_instances)]

    if ext == '.match':

        for pre in sub_prefixes:

            # Sym link match files from subdirectories to the output directory.
            # This takes a lot of care to get right.
            files = glob.glob(pre + '*' + ext)
            for f in files:
                src_file = os.path.relpath(f, output_folder)
                dst_file = output_prefix + f[len(pre):]

                # First wipe any existing file
                if os.path.exists(dst_file):
                   os.remove(dst_file)

                # Create sym link
                os.symlink(src_file, dst_file)

    else: # Stats files
        # Find all the statistics files
        stats_files = glob.glob(output_folder + '/*' + ext)

        # Duplicate them in each of the subfolders, creating them if needed.
        # - We could link them but these are tiny files so don't bother.
        for f in stats_files:
            base_name = f.replace(output_prefix,'')
            for s in sub_prefixes:
                sub_folder = os.path.dirname(s)
                new_path = s + base_name
                asp_system_utils.mkdir_p(os.path.dirname(new_path))
                shutil.copyfile(f, new_path)

if __name__ == '__main__':
    usage = '''parallel_bundle_adjust <images> <cameras> <optional ground control points> -o <output prefix> [options]
        Camera model arguments may be optional for some stereo
        session types (e.g. isis). See bundle_adjust for all options.\n''' + get_asp_version()

    # A wrapper for bundle_adjust which computes image statistics and IP matches
    #  in parallel across multiple machines.  The final bundle_adjust step is
    #  performed on a single machine.

    # Algorithm: When the script is started, it starts one copy of
    # itself on each node during steps 1 and 2 (statistics, matching).
    # Those scripts in turn start actual jobs on those nodes.
    # For step 3 (optimization), the script does the work itself.

    p = argparse.ArgumentParser(usage=usage)
    p.add_argument('--nodes-list',           dest='nodes_list', default=None,
                 help='The list of computing nodes, one per line. ' + \
                 'If not provided, run on the local machine.')
    p.add_argument('--processes',            dest='processes', default=None,
                 type=int, help='The number of processes to use per node.')
    p.add_argument('--threads', dest='threads', default=None,
                 type=int, help='The number of threads to use.')
    p.add_argument('-e', '--entry-point',    dest='entry_point', default=0,
                 help='Step to start with (statistics=0, matching=1, optimization=2).',
                 type=int)
    p.add_argument('--stop-point',           dest='stop_point',  default=3,
                 help='Step to stop after (statistics=1, matching=2, optimization=3).',
                 type=int)
    p.add_argument('-v', '--version',        dest='version', default=False,
                 action='store_true', help='Display the version of software.')
    p.add_argument('--verbose', dest='verbose', default=False, action='store_true',
                 help='Display the commands being executed.')

    # Internal variables below.
    # The index of the spawned process, 0 <= instance_index < processes.
    p.add_argument('--instance-index', dest='instance_index', default=None, type=int,
                 help=argparse.SUPPRESS)
    # Directory where the job is running
    p.add_argument('--work-dir', dest='work_dir', default=None,
                 help=argparse.SUPPRESS)
    # ISIS settings
    p.add_argument('--isisroot', dest='isisroot', default=None,
                 help=argparse.SUPPRESS)
    p.add_argument('--isisdata', dest='isisdata', default=None,
                 help=argparse.SUPPRESS)
    # Debug options
    p.add_argument('--dry-run', dest='dryrun', default=False, action='store_true',
                 help=argparse.SUPPRESS)

    global opt
    (opt, args) = p.parse_known_args()
    args = clean_args(args)

    if opt.version:
        asp_system_utils.print_version_and_exit()

    if not args and not opt.version:
        p.print_help()
        sys.exit(1)

    # Ensure our 'parallel' is not out of date
    check_parallel_version()

    if opt.threads is None:
        opt.threads = get_num_cpus()

    if opt.instance_index is None:
        # When the script is started, set some options from the
        # environment which we will pass to the scripts we spawn
        # 1. Set the work directory
        opt.work_dir = os.getcwd()
        # 2. Set the ISIS settings if any
        if 'ISISROOT'  in os.environ: opt.isisroot  = os.environ['ISISROOT']
        if 'ISISDATA' in os.environ: opt.isisdata = os.environ['ISISDATA']
        # 3. Fix for Pleiades, copy the nodes_list to current directory
        if opt.nodes_list is not None:
            if not os.path.isfile(opt.nodes_list):
                die('\nERROR: No such nodes-list file: ' + opt.nodes_list, code=2)
            tmpFile = tempfile.NamedTemporaryFile(delete=True, dir='.')
            shutil.copy2(opt.nodes_list, tmpFile.name)
            opt.nodes_list = tmpFile.name
            wipe_option(sys.argv, '--nodes-list', 1)
            sys.argv.extend(['--nodes-list', tmpFile.name])
    else:
        # After the script spawns itself to nodes, it starts in the
        # home dir. Make it go to the right place.
        os.chdir(opt.work_dir)
        # Set the ISIS settings
        if opt.isisroot  is not None: os.environ['ISISROOT' ] = opt.isisroot
        if opt.isisdata is not None: os.environ['ISISDATA'] = opt.isisdata

    num_nodes     = get_num_nodes(opt.nodes_list)
    num_instances = get_num_instances(args)

    if opt.version:
        args.append('-v')

    if opt.instance_index is None:

        # We get here when the script is started. The current running
        # process has become the management process that spawns other
        # copies of itself on other machines. This block will only do
        # actual work during the optimization step.

        # Wipe options which we will override.
        self_args = sys.argv # shallow copy
        wipe_option(self_args, '-e', 1)
        wipe_option(self_args, '--entry-point', 1)
        wipe_option(self_args, '--stop-point',  1)

        output_prefix = get_output_prefix(args)
        output_folder = os.path.dirname(output_prefix)

        # Create the main output folder
        if (not os.path.exists(output_folder)) and (output_folder != ""):
            os.makedirs(output_folder)

        # Statistics.
        step = Step.statistics
        if ( opt.entry_point <= step ):
            if ( opt.stop_point <= step ):
                sys.exit()

            # Spawn statistics processes to nodes.
            spawn_to_nodes(step, self_args)

            # Copy all statistics files to each working folder
            # Same for camera footprint bbox files if present.
            distribute_files(output_prefix, '-stats.tif', num_instances)
            distribute_files(output_prefix, '-bbox.txt', num_instances)

        # Matching.
        step = Step.matching
        if ( opt.entry_point <= step ):
            if ( opt.stop_point <= step ):
                sys.exit()

            # Spawn statistics processes to nodes.
            spawn_to_nodes(step, self_args)

            # Copy all statistics files to each working folder
            distribute_files(output_prefix, '.match', num_instances)

        # Optimization
        step = Step.optimization
        if ( opt.entry_point <= step ):
            if ( opt.stop_point <= step ):
                sys.exit()
            args.extend(['--skip-matching'])
            run_job('bundle_adjust', args, instance_index=-1, msg='%d: Optimizing' % step)

            # End main process case
    else:

        # This process was spawned by GNU Parallel with a given
        # value of opt.tile_id. Launch the job for that tile.
        if opt.verbose:
            print("Running on machine: ", os.uname())

        try:
            args.extend(['--instance-index', str(opt.instance_index)])

            if ( opt.entry_point == Step.statistics ):
                args.extend(['--stop-after-statistics'])
                run_job('bundle_adjust', args, opt.instance_index,
                        msg='%d: Statistics' % opt.entry_point)

            if ( opt.entry_point == Step.matching ):
                args.extend(['--stop-after-matching'])
                run_job('bundle_adjust', args, opt.instance_index,
                        msg='%d: Matching' % opt.entry_point)

        except Exception as e:
            die(e)
            raise
