/**********************************************************************************
#* This file is distributed under the terms of the GNU General Public License
#* as published by the Free Software Foundation, either version 3 of
#* the License, or (at your option) any later version.
#* See the file LICENSE in the root directory of this distribution
#* or <http://www.gnu.org/licenses/>.
#*
#* This file uses Libint library (https://github.com/evaleev/libint) for computing 
#* the atomic orbitals overlaps for two different geometries. This can be used to
#* compute the molecular orbital overlaps for non-adiabatic molecular dynamics.
#* The code uses Pybid11 to interface the code with Python.
#*
#* This code maily uses the test files in 
#* (https://github.com/evaleev/libint/tree/master/tests/hartree-fock) with 
#* modifications to be used for computing the atomic orbitals overlap.
*********************************************************************************/

// standard C++ headers
#include <cmath>
#include <iostream>
#include <fstream>
#include <sstream>
#include <iomanip>
#include <vector>
#include <chrono>
#include <thread>

// Eigen matrix algebra library
#include <Eigen/Dense>
#include <Eigen/Eigenvalues>

// Pybind11 for interface with Python
#include <pybind11.h>
#include <pybind11/eigen.h>
#include <numpy.h>
#include <stl.h>
#include <complex.h>
#include <functional.h>
#include <chrono.h>

// Libint Gaussian integrals library
#include <libint2.hpp>

#if !LIBINT2_CONSTEXPR_STATICS
#  include <libint2/statics_definition.h>
#endif

// OpenMP library
#if defined(_OPENMP)
#include <omp.h>
#endif


namespace py = pybind11;
using real_t = libint2::scalar_type;
// Defining the Matrix type using Eigen library. With Pybind11 the returned matrix
// one can get the results in a numpy array.
typedef Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
        Matrix;  
// Defining the Atom structure
struct Atom {
    int atomic_number;
    double x, y, z;
};

// Number of basis
size_t nbasis(const std::vector<libint2::Shell>& shells);
// Maps shells to basis functions
std::vector<size_t> map_shell_to_basis_function(const std::vector<libint2::Shell>& shells);
// Computing the integrals between two shells (not OpenMP)
Matrix compute_1body_ints(const std::vector<libint2::Shell>& shells_1, const std::vector<libint2::Shell>& shells_2,libint2::Operator t);
// Computing the integrals between two shells (OpenMPI parallel)
Matrix compute_1body_ints_parallel(const std::vector<libint2::Shell>& shells_1, const std::vector<libint2::Shell>& shells_2,libint2::Operator obtype);

// Namespace for performing parallel calculations (adopted from Libint test files)
namespace libint2 {
   int nthreads;

   // fires off \c nthreads instances of lambda in parallel
   template <typename Lambda>
   void parallel_do(Lambda& lambda) {
   #ifdef _OPENMP
   #pragma omp parallel
   {
     auto thread_id = omp_get_thread_num();
     lambda(thread_id);
   }
   #else  // use C++11 threads
   std::vector<std::thread> threads;
   for (int thread_id = 0; thread_id != libint2::nthreads; ++thread_id) {
     if (thread_id != nthreads - 1)
       threads.push_back(std::thread(lambda, thread_id));
     else
       lambda(thread_id);
   }  // threads_id
   for (int thread_id = 0; thread_id < nthreads - 1; ++thread_id)
     threads[thread_id].join();
   #endif
   }
}

// I don't know if that's necessary or not but for now I'll keep it.

int main(int argc, char *argv[]) {
    return 0;
}

// The main function for computing the AO overlaps. We will use this in python.
Matrix compute_overlaps(const std::vector<libint2::Shell>& shells_1, const std::vector<libint2::Shell>& shells_2, int& number_of_threads) {

    using libint2::Operator;
    
    using libint2::nthreads;
    auto nthreads_cstr = getenv("LIBINT_NUM_THREADS");
    nthreads = number_of_threads;
    if (nthreads_cstr && strcmp(nthreads_cstr, "")) {
       std::istringstream iss(nthreads_cstr);
       iss >> nthreads;
       if (nthreads > 1 << 16 || nthreads <= 0) nthreads = 1;
    }

#if defined(_OPENMP)
      omp_set_num_threads(nthreads);
#endif
      std::cout << "Will scale over " << nthreads
#if defined(_OPENMP)
                << " OpenMP"
#else
                << " C++11"
#endif
                << " threads" << std::endl;
    // Initialize Libint    
    libint2::initialize();
	// Compute the AO overlap matrix
    auto S = compute_1body_ints_parallel(shells_1, shells_2, Operator::overlap);
    std::cout << "\n\tFinished computing overlap integral\n";
	// End of AO matrix calculation
    libint2::finalize(); // done with libint

    return S;
}


