! ***********************************************************************
!
!   Copyright (C) 2010  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************

      module solve_hydro

      use star_private_def
      use const_def

      implicit none

      integer, parameter :: stencil_neighbors = 1 
            ! number of neighbors on each side (e.g., =1 for 3 point stencil)


      contains
      

      integer function do_hydro_converge( &
            s, itermin, skip_global_corr_coeff_limit, dt)
         ! return keep_going, retry, backup, or terminate
         
         ! do not require that functions have been evaluated for starting configuration.
         ! when finish, will have functions evaluated for the final set of primary variables.

         use mtx_lib
         use mtx_def
         use num_def
         use hydro_vars, only: eval_cell_collapse_timescale, eval_chem_timescale
         use hydro_rotation, only: get_rotation_sigmas
         use mix_info, only: get_convection_sigmas
         use star_utils, only: update_time, total_times

         
         type (star_info), pointer :: s
         integer, intent(in) :: itermin
         logical, intent(in) :: skip_global_corr_coeff_limit
         real(dp), intent(in) :: dt 
         
         character (len=64) :: hydro_decsol
         integer :: lid, lrd, ierr, nvar, nz, k, mljac, mujac, n, nzmax, &
            time0, clock_rate, &
            hydro_lwork, hydro_liwork, num_jacobians
         real(dp) :: total_all_before, sparse_nzmax_factor, &
            timescale, cell_collapse_timescale, chem_timescale
         logical :: report, dumping, numerical_jacobian
                     
         include 'formats'
         
         
         if (dt <= 0d0) then
            do_hydro_converge = keep_going
            return
         end if
         
         do_hydro_converge = terminate
      
         if (s% doing_timing) then
            total_all_before = total_times(s)
            call system_clock(time0,clock_rate)
         end if

         ierr = 0
         
         if (s% do_burn .or. s% do_mix) then
            nvar = s% nvar
         else
            nvar = s% nvar_hydro
         end if
         
         nz = s% nz
         n = nz*nvar
         mljac = (stencil_neighbors+1)*nvar-1
         mujac = mljac
         nzmax = 0
         
         if ((nvar <= s% hydro_decsol_switch &
                  .and. s% operator_coupling_choice == 0) .or. &
               s% operator_coupling_choice /= 0) then
            hydro_decsol = s% small_mtx_decsol
         else
            hydro_decsol = s% large_mtx_decsol
         end if
         
         s% hydro_decsol_option = decsol_option(hydro_decsol, ierr)         
         if (ierr /= 0) then
            write(*, *) 'bad value for hydro_decsol ' // trim(hydro_decsol)
            do_hydro_converge = terminate
            s% termination_code = t_solve_hydro
            return
         end if
         
         s% hydro_matrix_type = banded_matrix_type
         select case(s% hydro_decsol_option)
            case (bcyclic_dble)
               call bcyclic_dble_work_sizes(nvar, nz, lrd, lid)
               s% hydro_matrix_type = block_tridiag_dble_matrix_type
            case (bcyclic_klu)
               call bcyclic_klu_work_sizes(nvar, nz, lrd, lid)
               s% hydro_matrix_type = block_tridiag_dble_matrix_type
            case default
               write(*,*) &
                  'do_hydro_converge: invalid setting for hydro_decsol_option', &
                  s% hydro_decsol_option
               do_hydro_converge = terminate
               s% termination_code = t_solve_hydro
               s% result_reason = nonzero_ierr 
               return
         end select
               
         if (s% doing_first_model_of_run) &
            s% L_phot_old = s% xh(s% i_lum,1)/Lsun
         
         call get_convection_sigmas(s, dt, ierr)
         s% termination_code = t_solve_hydro
         if (ierr /= 0) return
         
         if (s% rotation_flag) then
            call get_rotation_sigmas(s, 1, nz, dt, ierr)
            s% termination_code = t_solve_hydro
            if (ierr /= 0) return
         end if
         
         if (s% v_flag) then
            do k=1,nz
               s% csound_init(k) = s% csound(k)
            end do
         end if
         
         call alloc_for_decsol(ierr)
         if (ierr /= 0) then
            if (s% report_ierr) write(*, *) 'do_hydro_converge: alloc_for_decsol failed'
            do_hydro_converge = retry
            s% result_reason = nonzero_ierr 
            s% termination_code = t_solve_hydro
            return
         end if
            
         call work_sizes_for_newton(ierr)
         if (ierr /= 0) then
            if (s% report_ierr) write(*, *) 'do_hydro_converge: work_sizes_for_newton failed'
            do_hydro_converge = retry
            s% result_reason = nonzero_ierr 
            s% termination_code = t_solve_hydro
            return
         end if
      
         call alloc_for_newton(ierr)
         if (ierr /= 0) then
            s% termination_code = t_solve_hydro
            return
         end if

         chem_timescale = eval_chem_timescale(s)
         cell_collapse_timescale = eval_cell_collapse_timescale(s)
         timescale = min(chem_timescale, cell_collapse_timescale)
         s% chem_timescale = chem_timescale
         s% cell_collapse_timescale = cell_collapse_timescale

         report = (s% report_hydro_solver_progress .or. s% report_ierr)
         s% hydro_call_number = s% hydro_call_number + 1
         dumping = (s% hydro_call_number == s% hydro_dump_call_number)         
         if (s% report_hydro_solver_progress) then
            write(*,*)
            write(*,2) 'hydro_call_number, s% dt, dt, dt/secyer, log dt/yr', &
               s% hydro_call_number, s% dt, dt, dt/secyer, log10(dt/secyer)
         end if
   
         numerical_jacobian = s% hydro_numerical_jacobian .and. &
            ((s% hydro_call_number == s% hydro_dump_call_number) &
            .or. (s% hydro_dump_call_number < 0))
                  
         do_hydro_converge = do_hydro_newton( &
            s, itermin, skip_global_corr_coeff_limit, dt, &
            report, dumping, numerical_jacobian, &
            nz, nvar, lrd, s% rpar_decsol, lid, s% ipar_decsol, &
            s% hydro_work, hydro_lwork, &
            s% hydro_iwork, hydro_liwork)
         
         if (s% doing_timing) &
            call update_time(s, time0, total_all_before, s% time_struct_burn_mix)
            
         
         contains
         
         
         subroutine alloc_for_decsol(ierr)
            integer, intent(out) :: ierr
            include 'formats'
            ierr = 0
            if (.not. associated(s% ipar_decsol)) then
               allocate(s% ipar_decsol(lid))
            else if (size(s% ipar_decsol, dim=1) < lid) then
               deallocate(s% ipar_decsol)
               allocate(s% ipar_decsol(int(1.3*lid)+100))
            end if
            
            if (.not. associated(s% rpar_decsol)) then
               allocate(s% rpar_decsol(lrd))
            else if (size(s% rpar_decsol, dim=1) < lrd) then
               deallocate(s% rpar_decsol)
               allocate(s% rpar_decsol(int(1.3*lrd)+100))
            end if            
         end subroutine alloc_for_decsol
         
         
         subroutine alloc_for_newton(ierr)
            integer, intent(out) :: ierr
            include 'formats'
            ierr = 0
            
            if (.not. associated(s% hydro_iwork)) then
               allocate(s% hydro_iwork(hydro_liwork))
            else if (size(s% hydro_iwork, dim=1) < hydro_liwork) then
               deallocate(s% hydro_iwork)
               allocate(s% hydro_iwork(int(1.3*hydro_liwork)+100))
            end if
            
            if (.not. associated(s% hydro_work)) then
               allocate(s% hydro_work(hydro_lwork))
            else if (size(s% hydro_work, dim=1) < hydro_lwork) then
               deallocate(s% hydro_work)
               allocate(s% hydro_work(int(1.3*hydro_lwork)+100))
            end if
            
         end subroutine alloc_for_newton
      
   
         subroutine work_sizes_for_newton(ierr)
            use mod_star_newton, only: get_newton_work_sizes
            integer, intent(out) :: ierr
            call get_newton_work_sizes(nvar, nz, hydro_lwork, hydro_liwork, ierr)
         end subroutine work_sizes_for_newton


      end function do_hydro_converge
      

      integer function do_hydro_newton( &
            s, itermin, skip_global_corr_coeff_limit, dt, &
            report, dumping, numerical_jacobian, &
            nz, nvar, lrd, rpar_decsol, lid, ipar_decsol, &
            newton_work, newton_lwork, &
            newton_iwork, newton_liwork)
         ! return keep_going, retry, backup, or terminate
         
         ! when using newton for hydro step, 
         ! do not require that functions have been evaluated for starting configuration.
         ! when finish, will have functions evaluated for the final set of primary variables.
         ! for example, the reaction rates will have been computed, so they can be used
         ! as initial values in the following burn and mix.
         
         use num_def
         !use num_lib
         use utils_lib, only: is_bad_num, has_bad_num
         use alloc

         type (star_info), pointer :: s
         integer, intent(in) :: itermin, nz, nvar
         logical, intent(in) :: skip_global_corr_coeff_limit, &
            report, dumping, numerical_jacobian
         real(dp), intent(in) :: dt
         integer, intent(in) :: lrd, lid
         
         real(dp), pointer :: dx(:,:), dx1(:) ! dx => dx1
         integer, pointer :: ipar_decsol(:) ! (lid)
         real(dp), pointer :: rpar_decsol(:) ! (lrd)
         integer, intent(in) :: newton_lwork, newton_liwork
         real(dp), pointer :: newton_work(:) ! (newton_lwork)
         integer, pointer :: newton_iwork(:) ! (newton_liwork)         
         logical :: converged
         integer :: i, k, species, ierr, alph, j1, j2
         real(dp) :: tol_correction_norm, tol_max_correction, varscale

         real(dp), parameter :: xscale_min = 1d-3

         include 'formats'

         species = s% species
         do_hydro_newton = keep_going
         
         if (s% T(s% nz) >= s% tol_correction_extreme_T_limit) then
            tol_correction_norm = s% tol_correction_norm_extreme_T
            tol_max_correction = s% tol_max_correction_extreme_T
         else if (s% T(s% nz) >= s% tol_correction_high_T_limit) then
            tol_correction_norm = s% tol_correction_norm_high_T
            tol_max_correction = s% tol_max_correction_high_T
         else if (s% number_of_backups_in_a_row >= 3) then
            tol_correction_norm = s% tol_correction_norm_alt
            tol_max_correction = s% tol_max_correction_alt
         else
            tol_correction_norm = s% tol_correction_norm
            tol_max_correction = s% tol_max_correction
         end if
         
         ! parameters for newton
         newton_iwork(1:num_iwork_params) = 0
         newton_work(1:num_work_params) = 0
         
         if ((s% doing_first_model_of_run) .or. s% model_number <= s% last_backup) &
            newton_iwork(i_try_really_hard) = 1 ! try_really_hard for 1st model or after a backup
         newton_iwork(i_itermin) = itermin
         
         newton_iwork(i_max_iter_for_enforce_resid_tol) = s% max_iter_for_resid_tol1
         newton_iwork(i_max_iter_for_resid_tol2) = s% max_iter_for_resid_tol2
         newton_iwork(i_max_iter_for_resid_tol3) = s% max_iter_for_resid_tol3
         
         if (s% refine_solution) then
            newton_iwork(i_refine_solution) = 1
         else
            newton_iwork(i_refine_solution) = 0
         end if
         
         if (s% refine_mtx_solution) then
            newton_iwork(i_refine_mtx_solution) = 1
         else
            newton_iwork(i_refine_mtx_solution) = 0
         end if
         
         newton_iwork(i_max_iterations_for_jacobian) = s% max_iterations_for_jacobian
         if (s% model_number < s% last_backup-1) then
            newton_iwork(i_max_iterations_for_jacobian) = 1
            newton_iwork(i_max_tries) = s% max_tries_after_backup2
            !write(*,*) 'use max_tries_after_backup2', s% max_tries_after_backup2
         else if (s% model_number < s% last_backup) then
            newton_iwork(i_max_iterations_for_jacobian) = 1
            newton_iwork(i_max_tries) = s% max_tries_after_backup
            !write(*,*) 'use max_tries_after_backup', s% max_tries_after_backup
         else if (s% retry_cnt > 0) then
            newton_iwork(i_max_tries) = s% max_tries_for_retry
            !write(*,*) 'use max_tries_for_retry', s% max_tries_for_retry
         else if (s% doing_first_model_of_run) then
            newton_iwork(i_max_tries) = s% max_tries1
         else
            newton_iwork(i_max_tries) = s% max_tries
         end if
         
         newton_iwork(i_tiny_min_corr_coeff) = s% tiny_corr_coeff_limit
         if (dumping .or. s% report_hydro_solver_progress) then
            newton_iwork(i_debug) = 1
         else
            newton_iwork(i_debug) = 0
         end if
         newton_iwork(i_model_number) = s% model_number
         if (s% model_number > s% model_number_for_last_jacobian) then
            newton_iwork(i_num_solves) = 0
            newton_iwork(i_num_jacobians) = 0
         else
            newton_iwork(i_num_solves) = s% num_solves
            newton_iwork(i_num_jacobians) = s% num_jacobians         
         end if

         newton_work(r_tol_residual_norm) = s% tol_residual_norm1
         newton_work(r_tol_max_residual) = s% tol_max_residual1
         newton_work(r_tol_residual_norm2) = s% tol_residual_norm2
         newton_work(r_tol_max_residual2) = s% tol_max_residual2
         newton_work(r_tol_residual_norm3) = s% tol_residual_norm3
         newton_work(r_tol_max_residual3) = s% tol_max_residual3

         newton_work(r_tol_max_correction) = tol_max_correction
         
         newton_work(r_target_corr_factor) = s% target_corr_factor
         newton_work(r_tol_abs_slope_min) = -1 ! unused
         newton_work(r_tol_corr_resid_product) = -1 ! unused
         
         newton_work(r_scale_correction_norm) = s% scale_correction_norm
         newton_work(r_corr_param_factor) = s% corr_param_factor
         newton_work(r_scale_max_correction) = s% scale_max_correction
         newton_work(r_corr_norm_jump_limit) = s% corr_norm_jump_limit
         newton_work(r_max_corr_jump_limit) = s% max_corr_jump_limit
         newton_work(r_resid_norm_jump_limit) = s% resid_norm_jump_limit
         newton_work(r_max_resid_jump_limit) = s% max_resid_jump_limit
         newton_work(r_slope_alert_level) = s% slope_alert_level
         newton_work(r_slope_crisis_level) = s% slope_crisis_level
         newton_work(r_tiny_corr_factor) = s% tiny_corr_factor
         newton_work(r_dt) = dt

         if (skip_global_corr_coeff_limit) then
            newton_work(r_min_corr_coeff) = 1       
         else
            newton_work(r_min_corr_coeff) = s% corr_coeff_limit         
         end if

         newton_work(r_sparse_non_zero_max_factor) = s% sparse_non_zero_max_factor
            
         call non_crit_get_work_array(s, dx1, nvar*nz, nvar*nz_alloc_extra, 'newton', ierr)
         if (ierr /= 0) return     
         dx(1:nvar,1:nz) => dx1(1:nvar*nz)
         s% newton_dx(1:nvar,1:nz) => dx1(1:nvar*nz)
         
         ! set xh and dx for initial guess using current structure info
         do j1 = 1, s% nvar_hydro
            if (j1 == s% i_xlnd) then
               s% surf_lnd = s% lnd(1)
               do k = 1, nz
                  s% xh(j1,k) = s% lnd(k)
               end do
            else if (j1 == s% i_lnPgas) then
               s% surf_lnPgas = s% lnPgas(1)
               do k = 1, nz
                  s% xh(j1,k) = s% lnPgas(k)
               end do
            else if (j1 == s% i_lnT) then
               s% surf_lnT = s% lnT(1)
               do k = 1, nz
                  s% xh(j1,k) = s% lnT(k)
               end do
            else if (j1 == s% i_lnR) then
               s% surf_lnR = s% lnR(1)
               do k = 1, nz
                  s% xh(j1,k) = s% lnR(k)
               end do
            else if (j1 == s% i_lum) then
               do k = 1, nz
                  s% xh(j1,k) = s% L(k)
               end do
            else if (j1 == s% i_vel) then
               s% surf_v = s% v(1)
               do k = 1, nz
                  s% xh(j1,k) = s% v(k)
               end do
            end if
         end do
         s% surf_lnS = s% lnS(1)
         s% num_surf_revisions = 0
         
         do k = 1, nz
            do j1 = 1, s% nvar_hydro
               dx(j1,k) = s% xh(j1,k) - s% xh_pre(j1,k)
            end do
         end do
         
         if (nvar >= s% i_chem1) then
            do k = 1, nz
               j2 = 1
               do j1 = s% i_chem1, nvar
                  dx(j1,k) = s% xa(j2,k) - s% xa_pre(j2,k)
                  j2 = j2+1
               end do
            end do
         end if  
         
         call hydro_newton_step( &
            s, nz, s% nvar_hydro, nvar, dx1, dt, &
            tol_correction_norm, numerical_jacobian, &
            lrd, rpar_decsol, lid, ipar_decsol, &
            newton_work, newton_lwork, &
            newton_iwork, newton_liwork, &
            converged, ierr)
         
         call non_crit_return_work_array(s, dx1, 'newton')
         nullify(s% newton_dx)

         if (dumping) stop 'debug: dumping hydro_newton' 
         
         if (ierr /= 0) then
            if (report) then
               write(*, *) 'hydro_newton_step returned ierr', ierr
               write(*, *) 's% model_number', s% model_number
               write(*, *) 'nz', nz
               write(*, *) 's% num_retries', s% num_retries
               write(*, *) 's% num_backups', s% num_backups
               write(*, *) 
            end if
            do_hydro_newton = retry
            s% result_reason = nonzero_ierr 
            return
         end if         
            
         s% num_solves = newton_iwork(i_num_solves)
         s% num_jacobians = newton_iwork(i_num_jacobians)
         s% total_num_jacobians = s% total_num_jacobians + s% num_jacobians
         
         if (converged) then ! sanity checks before accept it
            converged = check_after_converge(s, report, ierr)
         end if

         if (.not. converged) then
            do_hydro_newton = retry
            s% result_reason = hydro_failed_to_converge
            if (report) then
               write(*, *) 'hydro_newton_step failed to converge'
               write(*,2) 's% model_number', s% model_number
               write(*,2) 's% hydro_call_number', s% hydro_call_number
               write(*,2) 'nz', nz
               write(*,2) 's% num_retries', s% num_retries
               write(*,2) 's% num_backups', s% num_backups
               write(*,2) 's% number_of_backups_in_a_row', s% number_of_backups_in_a_row
               write(*,1) 'log dt/secyer', log10(dt/secyer)
               write(*, *) 
            end if
            return
         end if

      end function do_hydro_newton
      
      
      logical function check_after_converge(s, report, ierr) result(converged)
         type (star_info), pointer :: s
         logical, intent(in) :: report
         integer, intent(out) :: ierr
         integer :: k, nz
         include 'formats'
         ierr = 0
         nz = s% nz
         converged = .true.
         if (s% L(1) <= 0) then
            if (report) write(*,*) 'after hydro, negative surface luminosity'
            converged = .false.
            return
         end if
         if (s% R_center > 0) then
            if (s% R_center > exp(s% lnR(nz))) then
               if (report) &
                  write(*,2) 'volume < 0 in cell nz', nz, &
                     s% R_center - exp(s% lnR(nz)), s% R_center, exp(s% lnR(nz)), &
                     s% dm(nz), s% rho(nz), s% dq(nz)
               converged = .false.
               return
            end if
         end if
         do k=1,nz-1
            if (s% lnR(k) <= s% lnR(k+1)) then
               if (report) write(*,2) 'after hydro, negative cell volume in cell k', &
                     k, s% lnR(k) - s% lnR(k+1), s% lnR(k), s% lnR(k+1), & 
                     s% lnR_start(k) - s% lnR_start(k+1), s% lnR_start(k), s% lnR_start(k+1)
               converged = .false.; exit
               stop 'check_after_converge'
            else if (s% lnT(k) > ln10*12) then
               if (report) write(*,2) 'after hydro, logT > 12 in cell k', k, s% lnT(k)
               converged = .false.!; exit
            else if (s% lnT(k) < ln10) then
               if (report) write(*,*) 'after hydro, logT < 1 in cell k', k
               converged = .false.!; exit
            else if (s% lnd(k) > ln10*12) then
               if (report) write(*,*) 'after hydro, logRho > 12 in cell k', k
               converged = .false.!; exit
            else if (s% lnd(k) < -ln10*20) then
               if (report) write(*,*) 'after hydro, logRho < -20 in cell k', k
               converged = .false.!; exit
            end if
         end do
      end function check_after_converge
      
      
      subroutine hydro_newton_step( &
            s, nz, nvar_hydro, nvar, dx1, dt, &
            tol_correction_norm, numerical_jacobian, &
            lrd, rpar_decsol, lid, ipar_decsol, &
            newton_work, newton_lwork, &
            newton_iwork, newton_liwork, &
            converged, ierr)
	      use num_def
	      !use num_lib
	      use chem_def
         use mtx_lib
         use mtx_def
         use alloc
         use hydro_mtx, only: ipar_id, ipar_first_call, hydro_lipar, &
            rpar_dt, hydro_lrpar

         type (star_info), pointer :: s         
         integer, intent(in) :: nz, nvar_hydro, nvar
         real(dp), pointer :: dx1(:)
         real(dp), intent(in) :: dt
         real(dp), intent(in) :: tol_correction_norm
         logical, intent(in) :: numerical_jacobian
         integer, intent(in) :: lrd, lid
         integer, intent(inout), pointer :: ipar_decsol(:) ! (lid)
         real(dp), intent(inout), pointer :: rpar_decsol(:) ! (lrd)
         integer, intent(in) :: newton_lwork, newton_liwork
         real(dp), intent(inout), pointer :: newton_work(:) ! (newton_lwork)
         integer, intent(inout), pointer :: newton_iwork(:) ! (newton_liwork)
         logical, intent(out) :: converged
         integer, intent(out) :: ierr

         integer, parameter :: lipar=hydro_lipar, lrpar=hydro_lrpar
         integer, target :: ipar_target(lipar)
         real(dp), target :: rpar_target(lrpar)
         integer, pointer :: ipar(:)
         real(dp), pointer :: rpar(:)

         integer :: mljac, mujac, i, k, j, matrix_type, neq
         logical :: failure
         real(dp) :: varscale
         real(dp), parameter :: xscale_min = 1
         real(dp), pointer :: dx(:,:)
         
         real(dp), pointer, dimension(:,:) :: dx_init, x_scale
         real(dp), pointer, dimension(:) :: dx_init1, x_scale1
         
         logical, parameter :: dbg = .false.

         include 'formats'

         ierr = 0
         
         neq = nvar*nz
         
         dx(1:nvar,1:nz) => dx1(1:neq)
         
         if (dbg) write(*, *) 'enter hydro_newton_step'
         
         s% numerical_jacobian = numerical_jacobian
                  
         mljac = 2*nvar-1
         mujac = mljac
         
         ipar => ipar_target
         ipar(ipar_id) = s% id
         ipar(ipar_first_call) = 1

         rpar => rpar_target         
         rpar(rpar_dt) = dt
         
         call check_sizes(s, ierr)
         if (ierr /= 0) then
            write(*,*) 'check_sizes failed'
            return
         end if

         call non_crit_get_work_array(s, x_scale1, neq, nvar*nz_alloc_extra, 'hydro_newton_step', ierr)
         if (ierr /= 0) return
         x_scale(1:nvar,1:nz) => x_scale1(1:neq)
         
         call non_crit_get_work_array(s, dx_init1, neq, nvar*nz_alloc_extra, 'hydro_newton_step', ierr)
         if (ierr /= 0) return
         dx_init(1:nvar,1:nz) => dx_init1(1:neq)

         do i = 1, nvar
            if (i <= s% nvar_hydro) then
               varscale = maxval(abs(s% xh(i,1:nz)))
               varscale = max(xscale_min, varscale)
            else
               varscale = 1
            end if
            x_scale(i, 1:nz) = varscale
         end do
         
         do k = 1, nz
            do j = 1, nvar
               dx_init(j,k) = dx(j,k)
            end do
         end do
         
         if (s% matrix_type /= 0) then ! s% matrix_type is a control parameter
            matrix_type = s% matrix_type
         else ! s% hydro_matrix_type is the matrix_type currently in use.
            if (s% hydro_matrix_type <= 0) then
               if (s% hydro_decsol_option == bcyclic_dble .or. &
                   s% hydro_decsol_option == bcyclic_klu .or. &
                   s% hydro_decsol_option == block_thomas_dble) then
                  s% hydro_matrix_type = block_tridiag_dble_matrix_type
               else if (s% hydro_decsol_option == block_thomas_quad) then
                  s% hydro_matrix_type = block_tridiag_quad_matrix_type
               else
                  s% hydro_matrix_type = banded_matrix_type
               end if
               
               write(*,2) 'solve_hydro: set hydro_matrix_type', s% hydro_matrix_type
               
            end if
            matrix_type = s% hydro_matrix_type
         end if
            			
         if (dbg) write(*, *) 'call newton'
         select case(s% hydro_decsol_option)
               
            case (bcyclic_dble)
               if (matrix_type /= block_tridiag_dble_matrix_type) then
                  write(*,'(a)') 'matrix_type must be block_tridiag_dble_matrix_type for bcyclic_dble'
                  ierr = -1
                  return
               end if
               call newt(.false., ierr)
               
            case (bcyclic_klu)
               if (matrix_type /= block_tridiag_dble_matrix_type) then
                  write(*,'(a)') 'matrix_type must be block_tridiag_dble_matrix_type for bcyclic_dble'
                  ierr = -1
                  return
               end if
               call newt(.true., ierr)
               
            case default
            
               write(*,*) 'invalid hydro_decsol_option', s% hydro_decsol_option
               ierr = -1
               
         end select

         if (ierr /= 0 .and. s% report_ierr) then
            write(*,*) 'newton failed for hydro'
         end if
         
         converged = (ierr == 0) .and. (.not. failure)            
         if (converged) then         
            do k=1,nz
               do j=1,nvar_hydro
                  s% xh(j,k) = s% xh_pre(j,k) + dx(j,k)
               end do
            end do
            ! s% xa has already been updated by final call to set_newton_vars from newton solver
         end if

         call non_crit_return_work_array(s, x_scale1, 'hydro_newton_step')            
         call non_crit_return_work_array(s, dx_init1, 'hydro_newton_step')            
         
         
         contains
         
         
         subroutine newt(sparse, ierr)
            use chem_def
            use hydro_newton_procs
            use star_utils, only: total_times
            use mod_star_newton
            use mod_star_sparse
            use num_lib, only: default_failed_in_setmatrix, &
               default_set_primaries, default_set_secondaries
            logical, intent(in) :: sparse
            integer, intent(out) :: ierr
            integer :: time0, time1, clock_rate
            integer :: k, j
            real(dp) :: total_other_time, total_mtx_time, total_all_before, &
               total_all_after, time_callbacks, elapsed_time, time_self
            include 'formats'
         
            if (sparse) then
               call setup_for_star_sparse(ierr)
               if (ierr /= 0) return
            end if            
            
            if (s% doing_timing) then
               call system_clock(time0,clock_rate)
               newton_work(r_mtx_time) = 1
               newton_work(r_test_time) = 1
               total_all_before = total_times(s)
            else
               newton_work(r_mtx_time) = 0
               newton_work(r_test_time) = 0
               total_all_before = 0
            end if
            
            newton_iwork(i_caller_id) = s% id
            call newton( &
               s, nz, nvar, dx1, dx_init1, &
               sparse, lrd, rpar_decsol, lid, ipar_decsol, &
               s% hydro_decsol_option, tol_correction_norm, &
               x_scale1, s% equ1, &
               newton_work, newton_lwork, &
               newton_iwork, newton_liwork, &
               s% AF1, lrpar, rpar, lipar, ipar, failure, ierr)
            
            if (s% doing_timing) then ! subtract time_newton_mtx
               call system_clock(time1,clock_rate)
               total_all_after = total_times(s)
               time_callbacks = total_all_after - total_all_before
               elapsed_time = dble(time1-time0)/clock_rate
               total_other_time = elapsed_time - time_callbacks
               total_mtx_time = newton_work(r_mtx_time)
               time_self = total_other_time - total_mtx_time
               s% time_newton_self = s% time_newton_self + time_self
               s% time_newton_mtx = s% time_newton_mtx + total_mtx_time
               s% time_newton_test = s% time_newton_test + &
                  newton_work(r_test_time) - (total_mtx_time + time_callbacks)
            end if

            if (sparse .and. keep_sprs_statistics) then
               write(*,*)
               write(*,2) 'sprs_num_alloc_klu_storage', sprs_num_alloc_klu_storage
               write(*,2) 'sprs_num_clear_klu_storage', sprs_num_clear_klu_storage
               write(*,*)
               write(*,2) 'sprs_num_analyze', sprs_num_analyze
               write(*,2) 'sprs_num_free_symbolic', sprs_num_free_symbolic
               write(*,*)
               write(*,2) 'sprs_num_factor', sprs_num_factor
               write(*,2) 'sprs_num_free_numeric', sprs_num_free_numeric
               write(*,*)
               write(*,2) 'sprs_num_alloc_klu_factors', sprs_num_alloc_klu_factors
               write(*,2) 'sprs_num_free_klu_factors', sprs_num_free_klu_factors
               write(*,*)
               write(*,2) 'sprs_num_refactor', sprs_num_refactor
               write(*,2) 'sprs_num_solve', sprs_num_solve
               write(*,2) 'factor + refactor - solve', &
                  sprs_num_factor + sprs_num_refactor - sprs_num_solve
               write(*,*)
            end if

         end subroutine newt
         
         
         subroutine setup_for_star_sparse(ierr)
            use mod_star_sparse
            use net_lib, only: net_work_size
            
            integer, intent(out) :: ierr
         
            real(dp), target :: mtx_array(nvar*nvar)
            real(dp), pointer :: mtx(:,:)
            integer :: k, sprs_nonzeros, net_lwork
            type(sparse_info), pointer :: ks(:)
            
            include 'formats'
            
            ierr = 0
            mtx(1:nvar,1:nvar) => mtx_array(1:nvar*nvar)
            net_lwork = net_work_size(s% net_handle, ierr)
            
            !write(*,*) 'setup_for_sparse'

            call star_alloc_klu_storage(s, ierr)
            if (ierr /= 0) return
            
            if (s% bcyclic_nvar_hydro == s% nvar_hydro .and. &
                trim(s% bcyclic_sprs_shared_net_name) == trim(s% net_name)) return
            
            if (s% bcyclic_nvar_hydro /= 0) then ! free old before get new
               
               call star_sparse_free_symbolic(s, 1, nvar, ierr)
               if (ierr /= 0) return
               call star_sparse_free_numeric(s, 1, nvar, ierr)
               if (ierr /= 0) return
               call star_sparse_free_all(s, nvar, 2, ierr)
               if (ierr /= 0) return

            end if
            
            k = 1
            call star_sparse_matrix_info( &
               s, k, nvar, s% species, net_lwork, mtx, sprs_nonzeros, ierr)   
            if (ierr /= 0) then
               write(*,2) 'sparse_get_matrix failed', s% model_number
               stop 'hydro_newton_step'
            end if
            !write(*,*) 'call sparse_analyze'
            call star_sparse_analyze(s, k, nvar, mtx, ierr)   
            if (ierr /= 0) then
               write(*,2) 'sparse_analyze failed', s% model_number
               stop 'hydro_newton_step'
            end if
            
            ks => s% bcyclic_klu_storage
            s% bcyclic_sprs_shared_net_name = s% net_name
            s% bcyclic_nvar_hydro = s% nvar_hydro
            s% bcyclic_shared_sprs_nonzeros = sprs_nonzeros
            s% bcyclic_sprs_shared_ia => ks(k)% ia
            s% bcyclic_sprs_shared_ja => ks(k)% ja
            s% bcyclic_sprs_shared_ipar8_decsol = ks(k)% ipar8_decsol
         
         end subroutine setup_for_star_sparse
         
      
      end subroutine hydro_newton_step
      
      
      subroutine set_L_burn_by_category(s)
         use rates_def, only: i_rate
         type (star_info), pointer :: s
         integer :: k, j
         real(dp) :: L_burn_by_category(num_categories)            
         L_burn_by_category(:) = 0         
         do k = s% nz, 1, -1
            do j = 1, num_categories
               L_burn_by_category(j) = &
                  L_burn_by_category(j) + s% dm(k)*s% eps_nuc_categories(i_rate, j, k)
               s% luminosity_by_category(j,k) = L_burn_by_category(j)
            end do
         end do      
      end subroutine set_L_burn_by_category

      

      end module solve_hydro


