#include "cppdefs.h"
!-----------------------------------------------------------------------
!BOP
!
! !MODULE:  Salinity
!
! !INTERFACE:
   module salinity
!
! !DESCRIPTION:
!
! In this module, the salinity equation is processed by
! reading in the namelist {\tt salt} and initialising the salinity field
! (this is done in {\tt init\_salinity}),
! and calculating the advection-diffusion-equation (see {\tt do\_salinity}).
!
! !USES:
   use exceptions
   use domain, only: imin,jmin,imax,jmax,kmax,H,az,dry_z
!KB   use get_field, only: get_3d_field
   use variables_2d, only: fwf_int
   use variables_3d, only: rk,S,hn,kmin,sss
   use variables_3d, only: saltbins,saltbinmin,saltbinmax,Si_s
   use halo_zones, only: update_3d_halo,wait_halo,D_TAG,H_TAG
   IMPLICIT NONE

   interface
      subroutine tracer_diffusion(f,hn,AH_method,AH_const,AH_Prt,AH_stirr_const, &
                                  ffluxu,ffluxv,                                 &
                                  phymix)
         use domain, only: imin,imax,jmin,jmax,kmax
         IMPLICIT NONE
         REALTYPE,intent(in)           :: hn(I3DFIELD)
         integer,intent(in)            :: AH_method
         REALTYPE,intent(in)           :: AH_const,AH_Prt,AH_stirr_const
         REALTYPE,intent(inout)        :: f(I3DFIELD)
         REALTYPE,dimension(:,:,:),pointer,contiguous,intent(in),optional :: ffluxu,ffluxv
         REALTYPE,dimension(:,:,:),pointer,contiguous,intent(in),optional :: phymix
      end subroutine tracer_diffusion
   end interface
!
   private
!
! !PUBLIC DATA MEMBERS:
   public init_salinity, postinit_salinity, do_salinity, init_salinity_field
!
! !PRIVATE DATA MEMBERS:
   integer                   :: salt_method=1,salt_format=2
   character(len=PATH_MAX)   :: salt_file="t_and_s.nc"
   integer                   :: salt_field_no=1
   character(len=32)         :: salt_name='salt'
   REALTYPE                  :: salt_const=35*_ONE_
   integer                   :: salt_adv_split=0
   integer                   :: salt_adv_hor=1
   integer                   :: salt_adv_ver=1
   REALTYPE                  :: avmols = 1.1d-9
   integer                   :: salt_AH_method=0
   REALTYPE                  :: salt_AH_const=1.1d-9
   REALTYPE                  :: salt_AH_Prt=_TWO_
   REALTYPE                  :: salt_AH_stirr_const=_ONE_
   REALTYPE                  :: sss_nudging_time=-_ONE_
   integer                   :: salt_check=0
   REALTYPE                  :: min_salt=_ZERO_,max_salt=40*_ONE_
   integer,public            :: nonnegsalt_method=0
   integer, allocatable      :: binidx_kij(:,:,:)
!
! !REVISION HISTORY:
!  Original author(s): Karsten Bolding & Hans Burchard
!
!EOP
!-----------------------------------------------------------------------

   contains

!-----------------------------------------------------------------------
!BOP
!
! !IROUTINE: init_salinity - initialisation of salinity
! \label{sec-init-salinity}
!
! !INTERFACE:
   subroutine init_salinity(runtype,hotstart)
!
! !DESCRIPTION:
!
! Here, the salinity equation is initialised. First, the namelist
! {\tt salt} is read from {\tt getm.inp}. Then, depending on the
! {\tt salt\_method}, the salinity field is read from a
! hotstart file ({\tt salt\_method}=0), initialised with a constant value
! ({\tt salt\_method}=1), initialised and interpolated
! with horizontally homogeneous
! salinity from a given salinity profile ({\tt salt\_method}=2),
! or read in and interpolated from a 3D netCDF field ({\tt salt\_method}=3).
! Finally, a number of sanity checks are performed for the chosen
! salinity advection schemes.
!
! Apart from this, there are various options for specific initial
! conditions which are selected by means of compiler options.
!
! !USES:
   use advection, only: J7
   use advection_3d, only: print_adv_settings_3d
   use variables_3d, only: deformC_3d,deformX_3d,deformUV_3d,calc_stirr
   use m2d, only: Am_method,AM_LES
   use les, only: les_mode,LES_TRACER,LES_BOTH
   IMPLICIT NONE
!
! !INPUT PARAMETERS:
   integer,intent(in)        :: runtype
   logical,intent(in)        :: hotstart
!
! !LOCAL VARIABLES:
   integer                   :: i,j,k,l,n,rc
   integer                   :: status
   NAMELIST /salt/                                            &
            salt_method,salt_const,salt_file,                 &
            salt_format,salt_name,salt_field_no,              &
            salt_adv_split,salt_adv_hor,salt_adv_ver,         &
            avmols,salt_AH_method,salt_AH_const,salt_AH_Prt,  &
            salt_AH_stirr_const,sss_nudging_time,             &
            saltbins,saltbinmin,saltbinmax,                   &
            salt_check,min_salt,max_salt,                     &
            nonnegsalt_method
!EOP
!-------------------------------------------------------------------------
!BOC
#ifdef DEBUG
!   integer, save :: Ncall = 0
!   Ncall = Ncall+1
!   write(debug,*) 'init_salinity() # ',Ncall
#endif

   LEVEL2 'init_salinity()'
   read(NAMLST,salt)

   if (avmols .lt. _ZERO_) then
      LEVEL3 'setting avmols to 0.'
      avmols = _ZERO_
   else
      LEVEL3 'avmols = ',real(avmols)
   end if

