! ***********************************************************************
! Copyright (C) 2012  Bill Paxton
! This file is part of MESA.
! MESA is free software; you can redistribute it and/or modify
! it under the terms of the GNU General Library Public License as published
! by the Free Software Foundation; either version 2 of the License, or
! (at your option) any later version.
! MESA is distributed in the hope that it will be useful,
! but WITHOUT ANY WARRANTY; without even the implied warranty of
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
! GNU Library General Public License for more details.
! You should have received a copy of the GNU Library General Public License
! along with this software; if not, write to the Free Software
! Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
! ***********************************************************************

! derived from BCYCLIC written hirshman et. al.
! S.P.Hirshman, K.S.Perumalla, V.E.Lynch, & R.Sanchez,
! BCYCLIC: A parallel block tridiagonal matrix cyclic solver,
! J. Computational Physics, 229 (2010) 6392-6404.


      module mod_star_bcyclic

      use star_private_def
      use const_def, only: dp
      use mtx_lib, only: num_klu_ipar_decsol, num_klu_rpar_decsol, &
         klu_dble_decsols_nrhs_0_based, dense_to_col_with_diag_0_based
      use utils_lib, only: fill_with_NaNs
      
      implicit none

      integer, parameter :: lid = num_klu_ipar_decsol
      integer, parameter :: lrd = num_klu_rpar_decsol
      
      logical, parameter :: dbg = .false.
      logical, parameter :: do_fill_with_NaNs = .false.
      
      
      contains


      subroutine bcyclic_factor ( &
            s, lblk1, dblk1, ublk1, ipivot1, brhs1, nvar, nz, sparse, iter, &
            lrd, rpar_decsol, lid, ipar_decsol, ierr)
         type (star_info), pointer :: s
         real(dp), pointer :: lblk1(:) ! row section of lower block
         real(dp), pointer :: dblk1(:) ! row section of diagonal block
         real(dp), pointer :: ublk1(:) ! row section of upper block
         integer, pointer :: ipivot1(:) ! row section of pivot array for block factorization
         real(dp), pointer :: brhs1(:) ! row section of rhs
         integer, intent(in) :: nvar ! linear size of each block
         integer, intent(in) :: nz ! number of block rows
         logical, intent(in) :: sparse
         integer, intent(in) :: lrd, lid, iter
         real(dp), pointer, intent(inout) :: rpar_decsol(:) ! (lrd)
         integer, pointer, intent(inout) :: ipar_decsol(:) ! (lid)
         integer, intent(out) :: ierr
      
         integer, pointer :: iptr(:,:), nslevel(:), ipivot(:)
         integer :: ncycle, nstemp, maxlevels, nlevel, i, j, k
         logical :: have_odd_storage, have_klu_storage
         real(dp), pointer, dimension(:,:) :: dmat
         real(dp) :: dlamch, sfmin

         include 'formats'
            
         ierr = 0      
         
         if (dbg) write(*,*) 'start bcyclic_factor'

         ! compute number of cyclic reduction levels
         ncycle = 1
         maxlevels = 0
         do while (ncycle < nz)
            ncycle = 2*ncycle
            maxlevels = maxlevels+1
         end do
         maxlevels = max(1, maxlevels)
      
         have_odd_storage = associated(s% bcyclic_odd_storage)
         if (have_odd_storage) then
            if (size(s% bcyclic_odd_storage) < maxlevels) then
               call clear_storage(s)
               have_odd_storage = .false.
            end if
         end if

         if (.not. have_odd_storage) then
            allocate (s% bcyclic_odd_storage(maxlevels+3), stat=ierr)
            if (ierr /= 0) then
               write(*,*) 'alloc failed for odd_storage in bcyclic'
               return
            end if
            do nlevel = 1, size(s% bcyclic_odd_storage)
               s% bcyclic_odd_storage(nlevel)% ul_size = 0
            end do
         end if

         allocate (nslevel(maxlevels), stat=ierr)
         if (ierr /= 0) return
      
         if (sparse) then
            have_klu_storage = associated(s% bcyclic_klu_storage)
            if (have_klu_storage) then
               if (size(s% bcyclic_klu_storage) < nz) then
                  call clear_klu_storage(s)
                  have_klu_storage = .false.
               end if
            end if
            if (.not. have_klu_storage) then
               allocate (s% bcyclic_klu_storage(nz*2 + 1000), stat=ierr)
               if (ierr /= 0) then
                  write(*,*) 'alloc failed for klu_storage in bcyclic'
                  return
               end if
               do k = 1, size(s% bcyclic_klu_storage)
                  s% bcyclic_klu_storage(k)% sprs_nonzeros = -1
                  s% bcyclic_klu_storage(k)% ia => null()
                  s% bcyclic_klu_storage(k)% ja => null()
                  s% bcyclic_klu_storage(k)% values => null()
               end do
            end if    
         end if  

         ncycle = 1
         nstemp = nz
         nlevel = 1

         if (dbg) write(*,*) 'start factor_cycle'

         factor_cycle: do ! perform cyclic-reduction factorization

            nslevel(nlevel) = nstemp
            
            if (dbg) write(*,2) 'call cycle_onestep', nstemp

            call cycle_onestep( &
               s, nvar, nz, nstemp, ncycle, nlevel, sparse, iter, &
               lblk1, dblk1, ublk1, ipivot1, ierr)
            if (ierr /= 0) then
               !write(*,*) 'cycle_onestep failed'
               call dealloc
               return
            end if

            if (nstemp == 1) exit
         
            nstemp = (nstemp+1)/2
            nlevel = nlevel+1
            ncycle = 2*ncycle

            if (nlevel > maxlevels) exit

         end do factor_cycle

         if (dbg) write(*,*) 'done factor_cycle'
      
         ! factor row 1
         dmat(1:nvar,1:nvar) => dblk1(1:nvar*nvar)
         if (sparse) then
            call sparse_factor(s, 1, nvar, iter, dmat, ierr)
         else
            sfmin = dlamch('S')  
            ipivot(1:nvar) => ipivot1(1:nvar)
            call my_getf2(nvar, dmat, nvar, ipivot, sfmin, ierr)         
         end if
         if (ierr /= 0) then
            write(*,*) 'row 1 factor failed in bcyclic_factor'
            call dealloc
            return
         end if
      
         call dealloc
      
         
         if (dbg) write(*,*) 'done bcyclic_factor'
      
         contains 
      
         subroutine dealloc
            deallocate (nslevel)
         end subroutine dealloc
      
      
      end subroutine bcyclic_factor


      subroutine bcyclic_solve ( &
            s, lblk1, dblk1, ublk1, ipivot1, brhs1, nvar, nz, sparse, &
            lrd, rpar_decsol, lid, ipar_decsol, ierr)
         type (star_info), pointer :: s
         real(dp), pointer :: lblk1(:) ! row section of lower block
         real(dp), pointer :: dblk1(:) ! row section of diagonal block
         real(dp), pointer :: ublk1(:) ! row section of upper block
         integer, pointer :: ipivot1(:) ! row section of pivot array for block factorization
         real(dp), pointer :: brhs1(:)   ! row section of rhs
         integer, intent(in) :: nvar ! linear size of each block
         integer, intent(in) :: nz     ! number of block rows
         logical, intent(in) :: sparse
         integer, intent(in) :: lrd, lid
         real(dp), pointer, intent(inout) :: rpar_decsol(:) ! (lrd)
         integer, pointer, intent(inout) :: ipar_decsol(:) ! (lid)
         integer, intent(out) :: ierr
      
         integer, pointer :: iptr(:,:), nslevel(:), ipivot(:)
         integer :: ncycle, nstemp, maxlevels, nlevel, nvar2
         real(dp), pointer, dimension(:,:) :: dmat, bptr2

         include 'formats'
      
         
         if (dbg) write(*,*) 'start bcyclic_solve'

         ierr = 0      
         nvar2 = nvar*nvar
         ncycle = 1
         maxlevels = 0
         do while (ncycle < nz)
            ncycle = 2*ncycle
            maxlevels = maxlevels+1
         end do
         maxlevels = max(1, maxlevels)

         allocate (nslevel(maxlevels), stat=ierr)
         if (ierr /= 0) return

         ncycle = 1
         nstemp = nz
         nlevel = 1
         
         if (dbg) write(*,*) 'start forward_cycle'

         forward_cycle: do

            nslevel(nlevel) = nstemp
            if (dbg) write(*,2) 'call cycle_rhs', nstemp
            call cycle_rhs( &
               s, nstemp, nvar, ncycle, nlevel, sparse, dblk1, brhs1, ipivot1, ierr)
            if (ierr /= 0) then
               call dealloc
               return
            end if

            if (nstemp == 1) exit
         
            nstemp = (nstemp+1)/2
            nlevel = nlevel+1
            ncycle = 2*ncycle

            if (nlevel > maxlevels) exit

         end do forward_cycle
         
         if (dbg) write(*,*) 'done forward_cycle'
         
         if (sparse) then
            call sparse_solve(s,1,1,nvar,brhs1,ierr)
         else
            ipivot(1:nvar) => ipivot1(1:nvar)
            dmat(1:nvar,1:nvar) => dblk1(1:nvar2)
            bptr2(1:nvar,1:1) => brhs1(1:nvar)
            call my_getrs(nvar, 1, dmat, nvar, ipivot, bptr2, nvar, ierr)
         end if
         if (ierr /= 0) then
            write(*,*) 'failed in bcyclic_solve'
            call dealloc
            return
         end if
      
         ! back solve for even x's
         back_cycle: do while (ncycle > 1)      
            ncycle = ncycle/2
            nlevel = nlevel-1
            if (nlevel < 1) then
               ierr = -1
               exit
            end if
            nstemp = nslevel(nlevel)
            call cycle_solve( &
               s, nvar, nz, ncycle, nstemp, nlevel, sparse, lblk1, ublk1, brhs1)
         end do back_cycle
      
         call dealloc
         
         if (dbg) write(*,*) 'done bcyclic_solve'
      
      
         contains 
      
      
         subroutine dealloc
            deallocate (nslevel)
         end subroutine dealloc


      end subroutine bcyclic_solve
      
      
      subroutine clear_storage(s)
         type (star_info), pointer :: s
         integer :: nlevel
         nlevel = size(s% bcyclic_odd_storage)
         do while (nlevel > 0)
            if (s% bcyclic_odd_storage(nlevel)% ul_size > 0) then
               deallocate(s% bcyclic_odd_storage(nlevel)% umat1)
               deallocate(s% bcyclic_odd_storage(nlevel)% lmat1)
            end if
            nlevel = nlevel-1
         end do
         deallocate(s% bcyclic_odd_storage)
         nullify(s% bcyclic_odd_storage)
      end subroutine clear_storage
      
      
      subroutine clear_klu_storage(s)
         type (star_info), pointer :: s
         integer :: k
         do k = 1, size(s% bcyclic_klu_storage)
            if (associated(s% bcyclic_klu_storage(k)% ia)) &
               deallocate(s% bcyclic_klu_storage(k)% ia)
            if (associated(s% bcyclic_klu_storage(k)% ja)) &
               deallocate(s% bcyclic_klu_storage(k)% ja)
            if (associated(s% bcyclic_klu_storage(k)% values)) &
               deallocate(s% bcyclic_klu_storage(k)% values)
         end do
         deallocate(s% bcyclic_klu_storage)
         nullify(s% bcyclic_klu_storage)
      end subroutine clear_klu_storage


      subroutine cycle_onestep( &
            s, nvar, nz, nblk, ncycle, nlevel, sparse, iter, &
            lblk1, dblk1, ublk1, ipivot1, ierr)
         type (star_info), pointer :: s
         integer, intent(in) :: nvar, nz, nblk, ncycle, nlevel, iter
         logical, intent(in) :: sparse
         real(dp), pointer, intent(inout) :: lblk1(:), dblk1(:), ublk1(:)
         integer, pointer, intent(out) :: ipivot1(:)
         integer, intent(out) :: ierr
     
         integer, pointer :: ipivot(:)
         real(dp), pointer, dimension(:,:) :: dmat, umat, lmat, umat0, lmat0
         real(dp), pointer, dimension(:,:) :: lnext, unext, lprev, uprev
         real(dp), pointer :: mat1(:)
         integer :: i, j, shift, min_sz, new_sz, shift1, shift2, nvar2, &
            ns, ierr_loc, nmin, kcount, k, ii, jj, kk
         real(dp) :: dlamch, sfmin
      
         include 'formats'

         ierr = 0
         sfmin = dlamch('S')  
         nvar2 = nvar*nvar     
         nmin = 1
         kcount = 1+(nblk-nmin)/2
         min_sz = nvar2*kcount
         if (s% bcyclic_odd_storage(nlevel)% ul_size < min_sz) then
            if (s% bcyclic_odd_storage(nlevel)% ul_size > 0) &
               deallocate(s% bcyclic_odd_storage(nlevel)% umat1, s% bcyclic_odd_storage(nlevel)% lmat1)         
            new_sz = min_sz*1.1 + 100
            s% bcyclic_odd_storage(nlevel)% ul_size = new_sz
            allocate (s% bcyclic_odd_storage(nlevel)% umat1(new_sz), &
                      s% bcyclic_odd_storage(nlevel)% lmat1(new_sz), stat=ierr)
            if (ierr /= 0) then
               write(*,*) 'allocation error in cycle_onestep'
               return
            end if
         end if

