#!/usr/bin/env python3
# Thanks @skchronicles
# source: https://raw.githubusercontent.com/OpenOmics/mr-seek/2ecbbb2628b7102bf2cc23bc946858de2e09929f/workflow/scripts/jobby
# -*- coding: UTF-8 -*-

"""
ABOUT:
    `jobby` will take your past jobs and display their job information.
    Why? We have pipelines running on several different clusters and
    job schedulers. `jobby` is an attempt to centralize and abstract
    the process of querying different job schedulers. On each supported
    target system, `jobby` will attempt to determine the best method for
    getting job information to return to the user in a standardized
    format and unified cli.

REQUIRES:
  - python>=3.5

DISCLAIMER:
                        PUBLIC DOMAIN NOTICE
            NIAID Collaborative Bioinformatics Resource (NCBR)

       National Institute of Allergy and Infectious Diseases (NIAID)
    This software/database is a "United  States Government Work" under
    the terms of the United  States Copyright Act.  It was written as
    part of the author's official duties as a United States Government
    employee and thus cannot be copyrighted. This software is freely
    available to the public for use.

    Although all  reasonable  efforts have been taken  to ensure  the
    accuracy and reliability of the software and data, NCBR do not and
    cannot warrant the performance or results that may  be obtained by
    using this software or data. NCBR and NIH disclaim all warranties,
    express  or  implied,  including   warranties   of   performance,
    merchantability or fitness for any particular purpose.

    Please cite the author and NIH resources like the "Biowulf Cluster"
    in any work or product based on this material.

USAGE:
  $ jobby [OPTIONS] JOB_ID [JOB_ID ...]

EXAMPLE:
  $ jobby 18627545 15627516 58627597
"""

# Python standard library
from __future__ import print_function, division
import sys, os, subprocess, math, re
from subprocess import PIPE
import argparse  # added in python/3.5
import textwrap  # added in python/3.5
import tempfile  # added in python/3.5

# Jobby metadata
__version__ = "v0.2.0"
__authors__ = "Skyler Kuhn"
__email__ = "skyler.kuhn@nih.gov"
__home__ = os.path.dirname(os.path.abspath(__file__))
_name = os.path.basename(sys.argv[0])
_description = "Will take your job(s)... and display their information!"


# Classes
class Colors:
    """Class encoding for ANSI escape sequences for styling terminal text.
    Any string that is formatting with these styles must be terminated with
    the escape sequence, i.e. `Colors.end`.
    """

    # Escape sequence
    end = "\33[0m"
    # Formatting options
    bold = "\33[1m"
    italic = "\33[3m"
    url = "\33[4m"
    blink = "\33[5m"
    highlighted = "\33[7m"
    # Text Colors
    black = "\33[30m"
    red = "\33[31m"
    green = "\33[32m"
    yellow = "\33[33m"
    blue = "\33[34m"
    pink = "\33[35m"
    cyan = "\33[96m"
    white = "\33[37m"
    # Background fill colors
    bg_black = "\33[40m"
    bg_red = "\33[41m"
    bg_green = "\33[42m"
    bg_yellow = "\33[43m"
    bg_blue = "\33[44m"
    bg_pink = "\33[45m"
    bg_cyan = "\33[46m"
    bg_white = "\33[47m"


# Helper Functions
def which(cmd, path=None):
    """Checks if an executable is in $PATH
    @param cmd <str>:
        Name of executable to check
    @param path <list>:
        Optional list of PATHs to check [default: $PATH]
    @return <boolean>:
        True if exe in PATH, False if not in PATH
    """
    if path is None:
        path = os.environ["PATH"].split(os.pathsep)

    for prefix in path:
        filename = os.path.join(prefix, cmd)
        executable = os.access(filename, os.X_OK)
        is_not_directory = os.path.isfile(filename)
        if executable and is_not_directory:
            return True

    return False


def err(*message, **kwargs):
    """Prints any provided args to standard error.
    kwargs can be provided to modify print functions
    behavior.
    @param message <any>:
        Values printed to standard error
    @params kwargs <print()>
        Key words to modify print function behavior
    """
    print(*message, file=sys.stderr, **kwargs)


