/********************************************************
 *     _____________________
 *    / ____/  _/ ___/_  __/
 *   / / __ / / \__ \ / /   
 *  / /_/ // / ___/ // /    
 *  \____/___//____//_/ 
 * Geophysical Inversions using Spherical Tetrahedral meshes (GIST)
 *
 * Copyright (c) 2022  Yi Zhang (yizhang-geo@zju.edu.cn)
 *
 * GIST is distributed under a dual licensing scheme. You can redistribute 
 * it and/or modify it under the terms of the GNU Affero General Public 
 * License (AGPL) as published by the Free Software Foundation, either version 
 * 3 of the License, or (at your option) any later version. You should have 
 * received a copy of the GNU Affero General Public License along with this 
 * program. If not, see <http://www.gnu.org/licenses/>.
 * 
 * If the terms and conditions of the AGPL v.3. would prevent you from using 
 * the GIST, please consider the option to obtain a commercial license for a 
 * fee. These licenses are offered by the original author, Yi Zhang. As a rule, 
 * licenses are provided "as-is", unlimited in time for a one time fee. Please 
 * send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget 
 * to include some description of your company and the realm of its activities. 
 * Also add information on how to contact you by electronic and paper mail.
 ******************************************************/

#ifndef _GIST_MODELSPACE_H
#define _GIST_MODELSPACE_H

#include "gctl/core.h"
#include "gctl/geometry.h"
#include "gctl/io.h"
#include "gctl/algorithm.h"

#include "../utility/enum.h"
#include "../utility/data_type.h"
#include "../earth_1d/earth_1d.h"

using namespace gctl;

namespace GIST
{
    /**
     * @brief This class handles the IO operations of the Gmsh (.msh) mesh file.
     * It also handles some common processes of the modeling space. More information 
     * of the class's member functions could be found in the source files accordingly.
     */
    class ModelSpace
    {
    public:
        ModelSpace();
        virtual ~ModelSpace();
        
        /**
         * @brief 
         * 
         * @param nodes 
         * @param faces 
         * @param elems 
         * @param phys 
         * @param etags 
         * @param ftags 
         */
        void InitMesh(array<vertex3dc> &nodes, array<triangle> &faces, array<tetrahedron> &elems, array<gmsh_physical_group> &phys, matrix<int> &etags, matrix<int> &ftags);
        
        /**
         * @brief Read the Gmsh (.msh) model file.
         * 
         * @param filename Name of the Gmsh model file.
         * @param filemode  Whether the starting index of the vertex is zero (Packed) or one (NotPacked).
         * @param with_tags Whether to read element tags from the model file.
         */
        void ReadMesh(std::string filename, index_packed_e filemode = NotPacked, msh_tag_e with_tags = NoTag);
        
        /**
         * @brief Read triangular face elements from the Gmsh (.msh) model file.
         * 
         * @note The ReadMesh() has to be called before using this function.
         * 
         * @param with_tags Whether to read face tags from the model file.
         */
        void ReadMeshFaces(msh_tag_e with_tags = NoTag);
        
        /**
         * @brief Read node elements from the Gmsh (.msh) model file.
         * 
         * @note  The ReadMesh() has to be called before using this function.
         * 
         * @param with_tags Whether to read node tags from the model file.
         */
        bool ReadMeshEnodes(msh_tag_e with_tags = NoTag);
        
        /**
         * @brief Read physical groups from the Gmsh (.msh) model file.
         * 
         * @note  The ReadMesh() has to be called before using this function.
         */
        void ReadPhysGroups();
        
        /**
         * @brief Read Gmsh (.msh) model data.
         * 
         * @note  The ReadMesh() has to be called before using this function.
         * 
         * @param dataname Name of the data.
         * @param data Data array.
         */
        void ReadData(std::string dataname, array<double> &data);
        
