! ***********************************************************************
!
!   Copyright (C) 2010  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************


      module adjust_mass

      use star_private_def
      use alert_lib
      use const_def
      use chem_def, only: ih1, ihe4, ic12, in14, io16

      implicit none
      
      
      logical, parameter :: dbg_adjm = .false.

      logical, parameter :: check_for_bad_nums = .true.
      
      
      contains
      
      
      subroutine do_adjust_mass(s, species, ierr)
         use adjust_xyz, only: get_xa_for_accretion
         use utils_lib, only:has_bad_num, is_bad_num
         use star_utils, only: report_xa_bad_nums
         use num_lib, only: safe_log10
         use chem_def
         
         type (star_info), pointer :: s
         integer, intent(in) :: species
         integer, intent(out) :: ierr
         
         real(dp) :: &
            dt, delta_m, old_mstar, new_mstar, old_J, new_J, factor, &
            frac, env_mass, mmax, alfa, new_xmstar, old_xmstar, &
				old_J_const_mass, old_J_tot, new_J_const_mass, new_J_tot
         
         real(dp), target, dimension(species) :: &
            xtot_old_array, xtot_new_array, dxtot_array, xaccrete_array
         real(dp), pointer, dimension(:) :: xtot_old, xtot_new, dxtot, xaccrete
         
         real(dp), dimension(:), pointer :: &
            xm_old, xm_new, old_cell_mass, new_cell_mass, &
            oldloc, newloc, oldval, newval
         real(dp), dimension(:,:), pointer :: xa_old, work
         
         integer :: k, k_const_mass, nz
         
         logical, parameter :: dbg = .false.
         
         include 'formats.dek'
         
         ierr = 0
         xtot_old => xtot_old_array
         xtot_new => xtot_new_array
         dxtot => dxtot_array
         xaccrete => xaccrete_array
         s% angular_momentum_removed = 0
         
         if (dbg) write(*,*) 'do_adjust_mass'

         ! NOTE: don't assume that vars are all set at this point.
         ! use values from s% xh(:,:) and s% xa(:,:) only.
         ! e.g., don't use s% lnT(:) -- use s% xh(s% i_lnT,:) instead.
         
         s% q_for_recently_added = 1d0 ! change later if are adding mass
         s% k_below_recently_added = 1
         
         nz = s% nz         
         dt = s% dt
         if (s% super_eddington_wind_mdot == 0 .and. s% supersonic_wind_mdot == 0 &
               .and. dt < s% mass_change_full_on_dt) then
            if (dt <= s% mass_change_full_off_dt) then
               s% mstar_dot = 0
               if (dbg) write(*,1) 'do_adjust_mass return 1', dt, s% mass_change_full_off_dt
               return
            end if
            alfa = (dt - s% mass_change_full_off_dt)/ &
                     (s% mass_change_full_on_dt - s% mass_change_full_off_dt)
            dt = dt*alfa**2
         end if
         
         delta_m = dt*s% mstar_dot

         if (s% supersonic_wind_ejection_mass /= 0 .and. s% supersonic_wind_mdot /= 0) then
            delta_m = -min(s% supersonic_wind_ejection_mass, dt*s% supersonic_wind_mdot)
            s% mstar_dot = delta_m/dt
         else if (s% super_eddington_wind_mdot /= 0) then
            s% mstar_dot = -s% super_eddington_wind_mdot
            delta_m = dt*s% mstar_dot
         else if (delta_m == 0 &
            .or. (delta_m < 0 .and. s% star_mass <= s% min_star_mass_for_loss) &
            .or. (delta_m > 0 .and. s% max_star_mass_for_gain > 0 &
                  .and. s% star_mass >= s% max_star_mass_for_gain)) then
            if (dbg) write(*,*) 'do_adjust_mass return 2'
            return
         end if
         
         old_mstar = s% mstar
         old_xmstar = s% xmstar
         
         new_mstar = old_mstar + delta_m
         new_xmstar = old_xmstar + delta_m
         if (delta_m > 0) then
            s% q_for_recently_added = old_xmstar/(old_xmstar+s% factor_for_recently_added*delta_m)
            s% k_below_recently_added = nz
            do k = 1, nz
               if (s% q(k) <= s% q_for_recently_added) then
                  s% k_below_recently_added = k; exit
               end if
            end do
         end if
         
         if (s% show_info_for_recently_added .and. s% k_below_recently_added > 1) &
            write(*,2) 'recently_added: k_below, log 1-q', &
               s% k_below_recently_added, safe_log10(1d0 - s% q_for_recently_added)
         
         if (delta_m > 0 .and. s% max_star_mass_for_gain > 0 &
               .and. new_mstar > Msun*s% max_star_mass_for_gain) then
            new_mstar = Msun*s% max_star_mass_for_gain
            delta_m = new_mstar - old_mstar
         else if (delta_m < 0 .and. new_mstar < Msun*s% min_star_mass_for_loss) then
            new_mstar = Msun*s% min_star_mass_for_loss
            delta_m = new_mstar - old_mstar
         end if
         
         frac = old_xmstar/new_xmstar
         new_xmstar = old_xmstar/frac
         if (new_xmstar <= 0) then
            ierr = -1
            return
         end if
         s% xmstar = new_xmstar
         s% mstar = s% xmstar + s% M_center

         if (check_for_bad_nums) then
            if (has_bad_num(species*nz, s% xa)) then
               write(*, *) 'bad num in xa at start of change_mass: model_number', s% model_number
               call report_xa_bad_nums(s, ierr)
               return
            end if
         end if
         
         if (dbg_adjm) then
            env_mass = old_mstar-s% h1_boundary_mass*Msun
            write(*,'(a40,f26.16)') 'env_mass/old_mstar', env_mass/old_mstar
            write(*,*)
            write(*,1) 'delta_m/old_mstar', delta_m/old_mstar
            write(*,1) 's% h1_boundary_mass*Msun', s% h1_boundary_mass*Msun
            write(*,1) 'env_mass', env_mass
            write(*,1) 'delta_m/env_mass', delta_m/env_mass
            write(*,1) 'log10(abs(delta_m/env_mass))', safe_log10(abs(delta_m/env_mass))
            write(*,*)
         end if

         call do_alloc(ierr)
         if (ierr /= 0) return
         
         old_cell_mass(1:nz) = old_xmstar*s% dq(1:nz)
         
         if (delta_m > 0) then
            s% angular_momentum_removed = 0
         else
            s% angular_momentum_removed = angular_momentum_removed(ierr)
            if (ierr /= 0) return
         end if
         
         call revise_q_and_dq(s, nz, old_xmstar, new_xmstar, k_const_mass, ierr)
         if (ierr /= 0) then
            if (s% report_ierr) write(*, *) 'revise_q_and_dq failed in adjust mass'
            call dealloc
            return
         end if

         new_cell_mass(1:nz) = new_xmstar*s% dq(1:nz)
         xa_old(:,1:nz) = s% xa(:,1:nz) ! save the old abundances
         
         if (check_for_bad_nums) then
            if (has_bad_num(species*nz, xa_old)) then
               write(*, *) 'bad num in xa_old in adjust mass: model_number', s% model_number
               call report_xa_bad_nums(s, ierr)
               call dealloc
               return
            end if
         end if

         if (delta_m < 0) then
            xaccrete(1:species) = 0 ! xaccrete not used when removing mass
         else ! set xaccrete for composition of added material
            if (s% accrete_same_as_surface) then
               xaccrete(1:species) = xa_old(1:species,1)
            else
               call get_xa_for_accretion(s, xaccrete, ierr)
               if (ierr /= 0) then
                  if (s% report_ierr) write(*, *) 'get_xa_for_accretion failed in adjust mass'
                  call dealloc
                  return
               end if               
            end if
         end if
         
         if (check_for_bad_nums) then
            if (has_bad_num(species*nz, s% xa)) then
               write(*, *) 'bad num in xa before call set_xa: model_number', s% model_number
               call report_xa_bad_nums(s, ierr)
               call dealloc
               return
            end if
         end if
         
         mmax = max(old_mstar, new_mstar)
         if (delta_m < 0) then
            xm_old(1) = 0
            xm_new(1) = -delta_m ! note that xm_new(1) > 0 since delta_m < 0
         else
            xm_old(1) = delta_m
            xm_new(1) = 0
         end if
         do k = 2, nz
            xm_old(k) = xm_old(k-1) + old_cell_mass(k-1)
            if (k >= k_const_mass) then
               xm_new(k) = xm_old(k)
               new_cell_mass(k) = old_cell_mass(k)
            else
               xm_new(k) = xm_new(k-1) + new_cell_mass(k-1)
            end if
         end do

         old_J_const_mass = eval_total_angular_momentum(s, old_cell_mass, k_const_mass)
         old_J_tot = eval_total_angular_momentum(s, old_cell_mass, nz)
         
         call set_xa(s, nz, k_const_mass, species, xa_old, xaccrete, &
            xm_old, xm_new, mmax, old_cell_mass, new_cell_mass, ierr)  
         if (ierr /= 0 .and. s% report_ierr) write(*, *) 'set_xa failed in adjust mass' 
            
         if (ierr == 0) then     
            call set_lnR_v_info( &
               s, nz, k_const_mass, xm_old, xm_new, delta_m, old_xmstar, new_xmstar, &
               oldloc, newloc, oldval, newval, work, ierr)   
            if (ierr /= 0 .and. s% report_ierr) write(*, *) 'set_lnR_v_info failed in adjust mass' 
         end if
            
         if (ierr == 0 .and. s% rotation_flag) then     
            call set_omega( &
            	s, nz, k_const_mass, &
            	xm_old, xm_new, mmax, old_cell_mass, new_cell_mass, ierr)
            if (ierr /= 0 .and. s% report_ierr) write(*, *) 'set_omega failed in adjust mass' 
         end if
            
         if (ierr == 0) then     
            call set_lnd_lnT_info(s, nz, k_const_mass, &
               xm_old, xm_new, old_cell_mass, new_cell_mass, delta_m, old_xmstar, new_xmstar, &
               oldloc, newloc, oldval, newval, work, ierr)   
            if (ierr /= 0 .and. s% report_ierr) write(*, *) 'set_lnd_lnT_info failed in adjust mass' 
         end if
         
         if (check_for_bad_nums) then
            if (has_bad_num(species*nz, s% xa)) then
               write(*, *) 'bad num in xa at end of change_mass: model_number', s% model_number
               call report_xa_bad_nums(s, ierr)
               call dealloc
               return
            end if
         end if
         
         if (s% rotation_flag) then         
            new_J_const_mass = eval_total_angular_momentum(s, new_cell_mass, k_const_mass)
            new_J_tot = eval_total_angular_momentum(s, new_cell_mass, nz)
            if (.false.) write(*,3) 'adjust_mass J_const_mass interp err', &
               s% model_number, k_const_mass, &
					(new_J_const_mass + s% angular_momentum_removed - old_J_const_mass)/new_J_const_mass, &
					new_J_const_mass, old_J_const_mass, &
					old_J_const_mass - s% angular_momentum_removed, s% angular_momentum_removed
            if (.false.) write(*,2) 'adjust_mass J_tot interp err', &
               s% model_number, (new_J_tot + s% angular_momentum_removed - old_J_tot)/new_J_tot, &
					new_J_tot, old_J_tot - s% angular_momentum_removed
				!stop 'do_adjust_mass'
         end if
         
         call dealloc
         
         if (dbg_adjm) stop 'debugging: do_adjust_mass'
         if (dbg) write(*,*) 'do_adjust_mass return'

         
         contains
         
         
         real(dp) function angular_momentum_removed(ierr) result(J)
				! when call this, s% j_rot is still for old mass
            integer, intent(out) :: ierr
            integer :: k
            real(dp) :: r2, dmm1, dm00, dm, dm_sum, dm_lost
            include 'formats.dek'
            ierr = 0
            J = 0
            if (.not. s% rotation_flag) return
            dm00 = 0
            dm_sum = 0
            dm_lost = -delta_m
            do k = 1, nz
               dmm1 = dm00
               dm00 = old_cell_mass(k)
               dm = 0.5d0*(dmm1+dm00)
               if (dm_sum + dm > dm_lost) then
                  dm = dm_lost - dm_sum
                  dm_sum = dm_lost
					else
						dm_sum = dm_sum + dm
               end if
               J = J + dm*s% j_rot(k)
               if (dm_sum == dm_lost) exit
            end do
         end function angular_momentum_removed
         
         
         real(dp) function eval_total_angular_momentum(s,cell_mass,nz_last) result(J)
            type (star_info), pointer :: s
				real(dp) :: cell_mass(:)
            integer, intent(in) :: nz_last
            integer :: k
            real(dp) :: dmm1, dm00, dm
            include 'formats.dek'
            J = 0
            if (.not. s% rotation_flag) return
            dm00 = 0
            do k = 1, nz_last
               dmm1 = dm00
               dm00 = cell_mass(k)
               if (k == s% nz) then
                  dm = 0.5d0*dmm1+dm00
               else if (k == nz_last) then
                  dm = 0.5d0*dmm1
               else
                  dm = 0.5d0*(dmm1+dm00)
               end if
               J = J + dm*s% j_rot(k)
            end do
         end function eval_total_angular_momentum
         
         
         subroutine do_alloc(ierr)
            use interp_1d_def
            use alloc
            integer, intent(out) :: ierr
            ierr = 0            
            call get_work_array(s, xm_old, nz, nz_alloc_extra, 'adjust_mass xm_old', ierr)
            if (ierr /= 0) return            
            call get_work_array(s, xm_new, nz, nz_alloc_extra, 'adjust_mass xm_new', ierr)
            if (ierr /= 0) return            
            call get_work_array(s, old_cell_mass, nz, nz_alloc_extra, 'adjust_mass old_cell_mass', ierr)
            if (ierr /= 0) return            
            call get_work_array(s, new_cell_mass, nz, nz_alloc_extra, 'adjust_mass new_cell_mass', ierr)
            if (ierr /= 0) return            
            call get_2d_work_array(s, xa_old, species, nz, nz_alloc_extra, 'adjust_mass xa_old', ierr)
            if (ierr /= 0) return            
            call get_work_array(s, oldloc, nz, nz_alloc_extra, 'adjust_mass oldloc', ierr)
            if (ierr /= 0) return            
            call get_work_array(s, newloc, nz, nz_alloc_extra, 'adjust_mass newloc', ierr)
            if (ierr /= 0) return            
            call get_work_array(s, oldval, nz, nz_alloc_extra, 'adjust_mass oldval', ierr)
            if (ierr /= 0) return            
            call get_work_array(s, newval, nz, nz_alloc_extra, 'adjust_mass newval', ierr)
            if (ierr /= 0) return            
            call get_2d_work_array(s, work, nz, pm_work_size, nz_alloc_extra, 'adjust_mass work', ierr)
            if (ierr /= 0) return            
         end subroutine do_alloc
         
         
         subroutine dealloc
            use alloc            
            call return_work_array(s, xm_old, 'adjust_mass xm_old')            
            call return_work_array(s, xm_new, 'adjust_mass xm_new')            
            call return_work_array(s, old_cell_mass, 'adjust_mass old_cell_mass')            
            call return_work_array(s, new_cell_mass, 'adjust_mass new_cell_mass')            
            call return_2d_work_array(s, xa_old, 'adjust_mass xa_old')            
            call return_work_array(s, oldloc, 'adjust_mass oldloc')            
            call return_work_array(s, newloc, 'adjust_mass newloc')            
            call return_work_array(s, oldval, 'adjust_mass oldval')            
            call return_work_array(s, newval, 'adjust_mass newval')            
            call return_2d_work_array(s, work, 'adjust_mass work')                  
         end subroutine dealloc
         
         
      end subroutine do_adjust_mass


      subroutine set_lnd_lnT_info(s, nz, k_const_mass, &
               xm_old, xm_new, old_cell_mass, new_cell_mass, &
               delta_m, old_xmstar, new_xmstar, &
               oldloc, newloc, oldval, newval, work, ierr)
         use interp_1d_lib
         use interp_1d_def
         type (star_info), pointer :: s
         integer, intent(in) :: nz, k_const_mass
         real(dp), dimension(nz), intent(in) :: &
            xm_old, xm_new, old_cell_mass, new_cell_mass
         real(dp), intent(in) :: delta_m, old_xmstar, new_xmstar
         real(dp), pointer, dimension(:) :: oldloc, newloc, oldval, newval
         real(dp), pointer :: work(:,:)
         integer, intent(out) :: ierr         
         
         integer :: n, i_xlnd, i_lnPgas, i_lnE, i_lnT, nwork, k
         logical :: dbg
         
         include 'formats.dek'
         
         ierr = 0
         
         dbg = .false.
         
         i_xlnd = s% i_xlnd
         i_lnPgas = s% i_lnPgas
         i_lnE = s% i_lnE
         i_lnT = s% i_lnT
         n = k_const_mass
         nwork = pm_work_size
         
         oldloc(1) = 0
         oldloc(2:n) = xm_old(2:n) + 0.5d0*old_cell_mass(2:n)
         newloc(1:n) = xm_new(1:n) + 0.5d0*new_cell_mass(1:n)
         
         if (i_xlnd /= 0) then
            oldval(1:n) = s% xh(i_xlnd,1:n)
            call interpolate_vector( &
               n, oldloc, n, newloc, oldval, newval, interp_pm, nwork, work, ierr)
            if (ierr /= 0) return
            s% lnd_for_d_dt(1:n) = newval(1:n) - lnd_offset
         end if
         
         if (i_lnPgas /= 0) then
            oldval(1:n) = s% xh(i_lnPgas,1:n)
            call interpolate_vector( &
               n, oldloc, n, newloc, oldval, newval, interp_pm, nwork, work, ierr)
            if (ierr /= 0) return
            s% lnPgas_for_d_dt(1:n) = newval(1:n)
         end if
         
         if (i_lnE /= 0) then
            oldval(1:n) = s% xh(i_lnE,1:n)
            call interpolate_vector( &
               n, oldloc, n, newloc, oldval, newval, interp_pm, nwork, work, ierr)
            if (ierr /= 0) return
            s% lnE_var_for_d_dt(1:n) = newval(1:n)
         end if

         oldval(1:n) = s% xh(i_lnT,1:n)
         call interpolate_vector( &
            n, oldloc, n, newloc, oldval, newval, interp_pm, nwork, work, ierr)
         if (ierr /= 0) return
         s% lnT_for_d_dt(1:n) = newval(1:n)
         
         if (s% k_below_recently_added <= 1) return 
         
         ! need to interp by xq for new material
         n = s% k_below_recently_added
         oldloc(1:n) = (xm_old(1:n) + 0.5d0*old_cell_mass(1:n) - delta_m)/old_xmstar
         newloc(1:n) = (xm_new(1:n) + 0.5d0*new_cell_mass(1:n))/new_xmstar
         
         if (i_xlnd /= 0) then
            oldval(1:n) = s% xh(i_xlnd,1:n)
            call interpolate_vector( &
               n, oldloc, n, newloc, oldval, newval, interp_pm, nwork, work, ierr)
            if (ierr /= 0) return
            s% lnd_for_d_dt(1:n-1) = newval(1:n-1) - lnd_offset
         end if
         
         if (i_lnPgas /= 0) then
            oldval(1:n) = s% xh(i_lnPgas,1:n)
            call interpolate_vector( &
               n, oldloc, n, newloc, oldval, newval, interp_pm, nwork, work, ierr)
            if (ierr /= 0) return
            s% lnPgas_for_d_dt(1:n-1) = newval(1:n-1)
         end if
         
         if (i_lnE /= 0) then
            oldval(1:n) = s% xh(i_lnE,1:n)
            call interpolate_vector( &
               n, oldloc, n, newloc, oldval, newval, interp_pm, nwork, work, ierr)
            if (ierr /= 0) return
            s% lnE_var_for_d_dt(1:n-1) = newval(1:n-1)
         end if
         
         oldval(1:n) = s% xh(i_lnT,1:n)
         call interpolate_vector( &
            n, oldloc, n, newloc, oldval, newval, interp_pm, nwork, work, ierr)
         if (ierr /= 0) return
         s% lnT_for_d_dt(1:n-1) = newval(1:n-1)

      end subroutine set_lnd_lnT_info
      
      
      subroutine set_lnR_v_info( &
            s, nz, k_const_mass, xm_old, xm_new, &
            delta_m, old_xmstar, new_xmstar, &
            oldloc, newloc, oldval, newval, work, ierr)
         use interp_1d_lib
         use interp_1d_def
         type (star_info), pointer :: s
         integer, intent(in) :: nz, k_const_mass
         real(dp), dimension(nz), intent(in) :: xm_old, xm_new
         real(dp), intent(in) :: delta_m, old_xmstar, new_xmstar
         real(dp), pointer, dimension(:) :: oldloc, newloc, oldval, newval
         real(dp), pointer :: work(:,:)
         integer, intent(out) :: ierr         
         
         integer :: n, i_vel, i_lnR, nwork, k
         real(dp) :: J_old, J_new
         logical :: dbg
         
         include 'formats.dek'
         
         ierr = 0
         
         dbg = .false.
         
         i_vel = s% i_vel
         i_lnR = s% i_lnR
         n = k_const_mass
         nwork = pm_work_size
         
         oldloc(1) = 0
         oldloc(2:n) = xm_old(2:n)
         newloc(1:n) = xm_new(1:n)
         
         oldval(1:n) = s% lnR_for_d_dt(1:n)
         call interpolate_vector( &
            n, oldloc, n, newloc, oldval, newval, interp_pm, nwork, work, ierr)
         if (ierr /= 0) return
         s% lnR_for_d_dt(1:n) = newval(1:n)
         
         if (s% v_flag) then
            oldval(1:n) = s% v_for_d_dt(1:n)
            call interpolate_vector( &
               n, oldloc, n, newloc, oldval, newval, interp_pm, nwork, work, ierr)
            if (ierr /= 0) return
            s% v_for_d_dt(1:n) = newval(1:n)
         end if
         
         if (s% k_below_recently_added <= 1) return
         
         ! need to interp by xq for new material
         n = s% k_below_recently_added
         oldloc(1:n) = (xm_old(1:n) - delta_m)/old_xmstar
         newloc(1:n) = xm_new(1:n)/new_xmstar
         oldval(1:n) = s% xh(i_lnR,1:n)
         call interpolate_vector( &
            n, oldloc, n, newloc, oldval, newval, interp_pm, nwork, work, ierr)
         if (ierr /= 0) return
         s% lnR_for_d_dt(1:n-1) = newval(1:n-1)
         
         if (s% v_flag) then
            oldval(1:n) = s% xh(i_vel,1:n)
            call interpolate_vector( &
               n, oldloc, n, newloc, oldval, newval, interp_pm, nwork, work, ierr)
            if (ierr /= 0) return
            s% v_for_d_dt(1:n-1) = newval(1:n-1)
         end if

      end subroutine set_lnR_v_info

      
      subroutine set_xa( &
            s, nz, k_const_mass, species, xa_old, xaccrete, &
            old_cell_xbdy, new_cell_xbdy, mmax, old_cell_mass, new_cell_mass, ierr)
         ! set new values for s% xa(:,:)
         type (star_info), pointer :: s
         integer, intent(in) :: nz, k_const_mass, species
         real(dp), intent(in) :: xa_old(species, nz), xaccrete(species), mmax
         real(dp), dimension(nz), intent(in) :: &
            old_cell_xbdy, new_cell_xbdy, old_cell_mass, new_cell_mass
         integer, intent(out) :: ierr         
         integer :: k, op_err
         real(dp), parameter :: max_sum_abs = 10d0
         real(dp), parameter :: xsum_tol = 1d-2  
         include 'formats.dek'       
         ierr = 0
         if (dbg_adjm) write(*,2) 'set_xa: k_const_mass', k_const_mass
         if (k_const_mass < nz) then
            ! for k >= k_const_mass have m_new(k) = m_old(k),
            ! so no change in xa_new(:,k) for k > k_const_mass
            do k=k_const_mass+1,nz
               s% xa(:,k) = xa_old(:,k)
            end do
         end if