// This function adds the basis sets for each atomic type to a libint2::Shell variable
std::vector<libint2::Shell> add_to_shell(std::vector<libint2::Shell> shells, int& l_val, int& spherical , const std::vector<float>& exp, const std::vector<float>& coeff, const Atom &atom){

    // We need libint2::svector (take a look at shell.h file in libint library folder)
	// q for exponents
    libint2::svector<double> q;
    // p for contraction coefficients (libint will automatically converts them to refer to 
	// normalization free primitives before computing integrals) (Shell::renorm in shell.h)
    libint2::svector<double> p;
	
    // Appending the exponents and contraction coefficients from Python to q and p respectively.
    for (int i = 0; i < exp.size(); i++) {
        q.push_back(exp[i]);
        p.push_back(coeff[i]);
    }

    // The spherical/cartesian flag for shell
    bool is_spherical;
    if (spherical==1) {
        is_spherical = true;
    }
    else {
        is_spherical = false;
    };
	
	shells.push_back(
                     {
                       q, // exponents of primitive Gaussians
                       {  // contraction 0: s shell (l=0), spherical=false, contraction coefficients
                         {l_val, is_spherical, p}
                       },
                       {{atom.x,atom.y,atom.z}}   // origin coordinates
                     }
                    );
    // Return the shell with new data
    return shells;

}

// This function prints the shells. It is useful to check the results (this is also interfaced with Python using Pybind11)
void print_shell(std::vector<libint2::Shell> shells){

    std::cout << "\n\tShells are:\n";
    for (auto s=0; s<shells.size(); ++s)
    std::cout << shells[s] << std::endl;
    // The shell size
    std::cout << "\n The shell size is:\n" << shells.size() << std::endl;
}

// The Atom structure plays an important role in creating shells and computing the integrals. 
// With this function we initialize an Atom structure in Python.
Atom initialize_atom(const std::vector<float>& coords) {


    Atom atom;
	// As was in the libint test files, we NEED to convert angstrom to Bohr.
    const auto angstrom_to_bohr = 1 / 0.52917721092;
    atom.atomic_number = 0;
    atom.x = coords[0] * angstrom_to_bohr;
    atom.y = coords[1] * angstrom_to_bohr;
    atom.z = coords[2] * angstrom_to_bohr;
    // Return the atom position in Atom format
    return atom;
}

// We first need to initailize a libint2::Shell. Here is how we do it.
std::vector<libint2::Shell> initialize_shell(int& l_val, int& spherical , const std::vector<float>& exp, const std::vector<float>& coeff, const Atom &atom) {


  using libint2::Shell;
  using std::cout;
  using std::endl;
  std::vector<Shell> shells;

  // The same as above explanation for exponents and contraction coefficients
  libint2::svector<double> q;
  libint2::svector<double> p;
  for (int i = 0; i < exp.size(); i++) {
       q.push_back(exp[i]);
       p.push_back(coeff[i]);
      }

  bool is_spherical;
  if (spherical==1) {
      is_spherical = true;
  }
  else {
      is_spherical = false;
  };
  // Append to shell
  shells.push_back(
                       {
                         q, // exponents of primitive Gaussians
                         {  // contraction 0: s shell (l=0), spherical=false, contraction coefficients
                           {l_val, is_spherical, p}
                         },
                         {{atom.x,atom.y,atom.z}}   // origin coordinates
                       }
                   );

   return shells;
}

// Below are the functions that we use for starting the engine for computing the overlap integrals.

size_t nbasis(const std::vector<libint2::Shell>& shells) {
  size_t n = 0;
  for (const auto& shell: shells)
    n += shell.size();
  return n;
}

size_t max_nprim(const std::vector<libint2::Shell>& shells) {
  size_t n = 0;
  for (auto shell: shells)
    n = std::max(shell.nprim(), n);
  return n;
}

int max_l(const std::vector<libint2::Shell>& shells) {
  int l = 0;
  for (auto shell: shells)
    for (auto c: shell.contr)
      l = std::max(c.l, l);
  return l;
}

std::vector<size_t> map_shell_to_basis_function(const std::vector<libint2::Shell>& shells) {
  std::vector<size_t> result;
  result.reserve(shells.size());

  size_t n = 0;
  for (auto shell: shells) {
    result.push_back(n);
    n += shell.size();
  }

  return result;
}

// -----------