!  Sanity checks for advection specifications
   LEVEL3 'Advection of salinity'
   if (salt_adv_hor .eq. J7) stop 'init_salinity: J7 not implemented yet'
   call print_adv_settings_3d(salt_adv_split,salt_adv_hor,salt_adv_ver,_ZERO_)

   select case (salt_AH_method)
      case(0)
         LEVEL3 'salt_AH_method=0 -> horizontal diffusion of salt disabled'
      case(1)
         LEVEL3 'salt_AH_method=1 -> Using constant horizontal diffusivity of salt'
         if (salt_AH_const .lt. _ZERO_) then
              call getm_error("init_salinity()", &
                         "Constant horizontal diffusivity of salt <0");
         end if
         LEVEL4 real(salt_AH_const)
      case(2)
         LEVEL3 'salt_AH_method=2 -> using LES parameterisation'
         LEVEL4 'turbulent Prandtl number: ',real(salt_AH_Prt)
         deformC_3d =.true.
         deformX_3d =.true.
         deformUV_3d=.true.
         if (Am_method .eq. AM_LES) then
            les_mode = LES_BOTH
         else
            les_mode = LES_TRACER
         end if
      case(3)
         LEVEL3 'salt_AH_method=3 -> SGS stirring parameterisation'
         if (salt_AH_stirr_const .lt. _ZERO_) then
              call getm_error("init_salinity()", &
                         "salt_AH_stirr_const <0");
         end if
         LEVEL4 'stirring constant: ',real(salt_AH_stirr_const)
         deformC_3d =.true.
         deformX_3d =.true.
         deformUV_3d=.true.
         calc_stirr=.true.
      case default
         call getm_error("init_salinity()", &
                         "A non valid salt_AH_method has been chosen");
   end select

   if (sss_nudging_time.gt._ZERO_) then
      LEVEL3 'nudging of SSS enabled'
      LEVEL4 'sss_nudging_time=',real(sss_nudging_time)
      allocate(sss(E2DFIELD),stat=rc)
   end if

   if (runtype.eq.4 .and. saltbins.gt.0) then
      LEVEL3 'salinity space for TEF:',real(saltbinmin),real(saltbinmax),'(',saltbins,'bins)'
      allocate(Si_s(0:saltbins),stat=rc)
      if (rc /= 0) stop 'postinit_3d: Error allocating memory (Si_s)'
      do l=0,saltbins-1
         Si_s(l) = saltbinmin + (saltbinmax-saltbinmin)/saltbins * l
      end do
      Si_s(saltbins) = saltbinmax ! direct assignment to avoid truncation errors
      allocate(binidx_kij(_KRANGE_,I2DFIELD),stat=rc)
      if (rc /= 0) stop 'postinit_3d: Error allocating memory (binidx_kij)'
      binidx_kij = -1
   end if

   if (.not. hotstart) then
      call init_salinity_field()
   end if

   LEVEL3 'salt_check=',salt_check
   if (salt_check .ne. 0) then
      LEVEL4 'doing sanity check on salinity'
      LEVEL4 'min_salt=',min_salt
      LEVEL4 'max_salt=',max_salt
      if (salt_check .gt. 0) then
         LEVEL4 'out-of-bound values result in termination of program'
      end if
      if (salt_check .lt. 0) then
         LEVEL4 'out-of-bound values result in warnings only'
      end if

      if (.not. hotstart) then

      call check_3d_fields(imin,jmin,imax,jmax,kmin,kmax,az, &
                           S,min_salt,max_salt,status)
      if (status .gt. 0) then
         if (salt_check .gt. 0) then
            call getm_error("init_salinity()", &
                            "out-of-bound values encountered")
         end if
         if (salt_check .lt. 0) then
            LEVEL1 'init_salinity(): ',status, &
                   ' out-of-bound values encountered'
         end if
      end if

      end if

   end if

#ifdef NONNEGSALT
   if (nonnegsalt_method .ne. 1) then
      LEVEL3 "reset nonnegsalt_method=1 due to obsolete"
      LEVEL3 "NONNEGSALT macro. Note that this behaviour"
      LEVEL3 "will be removed in the future!"
      nonnegsalt_method = 1
   end if
#endif

   LEVEL3 'nonnegsalt_method = ',nonnegsalt_method

#ifdef DEBUG
   write(debug,*) 'Leaving init_salinity()'
   write(debug,*)
#endif
   return
   end subroutine init_salinity
!EOC

!-----------------------------------------------------------------------
!BOP
!
! !IROUTINE: init_salinity_field - initialisation of the salinity field
! \label{sec-init-salinity-field}
!
! !INTERFACE:
   subroutine init_salinity_field()
!
! !DESCRIPTION:
! Initialisation of the salinity field as specified by salt\_method
! and exchange of the HALO zones
!
! !USES:
   IMPLICIT NONE
!
! !INPUT PARAMETERS:
!
! !INPUT/OUTPUT PARAMETERS:
!
! !OUTPUT PARAMETERS:
!
! !LOCAL VARIABLES:
   integer                   :: i,j,k,n
   integer, parameter        :: nmax=10000
   REALTYPE                  :: zlev(nmax),prof(nmax)
   integer                   :: status
!EOP
!-------------------------------------------------------------------------
!BOC
#ifdef DEBUG
   integer, save :: Ncall = 0
   Ncall = Ncall+1
   write(debug,*) 'init_salinity_field() # ',Ncall
