! ***********************************************************************
!
!   Copyright (C) 2011  Bill Paxton
!
!   This file is part of MESA.
!
!   MESA is free software; you can redistribute it and/or modify
!   it under the terms of the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,or
!   (at your option) any later version.
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!   GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not,write to the Free Software
!   Foundation,Inc.,59 Temple Place,Suite 330,Boston,MA 02111-1307 USA
!
! ***********************************************************************


#ifdef DBLE
      module test_block_tri_dble
#else
      module test_block_tri_quad
#endif
      
      use mtx_lib
      use mtx_def

#ifdef DBLE
      use const_def, only: dp
      use utils_lib, only: is_bad_num
#define is_bad is_bad_num
#else
      use const_def, only: qp
      use utils_lib, only: is_bad_quad
#define is_bad is_bad_quad
#endif

      
      implicit none

#ifdef DBLE
      integer, parameter :: fltp = dp
#else
      integer, parameter :: fltp = qp
#endif
      
      integer, parameter :: caller_id = 0
      
      
      contains
      
      
#ifdef DBLE

      subroutine do_test_block_tri_dble
         call test_block_thomas_dble(.false., .true.)
         call test_block_dc_mt_dble(.false., .true.)
         call test_block_thomas_klu(.false., .true.)
         call test_block_dc_mt_klu(.false., .true.)
      end subroutine do_test_block_tri_dble
      
      
      subroutine test_block_thomas_klu(do_timing, for_release)
         logical, intent(in) :: do_timing, for_release
         call test_block(block_thomas_klu, do_timing, for_release)
      end subroutine test_block_thomas_klu
      
      
      subroutine test_block_thomas_dble(do_timing, for_release)
         logical, intent(in) :: do_timing, for_release
         call test_block(block_thomas_dble, do_timing, for_release)
      end subroutine test_block_thomas_dble
      
      
      subroutine test_block_dc_mt_dble(do_timing, for_release)
         logical, intent(in) :: do_timing, for_release
         call test_block(block_dc_mt_dble, do_timing, for_release)
      end subroutine test_block_dc_mt_dble

      
      subroutine test_block_dc_mt_klu(do_timing, for_release)
         logical, intent(in) :: do_timing, for_release
         call test_block(block_dc_mt_klu, do_timing, for_release)
      end subroutine test_block_dc_mt_klu

#else

      subroutine do_test_block_tri_quad
         call test_block_thomas_quad(.false., .true.)
         call test_block_dc_mt_quad(.false., .true.)
      end subroutine do_test_block_tri_quad
      
      
      subroutine test_block_thomas_quad(do_timing, for_release)
         logical, intent(in) :: do_timing, for_release
         call test_block(block_thomas_quad, do_timing, for_release)
      end subroutine test_block_thomas_quad
      
      
      subroutine test_block_dc_mt_quad(do_timing, for_release)
         logical, intent(in) :: do_timing, for_release
         call test_block(block_dc_mt_quad, do_timing, for_release)
      end subroutine test_block_dc_mt_quad

#endif
      
      
      
      
      subroutine test_block(which_decsol_option, do_timing_in, for_release)

         use omp_lib

         integer, intent(in) :: which_decsol_option
         logical, intent(in) :: do_timing_in, for_release
         
         integer :: nrep
         real(fltp), pointer :: lblk(:,:,:), dblk(:,:,:), ublk(:,:,:) ! (nvar,nvar,nz)
         real(fltp), pointer :: l_init(:,:,:), d_init(:,:,:), u_init(:,:,:) ! (nvar,nvar,nz)
         real(fltp), pointer :: x(:,:), xcorrect(:,:), brhs(:,:), work(:,:) ! (nvar,nz)
         integer, pointer :: ipiv(:,:) ! (nvar,nz)
          
         real(dp), pointer :: rpar_decsol(:) ! (lrd)
         integer, pointer :: ipar_decsol(:) ! (lid)
         real(fltp) :: time_factor, time_solve, time_refine, time_dealloc, sum_times
         real(fltp) :: sum_factor, setup, factor_As, &
            factor_As_thm_fac_avg, factor_As_thm_slv_avg, &
            factor_As_thm_fac_max, factor_As_thm_slv_max, &
            setup_coupling, factor_coupling

         integer :: i, j, k, ierr, lid, lrd, nvar, nz, omp_num_threads, rep
         logical :: do_timing, do_refine, use_given_weights
         character (len=255) :: fname, which_decsol_str
         
         include 'formats.dek'
         do_timing = do_timing_in
