#!/usr/bin/env python

from __future__ import print_function

import argparse
import glob
import io
import os
import re
import sys


# Fortran tokenization

re_module = re.compile(r"^ *module +([a-z_0-9]+)")
re_use = re.compile(r"^ *use +([a-z_0-9]+)")
re_cpp_define = re.compile(r"^ *# *define +[_a-zA-Z][_a-zA-Z0-9]")
re_cpp_undef = re.compile(r"^ *# *undef +[_a-zA-Z][_a-zA-Z0-9]")
re_cpp_ifdef = re.compile(r"^ *# *ifdef +[_a-zA-Z][_a-zA-Z0-9]")
re_cpp_ifndef = re.compile(r"^ *# *ifndef +[_a-zA-Z][_a-zA-Z0-9]")
re_cpp_if = re.compile(r"^ *# *if +")
re_cpp_else = re.compile(r"^ *# *else")
re_cpp_endif = re.compile(r"^ *# *endif")
re_cpp_include = re.compile(r"^ *# *include *[<\"']([a-zA-Z_0-9\.]+)[>\"']")
re_f90_include = re.compile(r"^ *include +[\"']([a-zA-Z_0-9\.]+)[\"']")
re_program = re.compile(r"^ *program +([a-z_0-9]+)", re.IGNORECASE)
re_end = re.compile(r"^ *end *(module|procedure) ", re.IGNORECASE)
# NOTE: This excludes comments and tokens with substrings containing `function`
# or `subroutine`, but will fail if the keywords appear in other contexts.
re_procedure = re.compile(
    r"^[^!]*(?<![a-z_])(function|subroutine)(?![a-z_])",
    re.IGNORECASE
)


# Preprocessor expression tokenization
# NOTE: Labels and attributes could be assigned here, but for now we just use
#   the token string as the label.
cpp_scanner = re.Scanner([
  (r'defined', lambda scanner, token: token),
  (r'[_A-Za-z][_0-9a-zA-Z]*', lambda scanner, token: token),
  (r'[0-9]+', lambda scanner, token: token),
  (r'\(', lambda scanner, token: token),
  (r'\)', lambda scanner, token: token),
  (r'\*', lambda scanner, token: token),
  (r'/', lambda scanner, token: token),
  (r'\+', lambda scanner, token: token),
  (r'-', lambda scanner, token: token),
  (r'!', lambda scanner, token: token),
  (r'>>', lambda scanner, token: token),
  (r'>=', lambda scanner, token: token),
  (r'>', lambda scanner, token: token),
  (r'<<', lambda scanner, token: token),
  (r'<=', lambda scanner, token: token),
  (r'<', lambda scanner, token: token),
  (r'==', lambda scanner, token: token),
  (r'&&', lambda scanner, token: token),
  (r'&', lambda scanner, token: token),
  (r'\|\|', lambda scanner, token: token),
  (r'\|', lambda scanner, token: token),
  (r'^ *\# *if', None),
  (r'\s+', None),
])


cpp_operate = {
    '(': lambda x: x,
    '!': lambda x: not x,
    'defined': lambda x, y: x in y,
    '*': lambda x, y: x * y,
    '/': lambda x, y: x // y,
    '+': lambda x, y: x + y,
    '-': lambda x, y: x - y,
    '>>': lambda x, y: x >> y,
    '<<': lambda x, y: x << y,
    '==': lambda x, y: x == y,
    '>': lambda x, y: x > y,
    '>=': lambda x, y: x >= y,
    '<': lambda x, y: x < y,
    '<=': lambda x, y: x <= y,
    '&': lambda x, y: x & y,
    '^': lambda x, y: x ^ y,
    '|': lambda x, y: x | y,
    '&&': lambda x, y: x and y,
    '||': lambda x, y: x or y,
}


cpp_op_rank = {
    '(': 13,
    '!': 12,
    'defined': 12,
    '*': 11,
    '/': 11,
    '+': 10,
    '-': 10,
    '>>': 9,
    '<<': 9,
    '>': 8,
    '>=': 8,
    '<': 8,
    '<=': 8,
    '==': 7,
    '&': 6,
    '^': 5,
    '|': 4,
    '&&': 2,
    '||': 2,
    ')': 1,
    '$': 1,
    None: 0,
}


