#!/usr/bin/env python
# __BEGIN_LICENSE__
#  Copyright (c) 2009-2013, United States Government as represented by the
#  Administrator of the National Aeronautics and Space Administration. All
#  rights reserved.
#
#  The NGT platform is licensed under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance with the
#  License. You may obtain a copy of the License at
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
# __END_LICENSE__

# TODO: The same behavior as bundle adjust should be used when dealing
# with pre-existing match files. This tool recreates them if cannot
# find them in the subdirectory for each subtask even though they
# exist in the final location.

import sys, argparse, subprocess, re, os, math, time, glob, shutil, math, copy
import os.path as P

# The path to the ASP python files
basepath    = os.path.abspath(sys.path[0])
pythonpath  = os.path.abspath(basepath + '/../Python')  # for dev ASP
libexecpath = os.path.abspath(basepath + '/../libexec') # for packaged ASP
sys.path.insert(0, basepath) # prepend to Python path
sys.path.insert(0, pythonpath)
sys.path.insert(0, libexecpath)

import asp_system_utils, asp_cmd_utils
asp_system_utils.verify_python_version_is_supported()

from asp_stereo_utils import * # must be after the path is altered above

# Prepend to system PATH
os.environ["PATH"] = libexecpath + os.pathsep + os.environ["PATH"]

# This is explained below
if 'ASP_LIBRARY_PATH' in os.environ:
    os.environ['LD_LIBRARY_PATH'] = os.environ['ASP_LIBRARY_PATH']

def get_output_prefix(args):
    '''Parse the output prefix from the argument list'''
    prefix = get_option(args, '-o', 1)
    if not prefix:
        prefix = get_option(args, '--output-prefix', 1)
    if len(prefix) < 2:
        raise Exception('Failed to parse the output prefix.')
    return prefix[1]

def get_num_nodes(nodes_list):

    if nodes_list is None:
        return 1 # local machine

    # Count the number of nodes without repetition (need this for
    # Pleiades).
    nodes = {}
    num_nodes = 0
    try:
        fh = open(nodes_list, "r")
        for line in fh:
            if re.match(r'^\s*$', line): continue # skip empty lines
            matches = re.match(r'^\s*([^\s]*)', line)
            if matches:
                nodes[matches.group(1)] = 1

        num_nodes = len(nodes)
    except Exception as e:
        die(e)
    if num_nodes == 0:
        raise Exception('The list of computing nodes is empty')

    return num_nodes

def get_best_procs_threads(step):
    # Decide the best number of processes to use on a node, and how
    # many threads to use for each process.  There used to be some
    # fancy logic, see below, but it does not work well. ASP mostly
    # uses 100% CPU per process the vast majority of the time, even
    # when invoked with a lot of threads.  Not sure why. The file
    # system could be the bottleneck.  As such, it is faster to just
    # use many processes and one thread per process.

    # We assume all machines have the same number of CPUs (cores)
    num_cpus = get_num_cpus()

    # Respect user's choice for the number of processes
    num_procs = num_cpus
    if opt.processes is not None:
        num_procs = opt.processes

    # Same for the number of threads.
    num_threads = 1
    if opt.threads is not None:
        num_threads = opt.threads

    if opt.verbose:
        print("For stage %d, using %d threads and %d processes." %
              (step, num_threads, num_procs))

    return (num_procs, num_threads)

