#!/usr/bin/env python3

"""
See https://acme-climate.atlassian.net/wiki/spaces/NGDNA/pages/3831923056/EAMxx+BFB+hashing
for full explanation.

This script is used by the scream-internal_diagnostics_level testmod to check
hash output after a test has run.
"""

import sys, re, glob, pathlib, argparse, gzip

from utils import run_cmd_no_fail, expect, GoodFormatter

###############################################################################
def parse_command_line(args, description):
###############################################################################
    parser = argparse.ArgumentParser(
        usage="""\n{0} <CASE_DIR> [<param>=<val>] ...
OR
{0} --help

\033[1mEXAMPLES:\033[0m
    \033[1;32m# Run hash checker on /my/case/dir \033[0m
    > {0} /my/case/dir
""".format(pathlib.Path(args[0]).name),
        description=description,
        formatter_class=GoodFormatter
    )

    parser.add_argument(
        "case_dir",
        help="The test case you want to check"
    )

    return parser.parse_args(args[1:])

###############################################################################
def readall(fn):
###############################################################################
    with open(fn,'r') as f:
        txt = f.read()
    return txt

###############################################################################
def greptxt(pattern, txt):
###############################################################################
    return re.findall('(?:' + pattern + ').*', txt, flags=re.MULTILINE)

###############################################################################
def grep(pattern, fn):
###############################################################################
    txt = readall(fn)
    return greptxt(pattern, txt)

###############################################################################
def get_log_glob_from_atm_modelio(case_dir):
###############################################################################
    filename = case_dir / 'CaseDocs' / 'atm_modelio.nml'
    ln = grep('diro = ', filename)[0]
    run_dir = pathlib.Path(ln.split()[2].split('"')[1])
    ln = grep('logfile = ', filename)[0]
    atm_log_fn = ln.split()[2].split('"')[1]
    return str(run_dir / '**' / f'atm.log.*')

###############################################################################
def get_hash_lines(fn,start_from_line):
###############################################################################
    times = []
    hash_lines = []

    lines = []
    with gzip.open(fn,'rt') as file:
        start_line_found = False
        for line in file:
            if start_line_found:
                lines.append(line)
            elif start_from_line in line:
                start_line_found = True

    i = 0
    while i < len(lines):
        line = lines[i]
        i = i+1
        # eamxx hash line has the form "eamxx hash> date=YYYY-MM-DD-XXXXX (STRING), naccum=INT
        # The INT at the end says how many of the following line contain hashes for this proc-step
        if "eamxx hash>" in line:
            times.append(parse_time(line))
            naccum_index = line.index("naccum=") + len("naccum=")
            N = int(line[naccum_index:])
            hashes = []
            for j in range(N):
                hashes.append(lines[i].strip('\n'))
                i += 1
            hash_lines.append(hashes)

    return times, hash_lines

###############################################################################
def parse_time(hash_ln):
###############################################################################
    # hash_ln has the form "eamxx hash> date=YYYY-MM-DD-XXXXX (STRING), naccum=INT
    return hash_ln.split()[2].split('=')[1]

###############################################################################
def all_equal(t1, t2):
###############################################################################
    if len(t1) != len(t2): return False
    for i in range(len(t1)):
        if t1[i] != t2[i]: return False
    return True

###############################################################################
def find_first_index_at_time(times, time):
###############################################################################
    for i, t in enumerate(times):
        if t==time: return i
    return None

###############################################################################
def diff(times, l1, l2):
###############################################################################
    diffs = []
    for i in range(len(l1)):
        if l1[i] != l2[i]:
            diffs.append((times[i], l1[i], l2[i]))
    return diffs

###############################################################################
def get_model_start_of_step_lines (atm_log):
###############################################################################
    lines = []
    with gzip.open(atm_log,'rt') as file:
        for line in file:
            if "model start-of-step time" in line:
                lines.append(line)
    return lines

###############################################################################
def check_hashes_ers(case_dir):
###############################################################################
    case_dir_p = pathlib.Path(case_dir)
    expect(case_dir_p.is_dir(), f"{case_dir} is not a dir")

    # Look for the two atm.log files.
    glob_pat = get_log_glob_from_atm_modelio(case_dir_p)
    atm_fns = glob.glob(glob_pat, recursive=True)
    if len(atm_fns) == 0:
        print('Could not find atm.log files with glob string {}'.format(glob_pat))
        return False
    if len(atm_fns) == 1:
        # This is the first run. Exit and wait for the second
        # run. (POSTRUN_SCRIPT is called after each of the two runs.)
        print('Exiting on first run.')
        return True
    else:
        expect(len(atm_fns)==2,
               "Error! Found more than 2 atm log files. Not sure what to do here."
               " NOTE: if you run ERS test twice, clear the run folder from old logs first.")

    atm_fns.sort()
    print('Diffing base {} and restart {}'.format(atm_fns[0], atm_fns[1]))

    start_line = get_model_start_of_step_lines(atm_fns[1])[0]

    # Extract hash lines, along with their timestamps, but ignore anything
    # before the line $start_line
    hash_lines = []
    times = []
    for f in atm_fns:
        t,h = get_hash_lines(f,start_line)
        hash_lines.append(h)
        times.append(t)

    run1_hashes = hash_lines[0]
    run2_hashes = hash_lines[1]
    if len(run1_hashes) != len(run2_hashes):
        print('Number of hash lines starting at restart time do not agree.')
        print(f' run1 number of hash lines (after rest time): {len(run1_hashes)}')
        print(f' run2 number of hash lines: {len(run2_hashes)}')
        print(run1_hashes)
        return False

    diffs = diff(times[1],run1_hashes, run2_hashes)

    if run1_hashes==run2_hashes:
        print('OK')
        return True
    else:
        # Nicely print the diffing hashes, grouped by their timestamp
        print(f'FOUND DIFFS')
        for d in diffs[-10:]:
            print (f' timestamp: {d[0]}')
            slen = len(d[1][0])
            print("run1".center(slen) + " | " + "run2".center(slen))
            for r1,r2 in zip(d[1],d[2]):
                # Print only the hash entry(ies) that diff in this timestamp
                if r1!=r2:
                    print (f'{r1} | {r2}')
        return False

###############################################################################
def _main_func(description):
###############################################################################
    success = check_hashes_ers(**vars(parse_command_line(sys.argv, description)))
    sys.exit(0 if success else 1)

###############################################################################

if (__name__ == "__main__"):
    _main_func(__doc__)
