! ***********************************************************************
!
!   Copyright (C) 2010  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
!
! ***********************************************************************

      module mod_diffusion_support

      use const_def
      use chem_def
      use utils_lib, only:is_bad_num, has_bad_num, return_nan
      use alert_lib, only:alert

      implicit none

      logical, parameter :: dbg = .false.      

      contains
         
      ! concentration is number density of ion divided by number density of free electrons
      subroutine get_C(nz, nzlo, nzhi, nc, m, Z, A, X, C, dC_dX)
         integer, intent(in) :: nz, nzlo, nzhi, nc, m
         double precision, intent(in) :: Z(m,nz) ! typical charge
         double precision, dimension(m), intent(in) :: A ! atomic number
         double precision, dimension(m,nz), intent(in) :: X ! mass fractions
         double precision, dimension(m,nz), intent(out) :: C ! concentration
         double precision, dimension(nc,nz), intent(out) :: dC_dX
         integer :: k, j
         double precision :: tmp      
!$OMP PARALLEL DO PRIVATE(k,j,tmp)
         do k=nzlo,nzhi
            tmp = sum(Z(1:nc,k)*X(1:nc,k)/A(1:nc))
            forall (j=1:nc)
               dC_dX(j,k) = 1/(A(j)*tmp)
               C(j,k) = X(j,k)*dC_dX(j,k)
            end forall
            C(m,k) = 1
         end do
!$OMP END PARALLEL DO
      end subroutine get_C
      
      
      subroutine get_electron_mass_fractions( &
               nz, nzlo, nzhi, nc, m, A, C, X)
         integer, intent(in) :: nz, nzlo, nzhi, nc, m
         double precision, dimension(m), intent(in) :: A ! atomic number
         double precision, dimension(m,nz), intent(in) :: C ! concentration
         double precision, dimension(m,nz), intent(inout) :: X ! mass fractions
         integer :: k
!$OMP PARALLEL DO PRIVATE(k)
         do k=nzlo,nzhi
            X(m,k) = A(m)/dot_product(A(1:nc),C(1:nc,k))
         end do
!$OMP END PARALLEL DO
      end subroutine get_electron_mass_fractions
      
            
      subroutine get_dC_dr_mid(m, nz, nzlo, nzhi, C, four_pi_r2_rho_mid, dm, dC_dr_mid)
         integer, intent(in) :: m, nz, nzlo, nzhi         
         double precision, dimension(m,nz), intent(in) :: C 
         double precision, intent(in) :: four_pi_r2_rho_mid(nz), dm(nz)
         double precision, dimension(m,nz), intent(out) :: dC_dr_mid
            ! sets dlnC_dr from nzlo to nzhi
         integer :: k, j
         double precision :: s1, s2
         dC_dr_mid(:,nzlo) = (C(:,nzlo) - C(:,nzlo+1))*four_pi_r2_rho_mid(nzlo)/dm(nzlo)
!$OMP PARALLEL DO PRIVATE(k, j, s1, s2)
         do k = nzlo+1, nzhi-1
            do j = 1, m
               s1 = (C(j,k-1) - C(j,k))*four_pi_r2_rho_mid(k-1)/dm(k-1)
               s2 = (C(j,k) - C(j,k+1))*four_pi_r2_rho_mid(k)/dm(k)
               if (s1*s2 <= 0) then
                  dC_dr_mid(j,k) = 0
               else
                  dC_dr_mid(j,k) = 2*s1*s2/(s1+s2)
               end if
            end do
         end do
!$OMP END PARALLEL DO
         dC_dr_mid(:,nzhi) = dC_dr_mid(:,nzhi-1)
      end subroutine get_dC_dr_mid
      
            
      subroutine get_middle_C_and_X( &
            nz, m, nzlo, nzhi, C, X, dC_dr_mid, tiny_C, C_mid, X_mid, dlnC_dr_mid)
         integer, intent(in) :: m, nz, nzlo, nzhi         
         double precision, dimension(m,nz), intent(in) :: C, X, dC_dr_mid
         double precision, intent(in) :: tiny_C
         double precision, dimension(m,nz), intent(out) :: C_mid, X_mid, dlnC_dr_mid
         integer :: k
!$OMP PARALLEL DO PRIVATE(k)
         do k = nzlo, nzhi-1
            X_mid(:,k) = (X(:,k) + X(:,k+1))/2
            C_mid(:,k) = (C(:,k) + C(:,k+1))/2
            dlnC_dr_mid(:,k) = dC_dr_mid(:,k)/max(tiny_C,C_mid(:,k))
         end do
!$OMP END PARALLEL DO
         C_mid(:,nzhi) = C_mid(:,nzhi-1)
         X_mid(:,nzhi) = X_mid(:,nzhi-1)
         dlnC_dr_mid(:,nzhi) = dlnC_dr_mid(:,nzhi-1)
      end subroutine get_middle_C_and_X
      

      subroutine eval_coulomb_logs( &
            nz, m, nzlo, nzhi, A, Z, X_mid, rho_mid, T_mid, clg_mid)
         integer, intent(in) :: nz, m, nzlo, nzhi
         double precision, dimension(m), intent(in) :: A
         double precision, dimension(m,nz), intent(in) :: X_mid, Z
         double precision, dimension(nz), intent(in) :: rho_mid, T_mid
         double precision, dimension(m,m,nz), intent(out) :: clg_mid
         integer :: k
!$OMP PARALLEL DO PRIVATE(k)
         do k=nzlo, nzhi-1
            call get_coulomb_logs( &
               m, A, Z(:,k), X_mid(:,k), rho_mid(k), T_mid(k), clg_mid(:,:,k))
         end do