#ifdef DBLE
         do_refine = .not. do_timing
#else
         do_refine = .false.
#endif
         
         ierr = 0
         
         call decsol_option_str(which_decsol_option, which_decsol_str, ierr)
         if (ierr /= 0) return
         
         if (for_release) then
            fname = 'block_tri.data'
         else
            !fname = 'block_tri.data'
            fname = 'block_tri_12.data'
            !fname = 'block_tri_72.data'
         end if
         if (do_timing) then
            if (for_release) then
               nrep = 50
            else
               nrep = 1 !20 !50
            end if
            omp_num_threads = omp_get_max_threads()
            call turn_on_d_and_c_timing
         else
            nrep = 1
         end if
         !write(*,2) 'nrep', nrep

         time_factor=0; time_solve=0; time_refine=0; time_dealloc=0
         
         if (do_timing) write(*,*) 'read test file'
         call read_testfile(fname)
         !call xread_testfile(fname)
         if (do_timing) write(*,3) trim(fname), nvar, nz

#ifdef DBLE
         if (which_decsol_option == block_dc_mt_dble) then
            write(*,*) 'test_block_dc_mt_dble'
            call block_dc_mt_dble_work_sizes(nvar,nz,lrd,lid)
            
         else if (which_decsol_option == block_thomas_dble) then
            write(*,*) 'block_thomas_dble'
            call block_thomas_dble_work_sizes(nvar,nz,lrd,lid)
            
         else if (which_decsol_option == block_dc_mt_klu) then
            write(*,*) 'test_block_dc_mt_klu'
            call block_dc_mt_klu_work_sizes(nvar,nz,lrd,lid)
            
         else if (which_decsol_option == block_thomas_klu) then
            write(*,*) 'block_thomas_klu'
            call block_thomas_klu_work_sizes(nvar,nz,lrd,lid)
            
#else
         if (which_decsol_option == block_dc_mt_quad) then
            write(*,*) 'test_block_dc_mt_quad'
            call block_dc_mt_quad_work_sizes(nvar,nz,lrd,lid)
            
         else if (which_decsol_option == block_thomas_quad) then
            write(*,*) 'block_thomas_quad'
            call block_thomas_quad_work_sizes(nvar,nz,lrd,lid)
#endif
         else
            write(*,*) 'bad value for which_decsol_option in test_block'
            stop 1
         end if
         
         allocate( &
            rpar_decsol(lrd), ipar_decsol(lid), x(nvar,nz), xcorrect(nvar,nz), &
            brhs(nvar,nz), ipiv(nvar,nz), work(nvar,nz), &
            l_init(nvar,nvar,nz), d_init(nvar,nvar,nz), u_init(nvar,nvar,nz), stat=ierr)
         if (ierr /= 0) then
            write(*,*) 'failed in alloc'
            stop 1
         end if
         do k=1,nz
            do i=1,nvar
               do j=1,nvar
                  l_init(j,i,k) = lblk(j,i,k)
                  d_init(j,i,k) = dblk(j,i,k)
                  u_init(j,i,k) = ublk(j,i,k)
               end do
            end do
         end do
         
         call set_xcorrect
         call set_brhs(lblk, dblk, ublk)
         
         if (do_timing) write(*,*) 'start timing'

         do rep = 1, nrep
         
            do k=1,nz
               do i=1,nvar
                  do j=1,nvar
                     lblk(j,i,k) = l_init(j,i,k)
                     dblk(j,i,k) = d_init(j,i,k)
                     ublk(j,i,k) = u_init(j,i,k)
                  end do
               end do
            end do

            if (do_timing .and. rep == nrep) call turn_on_d_and_c_show_timing
            
            use_given_weights = (rep > 1)

#ifdef DBLE
            if (which_decsol_option == block_dc_mt_dble) then
               call solve_blocks( &
                  use_given_weights, lblk, dblk, ublk, block_dc_mt_dble_decsolblk, null_decsolblk_quad)               
            else if (which_decsol_option == block_dc_mt_klu) then
               call solve_blocks( &
                  use_given_weights, lblk, dblk, ublk, block_dc_mt_klu_decsolblk, null_decsolblk_quad)               
            else if (which_decsol_option == block_thomas_dble) then
               call solve_blocks( &
                  use_given_weights, lblk, dblk, ublk, block_thomas_dble_decsolblk, null_decsolblk_quad)               
            else if (which_decsol_option == block_thomas_klu) then
               call solve_blocks( &
                  use_given_weights, lblk, dblk, ublk, block_thomas_klu_decsolblk, null_decsolblk_quad)
