! ***********************************************************************
!
!   Copyright (C) 2011  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************

      module solve_burn

      use star_private_def
      use alert_lib
      use const_def
      use num_def

      implicit none

      
      contains
      
      
      subroutine setup_interpolants_for_burn( &
            s, dt_total, num_times_max, &
            num_times, times, log10Ts_f, log10Rhos_f, etas_f, ierr)
         use interp_1d_def, only: pm_work_size
         use interp_1d_lib, only: interp_pm
         type (star_info), pointer :: s
         real(dp), intent(in) :: dt_total
         integer, intent(in) :: num_times_max
         integer, intent(out) :: num_times
         real(dp), pointer, dimension(:) :: times
         real(dp), pointer, dimension(:,:,:) :: log10Ts_f, log10Rhos_f, etas_f
         integer, intent(out) :: ierr

         real(dp), dimension(:,:,:), pointer :: work
         integer, parameter :: nwork = pm_work_size
         integer :: k, nz, op_err
         
         include 'formats.dek'
         
         ierr = 0
         nz = s% nz
         
         !num_times = num_times_max
         
         num_times = 1
         
         
         allocate( &
            times(num_times), &
            log10Ts_f(4,num_times,nz), log10Rhos_f(4,num_times,nz), &
            etas_f(4,num_times,nz), work(num_times,nwork,nz), stat=ierr)
         if (ierr /= 0) return
         
         times(1) = dt_total
         forall (k=1:nz)
            log10Ts_f(1,1,k) = s% lnT(k)/ln10
            log10Rhos_f(1,1,k) = s% lnd(k)/ln10
            etas_f(1,1,k) = s% eta(k)
         end forall
         
         if (num_times > 1) then
!$OMP PARALLEL DO PRIVATE(k,op_err)
            do k=1,nz
               call interp_pm(times, num_times, log10Ts_f(:,:,k), nwork, work(:,:,k), op_err)
               if (op_err /= 0) ierr = op_err
               call interp_pm(times, num_times, log10Rhos_f(:,:,k), nwork, work(:,:,k), op_err)
               if (op_err /= 0) ierr = op_err
            end do
!$OMP END PARALLEL DO
            if (ierr /= 0) then
               write(*,*) 'interp_pm failed for solve_burn'
               call dealloc
               return
            end if
         end if
         
         call dealloc
         
         contains
         
         subroutine dealloc
            deallocate(work)
         end subroutine dealloc
      
      end subroutine setup_interpolants_for_burn
      

      integer function do_solve_burn( &
            s, t_start, t_end, dt_total, dxdt_source_term, need_final_eps_nuc, &
            num_times, times, log10Ts_f, log10Rhos_f, etas_f)
         use star_utils, only: update_time, total_times
         use mtx_lib, only: decsol_option
         use num_lib, only: solver_option
         use chem_def, only: chem_isos
         
         type (star_info), pointer :: s
         real(dp), intent(in) :: t_start, t_end, dt_total
         real(dp), pointer, intent(in) :: dxdt_source_term(:,:) 
            ! (species,nz)  or null if no source term.
         logical, intent(in) :: need_final_eps_nuc
         integer, intent(in) :: num_times ! for interpolation of lnT and lnRho
         real(dp), pointer, dimension(:) :: times
         real(dp), pointer, dimension(:,:,:) :: log10Ts_f, log10Rhos_f, etas_f

         integer :: ierr, op_err, k, j, max_num_steps_used, nz, species, num_steps
         integer :: time0, time1, clock_rate, k_bad
         logical :: xsum_bad, converged, xsum_bad_at_k, converged_at_k, trace
         real(dp) :: dt, total_all_before, total_all_after, fraction_non_net, total_burn_time
         real(dp) :: time_solve_burn_in_net, time_solve_burn_non_net

         include 'formats.dek'
         
         ierr = 0

         if (s% doing_timing) then
            total_all_before = total_times(s)
            call system_clock(time0,clock_rate)
         else
            total_all_before = 0
         end if
         time_solve_burn_in_net = 0
         time_solve_burn_non_net = 0

         nz = s% nz
         species = s% species
         trace = s% op_split_burn_trace
         dt = t_end - t_start

         do_solve_burn = retry
         if (species >= s% op_split_decsol_switch) then
            s% burn_decsol_option = decsol_option(s% op_split_large_mtx_decsol, ierr)
         else
            s% burn_decsol_option = decsol_option(s% op_split_small_mtx_decsol, ierr)
         end if
         if (ierr /= 0) then
            write(*,*) 'bad string for op_split_burn_decsol'
            do_solve_burn = terminate
            return
         end if
         if (len_trim(s% op_split_burn_solver1) == 0) then
            s% burn_solver_option1 = -1
         else
            s% burn_solver_option1 = solver_option(s% op_split_burn_solver1, ierr)
            if (ierr /= 0) then
               write(*,*) 'bad string for op_split_burn_solver1 ' // trim(s% op_split_burn_solver1)
               do_solve_burn = terminate
               return
            end if
         end if
         s% burn_solver_option2 = solver_option(s% op_split_burn_solver2, ierr)
         if (ierr /= 0) then
            write(*,*) 'bad string for op_split_burn_solver2 ' // trim(s% op_split_burn_solver2)
            do_solve_burn = terminate
            return
         end if

         max_num_steps_used = 0
         
