#!/usr/bin/env python3

import sys
import os
import subprocess
import argparse

""" Main entry point to use the stencil transformation

This will look for pragma involving brick and replace the code in the place of the pragma.

This uses C-processor to determine configurations passed to the code generator
"""

if __name__ == "__main__":
    # first run the c processor of your choice
    if 'VSCPP' in os.environ:
        CPP = os.environ['VSCPP']
    else:
        CPP = 'cpp'

    parser = argparse.ArgumentParser(description='Inject vector scatter to your code!')
    parser.add_argument('input', help='file name for the input source', type=str)
    parser.add_argument('output', help='file name for the output', type=str)
    parser.add_argument('--thres', '-t', metavar='Thres', default=1.5,
                        help='threshold to determine whether to use scatter', type=float)
    parser.add_argument('--msize', '-m', metavar='MSize', default=5,
                        help='maximum (suggested) number of buffers', type=float)
    parser.add_argument('--K', '-k', metavar='K', default=2,
                        help='weight for buffers', type=float)
    parser.add_argument('--limit', '-l', metavar='Limit', default=20,
                        help='limit for input vectors and buffers', type=float)
    parser.add_argument('--unroll', '-u', help='fully unroll during scatter', action='store_true')
    parser.add_argument('--dsplit', '-d', help='perform dimensional split on the input stencils', action='store_true')

    idx = 0
    while idx < len(sys.argv) and sys.argv[idx] != '--':
        idx += 1

    opts = sys.argv[1:idx]
    arg = parser.parse_args(opts)
    src = arg.input
    dst = arg.output

    # clean compiler specific options
    cppargv = [arg for arg in sys.argv[(idx + 1):] if arg[0] == '-' and (arg[1].isupper() or arg[1] in 'm')]

    args = [CPP, '-x', 'c++', src] + cppargv

    path = os.path.dirname(os.path.abspath(src))
    sys.path.append(path)
    print(path)
    proc = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    print(proc.stderr.decode('utf-8'), file=sys.stderr)
    ppsrc = proc.stdout.decode('utf-8')
    lines = ppsrc.split('\n')

    loc = []
    for idx, line in enumerate(lines):
        if line.find('#pragma vecscatter') >= 0:
            loc.append(line)

    with open(src, 'rt') as f:
        source = f.read()
    source = source.split('\n')

    relsrc = os.path.relpath(src, os.path.dirname(dst))
    abssrc = os.path.abspath(src)

    orig_dir = os.getcwd()
    reldir = os.path.dirname(src)
    if len(reldir) > 0:
        absdir = "{}/{}/".format(os.getcwd(), reldir)
    else:
        absdir = "{}/".format(os.getcwd())
    os.chdir(absdir)


    def get_backend(vec):
        from st.codegen.backend import BackendAVX512, BackendAVX2, BackendScalar, BackendCUDA, BackendFlex, BackendSSE, \
            BackendCuFlex
        if vec == 'AVX512':
            return BackendAVX512()
        elif vec == 'AVX2':
            return BackendAVX2()
        elif vec == 'SSE':
            return BackendSSE()
        elif vec == 'CUDA':
            return BackendCUDA()
        elif vec == 'FLEX':
            return BackendFlex()
        elif vec == 'CUFLEX':
            return BackendCuFlex()
        elif vec == 'OPENCL':
            return BackendCUDA(16, 'sglid', ocl=True)
        elif vec == 'SYCL':
            return BackendCUDA(16, 'sglid', ocl=False)
        elif vec == 'HIP':
            return BackendCUDA(64, 'hipThreadIdx_x')
        return BackendScalar()


    def get_prec(prec):
        if prec == 'double':
            return 2
        return 1


    def Tile(file, line, script_name, prec, vec, *, dim=None, tile_iter=None, stride=None):
        if file == src:
            print(file, line, prec, dim)
            prec = get_prec(prec)
            backend = get_backend(vec)
            print("BACKEND", vec)
            dim = list(reversed(dim))
            tile_iter = list(reversed(tile_iter))
            if stride is not None:
                stride = list(reversed(stride))
            else:
                stride = []
            stride = stride + [1] * (len(dim) - len(stride))
            from st.codegen.backend import Tiled
            layout = Tiled(dim=dim, prec=prec, tile_iter=tile_iter, aligned=False)
            from st.codegen.base import CodeGen
            cg = CodeGen(backend, layout, dag_msize=arg.msize, scatter_thres=arg.thres, unroll=arg.unroll,
                         dimsplit=arg.dsplit, stride=stride, K=arg.K, klimit=arg.limit)
            with open(script_name, 'rt') as f:
                script = f.read()
            local = {}
            exec(script, {}, local)
            from io import StringIO
            s = StringIO()
            cg.gencode(local['STENCIL'], s)
            source[line - 1] = '# 1 "VSTile-{}-{}-{}" 1\n'.format(os.path.basename(script_name), vec,
                                                                  "x".join([str(d) for d in reversed(dim)])) + \
                               s.getvalue() + '# {} "{}" 2\n'.format(line, abssrc)


    def Brick(file, line, script_name, prec, vec, *, bidx='b', dim=None, fold=None, stride=None):
        if file == src:
            print(file, line)
            prec = get_prec(prec)
            backend = get_backend(vec)
            dim = list(reversed(dim))
            if not isinstance(fold, tuple):
                fold = [fold]
            fold = list(reversed(fold))
            if stride is not None:
                stride = list(reversed(stride))
            else:
                stride = []
            stride = stride + [1] * (len(dim) - len(stride))
            from st.codegen.backend import Brick
            layout = Brick(dim=dim, fold=fold, prec=prec, brick_idx=bidx,
                           cstruct=False if vec != 'OPENCL' and vec != 'SYCL' else True)
            from st.codegen.base import CodeGen
            cg = CodeGen(backend, layout, dag_msize=arg.msize, scatter_thres=arg.thres, unroll=arg.unroll,
                         dimsplit=arg.dsplit, stride=stride, K=arg.K, klimit=arg.limit)
            with open(script_name, 'rt') as f:
                script = f.read()
            local = {}
            exec(script, {}, local)
            from io import StringIO
            s = StringIO()
            cg.gencode(local['STENCIL'], s)
            source[line - 1] = '# 1 "VSBrick-{}-{}-{}-{}" 1\n'.format(os.path.basename(script_name), vec,
                                                                      "x".join([str(d) for d in reversed(dim)]),
                                                                      "x".join([str(f) for f in reversed(fold)])) + \
                               s.getvalue() + '# {} "{}" 2\n'.format(line, abssrc)


    for dir in reversed(loc):
        op = dir.split(' ')[2]
        st = dir.find(op)
        call = dir[(st + len(op) + 1):]
        if op == "Scatter":
            eval(call)

    os.chdir(orig_dir)
    with open(dst, 'wt') as f:
        f.write('# 1 "{}"\n'.format(abssrc) + '\n'.join(source))
