#!/usr/bin/env python
# __BEGIN_LICENSE__
#  Copyright (c) 2009-2025, United States Government as represented by the
#  Administrator of the National Aeronautics and Space Administration. All
#  rights reserved.
#
#  The NGT platform is licensed under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance with the
#  License. You may obtain a copy of the License at
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
# __END_LICENSE__

import sys, argparse, os, subprocess, shlex, itertools

# Set up the path to Python modules about to load
basepath    = os.path.abspath(sys.path[0])
pythonpath  = os.path.abspath(basepath + '/../Python')  # for dev ASP
libexecpath = os.path.abspath(basepath + '/../libexec') # for packaged ASP
sys.path.insert(0, basepath) # prepend to Python path
sys.path.insert(0, pythonpath)
sys.path.insert(0, libexecpath)

import asp_system_utils

def parseParamSweep(sweep_str):
  """Parse a --param-sweep string into a dict of {param_name: [values]}.
  
  Input format: "--param1 val1 val2 --param2 val3 val4"
  Returns: {"--param1": ["val1", "val2"], "--param2": ["val3", "val4"]}
  
  Supports multiple separators: spaces, colons, semicolons, newlines.
  A comma is used if a value itself has multiple parts (e.g., "7,7").
  """
  
  # Normalize all separators to spaces
  normalized = sweep_str.replace(':', ' ').replace(';', ' ')
  normalized = normalized.replace('\n', ' ')
  
  # Split into tokens
  tokens = normalized.split()
  
  params = {}
  current_param = None
  
  for token in tokens:
    if token.startswith('--'):
      # Start of a new parameter
      current_param = token
      params[current_param] = []
    elif current_param is not None:
      # Value for current parameter
      params[current_param].append(token)
  
  return params

def generateCombinations(parsed_params):
  """Generate all combinations from parsed parameters using Cartesian product.
  
  Args:
    parsed_params: dict like {"--param1": ["val1", "val2"], "--param2": ["val3"]}
  
  Returns:
    List of dicts, one for each combination:
    [{"--param1": "val1", "--param2": "val3"}, {"--param1": "val2", "--param2": "val3"}]
  """
  if not parsed_params:
    return [{}]
  
  param_names = list(parsed_params.keys())
  param_values = [parsed_params[name] for name in param_names]
  
  # Generate Cartesian product
  combinations = []
  for combo in itertools.product(*param_values):
    combo_dict = {param_names[i]: combo[i] for i in range(len(param_names))}
    combinations.append(combo_dict)
  
  return combinations

def runCombination(combo_dict, args, opt, out_prefix):
  """Run parallel_stereo and point2dem for a single parameter combination.
  
  Args:
    combo_dict: Dict of parameters like {"--stereo-algorithm": "asp_mgm", etc}
    args: Pass-through arguments (images, cameras, etc.)
    opt: Parsed options object
    out_prefix: Output prefix for this run
  
  Returns:
    0 on success, non-zero on failure
  """
  
  # Build the parameter list from combo_dict
  combo_args = []
  for param, value in combo_dict.items():
    # Handle multi-value parameters separated by commas (e.g., "7,7" -> "7 7")
    if ',' in value:
      combo_args.append(param)
      combo_args.extend(value.replace(',', ' ').split())
    else:
      combo_args.extend([param, value])
  
  # Assemble parallel_stereo command with combination parameters
  stereo_cmd = ['parallel_stereo'] + args + combo_args + [out_prefix]
  
  # Add DEM if provided
  if opt.dem:
    stereo_cmd.extend(['--dem', opt.dem])
  
  # Run the command using ASP utilities
  if opt.dry_run:
    print(' '.join(stereo_cmd))
  else:
    (out, err, status) = asp_system_utils.executeCommand(stereo_cmd, 
                                                         realTimeOutput = True)
    if status != 0:
      print(f"ERROR: parallel_stereo failed with status {status}")
      return status
    
    print(f"\nparallel_stereo completed successfully.")
  
  # Now run point2dem on the output point cloud
  pc_file = out_prefix + "-PC.tif"
  point2dem_cmd = ['point2dem', pc_file]
  
  # Add user-specified point2dem parameters
  if opt.point2dem_params:
    # Split the string into individual arguments
    point2dem_cmd.extend(shlex.split(opt.point2dem_params))
  
  # If --orthoimage is requested, insert L.tif right after it
  if opt.point2dem_params and '--orthoimage' in opt.point2dem_params:
    l_file = out_prefix + "-L.tif"
    # Find --orthoimage in the command and insert L.tif after it
    try:
      ortho_idx = point2dem_cmd.index('--orthoimage')
      point2dem_cmd.insert(ortho_idx + 1, l_file)
    except ValueError:
      pass  # Should not happen since we checked for it
  
  if opt.dry_run:
    print(' '.join(point2dem_cmd))
  else:
    (out, err, status) = asp_system_utils.executeCommand(point2dem_cmd,
                                                         realTimeOutput = True)
    if status != 0:
      print(f"ERROR: point2dem failed with status {status}")
      return status
    
    print(f"\npoint2dem completed successfully.")

  return 0

