! ***********************************************************************
!
!   Copyright (C) 2012  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************

      module solve_coupled_burn_mix

      use star_private_def
      

      implicit none

      
      integer, parameter :: i_id = 1
      integer, parameter :: i_cnt = 2
      integer, parameter :: i_mujac = 3
      integer, parameter :: burn_mix_lipar = 3

      
      contains

      
      
      subroutine do_isolve_burn_mix(s, dt, ierr)
         use num_lib
         use const_def
         use rates_def
         use chem_def
         use mtx_def
         use mtx_lib
         
         type (star_info), pointer :: s
         real(dp), intent(in) :: dt
         integer, intent(out) :: ierr

         
         integer, pointer :: iwork(:), ipar_decsol(:)
         real(dp), pointer :: work(:), rpar_decsol(:)
         integer :: liwork, lwork, max_steps, idid, lrd, lid, itol, iout, lout, &
            species, nz, num_vars, op_err, k
         integer :: ijac, mljac, mujac, imas, mlmas, mumas, nzmax, isparse, which_solver
         integer, parameter :: lrpar = 0, lipar = burn_mix_lipar
         integer :: ipar(lipar)
         real(dp) :: atol(1), rtol(1), rpar(lrpar), &
            t_start, t_stop, max_step_size, init_step_size
         logical :: trace
         
         include 'formats.dek'
         
         ierr = 0
         species = s% species
         nz = s% nz
         num_vars = species*nz
         trace = s% op_split_burn_mix_trace
         
         if (trace) &
            write(*,2) 'do_isolve_burn_mix ' // trim(s% op_split_burn_solver1), &
               s% model_number, dt

         which_solver = solver_option(s% op_split_burn_solver1, ierr)
            ! ros2_solver
            ! rose2_solver
            ! ros3p_solver
            ! ros3pl_solver
            ! rodas3_solver
            ! rodas4_solver
            ! rodasp_solver
            ! seulex_solver
            ! sodex_solver
         ijac = 1 ! analytic jacobian
         mljac = 2*species - 1
         mujac = mljac
         imas = 0
         mlmas = 0
         mumas = 0        
         nzmax = 0
         isparse = 0
         iout = 1
         lout = 0
         
         call lapack_work_sizes(num_vars, lrd, lid)
         call isolve_work_sizes( &
            num_vars, nzmax, imas, mljac, mujac, mlmas, mumas, liwork, lwork)
            
         allocate(iwork(liwork), work(lwork), ipar_decsol(lid), rpar_decsol(lrd))

         iwork(1:liwork) = 0
         work(1:lwork) = 0
         
         itol = 0
         rtol(1) = s% op_split_burn_rtol
         atol(1) = s% op_split_burn_atol
         max_steps = s% op_split_burn_maxsteps
         max_step_size = dt
         init_step_size = dt/20
         
         ipar(i_id) = s% id
         ipar(i_cnt) = 0 ! means 1st time so need to calculate net rates
         ipar(i_mujac) = mujac
         
         t_start = 0
         t_stop = dt

         call isolve( &
            which_solver, num_vars, burn_mix_fcn, t_start, s% xa, t_stop, &  
            init_step_size, max_step_size, max_steps, & 
            rtol, atol, itol, & 
            burn_mix_jac, ijac, null_sjac, nzmax, isparse, mljac, mujac, & 
            null_mas, imas, mlmas, mumas, & 
            burn_mix_solout, iout, & 
            lapack_decsol, null_decsols, lrd, rpar_decsol, lid, ipar_decsol, & 
            work, lwork, iwork, liwork, & 
            lrpar, rpar, lipar, ipar, & 
            lout, idid)

         call dealloc
         
         if (idid < 0) then
            if (s% report_ierr) write(*,*) 'failed in do_isolve_burn_mix'
            ierr = -1
         end if
         
         if (ierr == 0) call fixup_xa(s, s% xa, ierr)
         
         if (ierr == 0) then ! update eps_nuc info with final abundances
            op_err = 0
            do k=1,nz
               call calc_eps_nuc_info(s, k, species, op_err)
               if (op_err /= 0) ierr = op_err
               if (ierr /= 0) cycle
               if (s% X(k) < 0 .or. s% X(k) > 1) then
                  write(*,2) 's% X(k)', k, s% X(k)
                  stop 'do_isolve_burn_mix'
               end if
            end do
         end if
         
         !if (trace) stop 'do_isolve_burn_mix'
         
         
         contains
      
      
         subroutine dealloc
            deallocate(iwork, work, ipar_decsol, rpar_decsol)
         end subroutine dealloc
         

      end subroutine do_isolve_burn_mix
      
      
      subroutine calc_eps_nuc_info(s, k, species, ierr)
         use chem_lib, only: composition_info
         use micro, only: do1_net
         type (star_info), pointer :: s         
         integer, intent(in) :: k, species
         integer, intent(out) :: ierr  
         logical, parameter :: reuse_given_rates = .true.
         real(dp) :: xsum, dabar_dx(species), dzbar_dx(species)
         call composition_info(species, s% chem_id, s% xa(1:species,k), s% X(k), s% Y(k), &
            s% abar(k), s% zbar(k), s% z2bar(k), s% ye(k), &
            s% approx_abar(k), s% approx_zbar(k), xsum, dabar_dx, dzbar_dx)  
         ! update eps_nuc for new composition with same T and Rho
         ! reuse_given_rates = .true. since no change in T and Rho
         call do1_net( &
            s, k, s% species, s% num_reactions, s% net_lwork, reuse_given_rates, ierr)
      end subroutine calc_eps_nuc_info




      
      
      subroutine burn_mix_fcn(n, x, y, f, lrpar, rpar, lipar, ipar, ierr)
         use const_def, only: dp
         integer, intent(in) :: n, lrpar, lipar
         real(dp), intent(in) :: x
         real(dp), intent(inout) :: y(n) 
            ! okay to edit y if necessary (e.g., replace negative values by zeros)
         real(dp), intent(out) :: f(n) ! dy/dx
         real(dp), intent(inout), target :: rpar(lrpar)
         integer, intent(inout), target :: ipar(lipar)
         integer, intent(out) :: ierr ! nonzero means retry with smaller timestep.
         integer, parameter :: ldfy = 0
         real(dp) :: dfdy(ldfy, n)
         call burn_mix_jac(n, x, y, f, dfdy, ldfy, lrpar, rpar, lipar, ipar, ierr)
      end subroutine burn_mix_fcn


      subroutine burn_mix_jac(n, x, y, f, dfdy, ldfy, lrpar, rpar, lipar, ipar, ierr)
         use const_def, only: dp
         use utils_lib, only: set_pointer_2
         integer, intent(in) :: n, ldfy, lrpar, lipar
         real(dp), intent(in) :: x
         real(dp), intent(inout) :: y(n)
         real(dp), intent(out) :: f(n) ! dy/dx
         real(dp), intent(out) :: dfdy(ldfy, n)
            ! ldfy=0 means skip partials, just set f
            ! dense: dfdy(i, j) = partial f(i) / partial y(j)
            ! banded: dfdy(i-j+mujac+1, j) = partial f(i) / partial y(j)
               ! uses rows 1 to mljac+mujac+1 of dfdy.
               ! The j-th column of the square matrix is stored in the j-th column of the
               ! array dfdy as follows:
               ! dfdy(mujac+1+i-j, j) = partial f(i) / partial y(j)
               ! for max(1, j-mujac)<=i<=min(N, j+mljac)
         real(dp), intent(inout), target :: rpar(lrpar)
         integer, intent(inout), target :: ipar(lipar)
         integer, intent(out) :: ierr ! nonzero means terminate integration
         
         integer :: id, species, nz, i, k, j, jj, op_err, mujac
         logical :: reuse_given_rates
         type (star_info), pointer :: s
         real(dp), pointer :: xa(:,:), dxdt(:,:)
         real(dp) :: dm, dx00, dxp1, sig00_dm, sigp1_dm, J_mix_l, J_mix_d, J_mix_u
         
         include 'formats.dek'
         
         ierr = 0       
         id = ipar(i_id)
         call get_star_ptr(id, s, ierr)
         if (ierr /= 0) then
            write(*,*) 'bad id for burn_mix_jac', id
            stop
            return
         end if
         nz = s% nz
         species = s% species
         if (n /= nz*species) then
            write(*,*) 'bad n for burn_mix_jac', n, species*nz, species, nz
            stop
            return
         end if
         call set_pointer_2(xa, y, species, nz)
         call set_pointer_2(dxdt, f, species, nz)

         call fixup_xa(s, xa, ierr)
         if (ierr /= 0) return

         reuse_given_rates = (ipar(i_cnt) > 0)
         ipar(i_cnt) = ipar(i_cnt) + 1
         
         mujac = ipar(i_mujac)
         
         if (ldfy /= 0) dfdy = 0
         
         do k=1,nz
            call set1_net(k, op_err)
            if (op_err /= 0) ierr = op_err
            if (ierr /= 0) cycle
            dm = s% dm(k)
            if (k > 1) then
               sig00_dm = s% sig(k)/dm
            else
               sig00_dm = 0
            end if
            if (k < nz) then
               sigp1_dm = s% sig(k+1)/dm
            else
               sigp1_dm = 0
            end if
            do j=1,species
               dxdt(j,k) = s% dxdt_nuc(j,k)
               if (k > 1) then ! change from above
                  dx00 = xa(j,k-1) - xa(j,k)
                  dxdt(j,k) = dxdt(j,k) + dx00*sig00_dm
                  if (ldfy > 0) call setm1(j,j,k,sig00_dm)
               end if
               if (k < nz) then ! change from below
                  dxp1 = xa(j,k+1) - xa(j,k)
                  dxdt(j,k) = dxdt(j,k) + dxp1*sigp1_dm
                  if (ldfy > 0) call setp1(j,j,k,sigp1_dm)
               end if
               if (ldfy > 0) then
                  call set00(j,j,k,-(sig00_dm + sigp1_dm))
                  do jj=1,species
                     call set00(j,jj,k,s% d_dxdt_dx(j,jj,k))
                  end do
               end if
            end do            
         end do
         
         
         contains
         
         
         ! dfdy(mujac+1+i-ii, ii) = partial f(i) / partial y(ii)
         subroutine set00(j,jj,k,val)
            integer, intent(in) :: j, jj, k
            real(dp), intent(in) :: val
            integer :: kk, i, ii
            if (val == 0) return
            kk = (k-1)*species
            i = kk + j
            ii = kk + jj
            dfdy(mujac+1+i-ii, ii) = dfdy(mujac+1+i-ii, ii) + val
         end subroutine set00
         
         
         subroutine setp1(j,jj,k,val)
            integer, intent(in) :: j, jj, k
            real(dp), intent(in) :: val
            integer :: kk, i, ii
            if (val == 0) return
            kk = (k-1)*species
            i = kk + j
            ii = kk + species + jj
            dfdy(mujac+1+i-ii, ii) = dfdy(mujac+1+i-ii, ii) + val
         end subroutine setp1
         
         
         subroutine setm1(j,jj,k,val)
            integer, intent(in) :: j, jj, k
            real(dp), intent(in) :: val
            integer :: kk, i, ii
            if (val == 0) return
            kk = (k-1)*species
            i = kk + j
            ii = kk - species + jj
            dfdy(mujac+1+i-ii, ii) = dfdy(mujac+1+i-ii, ii) + val
         end subroutine setm1
      
      
         subroutine set1_net(k, ierr)
            use chem_lib, only: composition_info
            use micro, only: do1_net
            integer, intent(in) :: k
            integer, intent(out) :: ierr
            real(dp) :: sumx, dabar_dx(species), dzbar_dx(species)
            call composition_info( &
               species, s% chem_id, xa(1:species,k), s% X(k), s% Y(k), &
               s% abar(k), s% zbar(k), s% z2bar(k), s% ye(k), &
               s% approx_abar(k), s% approx_zbar(k), sumx, dabar_dx, dzbar_dx)  
            call do1_net( &
               s, k, s% species, s% num_reactions, s% net_lwork, reuse_given_rates, ierr)
         end subroutine set1_net
         
         
      end subroutine burn_mix_jac
 

      subroutine burn_mix_solout( &
            nr, xold, x, n, y, rwork_y, iwork_y, interp_y, lrpar, rpar, lipar, ipar, irtrn)
         ! nr is the step number.
         ! x is the current x value; xold is the previous x value.
         ! y is the current y value.
         ! irtrn negative means terminate integration.
         ! rwork_y and iwork_y hold info for interp_y
         ! note that these are not the same as the rwork and iwork arrays for the solver.
         use const_def, only: dp
         use utils_lib, only: set_pointer_2
         integer, intent(in) :: nr, n, lrpar, lipar
         real(dp), intent(in) :: xold, x
         real(dp), intent(inout) :: y(n)
         ! y can be modified if necessary to keep it in valid range of possible solutions.
         real(dp), intent(inout), target :: rpar(lrpar), rwork_y(*)
         integer, intent(inout), target :: ipar(lipar), iwork_y(*)
         interface
            include 'num_interp_y.dek'
         end interface
         integer, intent(out) :: irtrn ! < 0 causes solver to return to calling program.
         
         integer :: ierr, id, species, nz, i, j, k
         real(dp) :: xsum
         type (star_info), pointer :: s
         logical :: trace
         real(dp), pointer :: xa(:,:)
         
         include 'formats.dek'
         
         irtrn = 0
         ierr = 0       
         id = ipar(i_id)
         call get_star_ptr(id, s, ierr)
         if (ierr /= 0) then
            write(*,*) 'bad id for burn_mix_solout', id
            stop
            return
         end if
         
         trace = s% op_split_burn_mix_trace         
         if (trace) write(*,2) 'burn_mix_solout', nr, x
         nz = s% nz
         species = s% species
         call set_pointer_2(xa, y, species, nz)
         call fixup_xa(s, xa, ierr)
         if (ierr /= 0) irtrn = -1

      end subroutine burn_mix_solout
      
      
      
      subroutine fixup_xa(s, xa, ierr)
         type (star_info), pointer :: s
         real(dp), pointer :: xa(:,:)
         integer, intent(out) :: ierr
         integer :: j, k, species, nz
         real(dp) :: xsum
         include 'formats.dek'
         ierr = 0
         nz = s% nz
         species = s% species
         do k=1,nz ! fixup small errors -- complain if find big ones
            if (minval(xa(1:species,k)) < s% min_xa_hard_limit) then
               write(*,2) 'burn_solout bad negX', k, minval(xa(1:species,k))
               ierr = -1
               return
            end if
            do j=1,species
               xa(j,k) = max(0d0, min(1d0, xa(j,k)))
            end do
            xsum = sum(xa(1:species,k))
            if (abs(xsum - 1d0) > s% sum_xa_tolerance) then
               write(*,2) 'burn_solout bad xsum-1', k, xsum - 1d0
               ierr = -1
               return
            end if
            do j=1,species
               xa(j,k) = xa(j,k)/xsum
            end do
         end do
      end subroutine fixup_xa




      end module solve_coupled_burn_mix




