! ***********************************************************************
!
!   Copyright (C) 2012  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************


#ifdef DBLE
      module d_and_c_block_dble
      use my_lapack95_dble
      use utils_lib, only: &
         set_pointer_1, set_pointer_3, &
         enlarge_if_needed_1, enlarge_if_needed_2, enlarge_if_needed_3
#else
      module d_and_c_block_quad
      use my_lapack95_quad
      use utils_lib, only: set_pointer_1, &
         set_quad_pointer_1, set_quad_pointer_3, &
         quad_enlarge_if_needed_1, quad_enlarge_if_needed_2, quad_enlarge_if_needed_3
#endif
      use utils_lib, only: set_int_pointer_1, enlarge_integer_if_needed_2
      use const_def, only: dp
      

#ifdef DBLE
#define set_ptr_1 set_pointer_1
#define set_ptr_3 set_pointer_3
#define enlarge_1 enlarge_if_needed_1
#define enlarge_2 enlarge_if_needed_2
#define enlarge_3 enlarge_if_needed_3
#else
#define set_ptr_1 set_quad_pointer_1
#define set_ptr_3 set_quad_pointer_3
#define enlarge_1 quad_enlarge_if_needed_1
#define enlarge_2 quad_enlarge_if_needed_2
#define enlarge_3 quad_enlarge_if_needed_3
#endif


            

      implicit none


      type part
         integer :: part_num ! from 1 to num_parts
         integer :: kblk ! n-1 where n = number of zones in this part
         integer :: first, last ! indices of first and last zones in this part
         real(fltp), pointer, dimension(:,:,:) :: & ! (mblk,mblk,kblk)
            lblk, dblk, ublk
         real(fltp), pointer, dimension(:,:,:) :: & ! (mblk,mblk,kblk)
            u_bar, v_bar
         real(fltp), pointer, dimension(:,:) :: & ! (mblk,mblk)
            lambda, theta, gamma, phi, omega
         real(fltp), pointer, dimension(:,:) :: x, y, y_bar ! (mblk,kblk)
            ! note: x, y, and y_bar are actually just different pointers to the same vector of data.
            ! we retain the different names to match gustavo's description of the algorithm.
         real(fltp), pointer, dimension(:) :: psi, zeta ! (mblk)
            ! note: psi and zeta are different names for same vector of data
         integer, pointer :: ipiv(:,:) ! (mblk,kblk)
         integer :: thomas_id
      end type part 
      
      type d_and_c_info
         integer :: num_p ! num_parts
         integer :: nblk ! number of zones
         integer :: mblk ! number of variables per zone
         type (part), pointer :: p(:) ! (num_parts)
         ! matrices and vectors for the coupling solution
         real(fltp), pointer, dimension(:,:,:) :: & ! (mblk,mblk,nblk)
            lblk, dblk, ublk
         real(fltp), pointer, dimension(:,:,:) :: & ! (mblk,mblk,num_p-1)
            lambda_bar, gamma_bar, phi_bar
         real(fltp), pointer, dimension(:,:) :: & ! (mblk,num_p-1)
            psi_bar, zeta
            ! note: psi_bar and zeta are different names for same vector of data
         integer, pointer :: ipiv(:,:) ! (mblk,num_p-1)
         integer :: thomas_id
         ! bookkeeping
         integer :: handle
         logical :: in_use
      end type d_and_c_info
      
      
      integer, parameter :: max_handles = 100
      type (d_and_c_info), target :: handles(max_handles)
      
      
      logical :: do_factor_subtiming = .false.
      logical :: show_factor_subtiming = .false.
      real(fltp) :: factor_subtime_setup
      real(fltp) :: factor_subtime_factor_As
      real(fltp) :: factor_subtime_setup_C
      real(fltp) :: factor_subtime_factor_C
      
      logical :: do_dealloc_subtiming = .false.
      logical :: show_dealloc_subtiming = .false.
      real(fltp) :: dealloc_subtime_thomas
      real(fltp) :: dealloc_subtime_d_and_c
      
      
      integer, parameter :: i_id = 1
      integer, parameter :: i_cnt = i_id + 1
      integer, parameter :: num_ipars = i_cnt
      
      
      logical :: have_omp_num_threads = .false.
      integer :: omp_num_threads = 0
      
      integer, parameter :: parts_per_thread = 1
         ! you might think that 2 or 3 parts per thread would help with load balancing for factors
         ! but in my experience, it has made things slower, so it is set to 1
      
      logical, parameter :: do_refine_weights = .false.  ! haven't found anything that helps.
      
      
      logical, parameter :: dbg = .false.
      
      
      contains
      
      
      subroutine d_and_c_factor( &
            id, mblk, nblk, use_given_weights, lblk, dblk, ublk, &
            thomas_handle, thomas_factor, thomas_solve, thomas_dealloc, thomas_stats, &
            lrd, rpar_decsol, lid, ipar_decsol, &
            ierr)
         integer, intent(in) :: id, nblk, mblk
         logical, intent(in) :: use_given_weights
         real(fltp), pointer, intent(inout), dimension(:,:,:) :: lblk, dblk, ublk ! (mblk,mblk,nblk)
         ! row(i) of mtx has lblk(:,:,i), dblk(:,:,i), ublk(:,:,i)
         ! lblk(:,:,1) is not used; ublk(:,:,nblk) is not used.
         interface
