// Copyright 2024
//
// For a comprehensive list of the developers that contributed to these codes
// see the UK-AMOR website.
//
// This file is part of UKRmol-out (UKRmol+ suite).
//
//     UKRmol-out is free software: you can redistribute it and/or modify
//     it under the terms of the GNU General Public License as published by
//     the Free Software Foundation, either version 3 of the License, or
//     (at your option) any later version.
//
//     UKRmol-out is distributed in the hope that it will be useful,
//     but WITHOUT ANY WARRANTY; without even the implied warranty of
//     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//     GNU General Public License for more details.
//
//     You should have received a copy of the GNU General Public License
//     along with  UKRmol-out (in source/COPYING). Alternatively, you can also visit
//     <https://www.gnu.org/licenses/>.
//

#define CL_TARGET_OPENCL_VERSION 120

#include <complex.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>

#ifdef useclblast
#include <CL/cl.h>
#include <clblast_c.h>
#endif

const int max_name = 1024;
const int max_platforms = 10;
const int max_devices = 10;
const int zero = 0;
const int vwidth = 32;

int iplatform = 0;
int idevice = 0;
int initialized = 0;

#ifdef useclblast
cl_context context;
cl_command_queue queue;
#endif

// These functions are defined in "linalg_cl.f90". We need to pass outputs back
// through Fortran, because C and Fortran have separate I/O buffers and writing
// from both results in unpredictable order of outputs.
void f_print_info(const char* str, int* length);
void f_print_error(const char* str, int* length);

/**
 * \brief   Print info message
 * \authors J Benda
 * \date    2025
 *
 * This is a replacement for the standard `printf`, which composes the message string
 * and passes it to a Fortran function that respects current Fortran I/O units and
 * their buffers.
 */
void print_info(const char* format, ...)
{
    va_list list;
    va_start(list, format);
    char buffer[1024];
    int length = vsnprintf(buffer, sizeof(buffer), format, list);
    f_print_info(buffer, &length);
    va_end(list);
}

/**
 * \brief   Print error message
 * \authors J Benda
 * \date    2025
 *
 * This is a replacement for `fprintf(stderr, ...)`, which composes the message string
 * and passes it to a Fortran function that respects current Fortran I/O units and
 * their buffers.
 */
void print_error(const char* format, ...)
{
    va_list list;
    va_start(list, format);
    char buffer[1024];
    int length = vsnprintf(buffer, sizeof(buffer), format, list);
    f_print_error(buffer, &length);
    va_end(list);
}

/**
 * \brief   Check if OpenCL has been initialized
 * \authors J Benda
 * \date    2024
 *
 * Return 1 if OpenCL has been initialized, 0 otherwise.
 */
int is_initialized_cl ()
{
    return initialized;
}

/**
 * \brief   Initialize OpenCL
 * \authors J Benda
 * \date    2024
 *
 * Initialize OpenCL, select platform and device, create queue and context. The platform and device
 * index can be specified by means of the environment variables OCL_PLATFORM and OCL_DEVICE. If not
 * specified, the first platform and its first device will be used.
 */