!$OMP PARALLEL DO PRIVATE(k,op_err,xsum_bad,num_steps,converged) SCHEDULE(DYNAMIC,2)
         do k = 1, nz
            if (ierr /= 0) cycle
            op_err = 0
            call do1_burn( &
               s, k, species, num_times, times, t_start, t_end, dxdt_source_term, need_final_eps_nuc, &
               log10Ts_f, log10Rhos_f, etas_f, time_solve_burn_non_net, time_solve_burn_in_net, &
               xsum_bad, converged, num_steps, op_err)
            if (op_err /= 0) then
               ierr = op_err
               k_bad = k
               xsum_bad_at_k = xsum_bad
               converged_at_k = converged
            end if
            max_num_steps_used = max(num_steps, max_num_steps_used)
         end do
!$OMP END PARALLEL DO

         if (s% doing_timing) then ! substract time_solve_burn_in_net
            call system_clock(time1,clock_rate)
            total_all_after = total_times(s)
            total_burn_time = dble(time1-time0)/clock_rate - (total_all_after - total_all_before)
            if (time_solve_burn_non_net + time_solve_burn_in_net <= 0) then
               stop 'do_solve_burn'
               fraction_non_net = 0
            else
               fraction_non_net = &
                  time_solve_burn_non_net/(time_solve_burn_non_net + time_solve_burn_in_net)
            end if
            s% time_solve_burn_non_net = &
               s% time_solve_burn_non_net + fraction_non_net*total_burn_time
            s% time_solve_burn_in_net = &
               s% time_solve_burn_in_net + (1-fraction_non_net)*total_burn_time               
         end if
         
         if (ierr /= 0) then
            do_solve_burn = retry               
            k = k_bad
!$OMP CRITICAL (solve_burn_error)
            if (.not. converged_at_k) then
               write(*,'(a50,i5,99f20.12)') 'retry: burn failed to converge for cell', k
            else if (.not. xsum_bad_at_k) then
               j = minloc(s% xa(:,k), dim=1)
               write(*,'(a50,i5,99f20.12)') 'retry: burn failed: negative abundance for ' // &
                  trim(chem_isos% name(s% chem_id(j))), k, s% xa(j,k)
            else
               write(*,'(a50,i5,99f20.12)') 'retry: burn failed: bad abundances 1-xsum',&
                  k, 1d0-sum(s% xa(:,k)), 1d0-sum(s% xa_pre_hydro(:,k))
            end if
            !do j=1,num_times
            !   write(*,2) 'lgT, lgRho', j, log10Ts_f(1,j,k), log10Rhos_f(1,j,k)
            !end do
