#ifndef LATTICE_H
#define LATTICE_H

#include <sstream>
#include <iostream>
#include <fstream>
#include <array>
#include <vector>

// SUBNODE_LAYOUT is now defined in main.mk
// #define SUBNODE_LAYOUT

#include "plumbing/defs.h"
#include "plumbing/coordinates.h"
#include "plumbing/timing.h"

#ifdef SUBNODE_LAYOUT
#ifndef VECTOR_SIZE
#if defined(CUDA) || defined(HIP)
#define VECTOR_SIZE 8 // Size of float, length 1 vectors
#else
#define VECTOR_SIZE (256 / 8) // this is for AVX2
#endif
#endif
// This is the vector size used to determine the layout
constexpr unsigned number_of_subnodes = VECTOR_SIZE / sizeof(float);
#endif

namespace hila {
/// list of field boundary conditions - used only if SPECIAL_BOUNDARY_CONDITIONS defined
enum class bc { PERIODIC, ANTIPERIODIC, DIRICHLET };

/// False if we have b.c. which does not require communication
inline bool bc_need_communication(hila::bc bc) {
    if (bc == hila::bc::DIRICHLET) {
        return false;
    } else {
        return true;
    }
}

} // namespace hila

void test_std_gathers();
void report_too_large_node(); // report on too large node

/// useful information about a node
struct node_info {
    CoordinateVector min, size;
    unsigned evensites, oddsites;
};

/// Some backends need specialized lattice data
/// in loops. Forward declaration here and
/// implementations in backend headers.
/// Loops generated by hilapp can access
/// this through lattice.backend_lattice.
struct backend_lattice_struct;

/// The lattice struct defines the lattice geometry ans sets up MPI communication
/// patterns
class lattice_struct {
  private:
    CoordinateVector l_size;
    size_t l_volume = 0; // use this to flag initialization

    int l_label; // running number, identification of the lattice (TODO)

  public:
    /// Information about the node stored on this process
    struct node_struct {
        lattice_struct *parent = nullptr; // parent lattice, in order to access methods
        int rank;                         // rank of this node
        size_t sites, evensites, oddsites;
        size_t field_alloc_size;    // how many sites/node in allocations
        CoordinateVector min, size; // node local coordinate ranges
        unsigned nn[NDIRS];         // nn-node of node down/up to dirs
        bool first_site_even;       // is location min even or odd?

#ifdef EVEN_SITES_FIRST
        std::vector<CoordinateVector> coordinates;
#endif

        Vector<NDIM, unsigned> size_factor; // components: 1, size[0], size[0]*size[1], ...

        void setup(node_info &ni, lattice_struct &lattice);

#ifdef SUBNODE_LAYOUT
        /// If we have vectorized-style layout, we introduce "subnodes"
        /// size is mynode.size/subnodes.divisions, which is not
        /// constant across nodes
        struct subnode_struct {
            CoordinateVector divisions, size; // div to subnodes to directions, size
            size_t sites, evensites, oddsites;
            Direction merged_subnodes_dir;

            void setup(const node_struct &tn);
        } subnodes;
#endif

        unsigned volume() const {
            return sites;
        }

        /// true if this node is on the edge of the lattice to dir d
        bool is_on_edge(Direction d) const {
            return (is_up_dir(d) && min[d] + size[d] == parent->size(d)) ||
                   (!is_up_dir(d) && min[-d] == 0);
        }

    } mynode;

    /// information about all nodes
    struct allnodes {
        int number;                   // number of nodes
        CoordinateVector n_divisions; // number of node divisions to dir
        // lattice division: div[d] will have num_dir[d]+1 elements, last size
        // TODO: is this needed at all?
        std::vector<unsigned> divisors[NDIM];
        CoordinateVector max_size; // size of largest node

        std::vector<node_info> nodelist;

        unsigned *RESTRICT map_array;   // mapping (optional)
        unsigned *RESTRICT map_inverse; // inv of it

        void create_remap();                      // create remap_node
        unsigned remap(unsigned i) const;         // use remap
        unsigned inverse_remap(unsigned i) const; // inverse remap

    } nodes;

    /// Information necessary to communicate with a node
    struct comm_node_struct {
        unsigned rank; // rank of communicated with node
        size_t sites, evensites, oddsites;
        size_t buffer; // offset from the start of field array
        unsigned *sitelist;

        // Get a vector containing the sites of parity par and number of elements
        const unsigned *RESTRICT get_sitelist(Parity par, int &size) const {
            if (par == ALL) {
                size = sites;
                return sitelist;
            } else if (par == EVEN) {
                size = evensites;
                return sitelist;
            } else {
                size = oddsites;
                return sitelist + evensites;
            }
        }

        // The number of sites that need to be communicated
        unsigned n_sites(Parity par) const {
            if (par == ALL) {
                return sites;
            } else if (par == EVEN) {
                return evensites;
            } else {
                return oddsites;
            }
        }