def get_num_instances(args):
    '''Determine the number of total instances that the work will be split over.'''

    # For now we use a number of instances equal to the number of images.
    count = 0

    # TODO(oalexan1): This is very fragile logic.
    IMAGE_EXTENSIONS = ['.tif', '.tiff', '.ntf', '.png', '.jpeg', '.jpg',
                        '.jp2', '.img', '.cub', '.bip', '.bil', '.bsq']
    for a in args:
        lc = a.lower()
        for e in IMAGE_EXTENSIONS:
            if lc.endswith(e):
                count += 1

    # Handle --image-list
    if count == 0:
        if '--image-list' not in args:
            raise Exception("No input images found, and neither was " + \
                            "--image-list specified. Supported image extensions: " + \
                            " ".join(IMAGE_EXTENSIONS))

    # Parse --image-list
    if '--image-list' in args:
        count = 0
        for v in range(len(args)):
            if args[v] == '--image-list' and v + 1 < len(args):
                image_list = args[v + 1]
                
                # If image_list has commas, that means it is a list of lists.
                # Then split by comma and collect in an array.
                if ',' in image_list:
                    image_lists = image_list.split(',')
                else:
                    image_lists = [image_list]
                    
                for image_list in image_lists:    
                    # Check if image_list exists
                    if not os.path.isfile(image_list):
                        raise Exception("Could not find image list: " + image_list)
                    with open(image_list, "r") as fh:
                        lines = fh.readlines()
                        lines = (" ".join(lines)).split()
                        
                        for line in lines:
                            line = line.strip()
                            if len(line) == 0:
                                continue
                            count = count + 1

    return count

def get_subfolder_prefix(output_folder, instance_index):
    return os.path.join(output_folder, 'sub_idx_'+str(instance_index), 'run')

# Launch GNU Parallel for all tiles, it will take care of distributing
# the jobs across the nodes and load balancing. The way we accomplish
# this is by calling this same script but with --instance_index <num>.
def spawn_to_nodes(step, output_prefix, argsIn):

    args = copy.copy(argsIn)

    num_instances = get_num_instances(args)

    if opt.processes is None or opt.threads is None:
        # The user did not specify these. We will find the best
        # for their system.
        (procs, threads) = get_best_procs_threads(step)
    else:
        procs   = opt.processes
        threads = opt.threads

    asp_cmd_utils.wipe_option(args, '--processes', 1)
    asp_cmd_utils.wipe_option(args, '--threads', 1)
    args.extend(['--processes', str(procs)])
    args.extend(['--threads', str(threads)])
    args.extend(['--instance-count', str(num_instances)])

    # For convenience store the instance index list in a file that
    # will be passed to GNU parallel. Keep this in the run directory.
    asp_system_utils.mkdir_p(os.path.dirname(output_prefix))
    instance_list = output_prefix + '-instance-list.txt'
    f = open(instance_list, 'w')
    for i in range(num_instances):
        f.write("%d\n" % i)
    f.close()

    # Use GNU parallel with given number of processes.
    # TODO(oalexan1): Run 'parallel' using the runInGnuParallel() function call,
    # when the ASP_LIBRARY_PATH trick can be fully encapsulated in the
    # asp_system_utils.py code rather than being needed for each tool.
    cmd = ['parallel', '--will-cite', '--env', 'PATH', '--env', 'LD_LIBRARY_PATH', '--env', 'ASP_LIBRARY_PATH', '--env', 'ASP_DEPS_DIR', '-u', '--max-procs', str(procs), '-a', instance_list]
    if which(cmd[0]) is None:
        raise Exception('Need GNU Parallel to distribute the jobs.')

    if opt.nodes_list is not None:
        cmd += ['--sshloginfile', opt.nodes_list]

    if opt.parallel_options is not None:
        cmd += opt.parallel_options.split(' ')

    # Add the options which we want GNU parallel to not mess
    # with. Put them into a single string. Before that, put in quotes
    # any quantities having spaces, to avoid issues later.
    # Don't quote quantities already quoted.
    # TODO(oalexan1): Unify this across all tools.
    # Improve and use the function argListToString.
    args_copy = args[:] # deep copy
    args_copy += ["--work-dir", opt.work_dir]
    if opt.isisroot  is not None: 
        args_copy += ["--isisroot",  opt.isisroot]
    if opt.isisdata is not None: 
        args_copy += ["--isisdata", opt.isisdata]
    
    for index, arg in enumerate(args_copy):
        if re.search(r"[ \t]", arg) and arg[0] != '\'':
            args_copy[index] = '\'' + arg + '\''
    python_path = sys.executable # children must use same Python as parent
    start    = step; stop = start + 1
    args_str = python_path + " "              + \
               " ".join(args_copy)            + \
               " --entry-point " + str(start) + \
               " --stop-point " + str(stop) 
    args_str += " --instance-index {}"
    cmd += [args_str]

    # This is a bugfix for RHEL 8. The 'parallel' program fails to start with ASP's
    # libs, so temporarily hide them.
    if 'LD_LIBRARY_PATH' in os.environ:
        os.environ['ASP_LIBRARY_PATH'] = os.environ['LD_LIBRARY_PATH']
        os.environ['LD_LIBRARY_PATH'] = ''

    try:
        asp_system_utils.generic_run(cmd, opt.verbose)
    except Exception as e:
        exception_len = len(str(e))
        # This may fail because the command line is too long. It does not seem
        # possible to catch this error, which comes from parallel. So we
        # have to guess.
        if exception_len > 50000:
            print("If the command line is too long, consider using --image-list, etc.")

        raise Exception(str(e))

    # Undo the above
    if 'ASP_LIBRARY_PATH' in os.environ:
        os.environ['LD_LIBRARY_PATH'] = os.environ['ASP_LIBRARY_PATH']

