#
# Analyzing F and state variables in two-scale Lorenz96
# Mimicking Pathiraja and van Leeuwen 2022 JAMES
# Y.Sawada 20230303
#

import os
from pylab import *
import numpy as np
import matplotlib.pyplot as plt
import numpy.ma as ma
import struct
from mpl_toolkits.mplot3d import Axes3D
from scipy.io import FortranFile
from sklearn.metrics import mean_squared_error

#
# configurations
#
xx = 18     # dimension
xx_nat = 9     # dimension
tt = 14399 # total time it should be 14400
nn = 1 #961 # number of settings
dataxa = np.zeros((tt,xx,nn))
datanat = np.zeros((tt,xx_nat))
datafor = np.zeros((tt,xx_nat))

RmseState = np.zeros((nn))
RmsePara = np.zeros((nn))
CorrState = np.zeros((nn))
CorrPara = np.zeros((nn))

#BaseFileName = '../DATA/regular4/mpi_letkf2/guesmean'
#BaseFileName = '../DATA/regular4_hoope_AR_obserr10/mpi_hoope2_ens40/guesmean'
#BaseFileName = '../DATA/regular4/adaptive_letkf2_ens10/guesmean'
#BaseFileName = '../DATA/regular4_pseudo_obs/hoope_pso_ens10/analmean'
#BaseFileName = '../DATA/regular4_pseudo_obs/hoope_pso_adaptive_ens40/analmean'
BaseFileName = '../DATA/regular4/hoope_rtc_ens40/analmean'
#BaseFileName = '../DATA/regular4/letkf_ens10/analmean'
#BaseFileName = '../DATA/regular4/letkf_adaptive_ens40/analmean'
#BaseFileName = '../DATA/regular4/hoope_rtc_adaptive_ens40/analmean'

NeedOriginalData = False


#
# Reading files
#
if NeedOriginalData:
    for i in range(0,nn):
        print('reading file no. ', i)
        if i < 10:
            FileName = BaseFileName + '000' + str(i) + '.dat'
        elif i < 100:
            FileName = BaseFileName + '00' + str(i) + '.dat'
        elif i < 1000:
            FileName = BaseFileName + '0' + str(i) + '.dat'
        else:
            FileName = BaseFileName + str(i) + '.dat'
        FileSize = os.path.getsize(FileName)
        if FileSize == 1151920:
            f = FortranFile(FileName,'r')
            for j in range(0,tt):
                dataxa[j,:,i] = f.read_reals(float32)
        else:
            dataxa[:,:,i] = -999999
    f = FortranFile('../DATA/naturex.dat','r')
    for j in range(0,tt):
        datanat[j,:] = f.read_reals(float32)
    f = FortranFile('../DATA/force.dat','r')
    for j in range(0,tt):
        datafor[j,:] = f.read_reals(float32)
    #
    # Calculating mean RMSE
    #
    for i in range(0,nn):
        RmseState[i] = np.sqrt(mean_squared_error(dataxa[5000:,0:9,i],datanat[5000:,:]))
        RmsePara[i] = np.sqrt(mean_squared_error(dataxa[5000:,9:18,i],datafor[5000:,:]))
        CorrState[i] = np.corrcoef(dataxa[5000:,0:9,i].flatten(),datanat[5000:,:].flatten()) [0,1]
        CorrPara[i] = np.corrcoef(dataxa[5000:,9:18,i].flatten(),datafor[5000:,:].flatten()) [0,1]
    #    for t in range(5000,tt):
    #        RmseState[i] = RmseState[i] + np.sqrt(mean_squared_error(dataxa[t,0:9,i],datanat[t,:]))
    #        RmsePara[i] = RmsePara[i] + np.sqrt(mean_squared_error(dataxa[t,9:18,i],datafor[t,:]))
    #    RmseState[i] = RmseState[i]/len(dataxa[5000:tt,0,0])
    #    RmsePara[i] = RmsePara[i]/len(dataxa[5000:tt,0,0])
    #np.savetxt('RmseState_hoope_rtc_adaptive_ens40',RmseState)
    #np.savetxt('RmsePara_hoope_rtc_adaptive_ens40',RmsePara)
    np.savetxt('RmseState_nohoope_adaptive_ens40',RmseState)
    np.savetxt('RmsePara_nohoope_adaptive_ens40',RmsePara)
    #np.savetxt('CorrState_hoope_rtc_adaptive_ens40',CorrState)
    #np.savetxt('CorrPara_hoope_rtc_adaptive_ens40',CorrPara)
    np.savetxt('CorrState_nohoope_adaptive_ens40',CorrState)
    np.savetxt('CorrPara_nohoope_adaptive_ens40',CorrPara)

