#include "extra_compile_macros.h"

#if USE_EXP == 1

#include <cstdlib>
#include <memory>
#include <cmath>
#include <filesystem>

namespace fs = std::filesystem;

// EXP headers
#include <Eigen/Eigen>
#include <EXP/Coefficients.H>
#include <EXP/BiorthBasis.H>
#include <EXP/FieldGenerator.H>

#include "exp_fields.h"
#include "src/vectorization.h"

namespace gala_exp {

State pyexp_init(
    BasisClasses::BasisPtr *basis_ptr,
    CoefClasses::CoefsPtr *coefs_ptr,
    double snapshot_time_factor
) {
    if (!basis_ptr) {
        throw std::runtime_error("pyexp_init: basis pointer is null");
    }

    if (!coefs_ptr) {
        throw std::runtime_error("pyexp_init: coefs pointer is null");
    }

    if (!*basis_ptr) {
        throw std::runtime_error("pyexp_init: basis is null");
    }

    if (!*coefs_ptr) {
        throw std::runtime_error("pyexp_init: coefs is null");
    }

    auto biorth_basis(
      std::dynamic_pointer_cast<BasisClasses::BiorthBasis>(
        *basis_ptr
      )
    );
    if (!biorth_basis) {
      throw std::runtime_error("pyEXP Basis must be a BiorthBasis.");
    }

    return { biorth_basis, *coefs_ptr, snapshot_time_factor, -1 };
}

State exp_init(
  const std::string &config_fn, const std::string &coeffile,
  int stride, double tmin, double tmax, int snapshot_index, double snapshot_time_factor)
{
  YAML::Node yaml = YAML::LoadFile(std::string(config_fn));

  auto load_basis = [](auto yaml, auto config_fn) -> auto
  {
    BasisClasses::BasisPtr base_basis;
    {
      // change the cwd to the directory of the config file
      // so that relative paths in the config file work
      // TODO: this is not thread-safe, threads share a cwd
      ScopedChdir cd(fs::path(config_fn).parent_path());

      base_basis = BasisClasses::Basis::factory(yaml);
    }

    if (!base_basis) {
      std::ostringstream error_msg;
      error_msg << "Failed to load basis from config file: " << config_fn;
      throw std::runtime_error(error_msg.str());
    }
    return base_basis;
  };

  auto biorth_basis(
    std::dynamic_pointer_cast<BasisClasses::BiorthBasis>(
      load_basis(yaml, config_fn)
    )
  );
  if (!biorth_basis) {
    std::ostringstream error_msg;
    error_msg << "Basis in config file " << config_fn << " must be a BiorthBasis.";
    throw std::runtime_error(error_msg.str());
  }

  auto coefs = CoefClasses::Coefs::factory(coeffile,
                                       stride, tmin, tmax);

  if(!coefs) {
    std::ostringstream error_msg;
    error_msg << "Failed to load coefficients from file: " << coeffile;
    throw std::runtime_error(error_msg.str());
  }

  try {
    // Turn the "pure virtual" error in a more informative message
    // TODO: is there a better way to "validate" the Coefs object?
    coefs->Times();
  } catch (const std::runtime_error& e) {
    std::ostringstream error_msg;
    error_msg << "Failed to load coefficients from file: " << coeffile
              << ". Error: " << e.what();
    throw std::runtime_error(error_msg.str());
  }

  if(coefs->Times().empty()) {
    std::ostringstream error_msg;
    error_msg << "No times in coeffile=" << coeffile
              << " within tmin=" << tmin
              << " and tmax=" << tmax
              << " (raw EXP snapshot time units).";
    throw std::runtime_error(error_msg.str());
  }

  return { biorth_basis, coefs, snapshot_time_factor, snapshot_index };
}

State::State(
  BiorthBasisPtr basis_,
  CoefClasses::CoefsPtr coefs_,
  double snapshot_time_factor_,
  int snapshot_index)
    : basis(basis_),
      coefs(coefs_),
      snapshot_time_factor(snapshot_time_factor_) {

  try {
    // Turn the "pure virtual" error in a more informative message
    // TODO: is there a better way to "validate" the Coefs object?
    coefs->Times();
  } catch (const std::runtime_error& e) {
    std::ostringstream error_msg;
    error_msg << "Failed to fetch Times from Coefs object. "
              << "Is this a valid, non-empty Coefs instance? "
              << "Error: " << e.what();
    throw std::runtime_error(error_msg.str());
  }

  if(coefs->Times().empty()) {
    throw std::runtime_error("No times in coefficients.");
  }

  if (coefs->Times().size() == 1 && snapshot_index < 0) {
    // If there is only one loaded snapshot in the coefs,
    // we treat it as static
    snapshot_index = 0;
  }

  bool is_static = false;
  double tmin, tmax;

  if (snapshot_index >= 0) {
    const auto& times = coefs->Times();
    if (snapshot_index >= times.size()) {
      std::ostringstream error_msg;
      error_msg << "Invalid snapshot_index: " << snapshot_index
                << ". Valid indices are in [0," << (times.size() - 1) << "]"
                << " (times [" << times.front() << ", " << times.back() << "])"
                << " (raw EXP snapshot time units).";
      throw std::runtime_error(error_msg.str());
    }
    tmin = times[snapshot_index];
    tmax = tmin;

    basis->set_coefs(coefs->getCoefStruct(tmin));
    is_static = true;
  } else {
    // Adjust tmin and tmax to the first and last times in the coefficients

    auto times = coefs->Times();
    tmin = times.front();
    tmax = times.back();

    is_static = (tmax == tmin);

    if (is_static) {
      basis->set_coefs(gala_exp::interpolator(tmin, coefs));
    }
  }

  this->is_static = is_static;
  this->tmin = tmin;
  this->tmax = tmax;
}

// Linear interpolator on coefficients.  Higher order interpolation
// could be implemented similarly.  This is the same implementation
// used in BiorthBasis and probably belongs in CoefClasses . . .
//
CoefClasses::CoefStrPtr interpolator(double t, CoefClasses::CoefsPtr coefs)
{
  // This routine requires at least two snapshots to interpolate
  assert(coefs->Times().size() >= 2);

  // Interpolate coefficients
  //
  auto times = coefs->Times();

  if (t<times.front() or t>times.back()) {
    std::ostringstream sout;
    sout << "FieldWrapper::interpolator: time t=" << t << " is out of bounds: ["
         << times.front() << ", " << times.back() << "] (raw EXP snapshot time units)";
    throw std::runtime_error(sout.str());
  }

  auto it1 = std::lower_bound(times.begin(), times.end(), t);
  auto it2 = it1 + 1;

  if (it2 == times.end()) {
    it2--;
    it1 = it2 - 1;
  }

  // Handle degenerate case where it1 == it2 (single time entry)
  if (it1 == it2 || *it1 == *it2) {
    return coefs->getCoefStruct(*it1);
  }

  double a = (*it2 - t)/(*it2 - *it1);
  double b = (t - *it1)/(*it2 - *it1);

  auto coefsA = coefs->getCoefStruct(*it1);
  auto coefsB = coefs->getCoefStruct(*it2);

  // Duplicate a coefficient instance.  Shared pointer for proper
  // garbage collection.
  //
  auto newcoef = coefsA->deepcopy();

  // Now interpolate the matrix
  //
  newcoef->time = t;

  auto & cN = newcoef->store;
  auto & cA = coefsA->store;
  auto & cB = coefsB->store;

  for (int i=0; i<newcoef->store.size(); i++)
    cN(i) = a * cA(i) + b * cB(i);

  // Interpolate the center data
  //
  if (coefsA->ctr.size() and coefsB->ctr.size()) {
    newcoef->ctr.resize(3);
    for (int k=0; k<3; k++)
      newcoef->ctr[k] = a * coefsA->ctr[k] + b * coefsB->ctr[k];
  }

  return newcoef;
}

}