!$omp parallel do private(ns,kcount,shift,shift2,i)
         do ns = nmin, nblk, 2  ! copy umat and lmat
            kcount = (ns-nmin)/2 + 1
            shift = nvar2*(kcount-1)
            shift2 = nvar2*ncycle*(ns-1)
            do i=1,nvar2
               s% bcyclic_odd_storage(nlevel)% umat1(shift+i) = ublk1(shift2+i)
               s% bcyclic_odd_storage(nlevel)% lmat1(shift+i) = lblk1(shift2+i)
            end do
         end do
!$omp end parallel do

         if (nvar2*kcount > s% bcyclic_odd_storage(nlevel)% ul_size) then
            write(*,*) 'nvar2*kcount > ul_size in cycle_onestep'
            ierr = -1
            return
         end if
         
         if (dbg) write(*,*) 'start lu factorization'
         ! compute lu factorization of even diagonal blocks
         nmin = 2         
!$omp parallel do schedule(static,3) &
!$omp private(ipivot,dmat,ns,ierr_loc,shift1,shift2,k)
         do ns = nmin, nblk, 2
            k = ncycle*(ns-1) + 1
            shift1 = nvar*(k-1)
            shift2 = nvar*shift1
            dmat(1:nvar,1:nvar) => dblk1(shift2+1:shift2+nvar2)
            ierr_loc = 0
            if (sparse) then
               call sparse_factor(s, k, nvar, iter, dmat, ierr_loc)
            else
               ipivot(1:nvar) => ipivot1(shift1+1:shift1+nvar)
               call my_getf2(nvar, dmat, nvar, ipivot, sfmin, ierr_loc)         
            end if
            if (ierr_loc /= 0) then
               ierr = ierr_loc
            end if
         end do 
