/*
 * MicroHH
 * Copyright (c) 2011-2024 Chiel van Heerwaarden
 * Copyright (c) 2011-2024 Thijs Heus
 * Copyright (c) 2014-2024 Bart van Stratum
 * Copyright (c) 2022-2022 Stijn Heldens
 *
 * This file is part of MicroHH
 *
 * MicroHH is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.

 * MicroHH is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.

 * You should have received a copy of the GNU General Public License
 * along with MicroHH.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <vector>
#include <string>
#include <unistd.h>

#ifdef ENABLE_KERNEL_LAUNCHER
#include "kernel_launcher.h"
#include "cuda_launcher.h"

namespace kl = kernel_launcher;

const std::string& home_directory();

GridKernel::GridKernel(
        GridFunctor meta,
        kernel_launcher::TypeInfo functor_type,
        std::vector<kl::TypeInfo> param_types,
        Grid_layout grid):
    meta(std::move(meta)),
    functor_type(functor_type),
    param_types(std::move(param_types)),
    grid(grid) {}

bool GridKernel::equals(const IKernelDescriptor& that) const {
    if (const GridKernel* g = dynamic_cast<const GridKernel*>(&that)) {
        return g->meta.name == meta.name &&
                g->functor_type == functor_type &&
                g->param_types == param_types &&
                g->grid == grid;
    }

    return false;
}

size_t GridKernel::hash() const {
    return kl::hash_fields(meta.name, functor_type);
}

kl::KernelBuilder GridKernel::build() const {
    std::stringstream args;
    std::stringstream params;

    for (int i = 0; i < param_types.size(); i++) {
        if (i != 0) {
            args << ", ";
            params << ", ";
        }

        params << param_types[i].name() << " ";
        if (param_types[i].is_pointer()) {
            params << "__restrict__ ";
        }

        params << "a" << i;
        args << "a" << i;
    }

    std::string source = R"(
        #include ")" + std::string(meta.file) + R"(";
        #include "cuda_tiling.h"

        template <typename F>
        __global__
        __launch_bounds__(BLOCK_SIZE_X * BLOCK_SIZE_Y * BLOCK_SIZE_Z, BLOCKS_PER_SM)
        void kernel()" + params.str() + R"() {
            dim3 num_blocks = {NUM_BLOCKS_X, NUM_BLOCKS_Y, NUM_BLOCKS_Z};
            dim3 block_index = unravel_dim3(blockIdx.x, num_blocks, AXES_PERMUTATION);

            Grid_layout gd = {
                GRID_START_I,
                GRID_END_I,
                GRID_START_J,
                GRID_END_J,
                GRID_START_K,
                GRID_END_K,
                GRID_STRIDE_I,
                GRID_STRIDE_J,
                GRID_STRIDE_K
            };

            using Tiling = StaticTilingStrategy<
                BLOCK_SIZE_X,
                BLOCK_SIZE_Y,
                BLOCK_SIZE_Z,
                TILE_FACTOR_X,
                TILE_FACTOR_Y,
                TILE_FACTOR_Z,
                UNROLL_FACTOR_X,
                UNROLL_FACTOR_Y,
                UNROLL_FACTOR_Z,
                TILE_CONTIGUOUS_X,
                TILE_CONTIGUOUS_Y,
                TILE_CONTIGUOUS_Z
            >;

            cta_execute_tiling_with_edges<EDGE_LEVELS, Tiling>(gd, block_index, F{}, )" + args.str() + R"();
        }
    )";

    std::vector<uint32_t> block_size_x = {16, 32, 64, 128, 256};
    std::vector<uint32_t> block_size_y = {1, 2, 4, 8, 16};
    std::vector<uint32_t> block_size_z = {1, 2, 4, 8, 16};

    kl::KernelBuilder builder("kernel", kl::KernelSource("kernel.cu", source));
    auto bx = builder.tune("BLOCK_SIZE_X", block_size_x, meta.block_size.x);
    auto by = builder.tune("BLOCK_SIZE_Y", block_size_y, meta.block_size.y);
    auto bz = builder.tune("BLOCK_SIZE_Z", block_size_z, meta.block_size.z);
    auto blocks_per_sm = builder.tune_define("BLOCKS_PER_SM", {1, 2, 3, 4, 5, 6});

    auto tx = builder.tune_define("TILE_FACTOR_X", {1, 2, 3, 4});
    auto ty = builder.tune_define("TILE_FACTOR_Y", {1, 2, 3, 4});
    auto tz = builder.tune_define("TILE_FACTOR_Z", {1, 2, 3, 4, 8, 16, 32});

    // How many loops to unroll
    // - 0: no unroll
    // - 1: only inner loop
    // - 2: two inner loops
    // - 3: all loops
    auto unroll_depth = builder.tune_define("LOOP_UNROLL_DEPTH", {3, 2, 1, 0});

    // What order to unravel the block index.
    builder.tune_define("AXES_PERMUTATION", {0, 1, 2, 3, 4, 5});

    // Tiling is contiguous or block strided
    auto tile_cont = builder.tune_define("TILE_CONTIGUOUS_YZ", {0, 1});

    // Number of thread blocks
    dim3 problem_size = {
            uint32_t(grid.iend - grid.istart),
            uint32_t(grid.jend - grid.jstart),
            uint32_t(grid.kend - grid.kstart)
    };

    auto nx = kl::div_ceil(grid.iend - grid.istart, bx * tx);
    auto ny = kl::div_ceil(grid.jend - grid.jstart, by * ty);
    auto nz = kl::div_ceil(grid.kend - grid.kstart, bz * tz);

    std::string tuning_key = meta.name;

    for (auto p: param_types) {
        auto base = p.remove_pointer().remove_const();

        if (base == kl::type_of<float>()) {
            tuning_key += "@float";
            break;
        }

        if (base == kl::type_of<double>()) {
            tuning_key += "@double";
            break;
        }
    }

    builder
        .tuning_key(tuning_key)
        .template_arg(functor_type)
        .problem_size(problem_size)
        .block_size(bx, by, bz)
        .grid_size(nx * ny * nz);

    builder
        .define(bx)
        .define(by)
        .define(bz)
        .define("GRID_START_I", std::to_string(grid.istart))
        .define("GRID_START_J", std::to_string(grid.jstart))
        .define("GRID_START_K", std::to_string(grid.kstart))
        .define("GRID_END_I", std::to_string(grid.iend))
        .define("GRID_END_J", std::to_string(grid.jend))
        .define("GRID_END_K", std::to_string(grid.kend))
        .define("GRID_STRIDE_I", std::to_string(grid.istride))
        .define("GRID_STRIDE_J", std::to_string(grid.jstride))
        .define("GRID_STRIDE_K", std::to_string(grid.kstride))
        .define("NUM_BLOCKS_X", nx)
        .define("NUM_BLOCKS_Y", ny)
        .define("NUM_BLOCKS_Z", nz)
        .define("UNROLL_FACTOR_X", kl::ifelse(unroll_depth >= 1, tx, 1))
        .define("UNROLL_FACTOR_Y", kl::ifelse(unroll_depth >= 2, ty, 1))
        .define("UNROLL_FACTOR_Z", kl::ifelse(unroll_depth >= 3, tz, 1))
        .define("TILE_CONTIGUOUS_X", "0")
        .define("TILE_CONTIGUOUS_Y", tile_cont)
        .define("TILE_CONTIGUOUS_Z", tile_cont)
        .define("EDGE_LEVELS", std::to_string(meta.edge_levels));

    builder.compiler_flags(
            "--restrict",
            "-std=c++17",
            "-I" + home_directory() + "/include");

    // restrictions:
    // - Threads per block should not too small (>=64) or too big (<=1024)
    // - Threads per SM should not be too big (<=2048)
    // - Items per thread should not be too many (<= 32)
    auto threads_per_block = bx * by * bz;
    builder.restriction(threads_per_block >= 64 && threads_per_block <= 1024);
    builder.restriction(threads_per_block * blocks_per_sm <= 2048);
    builder.restriction(tx * ty * tz <= 32);

    return builder;
}

std::string find_home_directory() {
    std::string result;
    const char *env = getenv("MICROHH_HOME");

    // Check environment KEY
    if (env) {
        result = env;
    }

    // No success, try from __FILE__
    if (result.empty()) {
        std::string file = __FILE__;
        size_t index = file.rfind("/src/");

        if (index != std::string::npos) {
            result = file.substr(0, index);
        }
    }

    // No success, try "cwd"?
    if (result.empty()) {
        char buffer[1024] = {0};

       if (getcwd(buffer, sizeof(buffer)) != nullptr) {
           result = buffer;
       }
    }

    // No success, try "."
    if (result.empty()) {
        result = ".";
    }

    // Add trailing slash
    if (result.back() != '/') {
        result += "/";
    }

    if (env == nullptr) {
        std::cerr << "WARNING: environment variable MICROHH_HOME is not set, best guess: " << result << "\n";
    }

    return result;
}

const std::string& home_directory() {
    static std::string dir = find_home_directory();
    return dir;
}

bool launch_kernel(
        cudaStream_t stream,
        GridKernel grid_kernel,
        const std::vector<kl::KernelArg>& args
) {
    static bool initialized = false;
    static bool error_state = false;

    // If an exception was thrown at some point, just return false immediately.
    if (error_state) {
        return false;
    }

    kl::KernelDescriptor kernel = std::move(grid_kernel);

    try {
        if (!initialized) {
            initialized = true;
            kernel_launcher::append_global_wisdom_directory(home_directory() + "wisdom");
            kernel_launcher::set_global_capture_directory(home_directory() + "captures");
        }

        kernel_launcher::default_registry()
                .lookup(kernel)
                .launch_args(stream, args);
        return true;
    } catch (const std::exception &e) {
        kl::log_warning() << "error occurred while compiling the following kernel: " <<
                grid_kernel.meta.name << ":" << "\n"  << e.what() << std::endl;
        kl::log_warning() << "CUDA dynamic kernel compilation is now disabled and the application is in FALLBACK mode. "
                             "No more kernels will be executed using Kernel Launcher." << std::endl;

        error_state = true;
        return false;
    }
}
#endif
