#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Run the test suite, report failures
#
# To set command for running parallel jobs, define environment variable
# e.g. for bash
# export MPIRUN="mpirun -np"
# export MPIRUN="aprun -n"

import os
import sys
import time
import glob
from boututils.run_wrapper import shell
import threading
import re
import fcntl
import select

##################################################################

import argparse

parser = argparse.ArgumentParser(description="Run or build some tests.")
parser.add_argument(
    "-g",
    "--get-list",
    action="store_true",
    help="Return a list of tests that would be run/build",
)
parser.add_argument(
    "-m", "--make", action="store_true", help="Build the tests, rather then run them."
)
parser.add_argument(
    "-a", "--all", action="store_true", dest="all_tests", help="Run all tests"
)
parser.add_argument(
    "-b",
    "--set-bool",
    nargs="+",
    metavar=("value1=False", "value2=True"),
    help="Set a bool value for evaluating what scripts can be run.",
)
parser.add_argument(
    "-l",
    "--set-list",
    nargs="+",
    metavar=("test1", "test2"),
    help="Set the tests that should be run.",
)
parser.add_argument(
    "-j",
    "--jobs",
    nargs=1,
    type=int,
    dest="jobs",
    help="Set the number of cores to use in parallel.",
)

args = parser.parse_args()

##################################################################

sys.path.append("..")
from requirements import Requirements

requirements = Requirements()

if args.set_bool is not None:
    lookup = {"false": False, "no": False, "true": True, "yes": True}
    for setbool in args.set_bool:
        k, v = setbool.split("=")
        v = lookup[v.lower()]
        requirements.add(k, v, override=True)

try:
    requirements.add("make", args.make)
except RuntimeError:
    raise RuntimeError(
        "The make flag needs to be set by passing --make rather then --set-bool make=True"
    )

try:
    requirements.add("all_tests", args.all_tests)
except RuntimeError:
    raise RuntimeError(
        "The all-tests flag needs to be set by passing --all rather then --set-bool all_tests=True"
    )

##################################################################
# Parallel stuff
#
# We can run in parallel. As some checks require several threads to
# run, we take this into account when we schedule tests. Given a
# certain amount of threads that we are told to use, we try to run
# so many tests in parallel, that the total core count is not higher
# than this number. If this isn't possible because some jobs require
# more cores than we have threads, we run the remaining tests in
# serial. The number of parallel threads is called cost of a test.
#
# The beginning of MAKEFLAGS has some flags that tell us what we are
# supposed to be doing. Most of them we can ignore.
# Check what we are supposed to be doing:
if "MAKEFLAGS" in os.environ:
    makeflags = os.environ["MAKEFLAGS"]
    for token in makeflags:
        if token == "n":
            # Print recipies:
            # We ignore this for now ...
            sys.exit(0)
        elif token == "q":
            # Question mode:
            # We are never up-to-date
            sys.exit(1)
        elif token in "ksBd":
            # k - keep going
            # s - silent
            # B - always build (we do)
            # d - debug
            pass
        elif token == " ":
            break
        elif token == "-":
            # This is not from make
            break
        else:
            # Not implemented
            print("mode '%s' not implemented" % token)
            # sys.exit(42)
else:
    makeflags = ""
num_threads = 1
if args.jobs:
    num_threads = args.jobs[0]
js_read = None
# if we are running under make, we only have one thread (called job
# in make terms) for sure, but might be able to run more threads
# (depending what other jobs are currently running)
# Additional jobs come in form of `tokens` that we need to give
# back, if we are finished.
# https://www.gnu.org/software/make/manual/html_node/POSIX-Jobserver.html
if "--jobserver-auth=" in makeflags or "--jobserver-fds=" in makeflags:
    try:
        offset = makeflags.index("--jobserver-auth=")
    except ValueError:
        offset = makeflags.index("--jobserver-fds=")
    file_desc = [int(i) for i in re.findall("\d+", makeflags[offset:])[0:2]]
    js_read = file_desc[0]
    js_write = file_desc[1]
    js_tokens = b""


class Test(threading.Thread):
    def __init__(self, name, cmd, width, timeout):
        threading.Thread.__init__(self)
        self.name = name
        self.cmd = cmd
        self.width = width
        self.local = threading.local()
        self.cost = self._cost()
        self.timeout = timeout

    def run(self):
        # Check requirements
        self.local.req_met, self.local.req_expr = requirements.check(
            self.name + "/runtest"
        )
        self.local.start_time = time.time()
        if not self.local.req_met:
            print(
                "{:{width}}".format(self.name, width=self.width)
                + "S - {0} => False".format(self.local.req_expr)
            )
            sys.stdout.flush()
            self.status = 0
        else:
            # Run test, piping stdout so it is not sent to console
            command = "cd {}; timeout {} {}".format(self.name, self.timeout, self.cmd)
            self.status, self.output = shell(command, pipe=True)
            print("{:{width}}".format(self.name, width=self.width), end="")
            if self.status == 0:
                # ✓ Passed
                print("\u2713", end="")  # No newline
            elif self.status == 124:
                # 💤 timeout
                print("\U0001f4a4", end="")  # No newline
                self.output += "\n(It is likely that a timeout occured)"
            else:
                # ❌ Failed
                print("\u274C", end="")  # No newline
            print(" %7.3f s" % (time.time() - self.local.start_time), flush=True)

    def _cost(self):
        if self.cmd == "make":
            return 1
        with open(self.name + "/" + self.cmd, "r", encoding="utf8") as filein:
            contents = filein.read()
        # Find all lines starting with '#cores' or '#Cores:'
        match = re.findall("^\s*\#\s?[Cc]ores:\s*?(\d+)", contents, re.MULTILINE)
        if len(match) > 1:
            raise RuntimeError(
                "Found more then one match for core-count in " + self.name
            )
        if len(match) == 0:
            # default; no mpi
            return 1
        c = int(match[0])
        if c < 0:
            raise RuntimeError(
                "Core-count is %d and thus not positive in " % c + self.name
            )
        return c