#ifdef DBLE
            include 'thomas_block_procs_dble.dek'
#else
            include 'thomas_block_procs_quad.dek'
#endif
         end interface
         integer, intent(in) :: lrd, lid
         real(dp), pointer, intent(inout) :: rpar_decsol(:) ! (lrd)
         integer, pointer, intent(inout) :: ipar_decsol(:) ! (lid)
         integer, intent(out) :: ierr

         type (d_and_c_info), pointer :: d
         integer ::  num_parts, i, j, k, kblk, op_err, &
            factor, refactor, rejected_refactor, solve, sum_nz, max_nz
         integer, parameter :: extra = 200
         real(fltp) :: min_rcond, time_thread(omp_num_threads)
         type (part), pointer :: p, pp1
         real(dp), pointer, dimension(:) :: &
            time_thomas_solves, time_thomas_factors, &
            weights, factor_times, solve_times
         integer, pointer :: kblks(:)
            
         integer :: time0_thomas, time1_thomas, time0_factor ! for thomas subtasks in factor As
         integer :: time0, time1 ! for load balancing information
         integer :: clock_rate
         
         include 'formats.dek'
         
         time0_factor = 0 ! for gfortran
         if (do_factor_subtiming) call system_clock(time0_factor)
         
         ierr = 0
         call get_ptr(id,d,ierr)
         if (ierr /= 0) return
         
         if (.not. have_omp_num_threads) call d_and_c_block_init
         num_parts = parts_per_thread*omp_num_threads
         
         i = 1 ! for reporting max error in load balancing
         call set_pointer_1(weights, rpar_decsol(i+1:i+num_parts), num_parts)
         i = i+num_parts
         call set_pointer_1(factor_times, rpar_decsol(i+1:i+num_parts), num_parts)
         i = i+num_parts
         call set_pointer_1(solve_times, rpar_decsol(i+1:i+num_parts), num_parts)
         i = i+num_parts
         if (lrd < i) then
            write(*,3) 'mtx d_and_c_factor: lrd too small', lrd, i
            ierr = -1
            return
         end if
         factor_times(:) = 0
         solve_times(:) = 0
         
         i = num_ipars
         call set_int_pointer_1(kblks, ipar_decsol(i+1:i+num_parts), num_parts)
         i = i+num_parts
         if (lid < i) then
            write(*,3) 'mtx d_and_c_factor: lid too small', lid, i
            ierr = -1
            return
         end if
         
         d% num_p = num_parts
         d% nblk = nblk
         d% mblk = mblk
         d% lblk => lblk
         d% dblk => dblk
         d% ublk => ublk
         d% thomas_id = thomas_handle(ierr)
         if (ierr /= 0) return

         if (nblk <= 2*omp_num_threads .or. omp_num_threads < 2) then ! just do thomas
            d% num_p = 0
            call enlarge_integer_if_needed_2(d% ipiv,mblk,nblk,extra,ierr)
            if (ierr /= 0) then
               write(*,*) 'enlarge_integer_if_needed_2 failed'
               return
            end if
            call thomas_factor( &
               d% thomas_id, mblk, nblk, d% lblk, d% dblk, d% ublk, d% ipiv, &
               lrd, rpar_decsol, lid, ipar_decsol, ierr)
            return
         end if

         if (do_factor_subtiming) then
            allocate(time_thomas_solves(num_parts), time_thomas_factors(num_parts))
            time_thomas_solves = 0
            time_thomas_factors = 0
         end if
         
         call enlarge_3(d% lambda_bar,mblk,mblk,num_parts-1,extra,ierr)
         if (ierr /= 0) return
         call enlarge_3(d% gamma_bar,mblk,mblk,num_parts-1,extra,ierr)
         if (ierr /= 0) return
         call enlarge_3(d% phi_bar,mblk,mblk,num_parts-1,extra,ierr)
         if (ierr /= 0) return
         call enlarge_2(d% psi_bar,mblk,num_parts-1,extra,ierr)
         if (ierr /= 0) return
         call enlarge_integer_if_needed_2(d% ipiv,mblk,num_parts-1,extra,ierr)
         if (ierr /= 0) return
         
         call partition_mtx

         if (do_factor_subtiming) then
            call system_clock(time1,clock_rate)
            factor_subtime_setup = &
               factor_subtime_setup + dble(time1-time0_factor)/clock_rate
            time0_factor = time1
         end if
         time0_thomas = 0