#else
            if (which_decsol_option == block_dc_mt_quad) then
               call solve_blocks( &
                  use_given_weights, lblk, dblk, ublk, null_decsolblk, block_dc_mt_quad_decsolblk)              
            else if (which_decsol_option == block_thomas_quad) then
               call solve_blocks( &
                  use_given_weights, lblk, dblk, ublk, null_decsolblk, block_thomas_quad_decsolblk)
#endif            
            
            else
            
               write(*,*) 'missing case for which_decsol_option', which_decsol_option
               stop 1
               
            end if
            
         end do
         
         call check_x

#ifdef DBLE
         if (which_decsol_option == block_dc_mt_dble) then
               write(*,*) 'done test_block_dc_mt_dble'
         else if (which_decsol_option == block_thomas_dble) then
               write(*,*) 'done block_thomas_dble'
         else if (which_decsol_option == block_dc_mt_klu) then
               write(*,*) 'done test_block_dc_mt_klu'
         else if (which_decsol_option == block_thomas_klu) then
               write(*,*) 'done block_thomas_klu'
         end if
#else
         if (which_decsol_option == block_dc_mt_quad) then
               write(*,*) 'done test_block_dc_mt_quad'
         else if (which_decsol_option == block_thomas_quad) then
               write(*,*) 'done block_thomas_quad'
         end if
#endif

         if (do_timing) then
            call get_d_and_c_timing
            sum_times = time_factor + time_solve + time_refine + time_dealloc
            sum_factor = setup + factor_As + setup_coupling + factor_coupling
            write(*,*)
            write(*,*)
            write(*,*) 'top level operation times in ' // trim(which_decsol_str)
            write(*,'(99a9)') 'threads', 'factor', 'solve', 'dealloc', 'total'
            write(*,'(i9,99f9.4)') omp_num_threads, time_factor, time_solve, time_dealloc, sum_times
            write(*,*)
            write(*,*)
            write(*,*) 'overall summary of times in d_and_c_factor'
            write(*,'(5a9,15x,3a9,15x,99a9)') 'sum', 'setup', 'fac As', 'setup C', 'fac C'
            write(*,'(5f9.4,15x,3f9.4,15x,99f9.4)') &
               sum_factor, setup, factor_As, setup_coupling, factor_coupling
            write(*,*)
         end if
         
         write(*,*)

         deallocate(rpar_decsol, ipar_decsol, x, xcorrect, work, &
            brhs, ipiv, lblk, dblk, ublk, l_init, d_init, u_init)
         
         contains
         
         
         subroutine solve_blocks(use_given_weights, lblk, dblk, ublk, decsolblk, decsolblk_quad)
            logical, intent(in) :: use_given_weights
            real(fltp), pointer :: lblk(:,:,:), dblk(:,:,:), ublk(:,:,:) ! (nvar,nvar,nz)
            interface
               include 'mtx_decsolblk_dble.dek'
               include 'mtx_decsolblk_quad.dek'
            end interface
            
            integer :: iop, time0, time1, clock_rate, rep
            real(fltp) :: avg_err, max_err, err, atol, rtol
            integer :: i_max, j_max, i, j, k
         
            include 'formats.dek'         
         
            if (do_timing) call system_clock(time0,clock_rate)
            
            if (which_decsol_option == block_dc_mt_klu .and. use_given_weights) then
               iop = 3 ! factor using given weights to partition blocks
            else
               iop = 0 ! factor A
            end if
#ifdef DBLE
            call decsolblk( &
               iop,caller_id,nvar,nz,lblk,dblk,ublk,x,ipiv,lrd,rpar_decsol,lid,ipar_decsol,ierr)
#else
            call decsolblk_quad( &
               iop,caller_id,nvar,nz,lblk,dblk,ublk,x,ipiv,lrd,rpar_decsol,lid,ipar_decsol,ierr)
#endif
            if (ierr /= 0) then
               write(*,*) 'decsolblk failed for factor'
               stop 1
            end if

            if (do_timing) then
               call system_clock(time1,clock_rate)
               time_factor = time_factor + dble(time1-time0)/clock_rate
               time0 = time1
            end if
         
            do rep=1,1
            
               iop = 1 ! solve A*x = b
               
               do k=1,nz
                  do j=1,nvar
                     x(j,k) = brhs(j,k)
                  end do
               end do