def create_deps(src_dirs, skip_dirs, makefile, debug, exec_target, fc_rule,
                link_externals, defines):
    """Create "makefile" after scanning "src_dis"."""

    # Scan everything Fortran related
    all_files = find_files(src_dirs, skip_dirs)

    # Lists of things
    #  ... all F90 source
    F90_files = [
        f for f in all_files
        if f.endswith('.f90') or f.endswith('.F90')
        or f.endswith('.f') or f.endswith('.F')
    ]
    #  ... all C source
    c_files = [f for f in all_files if f.endswith('.c')]

    # Dictionaries for associating files to files
    # maps basename of file to full path to file
    f2F = dict(zip([os.path.basename(f) for f in all_files], all_files))
    # maps basename of file to directory
    f2dir = dict(zip([os.path.basename(f) for f in all_files],
                     [os.path.dirname(f) for f in all_files]))

    # Check for duplicate files in search path
    if not len(f2F) == len(all_files):
        a = []
        for f in all_files:
            if os.path.basename(f) in a:
                print('Warning: File {} was found twice! One is being ignored '
                      'but which is undefined.'.format(os.path.basename(f)))
            a.append(os.path.basename(f))

    # maps object file to F90 source
    o2F90 = dict(zip([object_file(f) for f in F90_files], F90_files))
    # maps object file to C source
    o2c = dict(zip([object_file(f) for f in c_files], c_files))

    o2mods, o2uses, o2h, o2inc, o2prg, prg2o, mod2o = {}, {}, {}, {}, {}, {}, {}
    externals, all_modules = [], []
    for f in F90_files:
        mods, used, cpp, inc, prg, has_externals = scan_fortran_file(f, defines)
        # maps object file to modules produced
        o2mods[object_file(f)] = mods
        # maps module produced to object file
        for m in mods:
            mod2o[m] = object_file(f)
        # maps object file to modules used
        o2uses[object_file(f)] = used
        # maps object file to .h files included
        o2h[object_file(f)] = cpp
        # maps object file to .inc files included
        o2inc[object_file(f)] = inc
        # maps object file to executables produced
        o2prg[object_file(f)] = prg
        if prg:
            for p in prg:
                if p in prg2o.keys():
                    # raise ValueError("Files %s and %s both create the same program '%s'"%(
                    #                 f,o2F90[prg2o[p]],p))
                    print("Warning: Files {} and {} both create the same "
                          "program '{}'".format(f, o2F90[prg2o[p]], p))
                    o = prg2o[p]
                    del prg2o[p]
                    # del o2prg[o] - need to keep so modifying instead
                    o2prg[o] = ['[ignored %s]' % (p)]
                else:
                    prg2o[p] = object_file(f)
        if has_externals:
            externals.append(object_file(f))
        all_modules += mods

    for f in c_files:
        _, _, cpp, inc, _, _ = scan_fortran_file(f, defines)
        # maps object file to .h files included
        o2h[object_file(f)] = cpp
        externals.append(object_file(f))

    # Are we building a library, single or multiple executables?
    targ_libs = []
    if exec_target:
        if exec_target.endswith('.a'):
            targ_libs.append(exec_target)
        else:
            if len(prg2o.keys()) == 1:
                o = prg2o.values()[0]
                del prg2o[o2prg[o][0]]
                prg2o[exec_target] = o
                o2prg[o] = exec_target
            else:
                raise ValueError("Option -x specified an executable name but "
                                 "none or multiple programs were found")
        targets = [exec_target]
    else:
        if len(prg2o.keys()) == 0:
            print("Warning: No programs were found and -x did not specify a "
                  "library to build")
        targets = prg2o.keys()

    # Create new makefile
    with open(makefile, 'w') as file:
        print("# %s created by makedep" % (makefile), file=file)
        print(file=file)
        print("# Invoked as", file=file)
        print('#   '+' '.join(sys.argv), file=file)
        print(file=file)
        print("all:", " ".join(targets), file=file)

        # print(file=file)
        # print("# SRC_DIRS is usually set in the parent Makefile but in case is it not we", file=file)
        # print("# record it here from when makedep was previously invoked.", file=file)
        # print("SRC_DIRS ?= ${SRC_DIRS}", file=file)

        # print(file=file)
        # print("# all_files:", ' '.join(all_files), file=file)

        # Write rule for each object from Fortran
        for obj in sorted(o2F90.keys()):
            found_mods = [m for m in o2uses[obj] if m in all_modules]
            found_objs = [mod2o[m] for m in o2uses[obj] if m in all_modules]
            found_deps = [
                dep for pair in zip(found_mods, found_objs) for dep in pair
            ]
            missing_mods = [m for m in o2uses[obj] if m not in all_modules]

            incs, inc_used = nested_inc(o2h[obj] + o2inc[obj], f2F, defines)
            inc_mods = [
                u for u in inc_used if u not in found_mods and u in all_modules
            ]

            incdeps = sorted(set([f2F[f] for f in incs if f in f2F]))
            incargs = sorted(set(['-I' + os.path.dirname(f) for f in incdeps]))

            # Header
            print(file=file)
            if debug:
                print("# Source file {} produces:".format(o2F90[obj]), file=file)
                print("#   object:", obj, file=file)
                print("#   modules:", ' '.join(o2mods[obj]), file=file)
                print("#   uses:", ' '.join(o2uses[obj]), file=file)
                print("#   found mods:", ' '.join(found_mods), file=file)
                print("#   found objs:", ' '.join(found_objs), file=file)
                print("#   missing:", ' '.join(missing_mods), file=file)
                print("#   includes_all:", ' '.join(incs), file=file)
                print("#   includes_pth:", ' '.join(incdeps), file=file)
                print("#   incargs:", ' '.join(incargs), file=file)
                print("#   program:", ' '.join(o2prg[obj]), file=file)

            # Fortran Module dependencies
            if o2mods[obj]:
                print(' '.join(o2mods[obj]) + ':', obj, file=file)

            # Fortran object dependencies
            obj_incs = ' '.join(inc_mods + incdeps + found_deps)
            print(obj + ':', o2F90[obj], obj_incs, file=file)

            # Fortran object build rule
            obj_rule = ' '.join([fc_rule] + incargs + ['-c', '$<'])
            print('\t' + obj_rule, file=file)

        # Write rule for each object from C
        for obj in sorted(o2c.keys()):
            incdeps = sorted(set([f2F[h] for h in o2h[obj] if h in f2F]))
            incargs = sorted(set(['-I' + os.path.dirname(f) for f in incdeps]))

            # Header
            print(file=file)
            if debug:
                print("# Source file %s produces:" % (o2c[obj]), file=file)
                print("#   object:", obj, file=file)
                print("#   includes_all:", ' '.join(o2h[obj]), file=file)
                print("#   includes_pth:", ' '.join(incdeps), file=file)
                print("#   incargs:", ' '.join(incargs), file=file)

            # C object dependencies
            print(obj + ':', o2c[obj], ' '.join(incdeps), file=file)

            # C object build rule
            c_rule = ' '.join(
              ['$(CC) $(DEFS) $(CPPFLAGS) $(CFLAGS)'] + incargs + ['-c', '$<']
            )
            #print('\t' + c_rule, ' '.join(incargs), '-c', '$<', file=file)
            print('\t' + c_rule, file=file)

        # Externals (so called)
        if link_externals:
            print(file=file)
            print("# Note: The following object files are not associated with "
                  "modules so we assume we should link with them:", file=file)
            print("# ", ' '.join(externals), file=file)
            o2x = None
        else:
            externals = []

        # Write rules for linking executables
        for p in sorted(prg2o.keys()):
            o = prg2o[p]
            print(file=file)
            print(p+':', ' '.join(link_obj(o, o2uses, mod2o, all_modules) + externals), file=file)
            print('\t$(LD) $(LDFLAGS) -o $@ $^ $(LIBS)', file=file)

        # Write rules for building libraries
        for lb in sorted(targ_libs):
            print(file=file)
            print(lb+':', ' '.join(list(o2F90.keys()) + list(o2c.keys())), file=file)
            print('\t$(AR) $(ARFLAGS) $@ $^', file=file)

        # Write cleanup rules
        print(file=file)
        print("clean:", file=file)
        print('\trm -f *.mod *.o', ' '.join(list(prg2o.keys()) + targ_libs), file=file)

        # Write re-generation rules
        print(file=file)
        print("remakedep:", file=file)
        print('\t'+' '.join(sys.argv), file=file)


