#!/bin/bash

# Copyright (C) 2025 Barcelona Supercomputing Center
#
# This file is part of DMR.
#
# DMR is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License
# as published by the Free Software Foundation; version 2 only.
#
# DMR is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty
# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with DMR; if not, see <https://www.gnu.org/licenses/>.

show_help() {
    cat << EOF
Usage: dmr_wrapper COMMAND [OPTIONS] ./your_program [PROGRAM_ARGS...]

Wrapper script to run a DMR program and manage its state.

Supported COMMAND options:
  prterun
  mpirun
  mpiexec

Options:
  --help, -h
        Show this help message and exit.

Example:
  dmr_wrapper mpirun -np 4 --host my_host:123 ./my_program arg1 arg2

Description:
  This script wraps the specified MPI launcher command to inject
  runtime options that:
    - Forward environment variables to the launched program.
    - Disable Slurm resource and process launch management (RAS, PLM).
    - Ensure absolute COMMAND path so it acts like the --prefix OpenMPI flag
    - Manage state and provide restart mechanism for checkpoint-restart, if needed

Notes:
  The wrapper expects the MPI command as the first argument, followed
  by its options, and finally the executable and its arguments.

EOF
}

# For debugging purposes
create_test_output() {
    printf '%s\0' \
      'my_host:123,my_host1:456' \
      1 \
      5 \
      123.456 \
      4 \
      'my_program' \
      1 \
      2 \
      3 > "$DMR_STATE_FILE"
}

find_executable() {
    local first_arg="$1"  
    shift  # Ignore first prterun/mpirun/mpiexec

    for arg in "$@"; do
        # Check if current is executable file
        if [ -x "$arg" ] && [ -f "$arg" ]; then
            echo "$arg"
            return 0
        fi

        # Check if on PATH
        if command -v "$arg" >/dev/null 2>&1; then
            command -v "$arg"
            return 0
        fi
    done

    return 1
}

insert_into_cmd() {
    local -a insert_strings=("${!1}")
    local executable="$2"
    shift 2

    INSERTED_CMD=()
    local inserted=false

    for arg in "$@"; do
        if [[ "$arg" == "$executable" && $inserted == false ]]; then
            inserted=true
            INSERTED_CMD+=("${insert_strings[@]}")
        fi
        INSERTED_CMD+=("$arg")
    done
}

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

ALLOWED_COMMANDS=("prterun" "mpirun" "mpiexec")

WRAPPER_CMD=$(basename "$0")