#ifdef DBLE
               call decsolblk( &
                  iop,caller_id,nvar,nz,lblk,dblk,ublk,x,ipiv,lrd,rpar_decsol,lid,ipar_decsol,ierr)
#else
               call decsolblk_quad( &
                  iop,caller_id,nvar,nz,lblk,dblk,ublk,x,ipiv,lrd,rpar_decsol,lid,ipar_decsol,ierr)
#endif
               if (ierr /= 0) then
                  write(*,*) 'decsolblk failed for solve'
                  stop 1
               end if

               if (do_timing) then
                  call system_clock(time1,clock_rate)
                  time_solve = time_solve + dble(time1-time0)/clock_rate
                  time0 = time1
               end if
            
#ifdef DBLE
               if (do_refine) then
                  if (for_release) then
                     call block_dble_refine1( &
                        l_init, d_init, u_init, lblk, dblk, ublk, ipiv, decsolblk, &
                        brhs, x, work, caller_id, lrd, rpar_decsol, lid, ipar_decsol, ierr)
                  else
                     atol = 1d-6
                     rtol = 1d-6   
                     call check1_x(avg_err, max_err, atol, rtol, i_max, j_max)
                     if (do_timing) write(*,fmt='(a30,e18.8)') 'before refine: max_err', max_err
                     call block_dble_refine1( &
                        l_init, d_init, u_init, lblk, dblk, ublk, ipiv, decsolblk, &
                        brhs, x, work, caller_id, lrd, rpar_decsol, lid, ipar_decsol, ierr)
                     call check1_x(avg_err, max_err, atol, rtol, i_max, j_max)
                     if (do_timing) write(*,fmt='(a30,e18.8)') 'after refine: max_err', max_err
                  end if
                  if (do_timing) then
                     call system_clock(time1,clock_rate)
                     time_refine = time_refine + dble(time1-time0)/clock_rate
                     time0 = time1
                  end if
               end if
#endif
            
            end do
         
            iop = 2 ! deallocate
#ifdef DBLE
            call decsolblk( &
               iop,caller_id,nvar,nz,lblk,dblk,ublk,x,ipiv,lrd,rpar_decsol,lid,ipar_decsol,ierr)
#else
            call decsolblk_quad( &
               iop,caller_id,nvar,nz,lblk,dblk,ublk,x,ipiv,lrd,rpar_decsol,lid,ipar_decsol,ierr)
#endif
            if (ierr /= 0) then
               write(*,*) 'decsolblk failed for deallocate'
               stop 1
            end if

            if (do_timing) then
               call system_clock(time1,clock_rate)
               time_dealloc = time_dealloc + dble(time1-time0)/clock_rate
            end if
                     
         end subroutine solve_blocks
          
         
         subroutine turn_on_d_and_c_show_timing
            use d_and_c_block_dble, only: show_factor_subtiming, show_dealloc_subtiming
            show_factor_subtiming = .true.
            !show_dealloc_subtiming = .true.
         end subroutine turn_on_d_and_c_show_timing
        
         
         subroutine turn_off_d_and_c_timing
            use d_and_c_block_dble, only: do_factor_subtiming, do_dealloc_subtiming
            do_factor_subtiming = .false.
            !do_dealloc_subtiming = .false.
         end subroutine turn_off_d_and_c_timing
         
         subroutine turn_on_d_and_c_timing
            use d_and_c_block_dble, only: &
               do_factor_subtiming, factor_subtime_setup, factor_subtime_factor_As, &
               factor_subtime_setup_C, factor_subtime_factor_C, &
               do_dealloc_subtiming, dealloc_subtime_d_and_c, dealloc_subtime_thomas
               
            do_factor_subtiming = .true.
            !do_dealloc_subtiming = .true.

            setup = 0
            factor_As = 0
            setup_coupling = 0
            factor_coupling = 0

            factor_subtime_setup = 0
            factor_subtime_factor_As = 0
            factor_subtime_setup_C = 0
            factor_subtime_factor_C = 0
            
            dealloc_subtime_thomas = 0
            dealloc_subtime_d_and_c = 0
            
         end subroutine turn_on_d_and_c_timing


         subroutine get_d_and_c_timing
            use d_and_c_block_dble, only: factor_subtime_setup, factor_subtime_factor_As, &
               factor_subtime_setup_C, factor_subtime_factor_C
            
            setup = setup + factor_subtime_setup
            factor_As = factor_As + factor_subtime_factor_As
            setup_coupling = setup_coupling + factor_subtime_setup_C
            factor_coupling = factor_coupling + factor_subtime_factor_C
            
         end subroutine get_d_and_c_timing
         
         
         subroutine read_testfile(fname)
            character (len=*), intent(in) :: fname
            integer :: iounit, ierr, i, j, k, line
            !write(*,*) 'reading ' // trim(fname)
            iounit = 33; ierr = 0
            open(unit=iounit, file=trim(fname), status='old', action='read', iostat=ierr)
            if (ierr /= 0) then
               write(*,*) 'failed to open ' // trim(fname)
               stop 1
            end if
