#!/usr/bin/env python3

"""
Script containing python test suite for SCREAM's CIME
namelist-related infrastructure.
"""

from utils import check_minimum_python_version, expect, ensure_pylint, get_timestamp, \
                  run_cmd_assert_result, run_cmd_no_fail, run_cmd

check_minimum_python_version(3, 6)

from machines_specs import is_machine_supported

import unittest, argparse, sys, os, shutil
from pathlib import Path

EAMXX_DIR         = Path(__file__).resolve().parent.parent
EAMXX_CIME_DIR    = EAMXX_DIR / "cime_config"
EAMXX_SCRIPTS_DIR = EAMXX_DIR / "scripts"
CIME_SCRIPTS_DIR  = EAMXX_DIR.parent.parent / "cime" / "scripts"

CONFIG = {
    "machine"  : None,
}

###############################################################################
class TestBuildnml(unittest.TestCase):
###############################################################################

    ###########################################################################
    def _create_test(self, extra_args, env_changes=""):
    ###########################################################################
        """
        Convenience wrapper around create_test. Returns list of full paths to created cases. If multiple cases,
        the order of the returned list is not guaranteed to match the order of the arguments.
        """
        extra_args = extra_args.split()

        test_id = f"cmd_nml_tests-{get_timestamp()}"
        extra_args.append("-t {}".format(test_id))

        full_run = (
            set(extra_args)
            & set(["-n", "--namelist-only", "--no-setup", "--no-build", "--no-run"])
        ) == set()
        if full_run:
            extra_args.append("--wait")

        output = run_cmd_assert_result(
            self, f"{env_changes} {CIME_SCRIPTS_DIR}/create_test {' '.join(extra_args)}")
        cases = []
        for line in output.splitlines():
            if "Case dir:" in line:
                casedir = line.split()[-1]
                self.assertTrue(os.path.isdir(casedir), msg="Missing casedir {}".format(casedir))
                cases.append(casedir)

        self.assertTrue(len(cases) > 0, "create_test made no cases")

        self._dirs_to_cleanup.extend(cases)

        return cases[0] if len(cases) == 1 else cases

    ###########################################################################
    def _get_values(self, case, name, value=None, expect_equal=True):
    ###########################################################################
        """
        Queries a name, optionally checking if the value matches or does not match the
        argument.
        """
        if 'ANY' not in name:
            orig = run_cmd_assert_result(self, f"./atmquery {name} --value", from_dir=case)
            if value:
                if expect_equal:
                    self.assertEqual(orig, value, msg=name)
                else:
                    self.assertNotEqual(orig, value, msg=name)

            return [name]

        else:
            output = run_cmd_assert_result(self, f"./atmquery {name}", from_dir=case).splitlines()
            names  = [line.rsplit(": ", maxsplit=1)[0].strip() for line in output]
            values = [line.rsplit(": ", maxsplit=1)[1].strip() for line in output]

            if value:
                for orig_value in values:
                    if expect_equal:
                        self.assertEqual(orig_value, value, msg=name)
                    else:
                        self.assertNotEqual(orig_value, value, msg=name)

            return names

    ###########################################################################
    def _chg_atmconfig(self, changes, case, reset=False, expect_lost=None):
    ###########################################################################
        expect_lost = reset if expect_lost is None else expect_lost

        changes_unpacked = {}
        for name, value in changes:
            names = self._get_values(case, name, value=value, expect_equal=False)

            run_cmd_assert_result(self, f"./atmchange {name}='{value}'", from_dir=case)

            for item in names:
                changes_unpacked[item] = value

        if reset:
            run_cmd_assert_result(self, "./atmchange --reset", from_dir=case)

        #  run_cmd_assert_result(self, "./case.setup", from_dir=case)

        for name, value in changes_unpacked.items():
            self._get_values(case, name, value=value, expect_equal=not expect_lost)

    ###########################################################################
    def setUp(self):
    ###########################################################################
        self._dirs_to_cleanup = []
        self._common_case = self._create_test("SMS.ne30_ne30.F2010-SCREAMv1 --no-build")

    ###########################################################################
    def tearDown(self):
    ###########################################################################
        for item in self._dirs_to_cleanup:
            shutil.rmtree(item)

    ###########################################################################
    def test_doctests(self):
    ###########################################################################
        """
        Run doctests for all eamxx/cime_config python files and nml-related files in scripts
        """
        run_cmd_assert_result(self, "python3 -m doctest *.py", from_dir=EAMXX_CIME_DIR)
        run_cmd_assert_result(self, "python3 -m doctest atm_manip.py", from_dir=EAMXX_SCRIPTS_DIR)

    ###########################################################################
    def test_pylint(self):
    ###########################################################################
        """
        Run pylint on all eamxx/cime_config python files and nml-related files in scripts
        """
        ensure_pylint()
        run_cmd_assert_result(self, "python3 -m pylint --disable C,R,E0401,W1514 *.py", from_dir=EAMXX_CIME_DIR)
        run_cmd_assert_result(self, "python3 -m pylint --disable C,R,E0401,W1514 atm_manip.py", from_dir=EAMXX_SCRIPTS_DIR)

    ###########################################################################
    def test_xmlchange_propagates_to_atmconfig(self):
    ###########################################################################
        """
        Test that xmlchanges impact atm config files
        """
        case = self._create_test("ERS_Ln22.ne30_ne30.F2010-SCREAMv1 --no-build")

        # atm config should match case test opts
        case_rest_n = run_cmd_assert_result(self, "./xmlquery REST_N --value", from_dir=case)
        atm_freq    = run_cmd_assert_result(self, "./atmquery Frequency --value", from_dir=case)
        self.assertEqual(case_rest_n, atm_freq)

        # Change XML and check that atmquery reflects this change
        new_rest_n = "6"
        self.assertNotEqual(new_rest_n, case_rest_n)
        run_cmd_assert_result(self, f"./xmlchange REST_N={new_rest_n}", from_dir=case)
        run_cmd_assert_result(self, "./case.setup", from_dir=case)
        new_case_rest_n = run_cmd_assert_result(self, "./xmlquery REST_N --value", from_dir=case)
        new_atm_freq    = run_cmd_assert_result(self, "./atmquery Frequency --value", from_dir=case)
        self.assertEqual(new_case_rest_n, new_rest_n)
        self.assertEqual(new_atm_freq, new_rest_n)

    ###########################################################################
    def test_atmchanges_are_preserved(self):
    ###########################################################################
        """
        Test that atmchanges are not lost when eamxx setup is called
        """
        case = self._common_case

        self._chg_atmconfig([("atm_log_level", "trace"), ("output_to_screen", "true")], case)

        run_cmd_no_fail ("./case.setup", from_dir=case)

        out1 = run_cmd_no_fail("./atmquery --value atm_log_level", from_dir=case)
        out2 = run_cmd_no_fail("./atmquery --value output_to_screen", from_dir=case)

        expect (out1=="trace", "An atm change appears to have been lost during case.setup")
        expect (out2=="true",  "An atm change appears to have been lost during case.setup")

    ###########################################################################
    def test_nml_defaults_append(self):
    ###########################################################################
        """
        Test that the append attribute for array-type params in namelist defaults works as expected
        """
        case = self._common_case

        # Add testOnly proc
        self._chg_atmconfig([("mac_aero_mic::atm_procs_list", "shoc,cldFraction,spa,p3,testOnly")], case)

        # Test case 1: append to base, then to last. Should give all 3 entries stacked
        run_cmd_no_fail(f"./xmlchange --append SCREAM_CMAKE_OPTIONS='SCREAM_NUM_VERTICAL_LEV 1'", from_dir=case)
        run_cmd_no_fail("./case.setup", from_dir=case)
        self._get_values(case,"my_param","1,2,3,4,5,6")

        # Test case 2: append to last, then to base. Should give 1st and 3rd entry
        run_cmd_no_fail(f"./xmlchange --append SCREAM_CMAKE_OPTIONS='SCREAM_NUM_VERTICAL_LEV 2'", from_dir=case)
        run_cmd_no_fail("./case.setup", from_dir=case)
        self._get_values(case,"my_param","1,2,5,6")

    ###########################################################################
    def test_append(self):
    ###########################################################################
        """
        Test that var+=value syntax behaves as expected
        """
        case = self._common_case

        # Append to an existing entry
        name = 'output_yaml_files'
        out = run_cmd_no_fail(f"./atmchange {name}+=a.yaml", from_dir=case)
        out = run_cmd_no_fail(f"./atmchange {name}+=b.yaml", from_dir=case)

        # Get the yaml files
        expected =f'a.yaml, b.yaml'
        self._get_values(case, name, value=expected, expect_equal=True)

    ###########################################################################
    def test_reset_atmchanges_are_lost(self):
    ###########################################################################
        """
        Test that atmchanges are lost when resetting
        """
        case = self._common_case

        self._chg_atmconfig([("atm_log_level", "trace")], case, reset=True)

    ###########################################################################
    def test_atmchanges_are_lost_with_hack_xml(self):
    ###########################################################################
        """
        Test that atmchanges are lost if SCREAM_HACK_XML=TRUE
        """
        case = self._common_case

        run_cmd_assert_result(self, "./xmlchange SCREAM_HACK_XML=TRUE", from_dir=case)

        self._chg_atmconfig([("atm_log_level", "trace")], case, expect_lost=True)

    ###########################################################################
    def test_atmchanges_are_preserved_testmod(self):
    ###########################################################################
        """
        Test that atmchanges via testmod are preserved
        """
        def_mach_comp = \
            run_cmd_assert_result(self, "../CIME/Tools/list_e3sm_tests cime_tiny", from_dir=CIME_SCRIPTS_DIR).splitlines()[-2].split(".")[-1]

        case = self._create_test(f"SMS.ne30_ne30.F2010-SCREAMv1.{def_mach_comp}.scream-scream_example_testmod_atmchange --no-build")

        # Check that the value match what's in the testmod
        out = run_cmd_no_fail("./atmquery --value cubed_sphere_map", from_dir=case)
        expect (out=="42", "An atm change appears to have been lost during case.setup")

    ###########################################################################
    def test_atmchanges_with_namespace(self):
    ###########################################################################
        """
        Test that atmchange works when using 'namespace' syntax foo::bar
        """
        case = self._common_case

        self._chg_atmconfig([("p3::enable_precondition_checks", "false")], case)

    ###########################################################################
    def test_atmchanges_on_arrays(self):
    ###########################################################################
        """
        Test that atmchange works for array data
        """
        case = self._common_case

        self._chg_atmconfig([("surf_mom_flux", "40.0,2.0")], case)

    ###########################################################################
    def test_atmchanges_on_all_matches(self):
    ###########################################################################
        """
        Test that atmchange works with ANY
        """
        case = self._common_case

        self._chg_atmconfig([("ANY::enable_precondition_checks", "false")], case)

    ###########################################################################
    def test_atmchanges_on_all_matches_plus_spec(self):
    ###########################################################################
        """
        Test atmchange with ANY followed by an atmchange of one of them
        """
        case = self._common_case

        self._chg_atmconfig([("ANY::enable_precondition_checks", "false"),
                             ("p3::enable_precondition_checks", "true")], case)

    ###########################################################################
    def test_atmchanges_for_atm_procs_add(self):
    ###########################################################################
        """
        Test atmchanges that add atm procs
        """
        case = self._common_case

        self._chg_atmconfig([("mac_aero_mic::atm_procs_list", "shoc,cldFraction,spa,p3,testOnly")], case)

        # If we are able to change subcycles of testOnly then we know the atmchange
        # above added the necessary atm proc XML block.
        self._chg_atmconfig([("testOnly::number_of_subcycles", "42")], case)

    ###########################################################################
    def test_atmchanges_for_atm_procs_add_invalid(self):
    ###########################################################################
        """
        Test atmchanges that add atm procs
        """
        case = self._common_case

        # modifying a procs list requires known processes or "_" pre/post suffixes
        stat, out, err = run_cmd ("./atmchange mac_aero_mic::atm_procs_list=shoc,cldfraction,spa,p3,spiderman",
                                  from_dir=case)

        expect (stat!=0,"Command './atmchange mac_aero_mic::atm_procs_list=shoc,cldFraction,spa,p3,spiderman' should have failed")

    ###########################################################################
    def test_buffer_unchanged_with_bad_change_syntax(self):
    ###########################################################################
        """
        Test atmchange does not change buffer if syntax was wrong
        """
        case = self._common_case

        # Attempting a bad change should not alter the content of SCREAM_ATMCHANGE_BUFFER
        old = run_cmd_no_fail ("./xmlquery --value SCREAM_ATMCHANGE_BUFFER",from_dir=case)
        stat, out, err = run_cmd ("./atmchange foo",from_dir=case)
        expect (stat!=0,"Command './atmchange foo' should have failed")

        new = run_cmd_no_fail ("./xmlquery --value SCREAM_ATMCHANGE_BUFFER",from_dir=case)

        expect (new==old, "A bad atmchange should have not modified SCREAM_ATMCHANGE_BUFFER")

    ###########################################################################
    def test_invalid_xml_option(self):
    ###########################################################################
        """
        Test atmchange errors out with invalid param names
        """
        case = self._common_case

        stat, out, err = run_cmd ("./atmchange p3::non_existent=3",from_dir=case)
        expect (stat!=0,"Command './atmchange p3::non_existent=3' should have failed")

    ###########################################################################
    def test_atmchanges_for_atm_procs_add_group(self):
    ###########################################################################
        """
        Test atmchanges that add atm proc groups
        """
        case = self._common_case

        out = run_cmd_no_fail ("./atmchange mac_aero_mic::atm_procs_list=shoc,_my_group_",from_dir=case)

        self._chg_atmconfig([("_my_group_::atm_procs_list", "testOnly")], case)

        # If we are able to change subcycles of testOnly then we know the atmchange
        # above added the necessary atm proc XML block.
        self._chg_atmconfig([("testOnly::number_of_subcycles", "42")], case)

    ###########################################################################
    def test_atmchanges_for_atm_procs_remove(self):
    ###########################################################################
        """
        Test atmchanges that remove atm procs
        """
        case = self._common_case

        self._chg_atmconfig([("mac_aero_mic::atm_procs_list", "shoc,cldFraction,spa")], case)

        stat, output, error = run_cmd("./atmquery --grep p3",from_dir=case)
        expect (output=="", "There is still a trace of the removed process")

###############################################################################
def parse_command_line(args, desc):
###############################################################################
    """
    Parse custom args for this test suite. Will delete our custom args from
    sys.argv so that only args meant for unittest remain.
    """
    help_str = \
"""
{0} [TEST] [TEST]
OR
{0} --help

\033[1mEXAMPLES:\033[0m
    \033[1;32m# Run basic pylint and doctests for everything \033[0m
    > {0}

""".format(Path(args[0]).name)

    parser = argparse.ArgumentParser(
        usage=help_str,
        description=desc,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("-m", "--machine",
                        help="Provide machine name. This is required for full (not dry) runs")

    args, py_ut_args = parser.parse_known_args()
    sys.argv[1:] = py_ut_args

    return args

###############################################################################
def scripts_tests(machine=None):
###############################################################################
    # Store test params in environment
    if machine:
        expect(is_machine_supported(machine), "Machine {} is not supported".format(machine))
        CONFIG["machine"] = machine

    unittest.main(verbosity=2)

###############################################################################
def _main_func(desc):
###############################################################################
    scripts_tests(**vars(parse_command_line(sys.argv, desc)))

if (__name__ == "__main__"):
    _main_func(__doc__)