def link_obj(obj, o2uses, mod2o, all_modules):
    """List of all objects needed to link "obj","""
    def recur(obj, depth=0):
        if obj not in olst:
            olst.append(obj)
        else:
            return
        uses = [m for m in o2uses[obj] if m in all_modules]
        if len(uses) > 0:
            ouses = [mod2o[m] for m in uses]
            for m in uses:
                o = mod2o[m]
                recur(o, depth=depth+1)
                # if o not in olst:
                #    recur(o, depth=depth+1)
                #    olst.append(o)
            return
        return
    olst = []
    recur(obj)
    return sorted(set(olst))


def nested_inc(inc_files, f2F, defines):
    """List of all files included by "inc_files", either by #include or F90
    include."""
    hlst = []
    used_mods = set()

    def recur(hfile):
        if hfile not in f2F.keys():
            return

        _, used, cpp, inc, _, _ = scan_fortran_file(f2F[hfile], defines)

        # Record any module updates inside of include files
        used_mods.update(used)

        if len(cpp) + len(inc) > 0:
            for h in cpp+inc:
                if h not in hlst and h in f2F.keys():
                    recur(h)
                    hlst.append(h)
            return
        return

    for h in inc_files:
        recur(h)

    return inc_files + sorted(set(hlst)), used_mods