        /**
         * @brief Read Gmsh (.msh) model data with nodes' or elements' index.
         * 
         * @note  The ReadMesh() has to be called before using this function.
         * 
         * @param dataname Name of the data.
         * @param d_type Type of the data.
         * @param index Index of the data array, which is ordered with respected to the stored node or element array, not the input model file.
         * @param data Data array.
         */
        void ReadData(std::string dataname, mesh_data_type_e d_type, array<size_t> &index, array<double> &data);
        
        /**
         * @brief Create a new model space object that is a part of the current model space.
         * 
         * @param out_ms Returned sub model space.
         * @param ele_names Element names that are used to construct the new model space.
         * @param face_names Face names that are used to construct the new model space.
         * @param sub2whole_ele_idx Index the maps the new model space's elements to the originated model space.
         * @param sub2whole_fac_idx Index the maps the new model space's facets to the originated model space.
         * @param sub2whole_node_idx Index the maps the new model space's nodes to the originated model space.
         * @return Object of the new model space.
         */
        void CrtSubModelSpace(ModelSpace &out_ms, std::initializer_list<std::string> ele_names, std::initializer_list<std::string> face_names, 
            array<size_t> *sub2whole_ele_idx = nullptr, array<size_t> *sub2whole_fac_idx = nullptr, array<size_t> *sub2whole_node_idx = nullptr);
       
        /**
         * @brief Create element data according to the elements' physical tag and assigned value.
         * 
         * @param tag_str A string which contains tag names and corresponding values. A group of the tag and value are connected by '/'. Different groups are connected by commas. Example: name1/2.5,name2/3.2 
         * @param data Output array array. Zero value is assigned to all elements by default.
         */
        void CrtElemData(std::string tag_str, array<double> &data);
        
        /**
         * @brief Create element data using the 1D reference earth model.
         * 
         * @param data Output data array.
         * @param ref_earth The 1D reference earth model.
         * @param mod_type  Targeting physical model type of the earth model.
         */
        void CrtElemData1DEarth(array<double> &data, const Earth1D &ref_earth, model_type_e mod_type);
        
        /**
         * @brief Create element data using an 1D linear trend.
         * 
         * @param data Output data array.
         * @param top_depth Top depth of the linear profile.
         * @param btm_depth Bottom depth of the linear profile.
         * @param top_val Top value of the linear profile.
         * @param btm_val Bottom value of the linear profile.
         */
        void CrtElemData1DLinear(array<double> &data, const Earth1D &ref_earth, double top_depth, double btm_depth, double top_val, double btm_val);
        
        /**
         * @brief 利用线性插值计算一个层状模型
         * 
         * @param elem_tags 层状模型内单元体模型名称（可能有多层）
         * @param upfac_tag 层状模型上顶面顶点模型名称
         * @param dnfac_tag 层状模型下底面顶点模型名称
         * @param up_val 上顶面模型值
         * @param dn_val 下底面模型值
         * @param out_data 返回的层状模型的顶点数据
         * @param out_idx 返回的层状模型的顶点元素索引
         */
        void CrtNodeDataLayer(std::vector<std::string> elem_tags, std::string upfac_tag, std::string dnfac_tag, 
            double up_val, double dn_val, array<double> &out_data, array<size_t> &out_idx);
        
        /**
         * @brief Convert nodal data to elemental data.
         * 
         * @param in_data Input node data array.
         * @param out_data Output element data array.
         */
        void Node2Element(const array<double> &in_data, array<double> &out_data, matrix_layout_e trans = NoTrans);

        /**
         * @brief Extract the diagonal elements of Q^T*Q
         * 
         * @param out_data Output array.
         */
        void Node2Element(array<double> &out_data);
        
        /**
         * @brief Convert elemental data 2 nodal data.
         * 
         * @param in_data Input element data array.
         * @param out_data Output node data array.
         */
        void Element2Node(const array<double> &in_data, array<double> &out_data, matrix_layout_e trans = NoTrans);
        
        /**
         * @brief Extract element or node data along a group of given locations.
         * 
         * @param data Input element or node data array. Size of the data must be equal the element's or node's size.
         * @param p_locs Locations of the given points.
         * @param p_data Output data array.
         * @param srad Searching radius for sampling around a point on profile.
         */
        void ExtractLocsProfile(const array<double> &data, const array<point3dc> &p_locs, array<double> &p_data, double srad);
        