!$omp end parallel do
	      if (ierr /= 0) then
	         !write(*,*) 'factorization failed in bcyclic'
	         return
	      end if

         if (dbg) write(*,*) 'done lu factorization; start solve'

!$omp parallel do schedule(static,3) &
!$omp private(ns,k,shift1,shift2,ipivot,dmat,umat,lmat,mat1,ierr_loc)
         do ns = nmin, nblk, 2       
            ! compute new l=-d[-1]l, u=-d[-1]u for even blocks
            k = ncycle*(ns-1) + 1
            shift1 = nvar*(k-1)
            shift2 = nvar*shift1
            lmat(1:nvar,1:nvar) => lblk1(shift2+1:shift2+nvar2)
            if (sparse) then
               mat1(1:nvar2) => lblk1(shift2+1:shift2+nvar2)
               call sparse_solve(s,k,nvar,nvar,mat1,ierr_loc)
            else
               ipivot(1:nvar) => ipivot1(shift1+1:shift1+nvar)
               dmat(1:nvar,1:nvar) => dblk1(shift2+1:shift2+nvar2)
               call my_getrs(nvar, nvar, dmat, nvar, ipivot, lmat, nvar, ierr_loc)
            end if
            if (ierr_loc /= 0) ierr = ierr_loc
            lmat = -lmat
            umat(1:nvar,1:nvar) => ublk1(shift2+1:shift2+nvar2)
            if (sparse) then
               mat1(1:nvar2) => ublk1(shift2+1:shift2+nvar2)
               call sparse_solve(s,k,nvar,nvar,mat1,ierr_loc)
            else
               call my_getrs(nvar, nvar, dmat, nvar, ipivot, umat, nvar, ierr_loc)
            end if
            if (ierr_loc /= 0) ierr = ierr_loc
            umat = -umat
         end do 
!$omp end parallel do
         if (dbg) write(*,*) 'done solve'

         if (ierr /= 0) return

         ! compute new odd blocks in terms of even block factors
         ! compute odd hatted matrix elements except at boundaries
         nmin = 1