# Wrapper was called without additional arguments, just echo out usage
if [[ $# -lt 1 ]]; then
    echo "Usage: $WRAPPER_CMD mpirun [--host my_host:123] [-np 1] [optional other flags] ./your_program [args...]" >&2
    exit 1
fi

if [[ "$1" == "--help" || "$1" == "-h" ]]; then
    show_help
    exit 0
fi

LAUNCHER_CMD="$1"

# Command without any potential path
LAUNCHER_BASE=$(basename "$1")

# Ensure the first argument is an allowed command
if [[ ! " ${ALLOWED_COMMANDS[*]} " =~ " $LAUNCHER_BASE " ]]; then
    echo "Error: First argument must be one of: ${ALLOWED_COMMANDS[*]}" >&2
    exit 1
fi

# Check if there are enough arguments after the command
if [[ $# -lt 2 ]]; then
    echo "Usage: $WRAPPER_CMD $LAUNCHER_CMD [--host my_host:123] [-np 1] [optional other flags] ./your_program [args...]" >&2
    exit 1
fi

# Find what we are actually trying to execute, so we can order arguments correctly
EXECUTABLE=$(find_executable "$@")

if [ -z "$EXECUTABLE" ]; then
    echo "Error: dmr_wrapper was not able to find an executable in the provided command" >&2
    exit 1
fi

# Find the full path of mpirun/mpiexec/prterun
LAUNCHER_WITH_PATH=$(command -v "$LAUNCHER_CMD")
if [[ -z "$LAUNCHER_WITH_PATH" ]]; then
    echo "Error: Command '$LAUNCHER_CMD' not found on your PATH." >&2
    exit 1
fi

# Ensure TMPDIR exists (DMR code + wrapper needs it)
if [[ -z "$TMPDIR" ]]; then
  export TMPDIR="/tmp"
fi

DMR_STATE_ID=$(uuidgen)
export DMR_STATE_FILE="$TMPDIR/$DMR_STATE_ID"
export DMR_RECONFIG_COUNT=0
export DMR_EXPANSION_COUNT=0
export DMR_RECONFIG_TIME=-1.0

# Flags to insert into the command
INSERTS_ARR=()
INSERTS_ARR+=(--runtime-options fwd-environment)  # Propagate environment
INSERTS_ARR+=(--prtemca ras ^slurm)  # Ensure we are not running with Slurm integration
INSERTS_ARR+=(--prtemca plm ^slurm) 

# "MCAST" seems to fail a lot causing noisy errors, so we disable it
export HCOLL_ENABLE_MCAST=0

# Make sure the mpirun/mpiexec/prterun command specifies an absolute path
# This is like providing the --prefix flag, but works better with mpirun and mpiexec
CMD_FROM_USER=("$@")
CMD_FROM_USER[0]=$LAUNCHER_WITH_PATH

IS_FIRST_EXEC=true

prrte_status=0

while true; do

    # First run, just launch the program
    if [[ "$IS_FIRST_EXEC" == "true" ]]; then
        insert_into_cmd INSERTS_ARR[@] "$EXECUTABLE" "${CMD_FROM_USER[@]}"
        "${INSERTED_CMD[@]}"
        prrte_status=$?
        IS_FIRST_EXEC=false
    else

        # A state file will have been created by DMR if we need to restart the DVM (to shrink/grow if this mechanism is used)
        # If not, this is a genuine termination and we should break out of the loop
        if [[ ! -f "$DMR_STATE_FILE" ]]; then
            break
        fi

        mapfile -d '' -t state_info < $DMR_STATE_FILE

        # The state file contains hosts and slots like my_host:123,my_host2:123 
        HOSTS="${state_info[0]}"

        insert_into_cmd INSERTS_ARR[@] "$EXECUTABLE" "${CMD_FROM_USER[@]}"
        NEW_CMD_BUILD=()

        HOST_FOUND=false

        # Rebuild the original command with new hosts
        for ((i=0; i<${#INSERTED_CMD[@]}; i++)); do

            current="${INSERTED_CMD[i]}"

            if [[ "$current" == "$EXECUTABLE" ]]; then
                if [[ ! "$HOST_FOUND" == true ]]; then
                    NEW_CMD_BUILD+=("--host" "$HOSTS")
                    HOST_FOUND=true
                fi
            fi

            # Override the old number of processes. This will lead OpenMPI to just use the maximum as determined by the slots
            if [[ "$current" == "--np" || "$current" == "-np" ]]; then
                ((i++))  # Skip the process count which comes after
                continue # Skip --np
            fi
            
            if [[ "$current" == "$EXECUTABLE" ]]; then
                break;
            fi

            # Need to replace the hosts with new hosts read from file
            if [[ "$current" == "-host" || "$current" == "--host" || \
                "$current" == "-hostfile" || "$current" == "--hostfile" || \
                "$current" == "-rankfile" || "$current" == "--rankfile" ]]; then

                NEW_CMD_BUILD+=("--host")
                ((i++))  # Skip the old hosts
                
                NEW_CMD_BUILD+=("$HOSTS")
                HOST_FOUND=true

            else
                NEW_CMD_BUILD+=("$current")
            fi
        done

        export DMR_EXPANSION_COUNT=${state_info[1]}
        export DMR_RECONFIG_COUNT=${state_info[2]}
        export DMR_RECONFIG_TIME=${state_info[3]}

        # Fully rebuild the executable + arguments from the saved state.
        ARGC=${state_info[4]}

        for ((i=0; i<ARGC; i++)); do
            NEW_CMD_BUILD+=("${state_info[$((5 + i))]}")
        done

        # Remove the state file; if somehow DMR fails to launch, it will not cause us to infinitely retry
        rm -f "$DMR_STATE_FILE"
        
        # DVM re-launch
        "${NEW_CMD_BUILD[@]}"
        prrte_status=$?
    fi
done

exit $prrte_status