#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Apr  6 15:24:35 2023

@author: shenzheqi
"""
# 同化开始于2005年，评估分析场的时候去掉第一年
# EN的数据纬度从-83开始，所以index从7开始，TEMP少了最后3个月的数据
# 计算11年的RMSE
#%% 读温盐数据
from netCDF4 import Dataset
import numpy as np
SEdata = Dataset("./Temp_05-17.nc")
TempSE = SEdata.variables['temp'][36:144]
SESdata = Dataset("./Salt_05-17.nc")
SaltSE = SESdata.variables['salt'][36:144]
PEdata = Dataset(".PE_Temp_05-17.nc")
TempPE = PEdata.variables['temp'][36:144]
PESdata = Dataset("./PE_Salt_05-17.nc")
SaltPE = PESdata.variables['salt'][36:144]
lat = SEdata.variables['lat'][:]
lon = SEdata.variables['lon'][:]
z_t = SEdata.variables['z_t'][:]
# reference
ENdata = Dataset("./EN4_05-20.nc")
TempEN = ENdata.variables['sst'][36:144,0:31,:,:]
SaltEN = ENdata.variables['salt'][36:144,0:31,:,:]
ORAdata = Dataset("./OtherReanalysis/oras5_t05-18_v.nc")
TempORA = ORAdata.variables['temp'][36:144]
ORASdata = Dataset("./OtherReanalysis/oras5_s05-18_v.nc")
SaltORA = ORASdata.variables['salt'][36:144]
ECdata = Dataset("./OtherReanalysis/ECDA_TS05-20_grided.nc")
TempEC = ECdata.variables['temp'][36:144]
SaltEC = ECdata.variables['salt'][36:144]
GODdata = Dataset("./OtherReanalysis/GODAS_TS05-20_grided.nc")
TempGOD = GODdata.variables['temp'][36:144]
SaltGOD = GODdata.variables['salt'][36:144]*1000
TempGOD.mask[TempGOD.data<10000] = False
#%%  range for reference
POrange =  [23,156,150,285]
IOrange = [23,120,30,135]
AOrange = [23,156,80,200]  # with roll 180
Troprange = [60,120]
#%% tempRMSE
tempRMSE = np.zeros([4,31,133,360])  # lat: -66.5:66.5, lon: 0:360
tempRMSEP = np.zeros([4,31,133,360])  # lat: -66.5:66.5, lon: 0:360
tempRMSE[0] = np.sqrt(np.nanmean((TempSE[:,:,23:156]-TempEN[:,:,range(16,149)])**2,axis=0))
tempRMSE[1] = np.sqrt(np.nanmean((TempSE[:,:,23:156]-TempORA[:,:,23:156])**2,axis=0))
tempRMSE[2] = np.sqrt(np.nanmean((TempSE[:,:,23:156]-TempEC[:,:,23:156])**2,axis=0))
tempRMSE[3] = np.sqrt(np.nanmean((TempSE[:,:,23:156]-TempGOD[:,:,23:156])**2,axis=0))

tempRMSEP[0] = np.sqrt(np.nanmean((TempPE[:,:,23:156]-TempEN[:,:,16:149])**2,axis=0))
tempRMSEP[1] = np.sqrt(np.nanmean((TempPE[:,:,23:156]-TempORA[:,:,23:156])**2,axis=0))
tempRMSEP[2] = np.sqrt(np.nanmean((TempPE[:,:,23:156]-TempEC[:,:,23:156])**2,axis=0))
tempRMSEP[3] = np.sqrt(np.nanmean((TempPE[:,:,23:156]-TempGOD[:,:,23:156])**2,axis=0))
#%% saltRMSE
saltRMSE = np.zeros([4,31,133,360])  # lat: -66.5:66.5, lon: 0:360
saltRMSEP = np.zeros([4,31,133,360])  # lat: -66.5:66.5, lon: 0:360
saltRMSE[0] = np.sqrt(np.nanmean((SaltSE[:,:,23:156]-SaltEN[:,:,range(16,149)])**2,axis=0))
saltRMSE[1] = np.sqrt(np.nanmean((SaltSE[:,:,23:156]-SaltORA[:,:,23:156])**2,axis=0))
saltRMSE[2] = np.sqrt(np.nanmean((SaltSE[:,:,23:156]-SaltEC[:,:,23:156])**2,axis=0))
saltRMSE[3] = np.sqrt(np.nanmean((SaltSE[:,:,23:156]-SaltGOD[:,:,23:156])**2,axis=0))

saltRMSEP[0] = np.sqrt(np.nanmean((SaltPE[:,:,23:156]-SaltEN[:,:,16:149])**2,axis=0))
saltRMSEP[1] = np.sqrt(np.nanmean((SaltPE[:,:,23:156]-SaltORA[:,:,23:156])**2,axis=0))
saltRMSEP[2] = np.sqrt(np.nanmean((SaltPE[:,:,23:156]-SaltEC[:,:,23:156])**2,axis=0))
saltRMSEP[3] = np.sqrt(np.nanmean((SaltPE[:,:,23:156]-SaltGOD[:,:,23:156])**2,axis=0))
#%% spatial RMSE
def comp_meanRMSE(tempRMSE):
    N,nlat,nlon = tempRMSE.shape
    meanRMSE = np.zeros(N)
    for ll in range(N):
        meanRMSE[ll] = np.sqrt(np.nanmean((tempRMSE[ll])**2))
    return meanRMSE
#%
RegMRMSE = np.zeros([5,4,31]);RegMRMSEP = np.zeros([5,4,31]) # 5 region, 4 dataset
for j in range(4):
    RegMRMSE[0,j] = comp_meanRMSE(tempRMSE[j])
    RegMRMSEP[0,j] = comp_meanRMSE(tempRMSEP[j])
    RegMRMSE[1,j] = comp_meanRMSE(tempRMSE[j,:,:,150:285])
    RegMRMSEP[1,j] = comp_meanRMSE(tempRMSEP[j,:,:,150:285])
    RegMRMSE[2,j] = comp_meanRMSE(tempRMSE[j,:,::97,30:135])
    RegMRMSEP[2,j] = comp_meanRMSE(tempRMSEP[j,:,::97,30:135])
    tempRMSE1r = np.roll(tempRMSE[j],180,axis=2)
    tempRMSEP1r = np.roll(tempRMSEP[j],180,axis=2)
    RegMRMSE[3,j] = comp_meanRMSE(tempRMSE1r[:,:,80:200])
    RegMRMSEP[3,j] = comp_meanRMSE(tempRMSEP1r[:,:,80:200])
    RegMRMSE[4,j] = comp_meanRMSE(tempRMSE[j,:,37:97])
    RegMRMSEP[4,j] = comp_meanRMSE(tempRMSEP[j,:,37:97])
#%% 分区算盐度RMSE
RegMRMSE_s = np.zeros([5,4,31]);RegMRMSEP_s = np.zeros([5,4,31]) # 5 region, 4 dataset
for j in range(4):
    RegMRMSE_s[0,j] = comp_meanRMSE(saltRMSE[j])
    RegMRMSEP_s[0,j] = comp_meanRMSE(saltRMSEP[j])
    RegMRMSE_s[1,j] = comp_meanRMSE(saltRMSE[j,:,:,150:285])
    RegMRMSEP_s[1,j] = comp_meanRMSE(saltRMSEP[j,:,:,150:285])
    RegMRMSE_s[2,j] = comp_meanRMSE(saltRMSE[j,:,::97,30:135])
    RegMRMSEP_s[2,j] = comp_meanRMSE(saltRMSEP[j,:,::97,30:135])
    saltRMSE1r = np.roll(saltRMSE[j],180,axis=2)
    saltRMSEP1r = np.roll(saltRMSEP[j],180,axis=2)
    RegMRMSE_s[3,j] = comp_meanRMSE(saltRMSE1r[:,:,80:200])
    RegMRMSEP_s[3,j] = comp_meanRMSE(saltRMSEP1r[:,:,80:200])
    RegMRMSE_s[4,j] = comp_meanRMSE(saltRMSE[j,:,37:97])
    RegMRMSEP_s[4,j] = comp_meanRMSE(saltRMSEP[j,:,37:97])

#%% 温度的分区RMSE图
import matplotlib.pyplot as plt
plt.figure(figsize=(12,8))
plt.tight_layout()
for j in range(5):
    for i in range(3):
        plt.subplot(3,5,1+j+5*i)
        plt.plot(RegMRMSE[j,i,:],z_t,'k--',lw=2,label='SE')
        plt.plot(RegMRMSEP[j,i,:],z_t,'r',lw=2,label='PE')
        plt.ylim(0,200);
        if j==0:
            plt.yticks([100,500,1000,1500],fontsize=14)
        else:
            plt.yticks([100,500,1000,1500],[])
        if i == 0 and j==4:
            plt.legend(fontsize=12)
        if j ==0 or j==1:
            plt.xlim(0,1);
            if i==2:
                plt.xticks([0.4,0.8],fontsize=14)
            else:
                plt.xticks([0.4,0.8],[])
        if j >1:
            plt.xlim(0,1);
            if i==2:
                plt.xticks([0.4,0.8],fontsize=14)
            else:
                plt.xticks([0.4,0.8],[])
        plt.grid(axis='y',linestyle='dotted')
        ax = plt.gca()
        ax.invert_yaxis()
# plt.savefig('figure6.eps')
#%% 盐度的分区RMSE图
import matplotlib.pyplot as plt
plt.figure(figsize=(12,8))
plt.tight_layout()
for j in range(5):
    for i in range(3):
        plt.subplot(3,5,1+j+5*i)
        plt.plot(RegMRMSE_s[j,i,:],z_t,'k--',lw=2,label='SE')
        plt.plot(RegMRMSEP_s[j,i,:],z_t,'b',lw=2,label='PE')
        plt.ylim(0,200);
        if j==0:
            plt.yticks([100,500,1000,1500],fontsize=14)
        else:
            plt.yticks([100,500,1000,1500],[])
        if i == 0 and j==4:
            plt.legend(fontsize=12)
        if j == 3:
            plt.xlim(0,0.5);
            if i==2:
                plt.xticks([0.2,0.4],fontsize=14)
            else:
                plt.xticks([0.4,0.4],[])
        # elif j ==2:
        #     plt.xlim(0,0.2);
        #     if i==2:
        #         plt.xticks([0.1,0.2],fontsize=14)
        #     else:
        #         plt.xticks([0.1,0.2],[])
        else:
            plt.xlim(0,0.5)
            if i==2:
                plt.xticks([0.2,0.4],fontsize=14)
            else:
                plt.xticks([0.2,0.4],[])
        plt.grid(axis='y',linestyle='dotted')
        ax = plt.gca()
        ax.invert_yaxis()
# plt.savefig('figure7.eps')
#%% 热带温盐RMSE差异
import pandas as pd
from matplotlib.colors import ListedColormap 
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
rgb1 = pd.read_csv('WhiteBlueGreenYellowRedgood.rgb',sep='\s+',skiprows=2,names=['r','g','b']).values/255
#将rgb信息映射为colormap
colormap1 = ListedColormap(rgb1)
colormap1.set_bad(color='grey')
rgb2 = pd.read_csv('cmocean_deep.rgb',sep='\s+',skiprows=2,names=['r','g','b']).values/255
#将rgb信息映射为colormap
colormap2 = ListedColormap(rgb2)
colormap2.set_bad(color='grey')

tropRMSE = np.sqrt(np.nanmean(tempRMSE[0,:,61:72]**2,axis=1))
tropRMSEP = np.sqrt(np.nanmean(tempRMSEP[0,:,61:72]**2,axis=1))
mask0 = np.zeros_like(tropRMSE)
mask0[tropRMSE==0]=1
tropRMSE = np.ma.array(tropRMSE,mask=mask0)
tropRMSEP = np.ma.array(tropRMSEP,mask=mask0)

tropRMSE_s = np.sqrt(np.nanmean(saltRMSE[0,:,61:72]**2,axis=1))
tropRMSEP_s = np.sqrt(np.nanmean(saltRMSEP[0,:,61:72]**2,axis=1))
tropRMSE_s = np.ma.array(tropRMSE_s,mask=mask0)
tropRMSEP_s = np.ma.array(tropRMSEP_s,mask=mask0)

LLon, Dd = np.meshgrid(lon,z_t)
plt.figure(figsize=(12,8))
plt.subplot(4,1,1)
plt.gca().set_facecolor("gainsboro") 
plt.contourf(LLon,Dd,tropRMSE,levels=np.linspace(0,3,16),extend='both',cmap=colormap1)
plt.yscale('log')
plt.yticks([100,1000],[100,1000],fontsize=14)
plt.ylabel('Depth',fontsize=14)
plt.title('RMSE of temperature: SE',fontsize=14)
plt.xticks(np.arange(0,361,60),[])
plt.grid(axis='y')
plt.colorbar(label='degC')
ax = plt.gca()
ax.invert_yaxis()
plt.subplot(4,1,2)
plt.gca().set_facecolor("gainsboro") 
plt.contourf(LLon,Dd,tropRMSE-tropRMSEP,levels=np.linspace(0,0.15,16),extend='both',cmap=plt.cm.Reds)
plt.yscale('log')
plt.yticks([100,1000],[100,1000],fontsize=14)
plt.ylabel('Depth',fontsize=14)
plt.title('RMSE of temperature: SE-PE',fontsize=14)
plt.xticks(np.arange(0,361,60),[])
plt.grid(axis='y')
plt.colorbar(label='degC')
ax = plt.gca()
ax.invert_yaxis()
plt.subplot(4,1,3)
plt.gca().set_facecolor("gainsboro") 
plt.contourf(LLon,Dd,tropRMSE_s,levels=np.linspace(0,1.5,16),extend='both',cmap=colormap2)
plt.yscale('log')
plt.yticks([100,1000],[100,1000],fontsize=14)
plt.ylabel('Depth',fontsize=14)
plt.title('RMSE of salinty: SE',fontsize=14)
plt.xticks(np.arange(0,361,60),[])
plt.grid(axis='y')
plt.colorbar(label='psu')
ax = plt.gca()
ax.invert_yaxis()
ax1 = plt.subplot(4,1,4)
plt.gca().set_facecolor("gainsboro") 
plt.contourf(LLon,Dd,tropRMSE_s-tropRMSEP_s,levels=np.linspace(0,0.10,11),extend='both',cmap=plt.cm.Reds)
plt.yscale('log')
plt.yticks([100,1000],[100,1000],fontsize=14)
plt.ylabel('Depth',fontsize=14)
plt.title('RMSE of salinty: SE-PE',fontsize=14)
plt.grid(axis='y')
ax1.set_xticks([0, 60, 120, 180, 240, 300, 360],fontsize=14)
lon_formatter = LongitudeFormatter(zero_direction_label=False)
ax1.xaxis.set_major_formatter(lon_formatter)
plt.xlabel('Longitude',fontsize=14)
plt.colorbar(label='psu')
ax = plt.gca()
ax.invert_yaxis()
plt.tight_layout(h_pad=1.5)
plt.savefig('figure8.eps')
#%%
LLat, Dd2 = np.meshgrid(lat[23:156],z_t)
tempRMSE1r = np.roll(tempRMSE[0],180,axis=2)
tempRMSEP1r = np.roll(tempRMSEP[0],180,axis=2)
atRMSE = np.sqrt(np.nanmean(tempRMSE1r[:,:,80:200]**2,axis=2))
atRMSEP = np.sqrt(np.nanmean(tempRMSEP1r[:,:,80:200]**2,axis=2))
saltRMSE1r = np.roll(saltRMSE[0],180,axis=2)
saltRMSEP1r = np.roll(saltRMSEP[0],180,axis=2)
atRMSE_s = np.sqrt(np.nanmean(saltRMSE1r[:,:,80:200]**2,axis=2))
atRMSEP_s = np.sqrt(np.nanmean(saltRMSEP1r[:,:,80:200]**2,axis=2))
plt.figure(figsize=(12,8))
plt.subplot(4,1,1)
plt.contourf(LLat,Dd2,atRMSE,levels=np.linspace(0,2,21),extend='both',cmap=colormap1)
plt.yticks([100,1000,2000],[100,1000,2000],fontsize=14)
plt.ylabel('Depth',fontsize=14)
plt.title('RMSE of temperature: SE',fontsize=14)
plt.xticks(np.arange(-30,61,30),[]);plt.xlim(-31,60)
plt.grid(axis='y')
plt.colorbar(label='degC')
ax = plt.gca()
ax.invert_yaxis()
plt.subplot(4,1,2)
plt.contourf(LLat,Dd2,atRMSE-atRMSEP,levels=np.linspace(0,0.3,16),extend='both',cmap=plt.cm.Reds)
plt.yticks([100,1000,2000],[100,1000,2000],fontsize=14)
plt.ylabel('Depth',fontsize=14)
plt.title('RMSE of temperature: SE-PE',fontsize=14)
plt.xticks(np.arange(-30,61,30),[]);plt.xlim(-31,60)
plt.grid(axis='y')
plt.colorbar(label='degC')
ax = plt.gca()
ax.invert_yaxis()
plt.subplot(4,1,3)
plt.contourf(LLat,Dd2,atRMSE_s,levels=np.linspace(0,1,11),extend='both',cmap=colormap2)
plt.yticks([100,1000,2000],[100,1000,2000],fontsize=14)
plt.ylabel('Depth',fontsize=14)
plt.title('RMSE of salinity: SE',fontsize=14)
plt.xticks(np.arange(-30,61,30),[]);plt.xlim(-31,60)
plt.grid(axis='y')
plt.colorbar(label='psu')
ax = plt.gca()
ax.invert_yaxis()
ax1 = plt.subplot(4,1,4)
plt.contourf(LLat,Dd2,atRMSE_s-atRMSEP_s,levels=np.linspace(0,0.1,11),extend='both',cmap=plt.cm.Reds)
plt.yticks([100,1000,2000],[100,1000,2000],fontsize=14)
plt.ylabel('Depth',fontsize=14)
lat_formatter = LatitudeFormatter()
ax1.xaxis.set_major_formatter(lat_formatter)
ax1.set_xticks(np.arange(-60,61,30),fontsize=16);plt.xlim(-31,60)
plt.xlabel('Latitude',fontsize=14)
plt.title('RMSE of salinity: SE-PE',fontsize=14)
plt.grid(axis='y')
plt.colorbar(label='psu')
ax = plt.gca()
ax.invert_yaxis()
plt.tight_layout(h_pad=1.5)
# plt.savefig('figure9.eps')