!$omp parallel do schedule(static,3) &
!$omp private(i,ns,shift2,dmat,umat,lmat,lnext,unext,lprev,uprev,kcount,shift,umat0,lmat0,k)
         do i= 1, 3*(1+(nblk-nmin)/2)
         
            ns = 2*((i-1)/3) + nmin
            k = ncycle*(ns-1) + 1
            shift2 = nvar2*(k-1)
            dmat(1:nvar,1:nvar) => dblk1(shift2+1:shift2+nvar2)
            umat(1:nvar,1:nvar) => ublk1(shift2+1:shift2+nvar2)
            lmat(1:nvar,1:nvar) => lblk1(shift2+1:shift2+nvar2)

            if (ns < nblk) then
               shift2 = nvar2*ncycle*ns
               lnext(1:nvar,1:nvar) => lblk1(shift2+1:shift2+nvar2)
               unext(1:nvar,1:nvar) => ublk1(shift2+1:shift2+nvar2)
            end if

            if (ns > 1) then
               shift2 = nvar2*ncycle*(ns-2)
               lprev(1:nvar,1:nvar) => lblk1(shift2+1:shift2+nvar2)
               uprev(1:nvar,1:nvar) => ublk1(shift2+1:shift2+nvar2)
            end if

            kcount = 1+(ns-nmin)/2
            shift = nvar2*(kcount-1)
            lmat0(1:nvar,1:nvar) => s% bcyclic_odd_storage(nlevel)% lmat1(shift+1:shift+nvar2)
            umat0(1:nvar,1:nvar) => s% bcyclic_odd_storage(nlevel)% umat1(shift+1:shift+nvar2)

            select case(mod(i-1,3))
            case (0)
               if (ns > 1) then
                  ! lmat = matmul(lmat0, lprev)
                  call my_gemm0_p1(nvar,nvar,nvar,lmat0,nvar,lprev,nvar,lmat,nvar)
               end if  
            case (1)
               if (ns < nblk) then
                  ! umat = matmul(umat0, unext)
                  call my_gemm0_p1(nvar,nvar,nvar,umat0,nvar,unext,nvar,umat,nvar)
               end if
            case (2)
               if (ns < nblk) then
                  if (ns > 1) then
                     ! dmat = dmat + matmul(umat0, lnext) + matmul(lmat0,uprev)
                     call my_gemm_plus_mm(nvar,nvar,nvar,umat0,lnext,lmat0,uprev,dmat)
                  else
                     ! dmat = dmat + matmul(umat0, lnext)
                     call my_gemm_p1(nvar,nvar,nvar,umat0,nvar,lnext,nvar,dmat,nvar)
                  end if
               else if (ns > 1) then
                  ! dmat = dmat + matmul(lmat0,uprev)
                  call my_gemm_p1(nvar,nvar,nvar,lmat0,nvar,uprev,nvar,dmat,nvar)
               end if  
            end select

         end do
!$omp end parallel do
         if (dbg) write(*,*) 'done cycle_onestep'

      end subroutine cycle_onestep


      subroutine cycle_rhs( &
            s, nblk, nvar, ncycle, nlevel, sparse, dblk1, brhs1, ipivot1, ierr)
         type (star_info), pointer :: s
         integer, intent(in) :: nblk, nvar, ncycle, nlevel
         logical, intent(in) :: sparse
         real(dp), pointer, intent(in) :: dblk1(:)
         real(dp), pointer, intent(inout) :: brhs1(:)
         integer, pointer, intent(in) :: ipivot1(:)
         integer, intent(out) :: ierr
      
         integer :: k, ns, ierr_loc, nmin, kcount, shift, i, shift1, shift2, nvar2
         integer, pointer :: ipivot(:)
         real(dp), pointer, dimension(:,:) :: dmat, umat, lmat, bptr2
         real(dp), pointer, dimension(:) :: bprev, bnext, bptr
      
         include 'formats'
      
         ierr = 0
         nvar2 = nvar*nvar
         ! compute dblk[-1]*brhs for even indices and store in brhs(even)
         nmin = 2
   	   ierr_loc = 0
!$omp parallel do schedule(static,3) &
!$omp private(ns,shift1,ipivot,shift2,k,dmat,bptr2,bptr,ierr_loc)
         do ns = nmin, nblk, 2
            k = ncycle*(ns-1) + 1
            shift1 = nvar*(k-1)
            if (sparse) then
               bptr(1:nvar) => brhs1(shift1+1:shift1+nvar)
               call sparse_solve(s,k,1,nvar,bptr,ierr_loc)
            else
               shift2 = nvar*shift1
               ipivot(1:nvar) => ipivot1(shift1+1:shift1+nvar)
               dmat(1:nvar,1:nvar) => dblk1(shift2+1:shift2+nvar2)
               bptr2(1:nvar,1:1) => brhs1(shift1+1:shift1+nvar)
               call my_getrs(nvar, 1, dmat, nvar, ipivot, bptr2, nvar, ierr_loc)
            end if
            if (ierr_loc /= 0) ierr = ierr_loc
         end do
!$omp end parallel do
   	  if (ierr /= 0) return

        ! compute odd (hatted) sources (b-hats) for interior rows
         nmin = 1
         kcount = 0
!$omp parallel do schedule(static,3) &
!$omp private(ns,shift1,bptr,kcount,shift,umat,lmat,bnext,bprev)
         do ns = nmin, nblk, 2
            shift1 = nvar*ncycle*(ns-1)
            bptr(1:nvar) => brhs1(shift1+1:shift1+nvar)
            kcount = 1+(ns-nmin)/2
            shift = nvar2*(kcount-1)         
            umat(1:nvar,1:nvar) => s% bcyclic_odd_storage(nlevel)% umat1(shift+1:shift+nvar2)
            lmat(1:nvar,1:nvar) => s% bcyclic_odd_storage(nlevel)% lmat1(shift+1:shift+nvar2)
            if (ns > 1) then
               shift1 = nvar*ncycle*(ns-2)
               bprev => brhs1(shift1+1:shift1+nvar)
            end if
            if (ns < nblk) then
               shift1 = nvar*ncycle*ns
               bnext => brhs1(shift1+1:shift1+nvar)
               if (ns > 1) then
                  ! bptr = bptr - matmul(umat,bnext) - matmul(lmat,bprev)
                  call my_gemv_mv(nvar,nvar,umat,bnext,lmat,bprev,bptr)
               else
                  ! bptr = bptr - matmul(umat,bnext)
                  call my_gemv(nvar,nvar,umat,nvar,bnext,bptr)
               end if
            else if (ns > 1) then
               ! bptr = bptr - matmul(lmat,bprev)
               call my_gemv(nvar,nvar,lmat,nvar,bprev,bptr)
            end if
         end do 
