! ***********************************************************************
!
!   Copyright (C) 2011  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************

      module solve_mix

      use star_private_def
      use const_def

      implicit none

      
      real(dp), parameter :: Xlim = 1d-14
      real(dp), parameter :: tiny_mass = 1d3 ! a kilogram
      real(dp), parameter :: tinyX = 1d-50
      real(dp), parameter :: smallX = 1d-20
      
      
      logical, parameter :: skip_dX_dm = .false.


      contains
      

      integer function do_solve_mix( &
            s, dt_total, avg_burn_dxdt, avg_mix_dxdt, &
            species, pass, num_passes)
         use star_utils, only: update_time, total_times
         use chem_def, only: chem_isos, ih1, ine20, img24
         use eos_def
         
         type (star_info), pointer :: s
         real(dp), intent(in) :: dt_total 
         real(dp), pointer, intent(in), dimension(:,:) :: &
            avg_burn_dxdt, avg_mix_dxdt
         integer, intent(in) :: species, pass, num_passes
         
         integer :: ierr, nz, i, j, jj, k, op_err, k_bad
         logical :: bad_xsum, trace, conservation_error, bad_neg_x, okay
         integer :: time0, clock_rate
         real(dp) :: total_all_before, test_value, xsum, avg_x, atol, rtol, dxdt, &
            target_avg, final_avg, avg_error, sum_rates, ratio, sum_dxdt
            
         real(dp), pointer :: dX_dm1(:), dX_dm(:,:)
         real(dp), pointer :: mass1(:), sum_mass(:)
         real(dp), pointer :: mass(:,:)
         
         real(qp), pointer, dimension(:) :: du, d, dl, x1, b1, xp1, bp1, vp1
         real(qp), pointer, dimension(:,:) :: x, b, xp, bp, vp ! (nz,species)
         
         real(dp), dimension(species) :: target_avg_x, final_avg_x, x_error
         
         include 'formats'
         
         do_solve_mix = keep_going
         ierr = 0
         nz = s% nz
         trace = s% op_split_mix_trace   
         atol = s% op_split_mix_atol
         rtol = s% op_split_mix_rtol
         
         if (s% doing_timing) then
            total_all_before = total_times(s)
            call system_clock(time0,clock_rate)
         end if
                  
         call do_alloc(ierr)
         if (ierr /= 0) then
            do_solve_mix = terminate
            s% termination_code = t_solve_mix
            write(*,*) 'allocate failed in do_solve_mix'
            return
         end if
         
         if (trace) write(*,2) 'start solve_mix', s% model_number
         
         if (.false.) then ! check abundances before mix
            do k=1,nz
               do j=1,species
                  if (s% xa(j,k) > 1d0 .or. s% xa(j,k) < 0d0) then   
                     write(*,3) 'bad xa', j, k, s% xa(j,k)
                     ierr = -1
                  end if
               end do
               if (abs(sum(s% xa(1:species,k)) - 1d0) > 1d-10) then
                  write(*,2) 'bad sum xa', k, sum(s% xa(1:species,k))
                  ierr = -1
               end if
            end do
            if (ierr /= 0) stop 'solve mix'
         end if
            
         do_solve_mix = keep_going               
               
         call create_matrix(dt_total, du, d, dl)

         conservation_error = .false.
         bad_neg_x = .false.
         bad_xsum = .false.
                  
!$OMP PARALLEL DO PRIVATE(j, k, op_err, ratio)

         solve_loop: do j = 1, species

            op_err = 0
            ! set b to rhs for matrix equation
            target_avg_x(j) = 0d0
            do k=1,nz
               b(k,j) = s% xa_pre(j,k) + dt_total*avg_burn_dxdt(j,k)
               target_avg_x(j) = target_avg_x(j) + s% dq(k)*b(k,j)
            end do
         
            call solve_tridiag( &
               dl, d, du, b(1:nz,j), x(1:nz,j), &
               bp(1:nz,j), xp(1:nz,j), vp(1:nz,j), nz, op_err)
            if (op_err /= 0) then
               ierr = op_err
               if (trace) write(*,2) 'do_solve_mix: solve_tridiag', j
            end if
         
            do k=1,nz
               s% xa(j,k) = x(k,j) ! note: allow x < 0; leave it for fixup later
            end do
            
         end do solve_loop
            
!$OMP END PARALLEL DO

         if (ierr /= 0) then
            if (s% report_ierr) write(*,*) 'solve failed in solve mix'
            do_solve_mix = retry
            call dealloc
            return
         end if
         
         if (.true.) then
            call fixup_after_mix( &
               s, species, nz, dt_total, mass, sum_mass, target_avg_x, &
               s% op_split_mix_atol, s% op_split_mix_rtol, &
               s% xa, dX_dm, ierr)
            if (ierr /= 0) then
               if (s% report_ierr) write(*,*) 'fixup failed in solve mix'
               do_solve_mix = retry
               call dealloc
               return
            end if
         end if
         
!xxx$OMP PARALLEL DO PRIVATE(j, k, sum_dxdt, sum_rates, op_err)

         do k=1,nz
         
            do j=1,species
               avg_mix_dxdt(j,k) = &
                  (s% xa(j,k) - s% xa_pre(j,k))/dt_total - avg_burn_dxdt(j,k)
            end do
            
            if (.true.) then
               op_err = 0
               call revise_avg_mix_dxdt(s, k, species, avg_mix_dxdt, dt_total, op_err)
               if (op_err /= 0) then
                  ierr = op_err
                  if (trace) write(*,2) 'do_solve_mix: revise_avg_mix_dxdt', k
               end if
            end if
            
         end do
            
