PROGRAM letkf
!=======================================================================
! 4D-LETKF with Lorenz-96
! Hybrid Offline Online Parameter Estimation with EnKF (HOOPE-EnKF)
! by Y.Sawada 20230113
! by Y.Sawada 20230119
! parameter search by Y.Sawada 20230302
!=======================================================================
  USE common
  USE common_letkf
  USE lorenz96
!  USE lorenz96_oro
  USE h_ope


  IMPLICIT NONE

  INCLUDE 'mpif.h'  ! YSaw 20230302
  LOGICAL,PARAMETER :: msw_detailout=.TRUE. ! YSaw 20221206
  INTEGER,PARAMETER :: ndays=360*5
  INTEGER,PARAMETER :: nt=ndays*8-1 !8
  INTEGER,PARAMETER :: nwindow=1 ! time window for 4D-LETKF
  INTEGER,PARAMETER :: nspinup=360*3*4 ! time steps for spin-up
  INTEGER,PARAMETER :: msw_local=0 ! localization mode switch
  !INTEGER,PARAMETER :: np = nx     ! the number of parameters by YSaw/ np is defined in obs_hoope/h_ope.f90
! msw_local : localization mode switch
!  0 : fixed localization
!  1 : adaptive localization
!  2 : combination of both
  REAL(r_size) :: xlocal=3.0d0 !3.0d0 ! localization scale
  REAL(r_size),PARAMETER :: tlocal=2.0d0 ! time localization scale
! negative value for no time localization
  REAL(r_size) :: sa=1.0d0 ! adaptive localization parameter
  REAL(r_size) :: sb=1.0d0 ! adaptive localization parameter
  REAL(r_size),PARAMETER :: msw_infl_min=1.05d0 ! inflation mode switch
  REAL(r_size),PARAMETER :: msw_infl_max=2.55d0 ! inflation mode switch
  REAL(r_size),PARAMETER :: msw_infl_para_min=1.05d0 ! inflation mode switch
  REAL(r_size),PARAMETER :: msw_infl_para_max=7.05d0 ! inflation mode switch
  REAL(r_size) :: msw_infl, msw_infl_para
  REAL(r_size), ALLOCATABLE :: parabox(:,:,:)
  LOGICAL,PARAMETER :: AdaptiveInfl = .FALSE. !YSaw 20230306
! msw_infl : inflation mode switch
!  < 0 : adaptive inflation
!  > 0 : fixed inflation value
  REAL(r_size) :: parm_infl(nx+np,nt) ! inflation parameter np+np: augmented state vector Y.Saw
  REAL(r_size) :: parm
  REAL(r_size) :: xmaxloc
  REAL(r_size) :: obserr=0.1d0 !0.1d0 !1.0d0
  REAL(r_size) :: obserr_para=3.15d0 ! Y.Saw 
  REAL(r_sngl) :: y4(ny+np) !ny+np YSaw
  REAL(r_sngl) :: x4(nx)
  REAL(r_sngl) :: x4_long(nx+np) ! YSaw
  REAL(r_size) :: xnature(nx,nt)
  REAL(r_size) :: xa(nx+np,nbv,nwindow) !nx+np: augmented state vector Y.Saw
  REAL(r_size) :: xf(nx+np,nbv,nwindow) !nx+np: augmented state vector Y.Saw
  REAL(r_size) :: dxf(nx+np,nbv,nwindow)!nx+np: augmented state vector Y.Saw
  REAL(r_size) :: xm(nx+np,nwindow)     !nx+np: augmented state vector Y.Saw
  REAL(r_size) :: y(ny+np,nwindow) !ny+np YSaw
  REAL(r_size) :: d(ny+np,nwindow) !ny+np YSaw
  REAL(r_size) :: h4d(ny+np,nx+np,nwindow) !nx+np: augmented state vector & ny+np YSaw
  REAL(r_size) :: hxf(ny+np,nbv,nwindow) !ny+np YSaw
  REAL(r_size) :: hxfm(ny+np,nwindow) !ny+np YSaw
  REAL(r_size) :: hdxf(ny+np,nbv,nwindow) !ny+np YSaw
  REAL(r_size) :: rdiag_loc((ny+np)*nwindow) !ny+np YSaw
  REAL(r_size) :: rloc_loc((ny+np)*nwindow) !ny+np YSaw
  REAL(r_size) :: d_loc((ny+np)*nwindow) !ny+np YSaw
  REAL(r_size) :: hdxf_loc((ny+np)*nwindow,nbv) !ny+np YSaw
  REAL(r_size) :: trans(nbv,nbv)
  REAL(r_size) :: dist,tdif
  REAL(r_size) :: rmse_t(nt),sprd_t(nt),infl_t(nt)
  REAL(r_size) :: rmse_x(nx),sprd_x(nx),infl_x(nx)
  REAL(r_size) :: rmseave,sprdave,inflave
  REAL(r_size) :: obsloc(3),wa,wb
  INTEGER :: irmse
  INTEGER :: ktoneday
  INTEGER :: ktcyc
  INTEGER :: i,j,n,nn,it,ios
  INTEGER :: ix
  INTEGER :: ixloc,jloc ! YSaw
  INTEGER :: ny_loc
  INTEGER :: nbv2
  INTEGER :: ierr, petot, my_rank, sqpetot !YSaw 20230302
  CHARACTER(10) :: initfile='init00.dat'
  CHARACTER(10) :: my_rank_name ! YSaw 20230302
  LOGICAL,PARAMETER :: valloc = .TRUE.

