/*********************************************************************************
* Copyright (C) 2015-2017 Alexey V. Akimov
*
* This file is distributed under the terms of the GNU General Public License
* as published by the Free Software Foundation, either version 2 of
* the License, or (at your option) any later version.
* See the file LICENSE in the root directory of this distribution
* or <http://www.gnu.org/licenses/>.
*
*********************************************************************************/

#include "../State.h"

/// liblibra namespace
namespace liblibra{

namespace libscripts{
namespace libstate{


void State::run_md_respa(){

  if(md==NULL) { std::cout<<"Error: MD parameters have not been defined\n"; exit(1);}
  if(!is_md_initialized){    std::cout<<"Error: Need to call init_md() first. MD is not initialized\n"; exit(2);   }

  int is_thermostat, is_barostat;
  is_thermostat = ((thermostat!=NULL) && ((md->ensemble=="NVT")||(md->ensemble=="NPT")||(md->ensemble=="NPT_FLEX")));
  is_barostat   = ((barostat!=NULL) && ((md->ensemble=="NPT")||(md->ensemble=="NPT_FLEX")||(md->ensemble=="NPH")||(md->ensemble=="NPH_FLEX")));

  double dt = md->dt;
  double dt_m = dt/((double)md->n_medium); // medium time step
  double dt_f = dt_m/((double)md->n_fast); // fast time step
  double dt_half = 0.5*dt;
  double dt_m_half = 0.5*dt_m;
  double dt_f_half = 0.5*dt_f;
  double Nf = system->Nf_t + system->Nf_r;
  int Nf_b = 0;
  if(is_barostat) {Nf_b = barostat->get_Nf_b();}
  double scl,sc3,sc4,ksi_r;
  MATRIX3x3 S,I,sc1,sc2;

//  cout<<"dt(slow) = "<<dt<<endl;
//  cout<<"dt_m = "<<dt_m<<endl;
//  cout<<"dt_f = "<<dt_f<<endl;

  double dt_b_half = dt_half/((double)md->n_outer);


  while(md->curr_step<md->max_step){

    // Operator NHCB(dt/2)
    for(int n_b=0;n_b<md->n_outer;n_b++){

    if(is_thermostat){
      double ekin_baro = 0.0;
      if(is_barostat){  ekin_baro = barostat->ekin_baro(); }
      thermostat->update_thermostat_forces(system->ekin_tr(),system->ekin_rot(),ekin_baro);
      thermostat->propagate_nhc(dt_b_half,system->ekin_tr(),system->ekin_rot(),ekin_baro);
    }


    if(is_thermostat){  thermostat->propagate_sPs(dt_half);    }

    //bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb
    // Operator B(dt/2)
    if(is_barostat){
      if(md->ensemble=="NPT"||md->ensemble=="NPH"){ barostat->update_barostat_forces(system->ekin_tr(),system->ekin_rot(),curr_V,curr_P);   }
      else if(md->ensemble=="NPT_FLEX"||md->ensemble=="NPH_FLEX"){ barostat->update_barostat_forces(system->ekin_tr(),system->ekin_rot(),curr_V,curr_P_tens);   }
      scl = 0.0; if(is_thermostat){ scl = thermostat->get_ksi_b();  }
      barostat->propagate_velocity(dt_b_half,scl);
    }
    //bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb

   }// for n_b


    double s_var,Ps,dt_half_s,dt_over_s,dt_over_s2,dt_f_over_s,dt_f_half_s,dt_f_over_s2;
    s_var = 1.0; Ps = 0.0;
    dt_half_s = dt_half;
    dt_f_half_s = dt_f_half;
    dt_over_s = dt;
    dt_over_s2 = dt;
    dt_f_over_s = dt_f;
    dt_f_over_s2 = dt_f;

    if(md->ensemble=="NVT"||md->ensemble=="NPT"||md->ensemble=="NPT_FLEX"||md->ensemble=="NPH"||md->ensemble=="NPH_FLEX"){
      if(is_thermostat){
        s_var = thermostat->get_s_var(); 
        dt_half_s = dt_half*s_var;
        dt_f_half_s = dt_f_half*s_var;
        dt_over_s = (dt/s_var);
        dt_f_over_s = (dt_f/s_var);
        dt_over_s2 = (dt_over_s/s_var);
        dt_f_over_s2 = (dt_f_over_s/s_var);
      }
    }

//    vector<VECTOR> f_fast,f_medium;
//    vector<VECTOR> t_fast,t_medium;
//    MATRIX3x3 s_fast,s_medium;
    double E_fast,E_medium;
//    cout<<"slow step\n";
    //aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa    
    // Operator A_slow(dt/2)
    //-------------------- Linear momentum propagation --------------------
    S = 0.0; I.identity();
    if(is_barostat){
      if(Nf_b==9){ S = barostat->ksi_eps + (barostat->ksi_eps.tr()/(barostat->get_Nf_t()/*+barostat->get_Nf_r()*/))*I; }
      else if(Nf_b==1){S = barostat->ksi_eps_iso * I + (3.0*barostat->ksi_eps_iso/(barostat->get_Nf_t()/*+barostat->get_Nf_r()*/))*I; }
    }
    if(is_thermostat){   S = S + thermostat->get_ksi_t() * I;      }
    sc1 = (exp_(S,-dt_half));//.symmetrized();
    sc2 = dt_half*(exp1_(S,-dt_half*0.5));//.symmetrized()*dt_half;

    //------------------- Angular momentum propagation -----------------------    
    if(is_thermostat){ ksi_r = thermostat->get_ksi_r();}else{ ksi_r = 0.0;}
    sc3 = exp(-dt_half*ksi_r);
    sc4 = dt_half*exp(-0.5*dt_half*ksi_r)*sinh_(0.5*dt_half*ksi_r);

    for(int i=0;i<system->Number_of_fragments;i++){
      RigidBody& top = system->Fragments[i].Group_RB;
      //-------------------- Linear momentum propagation --------------------
      top.scale_linear_(sc1);
      top.apply_force(sc2);
      //------------------- Angular momentum propagation -----------------------
      top.scale_angular_(sc3);
      top.apply_torque(sc4);       
    }// for i


    for(int n_m=0;n_m<md->n_medium;n_m++){  // medium size integration
//      cout<<"  medium step\n";
      // Operator A_medium(dt_medium/2)
      if(n_m==0){ system->load_respa_state("medium");  }

      for(i=0;i<system->Number_of_fragments;i++){
        RigidBody& top = system->Fragments[i].Group_RB;
        //-------------------- Linear momentum propagation --------------------
        top.apply_force(dt_m_half);
        //------------------- Angular momentum propagation -----------------------
        top.apply_torque(dt_m_half);
      }// for i


      for(int n_f=0;n_f<md->n_fast;n_f++){  // fast size integration
//        cout<<"    fast step\n";
        // Operator A_fast(dt_fast/2)
        if(n_f==0){ system->load_respa_state("fast"); }
        for(i=0;i<system->Number_of_fragments;i++){
          RigidBody& top = system->Fragments[i].Group_RB;
          //-------------------- Linear momentum propagation --------------------
          top.apply_force(dt_f_half);
          //------------------- Angular momentum propagation -----------------------
          top.apply_torque(dt_f_half);
        }// for i
   
        //if(is_thermostat){  thermostat->propagate_Ps(-dt_half*E_pot);    }
    //aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa

        //ccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc
        //ccccccccccccccccccccccccccccccc Core part ccccccccccccccccccccccccccccccccccccc
        // Operator core(dt_fast)
        sc1.identity();
        sc2.identity();
        sc2 = sc2 * dt_f;
        if(is_barostat){
          sc1 = (barostat->pos_scale(dt_f));
          sc2 = dt_f*barostat->vpos_scale(dt_f);
        }

        for(i=0;i<system->Number_of_fragments;i++){
          RigidBody& top = system->Fragments[i].Group_RB;    
          if(is_thermostat){  thermostat->propagate_Ps( 0.5*dt_f_over_s2*(top.ekin_rot()+top.ekin_tr()) ); }
          double Ps,s_var;
          s_var = 1.0;  if(is_thermostat){ s_var = thermostat->s_var; }
          Ps = 0.0;
          if(md->integrator=="Jacobi")    { top.propagate_exact_rb(dt_f_over_s); }
          else if(md->integrator=="DLML")  { top.propagate_dlml(dt_f_over_s,Ps); }
          else if(md->integrator=="Terec") { top.propagate_terec(dt_f_over_s);}
          else if(md->integrator=="qTerec") { top.propagate_qterec(dt_f_over_s);}
          else if(md->integrator=="NO_SQUISH"){ top.propagate_no_squish(dt_f_over_s);}
          else if(md->integrator=="KLN")   { top.propagate_kln(dt_f_over_s);}
          else if(md->integrator=="Omelyan"){ top.propagate_omelyan(dt_f_over_s);}

          if(is_thermostat){  thermostat->propagate_Ps( 0.5*dt_f_over_s2*(top.ekin_rot()+top.ekin_tr()) ); } 
          if(is_barostat) {  
            top.scale_position(sc1); 
            top.shift_position(sc2*top.rb_p*top.rb_iM);
          }
          else{
            top.shift_position(dt_f_over_s*top.rb_p*top.rb_iM);
          }        
        }// for i - all fragments

        if(is_thermostat){  thermostat->propagate_Ps(dt_f*( H0 - Nf*boltzmann*thermostat->Temperature*(log(thermostat->s_var)+1.0) ) ); }

        // Update cell shape
        if(is_barostat){ if(system->is_Box) {    system->Box  =  sc1 * system->Box;    }   }

        // Update atomic positions and calculate interactions
        for(i=0;i<system->Number_of_fragments;i++){ system->update_atoms_for_fragment(i);  }


        system->zero_forces_and_torques();
        E_pot = system->energy_respa("fast"); // only fast 
        // Update rigid-body variables
        system->update_fragment_forces_and_torques();

        if(n_f==(md->n_fast-1)){
          system->save_respa_state("fast");
          E_fast = E_pot;
        }
//        cout<<"    E_fast = "<<E_pot<<endl;

        //cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc

    //aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
        // Operator A_fast(dt_fast/2)
        for(i=0;i<system->Number_of_fragments;i++){
          RigidBody& top = system->Fragments[i].Group_RB;
          //-------------------- Linear momentum propagation --------------------
          top.apply_force(dt_f_half);
          //------------------- Angular momentum propagation -----------------------
          top.apply_torque(dt_f_half);
        }// for i

      }// for n_f

      // Operator A_medium(dt_medium/2)
      system->zero_forces_and_torques();
      E_pot = system->energy_respa("medium"); // only medium
      // Update rigid-body variables
      system->update_fragment_forces_and_torques();

      if(n_m==(md->n_medium-1)){
        system->save_respa_state("medium");
        E_medium = E_pot;
      }
//      cout<<"  E_medium = "<<E_pot<<endl;



      for(i=0;i<system->Number_of_fragments;i++){
        RigidBody& top = system->Fragments[i].Group_RB;
        //-------------------- Linear momentum propagation --------------------
        top.apply_force(dt_m_half);
        //------------------- Angular momentum propagation -----------------------
        top.apply_torque(dt_m_half);
      }// for i

    }// for n_m

    // Operator A_slow(dt_slow/2)
    system->zero_forces_and_torques();
    E_pot = system->energy_respa("slow"); // only slow
    // Update rigid-body variables
    system->update_fragment_forces_and_torques();

//    cout<<"E_slow = "<<E_pot<<endl;


   
    E_kin = 0.0;
    //-------------------- Linear momentum propagation --------------------
    S = 0.0; I.identity();
    if(is_barostat){
      if(Nf_b==9){ S = barostat->ksi_eps + (barostat->ksi_eps.tr()/(barostat->get_Nf_t()/*+barostat->get_Nf_r()*/))*I; }
      else if(Nf_b==1){S = barostat->ksi_eps_iso * I + (3.0*barostat->ksi_eps_iso/(barostat->get_Nf_t()/*+barostat->get_Nf_r()*/))*I; }
    }
    if(is_thermostat){   S = S + thermostat->get_ksi_t() * I;      }
    sc1 = (exp_(S,-dt_half));//.symmetrized();
    sc2 = dt_half*(exp1_(S,-dt_half*0.5));//.symmetrized()*dt_half;

    //------------------- Angular momentum propagation -----------------------
    if(is_thermostat){ ksi_r = thermostat->get_ksi_r();}else{ ksi_r = 0.0;}
    sc3 = exp(-dt_half*ksi_r);
    sc4 = dt_half*exp(-0.5*dt_half*ksi_r)*sinh_(0.5*dt_half*ksi_r);


    for(i=0;i<system->Number_of_fragments;i++){
      RigidBody& top = system->Fragments[i].Group_RB;
      //-------------------- Linear momentum propagation --------------------
      top.scale_linear_(sc1);
      top.apply_force(sc2);
      //------------------- Angular momentum propagation -----------------------
      top.scale_angular_(sc3);
      top.apply_torque(sc4);
      E_kin += (top.ekin_rot() + top.ekin_tr());
    }// for i

    //aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa

    // Here we restore total stress and energy
    E_pot += E_fast + E_medium;
//    system->increment_stress(s_fast);
//    system->increment_stress(s_medium);
//    for(i=0;i<system->Number_of_fragments;i++){
//        RigidBody& top = system->Fragments[i].Group_RB;
//        top.rb_force += (f_fast[i] + f_medium[i]);
//        top.rb_torque_e += (t_fast[i] + t_medium[i]);
//    }// for i


    if(is_thermostat){ thermostat->propagate_Ps( -dt_half*E_pot); }
   //------- Update state variables ------------
      curr_P_tens = system->pressure_tensor();
      curr_P = (curr_P_tens.tr()/3.0);
      curr_V = system->volume();
//   curr_P_tens = 0.0;
//   curr_P = 0.0;
//   curr_V = 1e+10;
   //-------------------------------------------


    for(n_b=0;n_b<md->n_outer;n_b++){

    //bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb
    // Operator B(dt/2)
    if(is_barostat){
      if(md->ensemble=="NPT"||md->ensemble=="NPH"){ barostat->update_barostat_forces(system->ekin_tr(),system->ekin_rot(),curr_V,curr_P);   }
      else if(md->ensemble=="NPT_FLEX"||md->ensemble=="NPH_FLEX"){ barostat->update_barostat_forces(system->ekin_tr(),system->ekin_rot(),curr_V,curr_P_tens);   }
      scl = 0.0; if(is_thermostat){ scl = thermostat->get_ksi_b();  }
      barostat->propagate_velocity(dt_b_half,scl);
    }

    //bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb

    if(is_thermostat){thermostat->propagate_sPs(dt_half); }

    // Operator NHCB(dt/2)
    if(is_thermostat){
      double ekin_baro = 0.0;
      if(is_barostat){  ekin_baro = barostat->ekin_baro(); }
      thermostat->update_thermostat_forces(system->ekin_tr(),system->ekin_rot(),ekin_baro);
      thermostat->propagate_nhc(dt_b_half,system->ekin_tr(),system->ekin_rot(),ekin_baro);
    }

    }// for n_b;


    if(md->ensemble=="NVE"){ s_var = 1.0;  Ps = 0.0;}
    if(is_thermostat){   E_kin/=(thermostat->s_var*thermostat->s_var); }

    E_kin_tr = system->ekin_tr();
    E_kin_rot = system->ekin_rot();
    E_tot = E_kin + E_pot;
//    if(!is_H0){ H0 = E_tot + thermostat->energy(); is_H0 = 1;}

    if(md->ensemble=="NVE"){  H_NP = E_tot; }
    else if(md->ensemble=="NVT"){
      if(is_thermostat){
        if(!is_H0){ H0 = E_tot + thermostat->energy(); is_H0 = 1;}
        if(thermostat->thermostat_type=="Nose-Poincare"){    H_NP = thermostat->s_var*(E_tot + thermostat->energy() - H0);   }
        else if(thermostat->thermostat_type=="Nose-Hoover"){ H_NP = E_tot + thermostat->energy();    }
      }
    }
    else if(md->ensemble=="NPH"||md->ensemble=="NPH_FLEX"){
      if(is_barostat){  H_NP = E_tot + barostat->ekin_baro() + curr_V * barostat->Pressure;   }
    }
    else if(md->ensemble=="NPT" || md->ensemble=="NPT_FLEX"){
      if(is_barostat){   H_NP = E_tot + barostat->ekin_baro() + curr_V * barostat->Pressure;  }
      if(is_thermostat){ H_NP += thermostat->energy();   }
    }

    curr_T = 2.0*E_kin/(Nf*boltzmann);

    //------------- Angular velocity --------------
    L_tot = 0.0; 
    P_tot = 0.0;
    for(i=0;i<system->Number_of_fragments;i++){
      RigidBody& top = system->Fragments[i].Group_RB;
      VECTOR tmp; tmp.cross(top.rb_cm,top.rb_p);
      L_tot += top.rb_A_I_to_e_T * top.rb_l_e + tmp;
      P_tot += top.rb_p;
    }

    md->curr_step++;
    md->curr_time+=dt;
      
  }// for s
  md->curr_step = 0;
  md->curr_time = 0.0;
}


}// namespace libstate
}// namespace libscripts
}// liblibra