#endif

   select case (salt_method)
      case(0)
         LEVEL3 'getting initial fields from hotstart'
      case(1)
         LEVEL3 'setting to constant value ',real(salt_const)
         forall(i=imin:imax,j=jmin:jmax, az(i,j) .ne. 0) &
                S(i,j,:) = salt_const
      case(2)
         LEVEL3 'using profile'
         LEVEL4 trim(salt_file)
         call read_profile(salt_file,nmax,zlev,prof,n)
         call ver_interpol(n,zlev,prof,imin,jmin,imax,jmax,kmax, &
                           az,H,hn,S)
      case(3)
         LEVEL3 'interpolating from 3D field'
         LEVEL4 trim(salt_file)
         call get_3d_field(salt_file,salt_name,salt_field_no,.true.,S)
      case default
         FATAL 'Not valid salt_method specified'
         stop 'init_salinity'
   end select

   S(:,:,0) = -9999._rk
   forall(i=imin:imax,j=jmin:jmax, az(i,j).eq.0) S(i,j,:) = -9999._rk

   call update_3d_halo(S,S,az,imin,jmin,imax,jmax,kmax,D_TAG)
   call wait_halo(D_TAG)
   call mirror_bdy_3d(S,D_TAG)


#ifdef DEBUG
   write(debug,*) 'Leaving init_salinity_field()'
   write(debug,*)
#endif
   return
   end subroutine init_salinity_field
!EOC
!-----------------------------------------------------------------------
!BOP
!
! !IROUTINE: postinit_salinity -
! \label{sec-postinit-salinity}
!
! !INTERFACE:
   subroutine postinit_salinity()
!
! !DESCRIPTION:
!
! !USES:
   use variables_3d, only: h_s, hS_s, hS2_s
   use getm_timers, only: tic, toc, TIM_TEF
   IMPLICIT NONE
!
! !INPUT PARAMETERS:
!
! !INPUT/OUTPUT PARAMETERS:
!
! !OUTPUT PARAMETERS:
!
! !LOCAL VARIABLES:
   integer :: i,j,k,l
!EOP
!-------------------------------------------------------------------------
!BOC
#ifdef DEBUG
   integer, save :: Ncall = 0
   Ncall = Ncall+1
   write(debug,*) 'postinit_salinity() # ',Ncall
#endif

!  Note (KK): This MUST be called after arrays have been allocated in
!             postinit_variables_3d()!
   if (allocated(binidx_kij)) then
      call tic(TIM_TEF)
      if (allocated(h_s  )) h_s   = _ZERO_
      if (allocated(hS_s )) hS_s  = _ZERO_
      if (allocated(hS2_s)) hS2_s = _ZERO_
      do j = jmin-HALO,jmax+HALO
         do i = imin-HALO,imax+HALO
            if ( az(i,j) .eq. 1 ) then
               call rebin(kmax,S(i,j,1:),saltbins,Si_s,binidx_kij(1:,i,j))
               do k=1,kmax
                  l = binidx_kij(k,i,j)
                  if (l .lt. 0) cycle
                  if (allocated(h_s  )) h_s  (i,j,l) = h_s  (i,j,l) + hn(i,j,k)
                  if (allocated(hS_s )) hS_s (i,j,l) = hS_s (i,j,l) + hn(i,j,k)*S(i,j,k)
                  if (allocated(hS2_s)) hS2_s(i,j,l) = hS2_s(i,j,l) + hn(i,j,k)*S(i,j,k)*S(i,j,k)
               end do
            end if
         end do
      end do
      call toc(TIM_TEF)
   end if

#ifdef DEBUG
   write(debug,*) 'Leaving postinit_salinity()'
   write(debug,*)
#endif
   return
   end subroutine postinit_salinity
!EOC
!-----------------------------------------------------------------------
!BOP
!
! !IROUTINE:  do_salinity - salinity equation \label{sec-do-salinity}
!
! !INTERFACE:
   subroutine do_salinity(loop)
!
! !DESCRIPTION:
!
! Here, one time step for the salinity equation is performed.
! First, preparations for the call to the advection schemes are
! made, i.e.\ calculating the necessary metric coefficients.
! After the call to the advection schemes, which actually perform
! the advection (and horizontal diffusion) step as an operational
! split step, the tri-diagonal matrix for calculating the
! new salinity by means of a semi-implicit central scheme for the
! vertical diffusion is set up.
! There are no source terms on the right hand sides.
! The subroutine is completed by solving the tri-diagonal linear
! equation by means of a tri-diagonal solver.
!
! Also here, there are some specific options for single test cases
! selected by compiler options.
!
! !USES:
   use advection_3d, only: do_advection_3d
   use advection_3d, only: ffluxu2_3d,ffluxv2_3d
   use variables_3d, only: dt,cnpar,hn,ho,nuh,uu,vv,ww,hun,hvn
   use domain,       only: imin,imax,jmin,jmax,kmax,az,au,av
   use domain,       only: dxc,dyc,dyu,dxv,arcd1
   use domain,       only: sdom
   use ice,          only: have_ice, ssf_ice, svf_ice
   use getm_timers, only: tic,toc,TIM_SALT,TIM_SALTH,TIM_MIXANALYSIS,TIM_TEF
   use variables_3d, only: Sfluxu,Sfluxv,Sfluxw
   use variables_3d, only: Sfluxu2,Sfluxv2
   use variables_3d, only: counts_s,flags_s,h_s,hS_s,hS2_s,hpmS_s,hnmS_s
   use variables_3d, only: hu_s,uu_s,Sfluxu_s,S2fluxu_s
   use variables_3d, only: hv_s,vv_s,Sfluxv_s,S2fluxv_s
   use variables_3d, only: wdia_s, f3dia_s
   use variables_3d, only: fwf_s,fwfS2_s, phymix_S_fwf_s
   use variables_3d, only: phymix_S_riv_s
   use variables_3d, only: rheight_int_s, rsflux_s
   use variables_3d, only: nummix_S, phymix_S, phymix_S_fwf
   use rivers,       only: nriverl, ri, rj, rnl, rheight_int, river_salt
!$ use omp_lib
   IMPLICIT NONE
!
! !INPUT PARAMETERS:
   integer, intent(in) :: loop