!$OMP PARALLEL DO PRIVATE(i,j,k,p,kblk,op_err,time0,time1,time0_thomas,clock_rate) schedule(dynamic,1)
         do i = 1, num_parts
            if (ierr /= 0) cycle
            if (do_refine_weights .or. do_factor_subtiming) then
               call system_clock(time0)
               time0_thomas = time0
            end if
            p => d% p(i)
            kblk = p% kblk
            op_err = 0
            call thomas_factor(&
               p% thomas_id, mblk, kblk, p% lblk, p% dblk, p% ublk, p% ipiv, &
               lrd, rpar_decsol, lid, ipar_decsol, op_err)
            if (op_err /= 0) then
               !write(*,2) 'i', i
               !write(*,2) 'size(p% lblk,3)', size(p% lblk,3)
               !write(*,2) 'mblk', mblk
               !write(*,2) 'kblk', kblk
               !stop 'failed in thomas_factor'
               ierr = op_err; cycle
            end if
            if (do_factor_subtiming) then
               call system_clock(time1,clock_rate)
               time_thomas_factors(i) = dble(time1-time0_thomas)/clock_rate
               time0_thomas = time1
            end if
            if (i < num_parts) then
               p% v_bar = 0
               do j=1,mblk
                  do k=1,mblk
                     p% v_bar(k,j,kblk) = p% theta(k,j)
                  end do
               end do
               call thomas_solve( &
                  p% thomas_id, mblk, kblk, p% lblk, p% dblk, p% ublk, p% v_bar, mblk, p% ipiv, &
                  lrd, rpar_decsol, lid, ipar_decsol, op_err)
               if (op_err /= 0) then
                  ierr = op_err; cycle
               end if
            end if
            if (i > 1) then
               p% u_bar = 0
               do j=1,mblk
                  do k=1,mblk
                     p% u_bar(k,j,1) = d% p(i-1)% omega(k,j)
                  end do
               end do
               call thomas_solve( &
                  p% thomas_id, mblk, kblk, p% lblk, p% dblk, p% ublk, p% u_bar, mblk, p% ipiv, &
                  lrd, rpar_decsol, lid, ipar_decsol, op_err)
               if (op_err /= 0) then
                  ierr = op_err; cycle
               end if
            end if
            if (do_refine_weights .or. do_factor_subtiming) then
               call system_clock(time1,clock_rate)
               factor_times(i) = dble(time1 - time0)
            end if
            if (do_factor_subtiming) then
               time_thomas_solves(i) = dble(time1-time0_thomas)/clock_rate
            end if
         end do
!$OMP END PARALLEL DO

         if (ierr == 0 .and. do_factor_subtiming) then
            call system_clock(time1,clock_rate)
            factor_subtime_factor_As = factor_subtime_factor_As + dble(time1-time0_factor)/clock_rate
            time0_factor = time1
            if (show_factor_subtiming) then
               write(*,*)
               write(*,*)
               write(*,*) 'break down of times by thomas part in d_and_c_factor'
               write(*,'(4x,7a9,99a14)') &
                  'kblk', '#fact', '#refact', '#reject', '#solve', 'max_nz', 'sum_nz', &
                     'total time', 'thmas fact', 'thmas solv', 'mn_rcnd'
               do i = 1, num_parts
                  p => d% p(i)
                  call thomas_stats(p% thomas_id, &
                     j, factor, refactor, rejected_refactor, solve, sum_nz, max_nz, min_rcond, ierr)
                  if (ierr /= 0) exit
                  write(*,'(i4,7i9,99(1pe14.3))') &
                     i, p% kblk, factor, refactor, rejected_refactor, solve, max_nz, sum_nz, &
                     time_thomas_factors(i) + time_thomas_solves(i), &
                     time_thomas_factors(i), time_thomas_solves(i), min_rcond
               end do
            write(*,*)  
            end if
         end if

         if (ierr /= 0) then
            !write(*,*) 'd_and_c_factor failed in thomas_factor'
            return
         end if
         
!$OMP PARALLEL DO PRIVATE(i,p,pp1,kblk,time0,time1,op_err)
         do i = 1, num_parts-1 ! calculate blocks of coupling system         
            p => d% p(i)
            pp1 => d% p(i+1)
            kblk = p% kblk         
            ! d% lambda_bar(:,:,i) = p% lambda - p% gamma*last(p% v_bar) - p% phi*first(pp1% u_bar)
            d% lambda_bar(:,:,i) = p% lambda
            call do_dgemm(p% gamma, p% v_bar(:,:,kblk), d% lambda_bar(:,:,i))
            call do_dgemm(p% phi, pp1% u_bar(:,:,1), d% lambda_bar(:,:,i))                 
            if (i > 1) then ! d% gamma_bar(:,:,i) = -p% gamma*last(p% u_bar)
               call do_dgemm0(p% gamma, p% u_bar(:,:,kblk), d% gamma_bar(:,:,i))
            else
               d% gamma_bar(:,:,i) = 0
            end if                  
            if (i < num_parts-1) then ! d% phi_bar(:,:,i) = -p% phi*first(pp1% v_bar)
               call do_dgemm0(p% phi, pp1% v_bar(:,:,1), d% phi_bar(:,:,i))
            else
               d% phi_bar(:,:,i) = 0
            end if
         end do
