! ***********************************************************************
!
!   Copyright (C) 2010  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************

      module net_burn_const_P
      use const_def
      use chem_def
      use net_def
      use rates_def, only: num_rvs
      use mtx_def
      
      implicit none


      integer, parameter :: i_burn_caller_id = 1
      integer, parameter :: i_net_handle = 2
      integer, parameter :: i_screening_mode = 3
      integer, parameter :: i_reuse_rates = 4
      integer, parameter :: i_net_lwork = 5
      integer, parameter :: i_eos_handle = 6
      integer, parameter :: i_sparse_format = 7
      integer, parameter :: i_clip = 8
      
      integer, parameter :: burn_lipar = 8

      
      integer, parameter :: r_rho = 1
      integer, parameter :: r_pressure = 2
      integer, parameter :: r_init_rho = 3
      integer, parameter :: r_time_net = 4
      integer, parameter :: r_time_eos = 5
      
      integer, parameter :: burn_lrpar = 5

      contains

      subroutine burn_1_zone_const_P(
     >         net_handle, eos_handle, num_isos, num_reactions, 
     >         which_solver, starting_temp, starting_x, clip,
     >         ntimes, times, log10Ps_f,
     >         rate_factors, category_factors, reaction_Qs, reaction_neuQs, screening_mode, 
     >         h, max_step_size, max_steps, rtol, atol, itol, which_decsol, 
     >         caller_id, solout, iout, ending_x, ending_temp, ending_rho, initial_rho,
     >         nfcn, njac, nstep, naccpt, nrejct, time_doing_net, time_doing_eos, ierr)
         use num_def
         use num_lib 
         use mtx_lib
         use alert_lib
         use rates_def, only: rates_reaction_id_max
         use net_initialize, only: work_size
         
         integer, intent(in) :: net_handle, eos_handle
         integer, intent(in) :: num_isos
         integer, intent(in) :: num_reactions
         double precision, pointer, intent(in) :: starting_x(:) ! (num_isos)
         double precision, intent(in) :: starting_temp
         logical, intent(in) :: clip ! if true, set negative x's to zero during burn.
         
         integer, intent(in) :: which_solver ! as defined in num_def.f
         integer, intent(in) :: ntimes ! ending time is times(num_times); starting time is 0
         double precision, pointer, intent(in) :: times(:) ! (num_times) 
         double precision, pointer, intent(in) :: log10Ps_f(:,:) ! (4,numtimes) interpolant for log10P(time)

         double precision, intent(in) :: rate_factors(num_reactions)
         double precision, intent(in) :: category_factors(num_categories)
         double precision, pointer, intent(in) :: reaction_Qs(:) ! (rates_reaction_id_max)
         double precision, pointer, intent(in) :: reaction_neuQs(:) ! (rates_reaction_id_max)
         integer, intent(in) :: screening_mode
         
         ! args to control the solver -- see num/public/num_isolve.dek
         double precision, intent(inout) :: h 
         double precision, intent(in) :: max_step_size ! maximal step size.
         integer, intent(in) :: max_steps ! maximal number of allowed steps.
         ! absolute and relative error tolerances
         double precision, intent(inout) :: rtol(*) ! relative error tolerance(s)
         double precision, intent(inout) :: atol(*) ! absolute error tolerance(s)
         integer, intent(in) :: itol ! switch for rtol and atol
         integer, intent(in) :: which_decsol ! from mtx_def
         integer, intent(in) :: caller_id
         interface ! subroutine called after each successful step
            include "num_solout.dek"
         end interface
         integer, intent(in)  :: iout
         double precision, intent(out) :: ending_x(num_isos), ending_temp, ending_rho, initial_rho
         integer, intent(out) :: nfcn    ! number of function evaluations
         integer, intent(out) :: njac    ! number of jacobian evaluations
         integer, intent(out) :: nstep   ! number of computed steps
         integer, intent(out) :: naccpt  ! number of accepted steps
         integer, intent(out) :: nrejct  ! number of rejected steps
         double precision, intent(inout) :: time_doing_net
            ! if < 0, then ignore
            ! else on return has input value plus time spent doing eval_net
         double precision, intent(inout) :: time_doing_eos
            ! if < 0, then ignore
            ! else on return has input value plus time spent doing eos
         integer, intent(out) :: ierr
         
         type (Net_General_Info), pointer :: g
         integer :: ijac, nzmax, isparse, mljac, mujac, imas, mlmas, mumas, lrd, lid,
     >         lout, liwork, lwork, i, j, lrpar, lipar, idid, net_lwork, nvar
         integer, pointer :: ipar(:), iwork(:), ipar_decsol(:)
         double precision, pointer :: rpar(:), work(:), v(:), rpar_decsol(:)
         double precision :: t, lgT, lgRho, tend
         
         include 'formats.dek'
         
         ending_x = 0
         ending_temp = 0
         ending_rho = 0
         nfcn = 0
         njac = 0
         nstep = 0
         naccpt = 0
         nrejct = 0
         ierr = 0
         
         nvar = num_isos + 1

         call get_net_ptr(net_handle, g, ierr)
         if (ierr /= 0) then
            return
         end if
         
         if (g% num_isos /= num_isos) then
            write(*,*) 'invalid num_isos', num_isos
            return
         end if
         
         if (g% num_reactions /= num_reactions) then
            write(*,*) 'invalid num_reactions', num_reactions
            return
         end if
         
         if (which_decsol == mkl_pardiso) then
            if (.not. okay_to_use_mkl_pardiso()) then
               write(*,*) 'MKL pardiso not loaded'
               write(*,*) 'please set use_MKL_pardiso = .false. in your inlist'
               write(*,*) 'or edit makefile_header and rebuild'
               stop 1
            end if
            lrd = 0; lid = 0
            nzmax = nvar**2 ! max number of non-zero entries
            isparse = mkl_pardiso_compressed_format
         else if (which_decsol == klu) then
            nzmax = num_isos**2 ! max number of non-zero entries
            isparse = klu_compressed_format
            call klu_work_sizes(num_isos, nzmax, lrd, lid)
         else if (which_decsol == lapack) then
            nzmax = 0
            isparse = 0
            call lapack_work_sizes(nvar, lrd, lid)
         else
            write(*,1) 'net 1 zone burn const P: unknown value for which_decsol', which_decsol
            stop 1
         end if
         
         ijac = 1
         mljac = nvar ! square matrix
         mujac = nvar

         imas = 0
         mlmas = 0
         mumas = 0        
         
         lout = 0
         
         net_lwork = work_size(g, num_isos, num_reactions)
         
         call isolve_work_sizes(nvar, nzmax, imas, mljac, mujac, mlmas, mumas, liwork, lwork)

         lipar = burn_lipar
         lrpar = burn_lrpar + num_reactions + num_categories + 
     >         2*rates_reaction_id_max + net_lwork + 3*num_rvs*num_reactions
         
         allocate(v(nvar), iwork(liwork), work(lwork), rpar(lrpar), ipar(lipar), 
     >         ipar_decsol(lid), rpar_decsol(lrd), stat=ierr)
         if (ierr /= 0) then
            write(*, *) 'allocate ierr', ierr
            return
         end if
         
         i = burn_lrpar
         rpar(i+1:i+num_reactions) = rate_factors(1:num_reactions); i = i+num_reactions
         rpar(i+1:i+num_categories) = category_factors(1:num_categories); i = i+num_categories
         rpar(i+1:i+rates_reaction_id_max) = reaction_Qs(1:rates_reaction_id_max)
         i = i+rates_reaction_id_max
 
         rpar(i+1:i+rates_reaction_id_max) = reaction_neuQs(1:rates_reaction_id_max)
         i = i+rates_reaction_id_max
            
         ipar(i_burn_caller_id) = caller_id
         ipar(i_net_handle) = net_handle
         ipar(i_eos_handle) = eos_handle
         ipar(i_screening_mode) = screening_mode
         ipar(i_net_lwork) = net_lwork
         ipar(i_sparse_format) = isparse
         if (clip) then
            ipar(i_clip) = 1
         else
            ipar(i_clip) = 0
         end if

         iwork = 0
         work = 0
         
         t = 0
         tend = times(ntimes)
         
         rpar(r_pressure) = 10**log10Ps_f(1,1) ! no interpolation yet
         
         rpar(r_init_rho) = -1
         rpar(r_time_net) = time_doing_net

         ipar(i_reuse_rates) = 0
         
         v(1:num_isos) = starting_x(1:num_isos)
         v(nvar) = log(starting_temp)
                     
         if (which_decsol == mkl_pardiso) then
            call do_isolve(null_decsol, mkl_pardiso_decsols)
         else if (which_decsol == klu) then
            call do_isolve(null_decsol, klu_decsols)
         else if (which_decsol == lapack) then
            call do_isolve(lapack_decsol, null_decsols)
         else
            write(*,*) 'unknown value for which_decsol', which_decsol
            stop 1
         end if

         nfcn = iwork(14)
         njac = iwork(15)
         nstep = iwork(16)
         naccpt = iwork(17)
         nrejct = iwork(18)            
         time_doing_net = rpar(r_time_net)
         time_doing_eos = rpar(r_time_eos)

         ending_x(1:num_isos) = v(1:num_isos)
         ending_temp = exp(v(nvar))
         ending_rho = rpar(r_rho)
         initial_rho = rpar(r_init_rho)
         
         if (ierr /= 0) then
            write(*, *) 'alert message: ' // trim(alert_message)
            write(*, '(a30,i10)') 'nfcn', nfcn
            write(*, '(a30,i10)') 'njac', njac
            write(*, '(a30,i10)') 'nstep', nstep
            write(*, '(a30,i10)') 'naccpt', naccpt
            write(*, '(a30,i10)') 'nrejct', nrejct
            stop 1
         end if
         
         deallocate(v, iwork, work, rpar, ipar, ipar_decsol, rpar_decsol)
      
         
         contains
         
         
         subroutine do_isolve(decsol, decsols)
            interface
               include "mtx_decsol.dek"
               include "mtx_decsols.dek"
            end interface
            call isolve(
     >         which_solver, nvar, burn_derivs, t, v, tend,  
     >         h, max_step_size, max_steps, 
     >         rtol, atol, itol, 
     >         burn_jacob, ijac, null_sjac, nzmax, isparse, mljac, mujac, 
     >         null_mas, imas, mlmas, mumas, 
     >         solout, iout, 
     >         decsol, null_decsols, lrd, rpar_decsol, lid, ipar_decsol, 
     >         work, lwork, iwork, liwork, 
     >         lrpar, rpar, lipar, ipar, 
     >         lout, idid)
         end subroutine do_isolve
         
      end subroutine burn_1_zone_const_P


      subroutine burn_derivs(nvar, t, v, f, lrpar, rpar, lipar, ipar, ierr)
         integer, intent(in) :: nvar, lrpar, lipar
         double precision, intent(in) :: t
         double precision, intent(inout) :: v(nvar)
         double precision, intent(out) :: f(nvar) ! dvdt
         double precision, intent(inout), target :: rpar(lrpar)
         integer, intent(inout), target :: ipar(lipar)
         integer, intent(out) :: ierr
         integer, parameter :: ld_dfdv = 0
         double precision :: dfdv(ld_dfdv,nvar)
         ierr = 0
         call burn_jacob(nvar, t, v, f, dfdv, ld_dfdv, lrpar, rpar, lipar, ipar, ierr)
      end subroutine burn_derivs


      subroutine burn_jacob(nvar, time, v, f, dfdv, ld_dfdv, lrpar, rpar, lipar, ipar, ierr)
         use chem_lib, only: composition_info
         use net_eval, only: eval_net
         use eos_def
         use eos_lib, only: Radiation_Pressure, eosPT_get, eos_theta_e
         use screen_def, only: classic_screening
         use rates_def, only: rates_reaction_id_max

         integer, intent(in) :: nvar, ld_dfdv, lrpar, lipar
         double precision, intent(in) :: time
         double precision, intent(inout) :: v(nvar)
         double precision, intent(out) :: f(nvar), dfdv(ld_dfdv, nvar)
         double precision, intent(inout), target :: rpar(lrpar)
         integer, intent(inout), target :: ipar(lipar)
         integer, intent(out) :: ierr
         
         integer :: net_handle, num_reactions, eos_handle
         double precision :: 
     >         abar, zbar, z2bar, ye, approx_abar, approx_zbar, sumx, T, logT, rho, logRho, pressure, Pgas, Prad, lgPgas,
     >         eta, dlnT_dt, x(nvar-1), dabar_dx(nvar-1), dzbar_dx(nvar-1)
         double precision, pointer :: category_factors(:)
         double precision :: eps_neu_total, eps_nuc
         double precision :: d_eps_nuc_dT
         double precision :: d_eps_nuc_dRho
         double precision :: d_eps_nuc_dx(nvar-1) 
         double precision :: dxdt(nvar-1)
         double precision :: d_dxdt_dRho(nvar-1)
         double precision :: d_dxdt_dT(nvar-1)
         double precision :: d_dxdt_dx(nvar-1, nvar-1)
         double precision, pointer :: reaction_eps_nuc(:,:) ! (num_rvs, num_reactions)
         double precision, target :: eps_nuc_categories(num_rvs, num_categories)
         double precision, pointer :: rate_screened(:,:) ! (num_rvs, num_reactions)
         double precision, pointer :: rate_raw(:,:) ! (num_rvs, num_reactions)
         double precision, pointer :: rate_factors(:) ! (num_reactions)
         double precision, pointer :: reaction_Qs(:) ! (rates_reaction_id_max)
         double precision, pointer :: reaction_neuQs(:) ! (rates_reaction_id_max)
         logical :: reuse_given_rates
         integer :: screening_mode, lwork, i, num_isos, time0, time1, clock_rate
         double precision, pointer :: work(:) ! (lwork)

         double precision :: xh, Y, Cp, theta_e, d_theta_e_deta
         double precision :: dlnRho_dlnPgas_const_T, dlnRho_dlnT_const_Pgas
         double precision :: dlnRho_dlnT_const_P, d_epsnuc_dlnT_const_P, d_Cp_dlnT
         double precision :: res(num_eos_basic_results)
         double precision :: d_dlnRho_const_T(num_eos_basic_results) 
         double precision :: d_dlnT_const_Rho(num_eos_basic_results) 
         integer, pointer :: net_iso(:), chem_id(:)

         type (Net_General_Info), pointer :: g
         
         include 'formats.dek'
         
         num_isos = nvar-1
         
         ierr = 0
         f = 0
         dfdv = 0
         
         eos_handle = ipar(i_eos_handle)
         
         net_handle = ipar(i_net_handle)
         call get_net_ptr(net_handle, g, ierr)
         if (ierr /= 0) then
            write(*,*) 'invalid handle for eval_net -- did you call alloc_net_handle?'
            return
         end if
         
         v(1:num_isos) = max(1d-30, min(1d0, v(1:num_isos))) ! positive definite mass fractions
         x(1:num_isos) = v(1:num_isos)
         
         num_reactions = g% num_reactions

         i = burn_lrpar
         
         rate_factors => rpar(i+1:i+num_reactions)
         i = i+num_reactions
         category_factors => rpar(i+1:i+num_categories)
         i = i+num_categories
         reaction_Qs => rpar(i+1:i+rates_reaction_id_max)
         i = i+rates_reaction_id_max
         reaction_neuQs => rpar(i+1:i+rates_reaction_id_max)
         i = i+rates_reaction_id_max

         lwork = ipar(i_net_lwork)
         work => rpar(i+1:i+lwork)
         
         i = i+lwork
         call set_Aptr(reaction_eps_nuc, rpar(i+1:i+num_rvs*num_reactions), num_rvs, num_reactions)
         i = i+num_rvs*num_reactions
         call set_Aptr(rate_raw, rpar(i+1:i+num_rvs*num_reactions), num_rvs, num_reactions)
         i = i+num_rvs*num_reactions
         call set_Aptr(rate_screened, rpar(i+1:i+num_rvs*num_reactions), num_rvs, num_reactions)
         i = i+num_rvs*num_reactions
         if (i /= lrpar) then
            write(*,2) 'burn_jacob i', i
            write(*,2) 'lrpar', lrpar
            ierr = -1
            return
         end if
         
         if (ipar(i_clip) /= 0) then
            forall (i=1:num_isos) x(i) = max(0d0, min(1d0, x(i)))
         end if

         call composition_info(
     >         num_isos, g% chem_id, x, xh, Y, abar, zbar, z2bar, ye, approx_abar, approx_zbar, 
     >         sumx, dabar_dx, dzbar_dx)
     
         logT = v(nvar)/ln10
         T = 10**logT
         pressure = rpar(r_pressure)         
         Prad = Radiation_Pressure(T)
         Pgas = pressure - Prad
         lgPgas = log10(Pgas)

         chem_id => g% chem_id
         net_iso => g% net_iso
                  
         if (rpar(r_time_eos) >= 0) call system_clock(time0,clock_rate)
         
         call eosPT_get(
     >         eos_handle, 1 - (xh + Y), xh, abar, zbar, 
     >         num_isos, chem_id, net_iso, x,
     >         Pgas, lgPgas, T, logT, 
     >         Rho, logRho, dlnRho_dlnPgas_const_T, dlnRho_dlnT_const_Pgas, 
     >         res, d_dlnRho_const_T, d_dlnT_const_Rho, ierr)
         Cp = res(i_Cp)
         eta = res(i_eta)
         theta_e = 0
         screening_mode = ipar(i_screening_mode)
         if (screening_mode == classic_screening)
     >         theta_e = eos_theta_e(eta, d_theta_e_deta)
         rpar(r_rho) = Rho
         if (rpar(r_init_rho) < 0) rpar(r_init_rho) = Rho
         if (ierr /= 0 .or. Cp <= 0) then
            write(*,*) 'eosPT_get failed'
            write(*,1) 'xh', xh
            write(*,1) 'Y', Y
            write(*,1) 'Z', 1 - (xh + Y)
            write(*,1) 'abar', abar
            write(*,1) 'zbar', zbar
            write(*,1) 'pressure', pressure
            write(*,1) 'Prad', Prad
            write(*,1) 'Pgas', Pgas
            write(*,1) 'lgPgas', lgPgas
            write(*,1) 'T', T
            write(*,1) 'logT', logT
            write(*,1) 'Rho', Rho
            write(*,1) 'logRho', logRho
            write(*,1) 'Cp', Cp
            ierr = -1
            return
         end if

         if (rpar(r_time_eos) >= 0) then
            call system_clock(time1,clock_rate)
            rpar(r_time_eos) = rpar(r_time_eos) + dble(time1 - time0) / clock_rate
            if (rpar(r_time_net) >= 0) time0 = time1
         else if (rpar(r_time_net) >= 0) then
            call system_clock(time0,clock_rate)
         end if

         call eval_net(
     >         g, num_isos, num_reactions, g% num_weaklib_rates,
     >         x, T, logT, rho, logRho, 
     >         abar, zbar, z2bar, ye, eta, rate_factors, category_factors,
     >         reaction_Qs, reaction_neuQs,
     >         eps_nuc, d_eps_nuc_dRho, d_eps_nuc_dT, d_eps_nuc_dx, 
     >         dxdt, d_dxdt_dRho, d_dxdt_dT, d_dxdt_dx, 
     >         screening_mode, theta_e, 
     >         rate_screened, rate_raw, (ipar(i_reuse_rates) /= 0),
     >         reaction_eps_nuc, eps_nuc_categories, eps_neu_total,
     >         lwork, work, ierr)

         if (rpar(r_time_net) >= 0) then
            call system_clock(time1,clock_rate)
            rpar(r_time_net) = rpar(r_time_net) + dble(time1 - time0) / clock_rate
         end if

         if (ierr /= 0) then
            write(*,*) 'eval_net failed'
            write(*,1) 'xh', xh
            write(*,1) 'Y', Y
            write(*,1) 'Z', 1 - (xh + Y)
            write(*,1) 'abar', abar
            write(*,1) 'zbar', zbar
            write(*,1) 'pressure', pressure
            write(*,1) 'Prad', Prad
            write(*,1) 'Pgas', Pgas
            write(*,1) 'lgPgas', lgPgas
            write(*,1) 'T', T
            write(*,1) 'logT', logT
            write(*,1) 'Rho', Rho
            write(*,1) 'logRho', logRho
            write(*,1) 'Cp', Cp
            ierr = -1
            return
         end if
         
         ipar(i_reuse_rates) = 1 ! okay to reuse rates after 1st time
         
         f(1:num_isos) = dxdt
         dlnT_dt = eps_nuc/(Cp*T)
         f(nvar) = dlnT_dt
         
         if (ld_dfdv > 0) then

            dlnRho_dlnT_const_P = -res(i_chiT)/res(i_chiRho)
            d_epsnuc_dlnT_const_P = d_eps_nuc_dT*T + d_eps_nuc_dRho*Rho*dlnRho_dlnT_const_P
            d_Cp_dlnT = d_dlnT_const_Rho(i_Cp) + d_dlnRho_const_T(i_Cp)*dlnRho_dlnT_const_P
            
            dfdv(1:num_isos,1:num_isos) = d_dxdt_dx

            dfdv(nvar,nvar) = d_epsnuc_dlnT_const_P/(Cp*T) - dlnT_dt*(1 + d_Cp_dlnT/Cp)
            
            ! d_dxdt_dlnT
            dfdv(1:num_isos,nvar) = 
     >         d_dxdt_dT(1:num_isos)*T + d_dxdt_dRho(1:num_isos)*Rho*dlnRho_dlnT_const_P
            
            ! d_dlnTdt_dx
            dfdv(nvar,1:num_isos) = d_eps_nuc_dx(1:num_isos)/(Cp*T)

         end if
         
         
         contains
         
   
         subroutine set_Aptr(Aptr, dest, n1, n2)
            double precision, pointer :: Aptr(:, :)
            double precision, target :: dest(n1, n2) ! reshape work section
            integer, intent(in) :: n1, n2
            Aptr => dest
         end subroutine set_Aptr
         
         
      end subroutine burn_jacob


      subroutine burn_sjac(n,time,y,f,nzmax,ia,ja,values,lrpar,rpar,lipar,ipar,ierr)  
         use mtx_lib, only: dense_to_sparse_with_diag
         integer, intent(in) :: n, nzmax, lrpar, lipar
         double precision, intent(in) :: time
         double precision, intent(inout) :: y(n)
         integer, intent(out) :: ia(n+1), ja(nzmax)
         double precision, intent(out) :: f(n), values(nzmax)
         double precision, intent(inout), target :: rpar(lrpar)
         integer, intent(inout), target :: ipar(lipar)
         integer, intent(out) :: ierr ! nonzero means terminate integration
         double precision, pointer :: dfdv(:,:) ! (n,n)
         integer :: ld_dfdv, nz, i, j, cnt, nnz
      	include 'formats.dek'
      	!write(*,1) 'burn_sjac', x
      	ierr = 0
         ld_dfdv = n
         allocate(dfdv(n,n),stat=ierr)
         if (ierr /= 0) return
         call burn_jacob(n,time,y,f,dfdv,ld_dfdv,lrpar,rpar,lipar,ipar,ierr)
         if (ierr /= 0) then
            deallocate(dfdv)
            return
         end if
         ! remove entries with abs(value) < 1d-16
         cnt = 0; nnz = 0
         do i=1,n
            do j=1,n
               if (dfdv(i,j) /= 0) then
                  nnz = nnz + 1
                  if (abs(dfdv(i,j)) < 1d-16) then
                     cnt = cnt+1; dfdv(i,j) = 0
                  end if
               end if
            end do
         end do
         call dense_to_sparse_with_diag(ipar(i_sparse_format),n,n,dfdv,nzmax,nz,ia,ja,values,ierr)
         deallocate(dfdv)
      	!write(*,2) 'done burn_sjac: nz', nz
      end subroutine burn_sjac
      

      end module net_burn_const_P