!
! !LOCAL VARIABLES:
   integer                   :: i,j,k,l,n,nn,rc
   REALTYPE, POINTER         :: Res(:)
   REALTYPE, POINTER         :: auxn(:),auxo(:)
   REALTYPE, POINTER         :: a1(:),a2(:),a3(:),a4(:)
   REALTYPE, pointer         :: fluxw(:)
   REALTYPE                  :: dS_ssf(I2DFIELD)
   REALTYPE                  :: sss_old(I2DFIELD)
   REALTYPE, dimension(I3DFIELD) :: Sold,Sfluxu_tmp,Sfluxv_tmp
   REALTYPE, dimension(I2DFIELD) :: wrk2d
   REALTYPE                  :: rivmix(0:kmax,1:nriverl)
   REALTYPE                  :: deltaS,hu_k,uu_k
   REALTYPE                  :: S_k(0:kmax)
   integer                   :: binidx(0:kmax)
  integer                    :: status
   logical                   :: calc_phymix_fwf
!EOP
!-----------------------------------------------------------------------
!BOC
#ifdef DEBUG
   integer, save :: Ncall = 0
   Ncall = Ncall+1
   write(debug,*) 'do_salinity() # ',Ncall
#endif
   call tic(TIM_SALT)

   if (allocated(phymix_S_riv_s)) then
      do nn=1,nriverl
         rivmix(:,nn) = phymix_S(ri(nn),rj(nn),:)
      end do
   end if

   calc_phymix_fwf = ( associated(phymix_S) .or. allocated(phymix_S_fwf) )
   if (calc_phymix_fwf) sss_old = S(:,:,kmax)

   dS_ssf = _ZERO_

   where( sdom .eq. 1 ) ! open ocean
!     Note (KK): fwf_int was already included into ho.
!                Need to remove corresponding salt input!
      dS_ssf = - fwf_int * S(:,:,kmax) / ho(:,:,kmax)
   end where
   if (have_ice) then
      where( sdom .eq. 3 ) ! glacial ice
!        Note (KK): fwf_method>0: fwf_int=dt*svf_ice was already included into ho.
!                                 Need to remove corresponding salt input!
!                   fwf_method=0: fwf_int=0, but ssf_ice still contains svf_ice*Smelt part.
!                                 Need to remove dt*svf_ice*S[kmax]!
         dS_ssf = dt * ( ssf_ice - svf_ice * S(:,:,kmax) ) / ho(:,:,kmax) ! positive incoming
      end where
   end if

   where( az .eq. 1 )
      S(:,:,kmax) = S(:,:,kmax) + dS_ssf
   end where

   if (calc_phymix_fwf) then
!     Note (KK): division by hn because we already refer to new volume
!                where diffusion will be added
      where (az .eq. 1) wrk2d = sss_old * S(:,:,kmax) * fwf_int / dt
      if (associated(phymix_S)) then
         where (az .eq. 1) phymix_S(:,:,kmax) = phymix_S(:,:,kmax) + wrk2d
      end if
      if (allocated(phymix_S_fwf)) then
         where (az .eq. 1) phymix_S_fwf = wrk2d
      end if
   end if

   if (allocated(fwf_s).or.allocated(fwfS2_s)) then
      call tic(TIM_TEF)
      if (allocated(fwf_s  )) fwf_s   = _ZERO_
      if (allocated(fwfS2_s)) fwfS2_s = _ZERO_
      do j = jmin,jmax
         do i = imin,imax
            if (az(i,j) .eq. 1) then
               l = binidx_kij(kmax,i,j) ! binning based on old salinity
               if (l .ge. 0) then
                  if (allocated(fwf_s  )) fwf_s  (i,j,l) = fwf_int(i,j)
                  if (allocated(fwfS2_s)) fwfS2_s(i,j,l) = fwf_int(i,j)*S(i,j,kmax)*S(i,j,kmax)
               end if
            end if
         end do
      end do
      call toc(TIM_TEF)
   end if


   call do_advection_3d(dt,S,uu,vv,ww,hun,hvn,ho,hn,                           &
                        salt_adv_split,salt_adv_hor,salt_adv_ver,_ZERO_,H_TAG, &
                        ffluxu=Sfluxu,ffluxv=Sfluxv,ffluxw=Sfluxw,nvd=nummix_S)
   if (allocated(Sfluxu2)) Sfluxu2 = ffluxu2_3d
   if (allocated(Sfluxv2)) Sfluxv2 = ffluxv2_3d

