! ***********************************************************************
!
!   Copyright (C) 2011  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************

      module solve_mix

      use star_private_def
      use alert_lib
      use const_def

      implicit none
      
      integer, parameter :: XL_mix = qp

      contains
      

      integer function do_solve_mix(s, t_start, t_end, dxdt_source_term)
         use star_utils, only: update_time, total_times
         use chem_def, only: chem_isos, ih1
         
         type (star_info), pointer :: s
         real(dp), intent(in) :: t_start, t_end 
         real(dp), pointer, intent(in) :: dxdt_source_term(:,:) 
            ! (species,nz)  or null if no source term.
         
         integer :: ierr, nz, species, i, j, k, op_err, k_bad
         logical :: bad_xsum, trace, conservation_error, bad_neg_x, have_source
         integer :: time0, clock_rate, max_nsteps, nsteps
         real(dp) :: dt, atol, rtol, total_all_before, xsum, dt_per_step, avg_x, &
            avg_source, init_avg, final_avg, avg_error
         real(XL_mix), pointer, dimension(:) :: du, d, dl
         real(dp), pointer :: x(:,:), b(:,:) ! (nz,species)
         real(dp), pointer :: init_avg_x(:) ! (species)

         include 'formats.dek'
         
         nz = s% nz
         species = s% species
         trace = s% op_split_mix_trace   
         dt = t_end - t_start 
         have_source = associated(dxdt_source_term)
         
         nsteps = s% op_split_mix_minsteps
         max_nsteps = s% op_split_mix_maxsteps*nsteps
         atol = s% op_split_mix_atol
         rtol = s% op_split_mix_rtol
         
         do_solve_mix = retry
         ierr = 0
         
         if (s% doing_timing) then
            total_all_before = total_times(s)
            call system_clock(time0,clock_rate)
         end if
         
         allocate(du(nz), d(nz), dl(nz), x(nz,species), &
            b(nz,species), init_avg_x(species), stat=ierr)
         if (ierr /= 0) then
            do_solve_mix = terminate
            write(*,*) 'allocate failed in do_solve_mix'
            return
         end if
         
         if (trace) write(*,1) 'start solve_mix', dt/secyer, &
            s% xa(s% net_iso(ih1),s% nz)
         
         dt_loop: do while (nsteps <= max_nsteps) ! try to do dt in nsteps
         
            dt_per_step = dt/nsteps
            ! each step has same dt_per_step so can use same matrix for each
            call create_matrix(du, d, dl)
         
            substep_loop: do i = 1, nsteps
         
               conservation_error = .false.
               bad_neg_x = .false.
               bad_xsum = .false.

!$OMP PARALLEL DO PRIVATE(j, op_err, avg_x)

               solve_loop: do j = 1, species
                  if (ierr /= 0) cycle solve_loop
                  op_err = 0
                  ! set b to rhs for matrix equation
                  forall (k=1:nz) b(k,j) = s% xa(j,k)
                  if (have_source) &
                     forall (k=1:nz) b(k,j) = b(k,j) + dt_per_step*dxdt_source_term(j,k)
                  init_avg_x(j) = dot_product(s% dq(1:nz), s% xa(j,1:nz))
                  call solve_high_precision_tridiag(dl, d, du, b(1:nz,j), x(1:nz,j), nz, op_err)
                  if (op_err /= 0) then
                     ierr = op_err
                     if (trace) write(*,2) 'do_solve_mix: solve_high_precision_tridiag', j
                  end if
                  forall (k=1:nz) s% xa(j,k) = x(k,j)
                  if (have_source) then
                     avg_source = dot_product(s% dq(1:nz), dt_per_step*dxdt_source_term(j,1:nz))
                  else
                     avg_source = 0
                  end if
                  init_avg = init_avg_x(j)
                  final_avg = dot_product(s% dq(1:nz), s% xa(j,1:nz))
                  avg_error = init_avg + avg_source - final_avg
                  if (abs(avg_error)/(atol + rtol*init_avg) > 1d0) &
                     conservation_error = .true.
               end do solve_loop
               
!$OMP END PARALLEL DO

               if (ierr /= 0) then
                  write(*,*) 'matrix solve failed in solve mix'
                  do_solve_mix = retry
                  exit dt_loop
               end if
                        
               if (.not. conservation_error) then
                  do_solve_mix = keep_going
               else
                  do_solve_mix = retry
                  if (trace .or. s% report_ierr) then
                     do j = 1, species
                        avg_x = dot_product(s% dq(1:nz), s% xa(j,1:nz))
                        if (abs(avg_x - init_avg_x(j))/(atol + rtol*init_avg_x(j)) > 1d0) then
                           write(*,2) 'mix: conservation error ' // trim(chem_isos% name(s% chem_id(j))), &
                              j, avg_x - init_avg_x(j), avg_x, init_avg_x(j)
                        end if
                     end do
                  end if
               end if
         
!$OMP PARALLEL DO PRIVATE(k,j,xsum)
               
               check_loop: do k = 1, nz ! check abundances and renormalize
                  if (minval(s% xa(:,k)) < s% min_xa_hard_limit) then
                     k_bad = k
                     bad_neg_x = .true.
                     if (do_solve_mix /= retry .and. trace) then
!$OMP CRITICAL (crit_mix_err)
                        j = minloc(s% xa(:,k), dim=1)
                        write(*,4) 'retry mix: k, i, nsteps, x ' // &
                           trim(chem_isos% name(s% chem_id(j))), &
                           k, i, nsteps, s% xa(j,k)