        /**
         * @brief Extract a vertical profile of the averaged element or node data along the radial direction.
         * 
         * @param data Input element or node data array. Size of the data must be equal the element's or node's size.
         * @param top_dep Top depth of the profile.
         * @param btm_dep Bottom depth of the profile.
         * @param ddep Spacing of the profile.
         * @param srad Searching radius for sampling around a point on profile.
         * @param pro_data Output profile data.
         */
        void ExtractVerticProfile(const array<double> &data, const Earth1D &ref_earth, double top_dep, double btm_dep, double ddep, double srad, array<double> &pro_data);
        
        /**
         * @brief Save models to the Gmsh (.msh) file.
         * 
         * @param filename Name of the output model file.
         * @param filemode Whether the node index is starting from zero (packed) or one (not packed).
         * @param fdata_idx returned data index of the facial elements if there is any. This is used for saving faces' element data
         * @param edata_idx returned data index of the body elements. This is used for saving tetrahedrons' element data
         */
		void SaveMesh(std::string filename, index_packed_e filemode = NotPacked, array<size_t> *fdata_idx = nullptr, array<size_t> *edata_idx = nullptr);
        
        /**
         * @brief Save model data to the Gmsh (.msh) file.
         * 
         * @note The SaveMesh() has to be called before using this function.
         * 
         * @param dataname Name of the data.
         * @param d_type Type of the data.
         * @param data Data array.
         * @param data_index Index of the data array.
         */
		void SaveData(std::string dataname, mesh_data_type_e d_type, const array<double> &data, array<size_t> *data_index = nullptr);
        
        /**
         * @brief Save model data array to text file.
         * 
         * @param filename Output file name.
         * @param dataname Data name.
         * @param ref_earth Object of the 1D reference earth model.
         * @param d_type Model data type.
         * @param data Input data array.
         * @param index Index of the input data array.
         */
        void SaveData2Text(std::string filename, std::string dataname, const Earth1D &ref_earth, gctl::mesh_data_type_e d_type, const array<double> &data, array<size_t> *index = nullptr);
        
        /**
         * @brief Initiate the smoothness model constraint with one preferred direction or varying directions.
         * 
         * @param smth_x Initiate the smoothness constraint with a fixed preferred direction for all elements. This is the x component of the direction vector.
         * @param smth_y Initiate the smoothness constraint with a fixed preferred direction for all elements. This is the y component of the direction vector.
         * @param smth_z Initiate the smoothness constraint with a fixed preferred direction for all elements. This is the x component of the direction vector.
         * @param smth_w Initiate the smoothness constraint with a fixed preferred direction for all elements. This is the weight the constraint.
         * @param smth_wgts Initiate the smoothness constraint with varying directional weights for different elements. The preferred directions are indicated by the 3D point objects and their weights are taken as their module lengths.
         */
        void InitMatRoughness(double smth_x, double smth_y, double smth_z, double smth_w, array<point3dc> *smth_wgts = nullptr);
        
        /**
         * @brief Initiate depth weights for gravity and geoid inversions.
         * 
         * @param obsp Observation sites array.
         * @param beta Depth weighting.
         */
        void InitDepthWeight_DepthApproach(const array<point3ds> &obsp, double beta);
        
        /**
         * @brief Initiate depth weights for gravity and geoid inversions.
         * 
         * @param obsp Observation sites array.
         * @param beta Depth weighting.
         */
        void InitDepthWeight_DistanceApproach(const array<point3ds> &obsp, double beta);

        /**
         * @brief Initiate model weights along a depth profile.
         * 
         * @param rads Input radius points.
         * @param wgts Input weight points.
         */
        void InitModelWeight_DepthProfile(const array<double> &rads, const array<double> &wgts);

        /**
         * @brief Initiate volume weights.
         * 
         * @param alpha Volume weighting.
         */
        void InitVolumeWeight(double alpha);
        
