/********************************************************
 *     _____________________
 *    / ____/  _/ ___/_  __/
 *   / / __ / / \__ \ / /   
 *  / /_/ // / ___/ // /    
 *  \____/___//____//_/ 
 * Geophysical Inversions using Spherical Tetrahedral meshes (GIST)
 *
 * Copyright (c) 2022  Yi Zhang (yizhang-geo@zju.edu.cn)
 *
 * GIST is distributed under a dual licensing scheme. You can redistribute 
 * it and/or modify it under the terms of the GNU Affero General Public 
 * License (AGPL) as published by the Free Software Foundation, either version 
 * 3 of the License, or (at your option) any later version. You should have 
 * received a copy of the GNU Affero General Public License along with this 
 * program. If not, see <http://www.gnu.org/licenses/>.
 * 
 * If the terms and conditions of the AGPL v.3. would prevent you from using 
 * the GIST, please consider the option to obtain a commercial license for a 
 * fee. These licenses are offered by the original author, Yi Zhang. As a rule, 
 * licenses are provided "as-is", unlimited in time for a one time fee. Please 
 * send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget 
 * to include some description of your company and the realm of its activities. 
 * Also add information on how to contact you by electronic and paper mail.
 ******************************************************/

#include "thermal.h"

void GIST::Thermal::Solve(array<double> &t)
{
    ts_type_ = SolveFEM;

    // 调用共轭梯度基类的求解函数
    if (t.size() != node_num_) t.resize(node_num_, 0.0);
    else t.assign_all(0.0);

    // Use diagonal elements of the kernel matrix as a preconditioner.
    thermal_kernel_.get_diagonal(precndt_);

    LCG_Minimize(t, tar_, gctl::LCG_PCG);
    return;
}

void GIST::Thermal::Solve(const array<double> &ref_t, array<double> &t)
{
    ts_type_ = SolveCFEM;

    if (ref_t.size() != elem_num_)
    {
        throw std::runtime_error("[GIST::Thermal::Solve] Invalid reference temperature model size.");
    }

    cnst_tar_.resize(node_num_);
    node_temper_.resize(node_num_);
    elem_temper_.resize(elem_num_);

    m_space_->Node2Element(ref_t, node_temper_, gctl::Trans);
    
    scale(node_temper_, reg_strength_);
    cnst_tar_ = tar_ + node_temper_;

    // 调用共轭梯度基类的求解函数
    if (t.size() != node_num_) t.resize(node_num_, 0.0);
    else t.assign_all(0.0);

    // 计算预优矩阵
    thermal_kernel_.get_diagonal(precndt_);
    m_space_->Node2Element(node_temper_);

    scale(node_temper_, reg_strength_);
    precndt_ += node_temper_;

    LCG_Minimize(t, cnst_tar_, gctl::LCG_PCG);
    return;
}