! ***********************************************************************
!
!   Copyright (C) 2012  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************

      module solve_diffusion

      use star_private_def
      use const_def

      implicit none



      contains
      

      integer function do_solve_diffusion( &
            s,dt,nz,x,sig,source,dm,atol1,rtol1,max_steps,nsteps)
         
         use mtx_lib
         use num_lib
         
         type (star_info), pointer :: s
         real(dp), intent(in) :: dt, atol1, rtol1
         real(dp), pointer, dimension(:) :: x, sig, source, dm
         integer, intent(in) :: nz, max_steps
         integer, intent(out) :: nsteps

         integer :: ierr, i, j, k
         integer :: time0, clock_rate, which_solver, n, itol, ijac, &
            nzmax, isparse, mljac, mujac, imas, mlmas, mumas, iout, lout, &
            lid, lrd, liwork, lwork, lrpar, lipar, idid
         real(dp) :: rtol(1), atol(1), t_start, t_end, dt_init, max_step_size
         real(dp), pointer :: work(:), rpar_decsol(:), rpar(:)
         integer, pointer :: iwork(:), ipar_decsol(:), ipar(:)
         
         include 'formats.dek'
         
         do_solve_diffusion = keep_going
         ierr = 0
         
         !which_solver = rodasp_solver
         !which_solver = seulex_solver
         !which_solver = sodex_solver
         which_solver = ros3pl_solver
         
         n = nz
         t_start = 0
         t_end = dt
         dt_init = dt !*1d-4
         max_step_size = dt
         rtol(1) = rtol1
         atol(1) = atol1
         itol = 0
         ijac = 1 ! analytic
         nzmax = 0 ! not sparse
         isparse = 0
         mljac = 1
         mujac = 1
         imas = 0
         mlmas = 0
         mumas = 0
         iout = 1
         lout = 6
         
         
         !call tridiag_work_sizes(n,lrd,lid)
         call lapack_work_sizes(n,lrd,lid)
         allocate(rpar_decsol(lrd), ipar_decsol(lid))
         
         call isolve_work_sizes( &
            n, nzmax, imas, mljac, mujac, mlmas, mumas, liwork, lwork)
         allocate(work(lwork), iwork(liwork))
         
         work(:) = 0
         iwork(:) = 0
         
         lrpar = 1 + 3*nz
         lipar = 2
         allocate(rpar(lrpar), ipar(lipar))
         
         i = 0
         rpar(1) = t_end; i = i+1
         rpar(i+1:i+nz) = sig(1:nz); i = i+nz
         rpar(i+1:i+nz) = dm(1:nz); i = i+nz
         rpar(i+1:i+nz) = source(1:nz); i = i+nz
         
         ipar(1) = s% id
         ipar(2) = nz
         
         call isolve( &
            which_solver, n, fcn, t_start, x, t_end, &
            dt_init, max_step_size, max_steps, &
            rtol, atol, itol, &
            jac, ijac, null_sjac, nzmax, isparse, mljac, mujac, &
            null_mas, imas, mlmas, mumas, &
            solout, iout, &
            lapack_decsol, &
            !tridiag_decsol, &
            null_decsols, lrd, rpar_decsol, lid, ipar_decsol,  &
            work, lwork, iwork, liwork, &
            lrpar, rpar, lipar, ipar, &
            lout, idid)
         
         nsteps = iwork(16) ! number of computed steps (accepted + rejected)
         
         if (idid < 0) then
            write(*,2) 'idid', idid
            do_solve_diffusion = retry
            if (s% report_ierr) write(*,*) 'isolve failed for solve_diffusion'
            
            stop 'solve_diffusion'
            
         end if

         deallocate(rpar_decsol, ipar_decsol, work, iwork, rpar, ipar)


      end function do_solve_diffusion


      
      subroutine jac(n, t, x, f, dfdx, ld_dfdx, lrpar, rpar, lipar, ipar, ierr)
         integer, intent(in) :: n, ld_dfdx, lrpar, lipar
         real(dp), intent(in) :: t
         real(dp), intent(inout) :: x(n)
         real(dp), intent(out) :: f(n) ! dx/dt
         real(dp), intent(out) :: dfdx(ld_dfdx, n)
         real(dp), intent(inout), target :: rpar(lrpar)
         integer, intent(inout), target :: ipar(lipar)
         integer, intent(out) :: ierr ! nonzero means terminate integration
         
         type (star_info), pointer :: s
         real(dp), pointer, dimension(:) :: sig, dm, source
         real(dp) :: sig00_dm, sigp1_dm, t_end
         integer :: i, k, id, nz
         
         ierr = 0

         if (1 > lipar) stop 'bad lipar for jac in solve_diffusion'
         id = ipar(1)
         call get_star_ptr(id, s, ierr)
         if (ierr /= 0) return
         
         nz = ipar(2)
         
         i = 0
         t_end = rpar(1); i = i+1
         sig => rpar(i+1:i+nz); i = i+nz
         dm => rpar(i+1:i+nz); i = i+nz
         source => rpar(i+1:i+nz); i = i+nz
         if (i > lrpar) stop 'bad lrpar for jac in solve_diffusion'
         
         ! f(k) = source(k) + (sig(k)*(x(k-1)-x(k)) - sig(k+1)*(x(k)-x(k+1)))/dm
         do k = 1, nz
            f(k) = source(k)
            if (k > 1) then
               sig00_dm = sig(k)/dm(k)
               f(k) = f(k) + sig00_dm*(x(k-1)-x(k))
            else
               sig00_dm = 0
            end if
            if (k < nz) then
               sigp1_dm = sig(k+1)/dm(k)
               f(k) = f(k) - sigp1_dm*(x(k)-x(k+1))
            else
               sigp1_dm = 0
            end if
            ! jacobian dfdx(i-j+2,j) = partial f(i) wrt x(j)
            if (ld_dfdx > 0) then
               dfdx(2,k) = -(sig00_dm + sigp1_dm)
               if (k > 1) dfdx(1,k-1) = sig00_dm
               if (k < nz) dfdx(3,k+1) = sigp1_dm
            end if
         end do        
         
      end subroutine jac
      
      
      subroutine fcn(n, t, x, f, lrpar, rpar, lipar, ipar, ierr)
         integer, intent(in) :: n, lrpar, lipar
         real(dp), intent(in) :: t
         real(dp), intent(inout) :: x(n) ! okay to edit x if necessary
         real(dp), intent(out) :: f(n) ! dx/dt
         real(dp), intent(inout), target :: rpar(lrpar)
         integer, intent(inout), target :: ipar(lipar)
         integer, intent(out) :: ierr ! nonzero means retry with smaller timestep.
         integer, parameter :: ld_dfdx = 0
         double precision :: dfdx(ld_dfdx,n)
         ierr = 0
         call jac(n, t, x, f, dfdx, ld_dfdx, lrpar, rpar, lipar, ipar, ierr)
      end subroutine fcn


      subroutine solout( &
            nr, told, t, n, x, rwork_y, iwork_y, interp_y, lrpar, rpar, lipar, ipar, irtrn)
         ! nr is the step number.
         ! t is the current time value; told is the previous time value.
         ! x is the current x value.
         ! irtrn negative means terminate integration.
         ! rwork_y and iwork_y hold info for interp_y
         ! note that these are not the same as the rwork and iwork arrays for the solver.
         use const_def, only: dp
         integer, intent(in) :: nr, n, lrpar, lipar
         real(dp), intent(in) :: told, t
         real(dp), intent(inout) :: x(n)
         ! x can be modified if necessary 
         real(dp), intent(inout), target :: rpar(lrpar), rwork_y(*)
         integer, intent(inout), target :: ipar(lipar), iwork_y(*)
         interface
            include 'num_interp_y.dek'
         end interface
         integer, intent(out) :: irtrn ! < 0 causes solver to return to calling program.
         real(dp) :: t_end
         include 'formats.dek'
         irtrn = 0
         x(:) = max(0d0, x(:))
         t_end = rpar(1)
         write(*,2) 'diff step', nr, t/t_end, t, t_end
      end subroutine solout


      end module solve_diffusion


