#!/usr/bin/env python

# Copyright (c) 2021, United States Government, as represented by the
# Administrator of the National Aeronautics and Space Administration.
#
# All rights reserved.
#
# The "ISAAC - Integrated System for Autonomous and Adaptive Caretaking
# platform" software is licensed under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with the
# License. You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

"""
Prepare for and then run stereo on many image pairs, and produce a fuse mesh.
"""

import argparse, os, re, shutil, subprocess, sys, glob, shlex
import numpy as np

# Set up the path to Python modules about to load
basepath    = os.path.abspath(sys.path[0])
pythonpath  = os.path.abspath(basepath + '/../Python')  # for dev ASP
libexecpath = os.path.abspath(basepath + '/../libexec') # for packaged ASP
sys.path.insert(0, basepath) # prepend to Python path
sys.path.insert(0, pythonpath)
sys.path.insert(0, libexecpath)

import asp_rig_utils

def process_args(args):
    """
    Set up the parser and parse the args.
    """

    parser = argparse.ArgumentParser(description = "Parameters for multi_stereo.")
    
    parser.add_argument("--rig_config",  dest="rig_config", default="",
                        help = "Rig configuration file.")
    
    parser.add_argument("--rig_sensor", dest="rig_sensor", default="",
                        help="Which rig sensors to use. Must be among the "       + \
                        "sensors specified via --rig_config. To use images from " + \
                        "several sensors, pass in a quoted list of them, "        + \
                        "separated by a space.")

    parser.add_argument("--camera_poses", dest="camera_poses",
                        default="", help= "Read images and camera poses for this " + \
                        "sensor from this list.")

    parser.add_argument("--out_dir", dest="out_dir",
                        default="", help="The directory where to write the stereo " + \
                        "output, textured mesh, and other data.")
    
    parser.add_argument("--stereo_options", dest="stereo_options",
                        default="", help="Options to pass to parallel_stereo. Use " + \
                        "double quotes around the full list and simple quotes "   + \
                        "if needed by an individual option, or vice-versa.")
    
    parser.add_argument("--pc_filter_options", dest="pc_filter_options",
                        default="", help="Options to pass to pc_filter.")

    parser.add_argument("--mesh_gen_options", dest="mesh_gen_options",
                        default="", help="Options to pass to voxblox for mesh generation.")
    
    # Note how the percent sign below is escaped, by writing: %%
    parser.add_argument("--undistorted_crop_win", dest="undistorted_crop_win",
                        default = "",
                        help = "The dimensions of the central image region to keep " + \
                        "after the internal undistortion step and before using "     + \
                        "it in stereo. Normally 85%% - 90%% of distorted (actual) "  + \
                        "image dimensions would do. Suggested the Astrobee images:"  + \
                        " sci_cam: '1250 1000' nav_cam: '1100 776'. haz_cam: '250 200'.")
    
    parser.add_argument("--first_step",  dest="first_step", default="stereo",
                        help  = "Let the first step run by this tool be one of: " + \
                        "'stereo', 'pc_filter', or 'mesh_gen'. " + \
                        "This allows resuming a run at a desired step. " + \
                        "The stereo subdirectories are deleted before that " + \
                        "step takes place.")
    
    parser.add_argument("--left",  dest="left", default="",
                        help  = "Instead of running pairwise stereo between every " + \
                        "image and the next one given in --camera_poses, use every " + \
                        "image from this list and corresponding one from the list " + \
                        "given by the --right option.")

    parser.add_argument("--right",  dest="right", default="",
                        help  = "To be used with --left.")
    
    parser.add_argument("--last_step",  dest="last_step", default="mesh_gen",
                        help  = "The last step run by this tool. See ``--first_step`` " + \
                        "for allowed values.")
    
    parser.add_argument("--first-image-index",  dest="first_image_index", default = None,
                        type = int,
                        help = "The index of the first image to use for stereo, in the "
                        "list of images. Indices start from 1. By default, use " + \
                        "all the images.")
    
    parser.add_argument("--last-image-index",  dest="last_image_index", default = None,
                        type = int, 
                        help = "The index of the last image to use for stereo, in the "
                        "list of images. Indices start from 1. By default, " + \
                        "use all the images.")

    args = parser.parse_args()

    return args

