! ***********************************************************************
!
!   Copyright (C) 2012  Bill Paxton
!
!   MESA is free software; you can use it and/or modify
!   it under the combined terms and restrictions of the MESA MANIFESTO
!   and the GNU General Library Public License as published
!   by the Free Software Foundation; either version 2 of the License,
!   or (at your option) any later version.
!
!   You should have received a copy of the MESA MANIFESTO along with
!   this software; if not, it is available at the mesa website:
!   http://mesa.sourceforge.net/
!
!   MESA is distributed in the hope that it will be useful,
!   but WITHOUT ANY WARRANTY; without even the implied warranty of
!   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!   See the GNU Library General Public License for more details.
!
!   You should have received a copy of the GNU Library General Public License
!   along with this software; if not, write to the Free Software
!   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
!
! ***********************************************************************

      module solve_omega_mix

      use star_private_def
      use const_def

      implicit none

      contains
      

      integer function do_solve_omega_mix(s, dt_total)
         use star_utils, only: set_omega, update_time, total_times, &
            total_angular_momentum
         use mix_info, only: update_rotation_mixing_info
         use hydro_rotation, only: get_rotation_sigmas
         
         type (star_info), pointer :: s
         real(dp), intent(in) :: dt_total
         
         integer :: ierr, nz, i, j, k, max_iters_per_substep, &
            max_iters_total, total_num_iters, num_iters
         integer :: time0, clock_rate, steps_used, max_steps
         real(qp) :: total_all_before, remaining_time, total_time, time, dt, &
            J_tot0, J_tot1, max_del, avg_del, tol_correction_max, tol_correction_norm
         real(dp) :: dt_dble
         real(qp), pointer, dimension(:) :: &
            du, d, dl, x, b, bp, vp, xp, dX, X_0, X_1, rhs, del
            
         logical, parameter :: dbg = .false.

         include 'formats'
         
         do_solve_omega_mix = keep_going
         
         ierr = 0
         if (dt_total <= 0d0) return
         
         nz = s% nz
         total_time = dt_total
         time = 0
         steps_used = 0
         max_steps = 20
         max_iters_per_substep = 4
         max_iters_total = 40
         total_num_iters = 0
         tol_correction_max = 1d-4
         tol_correction_norm = 1d-7
         
         if (s% doing_timing) then
            total_all_before = total_times(s)
            call system_clock(time0,clock_rate)
         end if

         ! update j_rot if have extra angular momentum change
         if (s% do_adjust_J_lost) then
            call adjust_J_lost(s, ierr)
            if (ierr /= 0) return
         end if
                     
         ! update omega for new i_rot and previous j_rot to conserve angular momentum
         call set_omega(s, 'solve_omega_mix')
         
         s% extra_jdot(1:nz) = 0
         s% extra_omegadot(1:nz) = 0
         if (s% use_other_torque) then
            call s% other_torque(s% id, ierr)
            if (ierr /= 0) then
               if (s% report_ierr .or. dbg) &
                  write(*, *) 'solve_omega_mix: other_torque returned ierr', ierr
               return
            end if
         end if
         
         call do_alloc(ierr)
         if (ierr /= 0) then
            do_solve_omega_mix = terminate
            s% termination_code = t_solve_omega_mix
            if (s% report_ierr) write(*,*) 'allocate failed in do_solve_omega_mix'
            return
         end if
         
         J_tot0 = total_angular_momentum(s)

      step_loop: do while &
               (total_time - time > 1d-10*total_time .and. &
                  steps_used < max_steps)
               
            steps_used = steps_used + 1
            
            dt = 0.5d0*min_mixing_timescale()
            remaining_time = total_time - time
            dt = max(dt, 1d-6*remaining_time)
            if (dt >= remaining_time) then
               dt = remaining_time
            else
               dt = min(dt, 0.5d0*remaining_time)
            end if
            if (steps_used >= max_steps) dt = remaining_time ! just go for it
            if (dbg) write(*,3) 'mix dt', &
                  s% model_number, steps_used, dt, dt/remaining_time
            
            ! X_0 is omega at start of substep
            ! X_1 is current candidate for omega at end of substep
            ! dX = X_1 - X_0
            do k=1,nz
               X_0(k) = s% omega(k)
               X_1(k) = X_0(k)
               dX(k) = 0d0
            end do
            
         solve_loop: do num_iters = 1, max_iters_per_substep
      
               if (total_num_iters >= max_iters_total) then
                  do_solve_omega_mix = retry
                  exit step_loop
               end if
            
               total_num_iters = total_num_iters+1
            
               if (s% use_other_torque_implicit) then
                  call s% other_torque_implicit(s% id, ierr)
                  if (ierr /= 0) then
                     if (s% report_ierr .or. dbg) &
                        write(*, *) 'other_torque_implicit returned ierr', ierr
                     do_solve_omega_mix = retry
                     exit step_loop
                  end if
               end if
            
               call create_matrix_and_rhs(dt)
               
               ! solve for del
               call solve_tridiag(dl, d, du, rhs(1:nz), del(1:nz), nz, ierr)
               if (ierr /= 0) then
                  if (s% report_ierr) &
                     write(*,*) 'matrix solve failed in solve mix'
                  do_solve_omega_mix = retry
                  exit step_loop
               end if
               
               ! apply the correction dX = dX + del
               ! X_1 = X_0 + dX
               ! X_0 is omega at start of substep
               ! X_1 is candidate for omega at end of substep
               do k=2,nz
                  dX(k) = dX(k) + del(k)
                  X_1(k) = X_0(k) + dX(k)
                  s% omega(k) = X_1(k)
               end do
   				s% omega(1) = s% omega(2)
               
               ! if correction small enough, exit solve_loop
               max_del = maxval(abs(del(1:nz)))
               avg_del = sum(abs(del(1:nz)))/nz
               if (max_del <= tol_correction_max .and. avg_del <= tol_correction_norm) then
                  if (dbg) &
                     write(*,3) 'substep converged: iters max_del avg_del dt/total', &
                        steps_used, num_iters, max_del, avg_del, dt/total_time
                  exit solve_loop ! this substep is done
               end if
               
               if (num_iters == max_iters_per_substep) then
                  if (s% report_ierr) &
                     write(*,*) 'num_iters == max_iters_per_substep in solve mix'
                  do_solve_omega_mix = retry
                  exit step_loop
               end if
				
            end do solve_loop
				
            time = time + dt
            
         end do step_loop
         
         if (dbg) write(*,2) 'omega mix steps_used', steps_used
         
         if (total_time - time > 1d-10*total_time) then
            do_solve_omega_mix = retry
            if (s% report_ierr) &
                  write(*,*) 'failed in mixing angular momentum'
            !write(*,3) 'steps_used max_steps frac time', steps_used, max_steps, time/total_time
            !stop 'do_solve_omega_mix'
         end if
         
         if (do_solve_omega_mix == keep_going) then         
            do k=1,nz
               s% j_rot(k) = s% i_rot(k)*s% omega(k)
            end do
            if (.not. (s% use_other_torque .or. s% use_other_torque_implicit)) then
               ! check conservation 
               J_tot1 = total_angular_momentum(s) ! what we have
               if (.false.) write(*,2) 'solve omega transport err', &
                  s% model_number, (J_tot0 - J_tot1)/J_tot0, J_tot0, J_tot1
               if (abs(J_tot0 - J_tot1) > 1d-6*J_tot0) then
                  if (s% report_ierr) &
                     write(*,*) 'retry: failed to conserve angular momentum in mixing'
                  do_solve_omega_mix = retry
                  !stop 'do_solve_omega_mix'
               end if         
               if (dbg) then
                  write(*,2) 'final J_tot1', s% model_number, J_tot1
                  write(*,2) '(J_tot1 - J_tot0)/J_tot0', &
                     steps_used, (J_tot1 - J_tot0)/J_tot0, J_tot0, J_tot1
               end if
            end if
         end if
         
         if (dbg) write(*,*)
         
         call dealloc

         !if (s% doing_timing) &
         !   call update_time(s, time0, total_all_before, s% time_solve_omega_mix)

         
         contains
         
         
         subroutine do_alloc(ierr)
            use alloc
            integer, intent(out) :: ierr
            call non_crit_get_quad_array(s, du, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            call non_crit_get_quad_array(s, d, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            call non_crit_get_quad_array(s, dl, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            call non_crit_get_quad_array(s, x, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            call non_crit_get_quad_array(s, b, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            call non_crit_get_quad_array(s, bp, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            call non_crit_get_quad_array(s, vp, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            call non_crit_get_quad_array(s, xp, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            
            call non_crit_get_quad_array(s, dX, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            call non_crit_get_quad_array(s, X_0, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            call non_crit_get_quad_array(s, X_1, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            call non_crit_get_quad_array(s, rhs, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            call non_crit_get_quad_array(s, del, nz, nz_alloc_extra, 'solve_omega_mix', ierr)
            if (ierr /= 0) return            
            
            
         end subroutine do_alloc
            
            
         subroutine dealloc
            use alloc
            call non_crit_return_quad_array(s, du, 'solve_omega_mix')
            call non_crit_return_quad_array(s, d, 'solve_omega_mix')
            call non_crit_return_quad_array(s, dl, 'solve_omega_mix')
            call non_crit_return_quad_array(s, x, 'solve_omega_mix')
            call non_crit_return_quad_array(s, b, 'solve_omega_mix')
            call non_crit_return_quad_array(s, bp, 'solve_omega_mix')
            call non_crit_return_quad_array(s, vp, 'solve_omega_mix')
            call non_crit_return_quad_array(s, xp, 'solve_omega_mix')
            
            call non_crit_return_quad_array(s, dX, 'solve_omega_mix')
            call non_crit_return_quad_array(s, X_0, 'solve_omega_mix')
            call non_crit_return_quad_array(s, X_1, 'solve_omega_mix')
            call non_crit_return_quad_array(s, rhs, 'solve_omega_mix')
            call non_crit_return_quad_array(s, del, 'solve_omega_mix')            
         end subroutine dealloc
         
         
         real(dp) function min_mixing_timescale() result(dt)
            integer :: k
            real(qp) :: &
               dt00, omega, i_rot, irot00, irotm1, &
               am_sig00, am_sigm1, del00_omega, &
               c00, cm1, dmbar, delm1_omega, d2omega
               
            include 'formats'
            
            dt = 1d99
            
            do k = 1, nz
            
               omega = s% omega(k)
               i_rot = s% i_rot(k)       

               if (k < s% nz) then
                  irot00 = 0.5d0*(s% i_rot(k) + s% i_rot(k+1))
                  am_sig00 = s% am_sig(k)
                  c00 = am_sig00*irot00
                  del00_omega = omega - s% omega(k+1)
               else
                  c00 = 0
                  del00_omega = 0
               end if

               if (k > 1) then
                  irotm1 = 0.5d0*(s% i_rot(k-1) + s% i_rot(k))
                  am_sigm1 = s% am_sig(k-1)
                  cm1 = am_sigm1*irotm1
                  if (k < s% nz) then
                     dmbar = 0.5d0*(s% dm(k-1) + s% dm(k))
                  else
                     dmbar = 0.5d0*s% dm(k-1) + s% dm(k)
                  end if
                  delm1_omega = s% omega(k-1) - omega
               else
                  cm1 = 0
                  dmbar = 0.5d0*s% dm(k)
                  delm1_omega = 0
               end if
            
               if (k == 1) then
                  d2omega = -c00*del00_omega
               else if (k == s% nz) then
                  d2omega = cm1*delm1_omega
               else
                  d2omega = cm1*delm1_omega - c00*del00_omega
               end if
               
               ! (omega*i_rot - omega_prev*i_rot_prev)/dt =
               !    d2omega/dmbar + extra_omegadot*i_rot + extra_jdot
               
               dt00 = max(1d-12,omega)*i_rot/ &
                  max(1d-50,abs(d2omega/dmbar + &
                                 s% extra_omegadot(k)*i_rot + s% extra_jdot(k)))
               if (dt00 < dt) dt = dt00

            end do

         end function min_mixing_timescale
         
         
         subroutine create_matrix_and_rhs(dt)
            ! basic equation from Heger, Langer, & Woosley, 2000, eqn 46.
            ! with source term added.
            
            
            
            
            
            real(qp), intent(in) :: dt
            integer :: k
            real(qp) :: &
               omega, i_rot, irot00, irotm1, &
               am_sig00, am_sigm1, del00_omega, &
               c00, cm1, dmbar, delm1_omega, d2omega, &
               d_d2omega_domega_p1, d_d2omega_domega_m1, d_d2omega_domega_00
               
            include 'formats'
            
            do k = 1, nz
            
               omega = s% omega(k)
               i_rot = s% i_rot(k)       

               if (k < s% nz) then
                  irot00 = 0.5d0*(s% i_rot(k) + s% i_rot(k+1))
                  am_sig00 = s% am_sig(k)
                  c00 = am_sig00*irot00
                  del00_omega = omega - s% omega(k+1)
               else
                  c00 = 0
                  del00_omega = 0
               end if

               if (k > 1) then
                  irotm1 = 0.5d0*(s% i_rot(k-1) + s% i_rot(k))
                  am_sigm1 = s% am_sig(k-1)
                  cm1 = am_sigm1*irotm1
                  if (k < s% nz) then
                     dmbar = 0.5d0*(s% dm(k-1) + s% dm(k))
                  else
                     dmbar = 0.5d0*s% dm(k-1) + s% dm(k)
                  end if
                  delm1_omega = s% omega(k-1) - omega
               else
                  cm1 = 0
                  dmbar = 0.5d0*s% dm(k)
                  delm1_omega = 0
               end if
            
               if (k == 1) then
                  d2omega = -c00*del00_omega
               else if (k == s% nz) then
                  d2omega = cm1*delm1_omega
               else
                  d2omega = cm1*delm1_omega - c00*del00_omega
               end if
               d_d2omega_domega_00 = -(cm1 + c00)
               
               ! X_1 = X_0 + dX
               ! X_0 is omega at start of substep
               ! X_1 is candidate for omega at end of substep
               
               ! residual = dX - dt*((d2omega/dmbar + extra_jdot)/i_rot + extra_omegadot)
               ! J = d(residual)/d(omega)
               ! del is linear estimate of change to dX to make residual = 0
               ! solve J*del = -residual == rhs
              
               rhs(k) = -dX(k) + &
                  dt*((d2omega/dmbar + &
                        s% extra_jdot(k))/i_rot + s% extra_omegadot(k)) 
               
               d(k) = 1d0 - d_d2omega_domega_00*dt/(dmbar*i_rot)
            
               if (k < s% nz) then               
                  d_d2omega_domega_p1 = c00
                  du(k) = -d_d2omega_domega_p1*dt/(dmbar*i_rot)
               end if
                           
               if (k > 1) then                            
                  d_d2omega_domega_m1 = cm1
                  dl(k-1) = -d_d2omega_domega_m1*dt/(dmbar*i_rot)
               end if
               
               if (s% use_other_torque_implicit) then
                  d(k) = d(k) - &
                     dt*(s% d_extra_jdot_domega_00(k)/i_rot + &
                           s% d_extra_omegadot_domega_00(k))
                  if (k < s% nz) &
                     du(k) = du(k) - &
                        dt*(s% d_extra_jdot_domega_p1(k)/i_rot + &
                              s% d_extra_omegadot_domega_p1(k))
                  if (k > 1) &
                     dl(k-1) = dl(k-1) - &
                        dt*(s% d_extra_jdot_domega_m1(k)/i_rot + &
                              s% d_extra_omegadot_domega_m1(k))
               end if

            end do
            
         end subroutine create_matrix_and_rhs         


         subroutine solve_tridiag(sub, diag, sup, rhs, x, n, ierr)
            implicit none
            !      sub - sub-diagonal
            !      diag - the main diagonal
            !      sup - sup-diagonal
            !      rhs - right hand side
            !      x - the answer
            !      n - number of equations

            integer, intent(in) :: n
            real(qp), dimension(:), intent(in) :: sup, diag, sub
            real(qp), dimension(:), intent(in) :: rhs
            real(qp), dimension(:), intent(out) :: x
            integer, intent(out) :: ierr

            real(qp) :: m
            integer i

            ierr = 0

            bp(1) = diag(1)
            vp(1) = rhs(1)

            do i = 2,n
               m = sub(i-1)/bp(i-1)
               bp(i) = diag(i) - m*sup(i-1)
               vp(i) = rhs(i) - m*vp(i-1)
            end do

            xp(n) = vp(n)/bp(n)
            x(n) = xp(n)
            do i = n-1, 1, -1
               xp(i) = (vp(i) - sup(i)*xp(i+1))/bp(i)
               x(i) = xp(i)
            end do

         end subroutine solve_tridiag
         
         
      end function do_solve_omega_mix


      ! before mix, remove  actual_J_lost - s% angular_momentum_removed
      ! then set s% angular_momentum_removed to actual_J_lost
         
      subroutine adjust_J_lost(s, ierr)
         type (star_info), pointer :: s
         integer, intent(out) :: ierr
         
         real(dp) :: dmm1, dm00, dm, J, mass_lost, &
            actual_J_lost, delta_J, frac
         real(dp), parameter :: min_tau = 300
         
         integer :: k, last_k
         
         include 'formats'
         
         ierr = 0
         mass_lost = s% mstar_old - s% mstar
         if (mass_lost <= 0) return
         
         actual_J_lost = &
            s% adjust_J_fraction*mass_lost*s% j_rot_avg_surf + &
            (1d0 - s% adjust_J_fraction)*s% angular_momentum_removed
         delta_J = actual_J_lost - s% angular_momentum_removed

         dm00 = 0
         J = 0
         last_k = s% nz
         do k = 1, s% nz
            dmm1 = dm00
            dm00 = s% dm(k)
            dm = 0.5d0*(dmm1+dm00)
            J = J + dm*s% j_rot(k)
            if (s% tau(k) < 3*min_tau) cycle
            if (J > 1.5*abs(delta_J)) then
               last_k = k
               exit
            end if
         end do
         
         frac = (J - delta_J)/J            
         do k = 1, last_k
            s% j_rot(k) = frac*s% j_rot(k)
         end do

         s% angular_momentum_removed = actual_J_lost
      
      end subroutine adjust_J_lost
         
         


      end module solve_omega_mix