!$OMP END PARALLEL DO
      end subroutine eval_coulomb_logs


      subroutine get_X_Y_limit_coeffs( &
            nz, m, nzlo, nzhi, ih, ihe, X_mid, X_full_on, X_full_off, &
            Y_full_on, Y_full_off, X_Y_limit_coeffs)
         ! only compute diffusion velocities in regions with non-degenerate electrons
         ! decrease coeffs to 0 as X goes from X_full_on to X_full_off
         integer, intent(in) :: nz, nzlo, nzhi, m, ih, ihe
         double precision, dimension(m,nz), intent(in) :: X_mid
         double precision, intent(in) :: X_full_on, X_full_off, Y_full_on, Y_full_off
         double precision, intent(out) :: X_Y_limit_coeffs(nz)
         integer :: k
         double precision :: X_term, Y_term   
!$OMP PARALLEL DO PRIVATE(k,X_term,Y_term)
         do k=nzlo, nzhi-1
            if (X_mid(ih,k) >= X_full_on) then
               X_term = 1
            else if (X_mid(ih,k) <= X_full_off) then
               X_term = 0
            else
               X_term = (X_mid(ih,k) - X_full_off) / (X_full_on - X_full_off)
            end if
            if (X_mid(ihe,k) >= Y_full_on) then
               Y_term = 1
            else if (X_mid(ihe,k) <= Y_full_off) then
               Y_term = 0
            else
               Y_term = (X_mid(ihe,k) - Y_full_off) / (Y_full_on - Y_full_off)
            end if
            if (X_term*Y_term == 0) then
               X_Y_limit_coeffs(k) = 0
            else if (X_term*Y_term == 1) then
               X_Y_limit_coeffs(k) = 1
            else
               X_Y_limit_coeffs(k) = 0.5d0*(1 - cos(pi*X_term*Y_term))
            end if
         end do
!$OMP END PARALLEL DO         
      end subroutine get_X_Y_limit_coeffs
      

      subroutine eval_gradient_coeffs( &
            nz, m, ihe, nzlo, nzhi, Y_full_off, &
            A, Z, X_mid, C_mid, clg_mid, aP, aT, aX, E_field, ierr)
         !use mod_thoul_solve1, only: do1_solve_thoul
         integer, intent(in) :: nz, m, ihe, nzlo, nzhi
         double precision, intent(in) :: Y_full_off
         double precision, dimension(m), intent(in) :: A
         double precision, dimension(m, nz), intent(in) :: X_mid
         double precision, dimension(m, nz), intent(in) :: Z
         double precision, dimension(m, m, nz), intent(in) :: clg_mid
         double precision, dimension(m, nz), intent(in) :: C_mid
         double precision, dimension(m, nz), intent(out) :: AP, AT
         double precision, dimension(m, m, nz), intent(out) :: AX
         double precision, dimension(nz), intent(out) :: E_field
         integer, intent(out) :: ierr         
         integer :: k, op_err
         1 format(a40,1pe26.16)    
         2 format(a40,i6,1pe26.16)    
!$OMP PARALLEL DO PRIVATE(k, op_err) SCHEDULE(DYNAMIC, 10)
         do k=nzlo, nzhi-1
            op_err = 0
            call eval1_gradient_coeffs( &
               k, nz, m, ihe, nzlo, nzhi, Y_full_off, &
               A, Z, X_mid, C_mid, clg_mid, aP, aT, aX, E_field, op_err)
            if (op_err /= 0) ierr = op_err
         end do
!$OMP END PARALLEL DO
         AP(:,nzhi) = AP(:,nzhi-1)
         AT(:,nzhi) = AT(:,nzhi-1)
         AX(:,:,nzhi) = AX(:,:,nzhi-1)
      end subroutine eval_gradient_coeffs
      

      subroutine eval1_gradient_coeffs( &
            k, nz, m, ihe, nzlo, nzhi, Y_full_off, &
            A, Z, X_mid, C_mid, clg_mid, aP, aT, aX, E_field, ierr)
         !use mod_thoul_solve1, only: do1_solve_thoul
         integer, intent(in) :: k, nz, m, ihe, nzlo, nzhi
         double precision, intent(in) :: Y_full_off
         double precision, dimension(m), intent(in) :: A
         double precision, dimension(m, nz), intent(in) :: X_mid
         double precision, dimension(m, nz), intent(in) :: Z
         double precision, dimension(m, m, nz), intent(in) :: clg_mid
         double precision, dimension(m, nz), intent(in) :: C_mid
         double precision, dimension(m, nz), intent(out) :: AP, AT
         double precision, dimension(m, m, nz), intent(out) :: AX
         double precision, dimension(nz), intent(out) :: E_field
         integer, intent(out) :: ierr  
         double precision, dimension(m) :: X_mid1, Z1, C_mid1, AP1, AT1
         double precision, dimension(m,m) :: clg_mid1, AX1
         double precision :: E_field1
         ierr = 0
         if (X_mid(ihe,k) < Y_full_off) then
            AP(:,k) = 0
            AT(:,k) = 0
            AX(:,:,k) = 0
            E_field(k) = 0
         else

            Z1(:) = Z(:,k)
            X_mid1(:) = X_mid(:,k)
            C_mid1(:) = C_mid(:,k)
            clg_mid1(:,:) = clg_mid(:,:,k)
            
            call do1_solve_thoul(2*m+2, m, ihe, A, Z1, X_mid1, C_mid1, &
               clg_mid1, AP1, AT1, AX1, E_field1, ierr)
               
            AP(:,k) = AP1
            AT(:,k) = AT1
            AX(:,:,k) = AX1
            E_field(k) = E_field1

         end if
      end subroutine eval1_gradient_coeffs       


      subroutine eval_velocities( &
               nz, m, nc, nzlo, nzhi, AP, AT, AX, rho_mid, T_mid, &
               dlnP_dr_mid, dlnT_dr_mid, dlnRho_dr_mid, X_mid, dlnC_dr_mid, &
               eta_T_limit_coeffs, X_Y_limit_coeffs, vgt_max, diffusion_factor, &
               v, vgt, sigma_lnC)
         integer, intent(in) :: nz, m, nc, nzlo, nzhi
         double precision, dimension(m, nz), intent(in) :: AP, AT, X_mid, dlnC_dr_mid
         double precision, dimension(m, m, nz), intent(in) :: AX
         double precision, dimension(nz), intent(in) :: &
            rho_mid, T_mid, dlnP_dr_mid, dlnT_dr_mid, dlnRho_dr_mid, &
            eta_T_limit_coeffs, X_Y_limit_coeffs
         double precision, intent(in) :: vgt_max
         double precision, intent(in) :: diffusion_factor(nc)
         double precision, intent(out) :: v(nc,nz) ! v(i,k) is velocity of nc i at point k 
         double precision, intent(out) :: vgt(nc,nz) ! vgt(i,k) is gravothermal part of v
         double precision, intent(out) :: sigma_lnC(nc,nc,nz) 
         integer :: k
         1 format(a40,12x,1pe26.16) 
         2 format(a40,6x,i6,1pe26.16) 
         3 format(a40,2i6,1pe26.16) 
