! ***********************************************************************
!
!   copyright (c) 2012  bill paxton
!
!   mesa is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the mesa manifesto
!   and the gnu general library public license as published
!   by the free software foundation; either version 2 of the license,
!   or (at your option) any later version.
!
!   you should have received a copy of the mesa manifesto along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   mesa is distributed in the hope that it will be useful,
!   but without any warranty; without even the implied warranty of
!   merchantability or fitness for a particular purpose.
!   see the gnu library general public license for more details.
!
!   you should have received a copy of the gnu library general public license
!   along with this software; if not, write to the free software
!   foundation, inc., 59 temple place, suite 330, boston, ma 02111-1307 usa
!
! ***********************************************************************


      module solve_split_mtx  

      use star_private_def
      use alert_lib
      use const_def
      !use mtx_lib, only: &
      !   d_and_c_alloc, d_and_c_factor, d_and_c_solve, d_and_c_dealloc, &
      !   d_and_c_dble_work_sizes
      use mtx_lib, only: block_dc_mt_dble_work_sizes, block_dc_mt_dble_decsolblk       
         
      
      implicit none
      
      logical, parameter :: dbg = .false.
      
      logical :: have_initialized = .false.
      
      type split_mtx_info
         integer :: nvarh, species, nz
         real(dp), dimension(:,:,:), pointer :: &
            Jhh_l, Jhh_d, Jhh_u, & ! (nvarh,nvarh,nz)  d(equ_xh)/d(xh)
            Jha_l, Jha_d, Jha_u, & ! (nvarh,species,nz) d(equ_xh)/d(xa)
            Jah_l, Jah_d, Jah_u, & ! (species,nvarh,nz) d(equ_xa)/d(xh)
            Jaa_l, Jaa_d, Jaa_u    ! (species,species,nz) d(equ_xa)/d(xa)
         real(dp), dimension(:,:,:), pointer :: & ! factored jacobians
            Jhh_l_fac, Jhh_d_fac, Jhh_u_fac, & ! (nvarh,nvarh,nz)  d(equ_xh)/d(xk)
            Jaa_l_fac, Jaa_d_fac, Jaa_u_fac    ! (species,species,nz) d(equ_xa)/d(xa)
         real(dp), dimension(:,:), pointer :: & ! working vectors of blocks
            xh, dxh, rh, wh, & ! (nvarh,nz)
            xa, dxa, ra, wa    ! (species,nz)
         !integer :: Jhh_id, Jaa_id
         integer :: Jhh_lrd, Jaa_lrd
         integer :: Jhh_lid, Jaa_lid
         real(dp), dimension(:), pointer :: Jhh_rpar_decsol, Jaa_rpar_decsol
         integer, dimension(:), pointer :: Jhh_ipar_decsol, Jaa_ipar_decsol
         ! bookkeeping
         integer :: handle
         logical :: in_use
      end type split_mtx_info
      
      
      integer, parameter :: max_handles = 100
      type (split_mtx_info), target :: handles(max_handles)
      
      
      contains         
      
      
      subroutine split_mtx_factor( &
            id, caller_id, nvar, nz, lblk, dblk, ublk, &
            lrd, rpar_decsol, lid, ipar_decsol, &
            ierr)
         use utils_lib, only: set_pointer_1, enlarge_integer_if_needed_2, &
            enlarge_if_needed_1, enlarge_if_needed_2, enlarge_if_needed_3, &
            set_int_pointer_1
         integer, intent(in) :: id, caller_id, nvar, nz
         real(dp), pointer, intent(inout), dimension(:,:,:) :: lblk, dblk, ublk ! (mblk,mblk,nblk)
         ! row(i) of mtx has lblk(:,:,i), dblk(:,:,i), ublk(:,:,i)
         ! lblk(:,:,1) is not used; ublk(:,:,nblk) is not used.
         integer, intent(in) :: lrd, lid
         double precision, target, intent(inout) :: rpar_decsol(lrd)
         integer, target, intent(inout) :: ipar_decsol(lid)
         integer, intent(out) :: ierr
         
         type (split_mtx_info), pointer :: p
         integer :: nvarh, species
         type (star_info), pointer :: s
         
         include 'formats.dek'
         
         ierr = 0
         call get_ptr(id,p,ierr)
         if (ierr /= 0) return

         call get_star_ptr(caller_id, s, ierr)
         if (ierr /= 0) return

         nvarh = s% nvar_hydro
         species = s% species
         if (nz > s% nz) then
            ierr = -1
            return
         end if
         
         p% nz = nz
         p% nvarh = nvarh
         p% species = species
         
         call alloc(ierr)
         if (ierr /= 0) return
         
         call split(ierr)
         if (ierr /= 0) return
         
         call factor(ierr)
         if (ierr /= 0) return
         
         
         contains
         
         
         subroutine alloc(ierr)
            use utils_lib, only: enlarge_if_needed_2, enlarge_if_needed_3
            integer, intent(out) :: ierr
            integer :: h, a, n
            integer, parameter :: extra = 200
            include 'formats.dek'
            
            h = nvarh; a = species; n = nz            
            ierr = 0
            ! enlarge if needed rather than reallocate
            call enlarge_if_needed_3(p% Jhh_l,h,h,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_3(p% Jhh_d,h,h,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_3(p% Jhh_u,h,h,n,extra,ierr)
            if (ierr /= 0) return
         
            call enlarge_if_needed_3(p% Jha_l,h,a,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_3(p% Jha_d,h,a,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_3(p% Jha_u,h,a,n,extra,ierr)
            if (ierr /= 0) return
         
            call enlarge_if_needed_3(p% Jah_l,a,h,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_3(p% Jah_d,a,h,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_3(p% Jah_u,a,h,n,extra,ierr)
            if (ierr /= 0) return
         
            call enlarge_if_needed_3(p% Jaa_l,a,a,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_3(p% Jaa_d,a,a,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_3(p% Jaa_u,a,a,n,extra,ierr)
            if (ierr /= 0) return
         
            call enlarge_if_needed_3(p% Jhh_l_fac,h,h,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_3(p% Jhh_d_fac,h,h,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_3(p% Jhh_u_fac,h,h,n,extra,ierr)
            if (ierr /= 0) return
                     
            call enlarge_if_needed_3(p% Jaa_l_fac,a,a,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_3(p% Jaa_d_fac,a,a,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_3(p% Jaa_u_fac,a,a,n,extra,ierr)
            if (ierr /= 0) return
         
            call enlarge_if_needed_2(p% xh,h,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_2(p% dxh,h,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_2(p% rh,h,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_2(p% wh,h,n,extra,ierr)
            if (ierr /= 0) return
         
            call enlarge_if_needed_2(p% xa,a,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_2(p% dxa,a,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_2(p% ra,a,n,extra,ierr)
            if (ierr /= 0) return
            call enlarge_if_needed_2(p% wa,a,n,extra,ierr)
            if (ierr /= 0) return

         end subroutine alloc
         
         
         subroutine split(ierr)
            integer, intent(out) :: ierr
            integer :: h, a, n, i, j, k
            h = nvarh; a = species; n = nz
            ierr = 0
            do k=1,nz
               do j=1,h
                  do i=1,h
                     p% Jhh_l(i,j,k) = lblk(i,j,k)
                     p% Jhh_d(i,j,k) = dblk(i,j,k)
                     p% Jhh_u(i,j,k) = ublk(i,j,k)
                  end do
                  do i=1,a
                     p% Jah_l(i,j,k) = lblk(h+i,j,k)
                     p% Jah_d(i,j,k) = dblk(h+i,j,k)
                     p% Jah_u(i,j,k) = ublk(h+i,j,k)
                  end do
               end do
               do j=1,a
                  do i=1,h
                     p% Jha_l(i,j,k) = lblk(i,h+j,k)
                     p% Jha_d(i,j,k) = dblk(i,h+j,k)
                     p% Jha_u(i,j,k) = ublk(i,h+j,k)
                  end do
                  do i=1,a
                     p% Jaa_l(i,j,k) = lblk(h+i,h+j,k)
                     p% Jaa_d(i,j,k) = dblk(h+i,h+j,k)
                     p% Jaa_u(i,j,k) = ublk(h+i,h+j,k)
                  end do
               end do
            end do
         end subroutine split
         
         subroutine factor(ierr)
            use utils_lib, only: enlarge_if_needed_1, enlarge_integer_if_needed_1
            integer, intent(out) :: ierr
            
            integer :: i, j, k
            integer, pointer :: ipiv(:,:)
            real(dp), pointer :: brhs(:,:)

            ierr = 0
            brhs => null()
            ipiv => null()

            ! factor Jhh
            call block_dc_mt_dble_work_sizes(nvarh, nz, p% Jhh_lrd, p% Jhh_lid)
            call enlarge_if_needed_1(p% Jhh_rpar_decsol, p% Jhh_lrd, 100, ierr)
            if (ierr /= 0) return
            call enlarge_integer_if_needed_1(p% Jhh_ipar_decsol, p% Jhh_lid, 100, ierr)
            if (ierr /= 0) return
            do k=1,nz
               do j=1,nvarh
                  do i=1,nvarh
                     p% Jhh_l_fac(i,j,k) = p% Jhh_l(i,j,k)
                     p% Jhh_d_fac(i,j,k) = p% Jhh_d(i,j,k)
                     p% Jhh_u_fac(i,j,k) = p% Jhh_u(i,j,k)
                  end do
               end do
            end do
            call block_dc_mt_dble_decsolblk( &
               0, caller_id, nvarh, nz, p% Jhh_l_fac, p% Jhh_d_fac, p% Jhh_u_fac, brhs, ipiv, &
               p% Jhh_lrd, p% Jhh_rpar_decsol, p% Jhh_lid, p% Jhh_ipar_decsol, &
               ierr)
            if (ierr /= 0) then
               write(*,*) 'd_and_c_factor Jhh failed'
               return
            end if
            
            ! factor Jaa
            call block_dc_mt_dble_work_sizes(species, nz, p% Jaa_lrd, p% Jaa_lid)
            call enlarge_if_needed_1(p% Jaa_rpar_decsol, p% Jaa_lrd, 100, ierr)
            if (ierr /= 0) return
            call enlarge_integer_if_needed_1(p% Jaa_ipar_decsol, p% Jaa_lid, 100, ierr)
            if (ierr /= 0) return
            do k=1,nz
               do j=1,species
                  do i=1,species
                     p% Jaa_l_fac(i,j,k) = p% Jaa_l(i,j,k)
                     p% Jaa_d_fac(i,j,k) = p% Jaa_d(i,j,k)
                     p% Jaa_u_fac(i,j,k) = p% Jaa_u(i,j,k)
                  end do
               end do
            end do
            call block_dc_mt_dble_decsolblk( &
               0, caller_id, species, nz, p% Jaa_l_fac, p% Jaa_d_fac, p% Jaa_u_fac, brhs, ipiv, &
               p% Jaa_lrd, p% Jaa_rpar_decsol, p% Jaa_lid, p% Jaa_ipar_decsol, &
               ierr)
            if (ierr /= 0) then
               write(*,*) 'd_and_c_factor Jhh failed'
               return
            end if

         end subroutine factor
         
         
      end subroutine split_mtx_factor
      
      
      subroutine split_mtx_solve( &
            id, caller_id, nvar, nz, xy, &
            lrd, rpar_decsol, lid, ipar_decsol, &
            ierr)
         use utils_lib, only: set_pointer_1, set_pointer_3
         integer, intent(in) :: id, caller_id, nvar, nz
         real(dp), pointer, intent(inout), dimension(:,:) :: xy ! (mblk,nblk)
            ! note: xy is rhs y on input and solution x on output
         integer, intent(in) :: lrd, lid
         double precision, target, intent(inout) :: rpar_decsol(lrd)
         integer, target, intent(inout) :: ipar_decsol(lid)
         integer, intent(out) :: ierr
         
         type (split_mtx_info), pointer :: p
         integer :: nvarh, species, nmax
         real(dp) :: atol, rtol, atol1, rtol1, err, max_err, sum_err, avg_err
         type (star_info), pointer :: s
         
         include 'formats.dek'
         
         ierr = 0
         call get_ptr(id, p, ierr)
         if (ierr /= 0) return

         call get_star_ptr(caller_id, s, ierr)
         if (ierr /= 0) return

         nvarh = p% nvarh
         species = p% species

         atol = s% split_mtx_solve_atol
         rtol = s% split_mtx_solve_rtol   
         nmax = s% split_mtx_solve_max_iters
         
         if (dbg) then
            write(*,2) 'solve nvarh', nvarh
            write(*,2) 'species', species
            write(*,2) 'nz', nz
            write(*,1) 'atol', atol
            write(*,1) 'rtol', rtol
            write(*,*)
         end if
         
         call split_xy(ierr)
         if (ierr /= 0) return

         call solve_xh(ierr)
         if (ierr /= 0) return
      
         call solve_xa(ierr)
         if (ierr /= 0) return
         
         call set_xy(ierr)
         if (ierr /= 0) return
            
         
         contains
         
         
         subroutine solve_xh(ierr)
            integer, intent(out) :: ierr
            integer :: i, k, n
            include 'formats.dek'
            do k=1,nz
               do i=1,species
                  p% wa(i,k) = p% ra(i,k)
               end do
            end do
            call solve_Jaa(p% wa, ierr)
            if (ierr /= 0) return
            call mult_Jha(p% wa, p% wh) ! wh = Jha*wa
            do k=1,nz
               do i=1,nvarh
                  p% xh(i,k) = p% rh(i,k) - p% wh(i,k)
               end do
            end do
            call solve_Jhh(p% xh, ierr)
            if (ierr /= 0) return
            do k=1,nz
               do i=1,nvarh
                  p% dxh(i,k) = p% xh(i,k)
               end do
            end do
            do n = 1, nmax
               call mult_Jah(p% dxh, p% wa) ! wa = Jah*dxh
               call solve_Jaa(p% wa, ierr)
               if (ierr /= 0) return
               call mult_Jha(p% wa, p% dxh) ! dxh = Jha*wa
               call solve_Jhh(p% dxh, ierr)
               if (ierr /= 0) return
               max_err = 0; sum_err = 0
               do k=1,nz ! xh = xh + dxh
                  do i=1,nvarh
                     p% xh(i,k) = p% xh(i,k) + p% dxh(i,k)
                     err = abs(p% dxh(i,k)) / (atol + rtol*abs(p% xh(i,k)))
                     sum_err = sum_err + err
                     if (err > max_err) max_err = err
                  end do
               end do
               if (max_err < 1) exit
            end do
            if (max_err >= 1) then
               ierr = -1
               write(*,2) 'solve_xh failed', n, max_err, sum_err/(nz*nvarh)
            end if
            if (dbg) write(*,2) 'solve_xh', n, max_err, sum_err/(nz*nvarh)
         end subroutine solve_xh
         
         
         subroutine solve_xa(ierr)
            integer, intent(out) :: ierr
            integer :: i, k, n
            include 'formats.dek'
            ierr = 0
            do k=1,nz
               do i=1,nvarh
                  p% wh(i,k) = p% rh(i,k)
               end do
            end do
            call solve_Jhh(p% wh, ierr)
            if (ierr /= 0) return
            call mult_Jah(p% wh, p% wa) ! wa = Jah*wh
            do k=1,nz ! xa = ra - wa
               do i=1,species
                  p% xa(i,k) = p% ra(i,k) - p% wa(i,k)
               end do
            end do
            call solve_Jaa(p% xa, ierr)
            if (ierr /= 0) return
            do k=1,nz ! dxa = xa
               do i=1,species
                  p% dxa(i,k) = p% xa(i,k)
               end do
            end do
            do n = 1, nmax
               call mult_Jha(p% dxa, p% wh) ! wh = Jha*dxa
               call solve_Jhh(p% wh, ierr)
               if (ierr /= 0) return
               call mult_Jah(p% wh, p% dxa) ! dxa = Jah*wh
               call solve_Jaa(p% dxa, ierr)
               if (ierr /= 0) return
               max_err = 0; sum_err = 0
               do k=1,nz ! xa = xa + dxa
                  do i=1,species
                     p% xa(i,k) = p% xa(i,k) + p% dxa(i,k)
                     err = abs(p% dxa(i,k)) / (atol + rtol*abs(p% xa(i,k)))
                     sum_err = sum_err + err
                     if (err > max_err) max_err = err
                  end do
               end do
               if (max_err < 1) exit
            end do
            if (max_err >= 1) then
               ierr = -1
               write(*,2) 'solve_xa failed', n, max_err, sum_err/(nz*species)
            end if
            if (dbg) write(*,2) 'solve_xa', n, max_err, sum_err/(nz*species)
         end subroutine solve_xa
         
         
         subroutine mult_Jaa(w_in,w_out) ! w_out = Jaa*w_in
            use mtx_lib, only: block_dble_mv
            real(dp), pointer, intent(in) :: w_in(:,:) ! (species,nz)
            real(dp), pointer, intent(out) :: w_out(:,:) ! (species,nz)
            call block_dble_mv(p% Jaa_l, p% Jaa_d, p% Jaa_u, w_in, w_out)
         end subroutine mult_Jaa
         
         
         subroutine mult_Jhh(w_in,w_out) ! w_out = Jhh*w_in
            use mtx_lib, only: block_dble_mv
            real(dp), pointer, intent(in) :: w_in(:,:) ! (nvarh,nz)
            real(dp), pointer, intent(out) :: w_out(:,:) ! (nvarh,nz)
            call block_dble_mv(p% Jhh_l, p% Jhh_d, p% Jhh_u, w_in, w_out)
         end subroutine mult_Jhh
         
         
         subroutine mult_Jha(wa,wh) ! wh = Jha*wa
            use mtx_lib, only: block_dble_mv
            real(dp), pointer, intent(in) :: wa(:,:) ! (species,nz)
            real(dp), pointer, intent(out) :: wh(:,:) ! (nvarh,nz)
            call block_dble_mv(p% Jha_l, p% Jha_d, p% Jha_u, wa, wh)
         end subroutine mult_Jha
         
         
         subroutine mult_Jah(wh,wa) ! wa = Jah*wh
            use mtx_lib, only: block_dble_mv
            real(dp), pointer, intent(in) :: wh(:,:) ! (nvarh,nz)
            real(dp), pointer, intent(out) :: wa(:,:) ! (species,nz)
            call block_dble_mv(p% Jah_l, p% Jah_d, p% Jah_u, wh, wa)
         end subroutine mult_Jah
         
         
         subroutine solve_Jaa(w,ierr)
            real(dp), pointer, intent(inout) :: w(:,:) ! (species,nz)
            integer, intent(out) :: ierr
            integer, pointer :: ipiv(:,:)
            ierr = 0
            ipiv => null()            
            call block_dc_mt_dble_decsolblk( &
               1, caller_id, species, nz, p% Jaa_l_fac, p% Jaa_d_fac, p% Jaa_u_fac, w, ipiv, &
               p% Jaa_lrd, p% Jaa_rpar_decsol, p% Jaa_lid, p% Jaa_ipar_decsol, &
               ierr)
         end subroutine solve_Jaa
         
         
         subroutine solve_Jhh(w,ierr)
            real(dp), pointer, intent(inout) :: w(:,:) ! (nvarh,nz)
            integer, intent(out) :: ierr
            integer, pointer :: ipiv(:,:)
            ierr = 0
            ipiv => null()            
            call block_dc_mt_dble_decsolblk( &
               1, caller_id, nvarh, nz, p% Jhh_l_fac, p% Jhh_d_fac, p% Jhh_u_fac, w, ipiv, &
               p% Jhh_lrd, p% Jhh_rpar_decsol, p% Jhh_lid, p% Jhh_ipar_decsol, &
               ierr)
         end subroutine solve_Jhh
         
         
         subroutine split_xy(ierr)
            integer, intent(out) :: ierr
            integer :: i, k
            ierr = 0
            do k=1,nz
               do i=1,nvarh
                  p% rh(i,k) = xy(i,k)
               end do
               do i=1,species
                  p% ra(i,k) = xy(nvarh+i,k)
               end do
            end do
         end subroutine split_xy
         
         
         subroutine set_xy(ierr)
            integer, intent(out) :: ierr
            integer :: i, k
            ierr = 0
            do k=1,nz
               do i=1,nvarh
                  xy(i,k) = p% xh(i,k)
               end do
               do i=1,species
                  xy(nvarh+i,k) = p% xa(i,k)
               end do
            end do
         end subroutine set_xy
         
         
      end subroutine split_mtx_solve
      
           
      subroutine split_mtx_dealloc( &
            id, caller_id, &
            lrd, rpar_decsol, lid, ipar_decsol, &
            ierr)
         integer, intent(in) :: caller_id, id
         integer, intent(in) :: lrd, lid
         double precision, target, intent(inout) :: rpar_decsol(lrd)
         integer, target, intent(inout) :: ipar_decsol(lid)
         integer, intent(out) :: ierr         
         
         type (split_mtx_info), pointer :: p
         integer, pointer :: ipiv(:,:)
         real(dp), pointer :: brhs(:,:)

         ierr = 0
         call get_ptr(id,p,ierr)
         if (ierr /= 0) return
         brhs => null()
         ipiv => null()                     
         call block_dc_mt_dble_decsolblk( &
            2, caller_id, p% nvarh, p% nz, p% Jhh_l_fac, p% Jhh_d_fac, p% Jhh_u_fac, brhs, ipiv, &
            p% Jhh_lrd, p% Jhh_rpar_decsol, p% Jhh_lid, p% Jhh_ipar_decsol, &
            ierr)
         if (ierr /= 0) return            
         call block_dc_mt_dble_decsolblk( &
            2, caller_id, p% species, p% nz, p% Jaa_l_fac, p% Jaa_d_fac, p% Jaa_u_fac, brhs, ipiv, &
            p% Jaa_lrd, p% Jaa_rpar_decsol, p% Jaa_lid, p% Jaa_ipar_decsol, &
            ierr)
         if (ierr /= 0) return
         call do_free_handle(id)
         
      end subroutine split_mtx_dealloc
      
      
      subroutine split_mtx_init
         type (split_mtx_info), pointer :: p
         integer :: i
         if (have_initialized) return
!$omp critical (split_mtx_initialize)
         if (.not. have_initialized) then
            do i = 1, max_handles
               p => handles(i)
               p% handle = i
               p% in_use = .false.
               nullify( &
                  p% Jhh_l, p% Jhh_d, p% Jhh_u, &
                  p% Jha_l, p% Jha_d, p% Jha_u, &
                  p% Jah_l, p% Jah_d, p% Jah_u, &
                  p% Jaa_l, p% Jaa_d, p% Jaa_u, &
                  p% Jhh_l_fac, p% Jhh_d_fac, p% Jhh_u_fac, &
                  p% Jaa_l_fac, p% Jaa_d_fac, p% Jaa_u_fac, &
                  p% xh, p% dxh, p% rh, p% wh, &
                  p% xa, p% dxa, p% ra, p% wa, &
                  p% Jaa_rpar_decsol, p% Jaa_ipar_decsol, &         
                  p% Jhh_rpar_decsol, p% Jhh_ipar_decsol)        
            end do
            have_initialized = .true.
         end if
!$omp end critical (split_mtx_initialize)
      end subroutine split_mtx_init

      
      integer function split_mtx_alloc(ierr)
         use alert_lib,only:alert
         integer, intent(out) :: ierr
         integer :: i
         ierr = 0
         if (.not. have_initialized) call split_mtx_init
         split_mtx_alloc = -1
!$omp critical (split_mtx_handle)
         do i = 1, max_handles
            if (.not. handles(i)% in_use) then
               handles(i)% in_use = .true.
               split_mtx_alloc = i
               exit
            end if
         end do
!$omp end critical (split_mtx_handle)
         if (split_mtx_alloc == -1) then
            ierr = -1
            call alert(ierr, 'no available split_mtx handle')
            return
         end if
         if (handles(split_mtx_alloc)% handle /= split_mtx_alloc) then
            ierr = -1
            call alert(ierr, 'broken handle for split_mtx')
            return
         end if
      end function split_mtx_alloc
            
      
      subroutine do_free_handle(handle)
         integer, intent(in) :: handle
         type (split_mtx_info), pointer :: p
         if (handle >= 1 .and. handle <= max_handles) then
            p => handles(handle)
            handles(handle)% in_use = .false.
         end if
      end subroutine do_free_handle
      

      subroutine get_ptr(handle,p,ierr)
         use alert_lib,only:alert
         integer, intent(in) :: handle
         type (split_mtx_info), pointer :: p
         integer, intent(out):: ierr         
         if (handle < 1 .or. handle > max_handles) then
            ierr = -1
            call alert(ierr,'invalid split_mtx handle')
            return
         end if
         p => handles(handle)
         ierr = 0
      end subroutine get_ptr

      
      subroutine split_mtx_work_sizes(nvar,nz,lrd,lid)
         integer, intent(in) :: nvar,nz
         integer, intent(out) :: lrd,lid
         lid = 1
         lrd = 0
      end subroutine split_mtx_work_sizes

   
      subroutine star_split_decsolblk( &
            iop,caller_id,nvar,nz,lblk,dblk,ublk,&
            brhs,ipiv,lrd,rpar_decsol,lid,ipar_decsol,ierr)
         use alert_lib, only : alert
         integer, intent(in) :: iop, caller_id, nvar, nz, lrd, lid
         integer, pointer, intent(inout) :: ipiv(:,:) ! (nvar,nz)
         real(dp), dimension(:,:,:), pointer, intent(inout) :: lblk, dblk, ublk
         ! row(i) of mtx has lblk(:,:,i), dblk(:,:,i), ublk(:,:,i)
         ! lblk(:,:,1) is not used; ublk(:,:,nz) is not used.
         real(dp), pointer, intent(inout)  :: brhs(:,:) ! (nvar,nz)     
         real(dp), target, intent(inout) :: rpar_decsol(lrd)
         integer, target, intent(inout) :: ipar_decsol(lid)
         integer, intent(out) :: ierr
         integer :: id
         include 'formats.dek'
         ierr = 0
         if (lid < 1) then
            ierr = -1
            return
         end if
         if (iop == 0) then ! factor
            id = split_mtx_alloc(ierr)
            if (ierr /= 0) return
            ipar_decsol(1) = id
            call split_mtx_factor( &
               id, caller_id, nvar, nz, lblk, dblk, ublk, &
               lrd, rpar_decsol, lid, ipar_decsol, &
               ierr)
         else if (iop == 1) then ! solve
            id = ipar_decsol(1)
            call split_mtx_solve( &
               id, caller_id, nvar, nz, brhs, &
               lrd, rpar_decsol, lid, ipar_decsol, &
               ierr)
         else if (iop == 2) then ! deallocate
            id = ipar_decsol(1)
            call split_mtx_dealloc( &
               id, caller_id, lrd, rpar_decsol, lid, ipar_decsol, ierr)
         else
            ierr = -1
            call alert(ierr,'star_split_decsolblk: iop bad')
         end if
      end subroutine star_split_decsolblk





      end module solve_split_mtx