        /**
         * @brief Initiate the reference model constraint.
         * 
         * @param ref_model The reference model.
         */
        void InitMatStructuralSimilarity(const array<double> &ref_model);

        /**
         * @brief Initiate a differential matrix to implement a monotonic constrain.
         * Normally, we asummed bigger values should be at deeper depths. We call that 
         * is a normal status. If bigger values are at shallower places. Set reversed 
         * to true.
         * 
         * @param reversed Welther to initiate in a reverse mode.
         */
        void InitMatMonotonic(bool reversed = false);

        /**
         * @brief Set the weights for the minimal model constraint.
         * 
         * @param wgt Global weight.
         * @param ele_wgts Individual element weights.
         */
        void SetMinimalModelWeight(array<double> *ele_wgts = nullptr);
        
        /**
         * @brief Set the weights for the smoothness model constraint.
         * 
         * @param wgt Global weight.
         * @param ele_wgts Individual element weights.
         */
        void SetSmoothModelWeight(array<double> *ele_wgts = nullptr);
        
        /**
         * @brief Set the weights for the reference model constraint.
         * 
         * @param wgt Global weight.
         * @param ele_wgts Individual element weights.
         */
        void SetReferModelWeight(double wgt, array<double> *ele_wgts = nullptr);

        /**
         * @brief Set weight of the monotonic model constrain
         * 
         * @param wgt Global weight
         */
        void SetMonotonicModelWeight(double wgt);
        
        /**
         * @brief Calculate the minimal model constraint value and its model gradients.
         * 
         * @param model Model value array.
         * @param model Background (zero) model value array.
         * @param out_grad Output model gradient array.
         * @param wgt Weight of the constraint.
         * @param extra_wgts Weights on individual elements.
         * @return Constraint value.
         */
        double MinimalModelConstraint(double wgt, const array<double> &model, const array<double> &bkg_model, array<double> &out_grad);
        
        /**
         * @brief Calculate the smoothness constraint value and its model gradients.
         * 
         * @param model Input model array.
         * @param out_grad Output model gradient array.
         * @param bkg_model Background model of which will be subtract from the model array.
         * @return Constraint value.
         */
        double SmoothModelConstraint(double wgt, const array<double> &model, array<double> &out_grad, array<double> *bkg_model = nullptr);
        
        /**
         * @brief Calculate the reference model constraint value and its model gradients.
         * 
         * @param model Model value array.
         * @param out_grad Output model gradient array.
         * @return Objective function's value.
         */
        double SimilarModelConstraint(const array<double> &model, array<double> &out_grad);

        /**
         * @brief Calculate the monotonic model constraint value and its model gradients.
         * 
         * @param model Model value array.
         * @param out_grad Output model gradient array.
         * @return Objective function's value.
         */
        double MonotonicModelConstraint(const array<double> &model, array<double> &out_grad);

        /**
         * @brief Export node element index of the given tag name
         * 
         * @param fac_name Tag of the targeting element-node
         * @param fac_idx Export element-node index
         */
        void ExportEnodeIndex(std::string node_name, array<size_t> &node_idx);
       
        /**
         * @brief Export node index of the given face name
         * 
         * @param fac_name Tag of the targeting faces
         * @param node_idx Export node index
         * @param unpacked Export duplicated node indices organizated facet-wise
         */
        void ExportFaceNodeIndex(std::string fac_name, array<size_t> &node_idx, bool unpacked = false);
        
        /**
         * @brief Export node index of the given element name
         * 
         * @param ele_name Tag of the targeting elements
         * @param node_idx Export node index
         */
        void ExportElemNodeIndex(std::string ele_name, array<size_t> &node_idx);

        /**
         * @brief Export face index of the given tag name
         * 
         * @param fac_name Tag of the targeting face
         * @param fac_idx Export face index
         */
        void ExportFaceIndex(std::string fac_name, array<size_t> &fac_idx);

        /**
         * @brief Export elements' index of the given tag name(s).
         * 
         * @param ele_name Tag names of the targeting elements.
         * @param ele_idx Exported elements' index
         */
        void ExportElementIndex(std::string ele_name, array<size_t> &ele_idx);