!  KK-TODO: do we have to treat salt fluxes from different
!           directional-split steps separately?
   if (allocated(hu_s).or.allocated(uu_s).or.allocated(Sfluxu_s).or.allocated(S2fluxu_s)) then
      call tic(TIM_TEF)
      if (allocated(hu_s     )) hu_s      = _ZERO_
      if (allocated(uu_s     )) uu_s      = _ZERO_
      if (allocated(Sfluxu_s )) Sfluxu_s  = _ZERO_
      if (allocated(S2fluxu_s)) S2fluxu_s = _ZERO_
      do j = jmin,jmax
         do i = imin-1,imax
            if (au(i,j).eq.1 .or. au(i,j).eq.2) then
               where (uu(i,j,1:kmax) .ne. _ZERO_)
                  S_k(1:kmax) = Sfluxu(i,j,1:kmax) / uu(i,j,1:kmax) / DYU
               else where
                  S_k(1:kmax) = saltbinmax + 9999.0_rk
               end where
               call rebin(kmax,S_k(1:),saltbins,Si_s,binidx(1:))
               do k=1,kmax
                  l = binidx(k)
                  if (l .lt. 0) cycle
                  if (allocated(hu_s     )) hu_s     (i,j,l) = hu_s     (i,j,l) + hun(i,j,k)
                  if (allocated(uu_s     )) uu_s     (i,j,l) = uu_s     (i,j,l) + uu(i,j,k)
                  if (allocated(Sfluxu_s )) Sfluxu_s (i,j,l) = Sfluxu_s (i,j,l) + Sfluxu(i,j,k)
                  if (allocated(S2fluxu_s)) S2fluxu_s(i,j,l) = S2fluxu_s(i,j,l) + Sfluxu2(i,j,k)
               end do
            end if
         end do
      end do
      call toc(TIM_TEF)
   end if
   if (allocated(hv_s).or.allocated(vv_s).or.allocated(Sfluxv_s).or.allocated(S2fluxv_s)) then
      call tic(TIM_TEF)
      if (allocated(hv_s     )) hv_s      = _ZERO_
      if (allocated(vv_s     )) vv_s      = _ZERO_
      if (allocated(Sfluxv_s )) Sfluxv_s  = _ZERO_
      if (allocated(S2fluxv_s)) S2fluxv_s = _ZERO_
      do j = jmin-1,jmax
         do i = imin,imax
            if (av(i,j).eq.1 .or. av(i,j).eq.2) then
               where (vv(i,j,1:kmax) .ne. _ZERO_)
                  S_k(1:kmax) = Sfluxv(i,j,1:kmax) / vv(i,j,1:kmax) / DXV
               else where
                  S_k(1:kmax) = saltbinmax + 9999.0_rk
               end where
               call rebin(kmax,S_k(1:),saltbins,Si_s,binidx(1:))
               do k=1,kmax
                  l = binidx(k)
                  if (l .lt. 0) cycle
                  if (allocated(hv_s     )) hv_s     (i,j,l) = hv_s     (i,j,l) + hvn(i,j,k)
                  if (allocated(vv_s     )) vv_s     (i,j,l) = vv_s     (i,j,l) + vv(i,j,k)
                  if (allocated(Sfluxv_s )) Sfluxv_s (i,j,l) = Sfluxv_s (i,j,l) + Sfluxv(i,j,k)
                  if (allocated(S2fluxv_s)) S2fluxv_s(i,j,l) = S2fluxv_s(i,j,l) + Sfluxv2(i,j,k)
               end do
            end if
         end do
      end do
      call toc(TIM_TEF)
   end if


   if (salt_AH_method .gt. 0) then

!     S is not halo updated after advection
      call tic(TIM_SALTH)
      call update_3d_halo(S,S,az,imin,jmin,imax,jmax,kmax,D_TAG)
      call wait_halo(D_TAG)
      call toc(TIM_SALTH)

      if (    allocated(hu_s).or.allocated(uu_s).or.allocated(Sfluxu_s).or.allocated(S2fluxu_s)       &
          .or.allocated(hv_s).or.allocated(vv_s).or.allocated(Sfluxv_s).or.allocated(S2fluxv_s)) then
         call tic(TIM_TEF)
         Sold = S
         do j = jmin-1,jmax+1
            do i = imin-1,imax+1
!              Note (KK): here we need the bins also in az=2 !!!
               if (az(i,j).ne.0) call rebin(kmax,S(i,j,1:),saltbins,Si_s,binidx_kij(1:,i,j))
            end do
         end do
         call toc(TIM_TEF)
      end if
      if (allocated(hu_s).or.allocated(uu_s).or.allocated(Sfluxu_s)) Sfluxu_tmp = Sfluxu
      if (allocated(hv_s).or.allocated(vv_s).or.allocated(Sfluxv_s)) Sfluxv_tmp = Sfluxv

      call tracer_diffusion(S,hn,salt_AH_method,salt_AH_const,salt_AH_Prt,salt_AH_stirr_const, &
                            ffluxu=Sfluxu,ffluxv=Sfluxv,               &
                            phymix=phymix_S)
      if (allocated(Sfluxu2)) Sfluxu2 = Sfluxu2 + ffluxu2_3d
      if (allocated(Sfluxv2)) Sfluxv2 = Sfluxv2 + ffluxv2_3d

      if (allocated(hu_s).or.allocated(uu_s).or.allocated(Sfluxu_s).or.allocated(S2fluxu_s)) then
         call tic(TIM_TEF)
         Sfluxu_tmp = Sfluxu - Sfluxu_tmp
         do j = jmin,jmax
            do i = imin-1,imax
               if (au(i,j).eq.1 .or. au(i,j).eq.2) then
                  do k=1,kmax
                     if (Sfluxu_tmp(i,j,k) .ne. _ZERO_) then
                     deltaS = Sold(i+1,j,k) - Sold(i,j,k)
                     !if (deltaS .ne. _ZERO_) then
                        hu_k = _HALF_ * ( hn(i,j,k) + hn(i+1,j,k) )
                        uu_k = Sfluxu_tmp(i,j,k) / deltaS / DYU
!                       fluxes with S(i)
                        l = binidx_kij(k,i,j)
                        if (l .ge. 0) then
                           if (allocated(hu_s     )) hu_s     (i,j,l) = hu_s     (i,j,l) + hu_k
                           if (allocated(uu_s     )) uu_s     (i,j,l) = uu_s     (i,j,l) - uu_k
                           if (allocated(Sfluxu_s )) Sfluxu_s (i,j,l) = Sfluxu_s (i,j,l) - uu_k*DYU*Sold(i,j,k)
                           if (allocated(S2fluxu_s)) S2fluxu_s(i,j,l) = S2fluxu_s(i,j,l) - uu_k*DYU*Sold(i,j,k)*Sold(i,j,k)
                        end if