!$OMP END CRITICAL (solve_burn_error)
         else
            do_solve_burn = keep_going
         end if
         
      end function do_solve_burn
      
      
      subroutine do1_burn( &
            s, k, species, num_times, times, t_start, t_end, dxdt_source_term, need_final_eps_nuc, &
            log10Ts_f, log10Rhos_f, etas_f, time_solve_burn_non_net, time_solve_burn_in_net, &
            xsum_bad, converged, num_steps, ierr)
         type (star_info), pointer :: s
         integer, intent(in) :: k, species, num_times
         real(dp), intent(in) :: t_start, t_end
         real(dp), pointer, intent(in) :: dxdt_source_term(:,:) 
         logical, intent(in) :: need_final_eps_nuc
         real(dp), pointer, dimension(:) :: times
         real(dp), pointer, dimension(:,:,:) :: log10Ts_f, log10Rhos_f, etas_f
         real(dp), intent(inout) :: time_solve_burn_non_net, time_solve_burn_in_net
         logical, intent(out) :: xsum_bad, converged
         integer, intent(out) :: num_steps, ierr

         integer :: time0, time1, clock_rate, time_burn0, time_burn1
         integer :: which_solver
         integer :: max_steps
         real(dp) :: total_burn_time, time_doing_net
         
         include 'formats.dek'
      
         ierr = 0
         if (s% doing_timing) then
            time_doing_net = 0
            call system_clock(time_burn0,clock_rate)
         else
            time_doing_net = -1
         end if
         
         max_steps = 1
         converged = .true.
         xsum_bad = .false.
         which_solver = s% burn_solver_option1
         if (which_solver > 0) then
            call do1_burn_for_cell( &
               s, k, t_start, t_end, dxdt_source_term, species, which_solver, max_steps, &
               num_times, times, log10Ts_f(:,:,k), log10Rhos_f(:,:,k), etas_f(:,:,k), &
               time_doing_net, converged, xsum_bad, num_steps, ierr)  
         end if
         
         if (which_solver <= 0 .or. ierr /= 0 .or. .not. converged) then ! try 2nd solver
            which_solver = s% burn_solver_option2
            max_steps = s% op_split_burn_maxsteps
            converged = .true.
            call do1_burn_for_cell( &
               s, k, t_start, t_end, dxdt_source_term, species, which_solver, max_steps, &
               num_times, times, log10Ts_f(:,:,k), log10Rhos_f(:,:,k), etas_f(:,:,k), &
               time_doing_net, converged, xsum_bad, num_steps, ierr)  
         end if
         
         if (need_final_eps_nuc .and. ierr == 0) then
            call calc_eps_nuc_info(s, k, species, ierr)
         end if
         
         if (s% doing_timing .and. ierr == 0) then
!$OMP CRITICAL (crit_time_for_burn)
            call system_clock(time_burn1,clock_rate)
            total_burn_time = dble(time_burn1 - time_burn0)/clock_rate
            time_solve_burn_non_net = time_solve_burn_non_net + (total_burn_time - time_doing_net)
            time_solve_burn_in_net = time_solve_burn_in_net + time_doing_net