def gen_pairs(args, all_distorted_images):
    """
    Create pairs of indices for the images to run stereo on. By
    default, use each image and the next in the list. Otherwise, if
    args.left and args.right are specified, use pairs made up of one
    image from the left list and corresponding one from the right
    list.
    """
    
    index_pairs = []
    if args.left == "" or args.right == "":
        for t in range(len(all_distorted_images) - 1):
            index_pairs.append((t, t + 1))
        return index_pairs

    dist_indices = asp_rig_utils.create_index_dict(all_distorted_images)
    image_pairs = asp_rig_utils.read_image_pairs(args.left, args.right)
    for pair in image_pairs:
        left_index = dist_indices[pair[0]]
        right_index = dist_indices[pair[1]]
        index_pairs.append((left_index, right_index))
    
    return index_pairs

# Assign to each step a number, given the order in which they should be done
step_dict = {"stereo": 0, "pc_filter": 1, "mesh_gen": 2}

def sanity_checks(args):

    if args.camera_poses == "":
        raise Exception("The path to the list having input images and poses was not specified.")

    if args.rig_config == "":
        raise Exception("The path to the rig configuration was not specified.")

    if args.rig_sensor == "":
        raise Exception("The rig sensor to use for texturing was not specified.")

    if args.out_dir == "":
        raise Exception("The path to the output directory was not specified.")

    if args.undistorted_crop_win == "":
        raise Exception("The undistorted crop win was not specified.")

    for step in [args.first_step, args.last_step]:
        if step not in step_dict:
            raise Exception("Invalid values specified for --first_step or --last_step.")

    if (args.first_image_index is not None or args.last_image_index is not None) and \
       (args.left != "" or args.right != ""):
        raise Exception("Cannot use --first_image_index, --last_image_index " +
                        "together with --left, --right.")

    if int(args.left == "") + int(args.right == "") == 1:
        raise Exception("Must specify either both --left and --right, or none of them.")
        
def write_asp_and_voxblox_cameras(undist_intrinsics_file, distorted_images,
                                  undistorted_images,
                                  world_to_cam):

    if len(undistorted_images) != len(world_to_cam):
        raise Exception("Expecting as many undistorted images as cameras.")
    
    (widx, widy, f, cx, cy) = asp_rig_utils.read_intrinsics(undist_intrinsics_file)
    undist_image_dict = {}
    for image in undistorted_images:
        # Use the image name without dir and extension as a key
        base = os.path.basename(image)
        base = os.path.splitext(base)[0]
        undist_image_dict[base] = image

    # Collect the undistorted images for which camera info is present
    undistorted_images_out = []
    cameras = []
    cam_to_world_files = []
    count = 0
    for dist_image in distorted_images:

        base = os.path.basename(dist_image)
        base = os.path.splitext(base)[0]
        if base not in undist_image_dict:
            raise Exception("Cannot find undistorted image for: "  + dist_image)
        undist_image = undist_image_dict[base]
        tsai_file = os.path.splitext(undist_image)[0] + ".tsai"

        # Find the camera-to-world transform
        cam_to_world = np.linalg.inv(world_to_cam[count])
        asp_rig_utils.write_tsai_camera_file(tsai_file, f, cx, cy, cam_to_world)
        
        # For use with voxblox later
        cam_to_world_file = os.path.splitext(undist_image)[0] + "_cam2world.txt"
        asp_rig_utils.write_cam_to_world_matrix(cam_to_world_file, cam_to_world)

        undistorted_images_out.append(undist_image)
        cameras.append(tsai_file)
        cam_to_world_files.append(cam_to_world_file)
        
        count = count + 1
        
    print("Found " + str(len(undistorted_images)) + " undistorted image(s).")
    print("Wrote: " + str(count) + " camera(s).")

    return (undistorted_images_out, cameras, cam_to_world_files)