!$OMP PARALLEL DO PRIVATE(k)
         do k=nzlo, nzhi-1
            call eval1_diffusion_velocities( &
               nc, m, AP(:,k), AT(:,k), AX(:,:,k), &
               rho_mid(k), T_mid(k), dlnP_dr_mid(k), dlnT_dr_mid(k), dlnRho_dr_mid(k), &
               X_mid(:,k), dlnC_dr_mid(:,k), &
               eta_T_limit_coeffs(k)*X_Y_limit_coeffs(k), &
               vgt_max, diffusion_factor, &
               v(:,k), vgt(:,k), sigma_lnC(:,:,k))
         end do
!$OMP END PARALLEL DO
         v(:,nzhi) = v(:,nzhi-1)
         vgt(:,nzhi) = vgt(:,nzhi-1)
         sigma_lnC(:,:,nzhi) = sigma_lnC(:,:,nzhi-1)
      end subroutine eval_velocities
      
      
      subroutine eval1_diffusion_velocities( &
            nc, m, AP, AT, AX, rho_mid, T_mid, dlnP_dr_mid, dlnT_dr_mid, dlnRho_dr_mid, &
            X_mid, dlnC_dr_mid, limit_coeff, vgt_max, diffusion_factor, &
            v, vgt, sigma_lnC)
         !use mod_thoul_solve1
         integer, intent(in) :: nc, m
         double precision, intent(in), dimension(m) :: aP, aT
         double precision, intent(in), dimension(m, m) :: aX 
         double precision, intent(in) :: rho_mid, T_mid, limit_coeff, vgt_max
         double precision, intent(in) :: dlnP_dr_mid, dlnT_dr_mid, dlnRho_dr_mid
         double precision, intent(in), dimension(nc) :: X_mid, dlnC_dr_mid, diffusion_factor
         double precision, intent(out) :: v(nc), vgt(nc), sigma_lnC(nc,nc)
         integer :: i, k, im
         double precision :: coef, dv_im
         double precision :: tau0  ! = 6d13*secyer, characteristic solar diffusion time (seconds)
         double precision, parameter :: rho_unit = 1d2
         double precision, parameter :: T_unit = 1d7

         1 format(a40,12x,1pe26.16) 
         
         tau0 = 6d13*secyer
         coef = limit_coeff*Rsun*(T_mid/T_unit)**2.5d0/(rho_mid/rho_unit)*(Rsun/tau0) ! converts to cgs units

         if (dbg) then
            if (is_bad_num(coef)) then
               write(*,*) 'bad num coef', coef
               write(*,1) 'T_mid', T_mid
               write(*,1) 'rho_mid', rho_mid
               write(*,1) '(T_mid/T_unit)**2.5d0', (T_mid/T_unit)**2.5d0
               write(*,1) 'limit_coeff', limit_coeff
               write(*,1) 'rho_mid/rho_unit', rho_mid/rho_unit
               write(*,1) 'Rsun/tau0', Rsun/tau0
               write(*,*)
               stop 'eval1_diffusion_velocities'
            end if
         end if
         
         if (coef == 0) then
            v(:) = 0; vgt(:) = 0; sigma_lnC(:,:) = 0; return
         end if
         do i=1,nc
            vgt(i) = diffusion_factor(i)*coef*(AP(i)*dlnP_dr_mid + AT(i)*dlnT_dr_mid)
            if (abs(vgt(i)) > vgt_max) vgt(i) = sign(1d0,vgt(i))*vgt_max
            sigma_lnC(i,:) = -diffusion_factor(i)*coef*AX(i,1:nc)
            v(i) = vgt(i) - sum(sigma_lnC(i,1:nc)*dlnC_dr_mid(1:nc))
         end do
         ! final fixup for velocity of most abundant to give exact local mass conservation.
         im = maxloc(X_mid(1:nc),dim=1)
         dv_im = -dot_product(X_mid(1:nc), v(1:nc))/X_mid(im)
         vgt(im) = vgt(im) + dv_im
         v(im) = v(im) + dv_im
      end subroutine eval1_diffusion_velocities

         
      subroutine get_flow_coeffs( &
            nz, nc, m, nzlo, nzhi, v, vgt, vgt_max, sigma_lnC, four_pi_r2_rho_mid, &
            total_time, dm, X, X_mid, C, C_mid, tiny_C, &
            eta_T_limit_coeffs, X_Y_limit_coeffs, AD_factor, AD_velocity, GT, CD, AD)
         integer, intent(in) :: nz, nc, m, nzlo, nzhi
         double precision, intent(inout) :: v(nc,nz)
         double precision, intent(in) :: vgt(nc,nz), vgt_max
         double precision, intent(in) :: sigma_lnC(nc,nc,nz)
         double precision, intent(in) :: four_pi_r2_rho_mid(nz), total_time, dm(nz)
         double precision, intent(in) :: &
            X(m,nz), X_mid(m,nz), C(m,nz), C_mid(m,nz), tiny_C, &
            eta_T_limit_coeffs(nz), X_Y_limit_coeffs(nz), AD_factor, AD_velocity
         double precision, intent(out) :: GT(nc,nz)
         double precision, intent(out) :: CD(nc,nc,nz)
         double precision, intent(out) :: AD(nz)
         integer :: k