void initialize_cl (int platform, int device)
{
    if (is_initialized_cl())
        return;

#ifdef useclblast
    char platform_name[max_name];
    char platform_vendor[max_name];
    char platform_version[max_name];
    char device_name[max_name];
    char device_vendor[max_name];

    const char* env;

    cl_platform_id platforms[max_platforms];
    cl_device_id devices[max_devices];
    cl_uint nplatforms, ndevices;
    cl_int err;

    // choose the OpenCL platform
    if (platform >= 0)
        iplatform = platform;
    else  if ((env = getenv("OCL_PLATFORM")) != NULL)
        iplatform = atoi(env);

    // choose the OpenCL device
    if (device >= 0)
        idevice = device;
    else if ((env = getenv("OCL_DEVICE")) != NULL)
        idevice = atoi(env);

    // get platform and device information
    clGetPlatformIDs(max_platforms, platforms, &nplatforms);
    clGetDeviceIDs(platforms[iplatform], CL_DEVICE_TYPE_ALL, max_devices, devices, &ndevices);
    clGetPlatformInfo(platforms[iplatform], CL_PLATFORM_NAME, sizeof(platform_name), platform_name, NULL);
    clGetPlatformInfo(platforms[iplatform], CL_PLATFORM_VENDOR, sizeof(platform_vendor), platform_vendor, NULL);
    clGetPlatformInfo(platforms[iplatform], CL_PLATFORM_VERSION, sizeof(platform_version), platform_version, NULL);
    clGetDeviceInfo(devices[idevice], CL_DEVICE_NAME, sizeof(device_name), device_name, NULL);
    clGetDeviceInfo(devices[idevice], CL_DEVICE_VENDOR, sizeof(device_vendor), device_vendor, NULL);

    print_info("");
    print_info("Using CL platform %d (%s; %s; %s)", iplatform, platform_name, platform_vendor, platform_version);
    print_info("Using CL device %d (%s; %s)", idevice, device_name, device_vendor);

    // initialize OpenCL
    if ((context = clCreateContext(NULL, 1, &devices[idevice], NULL, NULL, &err)) == NULL)
    {
        print_error("initialize_cl: error: Failed to create OpenCL context (error %d)", err);
        abort();
    }
    if ((queue = clCreateCommandQueue(context, devices[idevice], 0, &err)) == 0)
    {
        print_error("initialize_cl: error: Failed to create OpenCL comand queue (error %d)", err);
        abort();
    }

    initialized = 1;
#else
    print_error("initialize_cl: error: CL support not compiled in (missing -Duseclblast)");
    abort();
#endif
}


/**
 * \brief   Finalize OpenCL
 * \authors J Benda
 * \date    2024
 *
 * Release OpenCL queue and context.
 */
void finalize_cl ()
{
    if (!is_initialized_cl())
        return;

#ifdef useclblast
    clReleaseCommandQueue(queue);
    clReleaseContext(context);

    initialized = 0;
#else
    print_error("finalize_cl: error: CL support not compiled in (missing -Duseclblast)");
    abort();
#endif
}


/**
 * \brief   Round up to nearest multiple of vwidth
 * \authors J Benda
 * \date    2024
 *
 * Round the provided number up to the nearest multiple of the preferred work-group size.
 */
int pad (int m)
{
    return ((m + vwidth - 1) / vwidth) * vwidth;
}


/**
 * \brief   Calculate R-matrix
 * \authors J Benda
 * \date    2024 - 2025
 *
 * Evaluate R-matrix from the defining formula w [E - e]^{-1} wT. See mpi_rsolve/residr.
 *
 * CLBlast matrix multiplication kernel prefers the matrices A, B and C to be stored in a specific way to avoid
 * transposing them into an intermediate buffer. Luckily, the UKRmol+ layout of wamp is compatible with this, and
 * the layout of rmat is irrelevant, because it is a symmetric matrix. However, dimensions of all three matrices
 * should also be multiples of the work-group size, which is a device-dependent constant. To conform to this
 * storage requirement we extend and pad (see \ref pad) all participating matrices by zeros. This is done internally
 * in the initialization step and does not affect inputs or outputs of this function.
 *
 * Additionally, some GPU drivers disallow use of arbitrarily large buffers. Frequently the buffer size is limited
 * to 2 GiB. This subroutine divides the boundary amplitude matrix to multiple buffers, so that the usability of this
 * routine is only restricted by the total available memory and not by the maximal buffer size.
 *
 * \param[in] stage     Which stage to perform (0 = initialized OpenCL buffers, 1 = execute, 2 = release buffers)
 * \param[in] nchan     Number of outer region partial wave channels.
 * \param[in] nstat     Number of R-matrix poles.
 * \param[in] compress  Whether to return only triangular part of the R-matrix.
 * \param[in] alpha     Scalar multiplication factor to apply to the result.
 * \param[in] epole     Array of R-matrix poles.
 * \param[in] etotr     Total energy of the system (same units as epole).
 * \param[in] wamp      Boundary amplitudes as column-major matrix of dimension nchan×nstat.
 * \param[in] ld        Leading dimension of wamp.
 * \param[out] rmat     Buffer for R-matrix output. When `compress = 0`, it is expected to be a column-major matrix
 *                      of dimension nchan×nchan. When `compress = 1`, it is a linear storage of length nchan*(nchan + 1)/2.
 */