!$OMP END PARALLEL DO

         if (do_factor_subtiming) then
            call system_clock(time1,clock_rate)
            factor_subtime_setup_C = factor_subtime_setup_C + dble(time1-time0_factor)/clock_rate
            time0_factor = time1
         end if

         ! use thomas to factor the coupling system
         call thomas_factor( &
            d% thomas_id, mblk, num_parts-1, d% gamma_bar, d% lambda_bar, d% phi_bar, d% ipiv, &
            lrd, rpar_decsol, lid, ipar_decsol, ierr)
         if (ierr /= 0) then
            !write(*,*) 'd_and_c_factor failed in thomas_factor for coupling system'
            return
         end if

         if (do_factor_subtiming) then
            call system_clock(time1,clock_rate)
            factor_subtime_factor_C = factor_subtime_factor_C + dble(time1-time0_factor)/clock_rate
            deallocate(time_thomas_solves, time_thomas_factors)
         end if
         
         
         contains
            
            
         subroutine do_dgemm(a,b,c)
            real(fltp), dimension(:,:) :: a,b,c
            call my_gemm(mblk, mblk, mblk, a, mblk, b, mblk, c, mblk)
         end subroutine do_dgemm
            
            
         subroutine do_dgemm0(a,b,c)
            real(fltp), dimension(:,:) :: a,b,c
            call my_gemm0(mblk, mblk, mblk, a, mblk, b, mblk, c, mblk)
         end subroutine do_dgemm0
         
         
         subroutine partition_mtx
            use thomas_block_sparse, only: clip_limit
            use mtx_support, only: do_clip_blocks
            integer :: kblk_for_part(num_parts)
            real(fltp) :: targets(num_parts), chunk_size, next_target, total_terms
            integer :: n, r, i, kblk, first, last, j, ns, kblk_total, kblk_sum

            real(fltp), parameter :: extra_factor = 0.5 ! extra load for 1st and last parts

            real(fltp), parameter :: alfa = 0.5
               ! for alfa = 1, just consider number of blocks
               ! for alfa = 0, just consider number of nonzeros

            include 'formats.dek'
            
            kblk_total = nblk - num_parts + 1
            if (use_given_weights) then
               ipar_decsol(i_cnt) = ipar_decsol(i_cnt) + 1
               targets(1) = weights(1)
               do i=2,num_parts-1
                  targets(i) = targets(i-1) + weights(i)
               end do
               targets(num_parts) = 1d0
               kblk_for_part(1) = int(kblk_total*targets(1) + 0.5) ! round
               kblk_sum = kblk_for_part(1)
               do i=2,num_parts-1
                  kblk_for_part(i) = int(kblk_total*targets(i) - kblk_sum + 0.5) ! round
                  kblk_sum = kblk_sum + kblk_for_part(i)
               end do
               kblk_for_part(num_parts) = kblk_total - kblk_sum
               
               !do i=1,num_parts
               !   write(*,3) 'i kblk weight', i, kblk_for_part(i), weights(i)
               !end do
               
               !stop
               
            else
               ipar_decsol(i_cnt) = 1
               ! 1st and last parts get more zones than others since they do less work.
               n = (nblk - num_parts + 1)/(num_parts + 1)
               kblk_for_part(1) = (n*3)/2
               kblk_for_part(2:num_parts-1) = n
               kblk_for_part(num_parts) = kblk_for_part(1)
               r = (nblk - num_parts + 1) - sum(kblk_for_part)
               if (r > 0) then
                  r = r-1; kblk_for_part(1) = kblk_for_part(1) + 1
               end if
               if (r > 0) then
                  r = r-1; kblk_for_part(num_parts) = kblk_for_part(num_parts) + 1
               end if
               i = num_parts-1
               do while (r > 0 .and. i > 0)
                  kblk_for_part(i) = kblk_for_part(i) + 1
                  r = r-1; i = i-1
               end do
               do i = 1, num_parts
                  weights(i) = dble(kblk_for_part(i))/dble(kblk_total)
               end do
            end if
            
            if (sum(kblk_for_part(1:num_parts)) /= kblk_total) then
               write(*,2) 'oops', sum(kblk_for_part(1:num_parts)) + num_parts - 1
               write(*,2) 'nblk', nblk
               write(*,2) 'num_parts', num_parts
               stop
            end if            
            last = 0
            do i = 1, num_parts
               p => d% p(i)
               p% thomas_id = thomas_handle(ierr)
               if (ierr /= 0) return
               p% part_num = i
               kblk = kblk_for_part(i)
               kblks(i) = kblk
               call enlarge_3(p% u_bar,mblk,mblk,kblk,extra,ierr)
               if (ierr /= 0) return
               call enlarge_3(p% v_bar,mblk,mblk,kblk,extra,ierr)
               if (ierr /= 0) return
               call enlarge_integer_if_needed_2(p% ipiv,mblk,kblk,extra,ierr)
               if (ierr /= 0) return
               if (i == 1) then
                  first = 1
                  last = kblk
               else
                  first = last + 2
                  last = first + kblk - 1
               end if
               p% kblk = kblk
               p% first = first
               p% last = last
               p% lblk => lblk(:,:,first:last)
               p% dblk => dblk(:,:,first:last)
               p% ublk => ublk(:,:,first:last)
               if (i < num_parts) then
                  p% theta => ublk(:,:,last)
                  p% gamma => lblk(:,:,last+1)
                  p% lambda => dblk(:,:,last+1)
                  p% phi => ublk(:,:,last+1)
                  p% omega => lblk(:,:,last+2)
               end if               
               !write(*,3) 'kblk', i, kblk
            end do
         end subroutine partition_mtx
         
           
      end subroutine d_and_c_factor
      
      
      subroutine d_and_c_solve( &
            id, mblk, nblk, xy, &
            thomas_handle, thomas_factor, thomas_solve, thomas_dealloc, thomas_stats, &
            lrd, rpar_decsol, lid, ipar_decsol, &
            ierr)
         integer, intent(in) :: id, mblk,nblk
         real(fltp), pointer, intent(inout), dimension(:,:) :: xy ! (mblk,nblk)
            ! note: xy is rhs y on input and solution x on output
         interface