!$OMP PARALLEL DO PRIVATE(k)
         do k = nzlo, nzhi-1
            call get1_flow_coeffs( &
               k, nz, nc, m, nzlo, nzhi, v, vgt, vgt_max, sigma_lnC, four_pi_r2_rho_mid, &
               total_time, dm, X, X_mid, C, C_mid, tiny_C, &
               eta_T_limit_coeffs, X_Y_limit_coeffs, AD_factor, AD_velocity, GT, CD, AD)
         end do
!$OMP END PARALLEL DO
         GT(:,nzhi) = GT(:,nzhi-1)
         CD(:,:,nzhi) = CD(:,:,nzhi-1)
         AD(nzhi) = AD(nzhi-1)
      end subroutine get_flow_coeffs
      
      
      subroutine get1_flow_coeffs( &
            k, nz, nc, m, nzlo, nzhi, v, vgt, vgt_max, sigma_lnC, four_pi_r2_rho_mid, &
            total_time, dm, X, X_mid, C, C_mid, tiny_C, &
            eta_T_limit_coeffs, X_Y_limit_coeffs, AD_factor, AD_velocity, GT, CD, AD)
         integer, intent(in) :: k, nz, nc, m, nzlo, nzhi
         double precision, intent(inout) :: v(nc,nz)
         double precision, intent(in) :: vgt(nc,nz), vgt_max
         double precision, intent(in) :: sigma_lnC(nc,nc,nz)
         double precision, intent(in) :: four_pi_r2_rho_mid(nz), total_time, dm(nz)
         double precision, intent(in) :: &
            X(m,nz), X_mid(m,nz), C(m,nz), C_mid(m,nz), tiny_C, &
            eta_T_limit_coeffs(nz), X_Y_limit_coeffs(nz), AD_factor, AD_velocity
         double precision, intent(out) :: GT(nc,nz)
         double precision, intent(out) :: CD(nc,nc,nz)
         double precision, intent(out) :: AD(nz)
         
         integer :: i, j
         double precision :: c1, c2, AD_limit, limit_coeff, AD_v
         double precision, parameter :: AD_term_limit = 1d12
         ! flow(i,k) = (4 pi r^2 rho)*X_mid(i,k)*v(i,k)
         ! flow(i,k) = (4 pi r^2 rho)*X_mid(i,k)*
         !        (vgt(i,k) - sum(sigma_lnC(i,1:nc)*dlnC_dr_mid(1:nc)))
         ! flow(i,k) = GT(i,k)*X_mid(i,k)
         !              - sum{j}{CD(i,j,k)*(C(j,k)-C(j,k+1))}
      
         c1 = four_pi_r2_rho_mid(k)**2/dm(k)
         AD(k) = 0
         do i = 1, nc
            GT(i,k) = four_pi_r2_rho_mid(k)*vgt(i,k)
            c2 = c1*X_mid(i,k)
            do j = 1, nc
               CD(i,j,k) = c2*sigma_lnC(i,j,k)/max(tiny_C,C_mid(j,k))
            end do
            if (k > nzlo .and. k < nzhi-1) then
               AD_v = 100*abs(v(i,k))* &
                  abs(v(i,k-1) - 2*v(i,k) + v(i,k+1)) / &
                     (1d-50 + abs(v(i,k-1)) + 2*v(i,k) + v(i,k+1))
               AD(k) = max(AD(k), AD_factor*four_pi_r2_rho_mid(k)*AD_v)
            else
               AD(k) = 0
            end if
         end do
         limit_coeff = eta_T_limit_coeffs(k)*X_Y_limit_coeffs(k)
         ! turn on AD in regions where limit_coeff < 1
         if (limit_coeff < 1) &
            AD(k) = max(AD(k), (1-limit_coeff)*four_pi_r2_rho_mid(k)*1d-5)
         AD_limit = AD_term_limit*min(dm(k),dm(max(nzlo,k-1)))/max(1d0,total_time)
         if (AD_limit < AD(k)) AD(k) = AD_limit            
      end subroutine get1_flow_coeffs


      logical function check_xtotals(nc, atol, rtol, xtotal_init, xtotal)
         integer, intent(in) :: nc
         double precision, intent(in) :: atol, rtol
         double precision, intent(in), dimension(nc) :: xtotal_init, xtotal
         integer :: j
         double precision :: err
         character (len=256) :: message
         check_xtotals = .true.
         do j=1, nc
            err = abs(xtotal(j) - xtotal_init(j)) / &
               (atol + rtol*max(abs(xtotal(j)), abs(xtotal_init(j))))
            if (err > 1) then
               write(message,'(a,i4,e20.6)') &
                  'excessive non-conservation error for nc', j, err
               call alert(-1, message)
               check_xtotals = .false.
               if (dbg) then
                  write(*,'(a)') trim(message)
                  write(*,*)
               end if
            end if
         end do
      end function check_xtotals


      subroutine get_change_info( &
            nz, nzlo, nzhi, m, nc, atol, rtol, X, X_start_step, dx_max, dx_avg, kmax)
         integer, intent(in) :: nz, nzlo, nzhi, m, nc
         double precision, intent(in) :: atol, rtol
         double precision, intent(in), dimension(m,nz) :: X, X_start_step
         double precision, intent(out) :: dx_max, dx_avg
         integer, intent(out) :: kmax
         integer :: k, j, nterms, n
         double precision :: dx_sum, rel_dx(nc), tmp
         nterms = 0
         dx_sum = 0
         dx_max = 0
         do k=nzlo, nzhi
            forall (j=1:nc) &
               rel_dx(j) = abs(X(j,k) - X_start_step(j,k)) / &
                           (atol + rtol*max(X(j,k), X_start_step(j,k)))
            tmp = maxval(rel_dx(:))
            if (tmp > dx_max) then
               dx_max = tmp; kmax = k
            end if
            dx_sum = dx_sum + sum(rel_dx(:))
            nterms = nterms + nc
         end do
         dx_avg = dx_sum/max(1,nterms)
      end subroutine get_change_info
      
      
      ! need a special version of clean up that knows about electrons
      subroutine do_clean_up_fractions( &
            nz, nzlo, nzhi, m, nc, X, tiny_X, max_sum_abs, X_cleanup_tol, ierr)
         integer, intent(in) :: nz, nzlo, nzhi, m, nc
         double precision, intent(inout) :: X(m, nz)
         double precision, intent(in) :: tiny_X, max_sum_abs, X_cleanup_tol
         integer, intent(out) :: ierr
         integer :: k, op_err
         ierr = 0