def scan_fortran_file(src_file, defines=None):
    """Scan the Fortran file "src_file" and return lists of module defined,
    module used, and files included."""
    module_decl, used_modules, cpp_includes, f90_includes, programs = [], [], [], [], []

    cpp_defines = defines if defines is not None else []

    cpp_macros = dict([t.split('=') for t in cpp_defines])
    cpp_group_stack = []

    with io.open(src_file, 'r', errors='replace') as file:
        lines = file.readlines()

        external_namespace = True
        # True if we are in the external (i.e. global) namespace

        file_has_externals = False
        # True if the file contains any external objects

        cpp_exclude = False
        # True if the parser excludes the subsequent lines

        cpp_group_stack = []
        # Stack of condition group exclusion states

        for line in lines:
            # Start of #ifdef condition group
            match = re_cpp_ifdef.match(line)
            if match:
                cpp_group_stack.append(cpp_exclude)

                # If outer group is excluding or macro is missing, then exclude
                macro = line.lstrip()[1:].split()[1]
                cpp_exclude = cpp_exclude or macro not in cpp_macros

            # Start of #ifndef condition group
            match = re_cpp_ifndef.match(line)
            if match:
                cpp_group_stack.append(cpp_exclude)

                # If outer group is excluding or macro is present, then exclude
                macro = line.lstrip()[1:].split()[1]
                cpp_exclude = cpp_exclude or macro in cpp_macros

            # Start of #if condition group
            match = re_cpp_if.match(line)
            if match:
                cpp_group_stack.append(cpp_exclude)

                cpp_expr_value = cpp_expr_eval(line, cpp_macros)

                cpp_exclude = not cpp_expr_value

            # Complement #else condition group
            match = re_cpp_else.match(line)
            if match:
                # Reverse the exclude state, if there is no outer exclude state
                outer_grp_exclude = cpp_group_stack and cpp_group_stack[-1]
                cpp_exclude = not cpp_exclude or outer_grp_exclude

            # Restore exclude state when exiting conditional block
            match = re_cpp_endif.match(line)
            if match:
                cpp_exclude = cpp_group_stack.pop()

            # Skip lines inside of false condition blocks
            if cpp_exclude:
                continue

            # Activate a new macro (ignoring the value)
            match = re_cpp_define.match(line)
            if match:
                tokens = line.strip()[1:].split(maxsplit=2)
                macro = tokens[1]
                value = tokens[2] if tokens[2:] else None
                if '(' in macro:
                    # TODO: Actual handling of function macros
                    macro, arg = macro.split('(', maxsplit=1)
                    value = '(' + arg + value
                cpp_macros[macro] = value

            # Deactivate a macro
            match = re_cpp_undef.match(line)
            if match:
                new_macro = line.lstrip()[1:].split()[1]
                try:
                    cpp_macros.pop(new_macro)
                except KeyError:
                    # C99: "[A macro] is ignored if the specified identifier is
                    # not currently defined as a macro name."
                    continue

            match = re_module.match(line.lower())
            if match:
                # Avoid "module procedure" statements
                if match.group(1) not in 'procedure':
                    module_decl.append(match.group(1))
                    external_namespace = False

            match = re_use.match(line.lower())
            if match:
                used_modules.append(match.group(1))

            match = re_cpp_include.match(line)
            if match:
                cpp_includes.append(match.group(1))

            match = re_f90_include.match(line)
            if match:
                f90_includes.append(match.group(1))

            match = re_program.match(line)
            if match:
                programs.append(match.group(1))
                external_namespace = False

            match = re_end.match(line)
            if match:
                external_namespace = True

            # Check for any external procedures; if present, flag the file
            # as a potential source of
            # NOTE: This a very weak test that needs further modification
            if external_namespace and not file_has_externals:
                match = re_procedure.match(line)
                if match:
                    file_has_externals = True

    used_modules = [m for m in sorted(set(used_modules)) if m not in module_decl]
    return add_suff(module_decl, '.mod'), add_suff(used_modules, '.mod'), cpp_includes, f90_includes, programs, file_has_externals
    # return add_suff(module_decl, '.mod'), add_suff(sorted(set(used_modules)), '.mod'), cpp_includes, f90_includes, programs