void residr_cl (int stage, int nchan, int nstat, int compress, double alpha, double *epole, double etotr, double *wamp, int ld, double *rmat)
{
#ifdef useclblast
    static cl_mem *clWamp, clAmpae, clRmat;
    static double* rmat2;

    enum ResidrClStage
    {
        Initialize = 0,
        Execute    = 1,
        Finalize   = 2
    };

    // divide wmat into blocks in the row dimension (nchan), so that none of the blocks is significantly larger than 1 GiB
    size_t nblocks = 1 + (nstat*sizeof(double)*nchan >> 30);
    size_t nrows = (nchan + nblocks - 1)/nblocks;

    // dimensions rounded up to nearest multiple of vwidth
    size_t nrows2 = pad(nrows);
    size_t nstat2 = pad(nstat);

    if (!is_initialized_cl())
    {
        print_error("residr_cl: error: call initialize_cl first!");
        abort();
    }

    switch (stage)
    {
        case Initialize:
        {
            cl_int err;

            double* wamp2 = (double*)malloc(nrows2*nstat2*sizeof(double));

            clWamp = (cl_mem*)malloc(nblocks*sizeof(cl_mem));

            for (int iblock = 0; iblock < nblocks; iblock++)
            {
                // copy wamp to the padded storage
                #pragma omp parallel for
                for (size_t i = 0; i < nstat2; i++)
                    for (size_t j = 0; j < nrows2; j++)
                        wamp2[i*nrows2 + j] = (i < nstat && j < nrows && iblock*nrows + j < nchan) ? wamp[i*ld + iblock*nrows + j] : 0;

                // create device buffers
                if ((clWamp[iblock] = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, nrows2*nstat2*sizeof(double), wamp2, &err)) == NULL)
                {
                    print_error("residr_cl: clCreateBuffer(WAMP) failed (error %d)", err);
                    abort();
                }
            }

            if ((clAmpae = clCreateBuffer(context, CL_MEM_READ_WRITE, nrows2*nstat2*sizeof(double), NULL, &err)) == NULL)
            {
                print_error("residr_cl: clCreateBuffer(AMPAE) failed (error %d)", err);
                abort();
            }
            if ((clRmat = clCreateBuffer(context, CL_MEM_READ_WRITE, nrows2*nrows2*sizeof(double), NULL, &err)) == NULL)
            {
                print_error("residr_cl: clCreateBuffer(RMAT) failed (error %d)", err);
                abort();
            }

            // create output host buffer
            rmat2 = (double*)malloc(nrows2*nrows2*sizeof(double));

            // release host copy of wamp2
            free(wamp2);

            break;
        }

        case Execute:
        {
            CLBlastStatusCode status;

            cl_int err;

            double poles[nstat];
            size_t offsets[nstat];

            // calculate R-matrix poles
            #pragma omp parallel for
            for (int i = 0; i < nstat; i++)
            {
                poles[i] = 1/(epole[i] - etotr);
                offsets[i] = i*nrows2;
            }

            for (int jblock = 0; jblock < nblocks; jblock++)
            {
                // fill ampae with zeros
                if ((err = clEnqueueFillBuffer(queue, clAmpae, &zero, sizeof(zero), 0, nrows2*nstat2*sizeof(double), 0, NULL, NULL)) != CL_SUCCESS)
                {
                    print_error("residr_cl: clEnqueueFillBuffer failed (error %d)", err);
                    abort();
                }

                // add poles*wmat to ampae
                if ((status = CLBlastDaxpyBatched(nrows2, poles, clWamp[jblock], offsets, 1, clAmpae, offsets, 1, nstat, &queue, NULL)) != CLBlastSuccess)
                {
                    print_error("residr_cl: CLBlastDaxpyBatched failed (error %d)", status);
                    abort();
                }

                for (int iblock = 0; iblock <= jblock; iblock++)
                {
                    // multiply wmat * ampae -> rmat
                    if ((status = CLBlastDgemm(CLBlastLayoutColMajor, CLBlastTransposeNo, CLBlastTransposeYes, nrows2, nrows2, nstat2, alpha, clWamp[iblock], 0, nrows2, clAmpae, 0, nrows2, 0., clRmat, 0, nrows2, &queue, NULL)) != CLBlastSuccess)
                    {
                        print_error("residr_cl: CLBlastDgemm failed (error %d)", status);
                        abort();
                    }

                    // download the resulting R-matrix
                    if ((err = clEnqueueReadBuffer(queue, clRmat, CL_TRUE, 0, nrows2*nrows2*sizeof(double), rmat2, 0, NULL, NULL)) != CL_SUCCESS)
                    {
                        print_error("residr_cl: clEnqueueReadBuffer failed (error %d)", err);
                        abort();
                    }

                    // wait for completion of the pipeline
                    if ((err = clFinish(queue)) != CL_SUCCESS)
                    {
                        print_error("residr_cl: clFinish failed (error %d)", err);
                        abort();
                    }

                    size_t imin = iblock*nrows, imax = iblock + 1 < nblocks ? (iblock + 1)*nrows : nchan;
                    size_t jmin = jblock*nrows, jmax = jblock + 1 < nblocks ? (jblock + 1)*nrows : nchan;

                    // extract triangle from the symmetric R-matrix (or copy full)
                    if (compress)
                    {
                        #pragma omp parallel for schedule(dynamic,1)
                        for (size_t j = jmin; j < jmax; j++)
                            for (size_t i = imin; i < imax && i <= j; i++)
                                rmat[j*(j+1)/2 + i] = rmat2[(j - jblock*nrows)*nrows2 + (i - iblock*nrows)];
                    }
                    else
                    {
                        #pragma omp parallel for
                        for (size_t j = jmin; j < jmax; j++)
                            for (size_t i = imin; i < imax && i <= j; i++)
                                rmat[i*nchan + j] = rmat[j*nchan + i] = rmat2[(j - jblock*nrows)*nrows2 + (i - iblock*nrows)];
                    }
                }
            }

            break;
        }

        case Finalize:
        {
            for (int iblock = 0; iblock < nblocks; iblock++)
                clReleaseMemObject(clWamp[iblock]);

            clReleaseMemObject(clAmpae);
            clReleaseMemObject(clRmat);

            free(clWamp);
            free(rmat2);

            break;
        }
    }