!$OMP PARALLEL DO PRIVATE(k, op_err)
         do k = nzlo, nzhi
            call clean1(k, m, nc, X(:,k), tiny_X, max_sum_abs, X_cleanup_tol, op_err)
            if (op_err /= 0) ierr = -1
         end do
!$OMP END PARALLEL DO
      end subroutine do_clean_up_fractions
      

      subroutine clean1(k, m, nc, X, tiny_X, max_sum_abs, X_cleanup_tol, ierr)
         use utils_lib
         integer, intent(in) :: k, m, nc
         double precision, intent(inout) :: X(m)
         double precision, intent(in) :: tiny_X, max_sum_abs, X_cleanup_tol
         integer, intent(out) :: ierr
         integer :: j
         double precision :: xsum
         if (max_sum_abs > 1) then ! check for crazy values
            xsum = sum(abs(X(1:nc)))
            if (is_bad_num(xsum) .or. xsum > max_sum_abs) then
               if (dbg) write(*,*) 'clean1: xsum > max_sum_abs: k, xsum', k, xsum
               ierr = -1
               return
            end if
         end if
         ierr = 0
         do j = 1, nc
            if (X(j) < tiny_X) X(j) = tiny_X
            if (X(j) > 1) X(j) = 1
         end do
         xsum = sum(X(1:nc))
         if (abs(xsum-1) > X_cleanup_tol) then
            if (dbg) write(*,*) 'clean1: abs(xsum-1) > X_cleanup_tol: k, xsum', k, xsum
            ierr = -1
            return
         end if
         X(1:nc) = X(1:nc)/xsum
      end subroutine clean1
      

      subroutine get_eqn_matrix_entries( &
            nz, nzlo, nzhi, m, nc, X, dC_dX, GT, CD, AD, cell_dm, dt, &
            rhs, em1, e00, ep1)
         integer, intent(in) :: nz, nzlo, nzhi, m, nc
         double precision, intent(in) :: X(m,nz), dC_dX(nc,nz)
         double precision, intent(in) :: &
            GT(nc,nz), CD(nc,nc,nz), AD(nz), cell_dm(nz), dt
         double precision, intent(out), dimension(nc,nz) :: rhs
         double precision, intent(out), dimension(nc,nc,nz) :: em1, e00, ep1
         integer :: k       
         ! lhs(i,k) := X(i,k) - (flow(i,k) - flow(i,k-1))*dt/cell_dm(k)         
         ! em1(i,j,k) = d(lhs(i,k))/d(X(j,k-1))
         ! e00(i,j,k) = d(lhs(i,k))/d(X(j,k))
         ! ep1(i,j,k) = d(lhs(i,k))/d(X(j,k+1)) 
!$OMP PARALLEL DO PRIVATE(k)
         do k=nzlo,nzhi
            call get1_eqn_matrix_entries( &
               k, nz, nzlo, nzhi, m, nc, X, dC_dX, GT, CD, AD, cell_dm, dt, &
               em1, e00, ep1)
         end do
!$OMP END PARALLEL DO
         rhs(1:nc,nzlo:nzhi) = X(1:nc,nzlo:nzhi)
      end subroutine get_eqn_matrix_entries
      
      
      subroutine get1_eqn_matrix_entries( &
            k, nz, nzlo, nzhi, m, nc, X, dC_dX, GT, CD, AD, cell_dm, dt, &
            em1, e00, ep1)
         integer, intent(in) :: k, nz, nzlo, nzhi, m, nc
         double precision, intent(in) :: X(m,nz), dC_dX(nc,nz)
         double precision, intent(in) :: &
            GT(nc,nz), CD(nc,nc,nz), AD(nz), cell_dm(nz), dt
         double precision, intent(out), dimension(nc,nc,nz) :: em1, e00, ep1
         integer :: i, j
         em1(:,:,k) = 0; e00(:,:,k) = 0; ep1(:,:,k) = 0
         if (k > nzlo) then ! do flow(:,k-1)
            forall (i=1:nc)
               em1(i,i,k) = em1(i,i,k) + GT(i,k-1)/2 - AD(k-1)
               e00(i,i,k) = e00(i,i,k) + GT(i,k-1)/2 + AD(k-1)
               forall (j=1:nc)
                  em1(i,j,k) = em1(i,j,k) - CD(i,j,k-1)*dC_dX(j,k-1)
                  e00(i,j,k) = e00(i,j,k) + CD(i,j,k-1)*dC_dX(j,k)
               end forall
            end forall
         end if
         if (k < nzhi) then ! do -flow(:,k)
            forall (i=1:nc)
               e00(i,i,k) = e00(i,i,k) - GT(i,k)/2 + AD(k)
               ep1(i,i,k) = ep1(i,i,k) - GT(i,k)/2 - AD(k)
               forall (j=1:nc)
                  e00(i,j,k) = e00(i,j,k) + CD(i,j,k)*dC_dX(j,k)
                  ep1(i,j,k) = ep1(i,j,k) - CD(i,j,k)*dC_dX(j,k+1)
               end forall
            end forall
         end if
         em1(:,:,k) = em1(:,:,k)*dt/cell_dm(k)
         e00(:,:,k) = e00(:,:,k)*dt/cell_dm(k)
         ep1(:,:,k) = ep1(:,:,k)*dt/cell_dm(k)
         forall (i=1:nc) e00(i,i,k) = e00(i,i,k) + 1
      end subroutine get1_eqn_matrix_entries

      
      subroutine copy_to_banded_matrix( &
            nz, nzlo, nzhi, nc, n, neqs, em1, e00, ep1, idiag, ldA, AB)
         integer, intent(in) :: nz, nzlo, nzhi, nc, n, neqs, idiag, ldA
         double precision, dimension(nc, nc, nz), intent(in) :: em1, e00, ep1         
         double precision, intent(out) :: AB(ldA, neqs) ! the banded matrix
         ! n = nzhi-nzlo+1; neqs = nc*n
         ! AB(idiag+q-v, v) = partial of lhs(q) wrt X(v)
         integer :: i
         AB(:,:) = 0         