#sys.exit()
RmseState = np.loadtxt('RmseState_hoope_rtc_ens10').reshape(31,31)
RmsePara = np.loadtxt('RmsePara_hoope_rtc_ens10').reshape(31,31)
#RmseState = np.loadtxt('RmseState_nohoope_ens40').reshape(31,31)
#RmsePara = np.loadtxt('RmsePara_nohoope_ens40').reshape(31,31)
CorrState = np.loadtxt('CorrState_hoope_rtc_ens10').reshape(31,31)
CorrPara = np.loadtxt('CorrPara_hoope_rtc_ens10').reshape(31,31)
#CorrState = np.loadtxt('CorrState_nohoope_ens40').reshape(31,31)
#CorrPara = np.loadtxt('CorrPara_nohoope_ens40').reshape(31,31)
RmseState = np.ma.masked_where(RmseState>1e+6,RmseState)
RmsePara = np.ma.masked_where(RmsePara>1e+6,RmsePara)
CorrState = np.ma.masked_where(RmseState>1e+6,CorrState)
CorrPara = np.ma.masked_where(RmsePara>1e+6,CorrPara)

print(RmseState[0,0:30])
fig,ax = plt.subplots()
#aximg = ax.imshow(RmseState,cmap='rainbow',interpolation='nearest',vmax=np.max(RmseState),vmin=np.min(RmseState),origin='lower')
aximg = ax.imshow(RmseState,cmap='rainbow',interpolation='nearest',vmax=1.00,vmin=0.40,origin='lower')
#aximg = ax.imshow(RmseState,cmap='rainbow',interpolation='nearest',vmax=5.0,vmin=0.55,origin='lower')
#xtick = np.array(["1.05","1.145","1.24","1.335","1.43","1.525","1.62","1.715","1.81","1.905","2.0"])
#xtick = np.array(["1.05","1.445","1.24","1.335","1.43","1.525","1.62","1.715","1.81","1.905","2.0"])
xtick = np.array(["1.05","1.20","1.35","1.50","1.65","1.80","1.95","2.10","2.25","2.40","2.55"])
xlocs = np.linspace(0,30,11)
#ytick = np.array(["1.05","1.445","1.84","2.335","2.63","3.025","3.42","3.815","4.21","4.605","5.0"])
ytick = np.array(["1.05","1.65","2.25","2.85","3.45","4.05","4.65","5.25","5.85","6.45","7.05"])
ylocs = np.linspace(0,30,11)
plt.xticks(xlocs,xtick,fontsize=10)
plt.yticks(ylocs,ytick,fontsize=10)
fig.colorbar(aximg,ax=ax,shrink=0.6)
xlabel('state inflation',fontsize=10)
ylabel('parameter inflation',fontsize=10)
plt.savefig('hoope_rtc_state_ens10_RMSE.png')
#plt.savefig('hoope_pso_state_ens40_RMSE.png')
#plt.savefig('nohoope_state_ens40_RMSE.png')
plt.show()
#plt.gca().clear()