!$OMP END CRITICAL (crit_mix_err)
                     end if
                     do_solve_mix = retry
                     cycle check_loop
                  end if
                  s% xa(:,k) = max(0d0, min(1d0, s% xa(:,k)))
                  xsum = sum(s% xa(:,k))
                  if (abs(xsum - 1d0) > s% sum_xa_tolerance) then
                     k_bad = k
                     bad_xsum = .true.
                     if (do_solve_mix /= retry .and. trace) then
!$OMP CRITICAL (crit_mix_err)
                        write(*,4) 'retry mix: k, i, nsteps, 1-xsum', &
                           k, i, nsteps, 1d0-sum(s% xa(:,k))
!$OMP END CRITICAL (crit_mix_err)
                     end if
                     do_solve_mix = retry
                     cycle check_loop
                  end if
                  s% xa(:,k) = s% xa(:,k)/xsum
               end do check_loop
               
!$OMP END PARALLEL DO
               
               if (do_solve_mix /= keep_going) exit substep_loop

            end do substep_loop

            if (do_solve_mix == keep_going) exit dt_loop ! all substeps converged
            
            if (2*nsteps > max_nsteps .or. do_solve_mix /= retry) then ! time to quit
               if (trace) then
                  k = k_bad
                  if (bad_xsum) then
                     write(*,4) 'failed mix: k, i, nsteps, 1-xsum', &
                        k, i, nsteps, 1d0-sum(s% xa(:,k))
                  else if (bad_neg_x) then
                     j = minloc(s% xa(:,k), dim=1)
                     write(*,4) 'failed mix: k, i, nsteps, negative abundance ' // &
                        trim(chem_isos% name(s% chem_id(j))), k, i, nsteps, s% xa(j,k)
                  end if
               end if
               exit dt_loop
            end if
            
            ! restore xa for retry
            forall (k=1:nz,j=1:species) s% xa(j,k) = s% xa_pre_hydro(j,k)

            ! try it again with more substeps
            do_solve_mix = keep_going 
            nsteps = 2*nsteps
            if (trace) write(*,2) 'retry mix with nsteps increased to', nsteps
         
         end do dt_loop
         
         call dealloc
         
         if (trace) write(*,1) 'done solve_mix', dt/secyer, &
            s% xa(s% net_iso(ih1),s% nz)

         if (s% doing_timing) &
            call update_time(s, time0, total_all_before, s% time_solve_mix)

         
         contains

         
         subroutine dealloc
            deallocate(du, d, dl, x, b, init_avg_x)
         end subroutine dealloc
         
         !    x(k) - xprev(k) = -(dt/dm)*(sig(k+1)*(x(k)-x(k+1)) - sig(k)*(x(k-1)-x(k)))
         ! => x(k-1)*(-sig(k-1)*dt/dm) + 
         !    x(k)*(1+(sig(k-1)+sig(k))*dt/dm) + 
         !    x(k+1)*(-sig(k)*dt/dm) 
         !  = xprev(k)
         ! if have dxdt_source(k), then rhs gets additional dt*dxdt_source(k),
         ! but no change to matrix for lhs.
         subroutine create_matrix(du, d, dl)
            real (XL_mix), dimension(:) :: du, d, dl
            integer :: k
            real (XL_mix) :: dtdm, dtsig00dm, dtsigp1dm, dt_ps, dm, sig
            real (XL_mix), parameter :: xl0 = 0, xl1 = 1
            include 'formats.dek'
            do k = 1, nz
               dt_ps = dt_per_step
               dm = s% dm(k)
               dtdm = dt_ps/dm
               sig = s% sig(k)
               dtsig00dm = dtdm*sig
               if (k > 1) then
                  dl(k-1) = -dtsig00dm
               end if
               if (k < nz) then
                  sig = s% sig(k+1)
                  dtsigp1dm = dtdm*sig
                  du(k) = -dtsigp1dm
               else
                  dtsigp1dm = xl0
               end if
               d(k) = xl1 + dtsig00dm + dtsigp1dm
            end do
         end subroutine create_matrix         
         
         
      end function do_solve_mix



      subroutine solve_high_precision_tridiag(sub, diag, sup, rhs, x, n, ierr)
         implicit none
         !      sub - sub-diagonal
         !      diag - the main diagonal
         !      sup - sup-diagonal
         !      v - right hand side
         !      x - the answer
         !      n - number of equations

         integer, intent(in) :: n
         real(XL_mix), dimension(n), intent(in) :: sup, diag, sub
         real(dp), intent(in) :: rhs(:)
         real(dp), intent(out) :: x(:)
         integer, intent(out) :: ierr

         real(XL_mix), dimension(:), pointer :: bp, vp, xp
         real(XL_mix) :: m
         integer i

         ierr = 0
         allocate(bp(n), vp(n), xp(n), stat=ierr)
         if (ierr /= 0) return

         bp(1) = diag(1)
         vp(1) = rhs(1)

         do i = 2,n
            m = sub(i-1)/bp(i-1)
            bp(i) = diag(i) - m*sup(i-1)
            vp(i) = rhs(i) - m*vp(i-1)
         end do

         xp(n) = vp(n)/bp(n)
         x(n) = xp(n)
         do i = n-1, 1, -1
            xp(i) = (vp(i) - sup(i)*xp(i+1))/bp(i)
            x(i) = xp(i)
         end do

         deallocate(bp, vp, xp)

      end subroutine solve_high_precision_tridiag


      end module solve_mix