!                       fluxes with S(i+1)
                        l = binidx_kij(k,i+1,j)
                        if (l .ge. 0) then
                           if (allocated(hu_s     )) hu_s     (i,j,l) = hu_s     (i,j,l) + hu_k
                           if (allocated(uu_s     )) uu_s     (i,j,l) = uu_s     (i,j,l) + uu_k
                           if (allocated(Sfluxu_s )) Sfluxu_s (i,j,l) = Sfluxu_s (i,j,l) + uu_k*DYU*Sold(i+1,j,k)
                           if (allocated(S2fluxu_s)) S2fluxu_s(i,j,l) = S2fluxu_s(i,j,l) + uu_k*DYU*Sold(i+1,j,k)*Sold(i+1,j,k)
                        end if
                     end if
                  end do
               end if
            end do
         end do
         call toc(TIM_TEF)
      end if
      if (allocated(hv_s).or.allocated(vv_s).or.allocated(Sfluxv_s).or.allocated(S2fluxv_s)) then
         call tic(TIM_TEF)
         Sfluxv_tmp = Sfluxv - Sfluxv_tmp
         do j = jmin-1,jmax
            do i = imin,imax
               if (av(i,j).eq.1 .or. av(i,j).eq.2) then
                  do k=1,kmax
                     if (Sfluxv_tmp(i,j,k) .ne. _ZERO_) then
                     deltaS = Sold(i,j+1,k) - Sold(i,j,k)
                     !if (deltaS .ne. _ZERO_) then
                        hu_k = _HALF_ * ( hn(i,j,k) + hn(i,j+1,k) )
                        uu_k = Sfluxv_tmp(i,j,k) / deltaS / DXV
!                       fluxes with S(j)
                        l = binidx_kij(k,i,j)
                        if (l .ge. 0) then
                           if (allocated(hv_s     )) hv_s     (i,j,l) = hv_s     (i,j,l) + hu_k
                           if (allocated(vv_s     )) vv_s     (i,j,l) = vv_s     (i,j,l) - uu_k
                           if (allocated(Sfluxv_s )) Sfluxv_s (i,j,l) = Sfluxv_s (i,j,l) - uu_k*DXV*Sold(i,j,k)
                           if (allocated(S2fluxv_s)) S2fluxv_s(i,j,l) = S2fluxv_s(i,j,l) - uu_k*DXV*Sold(i,j,k)*Sold(i,j,k)
                        end if
!                       fluxes with S(j+1)
                        l = binidx_kij(k,i,j+1)
                        if (l .ge. 0) then
                           if (allocated(hv_s     )) hv_s     (i,j,l) = hv_s     (i,j,l) + hu_k
                           if (allocated(vv_s     )) vv_s     (i,j,l) = vv_s     (i,j,l) + uu_k
                           if (allocated(Sfluxv_s )) Sfluxv_s (i,j,l) = Sfluxv_s (i,j,l) + uu_k*DXV*Sold(i,j+1,k)
                           if (allocated(S2fluxv_s)) S2fluxv_s(i,j,l) = S2fluxv_s(i,j,l) + uu_k*DXV*Sold(i,j+1,k)*Sold(i,j+1,k)
                        end if
                     end if
                  end do
               end if
            end do
         end do
         call toc(TIM_TEF)
      end if

   end if


! OMP-NOTE: Pointers are used to for each thread to use its
!           own work storage.
!$OMP PARALLEL DEFAULT(SHARED)                                         &
!$OMP    PRIVATE(i,j,k,rc)                                             &
!$OMP    PRIVATE(Res,auxn,auxo,a1,a2,a3,a4,fluxw)

! Each thread allocates its own HEAP storage:
   allocate(Res(0:kmax),stat=rc)    ! work array
   if (rc /= 0) stop 'do_salinity: Error allocating memory (Res)'
   allocate(auxn(1:kmax-1),stat=rc)    ! work array
   if (rc /= 0) stop 'do_salinity: Error allocating memory (auxn)'
   allocate(auxo(1:kmax-1),stat=rc)    ! work array
   if (rc /= 0) stop 'do_salinity: Error allocating memory (auxo)'
   allocate(a1(0:kmax),stat=rc)    ! work array
   if (rc /= 0) stop 'do_salinity: Error allocating memory (a1)'
   allocate(a2(0:kmax),stat=rc)    ! work array
   if (rc /= 0) stop 'do_salinity: Error allocating memory (a2)'
   allocate(a3(0:kmax),stat=rc)    ! work array
   if (rc /= 0) stop 'do_salinity: Error allocating memory (a3)'
   allocate(a4(0:kmax),stat=rc)    ! work array
   if (rc /= 0) stop 'do_salinity: Error allocating memory (auxo)'

   fluxw => null()
   if (associated(Sfluxw) .or. associated(phymix_S)) then
      allocate(fluxw(0:kmax),stat=rc)    ! work array
      if (rc /= 0) stop 'do_salinity: Error allocating memory (fluxw)'
   end if

! Note: We do not need to initialize these work arrays.
!   Tested BJB 2009-09-25.

!  Advection and vertical diffusion and of salinity
!$OMP DO SCHEDULE(RUNTIME)
   do j=jmin,jmax
      do i=imin,imax
         if (az(i,j) .eq. 1) then
            if (kmax.gt.1) then
!     Auxilury terms, old and new time level,
               do k=1,kmax-1
                  auxo(k)=_TWO_*(1-cnpar)*dt*(nuh(i,j,k)+avmols)/ &
                             (hn(i,j,k+1)+hn(i,j,k))
                  auxn(k)=_TWO_*   cnpar *dt*(nuh(i,j,k)+avmols)/ &
                             (hn(i,j,k+1)+hn(i,j,k))
               end do

!        Matrix elements for surface layer
               k=kmax
               a1(k)=-auxn(k-1)
               a2(k)=hn(i,j,k)+auxn(k-1)
               a4(k)=S(i,j,k)*(hn(i,j,k)-auxo(k-1))+S(i,j,k-1)*auxo(k-1)
               if (allocated(sss)) then