!$OMP PARALLEL DO PRIVATE(k, op_err)
         do k = 1, k_const_mass
            op_err = 0
            call set1_xa(s, k, nz, species, xa_old, xaccrete, &
               old_cell_xbdy, new_cell_xbdy, mmax, old_cell_mass, new_cell_mass, op_err)
            if (op_err /= 0) ierr = op_err
         end do
!$OMP END PARALLEL DO
      end subroutine set_xa

      
      subroutine set1_xa(s, k, nz, species, xa_old, xaccrete, &
            old_cell_xbdy, new_cell_xbdy, mmax, old_cell_mass, new_cell_mass, ierr)
         ! set new values for s% xa(:,k)
         use num_lib, only: binary_search
         use utils_lib, only: is_bad_num
         use chem_def, only: chem_isos
         type (star_info), pointer :: s
         integer, intent(in) :: k, nz, species
         real(dp), intent(in) :: xa_old(species, nz), xaccrete(species), mmax
         real(dp), dimension(nz), intent(in) :: &
            old_cell_xbdy, new_cell_xbdy, old_cell_mass, new_cell_mass
         integer, intent(out) :: ierr
         
         real(dp) :: xm_outer, xm_inner, msum(species), xm0, xm1, new_cell_dm, dm_sum, dm
         integer :: kk, k_outer, j
         
         integer, parameter :: k_dbg = -1
         logical, parameter :: xa_dbg = .false.
         
         logical, parameter :: do_not_mix_accretion = .false.
         
         include 'formats.dek'
         
         ierr = 0
         msum(:) = -1 ! for testing
                  
         xm_outer = new_cell_xbdy(k)
         if (k == nz) then
            new_cell_dm = mmax - xm_outer - s% M_center
         else
            new_cell_dm = new_cell_mass(k)
         end if
         xm_inner = xm_outer + new_cell_dm
         
         dm_sum = 0d0
         
         if (xm_outer < old_cell_xbdy(1)) then ! there is some accreted material in new cell
            if (do_not_mix_accretion .or. xm_inner <= old_cell_xbdy(1)) then 
               ! new cell is entirely accreted material
               !write(*,2) 'new cell is entirely accreted material', k, new_cell_dm
               s% xa(:,k) = xaccrete(:)
               return
            end if
            dm = min(new_cell_dm, old_cell_xbdy(1) - xm_outer)
            dm_sum = dm
            msum(:) = xaccrete(:)*dm
            xm_outer = old_cell_xbdy(1)
            k_outer = 1
         else ! new cell entirely composed of old material
            msum(:) = 0
            if (xm_outer >= old_cell_xbdy(nz)) then
               ! new cell contained entirely in old cell nz
               k_outer = nz
            else
               ! binary search for k_outer such that
               ! xm_outer >= old_cell_xbdy(k_outer)
               ! and old_cell_xbdy(k_outer+1) > xm_outer
               k_outer = binary_search(nz, 0, old_cell_xbdy, xm_outer)
               
               ! check
               if (k_outer <= 0 .or. k_outer > nz) then

                  ierr = -1
                  if (.not. xa_dbg) return

                  write(*,2) 'k', k
                  write(*,2) 'k_outer', k_outer
                  write(*,1) 'xm_outer', xm_outer
                  write(*,2) 'old_cell_xbdy(1)', 1, old_cell_xbdy(1)
                  write(*,2) 'old_cell_xbdy(nz)', nz, old_cell_xbdy(nz)
                  stop 'debugging: set1_xa'
               end if
               
               if (xm_outer < old_cell_xbdy(k_outer)) then

                  ierr = -1
                  if (.not. xa_dbg) return

                  write(*,*) 'k', k
                  write(*,*) 'k_outer', k_outer
                  write(*,1) 'xm_outer', xm_outer
                  write(*,1) 'old_cell_xbdy(k_outer)', old_cell_xbdy(k_outer)
                  write(*,*) '(xm_outer < old_cell_xbdy(k_outer))'
                  stop 'debugging: set1_xa'
               end if
               
               if (k_outer < nz) then
                  if (old_cell_xbdy(k_outer+1) <= xm_outer) then

                     ierr = -1
                     if (.not. xa_dbg) return

                     write(*,*) 'k', k
                     write(*,*) 'k_outer', k_outer
                     write(*,1) 'xm_outer', xm_outer
                     write(*,1) 'old_cell_xbdy(k_outer+1)', old_cell_xbdy(k_outer+1)
                     write(*,*) '(old_cell_xbdy(k_outer+1) <= xm_outer)'
                     stop 'debugging: set1_xa'
                  end if
               end if
               
            end if
         end if
         
         if (k == -1) then
            ierr = -1
            if (.not. xa_dbg) return
            
            write(*,2) 'nz', nz
            write(*,2) 'k_outer', k_outer
            write(*,1) 'xm_outer', xm_outer
            write(*,1) 'xm_inner', xm_inner
         end if
         
         if (check_for_bad_nums) then
            do j=1,species
               if (is_bad_num(msum(j))) then
                  write(*,*) 'set1_xa', k, j, msum(j)
                  ierr = -1
                  return
                  stop 'debugging: failed in adjust mass before do kk = k_outer, nz'
               end if
            end do
         end if

         do kk = k_outer, nz ! loop until reach m_inner
            xm0 = old_cell_xbdy(kk)
            
            if (xm0 >= xm_inner) then
               if (dm_sum < new_cell_dm .and. kk > 1) then 
                  ! need to add a bit more from the previous source cell
                  dm = new_cell_dm - dm_sum
                  dm_sum = new_cell_dm
                  msum(:) = msum(:) + xa_old(:,kk-1)*dm
               end if
               exit
            end if
            
            if (kk == nz) then
               xm1 = mmax - s% M_center
            else
               xm1 = old_cell_xbdy(kk+1)
            end if
            
            if (xm1 < xm_outer) then
               ierr = -1
               if (.not. xa_dbg) return
               write(*,*)
               write(*,*) 'k', k
               write(*,*) 'kk', kk
               write(*,1) 'xm1', xm1
               write(*,1) 'xm_outer', xm_outer
               write(*,*) 'xm1 < xm_outer'
               stop 'debugging: set1_xa'
            end if
            
            if (xm0 >= xm_outer .and. xm1 <= xm_inner) then ! entire old cell kk is in new cell k
               
               dm = old_cell_mass(kk)
               dm_sum = dm_sum + dm
               
               if (dm_sum > new_cell_dm) then 
                  ! dm too large -- numerical roundoff problems
                  dm = dm - (new_cell_dm - dm_sum)
                  dm_sum = new_cell_dm
               end if
               
               msum(:) = msum(:) + xa_old(:,kk)*dm
               
            else if (xm0 <= xm_outer .and. xm1 >= xm_inner) then ! entire new cell k is in old cell kk
            
               dm = new_cell_mass(k)
               dm_sum = dm_sum + dm
               msum(:) = msum(:) + xa_old(:,kk)*dm
               
            else ! only use the part of old cell kk that is in new cell k
            
               if (xm_inner <= xm1) then ! this is the last part of new cell k
               
                  dm = new_cell_dm - dm_sum
                  dm_sum = new_cell_dm

               else ! notice that we avoid this case if possible because of numerical roundoff
               
                  dm = max(0d0, xm1 - xm_outer)
                  if (dm_sum + dm > new_cell_dm) dm = new_cell_dm - dm_sum
                  dm_sum = dm_sum + dm

               end if
               
               msum(:) = msum(:) + xa_old(:,kk)*dm
               
               if (dm <= 0) then
                  ierr = -1
                  if (.not. xa_dbg) return
                  write(*,*) 'dm <= 0', dm
                  stop 'debugging: set1_xa'
               end if
               
            end if
            
            if (check_for_bad_nums) then
               do j=1,species
                  if (is_bad_num(msum(j))) then
                     ierr = -1
                     if (.not. xa_dbg) return
                     
                     write(*,*) 'set1_xa', k, j, msum(j)
                     write(*,*) 'kk', kk
                     write(*,*) 'xa_old(j,kk)', xa_old(j,kk)
                     stop 'debugging: failed in adjust mass'
                  end if
               end do
            end if
            
            if (dm_sum >= new_cell_dm) then
               exit
            end if
            
         end do

         ! revise and renormalize
         s% xa(:,k) = msum(:) / new_cell_mass(k)
         s% xa(:,k) = s% xa(:,k)/sum(s% xa(:,k))
               
      end subroutine set1_xa
            
      
      subroutine revise_q_and_dq(s, nz, old_xmstar, new_xmstar, k_const_mass, ierr)
         use star_utils, only: normalize_dqs, set_qs
         type (star_info), pointer :: s
         integer, intent(in) :: nz
         real(dp), intent(in) :: old_xmstar, new_xmstar
         integer, intent(out) :: k_const_mass, ierr
         
         integer :: k, kA, kB, i_lnT, j00, jp1
         real(dp) :: lnTlim_A, lnTlim_B, qlim_A, qlim_B
         real(dp) :: frac, lnTmax, lnT_A, lnT_B, qA, qB_old, qB_new, dqAB_old, qfrac
         
         logical :: dbg
         logical :: okay_to_move_kB_inward

         include 'formats.dek'
         
         ierr = 0
         dbg = .false.
         
         if (s% adjust_mass_const_q) then
            k_const_mass = nz+1
            return
         end if
         
         okay_to_move_kB_inward = .false.
         
         lnTlim_A = s% adjust_mass_lnTlim_A  ! e.g., log(1d5)
         lnTlim_B = s% adjust_mass_lnTlim_B  ! e.g., log(1d6)
         
         qlim_A = s% adjust_mass_qlim_A ! 0.99
         qlim_B = s% adjust_mass_qlim_B ! 0.95
         
         frac = old_xmstar / new_xmstar
         i_lnT = s% i_lnT
         
         lnTmax = maxval(s% xh(i_lnT,1:nz))
         
         kA = 0
         lnT_A = min(lnTmax, lnTlim_A)
         do k = 1, nz
            if (s% xh(i_lnT,k) >= lnT_A .or. s% q(k) < qlim_A) then
               kA = k; exit
            end if
         end do
         if (kA == 0) kA = 1
         qA = s% q(kA)
         
         kB = 0
         lnT_B = min(lnTmax, lnTlim_B)
         do k = kA, nz
            if (s% xh(i_lnT,k) >= lnT_B .or. s% q(k) < qlim_B) then
               kB = k; exit
            end if
         end do
         if (kB == 0) kB = nz
         
         qB_old = s% q(kB)
         
         if (dbg_adjm) then
            write(*,*) 'before limit_dqAB_old'
            write(*,*) 'kA', kA
            write(*,1) 'qA', qA
            write(*,*) 'kB', kB
            write(*,1) 'qB_old', qB_old
            write(*,*)
            write(*,1) 'qA-qB_old', qA-qB_old
            write(*,*)
         end if
         
         call limit_dqAB_old(0.1d0,50)
         
         qB_new = qB_old * frac ! in order to keep m(kB) constant
         do ! make sure qfrac is not too far from 1
            qfrac = (qA - qB_new) / dqAB_old
            if (kB == nz) exit
            if (qfrac > 0.9d0 .and. qfrac < 1.05d0) exit
            if (qfrac > 0.5d0 .and. qfrac < 2d0) then
               j00 = maxloc(s% xa(:,kB),dim=1) ! most abundant species at kB
               jp1 = maxloc(s% xa(:,kB+1),dim=1) ! most abundant species at kB+1
               if (j00 /= jp1) then ! change in composition.
                  if (dbg) write(*,*) 'change in composition.  back up kB.'
                  kB = max(1,kB-5)
                  exit
               end if
            end if
            kB = kB+1
            qB_old = s% q(kB)
            dqAB_old = qA - qB_old
            qB_new = qB_old * frac
         end do
         
         k_const_mass = kB
         
         s% dq(kA:kB-1) = s% dq(kA:kB-1) * qfrac
         
         if (dbg) then
            write(*,1) 'revise_q_and_dq sum dqs', sum(s% dq(1:nz))
            write(*,2) 'qfrac region', kB, qfrac, s% q(kB), s% lnT(kB)/ln10
            write(*,2) 'frac region', kA, frac, s% q(kA), s% lnT(kA)/ln10
            write(*,2) 'nz', nz
            write(*,*)
         end if
         
         ! NOTE: it is critical to keep constant mass coords for k >= kB
         ! for applications such as nova burst where want fine details of mass changes
         
         ! set q's so that retain constant mass coords for k >= kB
         s% q(kB:nz) = s% q(kb:nz)*frac
         ! set dq's for k >= kB to match the new q's
         do k = kB+1, nz-1
            s% dq(k-1) = s% q(k-1) - s% q(k)
         end do
         s% dq(nz) = s% q(nz)
         
         ! adjust dq's for k < kB
         s% dq(1:kB-1) = s% dq(1:kB-1)*(1 - s% q(kB))/sum(s% dq(1:kB-1))
         if (dbg) write(*,1) 'new sum dqs', sum(s% dq(1:nz))
         ! set q's for k < kB
         s% q(1) = 1d0
         do k = 2, kB-1
            s% q(k) = s% q(k-1) - s% dq(k-1)
         end do

               
         contains
         
         
         subroutine limit_dqAB_old(limit,dk_limit)
            ! revise kA and kB until qA - qB_old >= limit or kB-kA >= dk_limit
            real(dp), intent(in) :: limit
            integer, intent(in) :: dk_limit
            ! if needed, first move kB inward
            do while (qA - qB_old < limit .and. kB-kA < dk_limit)
               if (okay_to_move_kB_inward .and. kB < nz) then ! move kB to center
                  ! if cannot have kB in envelope,
                  ! better to move it all the way to center.
                  kB = nz
                  qB_old = s% q(kB)
               else ! move kA outward
                  do while (qA - qB_old < limit)
                     if (kA == 1) then
                        kB = max(kB,kA - dk_limit)
                        qB_old = s% q(kB)
                        exit
                     end if
                     kA = kA - 1
                     qA = s% q(kA)
                  end do
                  exit
               end if
            end do
            dqAB_old = qA - qB_old            
         end subroutine limit_dqAB_old
         
      
      end subroutine revise_q_and_dq    

      
      subroutine set_omega( &
            s, nz, k_const_mass, &
            old_cell_xbdy, new_cell_xbdy, mmax, old_cell_mass, new_cell_mass, ierr)
         type (star_info), pointer :: s
         integer, intent(in) :: nz, k_const_mass
         real(dp), intent(in) :: mmax
         real(dp), dimension(nz), intent(in) :: &
            old_cell_xbdy, new_cell_xbdy, old_cell_mass, new_cell_mass
         integer, intent(out) :: ierr         
         integer :: k, op_err, old_k, new_k
         real(dp) :: old_j_tot, new_j_tot
			real(dp), pointer, dimension(:) :: &
			   old_xout, new_xout, old_dmbar, new_dmbar, old_j_rot
         include 'formats.dek'       
         ierr = 0
         
         ! testing
         if (old_cell_xbdy(k_const_mass) /= new_cell_xbdy(k_const_mass)) then
            write(*,2) 'old_cell_xbdy(k_const_mass)', k_const_mass, old_cell_xbdy(k_const_mass)
            write(*,2) 'new_cell_xbdy(k_const_mass)', k_const_mass, new_cell_xbdy(k_const_mass)
            stop 'set_omega'
         end if
         
			allocate( &
			   old_xout(nz), new_xout(nz), old_dmbar(nz), new_dmbar(nz), old_j_rot(nz))
			old_xout(1) = old_cell_xbdy(1)
			new_xout(1) = new_cell_xbdy(1)
			old_dmbar(1) = old_cell_mass(1)/2
			new_dmbar(1) = new_cell_mass(1)/2
			old_j_rot(1) = s% j_rot(1)
			do k=2,nz
				old_xout(k) = old_xout(k-1) + old_dmbar(k-1)
				new_xout(k) = new_xout(k-1) + new_dmbar(k-1)
				old_dmbar(k) = (old_cell_mass(k-1) + old_cell_mass(k))/2
				new_dmbar(k) = (new_cell_mass(k-1) + new_cell_mass(k))/2
			   old_j_rot(k) = s% j_rot(k)
			end do
			old_dmbar(nz) = old_cell_mass(nz-1)/2 + old_cell_mass(nz)
			new_dmbar(nz) = new_cell_mass(nz-1)/2 + new_cell_mass(nz)
