"""
Author: Dabih Isidori
Changes: Deborah Morgenstern and Isabell Stucke and Fiona Fix

This script is part of the preprocessing of ERA5 data to obtain the required
shape for LAGRANTO.
"""
from optparse import OptionParser
import warnings
import numpy as np
import xarray as xa
import metpy
import metpy.calc as mpcalc
import functions as func

# get input data from arguments
parser = OptionParser()
parser.add_option("-Y", dest = "YEAR", type = "str")
parser.add_option("-M", dest = "MONTH", type = "str")
parser.add_option("-D", dest = "DAY", type = "str")
parser.add_option("--INFILE", dest = "INFILE", type = "str")
parser.add_option("--OUTFILE", dest = "OUTFILE", type = "str")
(options,args) = parser.parse_args()

YEAR   = int(options.YEAR)
MONTH  = int(options.MONTH)
DAY    = int(options.DAY)


# Load data
filename = options.INFILE
ds       = xa.open_dataset(filename)
coeffs   = xa.open_dataset("level_coeffs.nc")                                            # SPECIFY: level_coeffs file

# Level_2 dummy is necessary for further preprocessing! 
pres = ds.pres.swap_dims({'level_2': 'lev'})
pres = pres.rename({'longitude': 'lon', 'latitude': 'lat', 'level_2': 'level_2'})

geoh = ds.geoh.swap_dims({'level_2': 'lev'})
geoh = geoh.rename({'longitude': 'lon', 'latitude': 'lat', 'level_2': 'level'})

# Convert variables names to fit LAGRANTO (<5 characters)
ds = ds.rename_vars({'u': 'U', 'v': 'V', 'w': 'OMEGA', 'blh': 'BLH_m','blh_interp':'BLH_i', 'sp': 'PS','t': 'T', 'q': 'Q'})
ds = ds.rename_dims({'longitude': 'lon', 'latitude': 'lat', 'level': 'lev'})
ds = ds.rename({'longitude': 'lon', 'latitude': 'lat', 'level': 'lev'})
ds = ds.set_index({'lon': 'lon', 'lat': 'lat', 'lev': 'lev'})

# Union with the renamed pressure variable
ds = ds.drop(['level_2', 'pres'])
ds = ds.merge(pres, join = 'left')

ds = ds.drop(['level_2', 'geoh'])
ds = ds.merge(geoh, join = 'left')
ds = ds.drop(['level'])

# Add metadata to coordinate 'lev'
ds.lev.attrs = {'standard_name': 'hybrid_sigma_pressure',
                'long_name': 'hybrid level at layer midpoints',
                'formula': 'hyam hybm (mlev=hyam+hybm*aps)',
                'formula_terms': 'ap: hyam b: hybm ps: aps',
                'units': 'level', 'positive': 'down'}
                
# Change Pa to hPa in PS
ds = ds.assign({'PS' : ds.PS/100})
ds.PS.attrs = {'standard_name': 'surface_air_pressure', 'long_name': 'Surface pressure', 'units': 'hPa'}

####
coeffs = coeffs.swap_dims({'lev_2': 'lev'})
coeffs = coeffs.rename({'lev_2': 'lev'})
coeffs = coeffs.set_index({'lev' : 'lev'})
# 
# # Add model level coefficients
ds['hyai'] = coeffs.hyai
ds['hybi'] = coeffs.hybi
ds['hyam'] = coeffs.hyam
ds['hybm'] = coeffs.hybm

####

# Calculate potential temperature field
temperature      = ds['T'].metpy.quantify()  # quantify is used for unit aware calcs in metpy
pressure         = ds['pres'].metpy.quantify()
th               = mpcalc.potential_temperature(pressure, temperature) 
ds['TH']         = (['time', 'lev', 'lat', 'lon'], th.values)
ds.TH.attrs      = dict(long_name="Potential temperature", units ="K")

# Calculate equivalent potential temperature field
specific_humidity     = ds['Q'].metpy.quantify()
dewpoint              = mpcalc.dewpoint_from_specific_humidity(pressure, temperature, specific_humidity)
th_e                  = mpcalc.equivalent_potential_temperature(pressure, temperature, dewpoint)
ds['TH_E']            = (['time', 'lev', 'lat', 'lon'], th_e.values)
ds.TH_E.attrs         = dict(long_name="Equivalent potential temperature", units ="K")
ds['Q']               = ds['Q'] *1000 # to calculate q in g/kg 

# get topograpy from surface geopotential
geopotential      = ds['z']
surf_height       = func.geopot_to_hamsl(geopotential)
surf_height       = surf_height.expand_dims(dim={'time':ds.sizes['time']}, axis=0)
ds['topo']        = (['time','lat','lon'],surf_height.values)
ds.topo.attrs     = dict(long_name="topography", units ="m")

#get hamsl from geoh
hamsl             = func.geoh_to_hamsl(ds['geoh'])
ds['hamsl']       = (['time', 'lev','lat', 'lon'], hamsl.values)
ds.hamsl.attrs    = dict(long_name="height_above_mean_sealevel", units="m")

# get BLH above mean sea level
BLH_amsl         = func.hag_to_hamsl(ds['BLH_m'], ds['topo'])
ds['BLH_a']      = (['time','lat','lon'], BLH_amsl.values)
ds.BLH_a.attrs   = dict(long_name="boundary_layer_height_above_mean_sealevel", units="m")

# calculate difference of hamsl and BLH_amsl
DZ               = (ds['hamsl']-ds['BLH_a'])/1000
ds['DZ']         = (['time','lev','lat','lon'], DZ.values)
ds.DZ.attrs      = dict(long_name="difference_hamsl_BLH_amsl", units="km")


# calculate difference of hamsl and BLH_i (interpolated)
DZ_i               = (ds['hamsl']-ds['BLH_i'])/1000
ds['DZi']         = (['time','lev','lat','lon'], DZ_i.values)
ds.DZi.attrs      = dict(long_name="difference_hamsl_BLH_interpolated", units="km")

for var_name in list(ds.data_vars.keys()):
    if len(var_name) > 5:
        warnings.warn(f"The variable name '{var_name}' is longer than 5 characters and might cause problems if you want to trace it with LAGRANTO")

# Save to netcdf
ofilename = options.OUTFILE
ds.to_netcdf(ofilename)