#ifdef DBLE
            include 'thomas_block_procs_dble.dek'
#else
            include 'thomas_block_procs_quad.dek'
#endif
         end interface
         integer, intent(in) :: lrd, lid
         real(dp), pointer, intent(inout) :: rpar_decsol(:) ! (lrd)
         integer, pointer, intent(inout) :: ipar_decsol(:) ! (lid)
         integer, intent(out) :: ierr

         real(dp), pointer, dimension(:) :: &
            weights, factor_times, solve_times
         integer :: time0, time1, clock_rate
         type (d_and_c_info), pointer :: d
         integer :: num_parts, i, k, kblk, op_err
         type (part), pointer :: p, pp1
         real(fltp), pointer :: p3(:,:,:)
         integer, parameter :: nrhs = 1
         
         include 'formats.dek'
         
         ierr = 0
         if (.not. have_omp_num_threads) then
            write(*,*) 'error: have called d_and_c_solve before calling d_and_c_factor'
            ierr = -1
            return
         end if

         call get_ptr(id,d,ierr)
         if (ierr /= 0) return
         
         if (nblk /= d% nblk .or. mblk /= d% mblk) then
            write(*,*) 'bad nblk or mblk in arg for d_and_c_solve', nblk, d% nblk, mblk, d% mblk
            ierr = -1
            return
         end if
         
         num_parts = d% num_p
         if (num_parts <= 0) then
            !write(*,*) 'just do thomas solve', nblk, omp_num_threads
            call set_ptr_3(p3, xy, mblk, nrhs, nblk)
            call thomas_solve( &
               d% thomas_id, mblk, nblk, d% lblk, d% dblk, d% ublk, p3, nrhs, d% ipiv, &
               lrd, rpar_decsol, lid, ipar_decsol, ierr)
            return
         end if

         i = 1 ! for reporting max error in load balancing
         call set_pointer_1(weights, rpar_decsol(i+1:i+num_parts), num_parts)
         i = i+num_parts
         call set_pointer_1(factor_times, rpar_decsol(i+1:i+num_parts), num_parts)
         i = i+num_parts
         call set_pointer_1(solve_times, rpar_decsol(i+1:i+num_parts), num_parts)
         i = i+num_parts
         if (lrd < i) then
            write(*,3) 'mtx d_and_c_solve: lrd too small', lrd, i
            ierr = -1
            return
         end if

         call partition_xy
         
!$OMP PARALLEL DO PRIVATE(i,p,p3,time0,time1,kblk,op_err)
         do i = 1, num_parts ! solve for y_bar
            if (do_refine_weights) call system_clock(time0)
            p => d% p(i)
            ! solve a(i)*y_bar(i) = x(i) ! x holds rhs at this point
            ! note: x, y, and y_bar just different names for the same vector
            kblk = p% kblk
            call set_ptr_3(p3, p% y_bar, mblk, nrhs, kblk)
            op_err = 0
            call thomas_solve( &
               p% thomas_id, mblk, kblk, p% lblk, p% dblk, p% ublk, p3, nrhs, p% ipiv, &
               lrd, rpar_decsol, lid, ipar_decsol, op_err)
            if (op_err /= 0) then
               !write(*,*) 'd_and_c_solve failed in thomas_solve for y_bar', i
               ierr = op_err
               cycle
            end if
            if (do_refine_weights) then
               call system_clock(time1)
               solve_times(i) = solve_times(i) + dble(time1 - time0)
            end if
         end do
!$OMP END PARALLEL DO            
         
!$OMP PARALLEL DO PRIVATE(i,p,pp1,kblk,op_err)
         do i = 1, num_parts-1 ! calculate rhs of coupling system
            p => d% p(i)
            pp1 => d% p(i+1)
            kblk = p% kblk
            ! d% psi_bar(:,i) = p% psi - p% gamma*last(p% y_bar) - p% phi*first(pp1% y_bar)
            d% psi_bar(:,i) = p% psi
            call do_dgemv(p% gamma, p% y_bar(:,kblk), d% psi_bar(:,i))
            call do_dgemv(p% phi, pp1% y_bar(:,1), d% psi_bar(:,i))
         end do
!$OMP END PARALLEL DO

         if (ierr /= 0) return
               
         ! use thomas to solve the coupling system
         d% zeta => d% psi_bar ! psi_bar is the rhs, zeta is the solution.  nrhs = 1
         call set_ptr_3(p3, d% zeta, mblk, nrhs, num_parts-1)
         call thomas_solve( &
            d% thomas_id, mblk, num_parts-1, d% gamma_bar, &
            d% lambda_bar, d% phi_bar, p3, nrhs, d% ipiv, &
            lrd, rpar_decsol, lid, ipar_decsol, ierr)
         if (ierr /= 0) then
            !write(*,*) 'd_and_c_solve failed in thomas_solve for coupling system'
            return
         end if
         