!$OMP PARALLEL DO PRIVATE(i)
         do i = 1, nc
            call copy1_to_banded_matrix( &
               i, nz, nzlo, nzhi, nc, n, neqs, em1, e00, ep1, idiag, ldA, AB)
         end do
!$OMP END PARALLEL DO
      end subroutine copy_to_banded_matrix


      subroutine copy1_to_banded_matrix( &
            i, nz, nzlo, nzhi, nc, n, neqs, em1, e00, ep1, idiag, ldA, AB)
         integer, intent(in) :: i, nz, nzlo, nzhi, nc, n, neqs, idiag, ldA
         double precision, dimension(nc, nc, nz), intent(in) :: em1, e00, ep1         
         double precision, intent(out) :: AB(ldA, neqs) ! the banded matrix
         ! A(idiag+q-v, v) = partial of lhs(q) wrt X(v)
         integer :: k, j, dk, ii, jj, kk
         do dk = -1, 1 ! block tridiagonal
            do j = 1, nc
               ii = i - j - nc*dk + idiag
               jj = j + nc*(dk-1)
               select case(dk)
                  case(-1) 
                     forall (kk=2:n) AB(ii,jj+nc*kk) = em1(i,j,kk+nzlo-1)
                  case(0) 
                     forall (kk=1:n) AB(ii,jj+nc*kk) = e00(i,j,kk+nzlo-1)
                  case(1) 
                     forall (kk=1:n-1) AB(ii,jj+nc*kk) = ep1(i,j,kk+nzlo-1)
               end select
            end do
         end do
      end subroutine copy1_to_banded_matrix

      
      subroutine solve_matrix_eqn( &
            nz, nzlo, nzhi, m, nc, n, neqs, rhs, X, ldab, AB, ierr)
         integer, intent(in) :: nz, nzlo, nzhi, m, nc, n, neqs, ldab
         double precision, intent(in) :: rhs(nc,nz)
         double precision, intent(inout) :: X(m,nz), AB(ldab,neqs)
         integer, intent(out) :: ierr
         call dbgsv_solve_matrix_eqn( &
            nz, nzlo, nzhi, m, nc, n, neqs, rhs, X, ldab, AB, ierr)
      end subroutine solve_matrix_eqn

      
      subroutine dbgsv_solve_matrix_eqn( &
            nz, nzlo, nzhi, m, nc, n, neqs, rhs, X, lda, A, ierr)
         integer, intent(in) :: nz, nzlo, nzhi, m, nc, n, neqs, lda
         double precision, intent(in) :: rhs(nc,nz)
         double precision, intent(inout) :: X(m,nz), A(lda,neqs)
         integer, intent(out) :: ierr
         
         integer, parameter :: nrhs = 1
         integer :: i, j, ku, kl, ldab, ldx, ldb
         double precision, pointer, dimension(:,:) :: ab, bb
         integer, pointer :: ipiv(:)
         
         ierr = 0                  
         ku = 2*nc - 1
         kl = ku
         if (lda /= ku + kl + 1) then
            stop 'dbgsv_solve_matrix_eqn'
         end if
         ldab = kl + lda
         ldb = neqs
         ldx = neqs
         allocate(ab(ldab,neqs), bb(ldb,nrhs), ipiv(neqs), stat=ierr)
         if (ierr /= 0) return
         ab(ku+1:ldab,:) = A(1:lda,:)         
         do i=1,n
            do j=1,nc
               bb((i-1)*nc+j,1) = rhs(j,i+nzlo-1)
            end do
         end do        
         call dgbsv( neqs, kl, ku, nrhs, ab, ldab, ipiv, bb, ldb, ierr )
         if (ierr /= 0) then
            call dealloc
            return
         end if
         do i=1,n
            do j=1,nc
               X(j,i+nzlo-1) = bb((i-1)*nc+j,1)
            end do
         end do         
         call dealloc
         
         contains
         
         subroutine dealloc
            deallocate(ab, bb, ipiv)
         end subroutine dealloc

      end subroutine dbgsv_solve_matrix_eqn

      
      subroutine dbgsvx_solve_matrix_eqn( &
            nz, nzlo, nzhi, m, nc, n, neqs, rhs, X, ldab, AB, ierr)
         integer, intent(in) :: nz, nzlo, nzhi, m, nc, n, neqs, ldab
         double precision, intent(in) :: rhs(nc,nz)
         double precision, intent(inout) :: X(m,nz), AB(ldab,neqs)
         integer, intent(out) :: ierr         
         integer, parameter :: nrhs = 1
         integer :: i, j, ku, kl, ldafb, ldx, ldb
         character :: fact, trans, equed
         double precision :: rcond
         double precision, pointer, dimension(:) :: r, c, ferr, berr, dgbsvx_work
         double precision, pointer, dimension(:,:) :: afb, bb, dgbsvx_x
         integer, pointer :: ipiv(:), iwork(:)
         ierr = 0                  
         ku = 2*nc - 1
         kl = ku
         if (ldab /= ku + kl + 1) then
            stop 'dbgsvx_solve_matrix_eqn'
         end if
         ldafb = kl + ldab
         ldb = neqs
         ldx = neqs
         allocate( &
            afb(ldafb,neqs), r(neqs), c(neqs), bb(ldb,nrhs), dgbsvx_x(ldx,nrhs), &
            ipiv(neqs), iwork(neqs), dgbsvx_WORK(3*neqs), ferr(nrhs), berr(nrhs), &
            stat=ierr)
         if (ierr /= 0) return
         fact = 'N'
         trans = 'N'         
         do i=1,n
            do j=1,nc
               bb((i-1)*nc+j,1) = rhs(j,i+nzlo-1)
            end do
         end do        
         call dgbsvx( fact, trans, neqs, kl, ku, nrhs, ab, ldab, afb, &
                      ldafb, ipiv, equed, r, c, bb, ldb, dgbsvx_x, ldx, &
                      rcond, ferr, berr, dgbsvx_work, iwork, ierr )
         if (ierr /= 0) then
            call dealloc
            return
         end if
         do i=1,n
            do j=1,nc
               X(j,i+nzlo-1) = dgbsvx_X((i-1)*nc+j,1)
            end do
         end do         
         call dealloc
         
         contains
         
         subroutine dealloc
            deallocate( &
               afb, r, c, bb, dgbsvx_x, &
               ipiv, iwork, dgbsvx_WORK, ferr, berr)
         end subroutine dealloc

      end subroutine dbgsvx_solve_matrix_eqn


      subroutine report_bad_block( &
            bad_zone, info, nvar, bad_mtx, lrpar, rpar, lipar, ipar)
         integer, intent(in) :: bad_zone, info, nvar
         double precision, intent(in) :: bad_mtx(nvar, nvar)
         integer, intent(in) :: lrpar, lipar
         double precision, intent(inout) :: rpar(lrpar)
         integer, intent(inout) :: ipar(lipar)
         !write(*,*) 'bad block', bad_zone
         !call show_bad_block( &
         !   bad_zone, info, nvar, bad_mtx, lrpar, rpar, lipar, ipar)
         !stop 'report_bad_block'
      end subroutine report_bad_block
                  
                  