#ifdef DBLE
            call mtx_read_block_tridiagonal(iounit,nvar,nz,lblk,dblk,ublk,ierr)
#else
            call mtx_read_quad_block_tridiagonal(iounit,nvar,nz,lblk,dblk,ublk,ierr)
#endif
            if (ierr /= 0) then
               write(*,*) 'failed to read ' // trim(fname)
               stop 1
            end if
            close(iounit)
         end subroutine read_testfile
         
         
         subroutine set_brhs(lblk, dblk, ublk)
            real(fltp), pointer :: lblk(:,:,:), dblk(:,:,:), ublk(:,:,:) ! (nvar,nvar,nz)
            integer :: k, j
            include 'formats.dek'
            ! set brhs = A*xcorrect
#ifdef DBLE
            call block_dble_mv(lblk, dblk, ublk, xcorrect, brhs)
#else
            call block_quad_mv(lblk, dblk, ublk, xcorrect, brhs)
#endif
            return
            do k = 1, 2 !nz
               do j = 1, nvar
                  if (brhs(j,k) /= 0) write(*,3) 'brhs xcorrect', j, k, brhs(j,k), xcorrect(j,k)
               end do
            end do
            write(*,*) 'end set_brhs'
            stop
         end subroutine set_brhs
         
         
         subroutine check_x
            real(fltp) :: max_err, err, atol, rtol, avg_err
            integer :: i_max, j_max,i, j, rep        
            include 'formats.dek'
            atol = 1d-4
            rtol = 1d-4   
            call check1_x(avg_err, max_err, atol, rtol, i_max, j_max)
            i = i_max; j = j_max
            if (max_err > 1) then
               write(*,3) 'BAD: err, x, xcorrect', i, j, max_err, x(i,j), xcorrect(i,j)
               !write(*,3) 'BAD: avg err, max err, x, xcorrect', i, j, avg_err, max_err, x(i,j), xcorrect(i,j)
            else if (do_timing .and. .not. do_refine) then
               write(*,3) 'err', i, j, max_err
               !write(*,3) 'avg err, max err', i, j, avg_err, max_err
            end if
         end subroutine check_x
         
         
         subroutine check1_x(avg_err, max_err, atol, rtol, i_max, j_max)
            real(fltp), intent(out) :: avg_err, max_err
            real(fltp), intent(in) ::  atol, rtol
            integer, intent(out) :: i_max, j_max
            integer :: i, j
            real(fltp) :: err_sum
            real(fltp) :: err
            include 'formats.dek'      
            max_err = 0; i_max = 0; j_max = 0; err_sum = 0
            do j = 1, nz
               do i = 1, nvar
                  if (is_bad(x(i,j))) then
                     write(*,3) 'x xcorrect', i, j, x(i,j), xcorrect(i,j)
                     stop 'check1_x'
                  end if
                  err = abs(x(i,j) - xcorrect(i,j))/(atol + rtol*max(abs(x(i,j)),abs(xcorrect(i,j))))
                  err_sum = err_sum + err
                  if (err > max_err) then
                     max_err = err; i_max = i; j_max = j
                  end if
               end do
            end do
            avg_err = err_sum/(nz*nvar)
         end subroutine check1_x
         
         
         subroutine set_xcorrect
            real(fltp) :: cnt
            integer :: k, j
            cnt = 1d0
            do k=1,nz
               do j=1,nvar
                  cnt = cnt + 1d-3
                  xcorrect(j,k) = cnt
               end do
            end do
         end subroutine set_xcorrect
           
      end subroutine test_block




#ifdef DBLE
      end module test_block_tri_dble
#else
      end module test_block_tri_quad
#endif