def fatal(*message, **kwargs):
    """Prints any provided args to standard error
    and exits with an exit code of 1.
    @param message <any>:
        Values printed to standard error
    @params kwargs <print()>
        Key words to modify print function behavior
    """
    err(*message, **kwargs)
    sys.exit(1)


def get_toolkit(tool_list):
    """Finds the best suited tool from a list of
    possible choices. Assumes tool list is already
    ordered from the best to worst choice. The first
    tool found in a user's $PATH is returned.
    @param tool_list list[<str>]:
        List of ordered tools to find
    @returns best_choice <str>:
        First tool found in tool_list
    """
    best_choice = None
    for exe in tool_list:
        if which(exe):
            best_choice = exe
            break

    # Did not find any tools
    # to potentially use
    if not best_choice:
        err("Error: Did not find any tools to get job information!")
        fatal(
            "Expected one of the following tools to be in $PATH:"
            "\t{0}".format(tool_list)
        )

    return best_choice


def add_missing(linelist, insertion_dict):
    """Adds missing information to a list. This can be used
    to add missing job information fields to the results of
    job querying tool.
    @param linelist list[<str>]:
        List containing job information for each field of interest
    @param insertion_dict dict[<int>] = str
        Dictionary used to insert missing information to a given
        index, where the keys are indices of the `linelist` and the
        values are information to add. Please note that the indices
        should be zero based. Note that multiple consecutive values
        should be inserted at once as a list, see example below:
    Example:
        add_missing([0,1,2,3,4], {3:['+','++'], 1:'-', 4:'@'})
        >> [0, '-', 1, 2, '+', '++', 3, '@', 4]
    """
    # Get the order of indices
    # add missing information
    # starting from largest to
    # smallest, if we insert
    # missing values in this
    # order we do not need to
    # calculate the offset of
    # new indices
    tmp_list = linelist
    indices = sorted(list(insertion_dict.keys()), reverse=True)
    for i in indices:
        # Check if multiple values
        # need to be inserted at a
        # given index
        if isinstance(insertion_dict[i], list):
            for v in reversed(insertion_dict[i]):
                tmp_list.insert(i, v)
        else:
            tmp_list.insert(i, insertion_dict[i])
    return tmp_list


def convert_size(size_bytes):
    """Converts bytes to a human readable format."""
    # Sizes range from B to YiB,
    # warning larger sizes storage
    # may results in blackhole
    size_name = ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB")
    if size_bytes == 0:
        return "0B"
    i = int(math.floor(math.log(size_bytes, 1024)))
    p = math.pow(1024, i)
    s = round(size_bytes / p, 2)
    return "{0}{1}".format(s, size_name[i])


def to_bytes(size):
    """Convert a human readable size unit into bytes.
    Returns None if cannot convert/parse provided size."""
    size2bytes = {
        "b": 1,
        "bytes": 1,
        "byte": 1,
        "k": 1024,
        "kib": 1024,
        "kb": 1000,
        "m": 1024**2,
        "mib": 1024**2,
        "mb": 1000**2,
        "g": 1024**3,
        "gib": 1024**3,
        "gb": 1000**3,
        "t": 1024**4,
        "tib": 1024**4,
        "tb": 1000**4,
        "p": 1024**5,
        "pib": 1024**5,
        "pb": 1000**5,
        "e": 1024**6,
        "eib": 1024**6,
        "eb": 1000**6,
        "z": 1024**7,
        "zib": 1024**7,
        "zb": 1000**7,
        "y": 1024**8,
        "yib": 1024**8,
        "yb": 1000**8,
    }

    size = size.replace(" ", "")
    match = re.search("(?P<size>[0-9.]+)(?P<units>[a-zA-Z]+)$", size)

    if match:
        human_units = match.group("units").lower()
        human_units = human_units.lstrip().rstrip()
        scaling_factor = size2bytes[human_units]
        bytes = int(math.ceil(scaling_factor * float(match.group("size"))))
    else:
        # Cannot parse units,
        # cannot convert value
        # into bytes
        return None

    return bytes


# Core logic for getting
# job information
def sge(jobs, threads, tmp_dir):
    """Displays SGE job information to standard output.
    @param sub_args <parser.parse_args() object>:
        Parsed command-line arguments
    @return None
    """
    # NOTE: add later for SGE cluster
    pass


