! ***********************************************************************
!
!   Copyright (C) 2013  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************


      module star_newton

      use star_private_def
      use const_def, only: dp, qp
      use utils_lib, only: is_bad_num
      use num_def
      use mtx_def
      use mtx_lib, only: &
         band_multiply_xa, lapack_work_sizes, &
         block_multiply_xa, quad_block_multiply_xa, &
         bcyclic_dble_work_sizes, bcyclic_klu_work_sizes, &
         bcyclic_dble_decsolblk, bcyclic_klu_decsolblk
         
      use hydro_newton_procs
               
      
      implicit none
      
      
      contains


      subroutine newton( &
            s, nz, nvar, dx, sparse, &
            lrd, rpar_decsol, lid, ipar_decsol, which_decsol, &
            tol_correction_norm, &
            xscale, equ, work, lwork, iwork, liwork, AF, AF_qp, B_qp, &
            lrpar, rpar, lipar, ipar, convergence_failure, ierr)
         use alloc, only: non_crit_get_quad_array, non_crit_return_quad_array
         use utils_lib, only: realloc_if_needed_1, quad_realloc_if_needed_1
         
         type (star_info), pointer :: s
         ! the primary variables
         integer, intent(in) :: nz ! number of zones
         integer, intent(in) :: nvar ! number of variables per zone
         real(dp), pointer, dimension(:) :: dx ! =(nvar,nz) 
         
         logical, intent(in) :: sparse

         ! these arrays provide optional extra working storage for the matrix routines.
         ! the implementations in mesa/mtx include routines to determine the sizes.
         ! for example, the LAPACK version is called lapack_work_sizes.
         integer, intent(in) :: lrd, lid
         integer, intent(inout), pointer :: ipar_decsol(:) ! (lid)
         real(dp), intent(inout), pointer :: rpar_decsol(:) ! (lrd)
         integer, intent(in) :: which_decsol

         real(dp), pointer, dimension(:) :: xscale ! =(nvar,nz)
         real(dp), pointer, dimension(:) :: equ ! =(nvar,nz)
         ! equ(i) has the residual for equation i, i.e., the difference between
         ! the left and right hand sides of the equation.

         ! work arrays. required sizes provided by the routine newton_work_sizes.
         ! for standard use, set work and iwork to 0 before calling.
         ! NOTE: these arrays contain some optional parameter settings and outputs.
         ! see num_def for details.
         integer, intent(in) :: lwork, liwork
         real(dp), intent(inout), target :: work(:) ! (lwork)
         integer, intent(inout), target :: iwork(:) ! (liwork)
         real(dp), pointer, dimension(:) :: AF ! for factored jacobian
            ! will be allocated or reallocated as necessary.  
         real(qp), pointer, dimension(:) :: AF_qp, B_qp

         ! convergence criteria
         real(dp), intent(in) :: tol_correction_norm
            ! a trial solution is considered to have converged if
            ! max_correction <= tol_max_correction and
            !
            ! either
            !          (correction_norm <= tol_correction_norm)  
            !    .and. (residual_norm <= tol_residual_norm)
            ! or
            !          (correction_norm*residual_norm <= tol_corr_resid_product)
            !    .and. (abs(slope) <= tol_abs_slope_min)
            !
            ! where "slope" is slope of the line for line search in the newton solver,
            ! and is analogous to the slope of df/ddx in a 1D newton root finder.

         ! parameters for caller-supplied routines
         integer, intent(in) :: lrpar, lipar
         real(dp), intent(inout) :: rpar(:) ! (lrpar)
         integer, intent(inout) :: ipar(:) ! (lipar)
         
         ! output
         logical, intent(out) :: convergence_failure
         integer, intent(out) :: ierr ! 0 means okay.
         
         integer :: ldAF, neqns, mljac, mujac
         real(dp), pointer :: AF_copy(:) ! =(ldAF, neq)
         real(qp), pointer :: AF_qp_copy(:) ! =(ldAF, neq)
         real(qp), pointer :: B_qp_copy(:) ! =(neq)
         
         integer :: n, need_lrd, need_lid
               
         integer :: test_time0, test_time1, clock_rate
         logical :: do_test_timing
         
         include 'formats.dek'
         
         !write(*,2) 'newton nvar', nvar

         do_test_timing = (work(r_test_time) /= 0)
         if (do_test_timing) call system_clock(test_time0,clock_rate)

         work(r_test_time) = 0

         ierr = 0

         neqns = nvar*nz
         ldAF = 3*nvar
         
         if (sparse) then
            call bcyclic_klu_work_sizes(nvar, nz, need_lrd, need_lid)
         else if (which_decsol == lapack) then
            call lapack_work_sizes(neqns, need_lrd, need_lid)
            mljac = 2*nvar-1
            mujac = mljac
            ldAF = 2*mljac+mujac+1
         else
            call bcyclic_dble_work_sizes(nvar, nz, need_lrd, need_lid) ! same for quad
         end if
         
         if (need_lrd > lrd .or. need_lid > lid) then
            write(*,*) 'bad lrd or lid for newton'
            write(*,2) 'need_lrd', need_lrd
            write(*,2) '     lrd', lrd
            write(*,2) 'need_lid', need_lid
            write(*,2) '     lid', lid
            ierr = -1
            return
         end if
         
         call realloc_if_needed_1(AF,ldAF*neqns,(ldAF+2)*200,ierr)
         if (ierr /= 0) return
         AF_copy => AF
         
         if (s% hydro_matrix_type /= block_tridiag_quad_matrix_type) then         
            nullify(B_qp_copy, AF_qp_copy)            
         else         
            call quad_realloc_if_needed_1(AF_qp,ldAF*neqns,(ldAF+2)*200,ierr)
            if (ierr /= 0) return
            AF_qp_copy => AF_qp            
            call quad_realloc_if_needed_1(B_qp,neqns,200,ierr)
            if (ierr /= 0) return
            B_qp_copy => B_qp
         end if
         
         call do_newton( &
            s, nz, nvar, dx, AF_copy, AF_qp_copy, B_qp_copy, ldAF, &
            neqns, sparse, lrd, rpar_decsol, lid, ipar_decsol, which_decsol, &
            tol_correction_norm, xscale, equ, &
            work, lwork, iwork, liwork, &
            lrpar, rpar, lipar, ipar, convergence_failure, ierr)
         s% newton_iter = 0

         if (do_test_timing) then
            call system_clock(test_time1,clock_rate)
            work(r_test_time) = work(r_test_time) + dble(test_time1 - test_time0) / clock_rate
         end if
        
         
         contains
         
               
         logical function bad_isize(a,sz,str)
            integer :: a(:)
            integer, intent(in) :: sz
            character (len=*), intent(in) :: str
            bad_isize = (size(a,dim=1) < sz)
            if (.not. bad_isize) return
            ierr = -1
            write(*,*) 'interpolation: bad sizes for ' // trim(str)
            return
         end function bad_isize
         
      
         logical function bad_size(a,sz,str)
            real(dp) :: a(:)
            integer, intent(in) :: sz
            character (len=*), intent(in) :: str
            bad_size = (size(a,dim=1) < sz)
            if (.not. bad_size) return
            ierr = -1
            write(*,*) 'interpolation: bad sizes for ' // trim(str)
            return
         end function bad_size
         
      
         logical function bad_size_dble(a,sz,str)
            real(dp) :: a(:)
            integer, intent(in) :: sz
            character (len=*), intent(in) :: str
            bad_size_dble = (size(a,dim=1) < sz)
            if (.not. bad_size_dble) return
            ierr = -1
            write(*,*) 'interpolation: bad sizes for ' // trim(str)
            return
         end function bad_size_dble
         
      
         logical function bad_sizes(a,sz1,sz2,str)
            real(dp) :: a(:,:)
            integer, intent(in) :: sz1,sz2
            character (len=*), intent(in) :: str
            bad_sizes = (size(a,dim=1) < sz1 .or. size(a,dim=2) < sz2)
            if (.not. bad_sizes) return
            ierr = -1
            write(*,*) 'interpolation: bad sizes for ' // trim(str)
            return
         end function bad_sizes
         
         
      end subroutine newton


      subroutine do_newton( &
         s, nz, nvar, dx1, AF1, AF1_qp, B1_qp, ldAF, neq, sparse, &
         lrd, rpar_decsol, lid, ipar_decsol, which_decsol, &
         tol_correction_norm, xscale1, equ1, &
         work, lwork, iwork, liwork, &
         lrpar, rpar, lipar, ipar, convergence_failure, ierr)

         type (star_info), pointer :: s

         integer, intent(in) :: nz, nvar, ldAF, neq
         logical, intent(in) :: sparse

         real(dp), pointer, dimension(:) :: AF1 ! =(ldAF, neq)
         real(qp), pointer, dimension(:) :: AF1_qp ! =(ldAF, neq)
         real(qp), pointer, dimension(:) :: B1_qp ! =(neq)
         real(dp), pointer, dimension(:) :: dx1, equ1, xscale1 
                           
         integer, intent(in) :: lrd, lid, which_decsol
         integer, intent(inout), pointer :: ipar_decsol(:) ! (lid)
         real(dp), intent(inout), pointer :: rpar_decsol(:) ! (lrd)

         ! controls         
         real(dp), intent(in) :: tol_correction_norm

         ! parameters for caller-supplied routines
         integer, intent(in) :: lrpar, lipar
         real(dp), intent(inout) :: rpar(:) ! (lrpar)
         integer, intent(inout) :: ipar(:) ! (lipar)

         ! work arrays
         integer, intent(in) :: lwork, liwork
         real(dp), intent(inout), target :: work(:) ! (lwork)
         integer, intent(inout), target :: iwork(:) ! (liwork)

         ! output
         logical, intent(out) :: convergence_failure
         integer, intent(out) :: ierr

         ! info saved in work arrays       
           
         real(dp), dimension(:,:), pointer :: dxsave, ddxsave, B, grad_f, B_init
         real(dp), dimension(:), pointer :: dxsave1, ddxsave1, B1, B_init1, grad_f1
         real(dp), dimension(:,:), pointer ::  rhs
         integer, dimension(:), pointer :: ipiv1
         real(dp), dimension(:,:), pointer :: &
            ddx, xgg, ddxd, ddxdd, xder, equsave
         
         integer, dimension(:), pointer :: ipiv_blk1

         real(dp), dimension(:,:), pointer :: A, Acopy
         real(dp), dimension(:), pointer :: A1, Acopy1
         real(dp), dimension(:), pointer :: lblk1, dblk1, ublk1
         real(dp), dimension(:), pointer :: lblkF1, dblkF1, ublkF1