#else
    print_error("residr_cl: error: CL support not compiled in (missing -Duseclblast)");
    abort();
#endif
}


/**
 * \brief   Real matrix-matrix multiplication (OpenCL)
 * \authors J Benda
 * \date    2024
 *
 * Use CLBlast to multiply C = alpha*op(A)*op(B) + beta*C. The signature is of the standard dgemm except that `transa` and `transb`
 * are integers 0 or 1 corresponding to 'N' or 'T'.
 */
void dgemm_cl (int transa, int transb, int m, int n, int k, double alpha, double *A, int lda, double *B, int ldb, double beta, double *C, int ldc)
{
#ifdef useclblast
    cl_mem clA, clB, clC;
    cl_int err;

    CLBlastStatusCode status;

    if (!is_initialized_cl())
    {
        print_error("dgemm_cl: error: call initialize_cl first!");
        abort();
    }

    // create device buffers
    if ((clA = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, m*k*sizeof(double), A, &err)) == NULL)
    {
        print_error("dgemm_cl: clCreateBuffer(A) failed (error %d)", err);
        abort();
    }
    if ((clB = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, k*n*sizeof(double), B, &err)) == NULL)
    {
        print_error("dgemm_cl: clCreateBuffer(B) failed (error %d)", err);
        abort();
    }
    if ((clC = clCreateBuffer(context, CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR, m*n*sizeof(double), C, &err)) == NULL)
    {
        print_error("dgemm_cl: clCreateBuffer(C) failed (error %d)", err);
        abort();
    }

    CLBlastTranspose ta = transa ? CLBlastTransposeYes : CLBlastTransposeNo;
    CLBlastTranspose tb = transb ? CLBlastTransposeYes : CLBlastTransposeNo;

    // multiply
    if ((status = CLBlastDgemm(CLBlastLayoutColMajor, ta, tb, m, n, k, alpha, clA, 0, lda, clB, 0, ldb, beta, clC, 0, ldc, &queue, NULL)) != CLBlastSuccess)
    {
        print_error("dgemm_cl: CLBlastDgemm failed (error %d)", status);
        abort();
    }

    // download the resulting matrix
    if ((err = clEnqueueReadBuffer(queue, clC, CL_TRUE, 0, m*n*sizeof(double), C, 0, NULL, NULL)) != CL_SUCCESS)
    {
        print_error("dgemm_cl: clEnqueueReadBuffer failed (error %d)", err);
        abort();
    }

    // wait for completion of the pipeline
    if ((err = clFinish(queue)) != CL_SUCCESS)
    {
        print_error("dgemm_cl: clFinish failed (error %d)", err);
        abort();
    }

    // finalize
    clReleaseMemObject(clA);
    clReleaseMemObject(clB);
    clReleaseMemObject(clC);