        /**
         * @brief Get the model's vertex.
         * 
         * @return Vertex array.
         */
        array<vertex3dc> &get_node();
        
        /**
         * @brief Get the model's element vertex.
         * 
         * @return Vertex array.
         */
        array<enode> &get_enode();
        
        /**
         * @brief Return the face elements if there is any.
         * 
         * @return Face array.
         */
        array<triangle> &get_face();
        
        /**
         * @brief Return the body elements.
         * 
         * @return Element array.
         */
		array<tetrahedron> &get_element();
        
        /**
         * @brief Get the physics object.
         * 
         * @return physical group array.
         */
        array<gmsh_physical_group> &get_physics();
        
        /**
         * @brief Get volumes of the model elements.
         * 
         * @return volume array.
         */
        array<double> &get_volume();
        
        /**
         * @brief Get elements' tag by type
         * 
         * @param tag_t tag type
         * @return targeting elements' tag
         */
        array<int> &get_tag(tag_type_e tag_t);
        
        /**
         * @brief Get the model's node number.
         * 
         * @return node number.
         */
		size_t get_node_number() const;
        
        /**
         * @brief Get the model's face number.
         * 
         * @return face number.
         */
        size_t get_face_number() const;
        
        /**
         * @brief Get the model's element number.
         * 
         * @return element number.
         */
		size_t get_element_number() const;
        
        /**
         * @brief Get center location of the inquired element.
         * 
         * @param id element index.
         * @return center location.
         */
        point3dc get_element_center(size_t id);
        
        /**
         * @brief Get the depth weight object.
         * 
         * @return weight array.
         */
        array<double> &get_depth_weight();
    
        /**
         * @brief Get the volume weight object
         * 
         * @return weight array.
         */
        array<double> &get_volume_weight();

        /**
         * @brief Get the model weight object
         * 
         * @return weight array.
         */
        array<double> &get_model_weight();

    private:
        gmshio fio_;
        // Start model variables
		size_t node_num_, enode_num_, face_num_, elem_num_;
		array<vertex3dc> nodes_;
        array<enode> enodes_;
        array<triangle> faces_;
		array<tetrahedron> elems_;
        array<gmsh_physical_group> phys_;
        array<double> elem_vols_;
        bool in_ready_, out_ready_, face_ready_, phys_ready_;
        bool face_tag_ready_, elem_tag_ready_, node_tag_ready_;

        array<int> node_ftags_, node_etags_, face_tags_, elem_tags_, enode_tags_;
        matrix<int> etags_, ftags_, ntags_;
        
        gctl::linear_sf lsf_;
        spmat<double> n2e_kernel_;
        spmat<double> e2n_kernel_;
        std::vector<std::vector<gctl::tetrahedron *>> node_neigh_;
        // End model variables

        // Start constraint variables
        size_t cmn_face_num_;
        array<cmn_triangle> cmn_faces_;

        spmat<double> mono_mat_; // 单调模型变化矩阵
        spmat<double> smooth_mat_; // 光滑度矩阵
        array<double> face_diff_;
        array<double> grad_dm_;
        //array<double> smth_mdl_, smth_face_diff_; // 测试用变量

        spmat<double> ref1_mat_; // 参考模型矩阵
        spmat<double> ref2_mat_;
        spmat<double> ref3_mat_;
        spmat<double> ref4_mat_;
        spmat<double> ref5_mat_;
        spmat<double> ref6_mat_;
        array<double> ref_diff_;

        bool mwgt_ready_, swgt_ready_, rwgt_ready_, modwgt_ready_, monowgt_ready_;
        double rwgt_, mono_wgt_;
        array<double> mini_wgts_, smth_wgts_, refm_wgts_, model_wgts_;

        bool vwgt_ready_, dwgt_ready_;
        array<double> vol_wgts_, dep_wgts_;
        // End constraint variables
    };
}

#endif // _GIST_MODELSPACE_H