!$omp end parallel do

         if (nvar2*kcount > s% bcyclic_odd_storage(nlevel)% ul_size) then
            write(*,*) 'nvar2*kcount > ul_size in cycle_rhs'
            ierr = -1
            return
         end if

      end subroutine cycle_rhs


      ! computes even index solution from the computed (at previous,higher level)
      ! odd index solutions at this level.
      ! note at this point, the odd brhs values have been replaced (at the highest cycle)
      ! with the solution values (x), at subsequent (lower) cycles, the
      ! odd values are replaced by the even solutions at the next highest cycle. the even 
      ! brhs values were multiplied by d[-1] and stored in cycle_rhs
      ! solve for even index values in terms of (computed at this point) odd index values
      subroutine cycle_solve( &
            s, nvar, nz, ncycle, nblk, nlevel, sparse, lblk1, ublk1, brhs1)
         type (star_info), pointer :: s
         integer, intent(in) :: nvar, nz, ncycle, nblk, nlevel
         logical, intent(in) :: sparse
         real(dp), pointer, intent(in) :: lblk1(:), ublk1(:)
         real(dp), pointer, intent(inout) :: brhs1(:)

         real(dp), pointer :: umat(:,:), lmat(:,:), bprev(:), bnext(:), bptr(:)
         real(dp), pointer, dimension(:) :: bprevr, bnextr
         integer :: shift1, shift2, nvar2, ns, ierr, nmin

         nvar2 = nvar*nvar
         nmin = 2
!$omp parallel do schedule(static,3) &
!$omp private(ns,shift1,bptr,shift2,lmat,bprev,umat,bnext)
         do ns = nmin, nblk, 2
            shift1 = ncycle*nvar*(ns-1)
            bptr(1:nvar) => brhs1(shift1+1:shift1+nvar)
            shift2 = nvar*shift1
            lmat(1:nvar,1:nvar) => lblk1(shift2+1:shift2+nvar2)
            if (ns > 1) then
               shift1 = ncycle*nvar*(ns-2)
               bprev(1:nvar) => brhs1(shift1+1:shift1+nvar)
            end if
            if (ns < nblk) then
               umat(1:nvar,1:nvar) => ublk1(shift2+1:shift2+nvar2)
               shift1 = ncycle*nvar*ns
               bnext(1:nvar) => brhs1(shift1+1:shift1+nvar)
               if (ns > 1) then
                  ! bptr = bptr + matmul(umat,bnext) + matmul(lmat,bprev)
                  call my_gemv_p_mv(nvar,nvar,umat,bnext,lmat,bprev,bptr)
               else
                  ! bptr = bptr + matmul(umat,bnext)
                  call my_gemv_p1(nvar,nvar,umat,nvar,bnext,bptr)
               end if
            else if (ns > 1) then
               ! bptr = bptr + matmul(lmat,bprev)
               call my_gemv_p1(nvar,nvar,lmat,nvar,bprev,bptr)
            end if
         end do
