! ***********************************************************************
!
!   Copyright (C) 2012  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************



! L.Brugnano, C.Magherini, F.Mugnai.
! Blended Implicit Methods for the Numerical Solution of DAE Problems,
! Jour. Comput. Appl. Mathematics  189 (2006) 34-50.


! our version is restricted to 4th order, diagonal mass matrix,
! block tridiagonal, analytic jacobian.


      module hydro_bimd
      
      
      use star_private_def
      use const_def
      use utils_lib, only: is_bad_num, set_pointer_2, set_pointer_3
      use mtx_lib
      use mtx_def
      
      
      implicit none


      integer, parameter :: caller_id = 0


      integer, parameter :: ns = 3 ! number of substeps
      real(dp), parameter :: rhobad = 0.99d0
      real(dp), parameter :: cerr_scal = 16d0
      real(dp), parameter :: gamma = 0.7387d0 ! for optimal convergence
      ! error estimate
      real(dp), parameter :: &
         vmax4_1 = 1d0/15d0, vmax4_2 = 1d0/4d0, vmax4_2_2 = 2d0/3d0
      real(dp), parameter :: &
         vmax(1:3) = (/ vmax4_1, gamma*vmax4_2, gamma*gamma*vmax4_2_2 /)
      
      
      logical, parameter :: trace = .false.
      

      contains
      

      integer function do_hydro_bimd( &
            s, dt, report, decsolblk, lrd, lid) !, &
            
            !naccept, niter, nfailerr, &
            !nfailnewt, nferrcons, &
            !maxit, index_vec_in, &
            !h, facr, facl, &
            !sfty, rtol, atol, &
            !, f0_in, y0_in, maxf0, mas, fcn, jac)
            
         ! return keep_going, retry, backup, or terminate
         
         ! do not require that functions have been evaluated for starting configuration.
         ! when finish, will have functions evaluated for the final set of primary variables.
         
         use hydro_newton_procs_dble, only: set_xscale_info
         use hydro_mtx_dble, only: set_vars_for_solver
         use hydro_eqns_dble, only: eval_equ_for_solver
         use star_utils, only: total_times
         use utils_lib, only: is_bad_num, set_pointer_2

         type (star_info), pointer :: s
         real(dp), intent(in) :: dt 
         logical, intent(in) :: report 
         interface
            include "mtx_decsolblk_dble.dek"
         end interface
         integer, intent(in) :: lrd, lid

         integer :: &
            indexd, nvar, nvar_hydro, species, nz, minit, ierr
         integer, dimension(:), pointer :: &
            index_vec, ipar_decsol1, ipar_decsol2, ipar_decsol3
         integer, dimension(:,:), pointer :: &
            ipvt1, ipvt2, ipvt3
            
         real(dp) :: t(ns)
         real(dp), dimension(:), pointer :: &
            m0, rtol, atol, rpar_decsol1, rpar_decsol2, rpar_decsol3
         real(dp), dimension(:,:), pointer :: &
            f0, err_scal, dx, xscale
         real(dp), dimension(:,:,:), pointer :: &  
            y, f, err, temp, dd, &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3
         
         integer :: niter, maxit
         real(dp) :: facr, facl, sfty, h
         integer :: time0, time1, clock_rate, i, j, k, j1, j2
         real(dp) :: total_all_before, total_all_after
         logical :: skip_partials, use_j0_for_all, converged

         logical, parameter :: dbg = trace !.false.

         include 'formats.dek'

         ierr = 0     
         do_hydro_bimd = retry
             
         if (dbg .or. trace) then
            write(*,*)
            write(*,2) 'enter hydro_bimd: model, logdt', &
               s% model_number, log10(dt/secyer)
         end if
         
         if (s% doing_timing) then
            call system_clock(time0,clock_rate)
            total_all_before = total_times(s)
         else
            total_all_before = 0
         end if
              
         call setup(ierr)
         if (ierr /= 0) return

         call do_alloc(ierr)
         if (ierr /= 0) return
            
         do j=1,nvar
            if (j == s% i_lnddot .or. j == s% i_lnTdot) then
               index_vec(j) = 2
            else
               index_vec(j) = 1
            end if
            if (s% ode_var(j)) then
               m0(j) = 1d0
            else
               m0(j) = 0d0
            end if
         end do
         indexd = maxval(index_vec(1:nvar))
         minit = indexd

         atol(1:nvar) = s% hydro_err_ratio_atol
         rtol(1:nvar) = s% hydro_err_ratio_rtol
         if (s% i_FL /= 0 .and. .not. s% include_L_in_error_est) call reset_tol(s% i_FL)
         if (s% i_vel /= 0 .and. .not. s% include_v_in_error_est) call reset_tol(s% i_vel)
         if (s% i_lnddot /= 0 .and. .not. s% include_lnddot_in_error_est) call reset_tol(s% i_lnddot)
         if (s% i_lnTdot /= 0 .and. .not. s% include_lnTdot_in_error_est) call reset_tol(s% i_lnTdot)

         if (dbg) write(*,*) 'call set_xscale_info'
         call set_xscale_info(s, nvar, nz, xscale, ierr)
         if (ierr /= 0) then
            !if (s% report_ierr) &
               write(*,2) 'hydro_seulex set_xscale_info ierr', ierr
               stop
            return
         end if 

         skip_partials = .false.
         call eval_equ_for_solver( &
            s, nvar, 1, nz, dt, skip_partials, xscale, ierr)         
         if (ierr /= 0) then
            !if (s% report_ierr) then
               write(*, *) 'hydro_seulex: eval_equ returned ierr', ierr
            !end if
            stop
            call dealloc
            return
         end if
         
         do k=1,nz
            do j=1,nvar
               f0(j,k) = s% equ_dble(j,k)
            end do
         end do

         call set_err_scaling( &
            s, nvar, nz, h, atol, rtol, err_scal, index_vec)

         do i = 1, ns
            t(i) = i*h
         end do
         
         dx(1:nvar_hydro,1:nz) = 0 
         if (nvar > nvar_hydro) then
            do k = 1, nz
               j2 = 1
               do j1 = nvar_hydro+1, nvar
                  dx(j1,k) = s% xa(j2,k) - s% xa_pre_hydro(j2,k)
                  j2 = j2+1
               end do
            end do
         end if

         call do1_step( &
            s, nvar, nz, use_j0_for_all, &
            niter, maxit, minit, indexd, index_vec, ipvt1, ipvt2, ipvt3, &
            h, t, dx, xscale, y, f0, f, err_scal, &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3, &
            m0, err, temp, dd,  facr, facl, &
            sfty, rtol, atol, decsolblk, &
            lrd, rpar_decsol1, rpar_decsol2, rpar_decsol3, &
            lid, ipar_decsol1, ipar_decsol2, ipar_decsol3, &
            ierr)
         
         write(*,2) 'do1_step ierr', ierr
         converged = (ierr == 0)
         write(*,*) 'converged', converged
         if (.true. .and. .not. converged) write(*,*) 'do1_step failed to converge'
         
         s% hydro_seulex_dt_limit = h*ns
         s% seulex_rows = niter
         
         write(*,2) 'niter, log dt/yr bimd', niter, log10(h*ns/secyer)

                  
         if (converged) then ! set final result and evaluate vars
            ! s% xa has already been updated by final call to set_vars_for_solver         
            do k=1,nz
               do j=1,nvar_hydro
                  s% xh(j,k) = s% xh_pre_hydro(j,k) + dx(j,k)
               end do
            end do
            ! a few more sanity checks before accept it
            !converged = check_after_converge(s, report, ierr)
            !if (.not. converged) then
            !   if (trace .or. report) &
            !      write(*,2) 'check_after_converge rejected: model', s% model_number
            !end if
         end if

         if (dbg) write(*,2) 'done hydro_bimd'
         if (dbg) write(*,2)
                  
         call dealloc

         if (s% doing_timing) then
            call system_clock(time1,clock_rate)
            total_all_after = total_times(s)
            ! see hydro_newton subroutine newt
         end if
         
         if (converged) then
         
            do_hydro_bimd = keep_going

            write(*, *) 'hydro_bimd converged'
            !write(*,*)

         else
         
            do_hydro_bimd = retry
            s% result_reason = hydro_failed_to_converge
            if (report .or. trace .or. .true.) then
               write(*, *) 'hydro_bimd failed to converge'
               write(*,2) 's% model_number', s% model_number
               write(*,2) 's% hydro_call_number', s% hydro_call_number
               write(*,2) 'nz', nz
               write(*,2) 's% num_retries', s% num_retries
               write(*,2) 's% num_backups', s% num_backups
               write(*,2) 's% number_of_backups_in_a_row', s% number_of_backups_in_a_row
               write(*,1) 'log dt/secyer', log10(dt/secyer)
               write(*, *) 
               !stop
            end if
         
         
            !stop
            

            return
         end if

         
         
         contains  
         
         
         
         subroutine setup(ierr)
            use num_def, only: block_tridiag_dble_matrix_type
            integer, intent(out) :: ierr
            ierr = 0
            species = s% species
            nvar_hydro = s% nvar_hydro
            nvar = s% nvar
            nz = s% nz
            s% hydro_matrix_type = block_tridiag_dble_matrix_type
            use_j0_for_all = .true.
            converged = .false.         
            niter = 0
            maxit = 10
            facr = 10.0
            facl = 0.12
            sfty = 5d-2
            h = dt/ns            
         end subroutine setup

         
         subroutine reset_tol(i)
            integer, intent(in) :: i
            atol(i) = 1d99
            rtol(i) = 1d99
         end subroutine reset_tol
      

         subroutine do_alloc(ierr)
            use alloc
            integer, intent(out) :: ierr
            ierr = 0            

            call alloc_nvar_nz(dx, ierr); if (ierr /= 0) return            
            call alloc_nvar_nz(xscale, ierr); if (ierr /= 0) return
            call alloc_nvar_nz(f0, ierr); if (ierr /= 0) return
            call alloc_nvar_nz(err_scal, ierr); if (ierr /= 0) return
            
            call alloc_nvar_nz_ns(y, ierr); if (ierr /= 0) return
            call alloc_nvar_nz_ns(f, ierr); if (ierr /= 0) return
            call alloc_nvar_nz_ns(err, ierr); if (ierr /= 0) return
            call alloc_nvar_nz_ns(temp, ierr); if (ierr /= 0) return
            call alloc_nvar_nz_ns(dd, ierr); if (ierr /= 0) return
                        
            call alloc_nvar_nvar_nz(s% lblk_dble, ierr); if (ierr /= 0) return           
            call alloc_nvar_nvar_nz(s% dblk_dble, ierr); if (ierr /= 0) return            
            call alloc_nvar_nvar_nz(s% ublk_dble, ierr); if (ierr /= 0) return            
            call alloc_nvar_nvar_nz(theta_lblk1, ierr); if (ierr /= 0) return
            call alloc_nvar_nvar_nz(theta_dblk1, ierr); if (ierr /= 0) return
            call alloc_nvar_nvar_nz(theta_ublk1, ierr); if (ierr /= 0) return
            call alloc_nvar_nvar_nz(theta_lblk2, ierr); if (ierr /= 0) return
            call alloc_nvar_nvar_nz(theta_dblk2, ierr); if (ierr /= 0) return
            call alloc_nvar_nvar_nz(theta_ublk2, ierr); if (ierr /= 0) return
            call alloc_nvar_nvar_nz(theta_lblk3, ierr); if (ierr /= 0) return
            call alloc_nvar_nvar_nz(theta_dblk3, ierr); if (ierr /= 0) return
            call alloc_nvar_nvar_nz(theta_ublk3, ierr); if (ierr /= 0) return
            
            call get_work_array(s, m0, nvar, 0, 'bimd', ierr)
            if (ierr /= 0) return
            call get_work_array(s, atol, nvar, 0, 'bimd', ierr)
            if (ierr /= 0) return
            call get_work_array(s, rtol, nvar, 0, 'bimd', ierr)
            if (ierr /= 0) return
            call get_work_array(s, rpar_decsol1, lrd, 0, 'bimd', ierr)
            if (ierr /= 0) return
            call get_work_array(s, rpar_decsol2, lrd, 0, 'bimd', ierr)
            if (ierr /= 0) return
            call get_work_array(s, rpar_decsol3, lrd, 0, 'bimd', ierr)
            if (ierr /= 0) return
                        
            call get_integer_work_array(s, index_vec, nvar, 0, ierr)
            if (ierr /= 0) return            
            call get_integer_work_array(s, ipar_decsol1, lid, 0, ierr)
            if (ierr /= 0) return            
            call get_integer_work_array(s, ipar_decsol2, lid, 0, ierr)
            if (ierr /= 0) return            
            call get_integer_work_array(s, ipar_decsol3, lid, 0, ierr)
            if (ierr /= 0) return
            
            call get_integer_2d_work_array(s, ipvt1, nvar, nz, 0, ierr)
            if (ierr /= 0) return
            call get_integer_2d_work_array(s, ipvt2, nvar, nz, 0, ierr)
            if (ierr /= 0) return
            call get_integer_2d_work_array(s, ipvt3, nvar, nz, 0, ierr)
            if (ierr /= 0) return
            
            !f0 = 0
            return

            err_scal = 0
            return 

            err = 0
            dx = 0
            xscale = 0
            
            return

            y = 0
            f = 0
            temp = 0
            dd = 0

            s% lblk_dble = 0
            s% dblk_dble = 0
            s% ublk_dble = 0
            theta_lblk1 = 0
            theta_dblk1 = 0
            theta_ublk1 = 0
            theta_lblk2 = 0
            theta_dblk2 = 0
            theta_ublk2 = 0
            theta_lblk3 = 0
            theta_dblk3 = 0
            theta_ublk3 = 0
            
            m0 = 0
            rpar_decsol1 = 0
            rpar_decsol2 = 0
            rpar_decsol3 = 0
            
            index_vec = 0
            ipar_decsol1 = 0
            ipar_decsol2 = 0
            ipar_decsol3 = 0
            
            ipvt1 = 0
            ipvt2 = 0
            ipvt3 = 0

         end subroutine do_alloc
         
         
         subroutine alloc_nvar_nz(p, ierr)
            use alloc, only: get_2d_work_array
            real(dp), pointer :: p(:,:)
            integer, intent(out) :: ierr
            call get_2d_work_array( &
               s, p, nvar, nz, nz_alloc_extra, 'hydro_bimd', ierr)
         end subroutine alloc_nvar_nz
         
         
         subroutine alloc_nvar_nvar_nz(p, ierr)
            use alloc, only: get_3d_work_array
            real(dp), pointer :: p(:,:,:)
            integer, intent(out) :: ierr
            call get_3d_work_array( &
               s, p, nvar, nvar, nz, nz_alloc_extra, 'hydro_bimd', ierr)
         end subroutine alloc_nvar_nvar_nz
         
         
         subroutine alloc_nvar_nz_ns(p, ierr)
            use alloc, only: get_3d_work_array
            real(dp), pointer :: p(:,:,:)
            integer, intent(out) :: ierr
            call get_3d_work_array( &
               s, p, nvar, nz, ns, 0, 'hydro_bimd', ierr)
         end subroutine alloc_nvar_nz_ns
         
         
         subroutine dealloc

            call return_2d(dx)
            call return_2d(xscale)
            call return_2d(f0)
            call return_2d(err_scal)

            call return_3d(y)
            call return_3d(f)
            call return_3d(err)
            call return_3d(temp)
            call return_3d(dd)

            call return_3d(s% lblk_dble)
            call return_3d(s% dblk_dble)
            call return_3d(s% ublk_dble)
            call return_3d(theta_lblk1)
            call return_3d(theta_dblk1)
            call return_3d(theta_ublk1)
            call return_3d(theta_lblk2)
            call return_3d(theta_dblk2)
            call return_3d(theta_ublk2)
            call return_3d(theta_lblk3)
            call return_3d(theta_dblk3)
            call return_3d(theta_ublk3)
            
            call return_1d(m0)
            call return_1d(atol)
            call return_1d(rtol)
            call return_1d(rpar_decsol1)
            call return_1d(rpar_decsol2)
            call return_1d(rpar_decsol3)
            
            call return_integer_1d(index_vec)
            call return_integer_1d(ipar_decsol1)
            call return_integer_1d(ipar_decsol2)
            call return_integer_1d(ipar_decsol3)
            
            call return_integer_2d(ipvt1)
            call return_integer_2d(ipvt2)
            call return_integer_2d(ipvt3)

         end subroutine dealloc

         
         subroutine return_1d(p)
            use alloc
            real(dp), pointer :: p(:)
            call return_work_array(s, p, 'hydro_bimd')
         end subroutine return_1d
         
         
         subroutine return_2d(p)
            use alloc
            real(dp), pointer :: p(:,:)
            call return_2d_work_array(s, p, 'hydro_bimd')
         end subroutine return_2d
         
         
         subroutine return_3d(p)
            use alloc
            real(dp), pointer :: p(:,:,:)
            call return_3d_work_array(s, p, 'hydro_bimd')
         end subroutine return_3d
         
         
         subroutine return_integer_1d(p)
            use alloc
            integer, pointer :: p(:)
            call return_integer_work_array(s, p)
         end subroutine return_integer_1d
         
         
         subroutine return_integer_2d(p)
            use alloc
            integer, pointer :: p(:,:)
            call return_integer_2d_work_array(s, p)
         end subroutine return_integer_2d




      end function do_hydro_bimd


      
      