def uge(jobs, threads, tmp_dir):
    """Displays UGE job information to standard output.
    @param sub_args <parser.parse_args() object>:
        Parsed command-line arguments
    @return None
    """
    # NOTE: add later for LOCUS cluster
    pass


def dashboard_cli(jobs, threads=1, tmp_dir=None):
    """Biowulf-specific tool to get SLURM job information.
    HPC staff recommend using this over the default slurm
    `sacct` command for performance reasons. By default,
    the `dashboard_cli` returns information for the following
    fields:
      jobid         state         submit_time   partition     nodes
      cpus          mem           timelimit     gres          dependency
      queued_time   state_reason  start_time    elapsed_time  end_time
      cpu_max       mem_max       eval
    Runs command:
        $  dashboard_cli jobs \\
            --joblist 12345679,12345680 \\
            --fields FIELD,FIELD,FIELD \\
            --tab --archive
    """
    fields = [
        "jobid",
        "jobname",
        "state",
        "partition",
        "gres",
        "cpus",
        "mem",
        "cpu_max",
        "mem_max",
        "timelimit",
        "queued_time",
        "start_time",
        "end_time",
        "elapsed_time",
        "nodelist",
        "user",
        "std_out",
        "std_err",
        "work_dir",
    ]

    # Display header information,
    # --tab option does not print
    # the header
    print("\t".join(fields))
    # Display job information
    cmd = subprocess.run(
        "dashboard_cli jobs --archive --tab --joblist {0} --fields {1}".format(
            ",".join(jobs), ",".join(fields)
        ),
        stdout=PIPE,
        stderr=PIPE,
        universal_newlines=True,
        shell=True,
    )

    # Check for failure
    # of the last command
    if cmd.returncode != 0:
        err("\nError: Failed to get job information with 'dashboard_cli'!")
        err("Please see error message below:")
        fatal("  └── ", cmd.stderr)

    print(cmd.stdout.rstrip("\n"))


def sacct(jobs, threads=1, tmp_dir=None):
    """Generic tool to get SLURM job information.
    `sacct` should be available on all SLURM clusters.
    The `dashboard_cli` is prioritized over using `sacct`
    due to perform reasons; however, this method will be
    portable across different SLURM clusters. To get maximum
    memory usage for a job, we will need to parse the MaxRSS
    field from the `$SLURM_JOBID.batch` lines.
    Returns job information for the following fields:
      jobid        jobname    state      partition    reqtres
      alloccpus    reqmem     maxrss     timelimit    reserved
      start        end        elapsed    nodelist     user
      workdir
    To get maximum memory usage for a job, we will need to parse
    the MaxRSS fields from the `$SLURM_JOBID.batch` lines.
    Runs command:
        $  sacct -j 12345679,12345680 \\
            --fields FIELD,FIELD,FIELD \\
            -P --delimiter $'\t'
    """
    header = [
        "jobid",
        "jobname",
        "state",
        "partition",
        "gres",
        "cpus",
        "mem",
        "cpu_max",
        "mem_max",
        "timelimit",
        "queued_time",
        "start_time",
        "end_time",
        "elapsed_time",
        "nodelist",
        "user",
        "std_out",
        "std_err",
        "work_dir",
    ]
    fields = [
        "jobid",
        "jobname",
        "state",
        "partition",
        "reqtres",
        "alloccpus",
        "reqmem",
        "maxrss",
        "timelimit",
        "reserved",
        "start",
        "end",
        "elapsed",
        "nodelist",
        "user",
        "workdir",
    ]

    # Missing std_out and std_err
    missing_fields = {15: ["-", "-"]}
    # Display header information,
    print("\t".join(header))
    # Display job information
    cmd = subprocess.run(
        "sacct -j {0} -P --delimiter $'\\t' --format={1}".format(
            ",".join(jobs), ",".join(fields)
        ),
        stdout=PIPE,
        stderr=PIPE,
        universal_newlines=True,
        shell=True,
    )

    # Check for failure
    # of the last command
    if cmd.returncode != 0:
        err("\nError: Failed to get job information with 'dashboard_cli'!")
        err("Please see error message below:")
        fatal("  └── ", cmd.stderr)

    # Get max memory information,
    # Stored as $SLURM_JOBID.batch
    # in the MaxRSS field
    j2m = {}
    # Remove trailing newline from
    # standard output and split lines
    # on remaining newline characters
    job_information = cmd.stdout.rstrip("\n").split("\n")
    for i, line in enumerate(job_information):
        if i < 1:
            # skip over header
            continue
        linelist = line.lstrip().rstrip().split("\t")
        if linelist[0].endswith(".batch"):
            jobid = linelist[0].strip().split(".")[0]
            maxmem = linelist[7].replace(" ", "")
            mem_bytes = to_bytes(maxmem)
            if not mem_bytes:
                # Could not convert
                # max_mem value into
                # bytes
                j2m[jobid] = "-"
                continue  # goto next line

            human_readable_mem = convert_size(mem_bytes)
            j2m[jobid] = human_readable_mem

    # Display the results
    for i, line in enumerate(job_information):
        if i < 1:
            # skip over header
            continue
        linelist = line.lstrip().rstrip().split("\t")
        jobid = linelist[0].strip()
        if "." not in jobid:
            try:
                max_mem = j2m[jobid]
            except KeyError:
                # Job maybe still be
                # running or in a non-
                # completed state.
                max_mem = "-"
            status = linelist[2].split(" ")[0]
            linelist[2] = status
            missing_fields[8] = max_mem
            linelist = add_missing(linelist, missing_fields)
            linelist = [info if info else "-" for info in linelist]
            print("\t".join(linelist))