!$omp end parallel do

      end subroutine cycle_solve
      

      subroutine sparse_factor(s, k, nvar, iter, mtx, ierr)   
         use mod_star_sparse
         
         type (star_info), pointer :: s
         integer, intent(in) :: k, nvar, iter
         real(dp), pointer, intent(inout) :: mtx(:,:)
         integer, intent(out) :: ierr
         
         logical, parameter :: use_pivoting = .true.
         logical :: did_refactor
         real(dp) :: rgrowth, condest
         type(sparse_info), pointer :: ks(:)
         include 'formats'

         ks => s% bcyclic_klu_storage
         
         call star_sparse_setup_shared(s, k, 1, nvar, ierr)
         if (ierr /= 0) stop 'sparse_factor'
      
         if (.not. use_pivoting) then
            call star_sparse_no_pivot(s, k, nvar, ierr) 
            if (ierr /= 0) stop 'sparse_factor'
         end if

         call star_sparse_store_new_values(s, k, nvar, mtx, .false., ierr)                  
         if (ierr /= 0) then
            write(*,3) 'sparse_store_new_values failed', k, s% model_number
            stop 'sparse_factor'
         end if
      
         did_refactor = .false.
         if (iter > 1 .and. ks(k)% have_Numeric) then ! try refactor
            rgrowth = 0
            call star_sparse_refactor(s, k, nvar, mtx, ierr) 
            if (ierr == 0) rgrowth = star_sparse_rgrowth(s, k, nvar, ierr)
            if (ierr /= 0 .or. rgrowth < 1d-4) then
               ierr = 0
            else
               did_refactor = .true.
            end if
         end if
         
         if (.not. did_refactor) then
            call star_sparse_factor(s, k, nvar, mtx, ierr)  
            if (ierr /= 0) then
               if (dbg) then
                  write(*,3) 'sparse_factor failed', k, s% model_number
               end if
               return
               stop 'sparse_factor'
            end if
         
            if (.false.) then
               condest = star_sparse_condest(s, k, nvar, ierr)
               write(*,3) 'sparse_factor condest', &
                  k, s% model_number, condest
            end if

         end if
         
      end subroutine sparse_factor


      subroutine sparse_solve(s, k, nrhs, nvar, b, ierr)
         use mod_star_sparse, only: star_sparse_solve
         type (star_info), pointer :: s
         integer, intent(in) :: k, nrhs, nvar
         real(dp), pointer :: b(:)
         integer, intent(out) :: ierr
         call star_sparse_solve(s, k, nrhs, nvar, b, ierr)         
      end subroutine sparse_solve
      

      subroutine old_sparse_factor(s, k, mblk, dmat, ierr)   
         
         type (star_info), pointer :: s
         integer, intent(in) :: k, mblk
         real(dp), pointer, intent(inout) :: dmat(:,:)
         integer, intent(out) :: ierr
      
         type(sparse_info), pointer :: ks(:)
         real(dp), target :: b_ary(1)
         real(dp), pointer :: values(:), b(:), rpar(:)
         integer, pointer :: ipar(:)
         integer :: nonzero_cnt, j, i, sprs_nonzeros
         
         include 'formats'
         
         ierr = 0
         ks => s% bcyclic_klu_storage
      
         nonzero_cnt = 0
         do j=1,mblk
            do i=1,mblk
               if (i == j) then ! don't clip diagonals
                  nonzero_cnt = nonzero_cnt + 1
                  cycle
               end if
               if (dmat(i,j) == 0) cycle
               nonzero_cnt = nonzero_cnt + 1
            end do
         end do
         
         !call enlarge_integer_if_needed_1(ks(k)% ia, mblk+1, 10, ierr)
         if (associated(ks(k)% ia)) then
            if (size(ks(k)% ia,dim=1) < mblk+1) then
               deallocate(ks(k)% ia)
               allocate(ks(k)% ia(mblk + 10), stat=ierr)
            end if
         else 
            allocate(ks(k)% ia(mblk + 10), stat=ierr)
         end if
         if (do_fill_with_NaNs) ks(k)% ia = -9999999
         if (ierr /= 0) then
            write(*,2) 'allocate failed for ia', mblk+1
            stop
            return
         end if
         
         if (associated(ks(k)% ja)) then
            if (size(ks(k)% ja,dim=1) < nonzero_cnt) then
               deallocate(ks(k)% ja)
               allocate(ks(k)% ja(nonzero_cnt + 100), stat=ierr)
            end if
         else 
            allocate(ks(k)% ja(nonzero_cnt + 100), stat=ierr)
         end if
         if (do_fill_with_NaNs) ks(k)% ja = -9999999
         if (ierr /= 0) then
            write(*,2) 'allocate failed for ja', nonzero_cnt
            stop
            return
         end if
         
         if (associated(ks(k)% values)) then
            if (size(ks(k)% values,dim=1) < nonzero_cnt) then
               deallocate(ks(k)% values)
               allocate(ks(k)% values(nonzero_cnt + 100), stat=ierr)
            end if
         else 
            allocate(ks(k)% values(nonzero_cnt + 100), stat=ierr)
         end if
         if (do_fill_with_NaNs) call fill_with_NaNs(ks(k)% values)
         if (ierr /= 0) then
            write(*,2) 'allocate failed for values', nonzero_cnt
            stop
            return
         end if
         
         values => ks(k)% values
         
         ! compressed_format is compressed_col_sparse_0_based for KLU
         call dense_to_col_with_diag_0_based( &
            mblk, mblk, dmat, nonzero_cnt, sprs_nonzeros, &
            ks(k)% ia, ks(k)% ja, values, ierr)
         if (ierr /= 0) then
            write(*,*) 'bcyclic failed in converting from dense to sparse'
            return
         end if
         if (sprs_nonzeros /= nonzero_cnt) then
            write(*,*) &
               'bcyclic failed in converting from dense to sparse: bad sprs_nonzeros'
            ierr = -1
            return
         end if
         ks(k)% sprs_nonzeros = sprs_nonzeros
         !write(*,3) 'sprs_nonzeros', &
            !sprs_nonzeros, mblk*mblk, dble(sprs_nonzeros)/dble(mblk*mblk)

         ipar => ks(k)% ipar8_decsol
         rpar => ks(k)% rpar_decsol
         b(1:1) => b_ary(1:1)
         if (do_fill_with_NaNs) then
            ipar(:) = -9999999
            call fill_with_NaNs(rpar)
         end if
         
         call klu_dble_decsols_nrhs_0_based( & ! factor
            0, mblk, mblk, sprs_nonzeros, &
            ks(k)% ia, ks(k)% ja, ks(k)% values, b, &
            lrd, rpar, lid, ipar, ierr)
            
         if (ierr /= 0) then
            write(*,2) 'klu_dble_decsols_nrhs_0_based failed factor', ierr
            write(*,2) 'k', k
            write(*,2) 'sprs_nonzeros', sprs_nonzeros
            write(*,*)
         end if
      
      end subroutine old_sparse_factor


      subroutine old_sparse_solve(s, k, nrhs, mblk, b, ierr)
         type (star_info), pointer :: s
         integer, intent(in) :: k, nrhs, mblk
         real(dp), pointer :: b(:)
         integer, intent(out) :: ierr
         real(dp), pointer :: rpar(:)
         type(sparse_info), pointer :: ks(:)
         integer, pointer :: ipar(:)
         include 'formats'
         ks => s% bcyclic_klu_storage
         ipar => ks(k)% ipar8_decsol
         rpar => ks(k)% rpar_decsol
         call klu_dble_decsols_nrhs_0_based( & ! solve
            1, nrhs, mblk, ks(k)% sprs_nonzeros, &
            ks(k)% ia, ks(k)% ja, ks(k)% values, b, &
            lrd, rpar, lid, ipar, ierr)
      end subroutine old_sparse_solve


      subroutine bcyclic_deallocate ( &
            s, lblk1, dblk1, ublk1, ipivot1, brhs1, nvar, nz, sparse, &
            lrd, rpar_decsol, lid, ipar_decsol, ierr)
         type (star_info), pointer :: s
         real(dp), pointer :: lblk1(:) ! row section of lower block
         real(dp), pointer :: dblk1(:) ! row section of diagonal block
         real(dp), pointer :: ublk1(:) ! row section of upper block
         integer, pointer :: ipivot1(:) ! row section of pivot array for block factorization
         real(dp), pointer :: brhs1(:) ! row section of rhs
         integer, intent(in) :: nvar ! linear size of each block
         integer, intent(in) :: nz ! number of block rows
         logical, intent(in) :: sparse
         integer, intent(in) :: lrd, lid
         real(dp), pointer, intent(inout) :: rpar_decsol(:) ! (lrd)
         integer, pointer, intent(inout) :: ipar_decsol(:) ! (lid)
         integer, intent(out) :: ierr
         
         integer :: k
         real(dp), pointer :: rpar(:), b(:)
         type(sparse_info), pointer :: ks(:)
         integer, pointer :: ipar(:)
         real(dp), target :: b_ary(1)
         
         ierr = 0
         if (.not. sparse) return
         
         ks => s% bcyclic_klu_storage
         b(1:1) => b_ary(1:1)
         do k = 1, size(s% bcyclic_klu_storage)
            if (ks(k)% sprs_nonzeros > 0) then
               ipar => ks(k)% ipar8_decsol
               rpar => ks(k)% rpar_decsol
               call klu_dble_decsols_nrhs_0_based( & ! free
                  2, 0, nvar, ks(k)% sprs_nonzeros, &
                  ks(k)% ia, ks(k)% ja, ks(k)% values, b, &
                  lrd, rpar, lid, ipar, ierr)
               ks(k)% sprs_nonzeros = -1
            end if
         end do

      end subroutine bcyclic_deallocate

      
      subroutine my_getf2(m, a, lda, ipiv, sfmin, info)
         integer :: info, lda, m
         integer :: ipiv(:)
         real(dp) :: a(:,:)
         real(dp) :: sfmin
         real(dp), parameter :: one=1, zero=0
         integer :: i, j, jp, ii, jj, n, mm
         real(dp) :: tmp, da
         do j = 1, m
            info = 0
            jp = j - 1 + maxloc(abs(a(j:lda,j)),dim=1)
            ipiv( j ) = jp
            if( a( jp, j ).ne.zero ) then
               if( jp.ne.j ) then ! swap a(j,:) and a(jp,:)
                  do i=1,m
                     tmp = a(j,i)
                     a(j,i) = a(jp,i)
                     a(jp,i) = tmp
                  end do
               end if
               if( j.lt.m ) then 
                  if( abs(a( j, j )) .ge. sfmin ) then
                     da = one / a( j, j )
                     n = m-j
                     mm = mod(n,5)
                     if (mm /= 0) then
                        do i = 1,mm
                           a(j+i,j) = da*a(j+i,j)
                        end do
                     end if
                     if (n >= 5) then
                        do i = mm + 1,n,5
                           a(j+i,j) = da*a(j+i,j)
                           a(j+i+1,j) = da*a(j+i+1,j)
                           a(j+i+2,j) = da*a(j+i+2,j)
                           a(j+i+3,j) = da*a(j+i+3,j)
                           a(j+i+4,j) = da*a(j+i+4,j)
                        end do
                     end if
                  else ! no scale
                    do i = 1, m-j 
                       a( j+i, j ) = a( j+i, j ) / a( j, j ) 
                    end do 
                  end if 
               end if 
            else if( info.eq.0 ) then
               info = j
            end if
            if( j.lt.m ) then
               !call dger( m-j, m-j, -one, a( j+1, j ), 1, a( j, j+1 ), lda, a( j+1, j+1 ), lda )
               do jj = j+1, m
                  do ii = j+1, m
                     a(ii,jj) = a(ii,jj) - a(ii,j)*a(j,jj)
                  end do
               end do
            end if
         end do
      end subroutine my_getf2
            
      
      subroutine my_getrs( n, nrhs, a, lda, ipiv, b, ldb, info )
         integer :: info, lda, ldb, n, nrhs
         integer, pointer :: ipiv(:)
         real(dp), pointer :: a(:,:), b(:,:) ! a( lda, * ), b( ldb, * )
         real(dp), parameter :: one=1, zero=0
         real(dp) :: temp
         integer :: i,j,k, n32, ix, ip
         info = 0
         call my_laswp(nrhs, b, ldb, 1, n, ipiv, 1 )
         !call dtrsm( 'left', 'lower', 'no transpose', 'unit', n, nrhs, one, a, lda, b, ldb )
         do j = 1,nrhs
            do k = 1,n
               if (b(k,j).ne.zero) then
                  do i = k + 1,n
                     b(i,j) = b(i,j) - b(k,j)*a(i,k)
                  end do
               end if
            end do
         end do
         !call dtrsm( 'left', 'upper', 'no transpose', 'non-unit', n, nrhs, one, a, lda, b, ldb )
         do j = 1,nrhs
            do k = n,1,-1
               if (b(k,j).ne.zero) then
                  b(k,j) = b(k,j)/a(k,k)
                  do i = 1,k - 1
                     b(i,j) = b(i,j) - b(k,j)*a(i,k)
                  end do
               end if
            end do
         end do
         
      end subroutine my_getrs
      
      
      subroutine my_laswp( n,   a, lda,  k1, k2, ipiv,  incx )
         integer :: incx, k1, k2, lda, n
         integer :: ipiv(:)
         real(dp) :: a(:,:) ! a( lda, * )
         integer :: i, i1, i2, inc, ip, ix, ix0, j, k, n32
         real(dp) :: temp
         ! interchange row i with row ipiv(i) for each of rows k1 through k2.
         if( incx.gt.0 ) then
            ix0 = k1
            i1 = k1
            i2 = k2
            inc = 1
         else if( incx.lt.0 ) then
            ix0 = 1 + ( 1-k2 )*incx
            i1 = k2
            i2 = k1
            inc = -1
         else
            return
         end if
         n32 = ( n / 32 )*32
         if( n32.ne.0 ) then
            do j = 1, n32, 32
               ix = ix0
               do i = i1, i2, inc
                  ip = ipiv( ix )
                  if( ip.ne.i ) then
                     do k = j, j + 31
                        temp = a( i, k )
                        a( i, k ) = a( ip, k )
                        a( ip, k ) = temp
                     end do
                  end if
                  ix = ix + incx
               end do
            end do
         end if
         if( n32.ne.n ) then
            n32 = n32 + 1
            ix = ix0
            do i = i1, i2, inc
               ip = ipiv( ix )
               if( ip.ne.i ) then
                  do k = n32, n
                     temp = a( i, k )
                     a( i, k ) = a( ip, k )
                     a( ip, k ) = temp
                  end do
               end if
               ix = ix + incx
            end do
         end if      
      end subroutine my_laswp
      
      
      subroutine my_gemm0_p1(m,n,k,a,lda,b,ldb,c,ldc)
         ! c := -a*b
         integer, intent(in) :: k,lda,ldb,ldc,m,n
         real(dp), dimension(:,:) :: a, b, c ! a(lda,*),b(ldb,*),c(ldc,*)
         integer :: j, i
         real(dp), parameter :: zero=0
         include 'formats.dek'
         ! transa = 'n'
         ! transb = 'n'
         ! alpha = -1
         ! beta = 0
         ! assumes other args are valid
         do j=1,n
            do i=1,m
               c(i,j) = zero
            end do
         end do
         call my_gemm_p1(m,n,k,a,lda,b,ldb,c,ldc)
      end subroutine my_gemm0_p1
      
      
      subroutine my_gemm_plus_mm(m,n,k,a,b,d,e,c) ! c := c + a*b + d*e
         integer, intent(in) :: k,m,n
         real(dp), dimension(:,:) :: a, b, c, d, e
         real(dp) :: tmp_b, tmp_e
         real(dp), parameter :: zero=0
         integer :: j, i, l
         do j = 1,n
            do l = 1,k
               tmp_b = b(l,j)
               tmp_e = e(l,j)
               if (tmp_b .ne. zero) then
                  if (tmp_e .ne. zero) then
                     do i = 1,m
                        c(i,j) = c(i,j) + tmp_b*a(i,l) + tmp_e*d(i,l)
                     end do
                  else
                     do i = 1,m
                        c(i,j) = c(i,j) + tmp_b*a(i,l)
                     end do
                  end if
               else if (tmp_e .ne. zero) then
                  do i = 1,m
                     c(i,j) = c(i,j) + tmp_e*d(i,l)
                  end do
               end if
            end do
         end do      
      end subroutine my_gemm_plus_mm
      
      
      subroutine my_gemm_p1(m,n,k,a,lda,b,ldb,c,ldc) ! c := c + a*b
         integer, intent(in) :: k,lda,ldb,ldc,m,n
         real(dp), dimension(:,:) :: a, b, c ! a(lda,*),b(ldb,*),c(ldc,*)
         real(dp) :: tmp
         real(dp), parameter :: zero=0
         integer :: j, i, l
         ! transa = 'n'
         ! transb = 'n'
         ! alpha = 1
         ! beta = 1
         ! assumes other args are valid
         do j = 1,n
            do l = 1,k
               tmp = b(l,j)
               if (tmp .ne. zero) then
                  do i = 1,m
                     c(i,j) = c(i,j) + tmp*a(i,l)
                  end do
               end if
            end do
         end do      
      end subroutine my_gemm_p1


      subroutine my_gemv_mv(m,n,a,x,b,z,y) ! y = y - a*x - b*z
         integer lda,m,n
         real(dp) :: a(:,:), b(:,:)
         real(dp) :: x(:), z(:), y(:)
         real(dp) :: tmp_x, tmp_z
         real(dp), parameter :: zero=0
         integer :: j, i
         do j = 1,n
            tmp_x = x(j)
            tmp_z = z(j)
            if (tmp_x.ne.zero) then
               if (tmp_z /= zero) then
                  do i = 1,m
                     y(i) = y(i) - tmp_x*a(i,j) - tmp_z*b(i,j)
                  end do
               else
                  do i = 1,m
                     y(i) = y(i) - tmp_x*a(i,j)
                  end do
               end if
            else if (tmp_z /= zero) then
               do i = 1,m
                  y(i) = y(i) - tmp_z*b(i,j)
               end do
            end if
         end do
      end subroutine my_gemv_mv


      subroutine my_gemv(m,n,a,lda,x,y) ! y = y - a*x
         integer lda,m,n
         real(dp) :: a(:,:) ! (lda,*)
         real(dp) :: x(:), y(:)
         real(dp) :: tmp
         real(dp), parameter :: zero=0
         ! trans = 'n'
         ! alpha = -1
         ! beta = 1
         ! incx = 1
         ! incy = 1
         integer :: j, i
         do j = 1,n
            tmp = x(j)
            if (tmp.ne.zero) then
               do i = 1,m
                  y(i) = y(i) - tmp*a(i,j)
               end do
            end if
         end do
      end subroutine my_gemv


      subroutine my_gemv_p_mv(m,n,a,x,b,z,y) ! y = y + a*x + b*z
         integer lda,m,n
         real(dp) :: a(:,:), b(:,:)
         real(dp) :: x(:), z(:), y(:)
         real(dp) :: tmp_x, tmp_z
         real(dp), parameter :: zero=0
         integer :: j, i
         do j = 1,n
            tmp_x = x(j)
            tmp_z = z(j)
            if (tmp_x.ne.zero) then
               if (tmp_z /= zero) then
                  do i = 1,m
                     y(i) = y(i) + tmp_x*a(i,j) + tmp_z*b(i,j)
                  end do
               else
                  do i = 1,m
                     y(i) = y(i) + tmp_x*a(i,j)
                  end do
               end if
            else if (tmp_z /= zero) then
               do i = 1,m
                  y(i) = y(i) + tmp_z*b(i,j)
               end do
            end if
         end do
      end subroutine my_gemv_p_mv


      subroutine my_gemv_p1(m,n,a,lda,x,y) ! y = y + a*x
         integer lda,m,n
         real(dp) :: a(:,:) ! (lda,*)
         real(dp) :: x(:), y(:)
         real(dp) :: tmp
         real(dp), parameter :: zero=0
         ! trans = 'n'
         ! alpha = -1
         ! beta = 1
         ! incx = 1
         ! incy = 1
         integer :: j, i
         do j = 1,n
            tmp = x(j)
            if (tmp.ne.zero) then
               do i = 1,m
                  y(i) = y(i) + tmp*a(i,j)
               end do
            end if
         end do
      end subroutine my_gemv_p1
      
      
      end module mod_star_bcyclic