// Computing the integrals between two shells in parallel using OpenMP (This is also adopted from libint test files with some modifications)
// In the main file we can also use other operators such as kinetic, nuclear, etc but since we only need overlap operator here we omit those parts.
Matrix compute_1body_ints_parallel(const std::vector<libint2::Shell>& shells_1, const std::vector<libint2::Shell>& shells_2,libint2::Operator obtype)
{
  using libint2::Shell;
  using libint2::Engine;
  using libint2::Operator;
  using libint2::nthreads;

  const auto n = nbasis(shells_1);
  // The AO overlap matrix n x n
  Matrix result(n,n);

  // Initializing different engines based on the nthreads
  std::vector<libint2::Engine> engines(nthreads);
  // Initializing the first engine. Others will be the same as this one (as is shown below)
  engines[0] = libint2::Engine(obtype, max_nprim(shells_1), max_l(shells_1), 0);
  // This part is for other operators, I'll keep it but not necessary.
  //  engines[0].set_params(oparams);
  // Other engines
  for (size_t i = 1; i != nthreads; ++i) {
    engines[i] = engines[0];
  }

  // Mapping the shells into basis sets
  auto shell2bf1 = map_shell_to_basis_function(shells_1);
  auto shell2bf2 = map_shell_to_basis_function(shells_2);
  // Compute for each thread_id
  auto compute = [&](int thread_id) {

    const auto& buf = engines[thread_id].results();

    for (auto s1 = 0l, s12 = 0l; s1 != shells_1.size(); ++s1) {
      auto bf1 = shell2bf1[s1];     // first basis function in this shell
      auto n1 = shells_1[s1].size();

      auto s1_offset = s1 * (s1+1) / 2;
      for (auto s2=0; s2<=s1; ++s2) {

        auto s12 = s1_offset + s2;
        if (s12 % nthreads != thread_id) continue;
          auto bf2 = shell2bf2[s2];   // first basis function in this shell
          auto n2 = shells_2[s2].size();

        auto n12 = n1 * n2;
		// Make compute for each engine
        engines[thread_id].compute(shells_1[s1], shells_2[s2]);
        // The results block 
        Eigen::Map<const Matrix> buf_mat(buf[0], n1, n2);
        result.block(bf1, bf2, n1, n2) = buf_mat;
        if (s1 != s2)  // if s1 >= s2, copy {s1,s2} to the corresponding
                       // {s2,s1} block, note the transpose!
        result.block(bf2, bf1, n2, n1) = buf_mat.transpose();
        }
      }
    };
  // Now for compute, do parallel
  libint2::parallel_do(compute);
  // Return the AO overlap values
  return result;                         
}

// Computing the integrals between two shells serial
Matrix compute_1body_ints(const std::vector<libint2::Shell>& shells_1, const std::vector<libint2::Shell>& shells_2,libint2::Operator obtype)
{
  // The explanation are as above
  using libint2::Shell;
  using libint2::Engine;
  using libint2::Operator;

  const auto n = nbasis(shells_1);
  Matrix result(n,n);
  Engine engine(obtype, max_nprim(shells_1), max_l(shells_1), 0);

  auto shell2bf = map_shell_to_basis_function(shells_1);

  // buf[0] points to the target shell set after every call  to engine.compute()
  const auto& buf = engine.results();

  // loop over unique shell pairs, {s1,s2} such that s1 >= s2
  // this is due to the permutational symmetry of the real integrals over Hermitian operators: (1|2) = (2|1)
  for(auto s1=0; s1!=shells_1.size(); ++s1) {

    auto bf1 = shell2bf[s1]; // first basis function in this shell
    auto n1 = shells_1[s1].size();

    for(auto s2=0; s2<=s1; ++s2) {

      auto bf2 = shell2bf[s2];
      auto n2 = shells_1[s2].size();

      // compute shell pair
      engine.compute(shells_1[s1], shells_2[s2]);

      // "map" buffer to a const Eigen Matrix, and copy it to the corresponding blocks of the result
      Eigen::Map<const Matrix> buf_mat(buf[0], n1, n2);
      result.block(bf1, bf2, n1, n2) = buf_mat;
      if (s1 != s2) // if s1 >= s2, copy {s1,s2} to the corresponding {s2,s1} block, note the transpose!
      result.block(bf2, bf1, n2, n1) = buf_mat.transpose();
    }
  }

  return result;
}

// This is a test function, NOT necessary :)
int add(int i, int j) {
    return i + j;
}

// Making the Pybind11 module for importing into Python
// I set the name to aux_funcs (auxiliary functions) that are used for computing the 
// molecular orbital overlaps by computing the AO overlaps.
PYBIND11_MODULE(aux_funcs, m) {
    m.doc() = "Auxiliary functions to ues Libint for calculating the AO overlap matrix"; // optional module docstring
	// The Atom class
    py::class_<Atom>(m, "Atom")
        .def_readwrite("atomic_number", &Atom::atomic_number)
        .def_readwrite("x", &Atom::x)
        .def_readwrite("y", &Atom::y)
        .def_readwrite("z", &Atom::z);
    // The libint2::Shell class
    py::class_<libint2::Shell>(m, "Shell");
    // The libint2::Operator class
    py::class_<libint2::Operator>(m, "Operator");
    m.def("nbasis", &nbasis, "A function for number of basis sets", py::return_value_policy::reference_internal);
    m.def("compute_1body_ints_parallel", &compute_1body_ints_parallel, "A function for computing AO overlaps parallel", py::return_value_policy::reference_internal);
    m.def("compute_1body_ints", &compute_1body_ints, "A function for computing integrals serial", py::return_value_policy::reference_internal);
    m.def("initialize_shell", &initialize_shell, "A function for creating shells", py::return_value_policy::reference_internal);
    m.def("initialize_atom", &initialize_atom, "A function for initializing Atom", py::return_value_policy::reference_internal);
    m.def("add_to_shell", &add_to_shell, "A function to add the exponents, contraction coefficients, and Atom structure to a predefined shell", py::return_value_policy::reference_internal);
    m.def("print_shell", &print_shell, "A function for printing a shell", py::return_value_policy::reference_internal);
    m.def("compute_overlaps", &compute_overlaps, "A function that computes overlaps between pairs of shells", py::return_value_policy::reference_internal);
}


