#!/usr/bin/env python3

import os
import pkgutil
import readline
import importlib
import sys
import inspect

readline.parse_and_bind("tab: complete")


def debug(*args, **kwargs):
    # print(*args,**kwargs)
    pass


class completer(object):
    def __init__(self, lst):
        self.lst = lst

    def complete(self, txt, state):
        matches = 0
        for a in self.lst:
            if a.startswith(txt):
                if state == matches:
                    return a
                matches += 1
        return None


class boutmodules(object):
    def __init__(self, lst):
        self.modules = []
        for a in lst:
            self.parseModule(a)
        debug(self.modules)

    def parseModule(self, name):
        try:
            mod = importlib.import_module(name)
        except:
            debug("Failed to import %s" % name)
            return
        try:
            for loader, module_name, is_pkg in pkgutil.walk_packages(path=mod.__path__):
                self.parseModule(name + "." + module_name)
        except AttributeError:
            debug("Failed to scan %s for submodues" % name)
        for b in inspect.getmembers(mod):
            if inspect.isfunction(b[1]):
                if not name in self.modules:
                    self.modules.append(name)
                return

    def returnCached(self, state):
        if len(self.cache) > state:
            return self.cache[state]
        else:
            return None

    def scanFunctions(self, module):
        self.functions = []
        self.funcs = {}
        mod = importlib.import_module(module)
        for b in inspect.getmembers(mod):
            if inspect.isfunction(b[1]):
                self.functions.append(b[0])
                self.funcs[b[0]] = b[1]


def esc(astr):
    astr = str(astr)
    if '"' in astr:
        if "'" in astr:
            astr = astr.replace("\\", "\\\\")
            return '"' + astr.replace('"', '"') + '"'
        return "'" + astr + "'"
    astr = astr.replace("\\", "\\\\")
    return '"' + astr + '"'


def parse_type(typestr):
    typearr_ = typestr.split(",")
    typearr = []
    for t in typearr_:
        for t2 in t.split(" or "):
            typearr.append(t2)
    opt = False
    if "optional" in typearr[-1]:
        opt = True
        typearr = typearr[:-1]
    for type_ in typearr:
        knowns = {
            "str": "str",
            "int": "int",
            "slice": "str_to_slice",
            "bool": "str_to_bool",
        }
        if type_ in knowns:
            return knowns[type_], opt
    raise ValueError("unknown type '%s'" % typestr)