!$OMP PARALLEL DO PRIVATE(i,k,p)
         do i = 1, num_parts
            p => d% p(i)
            do k = 1, p% kblk
               ! x(i) = y_bar(i) - u_bar(i)*zeta(i-1) - v_bar(i)*zeta(i)
               ! x and y_bar are synomyns, so no need to copy y_bar(i) to x(i)
               if (i < num_parts) &
                  call do_dgemv(p% v_bar(:,:,k), d% zeta(:,i), p% x(:,k))
               if (i > 1) &
                  call do_dgemv(p% u_bar(:,:,k), d% zeta(:,i-1), p% x(:,k))
            end do
            if (i == num_parts) cycle
            p% zeta = d% zeta(:,i)
         end do
!$OMP END PARALLEL DO
         
         
         contains

            
         subroutine do_dgemv(a,b,c)
            real(fltp), dimension(:,:) :: a
            real(fltp), dimension(:) :: b,c
            call my_gemv(mblk, mblk, a, mblk, b, c)
         end subroutine do_dgemv
         
         
         subroutine partition_xy 
            integer :: i, first, last
            do i = 1, num_parts
               p => d% p(i)
               first = p% first
               last = p% last
               p% y => xy(:,first:last)
               p% x => p% y
               p% y_bar => p% y
               if (i < num_parts) then
                  p% psi => xy(:,last+1)
                  p% zeta => p% psi
               end if
            end do

            if (dbg) then
               do i=1, num_parts
                  p => d% p(i)
                  write(*,*) 'y', i
                  call set_ptr_3(p3, p% y, mblk, 1, p% kblk)
                  call show_vector(p3)
                  write(*,*)
                  if (i < num_parts) then
                     write(*,*) 'psi', i
                     call set_ptr_3(p3, p% psi, mblk, 1, 1)
                     call show_vector(p3)
                     write(*,*)
                  end if
               end do
               write(*,*)
               !stop
            end if

         end subroutine partition_xy
         
           
      end subroutine d_and_c_solve
      
      
      subroutine d_and_c_dealloc( &
            id, &
            thomas_handle, thomas_factor, thomas_solve, thomas_dealloc, thomas_stats, &
            lrd, rpar_decsol, lid, ipar_decsol, &
            ierr)
         integer, intent(in) :: id
         interface
#ifdef DBLE
            include 'thomas_block_procs_dble.dek'
#else
            include 'thomas_block_procs_quad.dek'
#endif
         end interface
         integer, intent(in) :: lrd, lid
         real(dp), pointer, intent(inout) :: rpar_decsol(:) ! (lrd)
         integer, pointer, intent(inout) :: ipar_decsol(:) ! (lid)
         integer, intent(out) :: ierr
         
         type (d_and_c_info), pointer :: d
         integer :: i, op_err, num_parts
         integer :: time0, time1, clock_rate
         type (part), pointer :: p         
         include 'formats.dek'  
         
         !write(*,*) 'enter d_and_c_dealloc'
                
         if (do_dealloc_subtiming) call system_clock(time0)
         if (.not. have_omp_num_threads) then
            write(*,*) 'error: have called d_and_c_dealloc before calling d_and_c_factor'
            ierr = -1
            return
         end if
         ierr = 0
         call get_ptr(id,d,ierr)
         if (ierr /= 0) return
         num_parts = d% num_p
         if (num_parts <= 0) then
            !write(*,2) 'd_and_c_dealloc num_parts', num_parts
         else
            if (do_refine_weights) call refine_weights(ierr)
            if (ierr /= 0) return
            
            do i = 1, num_parts
               p => d% p(i)
               op_err = 0
               call thomas_dealloc(p% thomas_id, op_err)
               if (op_err /= 0) ierr = op_err
            end do         
!            do i = 1, num_parts
!               p => d% p(i)
!               deallocate(p% ipiv, p% u_bar, p% v_bar)
!            end do         
!            deallocate(d% p, d% lambda_bar, d% gamma_bar, d% phi_bar, d% psi_bar)
         end if
         