        // The local index of a site that is sent to neighbour
        unsigned site_index(unsigned site, Parity par) const {
            if (par == ODD) {
                return sitelist[evensites + site];
            } else {
                return sitelist[site];
            }
        }

        // The offset of the halo from the start of the field array
        unsigned offset(Parity par) const {
            if (par == ODD) {
                return buffer + evensites;
            } else {
                return buffer;
            }
        }
    };

    /// nn-communication has only 1 node to talk to
    struct nn_comminfo_struct {
        unsigned *index;
        comm_node_struct from_node, to_node;
        unsigned receive_buf_size; // only for general gathers
    };

    /// general communication
    struct gen_comminfo_struct {
        unsigned *index;
        std::vector<comm_node_struct> from_node;
        std::vector<comm_node_struct> to_node;
        size_t receive_buf_size;
    };

    /// nearest neighbour comminfo struct
    std::array<nn_comminfo_struct, NDIRS> nn_comminfo;

    /// Main neighbour index array
    unsigned *RESTRICT neighb[NDIRS];

    /// implement waiting using mask_t - unsigned char is good for up to 4 dim.
    dir_mask_t *RESTRICT wait_arr_;

#ifdef SPECIAL_BOUNDARY_CONDITIONS
    /// special boundary pointers are needed only in cases neighbour
    /// pointers must be modified (new halo elements). That is known only during
    /// runtime.
    struct special_boundary_struct {
        unsigned *neighbours;
        unsigned *move_index;
        size_t offset, n_even, n_odd, n_total;
        bool is_needed;
    };
    // holder for nb ptr info
    special_boundary_struct special_boundaries[NDIRS];
#endif

#ifndef VANILLA
    backend_lattice_struct *backend_lattice;
#endif

    void setup(const CoordinateVector &siz);
    void setup_layout();
    void setup_nodes();

    // Std accessors:
    // volume
    int64_t volume() const {
        return l_volume;
    }

    // size routines
    int size(Direction d) const {
        return l_size[d];
    }
    int size(int d) const {
        return l_size[d];
    }
    CoordinateVector size() const {
        return l_size;
    }

    int node_rank() const {
        return mynode.rank;
    }
    int n_nodes() const {
        return nodes.number;
    }
    // std::vector<node_info> nodelist() { return nodes.nodelist; }
    // CoordinateVector min_coordinate() const { return mynode.min; }
    // int min_coordinate(Direction d) const { return mynode.min[d]; }

    bool is_on_mynode(const CoordinateVector &c) const;
    int node_rank(const CoordinateVector &c) const;
    unsigned site_index(const CoordinateVector &c) const;
    unsigned site_index(const CoordinateVector &c, const unsigned node) const;
    unsigned field_alloc_size() const {
        return mynode.field_alloc_size;
    }

    void create_std_gathers();
    gen_comminfo_struct create_general_gather(const CoordinateVector &r);
    std::vector<comm_node_struct> create_comm_node_vector(CoordinateVector offset, unsigned *index,
                                                          bool receive);

    bool first_site_even() const {
        return mynode.first_site_even;
    };

#ifdef SPECIAL_BOUNDARY_CONDITIONS
    void init_special_boundaries();
    void setup_special_boundary_array(Direction d);

    const unsigned *get_neighbour_array(Direction d, hila::bc bc);
#else
    const unsigned *get_neighbour_array(Direction d, hila::bc bc) {
        return neighb[d];
    }
#endif

    unsigned remap_node(const unsigned i);

#ifdef EVEN_SITES_FIRST
    unsigned loop_begin(Parity P) const {
        if (P == ODD) {
            return mynode.evensites;
        } else {
            return 0;
        }
    }
    unsigned loop_end(Parity P) const {
        if (P == EVEN) {
            return mynode.evensites;
        } else {
            return mynode.sites;
        }
    }

    inline const CoordinateVector &coordinates(unsigned idx) const {
        return mynode.coordinates[idx];
    }

    inline int coordinate(unsigned idx, Direction d) const {
        return mynode.coordinates[idx][d];
    }

    inline Parity site_parity(unsigned idx) const {
        if (idx < mynode.evensites)
            return EVEN;
        else
            return ODD;
    }

#else // Now not EVEN_SITES_FIRST

    unsigned loop_begin(Parity P) const {
        assert(P == ALL && "Only parity ALL when EVEN_SITES_FIRST is off");
        return 0;
    }
    unsigned loop_end(Parity P) const {
        return mynode.sites;
    }

    // Use computation to get coordinates: from fastest
    // to lowest, dir = 0, 1, 2, ...
    // each coordinate is c[d] = (idx/size_factor[d]) % size[d] + min[d], but
    // do it like below to avoid the mod