fig,ax = plt.subplots()
#aximg = ax.imshow(RmsePara,cmap='rainbow',interpolation='nearest',vmax=np.max(RmsePara),vmin=np.min(RmsePara),origin='lower')
aximg = ax.imshow(RmsePara,cmap='rainbow',interpolation='nearest',vmax=3.5,vmin=2.2,origin='lower')
#xtick = np.array(["1.05","1.145","1.24","1.335","1.43","1.525","1.62","1.715","1.81","1.905","2.0"])
xtick = np.array(["1.05","1.20","1.35","1.50","1.65","1.80","1.95","2.10","2.25","2.40","2.55"])
xlocs = np.linspace(0,30,11)
#ytick = np.array(["1.05","1.445","1.84","2.335","2.63","3.025","3.42","3.815","4.21","4.605","5.0"])
ytick = np.array(["1.05","1.65","2.25","2.85","3.45","4.05","4.65","5.25","5.85","6.45","7.05"])
ylocs = np.linspace(0,30,11)
plt.xticks(xlocs,xtick,fontsize=10)
plt.yticks(ylocs,ytick,fontsize=10)
fig.colorbar(aximg,ax=ax,shrink=0.6)
xlabel('state inflation',fontsize=10)
ylabel('parameter inflation',fontsize=10)
plt.savefig('hoope_rtc_parameter_ens10_RMSE.png')
#plt.savefig('hoope_pso_parameter_ens40_RMSE.png')
#plt.savefig('nohoope_parameter_ens40_RMSE.png')
plt.show()

fig,ax = plt.subplots()
#aximg = ax.imshow(RmseState,cmap='rainbow',interpolation='nearest',vmax=np.max(RmseState),vmin=np.min(RmseState),origin='lower')
aximg = ax.imshow(CorrState,cmap='rainbow',interpolation='nearest',vmax=1.0,vmin=0.8,origin='lower')
#aximg = ax.imshow(RmseState,cmap='rainbow',interpolation='nearest',vmax=5.0,vmin=0.55,origin='lower')
#xtick = np.array(["1.05","1.145","1.24","1.335","1.43","1.525","1.62","1.715","1.81","1.905","2.0"])
xtick = np.array(["1.05","1.20","1.35","1.50","1.65","1.80","1.95","2.10","2.25","2.40","2.55"])
xlocs = np.linspace(0,30,11)
#ytick = np.array(["1.05","1.445","1.84","2.335","2.63","3.025","3.42","3.815","4.21","4.605","5.0"])
ytick = np.array(["1.05","1.65","2.25","2.85","3.45","4.05","4.65","5.25","5.85","6.45","7.05"])
ylocs = np.linspace(0,30,11)
plt.xticks(xlocs,xtick,fontsize=10)
plt.yticks(ylocs,ytick,fontsize=10)
fig.colorbar(aximg,ax=ax,shrink=0.6)
xlabel('state inflation',fontsize=10)
ylabel('parameter inflation',fontsize=10)
plt.savefig('hoope_rtc_state_ens10_R.png')
#plt.savefig('hoope_pso_state_ens40_R.png')
#plt.savefig('nohoope_state_ens40_R.png')
plt.show()
#plt.gca().clear()

fig,ax = plt.subplots()
#aximg = ax.imshow(RmseState,cmap='rainbow',interpolation='nearest',vmax=np.max(RmseState),vmin=np.min(RmseState),origin='lower')
aximg = ax.imshow(CorrPara,cmap='rainbow',interpolation='nearest',vmax=0.5,vmin=0.0,origin='lower')
#aximg = ax.imshow(RmseState,cmap='rainbow',interpolation='nearest',vmax=5.0,vmin=0.55,origin='lower')
#xtick = np.array(["1.05","1.145","1.24","1.335","1.43","1.525","1.62","1.715","1.81","1.905","2.0"])
xtick = np.array(["1.05","1.20","1.35","1.50","1.65","1.80","1.95","2.10","2.25","2.40","2.55"])
xlocs = np.linspace(0,30,11)
#ytick = np.array(["1.05","1.445","1.84","2.335","2.63","3.025","3.42","3.815","4.21","4.605","5.0"])
ytick = np.array(["1.05","1.65","2.25","2.85","3.45","4.05","4.65","5.25","5.85","6.45","7.05"])
ylocs = np.linspace(0,30,11)
plt.xticks(xlocs,xtick,fontsize=10)
plt.yticks(ylocs,ytick,fontsize=10)
fig.colorbar(aximg,ax=ax,shrink=0.6)
xlabel('state inflation',fontsize=10)
ylabel('parameter inflation',fontsize=10)
plt.savefig('hoope_rtc_parameter_ens10_R.png')
#plt.savefig('hoope_pso_parameter_ens40_R.png')
#plt.savefig('nohoope_parameter_ens40_R.png')
plt.show()
#plt.gca().clear()

