! ***********************************************************************
!
!   Copyright (C) 2012  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************

      module solve_omega_mix

      use star_private_def
      use const_def

      implicit none

      contains
      

      integer function do_solve_omega_mix(s)
         use star_utils, only: update_time, total_times, &
            total_extra_angular_momentum, total_angular_momentum
         
         type (star_info), pointer :: s
         
         integer :: ierr, nz, i, j, k
         integer :: time0, clock_rate, nsteps, max_nsteps
         real(qp) :: total_all_before, J_tot1, J_tot2, J_tot_extra, &
            time_end, time, dt
         real(qp), pointer, dimension(:) :: &
            du, d, dl, x, b, bp, vp, xp

         include 'formats.dek'
         
         do_solve_omega_mix = keep_going
         ierr = 0
         nz = s% nz
         time_end = s% dt
         time = 0
         nsteps = 0
         max_nsteps = 1
         
         if (s% doing_timing) then
            total_all_before = total_times(s)
            call system_clock(time0,clock_rate)
         end if
         
         allocate(du(nz), d(nz), dl(nz), x(nz), &
            b(nz), bp(nz), vp(nz), xp(nz), stat=ierr)
         if (ierr /= 0) then
            do_solve_omega_mix = terminate
            write(*,*) 'allocate failed in do_solve_omega_mix'
            return
         end if
         
         forall (k=1:nz) s% omega(k) = s% j_rot(k)/s% i_rot_pre_hydro(k)
         
         J_tot_extra = total_extra_angular_momentum(s, s% dt)
         if (J_tot_extra /= 0) &
            write(*,2) 'J_tot_extra', s% model_number, J_tot_extra
         
         J_tot1 = total_angular_momentum(s) + J_tot_extra
         ! we want to end up with total J = J_tot1
      
         substep_loop: do while &
               (time_end - time > 1d-10 .and. nsteps <= max_nsteps)
            nsteps = nsteps + 1
            dt = min(time_end/max_nsteps, time_end - time)
            call create_matrix(dt)
            ! set b to rhs for matrix equation
            ! rhs = omega_prev*i_rot_prev/i_rot + 
            !        dt*(extra_omegadot + extra_jdot/i_rot)
            forall (k=1:nz) b(k) = s% omega(k)*s% i_rot_pre_hydro(k)/s% i_rot(k)
            if (J_tot_extra /= 0) then
               forall (k=1:nz) b(k) = b(k) + &
                  dt*(s% extra_omegadot(k) + s% extra_jdot(k)/s% i_rot(k))
            end if
            call solve_tridiag(dl, d, du, b(1:nz), x(1:nz), nz, ierr)
            if (ierr /= 0) then
               write(*,*) 'matrix solve failed in solve mix'
               do_solve_omega_mix = retry
               exit
            end if
            forall (k=2:nz) s% omega(k) = x(k)
				s% omega(1) = s% omega(2)
            time = time + dt
         end do substep_loop
         
         if (time_end - time > 1d-10) then
            do_solve_omega_mix = retry
            write(*,*) 'failed in mixing angular momentum'
            write(*,3) 'nsteps max_nsteps frac time', nsteps, max_nsteps, time/time_end
            stop 'do_solve_omega_mix'
         end if
         
         if (do_solve_omega_mix == keep_going) then         
            forall (k=1:nz) s% j_rot(k) = s% i_rot(k)*s% omega(k)         
            J_tot2 = total_angular_momentum(s)
            if (.false.) write(*,2) 'solve omega transport err', &
               s% model_number, (J_tot1 - J_tot2)/J_tot1, J_tot1, J_tot2
            if (abs(J_tot1 - J_tot2) > 1d-6*J_tot1) then
               write(*,*) 'failure to conserve angular momentum in mixing'
               do_solve_omega_mix = retry
               !stop 'do_solve_omega_mix'
            end if         
         end if
         
         call dealloc

         !if (s% doing_timing) &
         !   call update_time(s, time0, total_all_before, s% time_solve_omega_mix)

         
         contains

         
         subroutine dealloc
            deallocate(du, d, dl, x, b, bp, vp, xp)
         end subroutine dealloc
         
         
         subroutine create_matrix(dt) ! angular momentum transport
            ! equation from Heger, Langer, & Woosley, 2000, eqn 46.
            ! with source term added.
            real(qp), intent(in) :: dt
            integer :: k
            real(qp) :: &
               omega, i_rot, irot00, irotm1, &
               am_sig00, am_sigm1, del00_omega, &
               c00, cm1, dmbar, delm1_omega, d2omega, &
               d_d2omega_domega_p1, d_d2omega_domega_m1, d_d2omega_domega_00
               
            include 'formats.dek'
            
            do k = 1, nz
            
               omega = s% omega(k)
               i_rot = s% i_rot(k)       

               if (k < s% nz) then
                  irot00 = 0.5d0*(s% i_rot(k) + s% i_rot(k+1))
                  am_sig00 = s% am_sig(k)
                  c00 = am_sig00*irot00
                  del00_omega = omega - s% omega(k+1)
               else
                  c00 = 0
                  del00_omega = 0
               end if

               if (k > 1) then
                  irotm1 = 0.5d0*(s% i_rot(k-1) + s% i_rot(k))
                  am_sigm1 = s% am_sig(k-1)
                  cm1 = am_sigm1*irotm1
                  if (k < s% nz) then
                     dmbar = 0.5d0*(s% dm(k-1) + s% dm(k))
                  else
                     dmbar = 0.5d0*s% dm(k-1) + s% dm(k)
                  end if
                  delm1_omega = s% omega(k-1) - omega
               else
                  cm1 = 0
                  dmbar = 0.5d0*s% dm(k)
                  delm1_omega = 0
               end if
            
               if (k == 1) then
                  d2omega = -c00*del00_omega
               else if (k == s% nz) then
                  d2omega = cm1*delm1_omega
               else
                  d2omega = cm1*delm1_omega - c00*del00_omega
               end if
               d_d2omega_domega_00 = -(cm1 + c00)
               
               ! (omega*i_rot - omega_prev*i_rot_prev)/dt =
               !    d2omega/dmbar + extra_omegadot*i_rot + extra_jdot
              
               ! lhs = omega - d2omega*dt/(dmbar*i_rot)
               ! rhs = omega_prev*i_rot_prev/i_rot + 
               !        dt*(extra_omegadot + extra_jdot/i_rot)
            
               d(k) = 1d0 - d_d2omega_domega_00*dt/(dmbar*i_rot)
            
               if (k < s% nz) then               
                  d_d2omega_domega_p1 = c00
                  du(k) = -d_d2omega_domega_p1*dt/(dmbar*i_rot)
               end if
                           
               if (k > 1) then                            
                  d_d2omega_domega_m1 = cm1
                  dl(k-1) = -d_d2omega_domega_m1*dt/(dmbar*i_rot)
               end if

            end do
            
         end subroutine create_matrix         


         subroutine solve_tridiag(sub, diag, sup, rhs, x, n, ierr)
            implicit none
            !      sub - sub-diagonal
            !      diag - the main diagonal
            !      sup - sup-diagonal
            !      rhs - right hand side
            !      x - the answer
            !      n - number of equations

            integer, intent(in) :: n
            real(qp), dimension(:), intent(in) :: sup, diag, sub
            real(qp), dimension(:), intent(in) :: rhs
            real(qp), dimension(:), intent(out) :: x
            integer, intent(out) :: ierr

            real(qp) :: m
            integer i

            ierr = 0

            bp(1) = diag(1)
            vp(1) = rhs(1)

            do i = 2,n
               m = sub(i-1)/bp(i-1)
               bp(i) = diag(i) - m*sup(i-1)
               vp(i) = rhs(i) - m*vp(i-1)
            end do

            xp(n) = vp(n)/bp(n)
            x(n) = xp(n)
            do i = n-1, 1, -1
               xp(i) = (vp(i) - sup(i)*xp(i+1))/bp(i)
               x(i) = xp(i)
            end do

         end subroutine solve_tridiag
         
         
      end function do_solve_omega_mix


      end module solve_omega_mix