!xxx$OMP END PARALLEL DO

         if (ierr /= 0) then
            if (s% report_ierr) write(*,*) 'solve failed in solve mix'
            do_solve_mix = retry
            call dealloc
            return
         end if

         call dealloc

         if (s% doing_timing) &
            call update_time(s, time0, total_all_before, s% time_solve_mix)
            
         
         contains

         
         subroutine do_alloc(ierr)
            use utils_lib, only: fill_with_NaNs
            use alloc
            integer, intent(out) :: ierr
            integer :: sz, sz_extra
            sz = nz*species
            sz_extra = nz_alloc_extra*species
            
            call non_crit_get_quad_array(s, du, nz, nz_alloc_extra, 'solve_mix', ierr)
            if (ierr /= 0) return  
                      
            call non_crit_get_quad_array(s, d, nz, nz_alloc_extra, 'solve_mix', ierr)
            if (ierr /= 0) return    
                    
            call non_crit_get_quad_array(s, dl, nz, nz_alloc_extra, 'solve_mix', ierr)
            if (ierr /= 0) return
                       
            call non_crit_get_quad_array(s, x1, sz, sz_extra, 'solve_mix', ierr)
            if (ierr /= 0) return            
            x(1:nz,1:species) => x1(1:sz)
            
            call non_crit_get_quad_array(s, b1, sz, sz_extra, 'solve_mix', ierr)
            if (ierr /= 0) return            
            b(1:nz,1:species) => b1(1:sz)
            
            call non_crit_get_quad_array(s, bp1, sz, sz_extra, 'solve_mix', ierr)
            if (ierr /= 0) return            
            bp(1:nz,1:species) => bp1(1:sz)
            
            call non_crit_get_quad_array(s, vp1, sz, sz_extra, 'solve_mix', ierr)
            if (ierr /= 0) return            
            vp(1:nz,1:species) => vp1(1:sz)
            
            call non_crit_get_quad_array(s, xp1, sz, sz_extra, 'solve_mix', ierr)
            if (ierr /= 0) return            
            xp(1:nz,1:species) => xp1(1:sz)

            call non_crit_get_work_array( &
               s, sum_mass, nz, nz_alloc_extra, 'solve_mix', ierr)
            if (ierr /= 0) return
            
            call non_crit_get_work_array( &
               s, mass1, sz, sz_extra, 'solve_mix', ierr)
            if (ierr /= 0) return
            mass(1:species,1:nz) => mass1(1:species*nz)
            
            call non_crit_get_work_array( &
               s, dX_dm1, sz, sz_extra, 'solve_mix', ierr)
            if (ierr /= 0) return
            dX_dm(1:species,1:nz) => dX_dm1(1:species*nz)

            !call fill_with_NaNs(dX_dm1)
            
         end subroutine do_alloc
            
            
         subroutine dealloc
            use alloc
            call non_crit_return_quad_array(s, du, 'solve_mix')
            call non_crit_return_quad_array(s, d, 'solve_mix')
            call non_crit_return_quad_array(s, dl, 'solve_mix')
            call non_crit_return_quad_array(s, x1, 'solve_mix')
            call non_crit_return_quad_array(s, b1, 'solve_mix')
            call non_crit_return_quad_array(s, bp1, 'solve_mix')
            call non_crit_return_quad_array(s, vp1, 'solve_mix')
            call non_crit_return_quad_array(s, xp1, 'solve_mix')
            call non_crit_return_work_array(s, mass1, 'solve mix')
            call non_crit_return_work_array(s, sum_mass, 'solve mix')
            call non_crit_return_work_array(s, dX_dm1, 'solve mix')
         end subroutine dealloc
         
         
         !    x(k) - xprev(k) = -(dt/dm)*(sig(k+1)*(x(k)-x(k+1)) - sig(k)*(x(k-1)-x(k)))
         ! => x(k-1)*(-sig(k-1)*dt/dm) + 
         !    x(k)*(1+(sig(k-1)+sig(k))*dt/dm) + 
         !    x(k+1)*(-sig(k)*dt/dm) 
         !  = xprev(k)
         subroutine create_matrix(dt_for_step, du, d, dl)
            real(dp), intent(in) :: dt_for_step
            real(qp), dimension(:) :: du, d, dl
            integer :: k
            real(qp) :: dtdm, dtsig00dm, dtsigp1dm, dt_ps, dm, sig
            real(qp), parameter :: xl0 = 0, xl1 = 1
            include 'formats'
            do k = 1, nz
               dt_ps = dt_for_step
               dm = s% dm(k)
               dtdm = dt_ps/dm
               sig = s% sig(k)
               dtsig00dm = dtdm*sig
               if (k > 1) then
                  dl(k-1) = -dtsig00dm
               end if
               if (k < nz) then
                  sig = s% sig(k+1)
                  dtsigp1dm = dtdm*sig
                  du(k) = -dtsigp1dm
               else
                  dtsigp1dm = xl0
               end if
               d(k) = xl1 + dtsig00dm + dtsigp1dm
            end do
         end subroutine create_matrix         


         subroutine solve_tridiag(sub, diag, sup, rhs, x, xp, bp, vp, n, ierr)
            !      sub - sub-diagonal
            !      diag - the main diagonal
            !      sup - sup-diagonal
            !      rhs - right hand side
            !      x - the answer
            !      n - number of equations
            integer, intent(in) :: n
            real(qp), dimension(:), intent(in) :: sup, diag, sub
            real(qp), dimension(:), intent(in) :: rhs
            real(qp), dimension(:), intent(out) :: x
            real(qp), dimension(:), intent(out) :: xp, bp, vp ! work arrays
            integer, intent(out) :: ierr

            real(qp) :: m
            integer i

            ierr = 0

            bp(1) = diag(1)
            vp(1) = rhs(1)

            do i = 2,n
               m = sub(i-1)/bp(i-1)
               bp(i) = diag(i) - m*sup(i-1)
               vp(i) = rhs(i) - m*vp(i-1)
            end do

            xp(n) = vp(n)/bp(n)
            x(n) = xp(n)
            do i = n-1, 1, -1
               xp(i) = (vp(i) - sup(i)*xp(i+1))/bp(i)
               x(i) = xp(i)
            end do

         end subroutine solve_tridiag
         
         
      end function do_solve_mix

      
      subroutine fixup_after_mix( &
            s, species, nz, dt_total, mass, sum_mass, target_avg_x, &
            X_total_atol, X_total_rtol, X_new, dX_dm, ierr)
         
         type (star_info), pointer :: s
         integer, intent(in) :: species, nz
         real(dp), intent(in) :: dt_total, X_total_atol, X_total_rtol
         real(dp), dimension(:), intent(in) :: target_avg_x ! (species)
         real(dp), dimension(:) :: sum_mass ! (nz)
         real(dp), dimension(:,:) :: mass ! (species,nz)
         real(dp), dimension(:,:) :: X_new, dX_dm ! (species,nz)
         integer, intent(out) :: ierr
         
         logical, parameter :: dbg = .false.
         integer :: k, bad_k, op_err
         
         real(dp) :: mtotal
         
         include 'formats'
         
         ierr = 0
         
         !write(*,*) 'skipping fixup in burn'
         !return
         
         !write(*,*) 'call get_masses'
         call get_masses(s, species, nz, dt_total, mass, ierr)
         if (ierr /= 0) then
            if (dbg) then
               stop 'failed in get_masses'
            end if
            return
         end if
      
         !write(*,*) 'call fix_negative_masses'
         call fix_negative_masses(s, species, nz, mass, ierr)
         if (ierr /= 0) then
            if (dbg) then
               stop 'failed in fix_negative_masses'
            end if
            return
         end if
      
         mtotal = sum(s% dm(1:nz))
         !write(*,*) 'call fix_species_conservation'
         call fix_species_conservation( &
            s, species, nz, mass, sum_mass, mtotal, target_avg_x, &
            X_total_atol, X_total_rtol, ierr)
         if (ierr /= 0) then
            if (dbg) then
               stop 'failed in fix_species_conservation'
            end if
            return
         end if
         
         if (.not. skip_dX_dm) then
            !write(*,*) 'call get1_dX_dm'