!********************************************************************

      subroutine get_coulomb_logs(m, a, z, x, rho, t, cl)
      implicit none
      integer, intent(in) :: m
      double precision, intent(in) :: a(m), z(m), x(m), rho, t
      double precision, intent(out) :: cl(m, m)

      integer i, j
      double precision zxa, ac, ni, cz, xij, ne, ao, lambdad, lambda, c(m)
! calculate concentrations from mass fractions:
      zxa=0.
      do i=1, m-1
	      zxa=zxa+z(i)*x(i)/a(i)
      end do
      do i=1, m-1
         c(i)=x(i)/(a(i)*zxa)
      end do
      c(m)=1.
! calculate density of electrons (ne) from mass density (rho):
      ac=0.
      do i=1, m
	      ac=ac+a(i)*c(i)
      end do	
      ne=rho/(1.6726e-24*ac)	
! calculate interionic distance (ao): 
      ni=0.
      do i=1, m-1
         ni=ni+c(i)*ne
      end do
      ao=(0.23873/ni)**(1./3.)	
! calculate debye length (lambdad):
      cz=0.
      do i=1, m
	      cz=cz+c(i)*z(i)**2
      end do
      lambdad=6.9010*sqrt(t/(ne*cz))
! calculate lambda to use in coulomb logarithm:
      lambda=max(lambdad, ao)
! calculate coulomb logarithms:
!$OMP PARALLEL DO PRIVATE(i, j, xij)
      do i=1, m
         do j=1, m
            xij=2.3939e3*t*lambda/abs(z(i)*z(j))
            cl(i, j)=0.81245*log(1.+0.18769*xij**1.2)
         end do
      end do
!$OMP END PARALLEL DO
      end subroutine get_coulomb_logs
      
!********************************************************************



!*************************************************************
! Original of this routine was written by Anne A. Thoul, at the Institute
! for Advanced Study, Princeton, NJ 08540.
! See Thoul et al., Ap.J. 421, p. 828 (1994)
!*************************************************************
! This routine inverses the burgers equations.
!
! The system contains N equations with N unknowns. 
! The equations are: the M momentum equations, 
!                    the M energy equations, 
!                    two constraints: the current neutrality 
!                                     the zero fluid velocity.
! The unknowns are: the M diffusion velocities,
!                   the M heat fluxes,
!                   the electric field E
!                   the gravitational force g.
!
!**************************************************
      subroutine do1_solve_thoul(n,m,ihe,a,z,x,c,cl,ap,at,ax,e_field,ierr)

! the parameter m is the number of fluids considered (ions+electrons)
! the parameter n is the number of equations (2*m+2).
!
! the vectors a,z and x contain the atomic mass numbers, 
! the charges (ionization), and the mass fractions, of the elements.
! note: since m is the electron fluid, its mass and charge must be
!      a(m)=m_e/m_u
!      z(m)=-1.
!
! the array cl contains the values of the coulomb logarithms.
! the vector ap, at, and array ax contains the results for the diffusion 
! coefficients.

      implicit none

      integer, intent(in) :: m,n,ihe
      double precision, intent(in) :: a(m),z(m),c(m),cl(m,m)
      double precision, intent(in) :: x(m)
      double precision, intent(out) :: ap(m),at(m),ax(m,m),e_field
      integer, intent(out) :: ierr

      double precision :: cc,ac,xx(m,m),y(m,m),yy(m,m),k(m,m)
      double precision :: alpha(n),nu(n),gamma(n,n),delta(n,n),ga(n)
      double precision :: ko
      integer :: i,j,l,indx(n)