!                 implicit nudging
                  a2(kmax) = a2(kmax) + dry_z(i,j)*hn(i,j,kmax)*dt/sss_nudging_time
                  a4(kmax) = a4(kmax) + dry_z(i,j)*hn(i,j,kmax)*dt/sss_nudging_time*sss(i,j)
!                 explicit nudging
                  !a4(kmax) = a4(kmax) - dry_z(i,j)*dt*hn(i,j,kmax)*(S(i,j,kmax)-sss(i,j))/sss_nudging_time
               end if

!        Matrix elements for inner layers
               do k=2,kmax-1
                  a3(k)=-auxn(k  )
                  a1(k)=-auxn(k-1)
                  a2(k)=hn(i,j,k)+auxn(k)+auxn(k-1)
                  a4(k)=S(i,j,k+1)*auxo(k)                           &
                       +S(i,j,k  )*(hn(i,j,k)-auxo(k)-auxo(k-1))     &
                       +S(i,j,k-1)*auxo(k-1)
               end do

!        Matrix elements for bottom layer
               k=1
               a3(k)=-auxn(k  )
               a2(k)=hn(i,j,k)+auxn(k)
               a4(k)=S(i,j,k+1)*auxo(k)                              &
                    +S(i,j,k  )*(hn(i,j,k)-auxo(k))

               call getm_tridiagonal(kmax,1,kmax,a1,a2,a3,a4,Res)

               if (associated(fluxw)) then
                  fluxw(   0) = _ZERO_
                  fluxw(kmax) = _ZERO_
                  do k=1,kmax-1
                     fluxw(k) = - (                                      &
                                     auxo(k) * ( S(i,j,k+1) - S(i,j,k) ) &
                                   + auxn(k) * ( Res  (k+1) - Res  (k) ) &
                                  ) / dt
                  end do
               end if
               if (associated(Sfluxw)) then
                  Sfluxw(i,j,1:kmax-1) = Sfluxw(i,j,1:kmax-1) + fluxw(1:kmax-1)
               end if
               if (associated(phymix_S)) then
!                 KK-TODO: only consider changes of S2 due to diffusion!
!                          (no sss_nudging or radiation for temperature)
                  do k=1,kmax-1
#if 0
                     fluxw(k) = fluxw(k) *                             &
                               (         cnpar *(Res  (k)+Res  (k+1))  &
                                + (_ONE_-cnpar)*(S(i,j,k)+S(i,j,k+1)))
#else
                     fluxw(k) = - (                                                              &
                                     auxo(k) * ( S(i,j,k+1) - S(i,j,k) ) * (S(i,j,k)+S(i,j,k+1)) &
                                   + auxn(k) * ( Res  (k+1) - Res  (k) ) * (Res  (k)+Res  (k+1)) &
                                  ) / dt
#endif

                  end do
                  do k=1,kmax
                     phymix_S(i,j,k) = phymix_S(i,j,k)                 &
                           - (  hn(i,j,k)*(Res(k)*Res(k) - S(i,j,k)*S(i,j,k))/dt &
                              + (fluxw(k)  - fluxw(k-1) ) )
                  end do
               end if

               do k=1,kmax
                  S(i,j,k)=Res(k)
               end do

            end if
         end if
      end do
   end do
!$OMP END DO

! Each thread must deallocate its own HEAP storage:
   deallocate(Res,stat=rc)
   if (rc /= 0) stop 'do_salinity: Error deallocating memory (Res)'
   deallocate(auxn,stat=rc)
   if (rc /= 0) stop 'do_salinity: Error deallocating memory (auxn)'
   deallocate(auxo,stat=rc)
   if (rc /= 0) stop 'do_salinity: Error deallocating memory (auxo)'
   deallocate(a1,stat=rc)
   if (rc /= 0) stop 'do_salinity: Error deallocating memory (a1)'
   deallocate(a2,stat=rc)
   if (rc /= 0) stop 'do_salinity: Error deallocating memory (a2)'
   deallocate(a3,stat=rc)
   if (rc /= 0) stop 'do_salinity: Error deallocating memory (a3)'
   deallocate(a4,stat=rc)
   if (rc /= 0) stop 'do_salinity: Error deallocating memory (a4)'

   if (associated(fluxw)) then
      deallocate(fluxw,stat=rc)
      if (rc /= 0) stop 'do_salinity: Error deallocating memory (fluxw)'
   end if