#else
    print_error("dgemm_cl: error: CL support not compiled in (missing -Duseclblast)");
    abort();
#endif
}


/**
 * \brief   Complex matrix-matrix multiplication (OpenCL)
 * \authors J Benda
 * \date    2024
 *
 * Use CLBlast to multiply C = alpha*op(A)*op(B) + beta*C. The signature is of the standard zgemm except that `transa` and `transb`
 * are integers 0 or 1 corresponding to 'N' or 'T'.
 */
void zgemm_cl (int transa, int transb, int m, int n, int k, double complex alpha, double complex *A, int lda, double complex *B, int ldb, double complex beta, double complex *C, int ldc)
{
#ifdef useclblast
    cl_mem clA, clB, clC;
    cl_int err;
    cl_double2 calpha, cbeta;

    CLBlastStatusCode status;

    if (!is_initialized_cl())
    {
        print_error("zgemm_cl: error: call initialize_cl first!");
        abort();
    }

    // create device buffers
    if ((clA = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, m*k*sizeof(double complex), A, &err)) == NULL)
    {
        print_error("zgemm_cl: clCreateBuffer(A) failed (error %d)", err);
        abort();
    }
    if ((clB = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, k*n*sizeof(double complex), B, &err)) == NULL)
    {
        print_error("zgemm_cl: clCreateBuffer(B) failed (error %d)", err);
        abort();
    }
    if ((clC = clCreateBuffer(context, CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR, m*n*sizeof(double complex), C, &err)) == NULL)
    {
        print_error("zgemm_cl: clCreateBuffer(C) failed (error %d)", err);
        abort();
    }

    CLBlastTranspose ta = transa ? CLBlastTransposeYes : CLBlastTransposeNo;
    CLBlastTranspose tb = transb ? CLBlastTransposeYes : CLBlastTransposeNo;

    calpha.x = creal(alpha);  cbeta.x = creal(beta);
    calpha.y = cimag(alpha);  cbeta.y = cimag(beta);

    // multiply
    if ((status = CLBlastZgemm(CLBlastLayoutColMajor, ta, tb, m, n, k, calpha, clA, 0, lda, clB, 0, ldb, cbeta, clC, 0, ldc, &queue, NULL)) != CLBlastSuccess)
    {
        print_error("zgemm_cl: CLBlastDgemm failed (error %d)", status);
        abort();
    }

    // download the resulting matrix
    if ((err = clEnqueueReadBuffer(queue, clC, CL_TRUE, 0, m*n*sizeof(double complex), C, 0, NULL, NULL)) != CL_SUCCESS)
    {
        print_error("zgemm_cl: clEnqueueReadBuffer failed (error %d)", err);
        abort();
    }

    // wait for completion of the pipeline
    if ((err = clFinish(queue)) != CL_SUCCESS)
    {
        print_error("zgemm_cl: clFinish failed (error %d)", err);
        abort();
    }

    // finalize
    clReleaseMemObject(clA);
    clReleaseMemObject(clB);
    clReleaseMemObject(clC);
#else
    print_error("zgemm_cl: error: CL support not compiled in (missing -Duseclblast)");
    abort();
#endif
}