def slurm(jobs, threads, tmp_dir):
    """Displays SLURM job information to standard output.
    @param sub_args <parser.parse_args() object>:
        Parsed command-line arguments
    @return None
    """
    # Try to use the following tools in this
    # order to get job information!
    # [1] `dashboard_cli` is Biowulf-specific
    # [2] `sacct` should always be there
    tool_priority = ["dashboard_cli", "sacct"]
    job_tool = get_toolkit(tool_priority)
    # Get information about each job
    # must use eval() to make string
    # to callable function
    eval(job_tool)(jobs=jobs, threads=threads, tmp_dir=tmp_dir)


def jobby(args):
    """
    Wrapper to each supported job scheduler: slurm, etc.
    Each scheduler has a custom handler to most effectively
    get and parse job information.
    @param sub_args <parser.parse_args() object>:
        Parsed command-line arguments
    @return None
    """
    # Get command line options
    abstract_handler = None
    job_ids = args.JOB_ID
    scheduler = args.scheduler
    threads = args.threads
    tmp_dir = args.tmp_dir

    # Set handler for each
    # supported scheduler
    if scheduler == "slurm":
        abstract_handler = slurm
    else:
        # Unsupported job scheduler,
        # needs to be implemented
        fatal('Error: "{0}" is an unsupported job scheduler!'.format(scheduler))

    # Display job(s) information
    # to standard output
    abstract_handler(jobs=job_ids, threads=threads, tmp_dir=tmp_dir)