def run_pair(left_image, right_image, left_cam, right_cam, args, sensor_dir, tool_base_dir,
             first_step, last_step):

    out_dir = args.out_dir + "/" + sensor_dir

    left_prefix = os.path.basename(os.path.basename(left_image))
    left_prefix, ext = os.path.splitext(left_prefix)

    right_prefix = os.path.basename(os.path.basename(right_image))
    right_prefix, ext = os.path.splitext(right_prefix)

    stereo_dir = out_dir + "/stereo/" + left_prefix + "_" + right_prefix
    stereo_prefix = stereo_dir + "/run"
    pcd_file = stereo_prefix + '-PC-filter.pcd'
    
    if first_step <= 0 and last_step >= 0:
        # Run stereo
        if os.path.isdir(stereo_dir):
            print("Will run stereo in existing directory: " + stereo_dir)
            ## Wipe the existing directory
            #print("Removing recursively old directory: " + stereo_dir)
            #shutil.rmtree(stereo_dir)

        # Split on spaces but keep quoted parts together. Wipe stray
        # continuation lines.
        args.stereo_options = args.stereo_options.replace('\\', ' ')
        stereo_options = shlex.split(args.stereo_options)

        # Run stereo
        tool_path = asp_rig_utils.find_tool(tool_base_dir, 'parallel_stereo')
        cmd = [tool_path] + stereo_options + \
              [left_image, right_image, left_cam, right_cam, stereo_prefix]
        
        # Allow this to fail, as perhaps not all stereo pairs are good
        asp_rig_utils.run_cmd(cmd, quit_on_failure = False)

    if first_step <= 1 and last_step >= 1:

        # Split on spaces but keep quoted parts together
        args.pc_filter_options = args.pc_filter_options.replace('\\', ' ')
        pc_filter_options = shlex.split(args.pc_filter_options)

        # Run pc_filter
        tool_path = asp_rig_utils.find_tool(tool_base_dir, 'pc_filter')
        cmd = [tool_path,
               '--input-cloud',   stereo_prefix + '-PC.tif',
               '--input-texture', stereo_prefix + '-L.tif',
               '--output-cloud',  stereo_prefix + '-PC-filter.tif',
               '--camera', left_cam] + \
               pc_filter_options

        # Allow this to fail, as perhaps not all stereo pairs are good
        asp_rig_utils.run_cmd(cmd, quit_on_failure = False)

        # Run point2mesh. This is for debugging purposes, to be able
        # to inspect each cloud.
        # If we use -s 1 the .obj file is too big. But -s 4 is too coarse.
        # TODO(oalexan1): Make creating these debug meshes optional.
        # Run point2mesh
        tool_path = asp_rig_utils.find_tool(tool_base_dir, 'point2mesh')
        cmd = [tool_path, "-s", "4",
               "--texture-step-size", "1",
               stereo_prefix + '-PC-filter.tif',
               stereo_prefix + '-L.tif']
        asp_rig_utils.run_cmd(cmd, quit_on_failure = False)
 
        # Write this in PCD format and in left camera's coordinates,
        # so it can be merged later with voxblox
        tool_path = asp_rig_utils.find_tool(tool_base_dir, 'pc_filter')
        cmd = [tool_path,
               '--input-cloud',   stereo_prefix + '-PC.tif',
               '--input-texture', stereo_prefix + '-L.tif',
               '--output-cloud',  pcd_file,
               '--camera', left_cam,
               '--transform-to-camera-coordinates',
               '--output-weight', stereo_prefix + 'PC-weight.tif'] + \
               pc_filter_options
        asp_rig_utils.run_cmd(cmd, quit_on_failure = False)

    return pcd_file