!$OMP END PARALLEL

   if (allocated(binidx_kij)) then
      do j = jmin,jmax
         do i = imin,imax
            if (az(i,j).eq.1) call rebin(kmax,S(i,j,1:),saltbins,Si_s,binidx_kij(1:,i,j))
         end do
      end do
   end if

   if (allocated(counts_s).or.allocated(flags_s).or.allocated(h_s).or.allocated(hS_s).or.allocated(hS2_s).or.allocated(hpmS_s).or.allocated(hnmS_s).or.allocated(phymix_S_fwf_s)) then
      call tic(TIM_TEF)
      if (allocated(counts_s)) counts_s = 0
      if (allocated(flags_s )) flags_s  = 0
      if (allocated(wdia_s  )) wdia_s   = h_s  ! need old h_s
      if (allocated(f3dia_s )) f3dia_s  = hS_s ! need old hS_s
      if (allocated(h_s     )) h_s      = _ZERO_
      if (allocated(hS_s    )) hS_s     = _ZERO_
      if (allocated(hS2_s   )) hS2_s    = _ZERO_
      if (allocated(hpmS_s  )) hpmS_s   = _ZERO_
      if (allocated(hnmS_s  )) hnmS_s   = _ZERO_
      if (allocated(phymix_S_fwf_s)) phymix_S_fwf_s = _ZERO_
      do j = jmin,jmax
         do i = imin,imax
            if (az(i,j) .eq. 1) then
               do k=1,kmax
                  l = binidx_kij(k,i,j)
                  if (l .lt. 0) cycle
                  if (allocated(counts_s)) counts_s(i,j,l) = counts_s(i,j,l) + 1
                  if (allocated(flags_s )) flags_s (i,j,l) = 1
                  if (allocated(h_s   )) h_s   (i,j,l) = h_s   (i,j,l) + hn(i,j,k)
                  if (allocated(hS_s  )) hS_s  (i,j,l) = hS_s  (i,j,l) + hn(i,j,k)*S(i,j,k)
                  if (allocated(hS2_s )) hS2_s (i,j,l) = hS2_s (i,j,l) + hn(i,j,k)*S(i,j,k)*S(i,j,k)
                  if (allocated(hpmS_s)) hpmS_s(i,j,l) = hpmS_s(i,j,l) + phymix_S(i,j,k)
                  if (allocated(hnmS_s)) hnmS_s(i,j,l) = hnmS_s(i,j,l) + nummix_S(i,j,k)
               end do
               l = binidx_kij(kmax,i,j)
               if (l .ge. 0) then
                  if (allocated(phymix_S_fwf_s)) phymix_S_fwf_s(i,j,l) = phymix_S_fwf(i,j)
               end if
            end if
         end do
      end do
      call toc(TIM_TEF)
   end if

   if (allocated(phymix_S_riv_s)) then
      call tic(TIM_TEF)
      phymix_S_riv_s = _ZERO_
      do nn=1,nriverl
         j = rj(nn)
         i = ri(nn)
         do k=1,kmax
            l = binidx_kij(k,i,j)
            if (l .lt. 0) cycle
            phymix_S_riv_s(i,j,l) = phymix_S_riv_s(i,j,l) + rivmix(k,nn)
         end do
      end do
      call toc(TIM_TEF)
   end if

#if 0
!  Note (KK): this does not work in case of zero-gradient river outflow
   do nn=1,nriverl
      n = rnl(nn)
      call rebin( 1 , river_salt(n:n) , saltbins , Si_s , binidx(nn:nn) )
   end do
#endif

   if (allocated(wdia_s)) then
      call tic(TIM_TEF)
      do l = 0,saltbins
         do j=jmin,jmax
            do i=imin,imax
               if (az(i,j) .eq. 1) then
!                 wdia_s contains old h_s
                  wdia_s(i,j,l) = - ( h_s (i,j,l) - wdia_s(i,j,l) ) / dt                        &
                                  - (   ( uu_s(i,j,l)*dyu(i,j) - uu_s(i-1,j  ,l)*dyu(i-1,j  ) ) &
                                      + ( vv_s(i,j,l)*dxv(i,j) - vv_s(i  ,j-1,l)*dxv(i  ,j-1) ) &
                                    ) * arcd1(i,j)
               end if
            end do
         end do
      end do
!     Note (KK): fwf_int contains bad data on land
      where (az .eq. 1) wdia_s(:,:,0) = wdia_s(:,:,0) + fwf_int / dt
      do nn=1,nriverl
         !l = binidx(nn)
         !if (l .lt. 0) cycle
         j = rj(nn)
         i = ri(nn)
         !wdia_s(i,j,l) = wdia_s(i,j,l) + rheight_int(nn) / dt
         wdia_s(i,j,:) = wdia_s(i,j,:) + rheight_int_s(:,nn) / dt
      end do
      do l = 1,saltbins
         wdia_s(:,:,l) = wdia_s(:,:,l) + wdia_s(:,:,l-1)
      end do
      call toc(TIM_TEF)
   end if

   if (allocated(f3dia_s)) then
      call tic(TIM_TEF)
      do l = 0,saltbins
         do j=jmin,jmax
            do i=imin,imax
               if (az(i,j) .eq. 1) then
!                 f3dia_s contains old hS_s
                  f3dia_s(i,j,l) = - ( hS_s (i,j,l) - f3dia_s(i,j,l) ) / dt        &
                                   - (   ( Sfluxu_s(i,j,l) - Sfluxu_s(i-1,j  ,l) ) &
                                       + ( Sfluxv_s(i,j,l) - Sfluxv_s(i  ,j-1,l) ) &
                                     ) * arcd1(i,j)
               end if
            end do
         end do
      end do
      do nn=1,nriverl
         !l = binidx(nn)
         !if (l .lt. 0) cycle
         !n = rnl(nn)
         j = rj(nn)
         i = ri(nn)
         !f3dia_s(i,j,l) = f3dia_s(i,j,l) + river_salt(n) * rheight_int(nn) / dt
         f3dia_s(i,j,:) = f3dia_s(i,j,:) + rsflux_s(:,nn)
      end do
      do l = 1,saltbins
         f3dia_s(:,:,l) = f3dia_s(:,:,l) + f3dia_s(:,:,l-1)
      end do
      call toc(TIM_TEF)
   end if

   if (salt_check .ne. 0 .and. mod(loop,abs(salt_check)) .eq. 0) then
      call check_3d_fields(imin,jmin,imax,jmax,kmin,kmax,az, &
                           S,min_salt,max_salt,status)
      if (status .gt. 0) then
         if (salt_check .gt. 0) then
            call getm_error("do_salinity()", &
                            "out-of-bound values encountered")
         end if
         if (salt_check .lt. 0) then
            LEVEL1 'do_salinity(): ',status, &
                   ' out-of-bound values encountered'
         end if
      end if
   end if

   call toc(TIM_SALT)
#ifdef DEBUG
   write(debug,*) 'Leaving do_salinity()'
   write(debug,*)
#endif
   return
   end subroutine do_salinity
!EOC

!-----------------------------------------------------------------------

   end module salinity

!-----------------------------------------------------------------------
! Copyright (C) 2001 - Hans Burchard and Karsten Bolding               !
!-----------------------------------------------------------------------
