/**
 * \copyright
 * Copyright (c) 2012-2019, OpenGeoSys Community (http://www.opengeosys.org)
 *            Distributed under a Modified BSD License.
 *              See accompanying file LICENSE.txt or
 *              http://www.opengeosys.org/project/license
 *
 */

#pragma once

#include <vector>

#include "GroundwaterFlowProcessData.h"
#include "MathLib/LinAlg/Eigen/EigenMapTools.h"
#include "NumLib/DOF/DOFTableUtil.h"
#include "NumLib/Extrapolation/ExtrapolatableElement.h"
#include "NumLib/Fem/FiniteElement/TemplateIsoparametric.h"
#include "NumLib/Fem/ShapeMatrixPolicy.h"
#include "ParameterLib/Parameter.h"
#include "ProcessLib/LocalAssemblerInterface.h"
#include "ProcessLib/LocalAssemblerTraits.h"
#include "ProcessLib/Utils/InitShapeMatrices.h"

namespace ProcessLib
{
namespace GroundwaterFlow
{
const unsigned NUM_NODAL_DOF = 1;

template <int GlobalDim>
Eigen::Matrix<double, GlobalDim, GlobalDim> hydraulicConductivity(
    std::vector<double> const& values)
{
    auto const size{values.size()};
    if (size == 1)  // This is a duplicate but preferred case for GlobalDim==1.
    {
        return Eigen::Matrix<double, GlobalDim, GlobalDim>::Identity() *
               values[0];
    }
    if (size == GlobalDim)
    {
        return Eigen::Map<Eigen::Matrix<double, GlobalDim, 1> const>(
                   values.data(), GlobalDim, 1)
            .asDiagonal();
    }
    if (size == GlobalDim * GlobalDim)
    {
        return Eigen::Map<Eigen::Matrix<double, GlobalDim, GlobalDim> const>(
            values.data(), GlobalDim, GlobalDim);
    }

    OGS_FATAL(
        "Hydraulic conductivity parameter values size is neither one nor %d "
        "nor %d squared, but %d.",
        GlobalDim, GlobalDim, values.size());
}

class GroundwaterFlowLocalAssemblerInterface
    : public ProcessLib::LocalAssemblerInterface,
      public NumLib::ExtrapolatableElement
{
public:
    virtual std::vector<double> const& getIntPtDarcyVelocity(
        const double /*t*/,
        GlobalVector const& /*current_solution*/,
        NumLib::LocalToGlobalIndexMap const& /*dof_table*/,
        std::vector<double>& /*cache*/) const = 0;
};

template <typename ShapeFunction, typename IntegrationMethod,
          unsigned GlobalDim>
class LocalAssemblerData : public GroundwaterFlowLocalAssemblerInterface
{
    using ShapeMatricesType = ShapeMatrixPolicyType<ShapeFunction, GlobalDim>;
    using ShapeMatrices = typename ShapeMatricesType::ShapeMatrices;

    using LocalAssemblerTraits = ProcessLib::LocalAssemblerTraits<
        ShapeMatricesType, ShapeFunction::NPOINTS, NUM_NODAL_DOF, GlobalDim>;

    using NodalMatrixType = typename LocalAssemblerTraits::LocalMatrix;
    using NodalVectorType = typename LocalAssemblerTraits::LocalVector;
    using GlobalDimVectorType = typename ShapeMatricesType::GlobalDimVectorType;

public:
    /// The hydraulic_conductivity factor is directly integrated into the local
    /// element matrix.
    LocalAssemblerData(MeshLib::Element const& element,
                       std::size_t const /*local_matrix_size*/,
                       bool is_axially_symmetric,
                       unsigned const integration_order,
                       GroundwaterFlowProcessData const& process_data)
        : _element(element),
          _process_data(process_data),
          _integration_method(integration_order),
          _shape_matrices(initShapeMatrices<ShapeFunction, ShapeMatricesType,
                                            IntegrationMethod, GlobalDim>(
              element, is_axially_symmetric, _integration_method))
    {
    }