def object_file(src_file):
    """Return the name of an object file that results from compiling
    src_file."""
    return os.path.splitext(os.path.basename(src_file))[0] + '.o'


def find_files(src_dirs, skip_dirs):
    """Return sorted list of all source files starting from each directory in
    the list "src_dirs"."""

    if skip_dirs is not None:
        skip = [os.path.normpath(s) for s in skip_dirs]
    else:
        skip = []

    # TODO: Make this a user-defined argument
    extensions = ('.f90', '.f', '.c', '.inc', '.h',  '.fh')

    files = []

    for path in src_dirs:
        if not os.path.isdir(path):
            raise ValueError("Directory '{}' was not found".format(path))
        for p, d, f in os.walk(os.path.normpath(path), followlinks=True):
            d[:] = [s for s in d if os.path.join(p, s) not in skip]

            for file in f:
                if any(file.lower().endswith(ext) for ext in extensions):
                    files.append(p+'/'+file)
    return sorted(set(files))


def add_suff(lst, suff):
    """Add "suff" to each item in the list"""
    return [f + suff for f in lst]


def cpp_expr_eval(expr, macros=None):
    if macros is None:
        macros = {}

    results, remainder = cpp_scanner.scan(expr.strip())

    # Abort if any characters are not tokenized
    if remainder:
        print('There are untokenized characters!')
        print('Expression:', repr(expr))
        print('Tokens:', results)
        print('Unscanned:', remainder)
        raise

    # Add an "end of line" character to force evaluation of the final tokens.
    results.append('$')

    stack = []
    prior_op = None

    tokens = iter(results)
    for tok in tokens:
        if tok in cpp_op_rank.keys():
            while cpp_op_rank[tok] <= cpp_op_rank[prior_op]:

                # Unary operators are "look ahead" so we always skip them.
                # (However, `op` below could be a unary operator.)
                if tok in ('!', 'defined', '('):
                    break

                second = stack.pop()
                op = stack.pop()

                if op == '(':
                    value = second

                elif op == '!':
                    if isinstance(second, str):
                        if second.isidentifier():
                            second = macros.get(second, '0')
                        if second.isdigit():
                            second = int(second)

                    value = cpp_operate[op](second)

                elif op == 'defined':
                    value = cpp_operate[op](second, macros)

                else:
                    first = stack.pop()

                    if isinstance(first, str):
                        if first.isidentifier():
                            first = macros.get(first, '0')
                        if first.isdigit():
                            first = int(first)

                    if isinstance(second, str):
                        if second.isidentifier():
                            second = macros.get(second, '0')
                        if second.isdigit():
                            second = int(second)

                    value = cpp_operate[op](first, second)

                prior_op = stack[-1] if stack else None
                stack.append(value)

            # The ) "operator" has already been applied, so it can be dropped.
            if tok != ')':
                stack.append(tok)
                prior_op = tok

        elif tok.isdigit() or tok.isidentifier():
            stack.append(tok)

        else:
            print("Unsupported token:", tok)
            raise

    # Remove the tail value
    eol = stack.pop()
    assert eol == '$'
    value = stack.pop()

    return value


# Parse arguments
parser = argparse.ArgumentParser(
    description="Generate make dependencies for F90 source code."
)
parser.add_argument(
    'path',
    nargs='+',
    help="Directories to search for source code."
)
parser.add_argument(
    '-o', '--makefile',
    default='Makefile.dep',
    help="Name of Makefile to put dependencies in to. Default is Makefile.dep."
)
parser.add_argument(
    '-f', '--fc_rule',
    default="$(FC) $(DEFS) $(CPPFLAGS) $(FCFLAGS)",
    help="String to use in the compilation rule. Default is: "
         "'$(FC) $(DEFS) $(CPPFLAGS) $(FCFLAGS)'"
)
parser.add_argument(
    '-x', '--exec_target',
    help="Name of executable to build. Fails if more than one program is "
         "found. If EXEC ends in .a then a library is built."
)
parser.add_argument(
    '-e', '--link_externals',
    action='store_true',
    help="Always compile and link any files that do not produce modules "
         "(externals)."
)
parser.add_argument(
    '-d', '--debug',
    action='store_true',
    help="Annotate the makefile with extra information."
)
parser.add_argument(
    '-s', '--skip',
    action='append',
    help="Skip directory in source code search."
)
parser.add_argument(
    '-D', '--define',
    action='append',
    help="Apply preprocessor define macros (of the form -DMACRO[=value])",
)
args = parser.parse_args()

# Do the thing
create_deps(args.path, args.skip, args.makefile, args.debug, args.exec_target,
            args.fc_rule, args.link_externals, args.define)