def fuse_clouds(args, pairs, sensor_dir, tool_base_dir, undistorted_images, 
                cam_to_world_files):
    """
    Fuse the results in a mesh with voxblox.
    """
    
    voxblox_index =  args.out_dir + "/" + sensor_dir + "/voxblox_index.txt"
    fused_mesh = args.out_dir + "/" + sensor_dir + "/fused_mesh.ply"
    print("Writing: " + voxblox_index)
    with open(voxblox_index, "w") as handle:
        for t in range(len(pairs)):
            L=pairs[t][0]
            R=pairs[t][1]
            handle.write(cam_to_world_files[L] + "\n")
            handle.write(pcd_files[t] + "\n")

    args.mesh_gen_options = args.mesh_gen_options.replace('\\', ' ')
    mesh_gen_options = shlex.split(args.mesh_gen_options)
    tool_path = asp_rig_utils.find_tool(tool_base_dir, 'voxblox_mesh')
    cmd = [tool_path, '--index', voxblox_index,
           '--output_mesh', fused_mesh] + \
           mesh_gen_options
    asp_rig_utils.run_cmd(cmd, quit_on_failure = True)

if __name__ == "__main__":

    args = process_args(sys.argv)
    sanity_checks(args)
    asp_rig_utils.mkdir_p(args.out_dir)

    sensors = args.rig_sensor.split()
    sensor_dir = "_".join(sensors) # for when there are multiple sensors
    tool_base_dir = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))

    # Read the images and camera poses
    # TODO(oalexan1): Make the block below a function
    all_distorted_images = []
    all_undistorted_images = []
    all_cams = []
    all_cam_to_world_files = []
    for sensor in sensors:

        # For the case when only a subset of images is used, and those
        # are set via lists
        # TODO(oalexan1): Make this a function called parse_all_cameras()
        lists = [""]
        if args.left != "" and args.right != "":
            lists = [args.left, args.right]
        distorted_images = []
        world_to_cam = []
        for subset_list in lists:
            (local_distorted_images, local_world_to_cam) \
                                 = asp_rig_utils.parse_cameras(args.camera_poses,
                                                           subset_list,
                                                           sensor,
                                                           args.first_image_index,
                                                           args.last_image_index)
            distorted_images += local_distorted_images
            world_to_cam += local_world_to_cam
            
        # Write undistorted images
        extension = '.tif' # use lossless .tif extension for undistorted images
        extra_opts = []
        suff = ""
        (undist_intrinsics_file, undistorted_images, undist_dir) \
                                 = asp_rig_utils.undistort_images(args, sensor,
                                                              distorted_images,
                                                              tool_base_dir,
                                                              extension, extra_opts, suff)
        # Write camera poses in the format of ASP and voxblox
        (undistorted_images, cameras, cam_to_world_files) \
                             = write_asp_and_voxblox_cameras(undist_intrinsics_file,
                                                             distorted_images,
                                                             undistorted_images,
                                                             world_to_cam)
        all_distorted_images += distorted_images
        all_undistorted_images += undistorted_images
        all_cams += cameras
        all_cam_to_world_files += cam_to_world_files

    # TODO(oalexan1): This must be parallelized like done for icebridge
    # Run for each pair parallel_stereo and/or filtering. Even if desired
    # to skip these steps, must go through the motions to produce
    # some lists of files neded for mesh generation later.
    pcd_files = []
    pairs = gen_pairs(args, all_distorted_images)
    for pair in pairs:
        L = pair[0]
        R = pair[1]
        pcd_file = run_pair(all_undistorted_images[L], all_undistorted_images[R],
                            all_cams[L], all_cams[R],
                            args, sensor_dir, tool_base_dir,
                            step_dict[args.first_step], step_dict[args.last_step])
        print("got pcd file: " + pcd_file)
        pcd_files.append(pcd_file)
        
    if step_dict[args.first_step] <= 2 and step_dict[args.last_step] >= 2:
        fuse_clouds(args, pairs, sensor_dir, tool_base_dir,
                    all_undistorted_images, all_cam_to_world_files)
