! ***********************************************************************
!
!   Copyright (C) 2010  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
!
! ***********************************************************************

      module mod_diffusion

      use const_def
      use chem_def
      use utils_lib, only:is_bad_num, has_bad_num, return_nan
      use alert_lib, only:alert
      use mod_diffusion_support
      use mod_diffusion_procs

      implicit none
      
      
      ! finite difference equation for element diffusion.
      
      ! X(i,k) = X_old(i,k) + (flow(i,k) - flow(i,k-1))*dt/cell_dm(k)
      
      ! flow(i,k) is outward flow of nc i at boundary between points k and k+1 (gm/sec).
      ! flow(i,k) is the flow into cell k from k+1; flow(i,k-1) is the flow out of cell k to k-1.
      ! since the scheme is based on flows between cells, it is conservative.
      
      ! flow combines gravi-thermal settling and classical diffusion.
      ! the flow expression is linearized in X by evaluating flow coeff's using X_old.
      ! hence, the scheme is semi-implicit. 
      
      ! the matrix form of the equation is M*X = X_old
      ! X(i,k) - (flow(i,k) - flow(i,k-1))*dt/cell_mass(k) = X_old(i,k)
      

      contains
      
      
      subroutine do_solve_diffusion( &
            nz, species, nc, m, class, class_chem_id, net_iso, &
            abar, ye, free_e, mstar, dm_in, cell_mass_in, &
            T, lnT, rho, lnd, r, dlnP_dm, dlnT_dm, dlnRho_dm, L, &
            total_time, maxsteps_allowed, &
            calculate_ionization, typical_charge, &
            atol, rtol, AD_factor, AD_velocity, vgt_max, &
            gamma, gamma_full_on, gamma_full_off, T_full_on, T_full_off, &
            X_full_on, X_full_off, Y_full_on, Y_full_off, diffusion_factor, &
            xa, steps_used, total_num_retries, nzlo, nzhi, X_init, X_final, &
            AP, AT, AX, dlnP_dr_mid, dlnT_dr_mid, dlnRho_dr_mid, dlnC_dr_mid, v, vgt, ierr )
            
         integer, intent(in) :: nz, species, nc, m 
            ! nc = number of classes of isotopes
            ! m = nc+1
         integer, intent(in) :: class(species), class_chem_id(nc), net_iso(:)
            ! class(i) = class number for species i. class numbers from 1 to nc
            ! class_chem_id(j) = isotope id number from chem_def for "typical" member of class j
         double precision, intent(in), dimension(nz) :: &
            abar, ye, free_e, gamma, dm_in, cell_mass_in, &
            T, lnT, rho, lnd, r, dlnP_dm, dlnT_dm, dlnRho_dm, L
         double precision, intent(in) :: &
            total_time, mstar, atol, rtol, AD_factor, AD_velocity, &
            vgt_max, gamma_full_on, gamma_full_off, &
            X_full_on, X_full_off, Y_full_on, Y_full_off, T_full_on, T_full_off,  diffusion_factor(nc)
         integer, intent(in) :: maxsteps_allowed
         logical, intent(in) :: calculate_ionization 
         double precision, intent(inout), dimension(nc,nz) :: typical_charge
         double precision, intent(inout) :: xa(species,nz) ! mass fractions
         double precision, intent(out), dimension(nc,nz) :: &
            X_init, X_final, v, vgt
         double precision, intent(out), dimension(nz) :: dlnP_dr_mid, dlnT_dr_mid, dlnRho_dr_mid
         double precision, intent(out), dimension(m,nz) :: AP, AT, dlnC_dr_mid
         double precision, intent(out), dimension(m,m,nz) :: AX
         integer, intent(out) :: steps_used, total_num_retries, ierr
         integer, intent(inout) ::  nzlo, nzhi !upper and lower bounds on region

         integer :: i, j, k, retry_count, idiag, n, neqs, ku, kl, ldab, ldafb, ldb, ldx, &
            kmax, h1, he4, nbound
         double precision :: dt, dt_next, min_dt, time, mtotal, xtotal_init(nc), xtotal(nc), &
            dx_max, dx_avg, dx_max_allowed
         logical :: converged
         
         integer, parameter :: max_retries = 25, min_nz_lo = 5
         double precision, parameter :: &
            dt_retry_factor = 0.5d0, dt_max_factor = 2d0, dt_min_factor = 0.75d0, &
            tiny_X = 1d-50, tiny_C = 1d-50, max_sum_abs = 10, X_cleanup_tol = 1d-2, &
            max_flow_frac = 1d0, max_flow_X_limit = 1d-5, dx_avg_target = 0.975d0
         double precision, dimension(:), pointer :: &
            dm, cell_dm, rho_mid, T_mid, four_pi_r2_rho_mid, &
            gamma_T_limit_coeffs, X_Y_limit_coeffs, A, E_field, AD
         double precision, dimension(:,:), pointer :: &
            Z, X, X_start_step, X_mid, C, C_mid, dC_dr, dC_dr_mid, dC_dX, GT, rhs, AB
         double precision, dimension(:,:,:), pointer :: &
            clg_mid, sigma_lnC, CD, em1, e00, ep1            

            
         1 format(a40,99(1pe26.16,1x))
         2 format(a40,i6,99(1pe26.16,1x))
                  
         if (m /= nc+1) then
            ierr = -1
            write(*,*) 'm /= nc+1'
            return
         end if
         
         ierr = 0
         steps_used = 0
         total_num_retries = 0
         
         h1 = net_iso(ih1)
         if (h1 == 0) then
            ierr = -1; write(*,*) 'isos must include h1 for diffusion'; return
         end if
         
         he4 = net_iso(ihe4)
         if (he4 == 0) then
            ierr = -1; write(*,*) 'isos must include he4 for diffusion'; return
         end if

         !reset nzlo if necessary
         nbound=nzlo
         do k=nzlo,nzhi
            if (T(k) > T_full_off .and. xa(he4,k)> Y_full_off) then
               nbound=k; exit
            endif
         enddo
         nzlo=nbound

         !reset nzhi if necessary
         nbound=nzhi
         do k=nzlo,nzhi
            if (gamma(k) >= gamma_full_off) then
               nbound=k-1; exit
            endif
         enddo
         nzhi=nbound

         n = nzhi-nzlo+1
         
         if (dbg) write(*,*) '  nz', nz
         if (dbg) write(*,*) 'nzlo', nzlo
         if (dbg) write(*,*) 'nzhi', nzhi
         if (dbg) write(*,*) '   n', n
         if (n <= 1) return

         neqs = n*nc
         idiag = 2*nc
         ku = 2*nc - 1
         kl = ku
         ldab = ku + kl + 1
         ldafb = kl + ldab
         ldb = neqs
         ldx = neqs

         
         call alloc(ierr)
         if (ierr /= 0) return
         
         dm(1:nz) = dm_in(1:nz)
         
         ! get info that is independent of composition         
         
         call get_struct_mid( &
            nz, nzlo, nzhi, r, rho, T, rho_mid, T_mid, four_pi_r2_rho_mid)
         
         call get_smooth_PTRho_gradients( &
            nz, nzlo, nzhi, dlnP_dm, dlnT_dm, dlnRho_dm, four_pi_r2_rho_mid, &
            dlnP_dr_mid, dlnT_dr_mid, dlnRho_dr_mid)
         
         call get_smooth_Z( &
            nz, nzlo, nzhi, nc, class_chem_id, m, abar, free_e, T, lnT, rho, lnd, &
            calculate_ionization, Z, typical_charge)
         
         call get_gamma_T_limit_coeffs( &
            nz, nzlo, nzhi, gamma, gamma_full_on, gamma_full_off, &
            T, T_full_on, T_full_off, gamma_T_limit_coeffs)

         call get_A_and_X( &
            nz, nzlo, nzhi, species, nc, m, class, class_chem_id, Z, xa, tiny_X, A, X)
         X_init(1:nc,1:nz) = X(1:nc,1:nz)

         if (.false.) then
            call test_new_diffusion
            stop 'done test_new_diffusion'
         end if
         
         ! dm(k) is the mass between points k and k+1
         ! cell_dm(k) is the mass associated with point k
         cell_dm(1:nz) = cell_mass_in(1:nz)

         ! combine cells 1 to nzlo
         forall (j=1:m) X(j,nzlo) = dot_product(cell_dm(1:nzlo),X(j,1:nzlo))
         cell_dm(nzlo) = sum(cell_dm(1:nzlo))
         X(1:m,nzlo) = X(1:m,nzlo)/cell_dm(nzlo)
         X(1:nc,nzlo) = X(1:nc,nzlo)/sum(X(1:nc,nzlo))
            
         mtotal = sum(cell_dm(nzlo:nzhi))
         forall (j=1:nc) xtotal_init(j) = &
            dot_product(cell_dm(nzlo:nzhi),X(j,nzlo:nzhi))/mtotal
         
         dx_max_allowed = 0.1d0/atol
         min_dt = total_time/(1000*maxsteps_allowed)
         time = 0
         dt = total_time
         
         steps_loop: do steps_used = 1, maxsteps_allowed
            
            call get_C(nz, nzlo, nzhi, nc, m, Z, A, X, C, dC_dX)
            
            call get_electron_mass_fractions(nz, nzlo, nzhi, nc, m, A, C, X)
         
            call get_dC_dr_mid(m, nz, nzlo, nzhi, C, four_pi_r2_rho_mid, dm, dC_dr_mid)
            
            call get_middle_C_and_X( &
               nz, m, nzlo, nzhi, C, X, dC_dr_mid, tiny_C, C_mid, X_mid, dlnC_dr_mid)
         
            call eval_coulomb_logs(nz, m, nzlo, nzhi, A, Z, X_mid, rho_mid, T_mid, clg_mid)
         
            call get_X_Y_limit_coeffs( &
               nz, m, nzlo, nzhi, class(h1), class(he4), X_mid, &
               X_full_on, X_full_off, Y_full_on, Y_full_off, X_Y_limit_coeffs)

            call eval_gradient_coeffs(nz, m, class(he4), nzlo, nzhi, Y_full_off, &
               A, Z, X_mid, C_mid, clg_mid, AP, AT, AX, E_field, ierr)
            if (failed('eval_gradient_coeffs')) return

            call eval_velocities( &
               nz, m, nc, nzlo, nzhi, AP, AT, AX, rho_mid, T_mid, &
               dlnP_dr_mid, dlnT_dr_mid, dlnRho_dr_mid, X_mid, dlnC_dr_mid, &
               gamma_T_limit_coeffs, X_Y_limit_coeffs, vgt_max, &
               diffusion_factor, v, vgt, sigma_lnC)
            
            call get_flow_coeffs( &
               nz, nc, m, nzlo, nzhi, v, vgt, vgt_max, sigma_lnC, four_pi_r2_rho_mid, &
               total_time, dm, X, X_mid, C, C_mid, tiny_C, &
               gamma_T_limit_coeffs, X_Y_limit_coeffs, AD_factor, AD_velocity, GT, CD, AD)
            
            converged = .false.
            retry_loop: do retry_count = 1, max_retries

               if (dbg) then
                  3 format(10x,2i8,99(1pe26.16,1x))
                  4 format(10x,2a8,99(a26,1x))
                  if ((steps_used == 1 .or. mod(steps_used,25)==0) .and. retry_count == 1) then
                     write(*,*)
                     write(*,4) 'steps', 'retries', &
                        'dt/secyer', 'time/secyer', '(total_time-time)/secyer'
                  end if
                  write(*,3) steps_used, total_num_retries, &
                     dt/secyer, time/secyer, (total_time-time)/secyer
               end if

               if (dt < min_dt) then
                  if (dbg) write(*,1) 'dt < min_dt: min_dt', min_dt
                  exit steps_loop
               end if
            
               if (retry_count == 1) then ! save
                  X_start_step(:,nzlo:nzhi) = X(:,nzlo:nzhi)
               else ! restore
                  X(:,nzlo:nzhi) = X_start_step(:,nzlo:nzhi)
                  total_num_retries = total_num_retries+1
               end if
            
               call get_eqn_matrix_entries(nz, nzlo, nzhi, m, nc,  &
                  X, dC_dX, GT, CD, AD, cell_dm, dt, rhs, em1, e00, ep1)
               
               call copy_to_banded_matrix( &
                  nz, nzlo, nzhi, nc, n, neqs, em1, e00, ep1, idiag, ldab, AB)
               
               call solve_matrix_eqn( &
                  nz, nzlo, nzhi, m, nc, n, neqs, rhs, X, ldab, AB, ierr)
               if (ierr /= 0) then
                  if (dbg) write(*,*) 'retry because failed to solve matrix eqn'
                  ierr = 0; dt = dt*dt_retry_factor; cycle retry_loop
               end if
               
               ! check for excessive error in any sum of X's == 1.0
               call do_clean_up_fractions( &
                  nz, nzlo, nzhi, m, nc, X, tiny_X, max_sum_abs, X_cleanup_tol, ierr)
               if (ierr /= 0) then
                  if (dbg) write(*,*) 'retry because failed in do_clean_up_fractions'
                  ierr = 0; dt = dt*dt_retry_factor; cycle retry_loop
               end if
               
               ! check for excessive error in conservation of abundances
               forall (j=1:nc) &
                  xtotal(j) = dot_product(cell_dm(nzlo:nzhi),X(j,nzlo:nzhi))/mtotal
               if (.not. check_xtotals(nc, atol, rtol, xtotal_init, xtotal)) then
                  if (dbg) write(*,*) 'retry because failed in conservation of abundances'
                  dt = dt*dt_retry_factor; cycle retry_loop
               end if
            
               ! check for excessive change in any X in any cell
               call get_change_info( &
                  nz, nzlo, nzhi, m, nc, atol, rtol, X, X_start_step, dx_max, dx_avg, kmax)
               if (dx_max > dx_max_allowed .or. dx_avg > 1) then
                  dt = dt*dt_retry_factor; cycle retry_loop
               end if
               
               converged = .true. 
               exit retry_loop
            
            end do retry_loop
            
            if (.not. converged) exit steps_loop
            
            time = time + dt
            if (time >= total_time) exit steps_loop
            
            dt = dt*min(dt_max_factor, max(dt_min_factor, dx_avg_target/max(1d-50,dx_avg)))             
            if (time+dt > total_time) dt = total_time - time
         
         end do steps_loop
         
         if (time < total_time) ierr = -1 ! failed to finish
         
         if (ierr == 0) then
            call set_new_xa(nz, nzlo, nzhi, species, nc, m, class, X_init, X, xa)
            call do_smooth_where_h_rich(nz, nzlo, nzhi, species, net_iso, xa)
            X_final(1:nc,nzlo:nzhi) = X(1:nc,nzlo:nzhi)
            X_final(1:nc,nzhi+1:nz) = X_init(1:nc,nzhi+1:nz)
            forall (j=1:nc) X_final(j,1:nzlo-1) = X_final(j,nzlo)
         end if
         
         call dealloc
         
         
         contains
         
         
         subroutine test_new_diffusion
            real*8 :: dm_dk, dm_dr, d_logP_dr, d_logT_dr, d_logRho_dr
            integer :: k
            
            include 'formats.dek'

            if (nc /= 9) then
               write(*,*) 'need 9 classes for new diffusion'
               stop
            end if
            if (class_chem_id(1) /= ih1) then
               write(*,*) 'class_chem_id(1) /= ih1 for new diffusion'
               stop
            end if
            if (class_chem_id(2) /= ihe4) then
               write(*,*) 'class_chem_id(2) /= ihe4 for new diffusion'
               stop
            end if
            if (class_chem_id(3) /= ic12) then
               write(*,*) 'class_chem_id(3) /= ic12 for new diffusion'
               stop
            end if
            if (class_chem_id(4) /= in14) then
               write(*,*) 'class_chem_id(4) /= in14 for new diffusion'
               stop
            end if
            if (class_chem_id(5) /= io16) then
               write(*,*) 'class_chem_id(5) /= io16 for new diffusion'
               stop
            end if
            if (class_chem_id(6) /= ine20) then
               write(*,*) 'class_chem_id(6) /= ine20 for new diffusion'
               stop
            end if
            if (class_chem_id(7) /= img24) then
               write(*,*) 'class_chem_id(7) /= img24 for new diffusion'
               stop
            end if
            if (class_chem_id(8) /= ife52) then
               write(*,*) 'class_chem_id(8) /= ife52 for new diffusion'
               stop
            end if
            if (class_chem_id(9) /= ini56) then
               write(*,*) 'class_chem_id(9) /= ini56 for new diffusion'
               stop
            end if

            write(*,*) nz, nzlo, nzhi
            do k=1,nz
               dm_dk = cell_mass_in(k)
               dm_dr = 4*pi*r(k)**2*rho(k)
               d_logP_dr = dlnP_dm(k)*dm_dr/ln10
               d_logT_dr = dlnT_dm(k)*dm_dr/ln10
               d_logRho_dr = dlnRho_dm(k)*dm_dr/ln10
               write(*,*) k
               write(*,'(1pe26.16)') &
                  X(1:nc,k), T(k), rho(k), L(k), r(k), dm_dk, &
                  d_logP_dr, d_logT_dr, d_logRho_dr
            end do
            write(*,*) -1
            write(*,*)
         
         end subroutine test_new_diffusion
         
         
         subroutine alloc(ierr)
            integer, intent(out) :: ierr
            ierr = 0
            allocate( &
               dm(nz), cell_dm(nz), rho_mid(nz), T_mid(nz), four_pi_r2_rho_mid(nz), &
               gamma_T_limit_coeffs(nz), X_Y_limit_coeffs(nz), A(m), E_field(nz), &
               Z(m,nz), X(m,nz), X_start_step(m,nz), X_mid(m,nz), &
               C(m,nz), C_mid(m,nz), dC_dr(m,nz), dC_dr_mid(m,nz), dC_dX(nc,nz), &
               clg_mid(m,m,nz), sigma_lnC(nc,nc,nz), &
               AD(nz), GT(nc,nz), CD(nc,nc,nz), rhs(nc,nz), AB(ldab,neqs), &
               em1(nc,nc,nz), e00(nc,nc,nz), ep1(nc,nc,nz), &
               stat=ierr)
         end subroutine alloc
         
         subroutine dealloc
            deallocate( &
               dm, cell_dm, rho_mid, T_mid, four_pi_r2_rho_mid, &
               gamma_T_limit_coeffs, X_Y_limit_coeffs, A, E_field, &
               Z, X, X_start_step, X_mid, C, C_mid, dC_dr, dC_dr_mid, dC_dX, &
               clg_mid, sigma_lnC, &
               AD, GT, CD, rhs, AB, em1, e00, ep1)
         end subroutine dealloc
         
         logical function failed(str)
            character (len=*) :: str
            failed = .false.
            if (ierr == 0) return
            if (dbg) write(*,*) 'failed in ' // trim(str), ierr
            call dealloc
            failed = .true.
         end function failed
      
      end subroutine do_solve_diffusion    
      
                        
      end module mod_diffusion