! intent(in)   
   ! integer  
      ! m = the size of the problem
      ! maxit = max number of blended iterations per integration step
      ! index_vec(:)   index_vec(i) gives index of variable i (1, 2, or 3)
      
   ! real(dp)
      ! facr, facl: the new stepsize must satisfy facl <= hnew/hold <= facr.
         !           (default: facl=1.2d-1, facr=1d1)
      ! sfty = safety factor for predicting the new stepsize  (default = 1d0/2d1)
      ! rtol = relative tolerance
      ! atol = absolute tolerances
      ! m0(:) = mass matrix (diagonal)

! intent(inout)  
   ! integer  
      ! niter = number of blended iterations

      ! nfailconv
      ! nferrcons

   ! real(dp)
      ! t0 = (input) the starting time point of the step
         ! (output) value of t up to where the solution has been computed
         !           (if the integration has been succesfull,then t0=tend)
      ! h = current timestep
      ! y0(:,:)
      ! f0(:,:)

! intent(out)  
   ! integer  
      ! ierr = 0 iff all okay
   ! real(dp)
      ! t(:) = internal times points for the substeps
      ! y(:,:,:) = the current numerical solution at each substep
      ! f(:,:,:) = contains the values of fcn(t,y) at each substep

            
            
! WORK ARRAYS
   ! ipvt
   ! fj0(:), err_scal(:,:), theta(:,:,:), err(:,:,:), temp(:,:,:), dd(:,:,:)


      subroutine do1_step( &
            s, nvar, nz, use_j0_for_all, &
            niter, maxit, minit, indexd, index_vec, ipvt1, ipvt2, ipvt3, &
            h, t, dx, xscale, y, f0, f, err_scal, &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3, &
            m0, err, temp, dd,  facr, facl, &
            sfty, rtol, atol, decsolblk, &
            lrd, rpar_decsol1, rpar_decsol2, rpar_decsol3, &
            lid, ipar_decsol1, ipar_decsol2, ipar_decsol3, &
            ierr)

         type (star_info), pointer :: s
         
         integer, intent(in) :: &
            nvar, nz, minit, maxit, indexd, index_vec(:)
         logical, intent(in) :: use_j0_for_all
         integer, intent(out) :: niter, ierr
         integer, dimension(:,:), pointer, intent(out) :: ipvt1, ipvt2, ipvt3
            
         real(dp), intent(in) :: &
            facr, facl, sfty, rtol(:), atol(:), m0(:)
         real(dp), intent(inout) :: h, t(:)
         real(dp), dimension(:,:), pointer, intent(inout) :: &
            dx, xscale, f0
         real(dp), pointer, intent(out) :: &
            y(:,:,:), f(:,:,:)
         
         real(dp), pointer, intent(out) :: &
            err_scal(:,:), err(:,:,:), temp(:,:,:), dd(:,:,:)
            
         real(dp), dimension(:,:,:), pointer, intent(out) :: &     
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3

         interface
            include "mtx_decsolblk_dble.dek"
         end interface

         integer, intent(in) :: lrd, lid
         integer, dimension(:), pointer :: &
            ipar_decsol1, ipar_decsol2, ipar_decsol3
         real(dp), dimension(:), pointer :: &
            rpar_decsol1, rpar_decsol2, rpar_decsol3

         logical :: local_err_okay, &
            have_factored1, have_factored2, have_factored3
         integer :: i, j, k
         real(dp) :: esp, hnew, nerr, nerr0, nerrstop
            
         include 'formats.dek'

         ierr = 0
         esp = 1d0/dble(ns+1)
         have_factored1 = .false.
         have_factored2 = .false.
         have_factored3 = .false.

         ! initialize y & f
         do j = 1, ns
            do k = 1, nz
               do i = 1, nvar
                  y(i,k,j) = dx(i,k)
                  f(i,k,j) = f0(i,k)
               end do
            end do
         end do

         call create1_theta( & 
            ! using current values for s% lblk_dble, dblk_dble, ublk_dble
            s, nvar, nz, h, m0, theta_lblk1, theta_dblk1, theta_ublk1, &
            have_factored1, ipvt1, &
            decsolblk, lrd, rpar_decsol1, lid, ipar_decsol1, ierr)                     
         if (ierr /= 0) then
            if (trace) write(*,*) 'create1_theta failed at start of step'
            return
         end if
   
         call newton_iterations( &
            s, nvar, nz, use_j0_for_all, ipvt1, ipvt2, ipvt3, &
            t, dx, xscale, f0, &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3, &
            have_factored1, have_factored2, have_factored3, &
            h, m0, err_scal, err, y, f, &
            nerr0, nerrstop, nerr, niter, indexd, &
            maxit, minit, temp, &
            decsolblk, lrd, rpar_decsol1, rpar_decsol2, rpar_decsol3, &
            lid, ipar_decsol1, ipar_decsol2, ipar_decsol3, ierr)
         if (ierr /= 0) then
            if (trace) write(*,*) 'newton_iterations returned ierr'
            call dealloc
            return
         end if
         
         
         if (nerr > nerrstop) write(*,*) 'newton_iterations nerr > nerrstop'
         
         local_err_okay = check_local_error( &
            s, nvar, nz, use_j0_for_all, index_vec, &
            ipvt1, ipvt2, ipvt3, indexd, maxit, &
            sfty, esp, facr, facl, &
            f0, f, err, err_scal, nerr, t, y, xscale, &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3, &
            have_factored1, have_factored2, have_factored3, &
            m0, nerrstop, h, hnew, decsolblk, &
            lrd, rpar_decsol1, rpar_decsol2, rpar_decsol3, &
            lid, ipar_decsol1, ipar_decsol2, ipar_decsol3)
         if (.not. local_err_okay) then
            if (.true.) write(*,*) 'check_local_error rejected step'
            call dealloc
            ierr = -1
            return
         end if
         
         write(*,1) 'step accepted nerr',nerr
         
         ! step accepted
         do k=1,nz
            do i=1,nvar
               dx(i,k) = y(i,k,ns)
               f0(i,k) = f(i,k,ns)
            end do
         end do
         hnew = max(hnew, facl*h)
         h = min(hnew, facr*h)
         
         call dealloc
         
         if (trace) write(*,*) 'do1_step returning okay', ierr
            
         contains
         
         
         subroutine dealloc
            integer :: ierr
            if (have_factored1) &
               call dealloc_blk( &
                  nvar, nz, theta_lblk1, theta_dblk1, theta_ublk1, ipvt1, &
                  decsolblk, lrd, rpar_decsol1, lid, ipar_decsol1, ierr)
            if (have_factored2) &
               call dealloc_blk( &
                  nvar, nz, theta_lblk2, theta_dblk2, theta_ublk2, ipvt2, &
                  decsolblk, lrd, rpar_decsol2, lid, ipar_decsol2, ierr)
            if (have_factored3) &
               call dealloc_blk( &
                  nvar, nz, theta_lblk3, theta_dblk3, theta_ublk3, ipvt3, &
                  decsolblk, lrd, rpar_decsol3, lid, ipar_decsol3, ierr)
         end subroutine dealloc

      end subroutine do1_step


      subroutine create1_theta( &
            s, nvar, nz, h, m0, theta_lblk, theta_dblk, theta_ublk, &
            have_factored, ipvt, &
            decsolblk, lrd, rpar_decsol, lid, ipar_decsol, ierr)       
         
         type (star_info), pointer :: s
         integer, intent(in) :: nvar, nz
         real(dp), intent(in) :: h
         real(dp), intent(in) :: m0(:)
         real(dp), dimension(:,:,:), pointer, intent(out) :: &
            theta_lblk, theta_dblk, theta_ublk
         logical, intent(inout) :: have_factored
         integer, intent(in) :: lrd, lid
         interface
            include "mtx_decsolblk_dble.dek"
         end interface
         integer, dimension(:), pointer :: ipar_decsol
         real(dp), dimension(:), pointer :: rpar_decsol
         integer, dimension(:,:), pointer, intent(out) :: ipvt
         integer, intent(out) :: ierr
         
         integer :: i, j, k
         real(dp) :: hgamma        
         include 'formats.dek'
         ierr = 0
         hgamma = h*gamma
         do k=1,nz
            do j=1,nvar
               do i=1,nvar
                  theta_ublk(i,j,k) = -hgamma*s% ublk_dble(i,j,k)
                  theta_dblk(i,j,k) = -hgamma*s% dblk_dble(i,j,k)
                  theta_lblk(i,j,k) = -hgamma*s% lblk_dble(i,j,k)
               end do
               theta_dblk(j,j,k) = theta_dblk(j,j,k) + m0(j)
            end do
         end do         
         call factor_blk( &
            nvar, nz, have_factored, theta_lblk, theta_dblk, theta_ublk, ipvt, &
            decsolblk, lrd, rpar_decsol, lid, ipar_decsol, ierr)
         
      end subroutine create1_theta
      
      
      subroutine set_err_scaling( &
            s, nvar, nz, h, atol, rtol, err_scal, index_vec)
         type (star_info), pointer :: s
         integer, intent(in) :: nvar, nz, index_vec(:)
         real(dp), intent(in) :: h, rtol(:), atol(:)
         real(dp), pointer, intent(out) :: err_scal(:,:)
         integer :: i, k, nvar_hydro
         real(dp) :: cerr_scal2, cerr_scal3, x, uround, atl, rtl
         include 'formats.dek' 
         uround = 1d-16
         nvar_hydro = s% nvar_hydro      
         do i=1,nvar
            atl = atol(i)
            rtl = rtol(i)
            select case (index_vec(i))
            case (1)
               do k=1,nz
                  if (i <= nvar_hydro) then
                     x = s% xh_pre_hydro(i,k)
                  else
                     x = s% xa_pre_hydro(i-nvar_hydro,k)
                  end if
                  err_scal(i,k) = 1d0/(atl+rtl*dabs(x))
               end do
            case (2)
               cerr_scal2 = min(1d0, rtl/(uround*1d2*cerr_scal))
               do k=1,nz
                  if (i <= nvar_hydro) then
                     x = s% xh_pre_hydro(i,k)
                  else
                     x = s% xa_pre_hydro(i-nvar_hydro,k)
                  end if
                  err_scal(i,k) = min(1d0, cerr_scal2*h)/(atl+rtl*dabs(x))
               end do
            case (3)
               cerr_scal3 = min(1d0, rtl/(uround*1d2*cerr_scal*cerr_scal))   
               do k=1,nz
                  if (i <= nvar_hydro) then
                     x = s% xh_pre_hydro(i,k)
                  else
                     x = s% xa_pre_hydro(i-nvar_hydro,k)
                  end if
                  err_scal(i,k) = min(1d0, cerr_scal3*h*h)/(atl+rtl*dabs(x))
               end do
            case default
               stop 'bad val index_vec in set_err_scaling'
            end select
         end do         
      end subroutine set_err_scaling
      
      
      subroutine newton_iterations( &
            s, nvar, nz, use_j0_for_all, ipvt1, ipvt2, ipvt3, &
            t, dx, xscale, f0, &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3, &
            have_factored1, have_factored2, have_factored3, &
            h, m0, err_scal, err, y, f, &
            nerr0, nerrstop, nerr, niter, indexd, &
            maxit, minit, temp, &
            decsolblk, lrd, rpar_decsol1, rpar_decsol2, rpar_decsol3, &
            lid, ipar_decsol1, ipar_decsol2, ipar_decsol3, ierr)

         type (star_info), pointer :: s
         integer, intent(in) :: nvar, nz, indexd, maxit, minit
         logical, intent(in) :: use_j0_for_all
         integer, dimension(:,:), pointer, intent(inout) :: &
            ipvt1, ipvt2, ipvt3
         integer, intent(out) :: niter
         interface
            include "mtx_decsolblk_dble.dek"
         end interface
         integer :: lrd, lid
         integer, dimension(:), pointer :: &
            ipar_decsol1, ipar_decsol2, ipar_decsol3
         real(dp), dimension(:), pointer :: &
            rpar_decsol1, rpar_decsol2, rpar_decsol3
         integer, intent(out) :: ierr

         real(dp), pointer, intent(in) :: &
            f0(:,:), err_scal(:,:)
         real(dp), intent(in) :: t(:), h, m0(:)
         real(dp), pointer, intent(inout) :: dx(:,:), xscale(:,:)
         real(dp), dimension(:,:,:), pointer, intent(inout) :: &
            y, f, temp, err, &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3
         logical, intent(inout) :: &
            have_factored1, have_factored2, have_factored3
         real(dp), intent(out) :: nerr, nerr0, nerrstop

         integer :: i, it
         logical :: do_another_iteration
         real(dp) :: rho, rho0, nerrold

         include 'formats.dek'

         it = 0
         rho = 0d0
         nerr0 = 1d0
         nerrstop = 1d0
         ierr = 0

         do it = 1, maxit ! newton loop

            !write(*,2) 'call blendstep4: it', it
            call blendstep4( &
               s, nvar, nz, use_j0_for_all, dx, xscale, f0, y, f, h, &
               theta_lblk1, theta_dblk1, theta_ublk1, &
               theta_lblk2, theta_dblk2, theta_ublk2, &
               theta_lblk3, theta_dblk3, theta_ublk3, &
               have_factored1, have_factored2, have_factored3, &
               ipvt1, ipvt2, ipvt3, err, m0, temp, &
               decsolblk, lrd, rpar_decsol1, rpar_decsol2, rpar_decsol3, &
               lid, ipar_decsol1, ipar_decsol2, ipar_decsol3, ierr)
            if (ierr /= 0) then
               if (trace) write(*,2) 'blendstep4 failed on it', it
               return
            end if

            call norm(s, nvar, nz, ns, err_scal, err, nerr)
            ! check for nans
            if (is_bad_num(nerr)) then
               ierr = -1
               if (trace) write(*,2) 'norm gave bad num on it', it
               return
            end if

            ! decide whether to accept, give up, or try again.
            nerrold = nerr0
            nerr0 = nerr
            rho0 = rho
            rho = nerr0/nerrold
            if (it > 1) rho = sqrt(rho0*rho)
            if (it < minit) then
               do_another_iteration = .true.
            else
               do_another_iteration = &
                  (nerr > nerrstop) .and. &
                  (it <= maxit) .and. &
                  (it <= indexd+1 .or. rho <= rhobad)
            end if
            
            if (.true. .or. trace) write(*,2) 'nerr', it, nerr
            
            if (.not. do_another_iteration) exit
            
            !write(*,2) 'call eval_for_newton_iteration: it', it
            call eval_for_newton_iteration( &
               s, nvar, nz, use_j0_for_all, h, f0, dx, xscale, m0, &
               theta_lblk1, theta_dblk1, theta_ublk1, &
               theta_lblk2, theta_dblk2, theta_ublk2, &
               theta_lblk3, theta_dblk3, theta_ublk3, &
               have_factored1, have_factored2, have_factored3, &
               ipvt1, ipvt2, ipvt3, &
               decsolblk, lrd, rpar_decsol1, rpar_decsol2, rpar_decsol3, &
               lid, ipar_decsol1, ipar_decsol2, ipar_decsol3, &   
               t, y, f, ierr)
            if (ierr /= 0) then
               if (trace) &
                  write(*,2) 'eval_for_newton_iteration failed on it', it
               return
            end if

         end do ! newton loop

         niter = niter + it

      end subroutine newton_iterations


      subroutine blendstep4( & ! blended iteration for the 4th order method
            s, nvar, nz, use_j0_for_all, dx, xscale, f0, y, f, h, &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3, &
            have_factored1, have_factored2, have_factored3, &
            ipvt1, ipvt2, ipvt3, z, m0, mz, &
            decsolblk, lrd, rpar_decsol1, rpar_decsol2, rpar_decsol3, &
            lid, ipar_decsol1, ipar_decsol2, ipar_decsol3, ierr)         
        
         type (star_info), pointer :: s
         integer, intent(in) :: nvar, nz
         logical, intent(in) :: use_j0_for_all
         integer, dimension(:,:), pointer, intent(in) :: ipvt1, ipvt2, ipvt3
         real(dp), intent(in) :: h, m0(:)
         real(dp), pointer, intent(in) :: dx(:,:), xscale(:,:), f0(:,:)
         real(dp), dimension(:,:,:), pointer, intent(in) :: &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3
         logical, intent(inout) :: &
            have_factored1, have_factored2, have_factored3
         real(dp), pointer, intent(inout) :: y(:,:,:), f(:,:,:), mz(:,:,:)
         real(dp), pointer, intent(out) :: z(:,:,:)
         interface
            include "mtx_decsolblk_dble.dek"
         end interface
         integer :: lrd, lid
         integer, dimension(:), pointer :: &
            ipar_decsol1, ipar_decsol2, ipar_decsol3
         real(dp), dimension(:), pointer :: &
            rpar_decsol1, rpar_decsol2, rpar_decsol3
         integer, intent(out) :: ierr

         integer :: i, j, k
         real(dp) :: mm

         real(dp), parameter :: &
            da4_1_1 = -102133d0/405d3, &
            da4_1_2 = 98743d0/18d4, &
            da4_1_3 = -7387d0/225d2, &
            da4_1_4 = +51709d0/162d4, &
            da4_2_1 = -950353d0/81d4, &
            da4_2_2 = +7387d0/9d3, &
            da4_2_3 = 10613d0/18d3, &
            da4_2_4 = -96031d0/405d3, &
            da4_3_1 = -22613d0/3d4, &
            da4_3_2 = -22161d0/2d4, &
            da4_3_3 = +22161d0/1d4, &
            da4_3_4 = -21257d0/6d4, &

            a24_1_1 = -302867d0/405d3, &
            a24_1_2 = 81257d0/18d4, &
            a24_1_3 = 7387d0/225d2, &
            a24_1_4 = -51709d0/162d4, &
            a24_2_1 = 140353d0/81d4, &
            a24_2_2 = -7387d0/9d3, &
            a24_2_3 = 7387d0/18d3, &
            a24_2_4 = 96031d0/405d3, &
            a24_3_1 = -7387d0/3d4, &
            a24_3_2 = 22161d0/2d4, &
            a24_3_3 = -22161d0/1d4, &
            a24_3_4 = 81257d0/6d4, &

            db4_1_1 = 919d0/135d2, &
            db4_1_2 = 4589d0/3d4, &
            db4_1_3 = -37d0/12d1, &
            db4_1_4 = 3d0/4d1, &
            db4_2_1 = 115387d0/27d4, &
            db4_2_2 = 17d0/15d0, &
            db4_2_3 = -6161d0/3d4, &
            db4_2_4 = -1d0/15d0, &
            db4_3_1 = 3d0/8d0, &
            db4_3_2 = 9d0/8d0, &
            db4_3_3 = 9d0/8d0, &
            db4_3_4 = -3637d0/1d4, &
            
            b24_1_1 = 7387d0/27d3, &
            b24_2_1 = -7387d0/27d4, &
            b24_3_1 = 0d0
         
         ierr = 0
         
         ! z = [dx y]*(da)'
         do k=1,nz
            do i=1,nvar
               z(i,k,1) = dx(i,k)*da4_1_1 + y(i,k,1)*da4_1_2 + y(i,k,2)*da4_1_3 + y(i,k,3)*da4_1_4
               z(i,k,2) = dx(i,k)*da4_2_1 + y(i,k,1)*da4_2_2 + y(i,k,2)*da4_2_3 + y(i,k,3)*da4_2_4
               z(i,k,3) = dx(i,k)*da4_3_1 + y(i,k,1)*da4_3_2 + y(i,k,2)*da4_3_3 + y(i,k,3)*da4_3_4
            end do
         end do

         ! mz_i = m0*z_i, i = 1, 2, 3
         do i=1,nvar
            mm = m0(i)
            do k=1,nz
               mz(i,k,1) = mm*z(i,k,1)
               mz(i,k,2) = mm*z(i,k,2)
               mz(i,k,3) = mm*z(i,k,3)
            end do
         end do

         ! mz = mz-h*[f0 f]*(db)'
         do k=1,nz
            do i=1,nvar
               mz(i,k,1) = mz(i,k,1) - &
                  h*(f0(i,k)*db4_1_1 + f(i,k,1)*db4_1_2 + f(i,k,2)*db4_1_3 + f(i,k,3)*db4_1_4)
               mz(i,k,2) = mz(i,k,2) - &
                  h*(f0(i,k)*db4_2_1 + f(i,k,1)*db4_2_2 + f(i,k,2)*db4_2_3 + f(i,k,3)*db4_2_4)
               mz(i,k,3) = mz(i,k,3) - &
                  h*(f0(i,k)*db4_3_1 + f(i,k,1)*db4_3_2 + f(i,k,2)*db4_3_3 + f(i,k,3)*db4_3_4)
            end do
         end do
         
         ! solve theta*mz = mz
         call do1_solve_blk(mz, 1, have_factored1, &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            ipvt1, rpar_decsol1, ipar_decsol1, ierr)
         if (ierr /= 0) return
         
         call do1_solve_blk(mz, 2, have_factored2, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            ipvt2, rpar_decsol2, ipar_decsol2, ierr)
         if (ierr /= 0) return

         call do1_solve_blk(mz, 3, have_factored3, &
            theta_lblk3, theta_dblk3, theta_ublk3, &
            ipvt3, rpar_decsol3, ipar_decsol3, ierr)
         if (ierr /= 0) return

         ! mz = mz+[dx y]*(a2)'
         do k=1,nz
            do i=1,nvar
               mz(i,k,1) = mz(i,k,1) + &
                  dx(i,k)*a24_1_1 + y(i,k,1)*a24_1_2 + y(i,k,2)*a24_1_3 + y(i,k,3)*a24_1_4
               mz(i,k,2) = mz(i,k,2) + &
                  dx(i,k)*a24_2_1 + y(i,k,1)*a24_2_2 + y(i,k,2)*a24_2_3 + y(i,k,3)*a24_2_4
               mz(i,k,3) = mz(i,k,3) + &
                  dx(i,k)*a24_3_1 + y(i,k,1)*a24_3_2 + y(i,k,2)*a24_3_3 + y(i,k,3)*a24_3_4
            end do
         end do

         ! z_i = m0*mz_i, i = 1, 2, 3
         do i=1,nvar
            mm = m0(i)
            do k=1,nz
               z(i,k,1) = mm*mz(i,k,1)
               z(i,k,2) = mm*mz(i,k,2)
               z(i,k,3) = mm*mz(i,k,3)
            end do
         end do

         ! z = z-h*[f0 f]*(b2)'
         do k=1,nz
            do i=1,nvar
               z(i,k,1) = z(i,k,1) - h*(f0(i,k)*b24_1_1 + f(i,k,1)*gamma)
               z(i,k,2) = z(i,k,2) - h*(f0(i,k)*b24_2_1 + f(i,k,2)*gamma)
               z(i,k,3) = z(i,k,3) - h*(f0(i,k)*b24_3_1 + f(i,k,3)*gamma)
            end do
         end do
         
         ! solve theta*z = z
         call do1_solve_blk(z, 1, have_factored1, &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            ipvt1, rpar_decsol1, ipar_decsol1, ierr)
         if (ierr /= 0) return

         call do1_solve_blk(z, 2, have_factored2, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            ipvt2, rpar_decsol2, ipar_decsol2, ierr)
         if (ierr /= 0) return

         call do1_solve_blk(z, 3, have_factored3, &
            theta_lblk3, theta_dblk3, theta_ublk3, &
            ipvt3, rpar_decsol3, ipar_decsol3, ierr)
         if (ierr /= 0) return

         ! y = y-z
         do j=1,ns
            do k=1,nz
               do i=1,nvar
                  y(i,k,j) = y(i,k,j) - z(i,k,j)
               end do
            end do
         end do
         
         
         contains
         
         
         subroutine do1_solve_blk( w, i, have_factored, &
               theta_lblk, theta_dblk, theta_ublk, ipvt, &
               rpar_decsol, ipar_decsol, ierr)
            real(dp) :: w(:,:,:)
            integer, intent(in) :: i
            real(dp), dimension(:,:,:), pointer, intent(in) :: &
               theta_lblk, theta_dblk, theta_ublk
            logical, intent(inout) :: have_factored
            integer, dimension(:,:), pointer, intent(in) :: ipvt
            integer, dimension(:), pointer :: ipar_decsol
            real(dp), dimension(:), pointer :: rpar_decsol
            integer, intent(out) :: ierr
            real(dp), pointer :: p2(:,:)
            call set_pointer_2(p2, w(:,:,i), nvar, nz)
            if (use_j0_for_all .or. .not. have_factored) then
               call solve_blk( &
                  nvar, nz, have_factored1, theta_lblk1, theta_dblk1, theta_ublk1, &
                  ipvt1, p2, decsolblk, lrd, rpar_decsol1, lid, ipar_decsol1, ierr)
            else
               call solve_blk( &
                  nvar, nz, have_factored, theta_lblk, theta_dblk, theta_ublk, &
                  ipvt, p2, decsolblk, lrd, rpar_decsol, lid, ipar_decsol, ierr)
            end if
         end subroutine do1_solve_blk
         

      end subroutine blendstep4
      
      
      subroutine eval_for_newton_iteration( &
            s, nvar, nz, use_j0_for_all, h, f0, dx, xscale, m0, &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3, &
            have_factored1, have_factored2, have_factored3, &
            ipvt1, ipvt2, ipvt3, &
            decsolblk, lrd, rpar_decsol1, rpar_decsol2, rpar_decsol3, &
            lid, ipar_decsol1, ipar_decsol2, ipar_decsol3, &   
            t, y, f, ierr)
            
         use hydro_mtx_dble, only: set_vars_for_solver
         use hydro_eqns_dble, only: eval_equ_for_solver

         type (star_info), pointer :: s
         integer, intent(in) :: nvar, nz
         logical, intent(in) :: use_j0_for_all

         real(dp), intent(in) :: h
         real(dp), pointer, intent(in) :: f0(:,:)
         real(dp), pointer, intent(inout) :: dx(:,:), xscale(:,:)

         real(dp), intent(in) :: m0(:)
         real(dp), dimension(:,:,:), pointer, intent(out) :: &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3
         logical, intent(inout) :: &
            have_factored1, have_factored2, have_factored3
         integer, intent(in) :: lrd, lid
         interface
            include "mtx_decsolblk_dble.dek"
         end interface
         integer, dimension(:), pointer :: &
            ipar_decsol1, ipar_decsol2, ipar_decsol3
         real(dp), dimension(:), pointer :: &
            rpar_decsol1, rpar_decsol2, rpar_decsol3
         integer, dimension(:,:), pointer, intent(out) :: &
            ipvt1, ipvt2, ipvt3
            
         real(dp), intent(in) :: t(:)
         real(dp), pointer, intent(inout) :: y(:,:,:), f(:,:,:)
         integer, intent(out) :: ierr
         
         integer :: i, ierr0, j, k
         logical :: skip_partials
         
         include 'formats.dek'
         
         ierr = 0
         
         skip_partials = use_j0_for_all
         
         do i = 1, ns

            call set_vars_for_solver( &
               s, 1, nz, i, y(:,:,i), xscale, t(i), ierr)
            if (ierr /= 0) then
               if (trace) write(*,2) 'failed in set_vars_for_solver', i
               exit
            end if

            call eval_equ_for_solver( &
               s, nvar, 1, nz, t(i), skip_partials, xscale, ierr)         
            if (ierr /= 0) then
               if (trace) write(*,2) 'failed in eval_equ', i
               exit
            end if

            do k=1,nz
               do j=1,nvar
                  f(j,k,i) = s% equ_dble(j,k)
               end do
            end do
         
            if (use_j0_for_all) cycle
            
            select case(i)
            case (1)
               call create1_theta( &
                  s, nvar, nz, h, m0, theta_lblk1, theta_dblk1, theta_ublk1, &
                  have_factored1, ipvt1, decsolblk, &
                  lrd, rpar_decsol1, lid, ipar_decsol1, ierr)                             
            case (2)
               call create1_theta( &
                  s, nvar, nz, h, m0, theta_lblk2, theta_dblk2, theta_ublk2, &
                  have_factored2, ipvt2, decsolblk, &
                  lrd, rpar_decsol2, lid, ipar_decsol2, ierr)                             
            case (3)
               call create1_theta( &
                  s, nvar, nz, h, m0, theta_lblk3, theta_dblk3, theta_ublk3, &
                  have_factored3, ipvt3, decsolblk, &
                  lrd, rpar_decsol3, lid, ipar_decsol3, ierr)                             
            end select
            if (ierr /= 0) exit
            
         end do
           
      end subroutine eval_for_newton_iteration


      subroutine norm(s, nvar, nz, nss, err_scal, err, nerr)
         type (star_info), pointer :: s
         integer, intent(in) :: nvar, nz, nss
         real(dp), pointer, intent(in) :: err_scal(:,:), err(:,:,:)
         real(dp), intent(out) :: nerr
         
         real(dp) :: nerr0, err0, nerrup
         integer :: i, j, k
         include 'formats.dek'
         nerr = 0d0
         do j = 1, nss-1
            nerr0 = 0d0
            do k=1,nz
               do i=1,nvar
                  err0 = (err(i,k,j)*err_scal(i,k))**2
                  nerr0 = nerr0 + err0
               end do
            end do
            nerr = max(nerr, nerr0)
         end do
         nerrup = 0d0
         do k=1,nz
            do i=1,nvar
               err0 = (err(i,k,nss)*err_scal(i,k))**2
               nerrup = nerrup + err0
            end do
         end do
         nerr = max(nerr, nerrup)
         nerr = sqrt(nerr/dble(nvar*nz))
         
      end subroutine norm


      logical function check_local_error( &
            s, nvar, nz, use_j0_for_all, &
            index_vec, ipvt1, ipvt2, ipvt3, indexd, maxit, &
            sfty, esp, facr, facl, &
            f0, f, err, err_scal, nerr, t, y, xscale, &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3, &
            have_factored1, have_factored2, have_factored3, &
            m0, nerrstop, h, hnew, decsolblk, &
            lrd, rpar_decsol1, rpar_decsol2, rpar_decsol3, &
            lid, ipar_decsol1, ipar_decsol2, ipar_decsol3)

         type (star_info), pointer :: s

         integer, intent(in) :: nvar, nz, index_vec(:), indexd, maxit
         logical, intent(in) :: use_j0_for_all
         integer, dimension(:,:), pointer, intent(in) :: ipvt1, ipvt2, ipvt3
         real(dp), intent(in) :: sfty, esp, facr, facl, &
            t(:), m0(:), nerrstop
         real(dp), pointer :: f0(:,:), err_scal(:,:)
         real(dp), dimension(:,:,:), pointer, intent(in) :: &
            theta_lblk1, theta_dblk1, theta_ublk1, &
            theta_lblk2, theta_dblk2, theta_ublk2, &
            theta_lblk3, theta_dblk3, theta_ublk3
         logical, intent(inout) :: &
            have_factored1, have_factored2, have_factored3
         real(dp), intent(inout) :: h, nerr
         real(dp), dimension(:,:), pointer, intent(in) :: xscale
         real(dp), dimension(:,:,:), pointer, intent(inout) :: err, y, f
         real(dp), intent(out) :: hnew
         interface
            include "mtx_decsolblk_dble.dek"
         end interface
         integer, intent(in) :: lrd, lid
         integer, dimension(:), pointer :: &
            ipar_decsol1, ipar_decsol2, ipar_decsol3
         real(dp), dimension(:), pointer :: &
            rpar_decsol1, rpar_decsol2, rpar_decsol3

         integer :: i, ierr
         real(dp) :: dt_safe1, dt_safe2, dt_fac1, dt_fac2, facmin, fac, limtr
         
         include 'formats.dek'

         if (nerr > nerrstop) then
            write(*,1) 'nerr > nerrstop', nerr, nerrstop
            check_local_error = .false.
            return
         end if

         check_local_error = .true.
         ierr = 0
         
         call eval_fcns( &
            s, nvar, nz, t, y, xscale, f, ierr)
         if (ierr /= 0)  then
            write(*,*) 'check_local_error failed in eval_fcns'
            check_local_error = .false.
            return
         end if
         
         if (use_j0_for_all .or. .not. have_factored3) then
            call localerr4( &
               s, nvar, nz, f0, f, h, err, err_scal, nerr, &
               theta_lblk1, theta_dblk1, theta_ublk1, have_factored1, ipvt1, m0, index_vec, &
               decsolblk, lrd, rpar_decsol1, lid, ipar_decsol1, ierr)
         else
            call localerr4( &
               s, nvar, nz, f0, f, h, err, err_scal, nerr, &
               theta_lblk3, theta_dblk3, theta_ublk3, have_factored3, ipvt3, m0, index_vec, &
               decsolblk, lrd, rpar_decsol3, lid, ipar_decsol3, ierr)
         end if
         if (is_bad_num(nerr) .or. ierr /= 0) then
            write(*,*) 'check_local_error failed in localerr4'
            check_local_error = .false.
            return
         end if

         if (nerr > 0d0) then
            !hnew = h*(sfty/nerr)**esp

            dt_safe1 = 0.6
            dt_safe2 = 0.93
         
            dt_fac1 = 0.1
            dt_fac2 = 4.0

            facmin = dt_fac1**esp
            fac = min(dt_fac2/facmin, max(facmin, (nerr/dt_safe1)**esp/dt_safe2))
            limtr = limiter(1d0/fac)
            hnew = h*limtr


         else
            hnew = facr*h
         end if

         if (nerr > 1d0) then ! failure due to local error test
            hnew = h*(1d-1/nerr)**esp
            h = max(hnew, facl*h)
            check_local_error = .false.
         end if
         
         
         contains

         
         real(dp) function limiter(x)
            real(dp), intent(in) :: x
            real(dp), parameter :: kappa = 2
            ! for x >= 0 and kappa = 2, limiter value is between 0.07 and 4.14
            ! for x = 1, limiter = 1
            limiter = 1 + kappa*ATAN((x-1)/kappa)
         end function limiter
         

      end function check_local_error
      
      
      subroutine eval_fcns( & ! for check_local_error
            s, nvar, nz, t, y, xscale, f, ierr)
         use hydro_mtx_dble, only: set_vars_for_solver
         use hydro_eqns_dble, only: eval_equ_for_solver

         type (star_info), pointer :: s
         integer, intent(in) :: nvar, nz
         real(dp), intent(in) :: t(:)
         real(dp), pointer, intent(in) :: xscale(:,:)
         real(dp), pointer, intent(inout) :: y(:,:,:), f(:,:,:)
         integer, intent(out) :: ierr
         integer :: i, j, k
         logical, parameter :: skip_partials = .true.     
         include 'formats.dek'
         ierr = 0         
         do i=1,ns ! note: do ns last so leave final values in vars.
            call set_vars_for_solver( &
               s, 1, nz, i, y(:,:,i), xscale, t(i), ierr)
            if (ierr /= 0) then
               if (trace) write(*,2) 'eval_fcns failed in set_vars_for_solver', i
               return
            end if
            call eval_equ_for_solver( &
               s, nvar, 1, nz, t(i), skip_partials, xscale, ierr)         
            if (ierr /= 0) then
               if (trace) write(*,2) 'eval_fcns failed in eval_equ', i
               return
            end if
            do k=1,nz
               do j=1,nvar
                  f(j,k,i) = s% equ_dble(j,k)
               end do
            end do               
         end do         
      end subroutine eval_fcns


      subroutine localerr4( &
            s, nvar, nz, f0, f, h, z, err_scal, nerr, &
            theta_lblk, theta_dblk, theta_ublk, &
            have_factored, ipvt, m0, index_vec, &
            decsolblk, lrd, rpar_decsol, lid, ipar_decsol, ierr)
         type (star_info), pointer :: s
         integer, intent(in) :: nvar, nz, index_vec(:)
         integer, pointer, intent(in) :: ipvt(:,:)
         real(dp), intent(in) :: h, m0(:)
         real(dp), pointer, intent(in) :: f0(:,:), f(:,:,:), err_scal(:,:)
         real(dp), dimension(:,:,:), intent(in), pointer :: &
            theta_lblk, theta_dblk, theta_ublk
         logical, intent(inout) :: have_factored
         real(dp), intent(out) :: nerr
         real(dp), pointer, intent(out) :: z(:,:,:)
         interface
            include "mtx_decsolblk_dble.dek"
         end interface
         integer :: lrd, lid
         integer, pointer :: ipar_decsol(:) ! (lid)
         real(dp), pointer :: rpar_decsol(:) ! (lrd)
         integer, intent(out) :: ierr
         
         real(dp), pointer :: p2(:,:), p3(:,:,:)

         integer :: i, k
         real(dp), parameter :: psi4_1 = -1d0, psi4_2 = 3d0
         
         include 'formats.dek'
         
         ierr = 0
         
         ! truncation error estimate
         do k=1,nz
            do i=1,nvar
               z(i,k,1) = h*(psi4_1*(f0(i,k)-f(i,k,3)) + psi4_2*(f(i,k,1)-f(i,k,2)))
               z(i,k,2) = z(i,k,1)
            end do
         end do
         
         ! solve theta*z_2 = z_2
         call set_pointer_2(p2, z(:,:,2), nvar, nz)
         call solve_blk( &
            nvar, nz, have_factored, theta_lblk, theta_dblk, theta_ublk, &
            ipvt, p2, decsolblk, lrd, rpar_decsol, lid, ipar_decsol, ierr)
         if (ierr /= 0) return
         
         do k=1,nz
            do i=1,nvar
               z(i,k,3) = z(i,k,1) - m0(i)*z(i,k,2)
            end do
         end do

         ! solve theta*z_3 = z_3
         call set_pointer_2(p2, z(:,:,3), nvar, nz)
         call solve_blk( &
            nvar, nz, have_factored, theta_lblk, theta_dblk, theta_ublk, &
            ipvt, p2, decsolblk, lrd, rpar_decsol, lid, ipar_decsol, ierr)
         if (ierr /= 0) return
         
         do i=1,nvar
            select case(index_vec(i))
            case (1)
               do k=1,nz
                  z(i,k,2) = vmax(1)*z(i,k,2)
                  z(i,k,3) = vmax(2)*z(i,k,3)
               end do
            case (2)
               do k=1,nz
                  z(i,k,2) = vmax(2)*z(i,k,2)
                  z(i,k,3) = vmax(2)*z(i,k,3)
               end do
            case (3)
               do k=1,nz
                  z(i,k,2) = vmax(3)*z(i,k,2)
                  z(i,k,3) = vmax(3)*z(i,k,3)/2d0
               end do
            case default
               write(*,*) 'localerr4: bad index_vec(i)', i, index_vec(i)
               stop 1
            end select
         end do

         call set_pointer_3(p3, z(:,:,2:3), nvar, nz, 2)
         call norm(s, nvar, nz, 2, err_scal, p3, nerr)

      end subroutine localerr4


      subroutine factor_blk( &
            nvar, nz, have_factored, lblk, dblk, ublk, ipiv_blk, &
            decsolblk, lrd, rpar_decsol, lid, ipar_decsol, ierr)
         integer, intent(in) :: nvar, nz
         logical, intent(inout) :: have_factored
         real(dp), dimension(:,:,:), pointer :: lblk, dblk, ublk
         integer, pointer :: ipiv_blk(:,:) ! (nvar,nz)
         interface
            include "mtx_decsolblk_dble.dek"
         end interface
         integer, intent(in) :: lrd, lid
         integer, pointer :: ipar_decsol(:) ! (lid)
         real(dp), pointer :: rpar_decsol(:) ! (lrd)
         integer, intent(out) :: ierr
         real(dp), target :: a(0,0)
         real(dp), pointer :: del(:,:)
         
         include 'formats.dek'
         
         ierr = 0
         del => a
         
         if (have_factored) then
            call decsolblk( & ! deallocate
               2, caller_id, nvar, nz, lblk, dblk, ublk, &
               del, ipiv_blk, lrd, rpar_decsol, lid, ipar_decsol, ierr)
            if (ierr /= 0) return
         end if
         
         call decsolblk( & ! factor
            0, caller_id, nvar, nz, lblk, dblk, ublk, &
            del, ipiv_blk, lrd, rpar_decsol, lid, ipar_decsol, ierr)
         if (ierr /= 0) then
            write(*,*) 'failed in decsolblk factor'
            stop 'factor_blk'
         end if
         
         have_factored = .true.
         
      end subroutine factor_blk


      subroutine solve_blk( &
            nvar, nz, have_factored, lblk, dblk, ublk, ipiv_blk, del, &
            decsolblk, lrd, rpar_decsol, lid, ipar_decsol, ierr)
         integer, intent(in) :: nvar, nz
         logical, intent(inout) :: have_factored
         real(dp), dimension(:,:,:), pointer :: lblk, dblk, ublk
         integer, pointer :: ipiv_blk(:,:) ! (nvar,nz)
         real(dp), pointer :: del(:,:) ! (nvar,nz)
         interface
            include "mtx_decsolblk_dble.dek"
         end interface
         integer, intent(in) :: lrd, lid
         integer, pointer :: ipar_decsol(:) ! (lid)
         real(dp), pointer :: rpar_decsol(:) ! (lrd)
         integer, intent(out) :: ierr
         
         ierr = 0
         
         if (.not. have_factored) then
            ierr = -1
            write(*,*) 'called solve_blk when .not. have_factored'
            stop 1
         end if
         
         call decsolblk( & ! solve
            1, caller_id, nvar, nz, lblk, dblk, ublk, &
            del, ipiv_blk, lrd, rpar_decsol, lid, ipar_decsol, ierr)
            
         if (ierr /= 0) then
            write(*,*) 'failed in decsolblk solve'
            stop 'solve_blk'
         end if         
         
      end subroutine solve_blk


      subroutine dealloc_blk( &
            nvar, nz, lblk, dblk, ublk, ipiv_blk, &
            decsolblk, lrd, rpar_decsol, lid, ipar_decsol, ierr)
         integer, intent(in) :: nvar, nz
         real(dp), dimension(:,:,:), pointer :: lblk, dblk, ublk
         integer, pointer :: ipiv_blk(:,:) ! (nvar,nz)
         interface
            include "mtx_decsolblk_dble.dek"
         end interface
         integer, intent(in) :: lrd, lid
         integer, pointer :: ipar_decsol(:) ! (lid)
         real(dp), pointer :: rpar_decsol(:) ! (lrd)
         integer, intent(out) :: ierr
         real(dp), target :: a(0,0)
         real(dp), pointer :: del(:,:)
         
         ierr = 0
         del => a
         call decsolblk( & ! deallocate
            2, caller_id, nvar, nz, lblk, dblk, ublk, &
            del, ipiv_blk, lrd, rpar_decsol, lid, ipar_decsol, ierr)            
            
      end subroutine dealloc_blk

      
      end module hydro_bimd
      
      
      
      
      
      
      

! -----------------------------------------------------------------------------------
!     the code bimd numerically solves (stiff) differential ode 
!     problems or linearly implicit dae problems of index up to 3 
!     with constant mass matrix
!
!     copyright (c)2005-2007   
!
!     authors: cecilia magherini (cecilia.magherini@ing.unipi.it)
!              luigi   brugnano  (brugnano@math.unifi.it) 
!
!
!     this program is free software; you can redistribute it and/or
!     modify it under the terms of the gnu general public license
!     as published by the free software foundation; either version 2
!     of the license, or (at your option) any later version.
!
!     this program is distributed in the hope that it will be useful,
!     but without any warranty; without even the implied warranty of
!     merchantability or fitness for a particular purpose.  see the
!     gnu general public license for more details.
!
!     licensed under the gnu general public license, version 2 or later.
!       http://www.gnu.org/licenses/info/gplv2orlater.html
!
!     you should have received a copy of the gnu general public license
!     along with this program; if not, write to the free software
!     foundation, inc., 51 franklin street, fifth floor, boston, ma  02110-1301,
!     usa.
! -----------------------------------------------------------------------------------