    void assemble(double const t, std::vector<double> const& local_x,
                  std::vector<double>& /*local_M_data*/,
                  std::vector<double>& local_K_data,
                  std::vector<double>& /*local_b_data*/) override
    {
        auto const local_matrix_size = local_x.size();
        // This assertion is valid only if all nodal d.o.f. use the same shape
        // matrices.
        assert(local_matrix_size == ShapeFunction::NPOINTS * NUM_NODAL_DOF);

        auto local_K = MathLib::createZeroedMatrix<NodalMatrixType>(
            local_K_data, local_matrix_size, local_matrix_size);

        unsigned const n_integration_points =
            _integration_method.getNumberOfPoints();

        ParameterLib::SpatialPosition pos;
        pos.setElementID(_element.getID());

        for (unsigned ip = 0; ip < n_integration_points; ip++)
        {
            pos.setIntegrationPoint(ip);
            auto const& sm = _shape_matrices[ip];
            auto const& wp = _integration_method.getWeightedPoint(ip);
            auto const k = hydraulicConductivity<GlobalDim>(
                _process_data.hydraulic_conductivity(t, pos));

            local_K.noalias() += sm.dNdx.transpose() * k * sm.dNdx * sm.detJ *
                                 sm.integralMeasure * wp.getWeight();
        }
    }

    /// Computes the flux in the point \c p_local_coords that is given in local
    /// coordinates using the values from \c local_x.
    Eigen::Vector3d getFlux(MathLib::Point3d const& p_local_coords,
                            double const t,
                            std::vector<double> const& local_x) const override
    {
        // eval dNdx and invJ at p
        auto const fe = NumLib::createIsoparametricFiniteElement<
            ShapeFunction, ShapeMatricesType>(_element);

        typename ShapeMatricesType::ShapeMatrices shape_matrices(
            ShapeFunction::DIM, GlobalDim, ShapeFunction::NPOINTS);

        // Note: Axial symmetry is set to false here, because we only need dNdx
        // here, which is not affected by axial symmetry.
        fe.computeShapeFunctions(p_local_coords.getCoords(), shape_matrices,
                                 GlobalDim, false);

        // fetch hydraulic conductivity
        ParameterLib::SpatialPosition pos;
        pos.setElementID(_element.getID());
        auto const k = _process_data.hydraulic_conductivity(t, pos)[0];

        Eigen::Vector3d flux;
        flux.head<GlobalDim>() =
            -k * shape_matrices.dNdx *
            Eigen::Map<const NodalVectorType>(local_x.data(), local_x.size());

        return flux;
    }

    Eigen::Map<const Eigen::RowVectorXd> getShapeMatrix(
        const unsigned integration_point) const override
    {
        auto const& N = _shape_matrices[integration_point].N;

        // assumes N is stored contiguously in memory
        return Eigen::Map<const Eigen::RowVectorXd>(N.data(), N.size());
    }

    std::vector<double> const& getIntPtDarcyVelocity(
        const double t,
        GlobalVector const& current_solution,
        NumLib::LocalToGlobalIndexMap const& dof_table,
        std::vector<double>& cache) const override
    {
        auto const n_integration_points =
            _integration_method.getNumberOfPoints();

        auto const indices = NumLib::getIndices(_element.getID(), dof_table);
        assert(!indices.empty());
        auto const local_x = current_solution.get(indices);
        auto const local_x_vec =
            MathLib::toVector<Eigen::Matrix<double, ShapeFunction::NPOINTS, 1>>(
                local_x, ShapeFunction::NPOINTS);

        cache.clear();
        auto cache_mat = MathLib::createZeroedMatrix<
            Eigen::Matrix<double, GlobalDim, Eigen::Dynamic, Eigen::RowMajor>>(
            cache, GlobalDim, n_integration_points);

        ParameterLib::SpatialPosition pos;
        pos.setElementID(_element.getID());

        for (unsigned i = 0; i < n_integration_points; ++i)
        {
            pos.setIntegrationPoint(i);
            auto const k = _process_data.hydraulic_conductivity(t, pos)[0];
            // dimensions: (d x 1) = (d x n) * (n x 1)
            cache_mat.col(i).noalias() =
                -k * _shape_matrices[i].dNdx * local_x_vec;
        }

        return cache;
    }

private:
    MeshLib::Element const& _element;
    GroundwaterFlowProcessData const& _process_data;

    IntegrationMethod const _integration_method;
    std::vector<ShapeMatrices, Eigen::aligned_allocator<ShapeMatrices>>
        _shape_matrices;
};

}  // namespace GroundwaterFlow
}  // namespace ProcessLib