!         deallocate(d% ipiv)
         
         call thomas_dealloc(d% thomas_id, ierr)
         if (do_dealloc_subtiming) then
            call system_clock(time1,clock_rate)
            dealloc_subtime_thomas = dealloc_subtime_thomas + dble(time1-time0)/clock_rate
            time0 = time1
         end if
         
         call do_free_handle(id)         

         if (do_dealloc_subtiming) then
            call system_clock(time1,clock_rate)
            dealloc_subtime_d_and_c = dealloc_subtime_d_and_c + dble(time1-time0)/clock_rate
         end if

         if (show_dealloc_subtiming) then
            write(*,*)
            write(*,'(a40,f12.4)') 'dealloc_subtime_thomas', dealloc_subtime_thomas
            write(*,'(a40,f12.4)') 'dealloc_subtime_d_and_c', dealloc_subtime_d_and_c
            write(*,'(a40,f12.4)') 'sum', dealloc_subtime_thomas + dealloc_subtime_d_and_c
            write(*,*)
         end if
         
         !write(*,*) 'exit d_and_c_dealloc'


         contains
         
         subroutine refine_weights(ierr)
            integer, intent(out) :: ierr
            real(fltp) :: max_time, sum_times, avg_time, err_max, err, &
               corr_factor, times(num_parts), err_min, lim_low_time, lim_high_time
            real(dp), pointer, dimension(:) :: &
               weights, factor_times, solve_times
            integer :: i, i_max, i_min, sum_kblks, knew, dk
            integer, pointer :: kblks(:)
            
            logical :: adjust_low_times, adjust_high_times
            
            include 'formats.dek'

            i = 1 ! for reporting max error in load balancing
            call set_pointer_1(weights, rpar_decsol(i+1:i+num_parts), num_parts)
            i = i+num_parts
            call set_pointer_1(factor_times, rpar_decsol(i+1:i+num_parts), num_parts)
            i = i+num_parts
            call set_pointer_1(solve_times, rpar_decsol(i+1:i+num_parts), num_parts)
            i = i+num_parts
            if (lrd < i) then
               write(*,3) 'mtx d_and_c_dealloc: lrd too small', lrd, i
               ierr = -1
               return
            end if
            
            i = num_ipars
            call set_int_pointer_1(kblks, ipar_decsol(i+1:i+num_parts), num_parts)
            i = i+num_parts
            if (lid < i) then
               write(*,3) 'mtx d_and_c_factor: lid too small', lid, i
               ierr = -1
               return
            end if            
            
            max_time = 0; sum_times = 0
            do i = 1, num_parts
               times(i) = factor_times(i) + solve_times(i)
               sum_times = sum_times + times(i)
            end do
            avg_time = sum_times/num_parts
            
            err_max = 0; err_min = 1d99
            do i = 1, num_parts
               err = times(i) - avg_time
               if (err > err_max) then
                  i_max = i; err_max = err
               end if
               if (err < err_min) then
                  i_min = i; err_min = err
               end if
            end do


            if (.true.) then

            !else if (.false.) then
            
               corr_factor = 0.1
               knew = int(0.5d0 + dble(kblks(i_max))*(1d0 + corr_factor*(avg_time/times(i_max) - 1d0)))
               dk = max(1,kblks(i_max) - knew)
               kblks(i_max) = kblks(i_max) - dk
               kblks(i_min) = kblks(i_min) + dk
               sum_kblks = sum(kblks(:))
               weights(:) = dble(kblks(:))/sum_kblks

            else if (.false.) then
            
               if (i_max > 1 .and. i_max < num_parts) then
                  weights(i_max-1:i_max+1) = sum(weights(i_max-1:i_max+1))/3d0
               else if (i_max > 1) then
                  weights(i_max:i_max+1) = sum(weights(i_max:i_max+1))/2d0
               else if (i_max < num_parts) then
                  weights(i_max-1:i_max) = sum(weights(i_max-1:i_max))/2d0
               end if
               
            else if (.false.) then
            
               if (i_min > 1 .and. i_min < num_parts) then
                  weights(i_min-1:i_min+1) = sum(weights(i_min-1:i_min+1))/3d0
               else if (i_min > 1) then
                  weights(i_min:i_min+1) = sum(weights(i_min:i_min+1))/2d0
               else if (i_min < num_parts) then
                  weights(i_min-1:i_min) = sum(weights(i_min-1:i_min))/2d0
               end if
               
            else
            
               corr_factor = 0.5
               adjust_low_times = .true.
               adjust_high_times = .true.
               lim_low_time = times(i_min)
               lim_high_time = times(i_max)
               do i = 1, num_parts
                  if ((times(i) <= lim_low_time .and. adjust_low_times) .or. &
                      (times(i) >= lim_high_time .and. adjust_high_times)) then
                     weights(i) = weights(i)*(1d0 + corr_factor*(avg_time/times(i) - 1d0))
                     !write(*,2) 'adjust weight fraction', i, (1d0 + corr_factor*(avg_time/times(i) - 1d0))
                  end if
               end do
            
            end if
            
            
            weights(:) = weights(:)/sum(weights(:))
            rpar_decsol(1) = err_max/avg_time ! report max abs relative deviation
            
            write(*,*) 'err_max/avg_time', i_max, err_max/avg_time

            return


            if (ipar_decsol(i_cnt) == 1) &
               write(*,'(8x,2(2a8,2a16),99a16)') &
                  'max', 'k4max', 'max time', 'max/avg', &
                  'min', 'k4min', 'min time', 'min/avg', &
                  'avg time', 'rel max', 'rel min'
            if (ipar_decsol(i_cnt) > 1) & ! don't even output the 1st time since it included transients
               write(*,'(i8,2(2i8,2f16.6),99f16.6)') ipar_decsol(i_cnt), &
                  i_max, kblks(i_max), times(i_max), times(i_max)/avg_time, &
                  i_min, kblks(i_min), times(i_min), times(i_min)/avg_time, &
                  avg_time, (times(i_max) - avg_time)/avg_time, (avg_time - times(i_min))/avg_time
            
            
            write(*,'(30x,99a12)') 'weight', 'time total', 'factor', 'solve'
            do i=1, num_parts
               if (i > 1 .and. i < num_parts) cycle
               p => d% p(i)
               write(*,'(a20,i4,i6,99f12.4)') 'weight', i, &
                  p% kblk, weights(i), times(i), factor_times(i), solve_times(i)
            end do
            write(*,*)
            
            !if (rpar_decsol(1) > 0.2) stop
            
         end subroutine refine_weights

      end subroutine d_and_c_dealloc

      
      subroutine d_and_c_block_init
         integer :: i, j, ierr, num_parts
         character (len=255) :: omp_num_threads_str
         type (d_and_c_info), pointer :: d
         type (part), pointer :: p
         if (have_omp_num_threads) return