# Parse command-line arguments
def parsed_arguments(name, description):
    """Parses user-provided command-line arguments. This requires
    argparse and textwrap packages. To create custom help formatting
    a text wrapped docstring is used to create the help message for
    required options. As so, the help message for require options
    must be suppressed. If a new required argument is added to the
    cli, it must be updated in the usage statement docstring below.
    @param name <str>:
        Name of the pipeline or command-line tool
    @param description <str>:
        Short description of pipeline or command-line tool
    """
    # Add styled name and description
    c = Colors
    styled_name = "{0}{1}{2}{3}{4}".format(c.bold, c.bg_black, c.cyan, name, c.end)
    description = "{0}{1}{2}".format(c.bold, description, c.end)
    temp = tempfile.gettempdir()

    # Please note: update the usage statement
    # below if a new option is added!
    usage_statement = textwrap.dedent(
        """\
        {0}: {1}

        {3}{4}Synopsis:{5}
          $ {2} [--version] [--help] \\
                [--scheduler {{slurm | ...}}] \\
                [--threads THREADS] [--tmp-dir TMP_DIR] \\
                <JOB_ID [JOB_ID ...]>

        {3}{4}Description:{5}
          {2} will take your past jobs and display their job information
        in a standardized format. Why???! We have pipelines running on several
        different clusters (using different job schedulers). {2} centralizes
        and abstracts the process of querying different job schedulers within
        a unified command-line interface.

          For each supported scheduler, jobby will determine the best method
        on a given target system for getting job information to return to the
        user in a common output format.

        {3}{4}Required Positional Arguments:{5}
          <JOB_ID [JOB_ID ...]>
                               Identiers of past jobs. One or more JOB_IDs
                               can be provided. Multiple JOB_IDs should be
                               separated by a space. Information for each
                               of the JOB_IDs will be displayed to standard
                               output. Please see example section below for
                               more information.

        {3}{4}Options:{5}
          -s,--scheduler {{slurm | ...}}
                                @Default: slurm
                                Job scheduler. Defines the job scheduler
                                of the target system. Additional support
                                for more schedulers coming soon!
                                  @Example: --scheduler slurm
          -n, --threads THREADS
                                @Default: 1
                                Number of threads to query the scheduler
                                in parallel.
                                  @Example: --threads: 8
          -t, --tmp-dir TMP_DIR
                                @Default: {7}/
                                Temporary directory. Path on the filesystem
                                for writing temporary output files. Ideally,
                                this path should point to a dedicated space
                                on the filesystem for writing tmp files. If
                                you need to inject a variable into this path
                                that should NOT be expanded, please quote the
                                options value in single quotes. The default
                                location of this option is set to the system
                                default via the $TMPDIR environment variable.
                                  @Example: --tmp-dir '/scratch/$USER/'

          -h, --help            Shows help and usage information and exits.
                                  @Example: --help

          -v, --version         Displays version information and exits.
                                  @Example: --version
        """.format(
            styled_name, description, name, c.bold, c.url, c.end, c.italic, temp
        )
    )

    # Display example usage in epilog
    run_epilog = textwrap.dedent(
        """\
        {2}{3}Example:{4}
        # Please avoid running jobby
        # on a cluster's head node!
        ./jobby -s slurm  -n 4 18627542 13627516 58627597 48627666

        {2}{3}Version:{4}
          {1}
        """.format(
            name, __version__, c.bold, c.url, c.end
        )
    )

    # Create a top-level parser
    parser = argparse.ArgumentParser(
        usage=argparse.SUPPRESS,
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=usage_statement,
        epilog=run_epilog,
        add_help=False,
    )

    # Required Positional Arguments
    # List of JOB_IDs, 1 ... N_JOB_IDS
    parser.add_argument("JOB_ID", nargs="+", help=argparse.SUPPRESS)

    # Options
    # Adding version information
    parser.add_argument(
        "-v",
        "--version",
        action="version",
        version="%(prog)s {}".format(__version__),
        help=argparse.SUPPRESS,
    )

    # Add custom help message
    parser.add_argument("-h", "--help", action="help", help=argparse.SUPPRESS)

    # Base directory to write
    # temporary/intermediate files
    parser.add_argument(
        "-t",
        "--tmp-dir",
        type=str,
        required=False,
        default=temp,
        help=argparse.SUPPRESS,
    )

    # Number of threads for the
    # pipeline's main proceess
    # This is only applicable for
    # local rules or when running
    # in local mode.
    parser.add_argument(
        "-n", "--threads", type=int, required=False, default=1, help=argparse.SUPPRESS
    )

    # Job scheduler to query,
    # available: SLURM, ...
    # More coming soon!
    parser.add_argument(
        "-s",
        "--scheduler",
        type=lambda s: str(s).lower(),
        required=False,
        default="slurm",
        choices=["slurm"],
        help=argparse.SUPPRESS,
    )

    # Define handlers for each sub-parser
    parser.set_defaults(func=jobby)
    # Parse command-line args
    args = parser.parse_args()

    return args


def main():
    # Sanity check for usage
    if len(sys.argv) == 1:
        # Nothing was provided
        fatal("Invalid usage: {} [-h] [--version] ...".format(_name))

    # Collect args for sub-command
    args = parsed_arguments(name=_name, description=_description)

    # Display version information
    err("{} ({})".format(_name, __version__))
    # Mediator method to call the
    # default handler function
    args.func(args)


if __name__ == "__main__":
    main()