!$OMP PARALLEL DO PRIVATE(k, op_err)
         do k = 1, k_const_mass
            op_err = 0
            call set1_omega( &
               s, k, nz, old_xout, new_xout, mmax, old_dmbar, new_dmbar, old_j_rot, op_err)
            if (op_err /= 0) ierr = op_err
         end do
!$OMP END PARALLEL DO

         if (.false.) then ! check
            write(*,1) 'new_xout(101)', new_xout(101), s% j_rot(101)
            write(*,1) 'new_xout(102)', new_xout(102), s% j_rot(102)
            write(*,1) 'new_xout(103)', new_xout(103), s% j_rot(103)
            write(*,1) 'new_xout(104)', new_xout(104), s% j_rot(104)
         
            write(*,1) 'old_xout(150)', old_xout(150), old_j_rot(150)
            write(*,1) 'old_xout(151)', old_xout(151), old_j_rot(151)
            write(*,1) 'old_xout(152)', old_xout(152), old_j_rot(152)
            write(*,1) 'old_xout(153)', old_xout(153), old_j_rot(153)

            old_j_tot = old_j_rot(k_const_mass)*old_cell_mass(k_const_mass)/2
            new_j_tot = s% j_rot(k_const_mass)*new_cell_mass(k_const_mass)/2
            old_k = k_const_mass-1
            new_k = k_const_mass-1
            do ! xout decreases and j_tot increases as move toward surface
               write(*,2) 'new_j_tot xout', new_k, new_j_tot, new_xout(new_k)
               write(*,2) 'old_j_tot xout', old_k, old_j_tot, old_xout(old_k)
               write(*,1) 'new - old', new_j_tot - old_j_tot, new_xout(new_k) - old_xout(old_k)
               write(*,*)
               if (new_xout(new_k) >= old_xout(old_k)) then
                  if (new_j_tot > old_j_tot) then
                     write(*,*) 'new_xout >= old_xout, but new_j_tot > old_j_tot'
                     stop 'set_omega'
                  end if
                  if (new_k == 1) exit
                  new_j_tot = new_j_tot + s% j_rot(new_k)*new_dmbar(new_k)
                  new_k = new_k - 1
               else
                  if (new_j_tot < old_j_tot) then
                     write(*,*) 'new_xout < old_xout, but new_j_tot < old_j_tot'
                     stop 'set_omega'
                  end if
                  if (old_k == 1) exit
                  old_j_tot = old_j_tot + old_j_rot(old_k)*old_dmbar(old_k)
                  old_k = old_k - 1
               end if
            end do
         end if

			deallocate(old_xout, new_xout, old_dmbar, new_dmbar, old_j_rot)
      end subroutine set_omega

      
		! this works like set1_xa except shifted to cell edge instead of cell center
      subroutine set1_omega(s, k, nz, &
            old_xout, new_xout, mmax, old_dmbar, new_dmbar, old_j_rot, ierr)
         ! set new value for s% omega(k)
         use num_lib, only: binary_search
         use utils_lib, only: is_bad_num
         type (star_info), pointer :: s
         integer, intent(in) :: k, nz
         real(dp), intent(in) :: mmax
         real(dp), dimension(:), intent(in) :: &
            old_xout, new_xout, old_dmbar, new_dmbar, old_j_rot
         integer, intent(out) :: ierr
         
         real(dp) :: xm_outer, xm_inner, j_tot, xm0, xm1, new_point_dmbar, dm_sum, dm
         integer :: kk, k_outer, j
         
         integer, parameter :: k_dbg = -1
         
         include 'formats.dek'
         
         ierr = 0                  
         xm_outer = new_xout(k)
         if (k == nz) then
            new_point_dmbar = mmax - xm_outer - s% M_center
         else
            new_point_dmbar = new_dmbar(k)
         end if
         xm_inner = xm_outer + new_point_dmbar
         
         if (k == k_dbg) then
            write(*,2) 'xm_outer', k, xm_outer
            write(*,2) 'xm_inner', k, xm_inner
            write(*,2) 'new_point_dmbar', k, new_point_dmbar
         end if
         
         !write(*,*)
         !write(*,2) 'xm_outer', k, xm_outer
         
         dm_sum = 0d0
         
         if (xm_outer < old_xout(1)) then ! there is some accreted material in new
            if (xm_inner <= old_xout(1)) then 
               ! new is entirely accreted material
               !write(*,2) 'new is entirely accreted material', k, new_point_dmbar
               s% omega(k) = 0
               return
            end if
            dm = min(new_point_dmbar, old_xout(1) - xm_outer)
            dm_sum = dm
            j_tot = 0
            xm_outer = old_xout(1)
            k_outer = 1
         else ! new entirely composed of old material
            if (k == k_dbg) write(*,*) 'new entirely composed of old material'
            j_tot = 0
            if (xm_outer >= old_xout(nz)) then
               ! new contained entirely in old nz
               k_outer = nz
            else
               ! binary search for k_outer such that
               ! xm_outer >= old_xout(k_outer)
               ! and old_xout(k_outer+1) > xm_outer
               k_outer = binary_search(nz, 0, old_xout, xm_outer)
               
               if (k == k_dbg) write(*,2) 'k_outer', k_outer, old_xout(k_outer), old_xout(k_outer+1)
               
               ! check
               if (k_outer <= 0 .or. k_outer > nz) then

                  ierr = -1
                  !return

                  write(*,2) 'k', k
                  write(*,2) 'k_outer', k_outer
                  write(*,1) 'xm_outer', xm_outer
                  write(*,2) 'old_xout(1)', 1, old_xout(1)
                  write(*,2) 'old_xout(nz)', nz, old_xout(nz)
                  stop 'debugging: set1_omega'
               end if
               
               if (xm_outer < old_xout(k_outer)) then

                  ierr = -1
                  !return

                  write(*,*) 'k', k
                  write(*,*) 'k_outer', k_outer
                  write(*,1) 'xm_outer', xm_outer
                  write(*,1) 'old_xout(k_outer)', old_xout(k_outer)
                  write(*,*) '(xm_outer < old_xout(k_outer))'
                  stop 'debugging: set1_omega'
               end if
               
               if (k_outer < nz) then
                  if (old_xout(k_outer+1) <= xm_outer) then

                     ierr = -1
                     !return

                     write(*,*) 'k', k
                     write(*,*) 'k_outer', k_outer
                     write(*,1) 'xm_outer', xm_outer
                     write(*,1) 'old_xout(k_outer+1)', old_xout(k_outer+1)
                     write(*,*) '(old_xout(k_outer+1) <= xm_outer)'
                     stop 'debugging: set1_omega'
                  end if
               end if
               
            end if
         end if
         
         if (k == -1) then
            ierr = -1
            !return
            
            write(*,2) 'nz', nz
            write(*,2) 'k_outer', k_outer
            write(*,1) 'xm_outer', xm_outer
            write(*,1) 'xm_inner', xm_inner
				stop 'debugging: set1_omega'
         end if

         do kk = k_outer, nz ! loop until reach m_inner
            xm0 = old_xout(kk)
               
            if (k == k_dbg) write(*,2) 'kk', kk, old_xout(kk), old_xout(kk+1)
            
            if (xm0 >= xm_inner) then
               if (dm_sum < new_point_dmbar .and. kk > 1) then 
                  ! need to add a bit more from the previous source
                  dm = new_point_dmbar - dm_sum
                  dm_sum = new_point_dmbar
                  j_tot = j_tot + old_j_rot(kk-1)*dm
               
                  if (.false. .or. k == k_dbg) &
                     write(*,3) 'new k contains some of old kk-1', &
                        k, kk, old_j_rot(kk-1)*dm, old_j_rot(kk-1), dm, j_tot/dm_sum, j_tot, dm_sum

                  end if
               exit
            end if
            
            if (kk == nz) then
               xm1 = mmax - s% M_center
            else
               xm1 = old_xout(kk+1)
            end if
            
            if (xm1 < xm_outer) then
               ierr = -1
               !return
               write(*,*)
               write(*,*) 'k', k
               write(*,*) 'kk', kk
               write(*,1) 'xm1', xm1
               write(*,1) 'xm_outer', xm_outer
               write(*,*) 'xm1 < xm_outer'
               stop 'debugging: set1_omega'
            end if
            
            if (xm0 >= xm_outer .and. xm1 <= xm_inner) then ! entire old kk is in new k
               
               dm = old_dmbar(kk)
               dm_sum = dm_sum + dm
               
               if (dm_sum > new_point_dmbar) then 
                  ! dm too large -- numerical roundoff problems
                  dm = dm - (new_point_dmbar - dm_sum)
                  dm_sum = new_point_dmbar
               end if
               
               j_tot = j_tot + old_j_rot(kk)*dm
               
               if (.false. .or. k == k_dbg) &
                  write(*,3) 'new k contains all of old kk', &
                     k, kk, old_j_rot(kk)*dm, old_j_rot(kk), dm, j_tot/dm_sum, j_tot, dm_sum
               
            else if (xm0 <= xm_outer .and. xm1 >= xm_inner) then ! entire new k is in old kk
            
               dm = new_dmbar(k)
               dm_sum = dm_sum + dm
               j_tot = j_tot + old_j_rot(kk)*dm
               
               if (.false. .or. k == k_dbg) &
                  write(*,3) 'all new k is in old kk', &
                     k, kk, old_j_rot(kk)*dm, old_j_rot(kk), dm, j_tot/dm_sum, j_tot, dm_sum
               
            else ! only use the part of old kk that is in new k
            
               if (k == k_dbg) then
                  write(*,*) 'only use the part of old kk that is in new k', xm_inner <= xm1
                  write(*,1) 'xm_outer', xm_outer
                  write(*,1) 'xm_inner', xm_inner
                  write(*,1) 'xm0', xm0
                  write(*,1) 'xm1', xm1
                  write(*,1) 'dm_sum', dm_sum
                  write(*,1) 'new_point_dmbar', new_point_dmbar
                  write(*,1) 'new_point_dmbar - dm_sum', new_point_dmbar - dm_sum
               end if
            
               if (xm_inner <= xm1) then ! this is the last part of new k
               
                  if (k == k_dbg) write(*,3) 'this is the last part of new k', k, kk

                  dm = new_point_dmbar - dm_sum
                  dm_sum = new_point_dmbar

               else ! notice that we avoid this case if possible because of numerical roundoff
               
                  if (k == k_dbg) write(*,3) 'we avoid this case if possible', k, kk
               
                  dm = max(0d0, xm1 - xm_outer)
                  if (dm_sum + dm > new_point_dmbar) dm = new_point_dmbar - dm_sum
                  dm_sum = dm_sum + dm

               end if
               
               j_tot = j_tot + old_j_rot(kk)*dm
               
               if (.false. .or. k == k_dbg) &
                  write(*,3) 'new k use only part of old kk', &
                     k, kk, old_j_rot(kk)*dm, old_j_rot(kk), dm, j_tot/dm_sum, j_tot, dm_sum
               
               if (dm <= 0) then
                  ierr = -1
                  !return
                  write(*,*) 'dm <= 0', dm
                  stop 'debugging: set1_omega'
               end if
               
            end if
            
            if (dm_sum >= new_point_dmbar) then
               if (k == k_dbg) then
                  write(*,2) 'exit for k', k
                  write(*,2) 'dm_sum', kk, dm_sum
                  write(*,2) 'new_point_dmbar', kk, new_point_dmbar
               end if
               exit
            end if
            
         end do
         
			if (dm_sum /= new_point_dmbar) then
            write(*,2) 'dm_sum', k, dm_sum
            write(*,2) 'new_point_dmbar', k, new_point_dmbar
            stop 'debugging: set1_omega'
			end  if
			
			s% j_rot(k) = j_tot/new_point_dmbar
			s% i_rot(k) = (2d0/3d0)*exp(2*s% lnR_for_d_dt(k))
         s% omega(k) = s% j_rot(k)/s% i_rot(k)
         
         if (k_dbg == k) then
            write(*,2) 's% omega(k)', k, s% omega(k)
            write(*,2) 's% j_rot(k)', k, s% j_rot(k)
            write(*,2) 's% i_rot(k)', k, s% i_rot(k)
            stop 'debugging: set1_omega'
         end if
               
      end subroutine set1_omega


      end module adjust_mass







         
         