/* ---------------------------------------------------------------------------
    EXP potential

    Calls the EXP code (https://github.com/exp-code/exp).
    Only available if EXP available at build time.
*/

double exp_value(double t, double *pars, double *q, int n_dim, void* state) {
  gala_exp::State *exp_state = static_cast<gala_exp::State *>(state);

  if (!exp_state->is_static) {
    // TODO: how expensive is this, actually?
    exp_state->basis->set_coefs(
      gala_exp::interpolator(t * exp_state->snapshot_time_factor, exp_state->coefs)
    );
  }

  // Get the field quantities
  // TODO: ask Martin/Mike for a way to compute only the potential - we're wasting
  // computation time here by computing all fields
  auto field = exp_state->basis->getFields(q[0], q[1], q[2]);

  return field[5];
}

void exp_gradient(double t, double *__restrict__ pars, double *__restrict__ q_in, int n_dim, size_t N, double *__restrict__ grad_in, void *__restrict__ state){
  gala_exp::State *exp_state = static_cast<gala_exp::State *>(state);

  if (!exp_state->is_static) {
    exp_state->basis->set_coefs(
      gala_exp::interpolator(t * exp_state->snapshot_time_factor, exp_state->coefs)
    );
  }

  double6ptr q = double6ptr{q_in, N};
  double6ptr grad = double6ptr{grad_in, N};

  Eigen::Map<Eigen::VectorXd> eigen_x(q.x, N);
  Eigen::Map<Eigen::VectorXd> eigen_y(q.y, N);
  Eigen::Map<Eigen::VectorXd> eigen_z(q.z, N);

  auto& allaccel = exp_state->basis->getAccel(eigen_x, eigen_y, eigen_z);

  for(size_t i = 0; i < N; i++) {
    grad.x[i] -= allaccel(i, 0);
    grad.y[i] -= allaccel(i, 1);
    grad.z[i] -= allaccel(i, 2);
  }

}

double exp_density(double t, double *pars, double *q, int n_dim, void* state) {
  gala_exp::State *exp_state = static_cast<gala_exp::State *>(state);

  if (!exp_state->is_static) {
    exp_state->basis->set_coefs(
      gala_exp::interpolator(t * exp_state->snapshot_time_factor, exp_state->coefs)
    );
  }

  // TODO: ask Martin/Mike for a way to compute only the density - we're wasting
  // computation time here by computing all fields
  auto field = exp_state->basis->getFields(q[0], q[1], q[2]);

  return field[2];
}

// TODO: No hessian available in EXP yet
// void exp_hessian(double t, double *pars, double *q, int n_dim, double *hess, void* state) {
//   gala_exp::State *exp_state = static_cast<gala_exp::State *>(state);

//   if (!exp_state->is_static) {
//     exp_state->basis->set_coefs(
//       gala_exp::interpolator(t * exp_state->snapshot_time_factor, exp_state->coefs)
//     );
//   }

//   auto field = exp_state->basis->getFields(q[0], q[1], q[2]);

//   for(int i=0; i<9; i++) {
//     hess[i] += NAN;  // TODO: get hessian from EXP
//   }
// }

#endif  // USE_EXP
