! ***********************************************************************
!
!   Copyright (C) 2012  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************

! the following is the copyright for seulex

! copyright (c) 2004, ernst hairer

! redistribution and use in source and binary forms, with or without 
! modification, are permitted provided that the following conditions are 
! met:

! - redistributions of source code must retain the above copyright 
! notice, this list of conditions and the following disclaimer.

! - redistributions in binary form must reproduce the above copyright 
! notice, this list of conditions and the following disclaimer in the 
! documentation and/or other materials provided with the distribution.

! this software is provided by the copyright holders and contributors “as 
! is” and any express or implied warranties, including, but not limited 
! to, the implied warranties of merchantability and fitness for a 
! particular purpose are disclaimed. in no event shall the regents or 
! contributors be liable for any direct, indirect, incidental, special, 
! exemplary, or consequential damages (including, but not limited to, 
! procurement of substitute goods or services; loss of use, data, or 
! profits; or business interruption) however caused and on any theory of 
! liability, whether in contract, strict liability, or tort (including 
! negligence or otherwise) arising in any way out of the use of this 
! software, even if advised of the possibility of such damage.



      module hydro_seulex
      
      use star_private_def
      use const_def
      
      
      implicit none


      integer, parameter :: caller_id = 0


      contains


      

      integer function do_hydro_seulex( &
            s, dt, report, decsolblk, lrd, rpar_decsol, lid, ipar_decsol)
         ! return keep_going, retry, backup, or terminate
         
         ! do not require that functions have been evaluated for starting configuration.
         ! when finish, will have functions evaluated for the final set of primary variables.
         
         use hydro_newton, only: check_after_converge, set_xscale_info
         use hydro_mtx, only: set_vars_for_solver
         use hydro_eqns, only: eval_equ_for_solver
         use star_utils, only: total_times
         use utils_lib, only: is_bad_num, set_pointer_2

         type (star_info), pointer :: s
         real(dp), intent(in) :: dt 
         logical, intent(in) :: report 
         interface
            include "mtx_decsolblk_dble.dek"
         end interface
         integer, intent(in) :: lrd, lid
         integer, pointer :: ipar_decsol(:) ! (lid)
         real(dp), pointer :: rpar_decsol(:) ! (lrd)
      
         real(dp) :: wk_set_vars, wk_eval_equ, wk_solve_mtx, wk_prep_mtx, wk_substep
         
         integer :: nj(max_seulex_kmax) ! nj(j) is number of substeps for jth row
         real(dp) :: a(max_seulex_kmax) ! a(j) is estimated work to create j rows of solution
         real(dp) :: hh(max_seulex_kmax) ! hh(j) is optimal next timestep for solution using j rows
         real(dp) :: w(max_seulex_kmax) ! w(j) is a(j)/hh(j) as estimate of cost benefit ratio for j rows
          
         integer :: time0, time1, clock_rate
         real(dp) :: total_all_before, total_all_after
         
         integer :: ierr, kmax, km2, nz, nvar, nvar_hydro, species, &
            i, j, k, kopt, kopt_prev, kc, iout
         real(dp) :: h, hj, dt_fac1, dt_fac2, dt_fac3, dt_fac4, dt_safe1, dt_safe2, atol, rtol
         logical :: trace, skip_partials, dt_by_err_ratio_avg, &
            have_a_factored_mtx, converged
         real(dp) :: err, err_ratio_avg_lim, err_ratio_max_lim, err_ratio_avg, err_ratio_max
         logical :: reject ! reject means finished row but error too large
         integer :: i_vel, i_FL, i_lnddot, i_lnTdot, j_max_err, k_max_err, &
            skip1, skip2, skip3, skip4, prev_nzlo, prev_nzhi
         
         real(dp), pointer :: error_vectors(:,:,:) ! (nvar,nz,kmax)

         ! temp storage to be allocated and deallocated
         real(dp), pointer, dimension(:,:) :: &
            equ_init, del, err_scale_inv, dx, xscale ! (nvar,nz)
         real(dp), pointer, dimension(:,:,:) :: &
            t, & ! (kmax,nvar,nz)
            dx_save_new, dx_save_prev, & ! (nvar,nz,*)
            lblk_init, dblk_init, ublk_init, & ! (nvar,nvar,nz)
            lhs_lblk, lhs_dblk, lhs_ublk, & ! (nvar,nvar,nz)
            lhs_lblkF, lhs_dblkF, lhs_ublkF ! (nvar,nvar,nz)
         integer, pointer :: ipiv_blk(:,:) ! (nvar,nz)
        
         logical, parameter :: dbg = .false.

         include 'formats.dek'
         
         ierr = 0         
         
         iout = 0 ! set to 2 for dense
         !iout = 2 ! set to 2 for dense    <<<<<< not debugged. 
            ! cannot converge with interpolated BCs.
         
         call setup(ierr)
         if (ierr /= 0) return
         
         if (dbg .or. trace) then
            write(*,*)
            write(*,3) 'enter hydro_seulex: model, kopt, logdt', &
               s% model_number, kopt, log10(dt/secyer)
         end if
         
         if (s% doing_timing) then
            call system_clock(time0,clock_rate)
            total_all_before = total_times(s)
         else
            total_all_before = 0
         end if
         
         call do_alloc(ierr)
         if (ierr /= 0) return

         if (dbg) write(*,*) 'call set_xscale_info'
         call set_xscale_info(s, nvar, nz, xscale, ierr)
         if (ierr /= 0) then
            !if (s% report_ierr) &
               write(*,2) 'hydro_seulex set_xscale_info ierr', ierr
               stop
            return
         end if 
         
         if (dbg) write(*,*) 'call set_err_scale_inv'
         call set_err_scale_inv

         skip_partials = .false.
         call eval_equ_for_solver(s, nvar, 1, nz, dt, skip_partials, xscale, ierr)         
         if (ierr /= 0) then
            !if (s% report_ierr) then
               write(*, *) 'hydro_seulex: eval_equ returned ierr', ierr
            !end if
            !stop
            call dealloc
            return
         end if
         
         ! save starting equ and jacobian
         do k=1,nz
            do j=1,nvar
               equ_init(j,k) = s% equ(j,k)
               do i=1,nvar
                  lblk_init(i,j,k) = s% lblk(i,j,k)
                  dblk_init(i,j,k) = s% dblk(i,j,k)
                  ublk_init(i,j,k) = s% ublk(i,j,k)
               end do
            end do
         end do
         
         ! kopt is estimated optimal number of rows
         ! kc is number of rows we've actually done so far
         if (kopt >= 2 .and. kopt <= kmax) then
            do j=1,kopt-1
               kc = j
               call get_row(kc, reject, err, ierr)
               if (ierr /= 0) exit
            end do
            if (ierr == 0) then ! consider doing row kc = kopt
               if (kc == 1 .or. reject) then
                  kc = kopt
                  call get_row(kc, reject, err, ierr)
               else if (err > 1 .and. kopt < kmax) then
                  if (err < dble(4*nj(kopt+1)*nj(kopt))) then
                     kc = kopt
                     call get_row(kc, reject, err, ierr)
                  end if
               end if
            end if
         else ! if (invalid kopt value) then cold start
            kopt = max(2, min(kmax-2, int(-log10(rtol+atol)*0.6d0+1.5d0)))
            do j=1,kopt
               kc = j
               call get_row(kc, reject, err, ierr)
               if (ierr /= 0) exit
               if (j == 1) cycle ! always do at least 2 rows
               if (err <= 1) exit ! stop as soon as err okay
            end do
         end if
         
         if (ierr == 0 .and. err > 1 .and. kc == kopt .and. kopt < kmax) then
            if (err < dble(2*nj(kopt+1))) then 
               ! hope for convergence in line kopt+1
               kc = kopt+1
               call get_row(kc, reject, err, ierr)
            end if
         end if

         if (ierr /= 0) then
            call dealloc
            return
         end if
         
         if (trace) write(*,1) 'err', err
         if (err <= 1) then ! step is accepted
            
            do k=1,nz
               do j=1,nvar
                  dx(j,k) = t(1,j,k)
               end do
            end do
            ! set vars to match final result
            call set_vars_for_solver(s, 1, nz, 0, dx, xscale, dt, ierr) 
            if (ierr /= 0) then
               call dealloc
               return
            end if
            ! compute optimal order for next step
            kopt_prev = kopt
            if (dbg) write(*,*) 'compute optimal order for next step'
            if (dbg) write(*,2) 'kc', kc
            if (dbg) write(*,2) 'kopt', kopt
            if (kc == 2) then
               kopt = min(3, kmax)
            else if (kc <= kopt) then
               kopt = kc
               ! consider changing kopt based on cost-benefit estimates
               if (dbg) write(*,2) 'w(kc-1)', kc-1, w(kc-1)
               if (dbg) write(*,2) 'w(kc)', kc, w(kc)
               if (dbg) write(*,2) 'w(kc)*dt_fac3', kc, w(kc)*dt_fac3
               if (dbg) write(*,*) 'w(kc-1) < w(kc)*dt_fac3', w(kc-1) < w(kc)*dt_fac3
               if (w(kc-1) < w(kc)*dt_fac3) kopt = kc-1
               if (w(kc) < w(kc-1)*dt_fac4) kopt = min(kc+1, kmax-1)
            else ! kc > kopt
               kopt = kc-1
               ! consider changing kopt based on cost-benefit estimates
               if (kc > 3) then
                  if (w(kc-2) < w(kc-1)*dt_fac3) kopt = kc-2
               end if
               if (dbg) write(*,2) 'w(kc-1)', kc-1, w(kc-1)
               if (dbg) write(*,2) 'w(kc)', kc, w(kc)
               if (dbg) write(*,1) 'dt_fac4', dt_fac4
               if (dbg) write(*,2) 'w(kc)*dt_fac4', kc, w(kc)*dt_fac4
               if (dbg) write(*,*) 'w(kc) < w(kc-1)*dt_fac4', w(kc) < w(kc-1)*dt_fac4
               if (w(kc) < w(kc-1)*dt_fac4) kopt = min(kc, kmax-1) 
            end if
            ! compute next step size based on next optimal order
            if (kopt <= kc) then ! for this case, we have hh value for kopt
               h = hh(kopt)
            ! else we are increasing kopt beyond kc, so increase hh(kc) to match
            else if (kc < kopt_prev .and. w(kc) < w(kc-1)*dt_fac4 .and. kopt < kmax) then
               h = hh(kc)*a(kopt+1)/a(kc)
            else
               h = hh(kc)*a(kopt)/a(kc)
            end if
         else ! step is rejected because err too large
            ! compute optimal order for retry
            kopt = min(kopt, kc, kmax-1)
            if (kopt > 2) then
               if (w(kopt-1) < w(kopt)*dt_fac3) kopt = kopt-1
            end if
            ! set next step size based on kopt
            h = hh(kopt)
         end if
         s% seulex_kopt = kopt
         s% seulex_rows = kc
         s% hydro_seulex_dt_limit = h
         s% err_ratio_max_hydro = err_ratio_max
         s% err_ratio_avg_hydro = err_ratio_avg

         if (dbg) write(*,1) 's% err_ratio_max_hydro', s% err_ratio_max_hydro
         if (dbg) write(*,1) 's% err_ratio_avg_hydro', s% err_ratio_avg_hydro
         
         converged = (err <= 1d0)
                  
         if (converged) then ! set final result and evaluate vars
            ! s% xa has already been updated by final call to set_vars_for_solver         
            do k=1,nz
               do j=1,nvar_hydro
                  s% xh(j,k) = s% xh_pre_hydro(j,k) + dx(j,k)
               end do
            end do
            ! a few more sanity checks before accept it
            converged = check_after_converge(s, report, ierr)
            if (.not. converged) then
               if (trace .or. report) write(*,2) 'check_after_converge rejected: model', s% model_number
               !stop 'hydro_seulex'
            end if
         end if

         if (s% doing_timing) then
            call system_clock(time1,clock_rate)
            total_all_after = total_times(s)
            ! see hydro_newton subroutine newt
         end if
         
         if (trace) then
            if (converged) then
               write(*,2) 'converged: rows, err, lg err max, lg err avg, lg next dt/yr', kc, &
                  err, log10(s% err_ratio_max_hydro), log10(s% err_ratio_avg_hydro), &
                  log10(s% hydro_seulex_dt_limit/secyer)
            else
               write(*,2) 'failed'
            end if
            write(*,*)
         end if
         
         
         
         
         

         if (dbg) write(*,2) 'done hydro_seulex'
         if (dbg) write(*,2)

         !if (dumping) stop 'debug: dumping hydro_seulex' 
         
         
         call dealloc
         
         if (converged) then
            do_hydro_seulex = keep_going

            !write(*, *) 'hydro_seulex converged'
            !write(*,*)

         else
            do_hydro_seulex = retry
            s% result_reason = hydro_failed_to_converge
            if (report) then
               write(*, *) 'hydro_seulex failed to converge'
               write(*,2) 's% model_number', s% model_number
               write(*,2) 's% hydro_call_number', s% hydro_call_number
               write(*,2) 'nz', nz
               write(*,2) 's% num_retries', s% num_retries
               write(*,2) 's% num_backups', s% num_backups
               write(*,2) 's% number_of_backups_in_a_row', s% number_of_backups_in_a_row
               write(*,1) 'log dt/secyer', log10(dt/secyer)
               write(*, *) 
               !stop
            end if
         
         
            !stop
            

            return
         end if
         
         
         
         
         
         
         
         
         
         
         
         contains  
         
         
         subroutine setup(ierr)
            use num_def, only: block_tridiag_dble_matrix_type
            integer, intent(out) :: ierr
            integer :: nsequ
            
            ierr = 0
         
            species = s% species
            nvar_hydro = s% nvar_hydro
            nvar = s% nvar
            nz = s% nz
            prev_nzlo = 1
            prev_nzhi = nz
            h = dt
            do_hydro_seulex = retry
            have_a_factored_mtx = .false.
            
            if (s% include_L_in_error_est) then
               skip1 = 0
            else
               skip1 = s% i_FL
            end if
         
            if (s% include_v_in_error_est) then
               skip2 = 0
            else
               skip2 = s% i_vel
            end if
         
            if (s% include_lnTdot_in_error_est) then
               skip3 = 0
            else
               skip3 = s% i_lnTdot
            end if
         
            if (s% include_lnddot_in_error_est) then
               skip4 = 0
            else
               skip4 = s% i_lnddot
            end if
            
            err_ratio_avg_lim = 0
            err_ratio_max_lim = 0
            
            s% num_jacobians = 1
            s% total_num_jacobians = s% total_num_jacobians + s% num_jacobians
            s% num_solves = 0
            s% hydro_matrix_type = block_tridiag_dble_matrix_type
            s% seulex_rows = 0
            
            i_vel = s% i_vel
            i_FL = s% i_FL
            i_lnddot = s% i_lnddot
            i_lnTdot = s% i_lnTdot

            kopt = s% seulex_kopt ! saved optimal order
            
            ! controls for seulex
            trace = s% trace_seulex
            dt_fac1 = s% seulex_dt_fac1
            dt_fac2 = s% seulex_dt_fac2
            dt_fac3 = s% seulex_dt_fac3
            dt_fac4 = s% seulex_dt_fac4
            dt_safe1 = s% seulex_dt_safe1
            dt_safe2 = s% seulex_dt_safe2
            dt_by_err_ratio_avg = s% seulex_dt_by_err_ratio_avg
            kmax = s% seulex_kmax
            km2 = (kmax*(kmax+1))/2
            wk_set_vars = s% seulex_wk_set_vars
            wk_eval_equ = s% seulex_wk_eval_equ
            wk_prep_mtx = s% seulex_wk_prep_mtx
            wk_solve_mtx = s% seulex_wk_solve_mtx
            
            atol = s% hydro_err_ratio_atol
            rtol = s% hydro_err_ratio_rtol
            
            nsequ = s% seulex_step_size_sequence
            if (nsequ == 1) then
               nj(1) = 1
               nj(2) = 2
               nj(3) = 3
               do i=4,kmax
                  nj(i) = 2*nj(i-2)
               end do
            else if (nsequ == 2) then
               nj(1) = 2
               nj(2) = 3
               do i=3,kmax
                  nj(i) = 2*nj(i-2)
               end do
            else if (nsequ == 3) then
               do i=1,kmax
                  nj(i) = i
               end do
            else if (nsequ == 4) then
               do i=1,kmax
                  nj(i) = i+1
               end do
            else if (nsequ == 5) then
               nj(1) = 3
               nj(2) = 4
               do i=3,kmax
                  nj(i) = 2*nj(i-2)
               end do
            else if (nsequ == 6) then
               nj(1) = 4
               nj(2) = 5
               do i=3,kmax
                  nj(i) = 2*nj(i-2)
               end do
            else 
               ierr = -1
               write(*,*) 'bad value for seulex_step_size_sequence'
               return
            end if
            
            wk_substep = wk_set_vars + wk_eval_equ + wk_solve_mtx ! for each substep
            a(1) = wk_eval_equ + wk_prep_mtx + wk_solve_mtx + (nj(1)-1)*wk_substep + wk_set_vars
            do i=2,kmax
               a(i) = a(i-1) + wk_prep_mtx + wk_solve_mtx + (nj(i)-1)*wk_substep
            end do
            w(1) = 1d30
            
         end subroutine setup
         
            
         subroutine set_err_scale_inv ! for calculation of error of scaled variables
            real(dp) :: x
            integer :: j, k
            do k=1,nz
               do j=1,nvar
                  if (j <= s% nvar_hydro) then
                     x = s% xh_pre_hydro(j,k)
                  else
                     x = s% xa_pre_hydro(j-nvar_hydro,k)
                  end if
                  err_scale_inv(j,k) = 1d0/(xscale(j,k)*(atol + rtol*abs(x)))
               end do
            end do
         end subroutine set_err_scale_inv      
         
      
         subroutine get_row(jj, reject, err, ierr) ! get row jj of extrapolation table
            use alloc, only: get_3d_work_array
            ! also sets err_ratio_avg and err_ratio_max for this row
            integer, intent(in) :: jj ! row number to do
            logical, intent(out) :: reject ! reject means finished row but error too large
            real(dp), intent(out) :: err
            integer, intent(out) :: ierr
            
            real(dp) :: hj, del1, del2, fac, err_sum, expo, facmin, limtr
            integer :: m, i, j, k, mm, l, n, nzlo, nzhi
            logical :: skip_partials
            real(dp), pointer, dimension(:,:,:) :: dx_save_tmp

            include 'formats.dek'
            
            ierr = 0
            err = 1d99
            reject = .false.
      
            m = nj(jj) ! number of substeps for this row
            hj = dt/m ! size of each substep for this row
            
            if (iout == 2) then
               call set_bounds(jj, nzlo, nzhi)
            else
               nzlo = 1
               nzhi = nz
            end if
            
            if (nzlo < prev_nzlo .or. nzhi > prev_nzhi) then
               if (trace) write(*,*) 'nzlo < prev_nzlo .or. nzhi > prev_nzhi'
               write(*,2) 'nzlo', nzlo
               write(*,2) 'prev_nzlo', prev_nzlo
               write(*,2) 'nzhi', nzhi
               write(*,2) 'prev_nzhi', prev_nzhi
               write(*,2) 'jj', jj
               stop
               ierr = -1
               return
            end if
            prev_nzlo = nzlo
            prev_nzhi = nzhi
   
            if (dbg) write(*,2) 'call prepare_lhs_matrix'
            call prepare_lhs_matrix(nzlo, nzhi, hj, lblk_init, dblk_init, ublk_init, ierr)
            if (ierr /= 0) then
               if (trace) write(*,*) 'failed in prepare_lhs_matrix'
               return
            end if
            
            do k=nzlo,nzhi
               do j=1,nvar
                  del(j,k) = equ_init(j,k)
               end do
            end do
            call solve_mtx(nzlo, nzhi, ierr)
            if (ierr /= 0) then
               if (trace) write(*,*) 'failed in solve_mtx'
               return
            end if
            
            if (iout==2) then
               dx_save_tmp => dx_save_prev
               dx_save_prev => dx_save_new
               dx_save_new => dx_save_tmp
               if (size(dx_save_new,dim=3) < m) then
                  call return_3d(dx_save_new)   
                  call get_3d_work_array( &
                     s, dx_save_new, nvar, nz, max(m,nj(min(kmax-1,jj+1))), 2, 'hydro_seulex', ierr)
                  if (ierr /= 0) return
               end if
            end if
         
            do mm = 1, m ! semi-implicit euler substeps
               
               ! at this point, del holds increment to solution for substep mm
               if (mm == 1) then
                  do k=nzlo,nzhi
                     do j=1,nvar
                        dx(j,k) = del(j,k)
                     end do
                  end do
               else
                  do k=nzlo,nzhi
                     do j=1,nvar
                        dx(j,k) = dx(j,k) + del(j,k)
                     end do
                  end do
               end if
               
               call set_dx_for_boundary_cells(mm, m, nzlo, nzhi, jj, hj*mm, ierr)
               if (ierr /= 0) then
                  if (trace) write(*,*) 'failed in set_dx_for_boundary_cells'
                  return
               end if

               ! now dx holds increment to initial soln after mm substeps
               if (iout==2) then
                  do k=max(1,nzlo-1),min(nz,nzhi+1)
                     do j=1,nvar
                        dx_save_new(j,k,mm) = dx(j,k)
                        if (.false. .and. j == 1 .and. k == 1063) &
                           write(*,4) 'dx_save(j,k,mm)', j, k, mm, dx_save_new(j,k,mm)
                     end do
                  end do
               end if


               if (mm == m) exit
               
               ! set up to do substep mm+1

               call set_vars_for_solver( & ! including boundary cells nzlo-1 and nzhi+1
                  s, max(1,nzlo-1), min(nz,nzhi+1), mm, dx, xscale, hj*(mm+1), ierr)
               if (ierr /= 0) then
                  if (trace) write(*,2) 'failed in set_vars_for_solver', mm
                  exit
               end if

               skip_partials = .true. ! don't need to calculate jacobian
               ! this uses boundary cells, but doesn't calculate equ for them.
               call eval_equ_for_solver(s, nvar, nzlo, nzhi, hj*(mm+1), skip_partials, xscale, ierr)         
               if (ierr /= 0) then
                  if (trace) write(*,2) 'failed in eval_equ', mm
                  exit
               end if
               
               if (.false.) then
                  j = s% i_lnT
                  k = 1072
                  write(*,4) 'mm, dx(j,k), equ', mm, j, k, dx(j,k), s% equ(j,k)
               end if
            
               if (mm == 1 .and. jj <= 2) then
                  call do_stability_check(nzlo, nzhi, hj, ierr)
                  if (ierr /= 0) then
                     if (trace) write(*,2) 'failed in do_stability_check', mm
                     exit
                  end if
               end if
               
               do k=nzlo,nzhi
                  do j=1,nvar
                     del(j,k) = s% equ(j,k)
                  end do
               end do
               call solve_mtx(nzlo, nzhi, ierr)
               if (ierr /= 0) then
                  if (trace) write(*,2) 'failed in solve_mtx', mm
                  exit
               end if
            
            end do
            
            call dealloc_mtx(ierr)
            if (ierr /= 0) return
            
            ! t holds the current row of the extrapolation table
            ! t(jj,-) is the base result; t(1,-) is the fully extrapolated result
            ! first set t(jj,-)
            ! for this row, we are only changing nzlo:nzhi
            do k=nzlo,nzhi
               do j=1,nvar
                  t(jj,j,k) = dx(j,k)
                  if (.false. .and. j == s% i_lnT .and. k == 1072) then
                     write(*,4) 't(jj,j,k)', jj, j, k, t(jj,j,k)
                  end if
               end do
            end do
            if (nzlo > 1) then ! use previous row extrapolated solution for 1:nzlo-1
               do k=1,nzlo-1
                  do j=1,nvar
                     t(jj,j,k) = t(1,j,k)
                  end do
               end do
            end if
            if (nzhi < nz) then ! use previous row extrapolated solution for nzhi+1:nz
               do k=nzhi+1,nz
                  do j=1,nvar
                     t(jj,j,k) = t(1,j,k)
                  end do
               end do
            end if
         
            if (jj == 1) return
         
            ! polynomial extrapolation
            do l=jj,2,-1
               fac = 1d0/((dble(nj(jj))/dble(nj(l-1)))-1d0)
               do k=nzlo,nzhi
                  do j=1,nvar
                     t(l-1,j,k) = t(l,j,k) + (t(l,j,k)-t(l-1,j,k))*fac
                     if (.false. .and. j == s% i_lnT .and. k == 1072) then
                        write(*,4) 't(l-1,j,k)', l-1, j, k, t(l-1,j,k)
                     end if
                  end do
               end do
            end do
            
            call estimate_local_error(jj, nzlo, nzhi, ierr)
            if (ierr /= 0) then
               if (trace) write(*,2) 'failed in estimate_local_error', mm
               return
            end if
            
            ! compute optimal step size for this row (from E.Hairer's SEULEX code)
            expo = 1d0/dble(jj)
            facmin = dt_fac1**expo
            fac = min(dt_fac2/facmin, max(facmin, (err/dt_safe1)**expo/dt_safe2))
            limtr = limiter(1d0/fac)
            hh(jj) = h*limtr
            
            ! expected cost-benefit ratio for this row in next step
            w(jj) = a(jj)/hh(jj)
            
            if (trace) then
               if (jj == 2) write(*,'(/,a8,5a8,99a12)') &
                  'seulex', 'row', 'nzlo', 'nzhi', 'nz', 'n', 'n/nz', &
                  'var max err', 'k max err', 'err', 'lg max err', &
                  'lg err avg', 'nxt lg dt', 'lg cost/dt'
               write(*,'(8x,5i8,f12.6,a12,i12,99f12.6)') &
                  jj, nzlo, nzhi, nz, nzhi-nzlo+1, dble(nzhi-nzlo+1)/dble(nz), &
                  trim(s% nameofvar(j_max_err)), k_max_err, err, &
                  log10(err_ratio_max), log10(err_ratio_avg), log10(hh(jj)/secyer), log10(w(jj))
            end if

         end subroutine get_row

            
         subroutine estimate_local_error(jj, nzlo, nzhi, ierr)
            ! estimate local error for this row (only for nzlo:nzhi)
            integer, intent(in) :: jj, nzlo, nzhi
            integer, intent(out) :: ierr
            
            integer :: j, k, n
            real(dp) :: err_sum, err_ratio
            include 'formats.dek'
            
            ierr = 0
            k_max_err = -1
            j_max_err = -1
            err_ratio_max = -1
            err_sum = 0
            n = 0
            do k=nzlo,nzhi
               do j=1,nvar
                  if (j==skip1 .or. &
                      j==skip2 .or. &
                      j==skip3 .or. &
                      j==skip4) then
                     error_vectors(j,k,jj-1) = 0
                     cycle
                  end if
                  err_ratio = abs(t(1,j,k) - t(2,j,k))*err_scale_inv(j,k)
                  if (err_ratio > err_ratio_max) then
                     k_max_err = k
                     j_max_err = j
                     err_ratio_max = err_ratio
                  end if
                  if (.false. .and. j == s% i_lnT .and. nzlo > 1) then
                     write(*,2) 'lnT err_ratio', k, err_ratio, t(1,j,k), t(2,j,k), err_scale_inv(j,k) 
                  end if
                  error_vectors(j,k,jj-1) = err_ratio
                  err_sum = err_sum + min(err_ratio, 1d30)
                  if (err_sum > 1d30) then
                     ierr = -1
                     return
                  end if
                  n = n+1
               end do
            end do
            if (is_bad_num(err_sum)) then
               ierr = -1
               return
            end if
            err_ratio_avg = err_sum/n
                        
            if (dt_by_err_ratio_avg) then
               err = err_ratio_avg
            else
               err = err_ratio_max
            end if
            
            if (jj >= 3) then ! compare to err_lim set from previous row
               if (err_ratio_avg >= err_ratio_avg_lim) then
                  reject = .true.
               else
                  err_ratio_avg_lim = max(4*err_ratio_avg,1d0)
               end if
               if (err_ratio_max >= err_ratio_max_lim) then
                  reject = .true.
               else
                  err_ratio_max_lim = max(4*err_ratio_max,1d0)
               end if
            else
               err_ratio_avg_lim = max(4*err_ratio_avg,1d0)
               err_ratio_max_lim = max(4*err_ratio_max,1d0)
            end if
         end subroutine estimate_local_error
      

         subroutine do_stability_check(nzlo, nzhi, hj, ierr) 
            ! simplified Newton iteration stability check for 1st substep of 1st 2 rows
            integer, intent(in) :: nzlo, nzhi
            real(dp), intent(in) :: hj
            integer, intent(out) :: ierr
            
            integer :: j, k
            real(dp) :: del1, del2
            
            include 'formats.dek'
         
            del1 = 0
            do k=nzlo,nzhi
               do j=1,nvar
                  del1 = del1 + abs(dx(j,k))*err_scale_inv(j,k)
               end do
            end do
            ! del1 is square of avg of 1st substep
            do j=1,nvar
               if (s% ode_var(j)) then
                  do k=nzlo,nzhi
                     del(j,k) = s% equ(j,k) - dx(j,k)/hj
                  end do
               else
                  do k=nzlo,nzhi
                     del(j,k) = s% equ(j,k)
                  end do
               end if
            end do
            call solve_mtx(nzlo, nzhi, ierr)
            if (ierr /= 0) then
               if (trace) write(*,2) 'solve_mtx failed in do_stability_check'
               return
            end if
            ! del is the correction to the 1st substep dx
         
            del2 = 0
            do k=nzlo,nzhi
               do j=1,nvar
                  del2 = del2 + abs(del(j,k))*err_scale_inv(j,k)
               end do
            end do
            ! del2 is square of avg of correction to 1st substep
            
            ! it is bad news if the correction is larger than the initial result
            if (del2 > max(1d0,del1)) then ! diverging
               if (trace) write(*,1) 'diverging', del2, del1
               ierr = -1
               return
            end if

         end subroutine do_stability_check
      
         
         subroutine prepare_lhs_matrix(nzlo, nzhi, hj, lblk, dblk, ublk, ierr)
            use utils_lib, only: set_pointer_3
            integer, intent(in) :: nzlo, nzhi
            real(dp), intent(in) :: hj
            real(dp), pointer, dimension(:,:,:) :: lblk, dblk, ublk
            integer, intent(out) :: ierr
            
            integer :: i, j, k, n
            real(dp), dimension(:,:,:), pointer :: lblkF_sub, dblkF_sub, ublkF_sub
            real(dp) :: hj_inv
            include 'formats.dek'

            ierr = 0
            hj_inv = 1d0/hj
            
            ! set lhs matrix to (M/hj - J)
!$OMP PARALLEL DO PRIVATE(i,k,j)
            do k=nzlo,nzhi
               do j=1,nvar
                  do i=1,nvar
                     lhs_lblk(i,j,k) = -lblk(i,j,k)
                     lhs_dblk(i,j,k) = -dblk(i,j,k)
                     lhs_ublk(i,j,k) = -ublk(i,j,k)
                  end do
               end do
               do i=1,nvar
                  if (.not. s% ode_var(i)) cycle
                  lhs_dblk(i,i,k) = hj_inv + lhs_dblk(i,i,k)
               end do
               ! make copy for factoring
               do j=1,nvar
                  do i=1,nvar
                     lhs_lblkF(i,j,k) = lhs_lblk(i,j,k)
                     lhs_dblkF(i,j,k) = lhs_dblk(i,j,k)
                     lhs_ublkF(i,j,k) = lhs_ublk(i,j,k)
                  end do
               end do
            end do
!$OMP END PARALLEL DO

            if (have_a_factored_mtx) then
               call dealloc_mtx(ierr)
               if (ierr /= 0) return
            end if
                 
            if (dbg) write(*,*) 'call decsolblk to factor'
            n = nzhi - nzlo + 1
            call set_pointer_3(lblkF_sub, lhs_lblkF, nvar, nvar, n)
            call set_pointer_3(dblkF_sub, lhs_dblkF, nvar, nvar, n)
            call set_pointer_3(ublkF_sub, lhs_ublkF, nvar, nvar, n)
            call decsolblk( & ! factor
               0, caller_id, nvar, nz, lblkF_sub, dblkF_sub, ublkF_sub, &
               del, ipiv_blk, lrd, rpar_decsol, lid, ipar_decsol, ierr)
            if (ierr /= 0) then
               write(*,*) 'failed in decsolblk factor'
               stop 'hydro_seulex prepare_lhs_matrix'
            end if
            have_a_factored_mtx = .true.
         
         end subroutine prepare_lhs_matrix
         
         
         subroutine solve_mtx(nzlo, nzhi, ierr)
            use utils_lib, only: set_pointer_2, set_pointer_3
            integer, intent(in) :: nzlo, nzhi
            integer, intent(out) :: ierr
            integer :: j, k, n
            real(dp), dimension(:,:,:), pointer :: lblkF_sub, dblkF_sub, ublkF_sub
            real(dp), pointer :: del_sub(:,:)
            include 'formats.dek'
            ierr = 0
            n = nzhi - nzlo + 1
            call set_pointer_2(del_sub, del, nvar, n)
            call set_pointer_3(lblkF_sub, lhs_lblkF, nvar, nvar, n)
            call set_pointer_3(dblkF_sub, lhs_dblkF, nvar, nvar, n)
            call set_pointer_3(ublkF_sub, lhs_ublkF, nvar, nvar, n)
            call decsolblk( & ! solve
               1, caller_id, nvar, nz, lblkF_sub, dblkF_sub, ublkF_sub, del_sub, ipiv_blk, &
               lrd, rpar_decsol, lid, ipar_decsol, ierr)
            if (ierr /= 0) return
            ! partials wrt scaled variables
            ! so multiply solution by xscale to get actual (unscaled) result
            do k=nzlo,nzhi
               do j=1,nvar
                  del(j,k) = del(j,k)*xscale(j,k)
               end do
            end do       
         end subroutine solve_mtx
         
         
         subroutine dealloc_mtx(ierr)
            integer, intent(out) :: ierr            
            ierr = 0
            call decsolblk( & ! deallocate
               2, caller_id, nvar, nz, lhs_lblkF, lhs_dblkF, lhs_ublkF, del, ipiv_blk, &
               lrd, rpar_decsol, lid, ipar_decsol, ierr)
            have_a_factored_mtx = .false.              
         end subroutine dealloc_mtx

         
         real(dp) function limiter(x)
            real(dp), intent(in) :: x
            real(dp), parameter :: kappa = 2
            ! for x >= 0 and kappa = 2, limiter value is between 0.07 and 4.14
            ! for x = 1, limiter = 1
            limiter = 1 + kappa*ATAN((x-1)/kappa)
         end function limiter
         
      
      
      
      
      

 
         
         
         subroutine alloc_nvar_nz(p, ierr)
            use alloc, only: get_2d_work_array
            real(dp), pointer :: p(:,:)
            integer, intent(out) :: ierr
            call get_2d_work_array( &
               s, p, nvar, nz, nz_alloc_extra, 'hydro_seulex', ierr)
         end subroutine alloc_nvar_nz
         
         
         subroutine alloc_nvar_nvar_nz(p, ierr)
            use alloc, only: get_3d_work_array
            real(dp), pointer :: p(:,:,:)
            integer, intent(out) :: ierr
            call get_3d_work_array( &
               s, p, nvar, nvar, nz, nz_alloc_extra, 'hydro_seulex', ierr)
         end subroutine alloc_nvar_nvar_nz
      
      
         subroutine do_alloc(ierr)
            use alloc, only: &
               get_2d_work_array, get_3d_work_array, get_integer_2d_work_array
            use mtx_lib, only: lapack_work_sizes
            integer, intent(out) :: ierr
            ierr = 0            
            
            call alloc_nvar_nz(equ_init, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nz(del, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nz(err_scale_inv, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nz(dx, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nz(xscale, ierr)
            if (ierr /= 0) return
            
            
            call get_3d_work_array(s, t, kmax, nvar, nz, 2, 'hydro_seulex', ierr)
            if (ierr /= 0) return
            
            call get_3d_work_array(s, dx_save_new, nvar, nz, nj(3), 2, 'hydro_seulex', ierr)
            if (ierr /= 0) return
            
            call get_3d_work_array(s, dx_save_prev, nvar, nz, nj(3), 2, 'hydro_seulex', ierr)
            if (ierr /= 0) return
                        
            call alloc_nvar_nvar_nz(s% lblk, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nvar_nz(s% dblk, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nvar_nz(s% ublk, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nvar_nz(lblk_init, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nvar_nz(dblk_init, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nvar_nz(ublk_init, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nvar_nz(lhs_lblk, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nvar_nz(lhs_dblk, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nvar_nz(lhs_ublk, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nvar_nz(lhs_lblkF, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nvar_nz(lhs_dblkF, ierr)
            if (ierr /= 0) return
            
            call alloc_nvar_nvar_nz(lhs_ublkF, ierr)
            if (ierr /= 0) return
            
            call get_integer_2d_work_array( &
               s, ipiv_blk, nvar, nz, nz_alloc_extra, ierr)
            if (ierr /= 0) return         
            
            if (associated(s% seulex_error_vectors)) &
               call return_3d(s% seulex_error_vectors) ! return previous step info
            call get_3d_work_array( &
               s, s% seulex_error_vectors, nvar, nz, kmax-1, 2, 'hydro_seulex', ierr)
            if (ierr /= 0) return
            error_vectors => s% seulex_error_vectors

         end subroutine do_alloc
         
         
         subroutine dealloc
            use alloc
            include 'formats.dek'
            
            if (have_a_factored_mtx) then
               call dealloc_mtx(ierr)
               if (ierr /= 0) return
            end if
            
            call return_2d(equ_init) 
            call return_2d(del) 
            call return_2d(err_scale_inv) 
            call return_2d(dx) 
            call return_2d(xscale)
            
            call return_3d(t)   
            call return_3d(dx_save_new)   
            call return_3d(dx_save_prev)   
            call return_3d(s% lblk)   
            call return_3d(s% dblk)   
            call return_3d(s% ublk)   
            call return_3d(lblk_init)   
            call return_3d(dblk_init)   
            call return_3d(ublk_init)   
            call return_3d(lhs_lblk)   
            call return_3d(lhs_dblk)   
            call return_3d(lhs_ublk)   
            call return_3d(lhs_lblkF)   
            call return_3d(lhs_dblkF)   
            call return_3d(lhs_ublkF)
             
            call return_integer_2d(ipiv_blk)         
            
         end subroutine dealloc

         
         subroutine return_1d(p)
            use alloc
            real(dp), pointer :: p(:)
            call return_work_array(s, p, 'hydro_seulex')
         end subroutine return_1d
         
         
         subroutine return_2d(p)
            use alloc
            real(dp), pointer :: p(:,:)
            call return_2d_work_array(s, p, 'hydro_seulex')
         end subroutine return_2d
         
         
         subroutine return_3d(p)
            use alloc
            real(dp), pointer :: p(:,:,:)
            call return_3d_work_array(s, p, 'hydro_seulex')
         end subroutine return_3d
         
         
         subroutine return_integer_2d(p)
            use alloc
            integer, pointer :: p(:,:)
            call return_integer_2d_work_array(s, p)
         end subroutine return_integer_2d


         subroutine set_bounds(jj, nzlo, nzhi)
            integer, intent(in) :: jj
            integer, intent(out) :: nzlo, nzhi
            
            real(dp) :: max_err
            integer :: k, kk
            
            integer, parameter :: cell_margin = 5
            real(dp), parameter :: max_err_lim = 5d-2
            
            include 'formats.dek'

            nzlo = 1
            nzhi = nz
            
            if (jj <= 2) return
            if (nj(jj-1) < 3) return
            
            ! set nzlo and nzhi using error estimates from previous row
            ! error_vectors(j,k,jj-2) has error estimates we need
            ! search for nzlo
            ! go in from prev_nzlo until find cell with max_err > max_err_lim
            do k=prev_nzlo,nz
               max_err = maxval(error_vectors(:,k,jj-2))
               if (max_err > max_err_lim) then
                  nzlo = max(prev_nzlo,k-cell_margin)
                  exit
               end if
            end do
            ! search for nzhi
            ! go out from prev_nzhi until find cell with max_err > max_err_lim
            do k=prev_nzhi,1,-1
               max_err = maxval(error_vectors(:,k,jj-2))
               if (max_err > max_err_lim) then
                  nzhi = min(prev_nzhi,k+cell_margin)
                  exit
               end if
            end do
            
            if (.true. .and. (nzlo > 1 .or. nzhi < nz)) &
               write(*,'(a50,6i6,99f12.4)') 'set_bounds: model row nzlo nzhi n nz n/nz', &
                  s% model_number, jj, nzlo, nzhi, nzhi-nzlo+1, nz, dble(nzhi-nzlo+1)/dble(nz)

         end subroutine set_bounds
         
         
         subroutine set_dx_for_boundary_cells(mm, m, nzlo, nzhi, jj, partial_dt, ierr)
            ! boundary cells are nzlo-1 and nzhi+1
            ! set dx for those cell to have interpolated values
            ! using results from previous row
            integer, intent(in) :: mm, m, nzlo, nzhi, jj
            real(dp), intent(in) :: partial_dt
            integer, intent(out) :: ierr
            
            real(dp) :: theta
            integer :: k, j, mprev
            
            include 'formats.dek'

            ierr = 0
            
            if (nzlo == 1 .and. nzhi == nz) return
            if (jj == 1) return
            
            mprev = nj(jj-1)
            
            ! mm is the number of substeps we've taken so far out of m total
            if (mm == m) then ! just copy final values from previous row
               k = nzlo-1
               if (k >= 1) then
                  do j=1,nvar
                     dx(j,k) = dx_save_prev(j,k,mprev)
                  end do
               end if
               k = nzhi+1
               if (k <= nz) then
                  do j=1,nvar
                     dx(j,k) = dx_save_prev(j,k,mprev)
                  end do
               end if
               return
            end if
            
            theta = partial_dt/dt ! fraction of dt for interpolation
            if (nzlo > 1) call interpolate_dx(jj, mprev, theta, nzlo-1, ierr)
            if (ierr /= 0) return            
            if (nzhi < nz) call interpolate_dx(jj, mprev, theta, nzhi+1, ierr)

         end subroutine set_dx_for_boundary_cells
         
         
         subroutine interpolate_dx(jj, mprev, theta, k, ierr)
            use interp_1d_lib
            use interp_1d_def
            integer, intent(in) :: jj ! jj is current row number
            integer, intent(in) :: mprev ! mprev = nj(jj-1)
            integer, intent(in) :: k ! cell for interpolation
            real(dp), intent(in) :: theta ! fraction of dt for interpolation
            integer, intent(out) :: ierr
            
            integer, parameter :: nwork = pm_work_size, n_new = 1, n_old = 4
            integer :: m, i0, ilo, ihi, j, mm, ii
            real(dp) :: pts(mprev+1), delta, &
               x_old(n_old), v_old(n_old), x_new(n_new), v_new(n_new)
            real(dp), target :: work(4,nwork)
            
            include 'formats.dek'

            ierr = 0
            if (theta == 0) then
               dx(1:nvar,k) = 0d0
               return
            end if
            if (theta < 0d0 .or. theta > 1d0) then
               write(*,1) 'interpolate_dx bad theta', theta
               
               stop
               
               ierr = -1
               return
            end if
            
            if (mprev < 3) then ! we only do interpolation with 3 or more substeps
               write(*,2) 'interpolate_dx mprev < 3', mprev
               
               stop
               
               ierr = -1
               return
            end if
            


            
            delta = 1d0/dble(mprev) ! fractional size of each substep in prev row
            do i=1,mprev+1
               if (i == mprev+1) then
                  pts(i) = 1d0
               else
                  pts(i) = (i-1)*delta
               end if
               if (i == 1) cycle
               if (pts(i-1) < theta .and. theta <= pts(i)) i0 = i
            end do
            if (i0 == 1) then
               ilo = 1; ihi = 4
            else
               ihi = min(i0+2,mprev+1)
               ilo = ihi - 3
            end if

            if (.false.) then
               write(*,*) 'interpolate_dx'
               j = 1
               do mm=1,3
                  write(*,4) 'dx_save_prev(j,k,mm)', j, k, mm, dx_save_prev(j,k,mm)
               end do
               write(*,2) 'k', k
               write(*,2) 'mprev', mprev
               write(*,2) 'ilo', ilo, pts(ilo)
               write(*,2) 'ihi', ihi, pts(ihi)
               write(*,1) 'delta', delta
               write(*,1) 'theta', theta
               do i=1,mprev+1
                  write(*,2) 'pts(i)', i, pts(i)
               end do
               write(*,*)
            end if

            
            if (.not. (pts(ilo) < theta .and. theta <= pts(ihi))) then
               write(*,*) 'bad ilo iho in interpolate_dx'
               stop
               
            
               ierr = -1
               return
            end if
            
            if (ihi - ilo + 1 /= n_old) then
               write(*,*) 'ihi - ilo + 1 /= n_old in interpolate_dx'
               
               stop
               
            
               ierr = -1
               return
            end if
            
            x_new(1) = theta
            x_old(1:n_old) = pts(ilo:ihi)
            do j=1,nvar
               do i=ilo,ihi
                  ii = i+1-ilo
                  if (ii == 1) then
                     v_old(ii) = 0
                  else
                     v_old(ii) = dx_save_prev(j,k,ii-1)
                  end if
               end do
               call interpolate_vector( &
                  n_old, x_old, n_new, x_new, v_old, v_new, interp_pm, nwork, work, ierr)
               if (ierr /= 0) then
                  write(*,2) 'interpolate_vector failed in interpolate_dx n_old', n_old
                  do i=1,n_old
                     write(*,2) 'x_old', i, x_old(i), pts(i)
                  end do
                  stop
                  return
               end if
               dx(j,k) = v_new(1)
            end do
            
            if (.false.) then
               write(*,4) 'interpolate_dx k ilo ihi', k, ilo, ihi
               do j=1,1
                  write(*,2) 'theta', k, theta
                  write(*,2) 'dx(j,k)', j, dx(j,k)
                  do i=ilo,ihi
                     if (i > 1) write(*,4) 'dx_save_prev(j,k,i-1)', j, k, i-1, dx_save_prev(j,k,i-1)
                  end do
               end do
               write(*,*)
               !stop 'interpolate_dx'
            end if
            
         end subroutine interpolate_dx

         
      end function do_hydro_seulex
      

      

      end module hydro_seulex