def createwrapper(mod, func_name, func, name):
    """

    mod : str
        the module name
    func_name : str
        the function name
    func : function
        The function itself
    name : str
        The filename
    """
    filename = name
    with open(name, "w") as f:

        def fprint(*args, end="\n"):
            out = ""
            sep = ""
            for i in args:
                out += sep + i
                sep = " "
            out += end
            f.write(out)

        fprint(
            """#!/usr/bin/env python3
# PYTHON_ARGCOMPLETE_OK

import argparse
try:
    import argcomplete
except ImportError:
    argcomplete=None

"""
        )
        doc = True
        para = False
        ret = False
        docs = ""
        arg_type = {}
        arg_opt = {}
        arg_help = {}

        paraopen = False
        for bla in func.__doc__.splitlines():
            blas = bla.strip()
            if doc:
                if "Parameters" == blas:
                    doc = False
                    para = True
                    off = 0
                    while bla[off] == " ":
                        off += 1
            if doc and blas:
                if docs:
                    docs += " "
                docs += blas
            if para:
                if blas == "Return":
                    ret = True
                    para = False
                elif blas == "":
                    pass
                elif blas == "Parameters" or blas == "----------":
                    pass
                elif len(bla) > off and bla[off] != " ":
                    arg = blas.split(" : ")
                    curarg = arg[0]

                    t, o = parse_type(arg[1])
                    arg_type[curarg] = t
                    arg_opt[curarg] = o
                else:
                    if curarg in arg_help:
                        arg_help[curarg].append(esc(" " + blas))
                    else:
                        arg_help[curarg] = []
                        arg_help[curarg].append(esc(blas))
        # Print functions that are needed
        if "str_to_slice" in arg_type.values():
            fprint(
                """
def str_to_slice(sstr):
    args=[]
    for s in sstr.split(','):
        args.append(int(s))
    print(args)
    return slice(*args)
"""
            )
        if "str_to_bool" in arg_type.values():
            fprint(
                """
def str_to_bool(sstr):
    low=sstr.lower()
    # no or false
    if low.startswith("n") or low.startswith("f"):
        return False
    # yes or true
    elif low.startswith("y") or low.startswith("t"):
        return True
    else:
        raise ArgumentTypeError("Cannot parse %s to bool type"%sstr)
"""
            )
        # Create the parser
        fprint("parser = argparse.ArgumentParser(%s)" % (esc(docs)))
        spec = inspect.signature(func)
        for name in spec.parameters:
            item = spec.parameters[name]
            # print(item)
            t = arg_type[name]
            if item.default == inspect._empty:
                # No default value present. Make positional argument
                fprint('parser.add_argument("%s", type=%s' % (name, t), end="")
            else:
                if t == "str_to_bool" and item.default == False:
                    fprint(
                        'parser.add_argument("--%s", action="store_true", default=False'
                        % (name),
                        end="",
                    )
                else:
                    fprint('parser.add_argument("--%s", type=%s' % (name, t), end="")
                    fprint(", default=", end="")
                    if type(item.default) == str:
                        fprint(esc(item.default), end="")
                    else:
                        fprint(str(item.default), end="")
                    if t == "str_to_slice":
                        if not name in arg_help:
                            arg_help[name] = []
                        arg_help[name].append(
                            esc(
                                "Pass slice object as comma separated list without spaces."
                            )
                        )
            if item.name in arg_help:
                pre = "\n    ,help="
                for helps in arg_help[item.name]:
                    fprint(pre, helps, end="")
                    pre = "\n    "
            fprint(")")

        fprint(
            """
if argcomplete:
    argcomplete.autocomplete(parser)

# late import for faster auto-complete"""
        )
        fprint("from %s import %s" % (mod, func_name))
        fprint(
            """
args = parser.parse_args()

# Call the function %s, using command line arguments
%s(**args.__dict__)

"""
            % (func_name, func_name)
        )
    # alternative, but I think 0o755 is easier to read
    # import stat
    # os.chmod(filename,stat.S_IRWXU|stat.S_IRGRP|stat.S_IXGRP|stat.S_IROTH|stat.S_IXOTH)
    os.chmod(filename, 0o755)


if __name__ == "__main__":
    realfun = None
    if len(sys.argv) < 3:
        print("Please wait, scanning modules ...")
        x = boutmodules(["boutpp", "boutdata", "boututils", "post_bout", "zoidberg"])
        readline.set_completer(completer(x.modules).complete)
        mod = input("Module: ")
        x.scanFunctions(mod)
        readline.set_completer(completer(x.functions).complete)
        fun = input("Function: ")
        realfun = x.funcs[fun]
    else:
        mod = sys.argv[1]
        fun = sys.argv[2]
    if len(sys.argv) < 4:
        name = "bin/bout-%s-%s" % (mod, fun)
        name = name.replace(".", "-")
        readline.set_completer()
    else:
        name = sys.argv[3]
    while True:
        if len(sys.argv) > 4:
            if sys.argv[4] == "-f":
                break
        readline.set_startup_hook(lambda: readline.insert_text(name))
        name = input("Filename: ")
        readline.set_startup_hook()
        if os.path.exists(name):
            print(
                "File %s already exists. If you continue it will be overwritten." % name
            )
            doit = input("Do you want to continue? [y/N] ")
            if "y" in doit.lower():
                break
        else:
            break
    # print("Creating wrapper; rerun with %s %s %s %s"%(sys.argv[0],mod,fun,name))
    if realfun is None:
        exec("from %s import %s as realfun" % (mod, fun))
    try:
        createwrapper(mod, fun, realfun, name)
    except:
        print("Creating failed. To rerun and overwrite the file without asking run:")
        print("%s %s %s %s -f" % (sys.argv[0], mod, fun, name))
        raise