!-----------------------------------------------------------------------
! call mpi
!-----------------------------------------------------------------------
  CALL MPI_Init(ierr)
  CALL MPI_COMM_size(MPI_COMM_WORLD,petot,ierr)
  CALL MPI_COMM_rank(MPI_COMM_WORLD,my_rank,ierr)
!-----------------------------------------------------------------------
! model parameters
!-----------------------------------------------------------------------
  sqpetot = INT(SQRT(REAL(petot)))
  ALLOCATE(parabox(sqpetot,sqpetot,2))
  DO i = 1, sqpetot
   DO j = 1, sqpetot
     parabox(i,j,1) = msw_infl_min + (i-1) * (msw_infl_max - msw_infl_min)/sqpetot !state inflation
     parabox(i,j,2) = msw_infl_para_min + (j-1) * (msw_infl_para_max - msw_infl_para_min)/sqpetot ! para inflation
   ENDDO
  ENDDO
  msw_infl = parabox(mod(my_rank+1,sqpetot),(my_rank+1)/sqpetot+1,1)
  msw_infl_para = parabox(mod(my_rank+1,sqpetot),(my_rank+1)/sqpetot+1,2)
  PRINT*, 'inflation factors',  msw_infl, msw_infl_para, my_rank
  IF (AdaptiveInfl) THEN
          msw_infl = -1.0d0 ! <0 means adaptive inflation
  ENDIF
  dt=0.0005d0
  force=13.09 !9.527d0
  oneday=0.2d0
  ktoneday = INT(oneday/dt)
  ktcyc = ktoneday/4
  xmaxloc = xlocal * 2.0d0 * SQRT(10.0d0/3.0d0)
  nbv2 = CEILING(REAL(nbv)/2.0)
  PRINT '(A)'     ,'==========LETKF settings=========='
  PRINT '(A,I8)'  ,' nbv       : ',nbv
  PRINT '(A,I8)'  ,' ny        : ',ny
  PRINT '(A,I8)'  ,' nwindow   : ',nwindow
  PRINT '(A,I8)'  ,' msw_local : ',msw_local
  PRINT '(A,F8.1)',' xlocal    : ',xlocal
  PRINT '(A,F8.1)',' tlocal    : ',tlocal
  PRINT '(A,F8.1)',' sa        : ',sa
  PRINT '(A,F8.1)',' sb        : ',sb
  PRINT '(A,F8.2)',' msw_infl  : ',msw_infl
  PRINT '(A)'     ,'=================================='
!-----------------------------------------------------------------------
! nature
!-----------------------------------------------------------------------
  OPEN(10,FILE='naturex.dat',FORM='unformatted')
  DO i=1,nt
    READ(10) x4
    xnature(:,i) = REAL(x4,r_size)
  END DO
  CLOSE(10)
!-----------------------------------------------------------------------
! initial conditions 'initXX.dat'
!-----------------------------------------------------------------------
! State variables
  DO i=1,nbv
    WRITE(initfile(5:6),'(I2.2)') i-1
    OPEN(10,FILE=initfile,FORM='unformatted')
    READ(10) xf(1:nx,i,1) ! YSaw
    CLOSE(10)
  END DO