##################################################################

if args.set_list is None:
    # Get list of directories containing test cases
    try:
        # Requires python >= 3.5
        tests = glob.iglob("**/runtest", recursive=True)
    except TypeError:
        # Fall back - check only a few folders ...
        tests = glob.glob("*/runtest")
        tests += glob.glob("*/*/runtest")
        tests += glob.glob("*/*/*/runtest")
        tests += glob.glob("*/*/*/*/runtest")

    tests = [x.rsplit("/", 1)[0] for x in tests]

else:
    # Take the user provided list
    tests = [x.rstrip("/") for x in args.set_list]


if args.get_list:
    for test in tests:
        req_met, _ = requirements.check(test + "/runtest")
        if req_met:
            print(test)
    sys.exit(0)


# A function to get more threads from the job server
def get_threads():
    global js_read
    if js_read:
        global cost_remain, js_tokens, avail_threads, num_threads
        try:
            fl = fcntl.fcntl(js_read, fcntl.F_GETFL)
        except OSError:
            print("Warning - we are not part of the job pool 😭, running in serial")
            js_read = None
            return
        isreadable, _, _ = select.select([js_read], [], [], 0)
        if isreadable != []:
            # Only set shortly to non-blocking, and only if we
            # expect we can read. Old make uses blocking pipes, and
            # assumes if it is non-blocking, a job is available.
            fcntl.fcntl(js_read, fcntl.F_SETFL, fl | os.O_NONBLOCK)
            new_token = os.read(js_read, cost_remain - avail_threads)
            fcntl.fcntl(js_read, fcntl.F_SETFL, fl)
            num_threads += len(new_token)
            js_tokens += new_token
            avail_threads += len(new_token)


##################################################################
# Run the actual test
# Kill tests that take longer than 10 minutes (default)
timeout = os.environ.get("BOUT_TEST_TIMEOUT", "10m" if not args.all_tests else "30m")
command = "./runtest" if not args.make else "make"

savepath = os.getcwd()  # Save current working directory
failed = []

start_time = time.time()

test_type = "Making" if args.make else "Running"

print(
    "======= {} {} {} tests ========".format(
        test_type, len(tests), savepath.split("/")[-1]
    )
)

longest = max([len(s) for s in tests])
avail_threads = num_threads
tester = sorted(
    [Test(t, command, longest + 1, timeout) for t in tests],
    key=lambda x: x.cost,
    reverse=True,
)
cost_remain = sum(x.cost for x in tester)
get_threads()

torun = [i for i in range(len(tests))]
running = []

while len(torun) + len(running):
    if avail_threads:
        # See whether we can start one
        for i in torun[:]:
            if tester[i].cost <= avail_threads:
                torun.remove(i)
                job = tester[i]
                job.start()
                running.append(job)
                avail_threads -= job.cost
                cost_remain -= job.cost
        # if nothing is running, start the first
        if running == []:
            tostart = torun.pop(0)
            job = tester[tostart]
            job.start()
            running.append(job)
            avail_threads -= job.cost
            cost_remain -= job.cost

    # Are any jobs finished
    for job in running:
        job.join(0.1)
        if not job.is_alive():
            if job.status:
                failed.append([job.name, job.output])
            avail_threads += job.cost
            running.remove(job)

    # Try to get more jobs
    if torun:
        get_threads()
    else:
        # Free not-needed jobs
        if avail_threads and js_read:
            old_tokens = js_tokens[0:avail_threads]
            js_tokens = js_tokens[avail_threads:]
            os.write(js_write, old_tokens)
            num_threads -= avail_threads
            avail_threads = 0

elapsed_time = time.time() - start_time

print("\n")

# Return remaining tokens
if js_read is not None:
    if js_tokens != "":
        os.write(js_write, js_tokens)

if failed:
    print("======= FAILURES ========")
    for test, output in failed:
        # Note: need Unicode string in case output contains unicode
        print("\n----- {0} -----\n{1}".format(test, output))

    print(
        "======= {0} failed in {1:.2f} seconds ========".format(
            len(failed), elapsed_time
        )
    )

    sys.exit(1)

else:
    print("======= All tests passed in {0:.2f} seconds =======".format(elapsed_time))