def generateRunIndex(parsed_sweeps, output_dir):
  """Generate run_index.csv mapping run directories to parameter values.
  
  Args:
    parsed_sweeps: List of sweep dicts with combinations
    output_dir: Output directory for the CSV file
  """
  
  # Collect all unique parameter names across all sweeps
  all_params = set()
  for sweep in parsed_sweeps:
    for combo in sweep['combinations']:
      all_params.update(combo.keys())
  
  # Sort parameter names for consistent column ordering
  param_names = sorted(all_params)
  
  # Build CSV content
  csv_lines = []
  
  # Header row
  header = ['run_dir'] + param_names
  csv_lines.append(', '.join(header))
  
  # Data rows
  run_num = 0
  for sweep in parsed_sweeps:
    for combo in sweep['combinations']:
      run_dir = f"{output_dir}/run_{run_num:04d}"
      row = [run_dir]
      
      # Add parameter values, handling commas in values (convert to spaces)
      for param in param_names:
        if param in combo:
          value = combo[param]
          # Replace commas with spaces (multi-value parameters)
          if ',' in value:
            value = value.replace(',', ' ')
          row.append(value)
        else:
          row.append('')  # Empty if parameter not in this combination
      
      csv_lines.append(', '.join(row))
      run_num += 1
  
  # Write to file
  csv_path = os.path.join(output_dir, 'run_index.csv')
  with open(csv_path, 'w') as f:
    f.write('\n'.join(csv_lines) + '\n')
  
  print(f"Writing run index: {csv_path}")

def main():
  parser = argparse.ArgumentParser(
    description='Run parallel_stereo on small patches with different ' + \
                'parameter combinations.')

  parser.add_argument('--param-sweep', action='append', default=[],
      help='Parameter sweep. Must be in quotes. Multiple such options can be ' + \
      'specified. Each defines a different set of parameter combinations to test.')
  parser.add_argument('--dem', default='',
    help='Input DEM for mapprojection. Required if using --proj-win.')
  parser.add_argument('--output-dir', default='', 
    help='Output directory.')
  parser.add_argument('--point2dem-params', default='',
    help='Parameters to pass to point2dem. If --orthoimage (with no ' + \
    'argument) is passed in, the needed L.tif will be passed to each ' + \
    'created point2dem command.')
  parser.add_argument('--dry-run', dest='dry_run', action='store_true',
    help='Print commands without executing them.')

  if len(sys.argv) == 1:
    parser.print_help()
    return 0

  (opt, args) = parser.parse_known_args()
  
  # Check that output directory is provided
  if not opt.output_dir:
    print("ERROR: --output-dir is required")
    return 1
  
  # Check if --proj-win is used in any sweep, and if so, require --dem
  has_proj_win = False
  if opt.param_sweep:
    for sweep_str in opt.param_sweep:
      if '--proj-win' in sweep_str:
        has_proj_win = True
        break
  
  if has_proj_win and not opt.dem:
    print("ERROR: --dem is required when using --proj-win in parameter sweeps")
    return 1
  
  # Parse parameter sweeps - each --param-sweep creates a separate experiment
  parsed_sweeps = []
  run_count = 0  # Running total of runs across all sweeps
  if opt.param_sweep:
    for sweep_num, sweep_str in enumerate(opt.param_sweep, start=1):
      parsed = parseParamSweep(sweep_str)
      combinations = generateCombinations(parsed)
      parsed_sweeps.append({
          'sweep_num': sweep_num,
          'run_start': run_count,
          'combinations': combinations
      })
      run_count += len(combinations)
  
  # If no sweeps specified, run once with no extra parameters
  if not parsed_sweeps:
    print("No parameter sweeps specified, running single instance")
    out_prefix = opt.output_dir + "/run"
    os.makedirs(opt.output_dir, exist_ok=True)
    status = runCombination({}, args, opt, out_prefix)
    return status
  
  # Create output directory and generate run index before starting runs
  os.makedirs(opt.output_dir, exist_ok=True)
  generateRunIndex(parsed_sweeps, opt.output_dir)
  
  # Iterate over all sweeps and combinations
  run_num = 0
  
  for sweep in parsed_sweeps:
    print(f"\nProcessing sweep {sweep['sweep_num']}")
    
    for combo in sweep['combinations']:
      # Create run directory with zero-padded number
      run_dir = f"{opt.output_dir}/run_{run_num:04d}"
      os.makedirs(run_dir, exist_ok=True)
      out_prefix = f"{run_dir}/run"
      
      print(f"\nRun {run_num:04d}")
      print(f"Parameters: {combo}")
      
      status = runCombination(combo, args, opt, out_prefix)
      if status != 0:
        print(f"ERROR: Run {run_num:04d} failed, stopping")
        return 1
      
      run_num += 1
  
  print(f"\nAll {run_num} runs completed successfully")
  return 0

if __name__ == '__main__':
  sys.exit(main())
