/*
 * This file defines a template for functions that require only one subspace. It should be included
 * multiple times in bpetsc_impl.c, with SUBSPACE defined as the desired value.
 */

#include "bpetsc_template_1.h"
#if PETSC_HAVE_CUDA
#include "bcuda_template_1.h"
#endif

/*
 * this function is actually the same for all subspaces but
 * I keep it here for organizational purposes
 */
static inline PetscInt C(reduce_state,SUBSPACE)(
  PetscInt state,
  const PetscInt* keep,
  PetscInt keep_size
){
  PetscInt rtn = 0;
  PetscInt i;
  for (i=keep_size; i>0; --i) {
    rtn <<= 1;
    rtn |= (state >> keep[i-1]) & 1;
  }
  return rtn;
}

static inline PetscInt C(combine_states,SUBSPACE)(
  PetscInt keep_state,
  PetscInt tr_state,
  const PetscInt* keep,
  PetscInt keep_size,
  PetscInt L
){
  PetscInt rtn = 0;
  PetscInt state_idx, keep_idx, tr_idx;
  PetscInt bit;
  // keep needs to be ordered for this!
  // ordering is checked in rdm below
  keep_idx = 0;
  tr_idx = 0;
  for (state_idx=0; state_idx<L; ++state_idx) {
    if (keep_idx < keep_size && state_idx == keep[keep_idx]) {
      bit = (keep_state >> keep_idx) & 1;
      ++keep_idx;
    }
    else {
      bit = (tr_state >> tr_idx) & 1;
      ++tr_idx;
    }
    rtn |= bit << state_idx;
  }
  return rtn;
}

void C(fill_combine_array,SUBSPACE)(
  PetscInt tr_state,
  PetscInt keep_size,
  const C(data,SUBSPACE)* sub_data_p,
  const PetscInt* keep,
  const PetscScalar *x_array,
  PetscInt *state_array,
  PetscScalar *combine_array,
  PetscInt* n_filled_p)
{
  PetscInt keep_dim, keep_state, full_state;
  PetscInt idx;

  keep_dim = 1<<keep_size;
  *n_filled_p = 0;
  for (keep_state=0; keep_state<keep_dim; ++keep_state) {
    full_state = C(combine_states,SUBSPACE)(keep_state, tr_state, keep, keep_size, sub_data_p->L);

    idx = C(S2I,SUBSPACE)(full_state, sub_data_p);

    if (idx != -1) {
      state_array[*n_filled_p] = keep_state;
      combine_array[*n_filled_p] = x_array[idx];
      ++(*n_filled_p);
    }
  }
}

#undef  __FUNCT__
#define __FUNCT__ "rdm"
PetscErrorCode C(rdm,SUBSPACE)(
  Vec vec,
  const C(data,SUBSPACE)* sub_data_p,
  PetscInt keep_size,
  const PetscInt* keep,
  PetscBool triang,
  PetscInt rtn_dim,
  PetscScalar* rtn
){

  const PetscScalar *v0_array;
  PetscInt i, j, n_filled, offset;
  PetscInt tr_state, tr_dim;
  int mpi_size, mpi_rank;
  PetscScalar a;
  Vec v0;
  VecScatter scat;

  PetscInt *state_array;
  PetscScalar *combine_array;

  for (i=1; i<keep_size; ++i) {
    if (keep[i] <= keep[i-1]) {
      SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_ARG_WRONG, "keep array must be strictly increasing");
    }
  }

  PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &mpi_size));
  PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &mpi_rank));

  /* scatter to process 0 */
  /* in the future, perhaps will do this in parallel */
  /* could be achieved with a round-robin type deal, like we do with MatMult */
  if (mpi_size > 1) {
    PetscCall(VecScatterCreateToZero(vec, &scat, &v0));
    PetscCall(VecScatterBegin(scat, vec, v0, INSERT_VALUES, SCATTER_FORWARD));
    PetscCall(VecScatterEnd(scat, vec, v0, INSERT_VALUES, SCATTER_FORWARD));
    PetscCall(VecScatterDestroy(&scat));
  }
  else {
    v0 = vec;
  }

  /* we're done if we're rank > 0 */
  if (mpi_rank > 0) {
    VecDestroy(&v0);
    return 0;
  }

  PetscCall(VecGetArrayRead(v0, &v0_array));

  PetscCall(PetscMalloc1(1<<keep_size, &state_array));
  PetscCall(PetscMalloc1(1<<keep_size, &combine_array));

  PetscMemzero(rtn, sizeof(PetscScalar)*rtn_dim*rtn_dim);

  tr_dim = 1 << (sub_data_p->L - keep_size);
  for (tr_state = 0; tr_state < tr_dim; ++tr_state) {
    C(fill_combine_array,SUBSPACE)(tr_state, keep_size, sub_data_p, keep,
      v0_array, state_array, combine_array, &n_filled);
    for (i=0; i<n_filled; ++i) {
      offset = state_array[i]*rtn_dim;
      a = combine_array[i];
      for (j=0; j<n_filled; ++j) {
        rtn[offset + state_array[j]] += a*PetscConj(combine_array[j]);
      }
    }
  }

  PetscCall(PetscFree(state_array));
  PetscCall(PetscFree(combine_array));

  PetscCall(VecRestoreArrayRead(v0, &v0_array));
  if (mpi_size > 1) {
    VecDestroy(&v0);
  }

  return 0;
}

#undef  __FUNCT__
#define __FUNCT__ "PrecomputeDiagonal_CPU"
PetscErrorCode C(PrecomputeDiagonal_CPU,SUBSPACE)(Mat A){
  PetscInt row_start, row_end, row_idx, term_idx;
  PetscInt sign, state;
  PetscReal value;

  shell_context *ctx;
  PetscCall(MatShellGetContext(A, &ctx));

  if (ctx->masks[0] != 0) {
    /* there is no diagonal! leave diag as PETSC_NULLPTR */
    return 0;
  }

  PetscCall(MatGetOwnershipRange(A, &row_start, &row_end));

  PetscCall(PetscMalloc1(row_end-row_start, &(ctx->diag)));

  for (row_idx=row_start; row_idx<row_end; ++row_idx) {
    if (row_idx==row_start) {
      state = C(I2S,SUBSPACE)(row_idx, ctx->right_subspace_data);
    } else {
      state = C(NextState,SUBSPACE)(state, row_idx, ctx->right_subspace_data);
    }

    value = 0;
    for (term_idx=0; term_idx<ctx->mask_offsets[1]; ++term_idx) {
      sign = 1 - 2*(builtin_parity(state & ctx->signs[term_idx]));
      value += sign * ctx->real_coeffs[term_idx];
    }
    ctx->diag[row_idx-row_start] = value;
  }

  return 0;
}

#undef  __FUNCT__
#define __FUNCT__ "PrecomputeDiagonal"
PetscErrorCode C(PrecomputeDiagonal,SUBSPACE)(Mat A)
{
  PetscErrorCode ierr;
  shell_context *ctx;
  PetscCall(MatShellGetContext(A, &ctx));

  if (!(ctx->gpu)) {
    ierr = C(PrecomputeDiagonal_CPU,SUBSPACE)(A);
  }
#if PETSC_HAVE_CUDA
  else {
    ierr = C(PrecomputeDiagonal_GPU,SUBSPACE)(A);
  }
#else
  else {
    SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_UNKNOWN_TYPE, "GPU not enabled for this build");
  }
#endif
  return ierr;
}