!$OMP END CRITICAL (crit_time_for_burn)
         end if

      end subroutine do1_burn
      
      
      subroutine calc_eps_nuc_info(s, k, species, ierr)
         use chem_lib, only: composition_info
         use micro, only: do1_net
         type (star_info), pointer :: s         
         integer, intent(in) :: k, species
         integer, intent(out) :: ierr  
         logical, parameter :: reuse_given_rates = .true.
         real(dp) :: xsum, dabar_dx(species), dzbar_dx(species)
         call composition_info(species, s% chem_id, s% xa(1:species,k), s% X(k), s% Y(k), &
            s% abar(k), s% zbar(k), s% z2bar(k), s% ye(k), &
            s% approx_abar(k), s% approx_zbar(k), xsum, dabar_dx, dzbar_dx)  
         ! update eps_nuc for new composition with same T and Rho
         ! reuse_given_rates = .true. since no change in T and Rho
         call do1_net( &
            s, k, s% species, s% num_reactions, s% net_lwork, reuse_given_rates, ierr)
      end subroutine calc_eps_nuc_info
            
      
      subroutine do1_burn_for_cell( &
            s, k, t_start, t_end, dxdt_source_term, species, which_solver, max_steps, &
            num_times, times, log10Ts_f, log10Rhos_f, etas_f, &
            time_doing_net, converged, xsum_bad, num_steps, ierr)  
         use net_lib, only: net_1_zone_burn
         use micro, only: get_screening_mode
         use rates_def, only: std_reaction_neuQs, std_reaction_Qs
         use mtx_def
         use num_def
         use num_lib
         use chem_def
         
         type (star_info), pointer :: s         
         integer, intent(in) :: k, num_times, species, which_solver, max_steps
         real(dp), intent(in) :: t_start, t_end 
         real(dp), pointer, intent(in) :: dxdt_source_term(:,:) 
            ! (species,nz)  or null if no source term.
         real(dp), pointer, dimension(:) :: times
         real(dp), dimension(:,:) :: log10Ts_f, log10Rhos_f, etas_f
         real(dp), intent(inout) :: time_doing_net 
         logical,intent(out) :: converged, xsum_bad
         integer, intent(out) :: num_steps, ierr  
         
         integer :: j
         real(dp), dimension(species), target :: new_source ! (species)
         real(dp), dimension(:), pointer :: new_source_term ! (species)
         integer :: screening_mode ! see screen_def
         real(dp) :: xsum, h 
         real(dp) :: max_step_size ! maximal step size.
         real(dp) :: rtol(1) ! relative error tolerance(s)
         real(dp) :: atol(1) ! absolute error tolerance(s)
         integer :: itol ! switch for rtol and atol
         integer :: which_decsol ! from mtx_def
         integer :: caller_id ! only provided for use by caller's solout routine
         interface ! subroutine called after each successful step
            include "num_solout.dek"
         end interface
         integer  :: iout ! switch for calling the subroutine solout:
            ! iout=0: subroutine is never called
            ! iout=1: subroutine is available for output.
            ! iout=2: want dense output available from solout.         
         integer :: nfcn    ! number of function evaluations
         integer :: njac    ! number of jacobian evaluations
         integer :: nstep   ! number of computed steps
         integer :: naccpt  ! number of accepted steps
         integer :: nrejct  ! number of rejected steps
         
         real(dp) :: sum_xa, x_in(species), x_out(species), xsum_init, x_min
         logical :: trace, have_source
         
         include 'formats.dek'
                
         ierr = 0
         trace = s% op_split_burn_trace .and. k == s% nz
         xsum_bad = .false.

         screening_mode = get_screening_mode(s,ierr)         
         if (ierr /= 0) then
            write(*,*) 'unknown string for screening_mode: ' // trim(s% screening_mode)
            return
         end if
         
         have_source = associated(dxdt_source_term)
         if (.not. have_source) then
            new_source_term => null()
         else
            new_source_term => new_source
            new_source_term(1:species) = dxdt_source_term(1:species,k)
         end if
               
         which_decsol = s% burn_decsol_option
         rtol(:) = s% op_split_burn_rtol
         atol(:) = s% op_split_burn_atol
         
         if (trace) write(*,1) 'start solve_burn', (t_end - t_start)/secyer, &
            s% xa(s% net_iso(ih1),s% nz)

         h = t_end - t_start ! try to do it in 1 step
         max_step_size = 0
         itol = 0
         
         if (k <= 0 .or. k > s% nz) then
            write(*,2) 'bad k for do1_burn_for_cell', k
            stop 'solve_burn'
         end if
         
         caller_id = s% id + 100*k
         !iout = 0 ! don't call solout
         !iout = 1 ! call solout so can edit intermediate results
         iout = 2 ! allow for interpolated results with solout
         
         xsum_init = 0
         x_min = 1d99
         do j=1,species
            x_in(j) = s% xa(j,k)
            if (x_in(j) < x_min) x_min = x_in(j)
            xsum_init = xsum_init + x_in(j)
         end do
         
         if (abs(xsum_init-1) > s% sum_xa_tolerance .or. &
               x_min < s% min_xa_hard_limit) then
            !write(*,2) 'bad input for solve_burn', k, abs(xsum_init-1), x_min
            ierr = -1
            return
            stop 'solve_burn'
         end if

         call net_1_zone_burn( &
            s% eos_handle, which_solver, species, s% num_reactions, &
            t_start, t_end, x_in, s% op_split_burn_clip, &
            num_times, times, log10Ts_f, log10Rhos_f, etas_f, &
            new_source_term, s% rate_factors, s% category_factors(:), &
            std_reaction_Qs, std_reaction_neuQs, &
            screening_mode, s% theta_e(k), & 
            h, max_step_size, max_steps, rtol, atol, itol, & 
            which_decsol, caller_id, burn_solout, iout, x_out, & 
            nfcn, njac, nstep, naccpt, nrejct, time_doing_net, ierr)
            
         num_steps = nstep
         
         if (ierr /= 0) then
            converged = .false.
            if (trace) write(*,2) 'failed in 1 zone burn', k
            
            if (.false.) then ! debug
               write(*,2) 'failed in 1 zone burn', k
               write(*,2) 'which_solver', which_solver
               write(*,2) 's% num_reactions', s% num_reactions
               write(*,2) 'num_times', num_times
               write(*,2) 'max_steps', max_steps
               write(*,2) 'nstep', nstep
               write(*,2) 'which_decsol', which_decsol
               write(*,1) 'log10Ts_f(1,1)', log10Ts_f(1,1)
               write(*,1) 'log10Rhos_f(1,1)', log10Rhos_f(1,1)
               write(*,1) 'times(1)', times(1)
               write(*,1) 'xsum_init', xsum_init
               write(*,1) 'sum(xa)', sum(s% xa(1:species,k))
               write(*,2) 'species', species
               do j = 1, species
                  write(*,1) trim(chem_isos% name(s% chem_id(j))), s% xa(j,k)
               end do
               write(*,*)
               stop 'do1_burn_for_cell'
            end if
            
         else
            converged = .true.     
            x_min = 1d99       
            do j=1,species
               s% xa(j,k) = x_out(j)
               if (s% xa(j,k) < x_min) x_min = s% xa(j,k)
            end do
            if (x_min < s% min_xa_hard_limit) then
               ierr = -1
               if (trace) then
                  j = minloc(s% xa(:,k), dim=1)
                  write(*,2) 'burn: negative abundance: x ' // &
                     trim(chem_isos% name(s% chem_id(j))), k, s% xa(j,k)
               end if
               if (which_solver /= s% burn_solver_option1) s% why_Tlim = Tlim_neg_X
            else
               s% xa(:,k) = max(0d0, min(1d0, s% xa(:,k)))
               xsum = sum(s% xa(:,k))
               if (abs(xsum - 1d0) <= s% sum_xa_tolerance) then
                  do j=1,species
                     s% xa(j,k) = s% xa(j,k)/xsum
                  end do
               else
                  if (trace) &
                     write(*,3) 'do1_burn_for_cell: bad abundances: 1-xsum', k, which_solver, &
                        1d0-xsum, s% sum_xa_tolerance
                  if (which_solver /= s% burn_solver_option1) s% why_Tlim = Tlim_bad_Xsum
                  xsum_bad = .true.
                  ierr = -1


            
                  if (.false.) then ! debug
                     write(*,2) 'failed in 1 zone burn: bad xsum', k
                     write(*,2) 'which_solver', which_solver
                     write(*,2) 's% num_reactions', s% num_reactions
                     write(*,2) 'num_times', num_times
                     write(*,2) 'max_steps', max_steps
                     write(*,2) 'nstep', nstep
                     write(*,2) 'which_decsol', which_decsol
                     write(*,1) 'log10Ts_f(1,1)', log10Ts_f(1,1)
                     write(*,1) 'log10Rhos_f(1,1)', log10Rhos_f(1,1)
                     write(*,1) 'times(1)', times(1)
                     write(*,1) 'xsum_init', xsum_init
                     write(*,1) 'sum(xa)', sum(s% xa(1:species,k))
                     write(*,1) 'sum(x_out)', sum(x_out(1:species))
                     write(*,2) 'species', species
                     do j = 1, species
                        write(*,1) trim(chem_isos% name(s% chem_id(j))), s% xa(j,k), &
                           x_in(j), x_out(j), x_out(j) - x_in(j), x_out(j) - s% xa(j,k)
                     end do
                     write(*,*)
                     stop 'do1_burn_for_cell'
                  end if


               end if
            end if
         end if
         
         if (trace) write(*,1) 'done solve_burn', (t_end - t_start)/secyer, &
            s% xa(s% net_iso(ih1),s% nz)
               
      end subroutine do1_burn_for_cell

      
      subroutine burn_solout( &
            step, t_prev, time, n, xa, rwork_y, iwork_y, interp_y, lrpar, rpar, lipar, ipar, irtrn)
         use star_private_def, only: star_info, get_star_ptr
         use num_lib, only: safe_log10
         use chem_def, only: chem_isos
         integer, intent(in) :: step, n, lrpar, lipar
         real(dp), intent(in) :: t_prev, time
         real(dp), intent(inout) :: xa(n)
         real(dp), intent(inout), target :: rpar(lrpar), rwork_y(*)
         integer, intent(inout), target :: ipar(lipar), iwork_y(*)
         interface
            real(dp) function interp_y(i, s, rwork_y, iwork_y, ierr)
               use const_def, only: dp
               integer, intent(in) :: i ! result is interpolated approximation of xa(i) at time=s.
               real(dp), intent(in) :: s ! interpolation time value (between t_prev and time).
               real(dp), intent(inout), target :: rwork_y(*)
               integer, intent(inout), target :: iwork_y(*)
               integer, intent(out) :: ierr
            end function interp_y
         end interface
         integer, intent(out) :: irtrn ! < 0 causes solver to return to calling program.

         real(dp) :: xsum
         integer :: j, id, k, caller_id, ierr, species
         type (star_info), pointer :: s

         include 'formats.dek'
         
         irtrn = 0     
         ierr = 0    
         
         caller_id = ipar(1)
         ! recall that caller_id = s% id + 100*k
         id = mod(caller_id,100)
         k = caller_id/100
         call get_star_ptr(id, s, ierr)
         if (ierr /= 0) then
            write(*,*) 'bad id for burn solout', caller_id, id, k, lipar
            stop 'solve_burn'
            irtrn = -1
            return
         end if
         
         if (k <= 0 .or. k > s% nz) then
            !write(*,*) 'bad k for burn solout', caller_id, id, k, lipar
            !stop 'solve_burn'
            irtrn = -1
            return
         end if
         
         species = s% species
         
         ! fixup small errors -- let big ones go for someone else to catch
         if (minval(xa(1:species)) < s% min_xa_hard_limit) then
            !write(*,2) 'burn_solout bad negX', k, minval(xa(1:species))
            irtrn = -1
            return
         end if
         if (s% op_split_burn_clip) then
            do j=1,species
               xa(j) = max(0d0, min(1d0, xa(j)))
            end do
         end if
         xsum = sum(xa(1:species))
         if (abs(xsum - 1d0) > s% sum_xa_tolerance) then
            !write(*,2) 'burn_solout bad xsum-1', k, xsum - 1d0
            irtrn = -1
            return
         end if
         do j=1,species
            xa(j) = xa(j)/xsum
         end do

      end subroutine burn_solout


      end module solve_burn