! Parameters
  DO i=1,nbv
    !CALL com_randn(np,xf(nx+1:nx+np,i,1)) ! fixed parameters
    CALL com_rand(np,xf(nx+1:nx+np,i,1)) ! fixed parameters
    !xf(nx+1:nx+np,i,1) = force + xf(nx+1:nx+np,i,1) * 2.87d0
    xf(nx+1:nx+np,i,1) = 1.0d0 + xf(nx+1:nx+np,i,1) * 29.0d0 !drawn from uniform distribution
    !PRINT*, 'at ', i, 'para = ', xf(nx+1:nx+np,i,1)
  ENDDO 
!-----------------------------------------------------------------------
! main
!-----------------------------------------------------------------------
  irmse = 0
  rmse_t = 0.0d0
  rmse_x = 0.0d0
  sprd_t = 0.0d0
  sprd_x = 0.0d0
  infl_t = 0.0d0
  infl_x = 0.0d0
  parm_infl(1:nx,1) = ABS(msw_infl)
  parm_infl(nx+1:nx+np,1) = ABS(msw_infl_para) ! different inflation factor for parameter Y.Saw
  !
  ! input files
  !
  OPEN(11,FILE='obs_pseudo_obs.dat',FORM='unformatted')
  !
  ! output files
  !
  WRITE(my_rank_name,'(i4.4)') my_rank
  OPEN(90,FILE='guesmean'//TRIM(my_rank_name)//'.dat',FORM='unformatted')
  OPEN(91,FILE='analmean'//TRIM(my_rank_name)//'.dat',FORM='unformatted')
  OPEN(92,FILE='gues'//TRIM(my_rank_name)//'.dat',FORM='unformatted')
  OPEN(93,FILE='anal'//TRIM(my_rank_name)//'.dat',FORM='unformatted')
  !>>>
  !>>> LOOP START
  !>>>
  it=1
  DO
    !PRINT *, 'now at ', it, my_rank
    !
    ! read obs
    !
    DO i=1,nwindow
      READ(11) y4
      y(:,i) = REAL(y4,r_size)
    END DO
    !
    ! 4d first guess
    !
    IF(nwindow > 1) THEN
      DO i=2,nwindow
        DO j=1,nbv
          CALL tinteg_rk4_varyingF(ktcyc,xf(1:nx,j,i-1),xf(nx+1:nx+np,j,i-1),xf(1:nx,j,i))
          xf(nx+1:nx+np,j,i) = xf(nx+1:nx+np,j,i-1) ! persistent model YSaw
        END DO
      END DO
    END IF
    !
    ! ensemble mean -> xm
    !
    DO j=1,nwindow
      DO i=1,nx + np ! YSaw
        CALL com_mean(nbv,xf(i,:,j),xm(i,j))
      END DO
    END DO
    !
    ! ensemble ptb -> dxf
    !
    DO j=1,nwindow
      DO i=1,nbv
        dxf(:,i,j) = xf(:,i,j) - xm(:,j)
        !PRINT*, "dxf = ", dxf(:,i,j), i
      END DO
    END DO
    !
    ! output first guess
    !
    IF(msw_detailout) THEN
      DO j=1,nwindow
        x4_long = xm(:,j)
        WRITE(90) x4_long
    !    DO i=1,nbv
    !      x4_long = xf(:,i,j)
    !      WRITE(92) x4_long
    !    END DO
      END DO
    END IF
    !---------------
    ! analysis step
    !---------------
    !
    ! hxf = H xf
    !
    ! State
    DO n=1,nwindow
      DO j=1,nbv
        CALL set_h(xf(:,j,n)) ! this xf is actually not used in set_h
        h4d(:,:,n) = h ! h has the dimension of nx+np in obs_hoope
        !h4d(:,nx+1:nx+np,n) = 0.0d0 ! Parameters Y.Saw
        hxf(:,j,n) = h4d(:,1,n) * xf(1,j,n)
        DO i=2,nx+np !YSaw
          hxf(:,j,n) = hxf(:,j,n) +  h4d(:,i,n) * xf(i,j,n)
        END DO
      END DO
    END DO
    !
    ! hxfm = mean(H xf)
    !
    DO n=1,nwindow
      DO i=1,ny+np !ny+np YSaw
        CALL com_mean(nbv,hxf(i,:,n),hxfm(i,n))
      END DO
    END DO
    !
    ! d = y - hxfm
    !
    d = y - hxfm
    !
    ! hdxf
    !
    DO n=1,nwindow
      DO i=1,nbv
        hdxf(:,i,n) = hxf(:,i,n) - hxfm(:,n)
      END DO
    END DO
    !
    ! LETKF
    !
    DO nn=1,nwindow
      DO ix=1,nx+np ! YSaw
        ny_loc = 0
        parm = parm_infl(ix,it+nn-1)
        DO n=1,nwindow
          tdif = REAL(ABS(nn-n),r_size)
          IF(tlocal < 0.0d0) tdif = 0.0d0
          DO i=1,ny+np
            DO j=1,nx+np ! Y.Saw state & parameters are observed
              IF(MAXVAL(h4d(i,:,n)) == h4d(i,j,n)) EXIT
            END DO
            ! model vs parameter spaces
            IF(j > nx) THEN
               jloc = j-nx
            ELSE
               jloc = j
            ENDIF
            IF(ix > nx) THEN
               ixloc = ix-nx
            ELSE
               ixloc = ix
            ENDIF
            !dist = REAL(MIN(ABS(j-ix),nx-ABS(j-ix)),r_size)
            dist = REAL(MIN(ABS(jloc-ixloc),nx-ABS(jloc-ixloc)),r_size)
            IF (j <= nx) xmaxloc = xlocal * 2.0d0 * SQRT(10.0d0/3.0d0) ! State
            IF (j > nx) xmaxloc = 0.1 ! parameter
            IF (valloc) THEN ! if true, parameter observation is not assimilated into state space
                    IF(ix <= nx .and. j > nx) dist = xmaxloc + 1.0d0
            ENDIF
            !IF (j > nx .and. mod(it,100) .ne.0) dist = xmaxloc + 1.0d0
            !IF (mod(it,2) .eq. 0) dist = xmaxloc + 1.0d0 ! no da
            !PRINT*, 'observation used ', dist, jloc, ixloc, j, ix, xmaxloc, parm
            IF(dist < xmaxloc) THEN
              ny_loc = ny_loc+1
              d_loc(ny_loc) = d(i,n)
              rdiag_loc(ny_loc) = obserr**2
              IF (j > nx) rdiag_loc(ny_loc) = obserr_para**2 ! parameter observation
              IF(msw_local == 0) THEN ! fixed localization
                rloc_loc(ny_loc) = EXP(-0.5 * (dist/xlocal)**2) &! space
                  & * EXP(-0.5 * (tdif/tlocal)**2) ! time
              ELSE ! adaptive localization
                CALL com_correl(nbv2,hdxf(i,1:nbv2,n),dxf(ix,1:nbv2,nn),obsloc(1))
                CALL com_correl(nbv-nbv2,hdxf(i,nbv2+1:nbv,n),dxf(ix,nbv2+1:nbv,nn),obsloc(2))
                CALL com_correl(nbv,hdxf(i,:,n),dxf(ix,:,nn),obsloc(3))
                wb = ABS(obsloc(3))
                wa = 1.0d0 - (0.5d0*ABS(obsloc(1)-obsloc(2)))
                rloc_loc(ny_loc) = wa**sa * wb**sb
                IF(msw_local == 2) THEN
                  rloc_loc(ny_loc) = rloc_loc(ny_loc) &
                  & * EXP(-0.5 * (dist/xlocal)**2) &! space
                  & * EXP(-0.5 * (tdif/tlocal)**2)  ! time
                END IF
              END IF
              hdxf_loc(ny_loc,:) = hdxf(i,:,n)
              IF(rloc_loc(ny_loc) < 0.0001d0) ny_loc = ny_loc-1
            END IF
          END DO
        END DO
        CALL letkf_core(nbv,(ny+np)*nwindow,ny_loc,hdxf_loc,rdiag_loc,rloc_loc,d_loc,parm,trans) !ny-->ny+np YSaw
        IF(msw_infl > 0.0d0) THEN
                IF (ix > nx) THEN ! Different inflation factor for parameter and state
                        parm = msw_infl_para
                ELSE
                        parm = msw_infl
                ENDIF
        ENDIF  
        DO j=1,nbv
          xa(ix,j,nn) = xm(ix,nn)
          DO i=1,nbv
            xa(ix,j,nn) = xa(ix,j,nn) + dxf(ix,i,nn) * trans(i,j)
          END DO
        END DO
        !IF (it < 100) THEN ! spinup No data assimilation
        !  DO j = 1, nbv
        !    xa(ix,j,nn) = xf(ix,j,nn)
        !  ENDDO
        !ENDIF
        !PRINT*, 'at ', ix, 'trans = ', trans
        parm_infl(ix,it+nn) = parm
      END DO
    END DO
    !
    ! ensemble mean
    !
    DO n=1,nwindow
      DO i=1,nx+np
        CALL com_mean(nbv,xa(i,:,n),xm(i,n))
      END DO
    END DO
    !
    ! output analysis
    !
    IF(msw_detailout) THEN
      DO n=1,nwindow
        x4_long = xm(:,n)
        WRITE(91) x4_long
        IF(my_rank == 133) THEN
        DO i=1,nbv
          x4_long = xa(:,i,n)
          WRITE(93) x4_long
        END DO
        ENDIF
      END DO
    END IF
    !
    ! RMSE,SPRD
    !
    DO n=1,nwindow
      rmse_t(it+n-1) = SQRT(SUM((xm(1:nx,n)-xnature(:,it+n-1))**2)/REAL(nx,r_size))
      DO i=1,nx
        sprd_t(it+n-1) = sprd_t(it+n-1) + SUM((xa(i,:,n)-xm(i,n))**2)
      END DO
      sprd_t(it+n-1) = SQRT(sprd_t(it+n-1)/REAL(nx*nbv,r_size))
      infl_t(it+n-1) = SUM(parm_infl(:,it+n-1))/REAL(nx,r_size)
      IF(it > nspinup) THEN
        DO i=1,nx
          rmse_x(i) = rmse_x(i) + (xm(i,n)-xnature(i,it+n-1))**2
          sprd_x(i) = sprd_x(i) + SUM((xa(i,:,n)-xm(i,n))**2)/REAL(nbv,r_size)
          infl_x(i) = infl_x(i) + parm_infl(i,it+n-1)
        END DO
        irmse = irmse + 1
      END IF
    END DO
    !---------------
    ! forecast step
    !---------------
    DO i=1,nbv
      CALL tinteg_rk4_varyingF(ktcyc,xa(1:nx,i,nwindow),xa(nx+1:nx+np,i,nwindow),xf(1:nx,i,1))
      xf(nx+1:nx+np,i,1) = xa(nx+1:nx+np,i,nwindow) ! persistent for parameter Y.Saw
    END DO
    it = it+nwindow
    IF(it > nt) EXIT
  END DO
  !<<<
  !<<< LOOP END
  !<<<
  CLOSE(11)
  CLOSE(90)
  CLOSE(91)
  CLOSE(92)
  CLOSE(93)
  !OPEN(10,FILE='infl.dat',FORM='unformatted')
  !DO i=1,nt
  !  x4_long = REAL(parm_infl(:,i),r_sngl) ! Y.Saw
  !  WRITE(10) x4_long
  !END DO
  !CLOSE(10)

  !OPEN(10,FILE='rmse_t.dat',FORM='formatted')
  !DO i=1,nt
  !  WRITE(10,'(3F12.4)') REAL(i-1)/4.0,rmse_t(i),sprd_t(i)
  !END DO
  !CLOSE(10)

  rmse_x = SQRT(rmse_x / REAL(irmse,r_size))
  sprd_x = SQRT(sprd_x / REAL(irmse,r_size))
  !OPEN(10,FILE='rmse_x.dat',FORM='formatted')
  !DO i=1,nx
  !  WRITE(10,'(I4,2F12.4)') i,rmse_x(i),sprd_x(i)
  !END DO
  !CLOSE(10)

  rmseave = SUM(rmse_t(nspinup+1:nt))/REAL(nt-nspinup,r_size)
  sprdave = SUM(sprd_t(nspinup+1:nt))/REAL(nt-nspinup,r_size)
  inflave = SUM(infl_t(nspinup+1:nt))/REAL(nt-nspinup,r_size)
  PRINT '(A,F12.5)','RMSE = ',rmseave
  PRINT '(A,F12.5)','SPRD = ',sprdave
  PRINT '(A,F12.5)','INFL = ',inflave

  PRINT *,'end at ', my_rank
  CALL MPI_Finalize(ierr)
  STOP
END PROGRAM letkf