    inline const CoordinateVector coordinates(unsigned idx) const {
        CoordinateVector c;
        unsigned vdiv, ndiv;

        vdiv = idx;
        for (int d = 0; d < NDIM - 1; ++d) {
            ndiv = vdiv / mynode.size[d];
            c[d] = vdiv - ndiv * mynode.size[d] + mynode.min[d];
            vdiv = ndiv;
        }
        c[NDIM - 1] = vdiv + mynode.min[NDIM - 1];

        return c;
    }

    inline int coordinate(unsigned idx, Direction d) const {
        return (idx / mynode.size_factor[d]) % mynode.size[d] + mynode.min[d];
    }

    inline Parity site_parity(unsigned idx) const {
        return coordinates(idx).parity();
    }

#endif

    CoordinateVector local_coordinates(unsigned idx) const {
        return coordinates(idx) - mynode.min;
    }

    lattice_struct::nn_comminfo_struct get_comminfo(int d) {
        return nn_comminfo[d];
    }

    /* MPI functions and variables. Define here in lattice? */
    void initialize_wait_arrays();


    MPI_Comm mpi_comm_lat;

    // Guarantee 64 bits for these - 32 can overflow!
    int64_t n_gather_done = 0, n_gather_avoided = 0;


    /// Return the coordinates of a site, where 1st dim (x) runs fastest etc.
    /// Useful in
    ///   for (int64_t i=0; i<lattice.volume(); i++) {
    ///      auto c = lattice.global_coordinates(i);

    CoordinateVector global_coordinates(size_t index) const {
        CoordinateVector site;
        foralldir(dir) {
            site[dir] = index % size(dir);
            index /= size(dir);
        }
        return site;
    }

    int id() const {
        return l_label;
    }
};

/// global handle to lattice
extern lattice_struct lattice;

// Keep track of defined lattices
extern std::vector<lattice_struct *> lattices;


#if defined(CUDA) || defined(HIP)
__device__ __host__ int loop_lattice_size(Direction d);
#else
inline int loop_lattice_size(Direction d) {
    return lattice.size(d);
}
#endif


#ifdef VANILLA
#include "plumbing/backend_cpu/lattice.h"
#elif defined(CUDA) || defined(HIP)
#include "plumbing/backend_gpu/lattice.h"
#elif defined(VECTORIZED)
#include "plumbing/backend_vector/lattice_vector.h"
#endif


//////////////////////////////////////////////////////////////////////
// Define looping utilities
// forallcoordinates(cv)  - loops over coordinates in "natural" order
// forcoordinaterange(cv, min, max) - loops over a box subvolume in natural order
// Note - not meant for regular lattice traversal.

// clang-format off
#if NDIM == 4

#define forallcoordinates(cv) \
for (cv[3] = 0; cv[3] < lattice.size(3); cv[3]++) \
for (cv[2] = 0; cv[2] < lattice.size(2); cv[2]++) \
for (cv[1] = 0; cv[1] < lattice.size(1); cv[1]++) \
for (cv[0] = 0; cv[0] < lattice.size(0); cv[0]++) 

#define forcoordinaterange(cv,cmin,cmax) \
for (cv[3] = cmin[3]; cv[3] <= cmax[3]; cv[3]++) \
for (cv[2] = cmin[2]; cv[2] <= cmax[2]; cv[2]++) \
for (cv[1] = cmin[1]; cv[1] <= cmax[1]; cv[1]++) \
for (cv[0] = cmin[0]; cv[0] <= cmax[0]; cv[0]++) 

#elif NDIM == 3

#define forallcoordinates(cv) \
for (cv[2] = 0; cv[2] < lattice.size(2); cv[2]++) \
for (cv[1] = 0; cv[1] < lattice.size(1); cv[1]++) \
for (cv[0] = 0; cv[0] < lattice.size(0); cv[0]++) 

#define forcoordinaterange(cv,cmin,cmax) \
for (cv[2] = cmin[2]; cv[2] <= cmax[2]; cv[2]++) \
for (cv[1] = cmin[1]; cv[1] <= cmax[1]; cv[1]++) \
for (cv[0] = cmin[0]; cv[0] <= cmax[0]; cv[0]++) 

#elif NDIM == 2

#define forallcoordinates(cv) \
for (cv[1] = 0; cv[1] < lattice.size(1); cv[1]++) \
for (cv[0] = 0; cv[0] < lattice.size(0); cv[0]++) 

#define forcoordinaterange(cv,cmin,cmax) \
for (cv[1] = cmin[1]; cv[1] <= cmax[1]; cv[1]++) \
for (cv[0] = cmin[0]; cv[0] <= cmax[0]; cv[0]++) 

#elif NDIM == 1

#define forallcoordinates(cv) \
for (cv[0] = 0; cv[0] < lattice.size(0); cv[0]++) 

#define forcoordinaterange(cv,cmin,cmax) \
for (cv[0] = cmin[0]; cv[0] <= cmax[0]; cv[0]++) 

#endif
// clang-format on

#endif