!         real(qp), dimension(:,:), pointer :: A_qp, Acopy_qp
!         real(qp), dimension(:), pointer :: A1_qp, Acopy1_qp
!         real(qp), dimension(:), pointer :: lblk1_qp, dblk1_qp, ublk1_qp
         real(qp), dimension(:), pointer :: lblkF1_qp, dblkF1_qp, ublkF1_qp
!         real(qp), dimension(:), pointer :: B1_qp
         
         ! locals
         real(dp)  ::  &
            coeff, f, slope, residual_norm, max_residual, &
            corr_norm_min, resid_norm_min, correction_factor, &
            correction_norm, corr_norm_initial, max_correction, slope_extra, &
            tol_max_correction, tol_residual_norm, tol_abs_slope_min, tol_corr_resid_product, &
            min_corr_coeff, tol_max_residual, max_corr_min, max_resid_min
         integer :: iter, max_tries, ndiag, zone, idiag, tiny_corr_cnt, ldA, i, j, k, info, &
            last_jac_iter, max_iterations_for_jacobian, force_iter_value, mljac, mujac, &
            test_time0, test_time1, time0, time1, clock_rate, caller_id
         character (len=strlen) :: err_msg
         logical :: first_try, dbg_msg, passed_tol_tests, &
            overlay_AF, do_mtx_timing, do_test_timing, doing_extra
         integer, parameter :: num_tol_msgs = 15
         character (len=32) :: tol_msg(num_tol_msgs)
         character (len=64) :: message
         real(dp), pointer, dimension(:) :: p1_1, p1_2

         ! set pointers to 1D data
         real(dp), pointer, dimension(:,:) :: dx, equ, xscale ! (nvar,nz)       
         real(dp), pointer, dimension(:,:) :: AF ! (ldAF,neq)
         real(dp), pointer, dimension(:,:,:) :: ublk, dblk, lblk ! (nvar,nvar,nz)
         real(dp), dimension(:,:,:), pointer :: lblkF, dblkF, ublkF ! (nvar,nvar,nz)

         !real(qp), pointer, dimension(:,:) :: AF_qp ! (ldAF,neq)
         real(qp), pointer, dimension(:,:,:) :: ublk_qp, dblk_qp, lblk_qp ! (nvar,nvar,nz)
         !real(qp), dimension(:,:,:), pointer :: lblkF_qp, dblkF_qp, ublkF_qp ! (nvar,nvar,nz)
         
         include 'formats.dek'
                  
         dx(1:nvar,1:nz) => dx1(1:neq)
         equ(1:nvar,1:nz) => equ1(1:neq)
         xscale(1:nvar,1:nz) => xscale1(1:neq)
         AF(1:ldAF,1:neq) => AF1(1:ldAF*neq)
         if (associated(AF1_qp)) then
            !AF_qp(1:ldAF,1:neq) => AF1_qp(1:ldAF*neq)
            ublkF1_qp(1:nvar*neq) => AF1_qp(1:nvar*neq)
            dblkF1_qp(1:nvar*neq) => AF1_qp(1+nvar*neq:2*nvar*neq)
            lblkF1_qp(1:nvar*neq) => AF1_qp(1+2*nvar*neq:3*nvar*neq)
         end if
         
         do_mtx_timing = (work(r_mtx_time) /= 0)
         work(r_mtx_time) = 0

         tol_msg(1)  = 'avg corr'
         tol_msg(2)  = 'max corr '
         tol_msg(3)  = 'avg+max corr'
         tol_msg(4)  = 'avg resid'
         tol_msg(5)  = 'avg corr+resid'
         tol_msg(6)  = 'max corr, avg resid'
         tol_msg(7)  = 'avg+max corr, avg resid'
         tol_msg(8)  = 'max resid'
         tol_msg(9)  = 'avg corr, max resid'
         tol_msg(10) = 'max corr+resid'
         tol_msg(11) = 'avg+max corr, max resid'
         tol_msg(12) = 'avg+max resid'
         tol_msg(13) = 'avg corr, avg+max resid'
         tol_msg(14) = 'max corr, avg+max resid'
         tol_msg(15) = 'avg+max corr+resid'
 
         ierr = 0
         iter = 0
         s% newton_iter = iter

         call set_param_defaults
         dbg_msg = (iwork(i_debug) /= 0)
         tol_residual_norm = work(r_tol_residual_norm)
         tol_max_residual = work(r_tol_max_residual)
         tol_max_correction = work(r_tol_max_correction)
         tol_abs_slope_min = work(r_tol_abs_slope_min)
         tol_corr_resid_product = work(r_tol_corr_resid_product)
         min_corr_coeff = work(r_min_corr_coeff)
         
         caller_id = iwork(i_caller_id)
         
         mljac = 2*nvar-1
         mujac = mljac
         idiag = mljac+mujac+1
         ndiag = 3*nvar
         ldA = ndiag
         call pointers(ierr)
         if (ierr /= 0) return
      
         doing_extra = .false.
         passed_tol_tests = .false. ! goes true when pass the tests
         convergence_failure = .false. ! goes true when time to give up
         coeff = 1.
         xscale = 1.
  
         residual_norm=0
         max_residual=0
         corr_norm_min=1d99
         max_corr_min=1d99
         max_resid_min=1d99
         resid_norm_min=1d99
         correction_factor=0
         
         call set_xscale_info(s, nvar, nz, xscale, ierr)
         if (ierr /= 0) then
            if (dbg_msg) &
               write(*, *) 'newton failure: set_xscale_info returned ierr', ierr
            convergence_failure = .true.
            return
         end if
         
         call eval_equations( &
            iter, nvar, nz, dx, xscale, equ, lrpar, rpar, lipar, ipar, ierr)
         if (ierr /= 0) then
            if (dbg_msg) &
               write(*, *) 'newton failure: eval_equations returned ierr', ierr
            convergence_failure = .true.
            return
         end if
         
         call sizequ( &
            iter, nvar, nz, equ, residual_norm, max_residual, lrpar, rpar, lipar, ipar, ierr)
         if (ierr /= 0) then
            if (dbg_msg) &
               write(*, *) 'newton failure: sizequ returned ierr', ierr
            convergence_failure = .true.
            return
         end if

         first_try = .true.
         iter = 1
         s% newton_iter = iter
         max_tries = abs(iwork(i_max_tries))
         last_jac_iter = 0
         tiny_corr_cnt = 0
         
         if (iwork(i_max_iterations_for_jacobian) == 0) then
            max_iterations_for_jacobian = 1000000
         else
            max_iterations_for_jacobian = iwork(i_max_iterations_for_jacobian)
         end if

         do while (.not. passed_tol_tests)
            
            if (dbg_msg .and. first_try) write(*, *)
                  
            if (iter >= s% iter_for_resid_tol2) then
               if (iter < s% iter_for_resid_tol3) then
                  tol_residual_norm = s% tol_residual_norm2
                  tol_max_residual = s% tol_max_residual2
               else
                  tol_residual_norm = s% tol_residual_norm3
                  tol_max_residual = s% tol_max_residual3
               end if
            end if

            overlay_AF = (min_corr_coeff == 1)
            if (overlay_AF) then
               A1 => AF1
               A => AF
               ldA = ldAF
               ublk1 => ublkF1
               dblk1 => dblkF1
               lblk1 => lblkF1
               lblk(1:nvar,1:nvar,1:nz) => lblk1(1:nvar*neq)
               dblk(1:nvar,1:nvar,1:nz) => dblk1(1:nvar*neq)
               ublk(1:nvar,1:nvar,1:nz) => ublk1(1:nvar*neq)
            end if

            call setmatrix( &
               neq, dx, xscale, dxsave, ddxsave, lrpar, rpar, lipar, ipar, ierr)
            if (ierr /= 0) then
               call write_msg('setmatrix returned ierr /= 0')
               convergence_failure = .true.
               exit
            end if
            iwork(i_num_jacobians) = iwork(i_num_jacobians) + 1
            last_jac_iter = iter
            
            if (.not. solve_equ()) then ! either singular or horribly ill-conditioned
               write(err_msg, '(a, i5, 3x, a)') 'info', ierr, 'bad_matrix'
               call oops(err_msg)
               exit
            end if
            iwork(i_num_solves) = iwork(i_num_solves) + 1

            ! inform caller about the correction
            call inspectB(iter, nvar, nz, dx, B, xscale, lrpar, rpar, lipar, ipar, ierr)
            if (ierr /= 0) then
               call oops('inspectB returned ierr')
               exit
            end if

            ! compute size of scaled correction B
            call sizeB(iter, nvar, nz, B, xscale, max_correction, correction_norm,  &
                     lrpar, rpar, lipar, ipar, ierr)
            if (ierr /= 0) then
               call oops('correction rejected by sizeB')
               exit
            end if
            
            correction_norm = abs(correction_norm)
            max_correction = abs(max_correction)
            corr_norm_min = min(correction_norm, corr_norm_min)
            max_corr_min = min(max_correction, max_corr_min)

            if (is_bad_num(correction_norm) .or. is_bad_num(max_correction)) then 
               ! bad news -- bogus correction
               call oops('bad result from sizeB -- correction info either NaN or Inf')
               exit
            end if

            if ((correction_norm > work(r_corr_param_factor)*work(r_scale_correction_norm)) .and. &
                  (iwork(i_try_really_hard) == 0)) then
               call oops('avg corr too large')
               exit
            endif
         
            ! shrink the correction if it is too large
            correction_factor = 1d0
            
            if (correction_norm*correction_factor > work(r_scale_correction_norm)) then
               correction_factor = work(r_scale_correction_norm)/correction_norm
               !write(*,2) 'correction_norm force reduce correction_factor', iter, correction_factor
            end if
            
            if (max_correction*correction_factor > work(r_scale_max_correction)) then
               correction_factor = work(r_scale_max_correction)/max_correction
               !write(*,2) 'max_correction force reduce correction_factor', iter, correction_factor
            end if
            
            ! fix B if out of definition domain
            call Bdomain( &
               iter, nvar, nz, B, dx, xscale, correction_factor, &
               lrpar, rpar, lipar, ipar, ierr)
            if (ierr /= 0) then ! correction cannot be fixed
               call oops('correction rejected by Bdomain')
               exit
            end if
            
            !if (correction_factor < 1d0) write(*,2) 'Bdomain correction_factor', iter, correction_factor
            
            if (iter > s% newton_itermin_until_reduce_min_corr_coeff .and. &
                  min_corr_coeff == 1d0 .and. &
                  s% newton_reduced_min_corr_coeff < 1d0) then
               min_corr_coeff = s% newton_reduced_min_corr_coeff
               !write(*,2) 'reduce min_corr_coeff', iter, min_corr_coeff
            end if

            if (min_corr_coeff < 1d0) then
            
               ! compute gradient of f = equ<dot>jacobian
               ! NOTE: NOT jacobian<dot>equ
               if (which_decsol == lapack) then
                  call band_multiply_xa(neq, mljac, mujac, A1, ldA, equ1, grad_f1)
               else
                  call block_multiply_xa(nvar, nz, lblk1, dblk1, ublk1, equ1, grad_f1)      
               end if
                           
               slope = eval_slope(nvar, nz, grad_f, B)
               !if (is_bad_num(slope)) then
               !   call oops('bad slope value')
               !   exit
               !end if
               if (is_bad_num(slope) .or. slope > 0) then ! a very bad sign
                  !write(*,3) 'NEWTON JACOBIAN ERROR: slope must be < 0', &
                  !   s% model_number, iter, slope
                  !write(*,*)
                  slope = 0
                  min_corr_coeff = 1d0
                  !call oops('bad jacobian')
                  !exit
               end if
               
            else
            
               slope = 0

            end if
            
            call adjust_correction( &
               min_corr_coeff, correction_factor, grad_f1, f, slope, coeff, &
               err_msg, lrpar, rpar, lipar, ipar, ierr)
            if (ierr /= 0) then
               call oops(err_msg)
               exit
            end if
            s% newton_adjust_iter = 0
            
            ! coeff is factor by which adjust_correction rescaled the correction vector
            if (coeff > work(r_tiny_corr_factor)*min_corr_coeff .or. min_corr_coeff >= 1d0) then
               tiny_corr_cnt = 0
            else
               tiny_corr_cnt = tiny_corr_cnt + 1
            end if

            ! check the residuals for the equations
            call sizequ( &
               iter, nvar, nz, equ, residual_norm, max_residual, &
               lrpar, rpar, lipar, ipar, ierr)
            if (ierr /= 0) then
               call oops('sizequ returned ierr')
               exit
            end if
            if (is_bad_num(residual_norm)) then
               call oops('residual_norm is a a bad number (NaN or Infinity)')
               exit
            end if
            if (is_bad_num(max_residual)) then
               call oops('max_residual is a a bad number (NaN or Infinity)')
               exit
            end if
            
            residual_norm = abs(residual_norm)
            max_residual = abs(max_residual)
            resid_norm_min = min(residual_norm, resid_norm_min)
            max_resid_min = min(max_residual, max_resid_min)
            
            if (max_correction > tol_max_correction*coeff .or. &
                  max_residual > tol_max_residual*coeff) then
               passed_tol_tests = .false.
            else
               passed_tol_tests = &
                     (correction_norm <= tol_correction_norm*coeff .and.  &
                      residual_norm <= tol_residual_norm*coeff) &
                .or.       &
                     (abs(slope) <= tol_abs_slope_min .and.  &
                      correction_norm*residual_norm <= tol_corr_resid_product*coeff*coeff)
            end if
            
            if (.not. passed_tol_tests) then
               if (iter >= max_tries) then
                  if (dbg_msg) then
                     call get_message
                     message = trim(message) // ' -- give up'
                     call write_msg(message)
                  end if
                  convergence_failure = .true.; exit
               else if (iwork(i_try_really_hard) == 0) then
                  if (coeff < min(min_corr_coeff,correction_factor)) then
                     call oops('coeff too small')
                     exit
                  else if (correction_norm > tol_correction_norm*coeff &
                        .and. (correction_norm > work(r_corr_norm_jump_limit)*corr_norm_min) &
                        .and. (.not. first_try)) then
                     call oops('avg corrrection jumped')
                     exit
                  else if (residual_norm > tol_residual_norm*coeff &
                        .and. (residual_norm > work(r_resid_norm_jump_limit)*resid_norm_min) &
                        .and. (.not. first_try)) then
                     call oops('avg residual jumped')
                     exit
                  else if (max_correction > tol_max_correction*coeff &
                        .and. (max_correction > work(r_max_corr_jump_limit)*max_corr_min) &
                        .and. (.not. first_try)) then
                     call oops('max corrrection jumped')
                     exit
                  else if (residual_norm > tol_residual_norm*coeff &
                        .and. (max_residual > work(r_max_resid_jump_limit)*max_resid_min) &
                        .and. (.not. first_try)) then
                     call oops('max residual jumped')
                     exit
                  else if (tiny_corr_cnt >= iwork(i_tiny_min_corr_coeff) &
                        .and. min_corr_coeff < 1) then
                     call oops('tiny corrections')
                     exit
                  end if
               end if
            end if
            
            if (dbg_msg) then
               if (.not. passed_tol_tests) then
                  call get_message
                  call write_msg(message)
               else if (iter < iwork(i_itermin)) then     
                  call write_msg('iter < itermin')
               else
                  call write_msg('okay!')
                  !write(*,1) 'tol_max_correction, max_correction', tol_max_correction, max_correction
               end if
            end if
            
            if (passed_tol_tests .and. (iter+1 < max_tries)) then 
               ! about to declare victory... but may want to do another iteration
               force_iter_value = force_another_iteration( &
                                    iter, iwork(i_itermin), lrpar, rpar, lipar, ipar)
               if (force_iter_value > 0) then
                  passed_tol_tests = .false. ! force another
                  tiny_corr_cnt = 0 ! reset the counter
                  corr_norm_min = 1d99
                  resid_norm_min = 1d99
                  max_corr_min = 1d99
                  max_resid_min = 1d99
               else if (force_iter_value < 0) then ! failure
                  call oops('force iter')
                  exit
               end if
            end if
            
            if (s% use_other_newton_monitor .and. &
                  associated(s% other_newton_monitor)) then
               call s% other_newton_monitor( &
                  s% id, iter, passed_tol_tests, &
                  correction_norm, max_correction, &
                  residual_norm, max_residual, ierr)
               if (ierr /= 0) then
                  call oops('other_newton_monitor')
                  exit
               end if
            end if

            iter=iter+1
            s% newton_iter = iter
            first_try = .false.

         end do
         

         contains
         
         
         
         subroutine get_message
            include 'formats.dek'
            i = 0
            if (correction_norm > tol_correction_norm*coeff) i = i+1
            if (max_correction > tol_max_correction*coeff) i = i+2
            if (residual_norm > tol_residual_norm*coeff) i = i+4
            if (max_residual > tol_max_residual*coeff) i = i+8
            if (i == 0) then
               message = 'out of tries'
            else
               message = tol_msg(i)
            end if
         end subroutine get_message

         
         subroutine set_param_defaults
         
            if (iwork(i_itermin) == 0) iwork(i_itermin) = 2
            if (iwork(i_max_tries) == 0) iwork(i_max_tries) = 50
            if (iwork(i_tiny_min_corr_coeff) == 0) iwork(i_tiny_min_corr_coeff) = 25
            
            if (work(r_tol_residual_norm)==0) work(r_tol_residual_norm)=1d99
            if (work(r_tol_max_residual)==0) work(r_tol_max_residual)=1d99
            if (work(r_tol_max_correction)==0) work(r_tol_max_correction)=1d99
            if (work(r_scale_correction_norm) == 0) work(r_scale_correction_norm) = 2d0
            if (work(r_corr_param_factor) == 0) work(r_corr_param_factor) = 10d0
            if (work(r_scale_max_correction) == 0) work(r_scale_max_correction) = 1d99
            if (work(r_corr_norm_jump_limit) == 0) work(r_corr_norm_jump_limit) = 1d99
            if (work(r_max_corr_jump_limit) == 0) work(r_max_corr_jump_limit) = 1d99
            if (work(r_resid_norm_jump_limit) == 0) work(r_resid_norm_jump_limit) = 1d99
            if (work(r_max_resid_jump_limit) == 0) work(r_max_resid_jump_limit) = 1d99
            if (work(r_min_corr_coeff) == 0) work(r_min_corr_coeff) = 1d-3
            if (work(r_tiny_corr_factor) == 0) work(r_tiny_corr_factor) = 2d0

         end subroutine set_param_defaults
         
         
         subroutine oops(msg)
            character (len=*), intent(in) :: msg
            character (len=strlen) :: full_msg
            full_msg = trim(msg) // ' -- give up'
            call write_msg(full_msg)
            convergence_failure = .true.
         end subroutine oops


         subroutine adjust_correction( &
               min_corr_coeff_in, max_corr_coeff, grad_f, f, slope, coeff,  &
               err_msg, lrpar, rpar, lipar, ipar, ierr)
            real(dp), intent(in) :: min_corr_coeff_in
            real(dp), intent(in) :: max_corr_coeff
            real(dp), intent(in) :: grad_f(:) ! (neq) ! gradient df/ddx at xold
            real(dp), intent(out) :: f ! 1/2 fvec^2. minimize this.
            real(dp), intent(in) :: slope 
            real(dp), intent(out) :: coeff 

            ! the new correction is coeff*xscale*B
            ! with min_corr_coeff <= coeff <= max_corr_coeff
            ! if all goes well, the new x will give an improvement in f
            
            character (len=*), intent(out) :: err_msg
            integer, intent(in) :: lrpar, lipar
            real(dp), intent(inout) :: rpar(:) ! (lrpar)
            integer, intent(inout) :: ipar(:) ! (lipar)
            integer, intent(out) :: ierr
      
            integer :: i, j, k, iter, k_max_corr, i_max_corr
            character (len=strlen) :: message
            logical :: first_time
            real(dp) :: a1, alam, alam2, alamin, a2, disc, f2, &
               rhs1, rhs2, temp, test, tmplam, max_corr, fold, min_corr_coeff
            real(dp) :: frac, f_target
            logical :: skip_eval_f, dbg_adjust
     
            real(dp), parameter :: alf = 1d-2 ! ensures sufficient decrease in f

            real(dp), parameter :: alam_factor = 0.2d0
            
            include 'formats.dek'
         
            ierr = 0                  
            coeff = 0
            dbg_adjust = .false.  !  (s% trace_k > 0 .and. s% trace_k <= nz)
            
            skip_eval_f = (min_corr_coeff_in == 1)
            if (skip_eval_f) then
               f = 0
            else
               do k=1,nz
                  do i=1,nvar
                     dxsave(i,k) = dx(i,k)
                     ddxsave(i,k) = ddx(i,k)
                  end do
               end do
               f = eval_f(nvar,nz,equ)
               if (is_bad_num(f)) then
                  ierr = -1
                  write(err_msg,*) 'adjust_correction failed in eval_f'
                  if (dbg_msg) write(*,*) &
                     'adjust_correction: eval_f(nvar,nz,equ)', eval_f(nvar,nz,equ)
                  return
               end if
            end if
            fold = f
            
            min_corr_coeff = min(min_corr_coeff_in, max_corr_coeff) ! make sure min <= max            
            alam = max_corr_coeff
            first_time = .true.
            f2 = 0
            alam2 = 0
            if (dbg_adjust) then
               write(*,4) 'max_corr_coeff', k, s% newton_iter, &
                  s% model_number, max_corr_coeff
               write(*,4) 'slope', k, s% newton_iter, &
                  s% model_number, slope
               write(*,4) 'f', k, s% newton_iter, &
                  s% model_number, f
            end if

         search_loop: do iter = 1, 1000
            
               coeff = max(min_corr_coeff, alam) 
               s% newton_adjust_iter = iter
               
               call apply_coeff(nvar, nz, dx, dxsave, B, xscale, coeff, skip_eval_f)               
               call eval_equations(iter, nvar, nz, dx, xscale, equ, lrpar, rpar, lipar, ipar, ierr)
               if (ierr /= 0) then
                  if (alam > min_corr_coeff .and. s% model_number == 1) then 
                     ! try again with smaller correction vector.
                     ! need this to rescue create pre-main-sequence model in some nasty cases.
                     alam = max(alam/10, min_corr_coeff)
                     ierr = 0
                     cycle
                  end if
                  write(err_msg,*) 'adjust_correction failed in eval_equations'
                  if (dbg_msg .or. dbg_adjust) &
                     write(*,2) 'adjust_correction: eval_equations returned ierr', ierr
                  exit search_loop
               end if
               
               if (min_corr_coeff == 1) return

               if (dbg_adjust) then
                  do k=1,nz
                     do i=1,nvar
                        write(*,5) trim(s% nameofequ(i)), k, iter, s% newton_iter, &
                           s% model_number, equ(i,k)
                     end do
                  end do
               end if
            
               f = eval_f(nvar,nz,equ)
               if (is_bad_num(f)) then
                  if (alam > min_corr_coeff) then
                     alam = max(alam/10, min_corr_coeff)
                     ierr = 0
                     cycle
                  end if
                  err_msg = 'equ norm is NaN or other bad num'
                  ierr = -1
                  exit search_loop
               end if
               
               f_target = max(fold/2, fold + alf*coeff*slope)
               if (f <= f_target) then
                  return ! sufficient decrease in f
               end if

               if (alam <= min_corr_coeff) then
                  return ! time to give up
               end if

               ! reduce alam and try again
               if (first_time) then
                  tmplam = -slope/(2*(f-fold-slope))
                  first_time = .false.
                  if (dbg_adjust) then
                     write(*,5) 'slope', k, iter, s% newton_iter, &
                        s% model_number, slope
                     write(*,5) 'f', k, iter, s% newton_iter, &
                        s% model_number, f
                     write(*,5) 'fold', k, iter, s% newton_iter, &
                        s% model_number, fold
                     write(*,5) '2*(f-fold-slope)', k, iter, s% newton_iter, &
                        s% model_number, 2*(f-fold-slope)
                  end if
               else ! have two prior f values to work with
                  rhs1 = f - fold - alam*slope
                  rhs2 = f2 - fold - alam2*slope
                  a1 = (rhs1/(alam*alam) - rhs2/(alam2*alam2))/(alam - alam2)
                  a2 = (-alam2*rhs1/(alam*alam) + alam*rhs2/(alam2*alam2))/(alam - alam2)
                  if (dbg_adjust) then
                     write(*,5) 'slope', k, iter, s% newton_iter, &
                        s% model_number, slope
                     write(*,5) 'f', k, iter, s% newton_iter, &
                        s% model_number, f
                     write(*,5) 'f2', k, iter, s% newton_iter, &
                        s% model_number, f2
                     write(*,5) 'fold', k, iter, s% newton_iter, &
                        s% model_number, fold
                     write(*,5) 'alam', k, iter, s% newton_iter, &
                        s% model_number, alam
                     write(*,5) 'alam2', k, iter, s% newton_iter, &
                        s% model_number, alam2
                     write(*,5) 'rhs1', k, iter, s% newton_iter, &
                        s% model_number, rhs1
                     write(*,5) 'rhs2', k, iter, s% newton_iter, &
                        s% model_number, rhs2
                     write(*,5) 'a1', k, iter, s% newton_iter, &
                        s% model_number, a1
                     write(*,5) 'a2', k, iter, s% newton_iter, &
                        s% model_number, a2
                  end if
                  if (a1 == 0) then
                     tmplam = -slope/(2*a2)
                  else
                     disc = a2*a2-3*a1*slope
                     if (disc < 0) then
                        tmplam = alam*alam_factor
                     else if (a2 <= 0) then
                        tmplam = (-a2+sqrt(disc))/(3*a1)
                     else
                        tmplam = -slope/(a2+sqrt(disc))
                     end if
                     if (dbg_adjust) then
                        write(*,5) 'disc', k, iter, s% newton_iter, &
                           s% model_number, disc
                     end if
                  end if
                  if (tmplam > alam*alam_factor) tmplam = alam*alam_factor
               end if

               alam2 = alam
               f2 = f
               alam = max(tmplam, alam*alam_factor, min_corr_coeff)
            
               if (dbg_adjust) then
                  write(*,5) 'tmplam', k, iter, s% newton_iter, &
                     s% model_number, tmplam
                  write(*,5) 'min_corr_coeff', k, iter, s% newton_iter, &
                     s% model_number, min_corr_coeff
                  write(*,5) 'alam_factor', k, iter, s% newton_iter, &
                     s% model_number, alam_factor
               end if
     
            end do search_loop

            do k=1,nz
               do i=1,nvar
                  dx(i,k) = dxsave(i,k)
                  ddx(i,k) = ddxsave(i,k)
               end do
            end do
         
         end subroutine adjust_correction
         
         
         subroutine apply_coeff(nvar, nz, dx, dxsave, B, xscale, coeff, just_use_dx)
            integer, intent(in) :: nvar, nz
            real(dp), intent(out), dimension(:,:) :: dx
            real(dp), intent(in), dimension(:,:) :: dxsave, B, xscale
            real(dp), intent(in) :: coeff
            logical, intent(in) :: just_use_dx
            integer :: i, k
            include 'formats'
            if (just_use_dx) then
               if (coeff == 1d0) then
                  do k=1,nz
                     do i=1,nvar
                        dx(i,k) = dx(i,k) + xscale(i,k)*B(i,k)
                     end do
                  end do
               else
                  do k=1,nz
                     do i=1,nvar
                        dx(i,k) = dx(i,k) + coeff*xscale(i,k)*B(i,k)
                     end do
                  end do
               end if
               return
            end if
            ! else use dxsave instead of dx
            if (coeff == 1d0) then
               do k=1,nz
                  do i=1,nvar
                     dx(i,k) = dxsave(i,k) + xscale(i,k)*B(i,k)
                  end do
               end do
               return
            end if
            do k=1,nz
               do i=1,nvar
                  dx(i,k) = dxsave(i,k) + coeff*xscale(i,k)*B(i,k)
               end do
            end do
         end subroutine apply_coeff


         logical function solve_equ()    
            use star_utils, only: start_time, update_time
            integer ::  i, k
            real(dp) :: ferr, berr, total_time
            
            include 'formats.dek'

            solve_equ=.true.
            do k=1,nz
               do i=1,nvar
                  b(i,k) = -equ(i,k)
               end do
            end do
            
            info = 0
            
            if (s% doing_timing) then
               call start_time(s, time0, total_time)
            else if (do_mtx_timing) then
               call system_clock(time0, clock_rate)            
            end if
            
            if (s% hydro_matrix_type == block_tridiag_quad_matrix_type) then
               call factor_mtx_qp
               if (info == 0) call solve_mtx_qp
            else
               call factor_mtx
               if (info == 0) call solve_mtx
            end if
            
            if (s% doing_timing) then
               call update_time(s, time0, total_time, s% time_newton_matrix)
            else if (do_mtx_timing) then
               call system_clock(time1, clock_rate)
               work(r_mtx_time) = work(r_mtx_time) + dble(time1 - time0) / clock_rate
            end if

            if (info /= 0) then 
               solve_equ=.false.
               b(1:nvar,1:nz)=0
            end if
         
         end function solve_equ
         
         
         subroutine factor_mtx_qp
            use star_bcyclic_qp, only: bcyclic_factor_qp
            integer :: k
            include 'formats.dek'
            
            if (which_decsol == lapack) then
               stop 'no support for quad banded_matrix_type'
            end if
            
            if (.not. overlay_AF) then ! copy for use in adjust_correction
               do k = 1,nvar*neq
                  lblk1(k) = lblkF1_qp(k)
                  dblk1(k) = dblkF1_qp(k)
                  ublk1(k) = ublkF1_qp(k)
               end do
            end if

            info = 0
            call bcyclic_factor_qp( &
               s, lblkF1_qp, dblkF1_qp, ublkF1_qp, ipiv_blk1, B1_qp, &
               nvar, nz, sparse, iter, &
               lrd, rpar_decsol, lid, ipar_decsol, info)     
                      
         end subroutine factor_mtx_qp
         
         
         subroutine factor_mtx
            use star_bcyclic, only: bcyclic_factor
            use mtx_lib, only: lapack_decsol
            integer :: k, ldafb
            include 'formats.dek'
            
            if (s% trace_newton_bcyclic_matrix_input .and. &
                s% model_number >= s% trace_newton_bcyclic_steplo .and. &
                s% model_number <= s% trace_newton_bcyclic_stephi .and. &
                iter >= s% trace_newton_bcyclic_iterlo .and. &
                iter <= s% trace_newton_bcyclic_iterhi) then
               write(*,3) 'newton call bcyclic_factor', iter, s% model_number
               call output_mtx
            end if
            
            if (which_decsol == lapack) then
               if (.not. overlay_AF) then
                  do j=1,neq
                     do i=1,ldA
                        AF(mljac+i,j) = A(i,j)
                     end do
                  end do
               end if                  
               ldafb=2*mljac+mujac+1
               call lapack_decsol(0, neq, ldafb, AF1, mljac, mujac, B1, ipiv_blk1, &
                     lrd, rpar_decsol, lid, ipar_decsol, info)
               return
            end if
            
            if (.not. overlay_AF) then
               do k = 1,nvar*neq
                  lblkF1(k) = lblk1(k)
                  dblkF1(k) = dblk1(k)
                  ublkF1(k) = ublk1(k)
               end do
            end if
            
            info = 0
            
            call bcyclic_factor( &
               s, lblkF1, dblkF1, AF1_qp,  ublkF1, ipiv_blk1, B1, &
               nvar, nz, sparse, iter, &
               lrd, rpar_decsol, lid, ipar_decsol, info)
               
            if (s% trace_newton_bcyclic_matrix_output .and. &
                s% model_number >= s% trace_newton_bcyclic_steplo .and. &
                s% model_number <= s% trace_newton_bcyclic_stephi .and. &
                iter >= s% trace_newton_bcyclic_iterlo .and. &
                iter <= s% trace_newton_bcyclic_iterhi) then
               write(*,3) 'newton after bcyclic_factor', iter, s% model_number
               call output_mtx
            end if
              
         end subroutine factor_mtx
         
         
         subroutine output_mtx
            integer :: i, j, k, nzlo, nzhi, jlo, jhi
            include 'formats.dek'
            nzlo = max(1,s% trace_newton_bcyclic_nzlo)
            nzhi = nz
            if (s% trace_newton_bcyclic_nzhi > 0 .and. &
                s% trace_newton_bcyclic_nzhi < nz) &
               nzhi = s% trace_newton_bcyclic_nzhi
            jlo = max(1, s% trace_newton_bcyclic_jlo)
            jhi = nvar
            if (s% trace_newton_bcyclic_jhi > 0 .and. &
                s% trace_newton_bcyclic_jhi < jhi) &
               jhi = s% trace_newton_bcyclic_jhi
            do k = nzlo, nzhi
               do j = jlo, jhi
                  do i = jlo, jhi
                     if (lblkF(i,j,k) /= 0d0) &
                        write(*,6) 'em1 ' // &
                           trim(s% nameofequ(i)) // ' ' // trim(s% nameofvar(j)), &
                           i, j, k, s% newton_iter, s% model_number, lblkF(i,j,k)
                     if (dblkF(i,j,k) /= 0d0) &
                        write(*,6) 'e00 ' // &
                           trim(s% nameofequ(i)) // ' ' // trim(s% nameofvar(j)), &
                           i, j, k, s% newton_iter, s% model_number, dblkF(i,j,k)
                     if (ublkF(i,j,k) /= 0d0) &
                        write(*,6) 'ep1 ' // &
                           trim(s% nameofequ(i)) // ' ' // trim(s% nameofvar(j)), &
                           i, j, k, s% newton_iter, s% model_number, ublkF(i,j,k)
                  end do
               end do
            end do
         end subroutine output_mtx
         
         
         subroutine solve_mtx_qp
            use star_bcyclic_qp, only: bcyclic_solve_qp
            include 'formats.dek'
            integer :: i
            do i=1,neq
               B1_qp(i) = B1(i)
            end do
            call bcyclic_solve_qp( &
               s, lblkF1_qp, dblkF1_qp, ublkF1_qp, ipiv_blk1, B1_qp, &
               nvar, nz, sparse, &
               lrd, rpar_decsol, lid, ipar_decsol, info)
            do i=1,neq
               B1(i) = B1_qp(i)
            end do
         end subroutine solve_mtx_qp
         
         
         subroutine solve_mtx
            use star_bcyclic, only: bcyclic_solve
            use mtx_lib, only: lapack_decsol
            integer :: ldafb, info_solve, info_dealloc
            include 'formats.dek'
            
            !write(*,*) 'sparse', sparse
            
            if (s% trace_newton_bcyclic_solve_input .and. &
                s% model_number >= s% trace_newton_bcyclic_steplo .and. &
                s% model_number <= s% trace_newton_bcyclic_stephi .and. &
                iter >= s% trace_newton_bcyclic_iterlo .and. &
                iter <= s% trace_newton_bcyclic_iterhi) then
               write(*,3) 'newton call bcyclic_solve', iter, s% model_number
               do k=max(1,s% trace_newton_bcyclic_nzlo),min(nz,s% trace_newton_bcyclic_nzhi)
                  do j=max(1,s% trace_newton_bcyclic_jlo),min(nvar, s% trace_newton_bcyclic_jhi)
                     if (.true. .or. B(j,k) /= 0d0) write(*,5) 'B ' // trim(s% nameofvar(j)), &
                        j, k, iter, s% model_number, B(j,k)
                  end do
               end do
            end if
            
            if (which_decsol == lapack) then
               ldafb=2*mljac+mujac+1
               call lapack_decsol( &
                  1, neq, ldafb, AF1, mljac, mujac, B1, ipiv_blk1,  &
                  lrd, rpar_decsol, lid, ipar_decsol, info_solve)     
               call lapack_decsol( &
                  2, neq, ldafb, AF1, mljac, mujac, B1, ipiv_blk1,  &
                  lrd, rpar_decsol, lid, ipar_decsol, info_dealloc)
               if (info_solve /= 0 .or. info_dealloc /= 0) info = -1
               return
            end if

            call bcyclic_solve( &
               s, lblkF1, dblkF1, AF1_qp, ublkF1, ipiv_blk1, B1, &
               nvar, nz, sparse, &
               lrd, rpar_decsol, lid, ipar_decsol, info)
               
            if (s% trace_newton_bcyclic_solve_output .and. &
                s% model_number >= s% trace_newton_bcyclic_steplo .and. &
                s% model_number <= s% trace_newton_bcyclic_stephi .and. &
                iter >= s% trace_newton_bcyclic_iterlo .and. &
                iter <= s% trace_newton_bcyclic_iterhi) then
               write(*,3) 'newton after bcyclic_solve', iter, s% model_number
               do k=max(1,s% trace_newton_bcyclic_nzlo),min(nz,s% trace_newton_bcyclic_nzhi)
                  do j=max(1,s% trace_newton_bcyclic_jlo),min(nvar, s% trace_newton_bcyclic_jhi)
                     if (.true. .or. B(j,k) /= 0d0) write(*,5) 'X ' // trim(s% nameofvar(j)), &
                        j, k, iter, s% model_number, B(j,k)
                  end do
               end do
            end if
            
         end subroutine solve_mtx
         
         
         logical function do_enter_setmatrix( &
                  neq, dx, xscale, lrpar, rpar, lipar, ipar, ierr)
            ! create jacobian by using numerical differences for partial derivatives
            implicit none
            integer, intent(in) :: neq
            real(dp), pointer, dimension(:,:) :: dx, ddx, xscale
            integer, intent(in) :: lrpar, lipar
            real(dp), intent(inout) :: rpar(:) ! (lrpar)
            integer, intent(inout) :: ipar(:) ! (lipar)
            integer, intent(out) :: ierr
            logical :: need_solver_to_eval_jacobian
            integer :: i, j, k
            include 'formats.dek'
            need_solver_to_eval_jacobian = .true.
            call enter_setmatrix(iter,  &
                  nvar, nz, neq, dx, xscale, xder, need_solver_to_eval_jacobian,  &
                  size(A,dim=1), A1, AF1_qp, idiag, lrpar, rpar, lipar, ipar, ierr)
            do_enter_setmatrix = need_solver_to_eval_jacobian
         end function do_enter_setmatrix


         subroutine setmatrix( &
               neq, dx, xscale, dxsave, ddxsave, lrpar, rpar, lipar, ipar, ierr)
            ! create jacobian by using numerical differences for partial derivatives
            use star_utils, only: e00, em1, ep1
            integer, intent(in) :: neq
            real(dp), pointer, dimension(:,:) :: dx, xscale, dxsave, ddxsave
            integer, intent(in) :: lrpar, lipar
            real(dp), intent(inout) :: rpar(:) ! (lrpar)
            integer, intent(inout) :: ipar(:) ! (lipar)
            integer, intent(out) :: ierr

            integer :: k, i_var, i_equ, k_off, cnt_00, cnt_m1, cnt_p1
            real(dp), dimension(:,:), pointer :: &
               save_equ, save_dx, save_xtra, xtra
            real(dp) :: dvar, dequ, dxtra
            logical :: need_solver_to_eval_jacobian
            
            integer, parameter :: num_xtra = 3
               ! opacity, lnP, lnT
            !integer, parameter :: num_xtra = 7 
               ! opacity, lnP, lnT, Qvisc, eps_visc, dvdt_visc, eta_visc
            character (len=32) :: xtra_names(7) ! allocate for max num
            real(dp), pointer, dimension(:,:,:) :: & ! (nvar,num_xtra,nz)
               xtra_lblk_analytic, xtra_dblk_analytic, xtra_ublk_analytic, &
               xtra_lblk_numeric, xtra_dblk_numeric, xtra_ublk_numeric
            
            include 'formats.dek'

            ierr = 0
            
            need_solver_to_eval_jacobian = do_enter_setmatrix( &
                  neq, dx, xscale, lrpar, rpar, lipar, ipar, ierr)     
            if (ierr /= 0) return
            if (.not. need_solver_to_eval_jacobian) return

            ! xder has been set by enter_setmatrix
            ! and pointers have been set for e00, em1, and ep1.
            
            ! mimic call on eval_partials by making calls like the following:

               ! call e00(s, xscale, i_equ, i_var, k, nvar, dequ_dvar)
               ! call em1(s, xscale, i_equ, i_var, k, nvar, dequ_dvar)
               ! call ep1(s, xscale, i_equ, i_var, k, nvar, dequ_dvar)
            
            write(*,3) '1st call eval_equations'       
            call eval_equations( &
               iter, nvar, nz, dx, xscale, equ, lrpar, rpar, lipar, ipar, ierr)
            if (ierr /= 0) then
               write(*,3) '1st call eval_equations failed'       
               stop 'setmatrix'
            end if
            
            allocate( &
               save_dx(nvar,nz), &
               save_equ(nvar,nz), &
               xtra(num_xtra,nz), save_xtra(num_xtra,nz), &
               xtra_lblk_analytic(num_xtra,nvar,nz), &
               xtra_dblk_analytic(num_xtra,nvar,nz), &
               xtra_ublk_analytic(num_xtra,nvar,nz), &
               xtra_lblk_numeric(num_xtra,nvar,nz), &
               xtra_dblk_numeric(num_xtra,nvar,nz), &
               xtra_ublk_numeric(num_xtra,nvar,nz))
            
            do k=1,nz
               do j=1,nvar
                  save_dx(j,k) = dx(j,k)
                  save_equ(j,k) = equ(j,k)
               end do
            end do
            
            xtra_names(1) = 'opacity'
            xtra_names(2) = 'lnP'
            xtra_names(3) = 'lnT'
            if (num_xtra == 7) then
               xtra_names(4) = 'Qvisc'
               xtra_names(5) = 'eps_visc'
               xtra_names(6) = 'dvdt_visc'
               xtra_names(7) = 'eta_visc'
            end if
            call set_xtras(save_xtra,num_xtra)

            xtra_lblk_numeric = 0
            xtra_dblk_numeric = 0
            xtra_ublk_numeric = 0
            
            cnt_00 = 0
            cnt_m1 = 0
            cnt_p1 = 0
            do i_var = 1, nvar               
               do k_off = 0, 2    
                          
                  do k = 1+k_off, nz, 3
                     dx(i_var,k) = save_dx(i_var,k) + xder(i_var,k)
                     if (dx(i_var,k) == save_dx(i_var,k)) then
                        write(*,3) 'xder too small', i_var, k, xder(i_var,k)
                        stop 'setmatrix'
                     end if
                  end do
                  
                  call eval_equations( &
                     iter, nvar, nz, dx, xscale, equ, lrpar, rpar, lipar, ipar, ierr)
                  if (ierr /= 0) then
                     !exit               
                     write(*,3) 'call eval_equations failed'       
                     stop 'setmatrix'
                  end if
                  
                  call set_xtras(xtra,num_xtra)
                  
                  do k = 1+k_off, nz, 3
                     dvar = xder(i_var,k) 
                                         
                     do i_equ = 1, nvar                     
                        ! e00(i,j,k) is partial of equ(i,k) wrt var(j,k)
                        dequ = equ(i_equ,k) - save_equ(i_equ,k)
                        if (dequ /= 0) then
                           call e00(s, xscale, i_equ, i_var, k, nvar, dequ/dvar)  
                           cnt_00 = cnt_00+1
                        end if                   
                        if (k > 1) then
                           ! ep1(i,j,k) is partial of equ(i,k) wrt var(j,k+1)
                           ! ep1(i,j,k-1) is partial of equ(i,k-1) wrt var(j,k)
                           dequ = equ(i_equ,k-1) - save_equ(i_equ,k-1)
                           if (dequ /= 0) then
                              call ep1(s, xscale, i_equ, i_var, k-1, nvar, dequ/dvar)
                              cnt_p1 = cnt_p1+1
                           end if
                        end if                     
                        if (k < nz) then
                           ! em1(i,j,k) is partial of equ(i,k) wrt var(j,k-1)
                           ! em1(i,j,k+1) is partial of equ(i,k+1) wrt var(j,k)
                           dequ = equ(i_equ,k+1) - save_equ(i_equ,k+1)
                           if (dequ /= 0) then
                              call em1(s, xscale, i_equ, i_var, k+1, nvar, dequ/dvar)
                              cnt_m1 = cnt_m1+1
                           end if
                        end if                     
                     end do                     

                     do i = 1, num_xtra
                        ! dblk(i,j,k) = df(i,k)/dx(j,k)
                        j = i_var
                        dxtra = xtra(i,k) - save_xtra(i,k)
                        if (is_bad_num(dxtra)) then
                           write(*,2) 'dblk dxtra', k, dxtra
                           write(*,2) 'xtra(i,k)', k, xtra(i,k)
                           write(*,2) 'save_xtra(i,k)', k, save_xtra(i,k)
                           stop
                        end if
                        xtra_dblk_numeric(i,j,k) = dxtra/dvar                        
                        if (k < nz) then
                           ! lblk(i,j,k) = df(i,k)/dx(j,k-1)
                           ! lblk(i,j,k+1) = df(i,k+1)/dx(j,k)
                           dxtra = xtra(i,k+1) - save_xtra(i,k+1)
                           if (is_bad_num(dxtra)) then
                              write(*,2) 'lblk dxtra', k, dxtra
                              stop
                           end if
                           xtra_lblk_numeric(i,j,k+1) = dxtra/dvar
                        else
                           xtra_lblk_numeric(i,j,1) = 0
                        end if                                             
                        if (k > 1) then
                           ! ublk(i,j,k) = df(i,k)/dx(j,k+1)
                           ! ublk(i,j,k-1) = df(i,k-1)/dx(j,k)
                           dxtra = xtra(i,k-1) - save_xtra(i,k-1)
                           if (is_bad_num(dxtra)) then
                              write(*,2) 'ublk dxtra', k, dxtra
                              stop
                           end if
                           xtra_ublk_numeric(i,j,k-1) = dxtra/dvar
                        else
                           xtra_ublk_numeric(i,j,nz) = 0
                        end if                                             
                     end do                     

                  end do
                  
                  do k = 1+k_off, nz, 3
                     dx(i_var,k) = save_dx(i_var,k)
                  end do    
                                   
               end do            
            end do
            !write(*,2) 'cnt_00', cnt_00
            !write(*,2) 'cnt_m1', cnt_m1
            !write(*,2) 'cnt_p1', cnt_p1
            !write(*,2) 'total', cnt_00 + cnt_m1 + cnt_p1
            
            if (ierr == 0) then
               write(*,*) 'call exit_setmatrix'
               call exit_setmatrix(iter, nvar, nz, neq,  &
                  dx, xscale, lrpar, rpar, lipar, ipar, num_xtra, xtra_names, &
                  xtra_lblk_analytic, xtra_dblk_analytic, xtra_ublk_analytic, &
                  xtra_lblk_numeric, xtra_dblk_numeric, xtra_ublk_numeric, &
                  ierr)
            end if
            
            deallocate(save_dx, save_equ, save_xtra, xtra, &
               xtra_lblk_analytic, xtra_dblk_analytic, xtra_ublk_analytic, &
               xtra_lblk_numeric, xtra_dblk_numeric, xtra_ublk_numeric)
               
         end subroutine setmatrix

            
         subroutine set_xtras(x,num_xtra)
            real(dp) :: x(:,:)
            integer, intent(in) :: num_xtra
            integer :: k
            include 'formats'
            do k=1,nz
               x(1,k) = s% opacity(k)
               if (is_bad_num(x(1,k))) then
                  write(*,2) 'exit_setmatrix x(1,k)', k, x(1,k)
                  stop
               end if
            end do
            do k=1,nz
               x(2,k) = s% lnP(k)
               if (is_bad_num(x(2,k))) then
                  write(*,2) 'exit_setmatrix x(2,k)', k, x(2,k)
                  stop
               end if
            end do
            do k=1,nz
               x(3,k) = s% lnT(k)
               if (is_bad_num(x(3,k))) then
                  write(*,2) 'exit_setmatrix x(3,k)', k, x(3,k)
                  stop
               end if
            end do

            if (num_xtra == 7) then
               do k=1,nz
                  x(4,k) = s% Qvisc(k)
                  if (is_bad_num(x(4,k))) then
                     write(*,2) 'exit_setmatrix x(4,k)', k, x(4,k)
                     stop
                  end if
               end do
               do k=1,nz
                  x(5,k) = s% eps_visc(k)
                  if (is_bad_num(x(5,k))) then
                     write(*,2) 'exit_setmatrix x(5,k)', k, x(5,k)
                     stop
                  end if
               end do
               x(6,1) = 0
               do k=2,nz ! no values for dvdt_visc(1)
                  x(6,k) = s% dvdt_visc(k)
                  if (is_bad_num(x(6,k))) then
                     write(*,2) 'exit_setmatrix x(6,k)', k, x(6,k)
                     stop
                  end if
               end do
               do k=1,nz
                  x(7,k) = s% eta_visc(k)
                  if (is_bad_num(x(7,k))) then
                     write(*,2) 'exit_setmatrix x(7,k)', k, x(7,k)
                     stop
                  end if
               end do
            end if
         end subroutine set_xtras
      
      
         subroutine write_msg(msg)
            use star_utils, only: eval_total_energy_integrals
            real(dp), parameter :: secyer = 3.1558149984d7 ! seconds per year
            character(*)  :: msg
            
            real(dp) :: sum_dEdt_expected, sum_dEdt_actual, sources_and_sinks
            
            include 'formats'
            if (.not. dbg_msg) return
            
  111       format(i6, 2x, i3, 2x, a, f8.4, 6(2x, a, 1x, e10.3), 2x, a, f6.2, 2x, a)            
            write(*,111) &
               iwork(i_model_number), iter, &
               'coeff', coeff,  &
               'slope', slope,  &
               'f', f, &
               'avg resid', residual_norm,  &
               'max resid', max_residual,  &
               'avg corr', correction_norm,  &
               'max corr', max_correction,  &
               'lg dt/yr', log10_cr(max(1d-99,work(r_dt)/secyer)),  &
               trim(msg)            

            return
            
            sum_dEdt_actual = sum( &
               s% d_IE_dt_actual(1:nz) + &
               s% d_KE_dt_actual(1:nz) + &
               s% d_PE_dt_actual(1:nz))
            
            sum_dEdt_expected = sum( &
               s% d_IE_dt_expected(1:nz) + &
               s% d_KE_dt_expected(1:nz) + &
               s% d_PE_dt_expected(1:nz))
               
            sources_and_sinks = &
               !+ sum{k=1,nz}(eps(k)*dm(k)) 
               !+ 4*pi*v_center*(P(nz)*r_center^2 - Q(nz)/r_center)
               !- 4*pi*s% v(1)*s% r(1)*s% r(1)*P_surf
               + s% L_center
            if (s% L_flag) sources_and_sinks = sources_and_sinks - s% L(1)

            write(*,3) 'sources_and_sinks', &
               iter, s% model_number, sources_and_sinks
            write(*,3) 'sum_dEdt_actual', &
               iter, s% model_number, sum_dEdt_actual
            write(*,3) 'sum_dEdt_expected', &
               iter, s% model_number, sum_dEdt_expected
            write(*,3) 'sum_dEdt_expected - sum_dEdt_actual', &
               iter, s% model_number, sum_dEdt_expected - sum_dEdt_actual
            write(*,*)


         end subroutine write_msg


         subroutine pointers(ierr)
            integer, intent(out) :: ierr
      
            integer :: i, j
            character (len=strlen) :: err_msg

            ierr = 0         

            i = num_work_params+1
            
            A1(1:ndiag*neq) => work(i:i+ndiag*neq-1); i = i+ndiag*neq
            
            dxsave1(1:neq) => work(i:i+neq-1); i = i+neq
            dxsave(1:nvar,1:nz) => dxsave1(1:neq)
            
            ddxsave1(1:neq) => work(i:i+neq-1); i = i+neq
            ddxsave(1:nvar,1:nz) => ddxsave1(1:neq)
            
            B1 => work(i:i+neq-1); i = i+neq
            B(1:nvar,1:nz) => B1(1:neq)
            
            B_init1 => work(i:i+neq-1); i = i+neq
            B_init(1:nvar,1:nz) => B_init1(1:neq)
            
            grad_f1(1:neq) => work(i:i+neq-1); i = i+neq
            grad_f(1:nvar,1:nz) => grad_f1(1:neq)
            
            rhs(1:nvar,1:nz) => work(i:i+neq-1); i = i+neq
            
            xder(1:nvar,1:nz) => work(i:i+neq-1); i = i+neq
            
            ddx(1:nvar,1:nz) => work(i:i+neq-1); i = i+neq

            if (i-1 > lwork) then
               ierr = -1
               write(*,  &
                  '(a, i6, a, 99i6)') 'newton: lwork is too small.  must be at least', i-1, &
                  '   but is only ', lwork, neq, ndiag, ldAF
               return
            end if
         
            i = num_iwork_params+1
            ipiv1(1:neq) => iwork(i:i+neq-1); i = i+neq
            if (i-1 > liwork) then
               ierr = -1
               write(*, '(a, i6, a, i6)')  &
                        'newton: liwork is too small.  must be at least', i,  &
                        '   but is only ', liwork
               return
            end if
     
            ipiv_blk1(1:neq) => ipiv1(1:neq)

            A(1:ndiag,1:neq) => A1(1:ndiag*neq)
            Acopy1 => A1
            Acopy => A
            
            ublk1(1:nvar*neq) => A1(1:nvar*neq)
            dblk1(1:nvar*neq) => A1(1+nvar*neq:2*nvar*neq)
            lblk1(1:nvar*neq) => A1(1+2*nvar*neq:3*nvar*neq)
            
            lblk(1:nvar,1:nvar,1:nz) => lblk1(1:nvar*neq)
            dblk(1:nvar,1:nvar,1:nz) => dblk1(1:nvar*neq)
            ublk(1:nvar,1:nvar,1:nz) => ublk1(1:nvar*neq)

            ublkF1(1:nvar*neq) => AF1(1:nvar*neq)
            dblkF1(1:nvar*neq) => AF1(1+nvar*neq:2*nvar*neq)
            lblkF1(1:nvar*neq) => AF1(1+2*nvar*neq:3*nvar*neq)

            lblkF(1:nvar,1:nvar,1:nz) => lblkF1(1:nvar*neq)
            dblkF(1:nvar,1:nvar,1:nz) => dblkF1(1:nvar*neq)
            ublkF(1:nvar,1:nvar,1:nz) => ublkF1(1:nvar*neq)

         end subroutine pointers
         
         
         real(dp) function eval_slope(nvar, nz, grad_f, B)
            integer, intent(in) :: nvar, nz
            real(dp), intent(in), dimension(:,:) :: grad_f, B
            integer :: k, i
            eval_slope = 0
            do i=1,nvar
               eval_slope = eval_slope + dot_product(grad_f(i,1:nz),B(i,1:nz))
            end do
         end function eval_slope
         
         
         real(dp) function eval_f(nvar, nz, equ)
            integer, intent(in) :: nvar, nz
            real(dp), intent(in), dimension(:,:) :: equ
            integer :: k, i
            real(dp) :: q
            include 'formats.dek'
            eval_f = 0
            do k = 1, nz
               do i = 1, nvar
                  q = equ(i,k)
                  eval_f = eval_f + q*q
               end do
            end do
            eval_f = eval_f/2
            !write(*,1) 'do_newton: eval_f', eval_f
         end function eval_f


      end subroutine do_newton
      
   
      subroutine get_newton_work_sizes(nvar, nz, lwork, liwork, ierr)
         integer, intent(in) :: nvar, nz
         integer, intent(out) :: lwork, liwork, ierr
         
         integer :: ndiag, neq
         
         include 'formats.dek'

         ierr = 0
         neq = nvar*nz
         ndiag = 3*nvar
         liwork = num_iwork_params + neq     
         lwork = num_work_params + neq*(ndiag + 9)
         
      end subroutine get_newton_work_sizes


      end module star_newton