def run_job(prog, args, instance_index, **kw):
    '''Wrapper to run a command.
       Set instance_index=-1 if only one instance will be called.'''

    binpath = bin_path(prog)
    call    = [binpath]
    call.extend(args)

    if instance_index < 0: # Only one instance will be used
        if opt.threads is not None:
            asp_cmd_utils.wipe_option(call, '--threads', 1)
            call.extend(['--threads', str(opt.threads)])

    else: # One of the parallel jobs

        if opt.threads is not None:
            asp_cmd_utils.wipe_option(call, '--threads', 1)
            call.extend(['--threads', str(opt.threads)])

        if ('--stop-after-statistics' not in args) and ('--save-vwip' in args):
            # Must create a custom subdirectory for .vwip files,
            # as those are different for each pair.
            output_prefix = get_output_prefix(args)
            output_folder = os.path.dirname(output_prefix)
            sub_prefix    = get_subfolder_prefix(output_folder, instance_index)
            set_option(call, '--vwip-prefix', [sub_prefix])

    if opt.dryrun:
        print('%s' % ' '.join(call))
        return
    if opt.verbose:
        print('%s' % ' '.join(call))
    try:
        code = subprocess.call(call)
    except OSError as e:
        raise Exception('%s: %s' % (binpath, e))
    if code != 0:
        raise Exception('Bundle adjust step ' + kw['msg'] + ' failed')

def distribute_or_gather_files(output_prefix, ext, num_instances):
    '''The .stats (and camera footprint bbox files) files get distributed to all folders.
       The .match files get symlinked to the main folder.
       No longer used as files are no longer saved in sudirectories.
       '''

    output_folder = os.path.dirname(output_prefix)
    sub_prefixes = [get_subfolder_prefix(output_folder, i) for i in range(num_instances)]

    if ext == '.match':

        for pre in sub_prefixes:

            # Sym link match files from subdirectories to the output directory.
            # This takes a lot of care to get right.
            files = glob.glob(pre + '*' + ext)
            for f in files:
                src_file = os.path.relpath(f, output_folder)
                dst_file = output_prefix + f[len(pre):]

                # First wipe any existing file
                if os.path.exists(dst_file):
                   os.remove(dst_file)

                # Create sym link
                os.symlink(src_file, dst_file)

    else: # Stats files
        # Find all the statistics files
        stats_files = glob.glob(output_folder + '/*' + ext)

        # Duplicate them in each of the subfolders, creating them if needed.
        # - We could link them but these are tiny files so don't bother.
        for f in stats_files:
            base_name = f.replace(output_prefix,'')
            for s in sub_prefixes:
                sub_folder = os.path.dirname(s)
                new_path = s + base_name
                asp_system_utils.mkdir_p(os.path.dirname(new_path))
                shutil.copyfile(f, new_path)