! the vector c contains the concentrations
! cc is the total concentration: cc=sum(c_s)
! ac is proportional to the mass density: ac=sum(a_s c_s)
! the arrays xx,y,yy and k are various parameters which appear in 
! burgers equations.
! the vectors and arrays alpha, nu, gamma, delta, and ga represent
! the "right- and left-hand-sides" of burgers equations, and later 
! the diffusion coefficients.
      
! initialize:

      ierr = 0
      ko = 2d0  
      indx(:) = 0    

! calculate cc and ac:
      
      cc=sum(c(:))
      ac=dot_product(a(:),c(:))

! calculate the coefficients of the burgers equations

      do i=1,m
         do j=1,m
            xx(i,j)=a(j)/(a(i)+a(j))
            y(i,j)=a(i)/(a(i)+a(j))
            yy(i,j)=3d0*y(i,j)+1.3d0*xx(i,j)*a(j)/a(i)
            k(i,j)=cl(i,j)*sqrt(a(i)*a(j)/(a(i)+a(j)))*c(i)*c(j)*z(i)**2*z(j)**2
         end do
      end do

! write the burgers equations and the two constraints as
! alpha_s dp + nu_s dt + sum_t(not ihe or m) gamma_st dc_t 
!                     = sum_t delta_st w_t

      do i=1,m
         alpha(i)=c(i)/cc
         nu(i)=0d0
         gamma(i,1:n)=0d0
         do j=1,m
            if ((j /= ihe).and.(j /= m)) then
               gamma(i,j)=-c(j)/cc+c(ihe)/cc*z(j)*c(j)/z(ihe)/c(ihe)
               if (j == i) then
                  gamma(i,j)=gamma(i,j)+1d0
               end if
               if (i == ihe) then
                  gamma(i,j)=gamma(i,j)-z(j)*c(j)/z(ihe)/c(ihe)
               end if
               gamma(i,j)=gamma(i,j)*c(i)/cc
            end if
         end do
      end do
      
      do i=m+1,n-2
         alpha(i)=0d0
         nu(i)=2.5d0*c(i-m)/cc
         gamma(i,1:n)=0d0
      end do
      
      alpha(n-1)=0d0
      nu(n-1)=0d0
      gamma(n-1,1:n)=0d0
      
      alpha(n)=0d0
      nu(n)=0d0
      gamma(n,1:n)=0d0
      
      delta(:,:) = 0d0
      
      do i=1,m
         do j=1,m
            if (j == i) then
               do l=1,m
                  if (l /= i) then
                     delta(i,j)=delta(i,j)-k(i,l)
                  end if
               end do
            else
               delta(i,j)=k(i,j)
            end if
         end do
         
         do j=m+1,n-2
            if (j-m == i) then
               do l=1,m
                  if (l /= i) then
                     delta(i,j)=delta(i,j)+0.6d0*xx(i,l)*k(i,l)
                  end if
               end do
            else
               delta(i,j)=-0.6d0*y(i,j-m)*k(i,j-m)
            end if
         end do
         
         delta(i,n-1)=c(i)*z(i)
         
         delta(i,n)=-c(i)*a(i)
      end do
      
      do i=m+1,n-2
         do j=1,m
            if (j == i-m) then
               do l=1,m
                  if (l /= i-m) then
                     delta(i,j)=delta(i,j)+1.5d0*xx(i-m,l)*k(i-m,l)
                  end if
               end do
            else
               delta(i,j)=-1.5d0*xx(i-m,j)*k(i-m,j)
            end if
         end do
         
         do j=m+1,n-2
            if (j-m == i-m) then
               do l=1,m
                  if (l /= i-m) then
                     delta(i,j)=delta(i,j)-y(i-m,l)*k(i-m,l)*(1.6d0*xx(i-m,l)+yy(i-m,l))
                  end if
               end do
               delta(i,j)=delta(i,j)-0.8d0*k(i-m,i-m)
            else
               delta(i,j)=2.7d0*k(i-m,j-m)*xx(i-m,j-m)*y(i-m,j-m)
            end if
         end do
         
         delta(i,n-1:n)=0d0
      end do
      
      delta(n-1,1:m)=c(1:m)*z(1:m)
      delta(n-1,m+1:n)=0d0
      
      delta(n,1:m)=c(1:m)*a(1:m)
      delta(n,m+1:n)=0d0
      
! This critical section should not be necessary, but as of 2009/04/17 it is.  ;-(
! Someday try it again. The error was with ifort 11.0 20090131
! ----  as of July 17, 2010, seems to be okay.  ifort 11.1.088 on mac (BP)
!x$OMP CRITICAL (diffusion)
      call dgetrf(n, n, delta, n, indx, ierr)
!x$OMP END CRITICAL (diffusion)
      if (ierr /= 0) return
      
      call dgetrs( 'n', n, 1, delta, n, indx, alpha, n, ierr )
      if (ierr /= 0) return
      
      call dgetrs( 'n', n, 1, delta, n, indx, nu, n, ierr )
      if (ierr /= 0) return
      
      do j=1,n
         ga(:)=gamma(:,j)
         call dgetrs( 'n', n, 1, delta, n, indx, ga, n, ierr )
         if (ierr /= 0) return
         gamma(:,j)=ga(:)
      end do

      ap(1:m)=alpha(1:m)*ko*ac*cc
      at(1:m)=nu(1:m)*ko*ac*cc
      ax(1:m,1:m)=gamma(1:m,1:m)*ko*ac*cc
      e_field = gamma(n-1,1)

      end subroutine do1_solve_thoul                                                        


      end module mod_diffusion_support

