! ***********************************************************************
!
!   Copyright (C) 2008  Bill Paxton
!
!   This file is part of MESA.
!
!   MESA is free software; you can redistribute it and/or modify
!   it under the terms of the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License, or
!   (at your option) any later version.
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!   GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************

      module test_mix
      
      use const_def, only: dp

      implicit none
      
      logical, parameter :: stop_for_bugs = .true.
      
      
      integer, parameter :: species = 7
      
      integer, parameter :: ipar_j = 1

      integer, parameter :: mix_lipar = 3
      
      integer, parameter :: rpar_ytotal = 1
      integer, parameter :: rpar_x = 2
      integer, parameter :: rpar_xend = 3
      
      integer, parameter :: mix_lrpar = 3



      
      integer :: nz, nzlo, nzhi, nfcn, njac
      real(dp), dimension(:), pointer :: abar, zbar, z2bar, ye, &
         T, rho, lnT, lnRho, dm, r, vc, cdc
      real(dp), dimension(:,:), pointer :: initial_xa, xa, dxa_dt

      contains
            
      
      subroutine do_test_mix(mix_solver)
         integer, intent(in) :: mix_solver

         integer :: n,nsteps,i,j,k,lrd,lid,liwork,lwork,ierr,maxsteps_allowed, &
            maxsteps_used, nsubsteps, mixing_scheme, n_conv_regions
         real(dp) :: del_t, mix_atol, mix_rtol
         
         nsubsteps = 2
         
         ierr = 0
         
         call read_mix_data(del_t,ierr)
         if (ierr /= 0) then
            write(*,*) 'read_mix_data ierr', ierr
            stop 1
         end if

         maxsteps_allowed = 100
         mix_atol = 1d-4
         mix_rtol = 1d-4

         xa = initial_xa
      
         nzlo = 164
         nzhi = 226
         
         nfcn = 0
         njac = 0

         write(*,*)
         write(*,*) '          nzlo', nzlo
         write(*,*) '          nzhi', nzhi
         write(*,*)

         n = nzhi-nzlo+1
         call get_mix_work_sizes(n,lrd,lid,liwork,lwork)
         
         j = 1 ! the species to mix
         if (.not. do_mix1_species(n,nzlo,nzhi,j,maxsteps_allowed,nsteps, &
                     del_t,liwork,lwork,lrd,lid, &
                     mix_solver,mix_atol,mix_rtol)) then
            write(*,*) 'do_mix1_species returned false'
            stop 1
         end if
         
         write(*,*)
         write(*,*) '          nfcn', nfcn
         write(*,*) '          njac', njac
         write(*,*) '   nfcn + njac', nfcn + njac
         write(*,*)

         call write_mix_results
         call free_mix_data

      end subroutine do_test_mix
      
      
      logical function do_mix1_species( &
                        n,nzlo,nzhi,j,maxsteps_allowed,nsteps, &
                        del_t,liwork,lwork,lrd,lid, &
                        mix_solver,mix_atol,mix_rtol) 
         ! return false if cannot satisfy accuracy requirements for mixing for species j
         use mtx_lib,only:tridiag_decsol,null_decsols,tridiag_decsolc,null_decsolcs
         use num_lib,only:null_mas,null_solout,isolve,null_sjac
         use const_def,only:ln10
         use utils_lib,only:is_bad_num
         
         integer, intent(in) :: n, nzlo, nzhi, j, maxsteps_allowed, liwork, lwork, lrd, lid, mix_solver
         integer, intent(out) :: nsteps
         real(dp), intent(in) :: del_t, mix_atol, mix_rtol

         integer, parameter :: lipar=mix_lipar, lrpar=mix_lrpar
         integer, target :: ipar_ary(lipar)
         real(dp), target :: rpar_ary(lrpar)
         integer, pointer :: ipar_decsol(:) ! (lid)
         real(dp), pointer :: rpar_decsol(:) ! (lrd)
         real(dp), pointer :: rpar(:)
         integer, pointer :: ipar(:)
         real(dp), pointer :: work(:) ! (lwork)
         integer, pointer :: iwork(:) ! (liwork)

         integer :: ierr, i, lout, mljac, mujac, mlmas, mumas, maxtries, itol, iout, imas, ijac, &
            ifcn, idfx, idid, nzmax, isparse, omp_get_thread_num, maxsteps, solver
         real(dp) :: x1, x2, hmin, h, atol(1), rtol(1), max_step_size, ytotal, ytotal_end
         
         real(dp) :: y(n)
         
         ipar => ipar_ary
         rpar => rpar_ary
         
         allocate(work(lwork), iwork(liwork), ipar_decsol(lid), rpar_decsol(lrd))
         
         ipar(ipar_j) = j
         
         rpar(rpar_ytotal) = dot_product(dm(nzlo:nzhi),xa(j,nzlo:nzhi))
         rpar(rpar_xend) = del_t

         x1 = 0
         x2 = del_t
         hmin = 0
         ierr = 0

         itol = 0 ! scalar tolerances
         atol(1) = mix_atol
         rtol(1) = mix_rtol
         
         
         atol(1) = 1d-8
         rtol(1) = 1d-8
         
         
         
         
         

         ijac = 1 ! analytical jacobian
         !ijac = 0 ! numerical jacobian
         
         iout = 1 ! call solout
         ifcn = 0 ! autonomous
         nzmax = 0 ! not sparse
         isparse = 0
         idfx = 0
         imas = 0
         mlmas = 0
         mumas = 0        
         lout = 6
         iout = 1
         mujac = 1
         mljac = 1
         
         
         h = del_t
         
         max_step_size = 0
         maxsteps = 100 ! maxsteps_allowed
         
         ! FIX maxsteps parameter
         
         
         solver = mix_solver
         
         
         
         

         iwork = 0
         work = 0
         ipar_decsol = 0
         rpar_decsol = 0
         
         y(1:n) = xa(j,nzlo:nzhi)
         
         ytotal = dot_product(xa(j,nzlo:nzhi),dm(nzlo:nzhi))
         
         1 format(a30,e26.16)
         3 format(a30,f26.16)
         2 format(a30,i26)
         write(*,3) 'initial max y', maxval(y(1:n))
         write(*,3) 'initial avg y', ytotal/sum(dm(nzlo:nzhi))
         write(*,3) 'initial min y', minval(y(1:n))
         write(*,*)
                     
         call isolve( &
               solver, &
               n,mix_fcn,x1,y,x2, &
               h,max_step_size,maxsteps,  &
               rtol,atol,itol, &
               mix_jac,ijac,null_sjac,nzmax,isparse,mljac,mujac, &
               null_mas,imas,mlmas,mumas, &
               mix_solout,iout, &
               tridiag_decsol,null_decsols,lrd,rpar_decsol,lid,ipar_decsol, &
               work,lwork,iwork,liwork, &
               lrpar,rpar,lipar,ipar, &
               lout,idid)
         
         nsteps = iwork(16)
         ytotal_end = dot_product(y(1:n),dm(nzlo:nzhi))
         
         deallocate(work, iwork)
         
         write(*,*)
         write(*,2) 'back from isolve', idid
         write(*,2) 'solver', solver
         write(*,2) 'nsteps', nsteps
         write(*,3) 'x/xend', rpar(rpar_x)/rpar(rpar_xend)
         write(*,*)
         write(*,3) 'final max y', maxval(y(1:n))
         write(*,3) 'final avg y', ytotal_end/sum(dm(nzlo:nzhi))
         write(*,3) 'final min y', minval(y(1:n))
         write(*,3) 'final/initial ytotal', ytotal_end/ytotal
         write(*,3) 'final max-min y', maxval(y(1:n))-minval(y(1:n))
         write(*,*)
               
         if (idid == 1) then
            xa(j,nzlo:nzhi) = y(1:n)
            do_mix1_species = .true.
            return
         end if
            
         
         do_mix1_species = .false.
         

      end function do_mix1_species


      subroutine mix_op(n,x,y,f,dfdy,ld_dfdy,lrpar,rpar,lipar,ipar,ierr)
         integer, intent(in) :: n, ld_dfdy, lrpar, lipar
         real(dp), intent(in) :: x
         real(dp), intent(inout) :: y(n)
         real(dp), intent(out) :: f(n), dfdy(ld_dfdy,n)
         integer, intent(inout), pointer :: ipar(:) ! (lipar)
         real(dp), intent(inout), pointer :: rpar(:) ! (lrpar)
         integer, intent(out) :: ierr
         
         character (len=100) :: message
         integer :: i, j, k, species, nz
         real(dp) :: dy1, dy2, dm1, dm2, sig1, sig2, sig1dm1, sig2dm1, sigdm_max, sig2_max
         
         sigdm_max = 1d10  ! limit to prevent matrix ill-conditioning

         ierr = 0
         if (lipar < mix_lipar) then
            if (stop_for_bugs) stop 'bad lipar for mix_op'
            ierr = -1
            return
         end if
         j = ipar(ipar_j)
         
         if (n /= nzhi-nzlo+1) then
            if (stop_for_bugs) stop 'bad n for mix_op'
            ierr = -1
            return
         end if
         
         if (n == 1) then
            f(1) = 0
            if (ld_dfdy > 0) dfdy = 0
            return
         end if
         
         ! diffusion equation for cell k
            ! sig(k) = diffusion coefficient at face k (between cells k-1 and k) [gm^2/sec]
            ! dm(k) = mass of cell k
            ! y(k) = mass fraction for species in cell k
            ! dy_dm(k) = 2*(y(k-1)-y(k))/(dm(k-1)+dm(k))
            ! dy_dt(k) = (sig(k)*dy_dm(k) - sig(k+1)*dy_dm(k+1))/dm(k)
         
         f = 0
         sig2 = 0
         dy2 = 0
         dm2 = dm(nzlo)
         do i=1,n
            k = nzlo+i-1
            dm1 = dm2
            dy1 = dy2
            sig1 = sig2
            if (i < n) then
               dm2 = dm(k+1)
               dy2 = y(i)-y(i+1)
               sig2 = 2*cdc(k+1)/(dm1+dm2)
               ! need to limit sig to prevent numerical ill-conditioning
               sig2_max = sigdm_max*min(dm1,dm2)
               if (sig2 > sig2_max) then
                  write(*,*) 'sig2', k, sig2, sig2_max, sig2/sig2_max
                  
                  
                  stop 'mix_op'
                  
                  
                  sig2 = sig2_max
               end if
            else
               dm2 = 0
               dy2 = 0
               sig2 = 0
            end if
            sig1dm1 = sig1/dm1
            sig2dm1 = sig2/dm1
            f(i) = dy1*sig1dm1 - dy2*sig2dm1;
            if (ld_dfdy == 0) cycle ! not doing partials
            if (i > 1) then
               dfdy(3,i-1) = sig1dm1 ! df(i)/dy(i-1)
            end if
            dfdy(2,i) = -(sig1dm1+sig2dm1); ! df(i)/dy(i)
            if (i < n) then
               dfdy(1,i+1) = sig2dm1; ! df(i)/dy(i+1)
            end if
         end do

      end subroutine mix_op


      subroutine mix_fcn(n,x,y,f,lrpar,rpar,lipar,ipar,ierr)
         integer, intent(in) :: n, lrpar, lipar
         real(dp), intent(in) :: x
         real(dp), intent(inout) :: y(n)
         real(dp), intent(out) :: f(n)
         integer, intent(inout), pointer :: ipar(:) ! (lipar)
         real(dp), intent(inout), pointer :: rpar(:) ! (lrpar)
         integer, intent(out) :: ierr
         real(dp) :: dfdy(0,n)
         ierr = 0
         nfcn = nfcn+1
         call mix_op(n,x,y,f,dfdy,0,lrpar,rpar,lipar,ipar,ierr)
      end subroutine mix_fcn


      subroutine mix_jac(n,x,y,f,dfdy,ldfy,lrpar,rpar,lipar,ipar,ierr)
         integer, intent(in) :: n, ldfy, lrpar, lipar
         real(dp), intent(in) :: x
         real(dp), intent(inout) :: y(n)
         real(dp), intent(out) :: f(n), dfdy(ldfy,n)
         integer, intent(inout), pointer :: ipar(:) ! (lipar)
         real(dp), intent(inout), pointer :: rpar(:) ! (lrpar)
         integer, intent(out) :: ierr
         ierr = 0
         njac = njac+1
         call mix_op(n,x,y,f,dfdy,ldfy,lrpar,rpar,lipar,ipar,ierr)
      end subroutine mix_jac


      subroutine mix_solout(nr,xold,x,n,y,work,iwork,interp_y,lrpar,rpar,lipar,ipar,irtrn)
         ! nr is the step number.
         ! x is the current x value; xold is the previous x value.
         ! y is the current y value.
         ! irtrn negative means terminate integration.
         integer, intent(in) :: nr, n, lrpar, lipar
         real(dp), intent(in) :: xold, x
         real(dp), intent(inout) :: y(n)
         ! y can be modified if necessary to keep it in valid range of possible solutions.
         real(dp), intent(inout), target :: work(*)
         integer, intent(inout), target :: iwork(*)
         integer, intent(inout), pointer :: ipar(:) ! (lipar)
         real(dp), intent(inout), pointer :: rpar(:) ! (lrpar)
         interface
            real(dp) function interp_y(i,s,work,iwork,ierr)
               use const_def, only: dp
               integer, intent(in) :: i ! result is interpolated approximation of y(i) at x=s.
               real(dp), intent(in) :: s ! interpolation x value (between xold and x).
               real(dp), intent(inout), target :: work(*)
               integer, intent(inout), target :: iwork(*)
               integer, intent(out) :: ierr
            end function interp_y
         end interface
         integer, intent(out) :: irtrn ! < 0 causes solver to return to calling program. 
         
         integer :: i, j, ierr
         real(dp) :: ytotal, ytotal_expected
         
         irtrn = 0

         if (lipar < mix_lipar) then
            if (stop_for_bugs) stop 'bad lipar for mix_solout'
            irtrn = -1
            return
         end if
         j = ipar(ipar_j)
         
         if (n /= nzhi-nzlo+1) then
            if (stop_for_bugs) stop 'bad n for mix_solout'
            ierr = -1
            return
         end if

         ytotal_expected = rpar(rpar_ytotal)
         if (ytotal_expected == 0) then
            y(1:n) = 0
            return
         end if
         
         forall (i=1:n) y(i) = max(0d0,min(1d0,y(i)))
         ytotal = dot_product(dm(nzlo:nzhi),y(1:n))
         if (ytotal < 1d-10*ytotal_expected .or. ytotal <= 0) then
         
         
            write(*,*) 'bad ytotal', j, ytotal, ytotal_expected
            stop 'mix_solout'
            
            
            y(1:n) = ytotal_expected / sum(dm(nzlo:nzhi))
         else
            y(1:n) = y(1:n)*(ytotal_expected/ytotal)
         end if
         
         rpar(rpar_x) = x
         
         return
         
         
         write(*,*) 'nr', nr
         write(*,*) 'x/xend', x/rpar(rpar_xend)
         if (.false.) then
            call write_mix_results
            stop 'mix_solout'
         end if
         
         return
         
 11      format(a20,e26.16)
         if (ipar(1) == ipar(3)) then
            write(*,*) 'mix_solout', nr
            write(*,11) 'y(1)', y(1)
            write(*,*)  
         end if

      end subroutine mix_solout
      
      
      subroutine get_mix_work_sizes(n,lrd,lid,liwork,lwork)
         use num_lib
         use mtx_lib
         integer, intent(in) :: n
         integer, intent(out) :: lrd,lid,liwork,lwork
         integer :: nzmax, imas, mljac, mujac, mlmas, mumas, ldjac
         nzmax = 0
         mljac = 1
         mujac = 1
         imas = 0
         mlmas = 0
         mumas = 0
         ldjac = mljac+mujac+1
         call isolve_work_sizes(n,nzmax,imas,mljac,mujac,mlmas,mumas,liwork,lwork)
         call tridiag_work_sizes(n,lrd,lid)
      end subroutine get_mix_work_sizes
      
      
      subroutine read_mix_data(dt,ierr)
         use utils_lib,only:alloc_iounit,free_iounit
         integer, intent(out) :: ierr
         real(dp), intent(out) :: dt
               
         integer :: iounit, i, k
         character (len=64) :: filename
         
         ierr = 0
         filename = 'mix_test.data'
         iounit = alloc_iounit(ierr); if (ierr /= 0) return
         open(unit=iounit, file=trim(filename), action='read', iostat=ierr)
         if (ierr /= 0) then
            write(*,*) 'failed to open ', trim(filename)
            return
         end if
         read(unit=iounit,fmt=*,iostat=ierr) nz, dt
         if (ierr /= 0) return
         read(unit=iounit,fmt=*,iostat=ierr) 
         if (ierr /= 0) return
         
         allocate(abar(nz),zbar(nz),z2bar(nz),ye(nz),T(nz),Rho(nz),lnT(nz),lnRho(nz),dm(nz), &
            r(nz),vc(nz),cdc(nz),stat=ierr)
         if (ierr /= 0) return
         allocate(xa(species,nz),dxa_dt(species,nz),initial_xa(species,nz),stat=ierr)
         if (ierr /= 0) return
         
         do i=1,nz
            read(unit=iounit,fmt=*,iostat=ierr) k, T(i), Rho(i), dm(i), r(i), vc(i), cdc(i)
            lnT(i) = log(T(i))
            lnRho(i) = log(Rho(i))
         end do

         read(unit=iounit,fmt=*,iostat=ierr) 
         if (ierr /= 0) return
         read(unit=iounit,fmt=*,iostat=ierr) 
         if (ierr /= 0) return
         
         do i=1,nz
            read(unit=iounit,fmt=*,iostat=ierr) k, &
               initial_xa(1,i), initial_xa(2,i), initial_xa(3,i), initial_xa(4,i), &
               initial_xa(5,i), initial_xa(6,i), initial_xa(7,i)
            if (abs(sum(initial_xa(1:species,i)) - 1d0) > 1d-4) then
               write(*,*) 'k', k, sum(initial_xa(1:species,i))
               do k=1,species
                  write(*,*) k, initial_xa(k,i)
               end do
               stop 'bad sum for initial_xa input'
            end if
         end do
      
         close(iounit)
      
         call free_iounit(iounit)

      end subroutine read_mix_data
      
      
      subroutine free_mix_data
         deallocate(abar,zbar,z2bar,ye,T,Rho,lnT,lnRho,dm,r,vc, &
               cdc,initial_xa,xa,dxa_dt)
      end subroutine free_mix_data
         
         
      subroutine write_mix_results
         use utils_lib
         use const_def
         character (len=100) :: filename
         integer :: k, ierr, iounit
         filename = 'plot_data/mix.data'
         ierr = 0
         iounit = alloc_iounit(ierr); if (ierr /= 0) return
         open(iounit, file=trim(filename), action='write', status='replace', iostat=ierr)
         if (ierr == 0) then
            write(*,*) 'write burn results to ' // trim(filename)
            write(iounit,'(a)') 'mix log'
            write(iounit,'(99(a,1x))') 'y', 'y0', 'dy', 'cdc', 'cdcdm2'
            do k=1,nz
               write(iounit,'(99e24.10)') xa(1,k), initial_xa(1,k), xa(1,k)-initial_xa(1,k), cdc(k), cdc(k)/dm(k)**2
            end do
            close(iounit)
         else
            write(*,*) 'failed to open internals file ' // trim(filename)
         end if
         call free_iounit(iounit)      
      end subroutine write_mix_results
      
      
      end module test_mix