!$omp critical (d_and_c_init)
         if (.not. have_omp_num_threads) then
            call get_environment_variable("OMP_NUM_THREADS", omp_num_threads_str)
            ierr = 0
            read(omp_num_threads_str,*,iostat=ierr) omp_num_threads
            if (ierr /= 0) then
               omp_num_threads = 1
            end if
            have_omp_num_threads = .true.
            num_parts = parts_per_thread*omp_num_threads
            do i = 1, max_handles
               d => handles(i)
               d% handle = i
               d% in_use = .false.
               nullify(d% lambda_bar)
               nullify(d% gamma_bar)
               nullify(d% phi_bar)
               nullify(d% psi_bar)
               nullify(d% ipiv)
               allocate(d% p(num_parts))
               do j = 1, num_parts
                  p => d% p(j)
                  nullify(p% ipiv)
                  nullify(p% u_bar)
                  nullify(p% v_bar)
               end do
            end do
         end if
!$omp end critical (d_and_c_init)
         if (ierr /= 0) then
            write(*,*)
            write(*,*)
            write(*,*)
            write(*,*)
            write(*,'(a)') 'Please set the OMP_NUM_THREADS environment variable.'
            write(*,'(a)') 'e.g., on a 2 core machine, you might do'
            write(*,'(a)') '  setenv OMP_NUM_THREADS 2'
            write(*,'(a)') 'or'
            write(*,'(a)') '  export OMP_NUM_THREADS=2'
            write(*,*)
            write(*,*)
            write(*,*)
            write(*,*)
            stop 1
         end if
      end subroutine d_and_c_block_init

      
      integer function d_and_c_alloc(ierr)
         use alert_lib,only:alert
         integer, intent(out) :: ierr
         integer :: i
         if (.not. have_omp_num_threads) call d_and_c_block_init
         ierr = 0
         d_and_c_alloc = -1
!$omp critical (d_and_c_handle)
         do i = 1, max_handles
            if (.not. handles(i)% in_use) then
               handles(i)% in_use = .true.
               d_and_c_alloc = i
               exit
            end if
         end do
!$omp end critical (d_and_c_handle)
         if (d_and_c_alloc == -1) then
            ierr = -1
            call alert(ierr, 'no available d_and_c handle')
            return
         end if
         if (handles(d_and_c_alloc)% handle /= d_and_c_alloc) then
            ierr = -1
            call alert(ierr, 'broken handle for d_and_c')
            return
         end if
      end function d_and_c_alloc
            
      
      subroutine do_free_handle(handle)
         integer, intent(in) :: handle
         type (d_and_c_info), pointer :: d
         if (handle >= 1 .and. handle <= max_handles) then
            d => handles(handle)
            handles(handle)% in_use = .false.
         end if
      end subroutine do_free_handle
      

      subroutine get_ptr(handle,d,ierr)
         use alert_lib,only:alert
         integer, intent(in) :: handle
         type (d_and_c_info), pointer :: d
         integer, intent(out):: ierr         
         if (handle < 1 .or. handle > max_handles) then
            ierr = -1
            call alert(ierr,'invalid d_and_c handle')
            return
         end if
         d => handles(handle)
         ierr = 0
      end subroutine get_ptr
      
      
      subroutine show_column(b)
         real(fltp), intent(in), dimension(:) :: b
         integer :: k, j
         do k=1,size(b,dim=1)
            write(*,fmt='(1pe12.4)') b(k)
         end do
      end subroutine show_column
      
      
      subroutine show_block(b)
         real(fltp), intent(in), dimension(:,:) :: b
         integer :: k, j
         do k=1,size(b,dim=1)
            do j=1,size(b,dim=2)
               write(*,fmt='(1pe12.4)',advance='no') b(k,j)
            end do
            write(*,*)
         end do
      end subroutine show_block
      
      
      subroutine show_vector(v)
         real(fltp), intent(in), dimension(:,:,:) :: v ! (mblk,mblk,kblk)
         integer :: k, kblk
         kblk = size(v,dim=3)
         do k=1,kblk
            call show_block(v(:,:,k))
         end do
      end subroutine show_vector

      
      subroutine d_and_c_work_sizes(nvar,nz,lrd,lid)
         integer, intent(in) :: nvar,nz
         integer, intent(out) :: lrd,lid
         integer :: num_parts
         call d_and_c_block_init
         num_parts = parts_per_thread*omp_num_threads
         lid = num_parts + 2
         lrd = 3*num_parts + 1
      end subroutine d_and_c_work_sizes


#ifdef DBLE
      end module d_and_c_block_dble
#else
      end module d_and_c_block_quad
#endif