class ParallelBaStep:
    # The ids of individual parallel_bundle_adjust steps
    statistics   = 0
    matching     = 1
    optimization = 2

if __name__ == '__main__':
    usage = '''parallel_bundle_adjust <images> <cameras> <optional ground control points> -o <output prefix> [options]
        Camera model arguments may be optional for some stereo
        session types (e.g. isis). See bundle_adjust for all options.\n''' + get_asp_version()

    # A wrapper for bundle_adjust which computes image statistics and IP matches
    #  in parallel across multiple machines.  The final bundle_adjust step is
    #  performed on a single machine.

    # Algorithm: When the script is started, it starts one copy of
    # itself on each node during steps 1 and 2 (statistics, matching).
    # Those scripts in turn start actual jobs on those nodes.
    # For step 3 (optimization), the script does the work itself.

    p = argparse.ArgumentParser(usage=usage)
    p.add_argument('--nodes-list',           dest='nodes_list', default=None,
                 help='The list of computing nodes, one per line. ' + \
                 'If not provided, run on the local machine.')
    p.add_argument('--processes',            dest='processes', default=None,
                 type=int, help='The number of processes to use per node.')
    p.add_argument('--threads', dest='threads', default=None,
                 type=int, help='The number of threads to use.')
    p.add_argument('-e', '--entry-point',    dest='entry_point', default=0,
                   help = "Bundle adjustment entry point (start at this stage). " + \
                   "Options: statistics = 0, matching = 1, optimization = 2.",
                   type=int)
    p.add_argument('--stop-point',           dest='stop_point',  default=3,
                   help = "Bundle adjustment stop point (stop *before* this stage). " + \
                   "Options: statistics = 0, matching = 1, optimization = 2, all = 3.",
                   type=int)
    p.add_argument('--parallel-options', dest='parallel_options', default='--sshdelay 0.2',
                   help='Options to pass directly to GNU Parallel.')
    p.add_argument('-v', '--version',        dest='version', default=False,
                 action='store_true', help='Display the version of software.')
    p.add_argument('--verbose', dest='verbose', default=False, action='store_true',
                 help='Display the commands being executed.')

    # Internal variables below.
    # The index of the spawned process, 0 <= instance_index < processes.
    p.add_argument('--instance-index', dest='instance_index', default=None, type=int,
                 help=argparse.SUPPRESS)
    # Directory where the job is running
    p.add_argument('--work-dir', dest='work_dir', default=None,
                 help=argparse.SUPPRESS)
    # ISIS settings
    p.add_argument('--isisroot', dest='isisroot', default=None,
                 help=argparse.SUPPRESS)
    p.add_argument('--isisdata', dest='isisdata', default=None,
                 help=argparse.SUPPRESS)
    # Debug options
    p.add_argument('--dry-run', dest='dryrun', default=False, action='store_true',
                 help=argparse.SUPPRESS)

    global opt
    (opt, args) = p.parse_known_args()
    args = clean_args(args)
    output_prefix = get_output_prefix(args)

    if opt.version:
        asp_system_utils.print_version_and_exit()

    if not args and not opt.version:
        p.print_help()
        sys.exit(1)

    # Ensure our 'parallel' is not out of date
    check_parallel_version()

    if opt.threads is None:
        opt.threads = get_num_cpus()

    if opt.instance_index is None:
        # When the script is started, set some options from the
        # environment which we will pass to the scripts we spawn
        # 1. Set the work directory
        opt.work_dir = os.getcwd()
        # 2. Set the ISIS settings if any
        if 'ISISROOT'  in os.environ: opt.isisroot  = os.environ['ISISROOT']
        if 'ISISDATA' in os.environ: opt.isisdata = os.environ['ISISDATA']
        # 3. Fix for Pleiades, copy the nodes_list to the run directory
        # to ensure it can be seen on any node.
        if opt.nodes_list is not None:
            if not os.path.isfile(opt.nodes_list):
                die('\nERROR: No such nodes-list file: ' + opt.nodes_list, code=2)
            local_nodes_list = output_prefix + "-nodes-list.txt"
            if opt.nodes_list != local_nodes_list:
                asp_system_utils.mkdir_p(os.path.dirname(output_prefix))
                shutil.copy2(opt.nodes_list, local_nodes_list)
                opt.nodes_list = local_nodes_list
            asp_cmd_utils.wipe_option(sys.argv, '--nodes-list', 1)
            sys.argv.extend(['--nodes-list', opt.nodes_list])
            print("Nodes list: " + opt.nodes_list)
            
    else:
        # After the script spawns itself to nodes, it starts in the
        # home dir. Make it go to the right place.
        os.chdir(opt.work_dir)
        # Set the ISIS settings
        if opt.isisroot  is not None: os.environ['ISISROOT' ] = opt.isisroot
        if opt.isisdata is not None: os.environ['ISISDATA'] = opt.isisdata

    num_nodes     = get_num_nodes(opt.nodes_list)
    num_instances = get_num_instances(args)

    if opt.version:
        args.append('-v')

    if opt.instance_index is None:

        # We get here when the script is started. The current running
        # process has become the management process that spawns other
        # copies of itself on other machines. This block will only do
        # actual work during the optimization step.

        # Wipe options which we will override.
        self_args = sys.argv # shallow copy
        asp_cmd_utils.wipe_option(self_args, '-e', 1)
        asp_cmd_utils.wipe_option(self_args, '--entry-point', 1)
        asp_cmd_utils.wipe_option(self_args, '--stop-point',  1)

        output_folder = os.path.dirname(output_prefix)

        # Create the main output folder
        if (not os.path.exists(output_folder)) and (output_folder != ""):
            os.makedirs(output_folder)

        if '--isis-cnet' in args or '--match-files-prefix' in args or          \
          '--clean-match-files-prefix' in args or '--skip-matching' in args or \
          '--nvm' in args:
           print("Skipping statistics and matching based on input options.")
           if opt.entry_point < ParallelBaStep.optimization:
              opt.entry_point = ParallelBaStep.optimization
            
        # Statistics
        step = ParallelBaStep.statistics
        if (opt.entry_point <= step):
            if (opt.stop_point <= step):
                sys.exit()

            # Spawn statistics processes to nodes.
            spawn_to_nodes(step, output_prefix, self_args)

        # Matching
        step = ParallelBaStep.matching
        if ( opt.entry_point <= step):
            if (opt.stop_point <= step):
                sys.exit()

            # Spawn statistics processes to subdir for nodes.
            spawn_to_nodes(step, output_prefix, self_args)

        # Optimization
        step = ParallelBaStep.optimization
        if (opt.entry_point <= step):
            if ( opt.stop_point <= step ):
                sys.exit()
            args.extend(['--skip-matching'])
            run_job('bundle_adjust', args, instance_index=-1, msg='%d: Optimizing' % step)

            # End main process case
    else:

        # This process was spawned by GNU Parallel with a given
        # value of opt.tile_id. Launch the job for that tile.
        if opt.verbose:
            print("Running on machine: ", os.uname())

        try:
            args.extend(['--instance-index', str(opt.instance_index)])

            if ( opt.entry_point == ParallelBaStep.statistics ):
                args.extend(['--stop-after-statistics'])
                run_job('bundle_adjust', args, opt.instance_index,
                        msg='%d: Statistics' % opt.entry_point)

            if ( opt.entry_point == ParallelBaStep.matching ):
                args.extend(['--stop-after-matching'])
                run_job('bundle_adjust', args, opt.instance_index,
                        msg='%d: Matching' % opt.entry_point)

        except Exception as e:
            die(e)
            raise