!$OMP PARALLEL DO PRIVATE(k, op_err)
            do k = 1, nz
               call get1_dX_dm( &
                  k, nz, species, mass, sum_mass, &
                  dX_dm, .false., op_err)
               if (op_err /= 0) then
                  bad_k = k
                  ierr = op_err
               end if
            end do
!$OMP END PARALLEL DO
         end if
            
         if (ierr /= 0) then
            if (dbg) write(*,2) 'failed in get1_dX_dm', bad_k
            if (dbg) stop 'fixup'
            return
         end if
         
         !write(*,*) 'call redistribute_mass'
         call redistribute_mass( &
            s, species, nz, X_total_atol, X_total_rtol, &
            mass, sum_mass, target_avg_x, X_new, dX_dm, ierr)
         if (ierr /= 0) then
            if (dbg) then
               stop 'failed in redistribute_mass'
            end if
            return
         end if
      
      end subroutine fixup_after_mix


      subroutine get1_dX_dm( &
            k, nz, species, mass, sum_mass, &
            dX_dm, dbg, ierr)
         integer, intent(in) :: k, nz, species
         real(dp), intent(in) :: sum_mass(:), mass(:,:)
         real(dp), intent(out) :: dX_dm(:,:)
         logical :: dbg
         integer, intent(out) :: ierr
         
         real(dp) :: slope, sm1, s00, xface_00, xface_p1, &
            dm_half, dm_00, dm_p1, dm_m1, dmbar_p1, dmbar_00, &
            x00, xm1, xp1
         integer :: j
         real(dp), parameter :: tiny_slope = 1d-10
            
         include 'formats'
         
         ierr = 0
         
         dm_00 = sum_mass(k)
         if (dm_00 < tiny_mass) then
            dX_dm(1:species,k) = 0d0
            return
         end if
         dm_half = 0.5d0*dm_00
      
         if (k > 1 .and. k < nz) then
            
            dm_m1 = sum_mass(k-1)
            dm_p1 = sum_mass(k+1)
            if (dm_m1 < tiny_mass .or. dm_p1 < tiny_mass) then
               dX_dm(1:species,k) = 0d0
               return
            end if
            
            dmbar_00 = 0.5d0*(dm_00 + dm_m1)
            dmbar_p1 = 0.5d0*(dm_00 + dm_p1)
            
            do j=1,species
               xm1 = mass(j,k-1)/dm_m1
               x00 = mass(j,k)/dm_00
               xp1 = mass(j,k+1)/dm_p1
               sm1 = (xm1 - x00)/dmbar_00
               s00 = (x00 - xp1)/dmbar_p1
               slope = 0.5d0*(sm1 + s00)
               if (sm1*s00 <= 0 .or. abs(slope) < tiny_slope) then
                  dX_dm(j,k) = 0d0
               else
                  dX_dm(j,k) = slope
                  xface_00 = x00 + slope*dm_half ! value at face(k)
                  xface_p1 = x00 - slope*dm_half ! value at face(k+1)
                  if (xface_00 > 1d0 .or. xface_00 < 0d0 .or. &
                        (xm1 - xface_00)*(xface_00 - x00) < 0 .or. &
                      xface_p1 > 1d0 .or. xface_p1 < 0d0 .or. &
                        (x00 - xface_p1)*(xface_p1 - xp1) < 0) then
                     if (abs(sm1) <= abs(s00)) then
                        dX_dm(j,k) = sm1
                     else
                        dX_dm(j,k) = s00
                     end if
                  end if
               end if
            end do
            
         else if (k == 1) then
            
            dm_p1 = sum_mass(k+1)
            if (dm_p1 < tiny_mass) then
               dX_dm(1:species,k) = 0d0
               return
            end if
            
            dmbar_p1 = 0.5d0*(dm_00 + dm_p1)
            
            do j=1,species
               x00 = mass(j,k)/dm_00
               xp1 = mass(j,k+1)/dm_p1
               slope = (x00 - xp1)/dmbar_p1
               if (abs(slope) < tiny_slope) then
                  dX_dm(j,k) = 0d0
               else
                  dX_dm(j,k) = slope
                  xface_00 = x00 + slope*dm_half ! value at face(k)
                  xface_p1 = x00 - slope*dm_half ! value at face(k+1)
                  if (xface_00 > 1d0 .or. xface_00 < 0d0 .or. &
                        (x00 - xface_p1)*(xface_p1 - xp1) < 0) then
                     dX_dm(j,k) = 0d0
                  end if
               end if
            end do
         
         else if (k == nz) then
         
            dm_m1 = sum_mass(k-1)
            if (dm_m1 < tiny_mass) then
               dX_dm(1:species,k) = 0d0
               return
            end if
            
            dmbar_00 = 0.5d0*(dm_00 + dm_m1)
            
            do j=1,species
               xm1 = mass(j,k-1)/dm_m1
               x00 = mass(j,k)/dm_00
               slope = (xm1 - x00)/dmbar_00
               if (abs(slope) < tiny_slope) then
                  dX_dm(j,k) = 0d0
               else
                  dX_dm(j,k) = slope
                  xface_00 = x00 + slope*dm_half ! value at face(k)
                  xface_p1 = x00 - slope*dm_half ! value at face(k+1)
                  if (xface_p1 > 1d0 .or. xface_p1 < 0d0 .or. &
                        (xm1 - xface_00)*(xface_00 - x00) < 0) then
                     dX_dm(j,k) = 0d0
                  end if
               end if
            end do
            
         else
            
            write(*,2) 'k bad', k
            stop 'get1_dX_dm'
         
         end if
         
         ! adjust so that sum(dX_dm) = 0
         if (sum(dX_dm(1:species,k)) > 0d0) then
            j = maxloc(dX_dm(1:species,k),dim=1)
         else
            j = minloc(dX_dm(1:species,k),dim=1)
         end if
         dX_dm(j,k) = 0d0 ! remove from sum
         dX_dm(j,k) = -sum(dX_dm(1:species,k))
         
         ! recheck for valid values at faces 
         do j=1,species
            x00 = mass(j,k)/dm_00
            slope = dX_dm(j,k)
            xface_00 = x00 + slope*dm_half ! value at face(k)
            xface_p1 = x00 - slope*dm_half ! value at face(k+1)
            if (xface_00 > 1d0 .or. xface_00 < 0d0 .or. &
                xface_p1 > 1d0 .or. xface_p1 < 0d0) then
               if (dbg) then ! .and. abs(slope) > 1d-10) then
                  write(*,3) 'give up on dX_dm', j, k
                  write(*,1) 'slope', slope
                  write(*,1) 'dm_half', dm_half
                  write(*,1) 'xface_00', xface_00
                  write(*,1) 'xface_p1', xface_p1
                  if (k > 1) then
                     dm_m1 = sum_mass(k-1)
                  write(*,1) 'xm1', mass(j,k-1)/dm_m1
                  end if
                  write(*,1) 'x00', x00
                  if (k < nz) then
                     dm_p1 = sum_mass(k+1)
                     write(*,1) 'xp1', mass(j,k+1)/dm_p1
                  end if
                  write(*,*)
               end if
               dX_dm(1:species,k) = 0d0
               exit
            end if
         end do
      
      end subroutine get1_dX_dm

         
      subroutine check_xa(s, species, nz, ierr)
         use star_utils, only: current_min_xa_hard_limit
         use chem_def, only: chem_isos
         type (star_info), pointer :: s
         integer, intent(in) :: species, nz         
         integer, intent(out) :: ierr
         integer :: j, k
         real(dp) :: min_xa_hard_limit
         logical, parameter :: dbg = .false.
         include 'formats'
         ierr = 0
         min_xa_hard_limit = current_min_xa_hard_limit(s)
         do k=1,nz
            do j=1,species
               if (s% xa(j,k) < min_xa_hard_limit) then
                  if (dbg) write(*,3) &
                     trim(chem_isos% name(s% chem_id(j))), j, k, s% xa(j,k)
                  ierr = -1
                  if (.not. dbg) return
               end if
            end do
         end do
      end subroutine check_xa

         
      subroutine get_masses(s, species, nz, dt_total, mass, ierr)
         use star_utils, only: current_min_xa_hard_limit
         use chem_def, only: chem_isos
         type (star_info), pointer :: s
         integer, intent(in) :: species, nz         
         real(dp) :: dt_total, mass(:,:)
         integer, intent(out) :: ierr
         integer :: j, k, kk
         real(dp) :: limit, min_xa_hard_limit
         logical, parameter :: dbg = .false.
         include 'formats'
         ierr = 0
         min_xa_hard_limit = current_min_xa_hard_limit(s)
         limit = 1d3*min_xa_hard_limit
         do k=1,nz
            do j=1,species
               if (s% xa(j,k) < limit) then ! use loose limit. 
                  ! let errors pass if at a jump in mixing coef
                  ! extreme errors will be caught in fix_species_conservation
                  if (k < nz) then
                     if (s% sig(k) <= 1d4*s% sig(k+1) .and. &
                           s% sig(k+1) <= 1d4*s% sig(k)) then
                        if (dbg) write(*,3) 'get_masses ' // &
                           trim(chem_isos% name(s% chem_id(j))), &
                           j, k, s% xa(j,k), s% xa_pre(j,k), &
                           s% xa_pre(j,k) + dt_total*s% avg_burn_dxdt(j,k), &
                           s% sig(k), s% sig(k+1) !, s% sig_term_limit*s% dm(k)/s% dt
                        ierr = -1
                        if (.not. dbg) return
                     end if
                  end if
               end if
               mass(j,k) = s% dm(k)*s% xa(j,k)
            end do
         end do
      end subroutine get_masses


      subroutine fix_negative_masses(s, species, nz, mass, ierr)
         type (star_info), pointer :: s
         integer, intent(in) :: species, nz         
         real(dp) :: mass(:,:)
         integer, intent(out) :: ierr         
      
         integer :: k, j, cnt, maxcnt, k_hi, k_lo, kk, jj
         real(dp) :: dm, source_mass, frac, sum_m
      
         include 'formats'
      
         ierr = 0
   
         do k = 1,nz
      
            fix1: do j=1,species
         
               if (mass(j,k) >= 0d0) cycle
               if (mass(j,k) >= -1d-13*s% dm(k)) then
                  mass(j,k) = 0d0
                  cycle fix1
               end if
            
               k_hi = min(k+1,nz)
               k_lo = max(k-1,1)
               maxcnt = 2
               do cnt = 1, maxcnt
                  sum_m = sum(mass(j,k_lo:k_hi))
                  if (sum_m >= tiny_mass) exit
                  if (cnt == maxcnt .or. mass(j,k_lo) < 0d0 .or. mass(j,k_hi) < 0d0) then
                     mass(j,k) = 0d0
                     cycle fix1
                  end if
                  k_hi = min(k_hi+1,nz)
                  k_lo = max(k_lo-1,1)
               end do
            
               dm = -mass(j,k) ! dm > 0
               mass(j,k) = 0d0
               ! remove dm from neighbors
               source_mass = sum_m + dm
               frac = sum_m/source_mass
               do kk = k_lo, k_hi
                  mass(j,kk) = mass(j,kk)*frac
               end do
               if (abs(sum_m - sum(mass(j,k_lo:k_hi))) > 1d-12*sum_m) then

                  write(*,5) 'bad (sum_m - sum mass(j,:))/sum_m', j, k, k_lo, k_hi, &
                     (sum_m - sum(mass(j,k_lo:k_hi)))/sum_m
                  write(*,1) 'sum_m - sum(mass(j,k_lo:k_hi))', sum_m - sum(mass(j,k_lo:k_hi))
                  write(*,1) 'sum(mass(j,k_lo:k_hi))', sum(mass(j,k_lo:k_hi))
                  write(*,1) 'sum_m', sum_m
                  write(*,1) 'dm', dm

                  write(*,1) 'sum_m + dm', sum_m + dm
                  write(*,1) 'frac', frac

                  write(*,1) 'dm/s% dm(k)', dm/s% dm(k)
                  write(*,2) 'k_lo', k_lo
                  write(*,2) 'k', k
                  write(*,2) 'k_hi', k_hi
                  do jj=k_lo,k_hi
                     write(*,3) 'mass', j, jj, mass(j,jj), mass(j,jj)/s% dm(jj)
                  end do
                  stop 'fixup'
               end if

            end do fix1
         
         end do
      
         do k=1,nz
            do j=1,species
               if (mass(j,k) < 0d0) then
                  write(*,3) 'mass(j,k)', j, k, mass(j,k)
                  stop 'fix_negative_masses'
               end if
            end do
         end do
   
      end subroutine fix_negative_masses
   
      
      subroutine fix_species_conservation( &
            s, species, nz, mass, sum_mass, mtotal, target_avg_x, &
            atol, rtol, ierr)
         type (star_info), pointer :: s
         integer, intent(in) :: species, nz 
         real(dp), intent(in) :: atol, rtol        
         real(dp) :: mass(:,:), sum_mass(:), target_avg_x(:), mtotal
         integer, intent(out) :: ierr     
      
         integer :: k, j, bad_j
         real(dp) :: xtotal_new, frac, err, bad_Xsum, &
            X_total_atol, X_total_rtol
      
         logical, parameter :: dbg = .false.
      
         include 'formats'
      
         ierr = 0
         bad_j = 0
         bad_Xsum = 0d0
         X_total_atol = atol
         X_total_rtol = rtol
      
         do j=1,species
            if (target_avg_x(j) < tinyX) cycle
            xtotal_new = sum(mass(j,1:nz))/mtotal
            frac = xtotal_new/target_avg_x(j)
            err = abs(xtotal_new - target_avg_x(j)) / (X_total_atol + &
               X_total_rtol*max(xtotal_new, target_avg_x(j)))
            if (err > 1d0) then
               if (dbg) write(*,2) 'fixup err', j, err
               if (dbg) write(*,2) 'xtotal_new', j, xtotal_new
               if (dbg) write(*,2) 'target_avg_x(j)', j, target_avg_x(j)
               if (dbg) write(*,2) 'X_total_atol', j, X_total_atol
               if (dbg) write(*,2) 'X_total_rtol', j, X_total_rtol
               if (dbg) write(*,2) 'frac', j, frac
               if (dbg) stop 'fixup 1'
               bad_j = j
               bad_Xsum = err
               ierr = -1
               return
            end if
            do k=1,nz
               mass(j,k) = mass(j,k)/frac
            end do
         end do

         do k=1,nz
            sum_mass(k) = sum(mass(1:species,k))
            if (sum_mass(k) < 0d0) then
               write(*,2) 'sum_mass(k)', k, sum_mass(k)
               stop 'fixup 2'
            end if
         end do
      
      end subroutine fix_species_conservation

      
      subroutine redistribute_mass( &
            s, species, nz, X_total_atol, X_total_rtol, &
            mass, sum_mass, target_avg_x, X_new, dX_dm, ierr)
         use utils_lib, only: is_bad_num
         use chem_def, only: chem_isos
         
         type (star_info), pointer :: s
         integer, intent(in) :: species, nz         
         real(dp) :: X_total_atol, X_total_rtol, &
            mass(:,:), sum_mass(:), target_avg_x(:)
         real(dp) :: X_new(:,:), dX_dm(:,:)
         integer, intent(out) :: ierr
      
         integer :: k_source, max_iters, k, i, j
         real(dp) :: source_cell_mass, remaining_source_mass, cell_dm_k, &
            remaining_needed_mass, frac, sumX, total_source, total_moved, &
            dm0, dm1, old_sum, new_sum, mtotal, err, dm, diff_dm, total_source_0
         logical :: okay, dbg
         
         include 'formats'
      
         ierr = 0
         total_moved = 0d0
         dbg = .false.
         total_source_0 = sum(sum_mass(1:nz))
      
         ! redistribute mass to make sum(X_new) = 1 for all cells
         ! this is done serially from 1 to nz
         k_source = 1
         source_cell_mass = sum_mass(k_source)
         remaining_source_mass = source_cell_mass
         max_iters = 100
      
         if (dbg) write(*,2) 'nz', nz
      
      cell_loop: do k=1,nz
   
            !dbg = total_num_iters == 26 .and. k == 535
   
            cell_dm_k = s% dm(k)
            remaining_needed_mass = cell_dm_k
            X_new(1:species,k) = 0d0
            if (dbg) write(*,*)
         
         fill_loop: do i=1,max_iters
      
               if (dbg) write(*,4) 'remaining_needed_mass/cell_dm_k', &
                  i, k, k_source, remaining_needed_mass/cell_dm_k, &
                  remaining_needed_mass, cell_dm_k, remaining_source_mass
               if (is_bad_num(remaining_needed_mass)) stop 'redistribute_mass'
               if (is_bad_num(remaining_source_mass)) then
                  write(*,4) 'remaining_source_mass', &
                     i, k, k_source, remaining_source_mass, &
                     remaining_needed_mass, cell_dm_k
                  stop 'redistribute_mass'
               end if
               
               !if (k == nz .and. k_source == nz) &
               !   remaining_source_mass = remaining_needed_mass
               
               if (remaining_needed_mass <= remaining_source_mass) then
                  if (dbg) write(*,1) 'remaining_needed_mass <= remaining_source_mass', &
                     remaining_needed_mass, remaining_source_mass
                  diff_dm = remaining_needed_mass
                  dm0 = remaining_source_mass - diff_dm
                  dm1 = remaining_source_mass
                  if (dm0 < 0 .or. dm1 < 0) then
                     write(*,2) 'dm0', k, dm0
                     write(*,2) 'dm1', k, dm1
                     write(*,2) 'diff_dm', k, diff_dm
                     write(*,2) 'remaining_source_mass', k, remaining_source_mass
                     write(*,2) 'remaining_needed_mass', k, remaining_needed_mass
                     stop 'redistribute_mass'
                  end if
                  total_moved = total_moved + remaining_needed_mass

                  if (dbg) then
                     write(*,2) 'cell_dm_k', k, cell_dm_k
                     write(*,2) 'remaining_needed_mass', k, remaining_needed_mass
                     write(*,2) &
                        'cell_dm_k - remaining_needed_mass', k, cell_dm_k - remaining_needed_mass
                     write(*,2) 'sum(X_new)*cell_dm_k', k, &
                        sum(X_new(1:species,k))*cell_dm_k
                  end if
                  do j=1,species
                     if (dbg) write(*,3) 'init X_new(j,k)', j, k, X_new(j,k)
                     dm = integrate_mass(j,dm0,dm1,diff_dm)
                     if (dm < 0d0) then
                        write(*,3) 'dm', j, k, dm
                        write(*,*)
                        stop 'redistribute_mass'
                     end if
                     X_new(j,k) = X_new(j,k) + dm/cell_dm_k
                     if (dbg .and. (X_new(j,k) > 1d0 .or. X_new(j,k) < 0d0)) then
                  
                        write(*,3) 'bad X_new(j,k)', j, k, X_new(j,k)
                        write(*,3) 'dm/cell_dm_k', j, k, dm/cell_dm_k
                        write(*,3) 'dm', j, k, dm
                        write(*,3) 'cell_dm_k', j, k, cell_dm_k
                        write(*,3) 'remaining_needed_mass', j, k, remaining_needed_mass
                        write(*,3) 'diff_dm', j, k, diff_dm
                        write(*,3) 'dm0', j, k, dm0
                        write(*,3) 'dm1', j, k, dm1
                        write(*,3) 'remaining_source_mass', j, k, remaining_source_mass
                        write(*,*)
                        stop 'redistribute_mass'
                     end if
                     remaining_needed_mass = remaining_needed_mass - dm
                  end do
                  if (dbg) write(*,1) 'final remaining_needed_mass/cell_dm_k', &
                     remaining_needed_mass/cell_dm_k, remaining_needed_mass, cell_dm_k
                  remaining_source_mass = dm0
                  remaining_needed_mass = 0d0
                  exit fill_loop
               end if
            
               if (dbg) write(*,1) 'use all remaining source cell mass'
               ! use all remaining source cell mass
               diff_dm = remaining_source_mass
               dm0 = 0d0
               dm1 = diff_dm
               do j=1,species
                  dm = integrate_mass(j,dm0,dm1,diff_dm)
                  X_new(j,k) = X_new(j,k) + dm/cell_dm_k
                  if (dbg .and. X_new(j,k) > 1d0 .or. X_new(j,k) < 0d0) then
                     write(*,3) 'bad X_new(j,k)', j, k, X_new(j,k)
                     write(*,1) 'dm', dm
                     write(*,1) 'dm0', dm0
                     write(*,1) 'dm1', dm1
                     write(*,1) 'remaining_source_mass', remaining_source_mass
                     write(*,1) 'remaining_needed_mass', remaining_needed_mass
                     write(*,1) 'cell_dm_k', cell_dm_k
                     stop 'redistribute_mass'
                  end if
               end do
               if (is_bad_num(remaining_needed_mass)) then
                  write(*,4) 'remaining_needed_mass', i, k, k_source, remaining_needed_mass
                  stop 'redistribute_mass'
               end if
               total_moved = total_moved + remaining_source_mass
               
               remaining_needed_mass = remaining_needed_mass - remaining_source_mass
               if (remaining_needed_mass > cell_dm_k) then
                  write(*,2) 'remaining_source_mass', k, remaining_source_mass
                  write(*,2) 'remaining_needed_mass', k, remaining_needed_mass
                  write(*,2) 'cell_dm_k', k, cell_dm_k
                  stop 'redistribute_mass'
               end if
            
               ! go to next source cell
               k_source = k_source + 1 ! okay to allow k_source > nz; see integrate_mass
               source_cell_mass = sum_mass(min(nz,k_source))
               remaining_source_mass = source_cell_mass
            
               if (source_cell_mass < 0d0 .or. is_bad_num(source_cell_mass)) then
                  ierr = -1
                  return
               
                  write(*,4) 'source_cell_mass', &
                     i, k, k_source, remaining_source_mass, source_cell_mass
                  stop 'redistribute_mass'
               end if
            
            end do fill_loop
            if (dbg) write(*,1) 'finished fill_loop'
         
            if (remaining_needed_mass > 0d0) then
               ierr = -1
               return
            
               write(*,5) 'remaining_needed_mass > 0d0', k, k_source, nz, nz, &
                  remaining_needed_mass
               write(*,1) 'source_cell_mass', source_cell_mass
               write(*,1) 'remaining_source_mass', remaining_source_mass
               stop 'redistribute_mass'
            end if
         
            sumX = sum(X_new(1:species,k))
            if (abs(sumX - 1d0) > 1d-10) then
               write(*,1) 'sum(X(k)) - 1d0', sumX - 1d0
               write(*,1) 'sum mass/source_cell_mass', sum(mass(1:species,k_source))/source_cell_mass
               !write(*,1) 'sum dX_dm k_source', sum(dX_dm(1:species,k_source))
               write(*,2) 'k', k
               write(*,2) 'k_source', k_source
               write(*,2) 'nz', nz
               stop 'redistribute_mass'
            end if
         
            do j=1,species
               X_new(j,k) = X_new(j,k)/sumX
            end do
            !sum_mass(k) = sum_mass(k)/sumX
         
         end do cell_loop
      
         if (dbg) write(*,1) 'finished cell_loop'
      
         total_source = sum(sum_mass(1:nz))
         if (abs(total_moved/total_source - 1d0) > 1d-6) then
            write(*,1) 'total_moved/total_source - 1', total_moved/total_source - 1d0
            write(*,1) 'total_source_0', total_source_0
            write(*,1) 'total_source', total_source
            write(*,1) 'total_moved', total_moved
            write(*,1) 'total_moved - total_source', total_moved - total_source
            write(*,1) 'total_moved - total_source_0', total_moved - total_source_0
            write(*,1) 'total_source - total_source_0', total_source - total_source_0
            write(*,2) 'k_source', k_source
            write(*,2) 'nz', nz
            write(*,2) 'nz', nz
            write(*,*)
            stop 'redistribute_mass'
         end if
      
         ! check cell sums
         okay = .true.
         do k=1,nz
            new_sum = sum(X_new(1:species,k))
            if (abs(new_sum - 1d0) > 1d-14 .or. is_bad_num(new_sum)) then
               write(*,2) 'redistribute_mass: new_sum-1', k, new_sum-1d0
               okay = .false.
            end if
         end do
         if (.not. okay) stop 'redistribute_mass'
      
         ! recheck conservation
         mtotal = sum(s% dm(1:nz))
         okay = .true.         
         do j=1,species
            if (target_avg_x(j) < 1d-20) cycle
            old_sum = mtotal*target_avg_x(j)
            new_sum = dot_product(s% dm(1:nz),X_new(j,1:nz))
            err = (new_sum - old_sum)/max(old_sum,new_sum)
            !write(*,2) 'err', j, err, old_sum, new_sum
            if (abs(err) > X_total_atol .or. is_bad_num(err)) then
               write(*,2) 'redistribute_mass err > atol ' // &
                  trim(chem_isos% name(s% chem_id(j))), &
                  j, err, old_sum, new_sum, target_avg_x(j)
               okay = .false.
            end if
         end do
         if (.not. okay) then
            ierr = -1
            !stop 'redistribute_mass'
         end if

           
         contains
      
      
         real(dp) function integrate_mass(j,dm0,dm1,diff)    
            integer, intent(in) :: j
            real(dp), intent(in) :: dm0, dm1, diff
         
            real(dp) :: dm, x, slope, x0, x1, half_dm, xavg
            integer :: k
         
            include 'formats'
         
            slope = 0d0
            if (k_source > nz) then ! reuse last source cell
               k = nz
            else
               k = k_source
               if (.not. skip_dX_dm) slope = dX_dm(j,k)
            end if
            dm = sum_mass(k)
            if (dm < tiny_mass) then
               integrate_mass = 0d0
               return
            end if
            x = mass(j,k)/dm
            half_dm = 0.5d0*dm
            x0 = x + slope*(dm0 - half_dm)
            x1 = x + slope*(dm1 - half_dm)
            if (dm0 > dm1) then
               write(*,1) 'x0', x0
               write(*,1) 'x1', x1
               write(*,1) 'dm0', dm0
               write(*,1) 'dm1', dm1
               stop 'integrate_mass'
            end if
            xavg = min(1d0, max(0d0, 0.5d0*(x0 + x1)))
            integrate_mass = xavg*diff
                 
         end function integrate_mass

      
      end subroutine redistribute_mass

      
      subroutine revise_avg_mix_dxdt(s, k, species, avg_mix_dxdt, dt_total, ierr)
         type (star_info), pointer :: s
         integer, intent(in) :: k, species
         real(dp), pointer :: avg_mix_dxdt(:,:)
         real(dp), intent(in) :: dt_total
         integer, intent(out) :: ierr
         
         integer :: j
         real(dp) :: sum_dxdt, sum_rates
         
         include 'formats'
         
         ierr = 0
         ! revise avg_mix_dxdt so sums to 0 for cell k
         sum_dxdt = sum(avg_mix_dxdt(1:species,k))

         sum_rates = 0d0 ! sum of rates > 0
         if (sum_dxdt > 0d0) then ! scale down the rates that are > 0
            do j=1,species
               if (avg_mix_dxdt(j,k) <= 0d0) cycle
               sum_rates = sum_rates + avg_mix_dxdt(j,k)
            end do
            if (sum_rates <= 0d0) then
               write(*,2) 'sum_rates should be > 0', k, sum_rates
               stop 'solve mix'
            end if
            do j=1,species
               if (avg_mix_dxdt(j,k) <= 0d0) cycle
               avg_mix_dxdt(j,k) = &
                  avg_mix_dxdt(j,k)*(sum_rates - sum_dxdt)/sum_rates
            end do
         else if (sum_dxdt < 0d0) then ! scale down the rates that are < 0
            do j=1,species
               if (avg_mix_dxdt(j,k) >= 0d0) cycle
               sum_rates = sum_rates + avg_mix_dxdt(j,k)
            end do
            if (sum_rates >= 0d0) then
               write(*,2) 'sum_rates should be < 0', k, sum_rates
               stop 'solve mix'
            end if
            do j=1,species
               if (avg_mix_dxdt(j,k) >= 0d0) cycle
               avg_mix_dxdt(j,k) = &
                  avg_mix_dxdt(j,k)*(sum_rates - sum_dxdt)/sum_rates
            end do
         end if
         sum_dxdt = sum(avg_mix_dxdt(1:species,k))
         if (abs(sum_dxdt) > 1d-13) then
            write(*,2) 'bad sum avg_mix_dxdt', k, sum_dxdt
            write(*,2) 'sum xa', k, sum(s% xa(1:species,k))
            write(*,2) 'sum xa_pre', k, sum(s% xa_pre(1:species,k))
            write(*,2) 'dt_total', k, dt_total
            stop 'solve mix'
         end if
         
      end subroutine revise_avg_mix_dxdt
            




      end module solve_mix


