# -*- coding: utf-8 -*-
'''
FILENAME:
    model_class.py

DESCRIPTION:
    The model class contains global options and spatial parameters attribute
    dictionaries as well as the model processes processes
    (vertical water balance and river routing).

AUTHOR:
    Tobias Stacke

Copyright (C):
    2020-2021 Helmholtz-Zentrum Geesthacht

LICENSE:
    This program is free software: you can redistribute it and/or modify it under the
    terms of the GNU General Public License as published by the Free Software Foundation,
    either version 3 of the License, or (at your option) any later version.

    This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
    without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
    See the GNU General Public License for more details.

    You should have received a copy of the GNU General Public License along with this program.
    If not, see http://www.gnu.org/licenses/.
'''

# Module info
__author__ = 'Tobias Stacke'
__copyright__ = 'Copyright (C) 2020-2021 Helmholtz-Zentrum Geesthacht'
__license__ = 'GPLv3'

# Load modules
import numpy as np
import utility_routines as utr
import xarray as xr
import math as m
import datetime as dt
import copy as cp
import sys
import os
import pdb
import subprocess as spr
from termcolor import colored


# ======================================================================================================
# HydroPy model class
# ======================================================================================================
class model:

    # ======================================================================================================
    # INITIALIZATION
    # ======================================================================================================

    def __init__(self):
        # Class containing model options and parameters as attribute dictionaries as well as process functions

        # General description
        self.description = "HydroPy model class containing flux and state attributes and hydrological processes"
        self.author = "Tobias Stacke, tobias.stacke@hzg.de, HZG, Germany"

        # Set default global model options
        self.opt = {
            # General information
            'expid': 'hydropy',
            'contact': 'unknown',
            'institute': 'unknown',
            # Pathinformation for forcing, input and output directories
            'forcing': os.getcwd() + "/forcing",
            'input': os.getcwd() + "/input",
            'output': os.getcwd() + "/output",
            'para': os.getcwd() + "/input/hydropy_para.nc",
            'restart': None,
            'restdate': None,
            # Use optimization
            'use_fortran': True,
            # Output variables and temporal resolution
            'daily': 'qs,qsb',
            'monthly': 'rootmoist,swe,evap,qtot',
            # Enabled model processes
            'with_permafrost': True,
            'with_skin': True,
            'with_leakage': False,
            'with_rivers': True,
            # Global parameter values for snow processes
            'rainf_lower': 273.15 - 1.1,  # Lower threshold for liquid precipitation [K]
            'snowf_upper': 273.15 + 3.3,  # Upper threshold for solid  precipitation [K]
            'melt_crit': 273.15 - 0.0,  # Critical temperature for snow melt [K]
            't_refreeze': 273.15 - 0.0,  # Critical temperature for refreezing water [K]
            'frc_liquid': 0.06,  # Fraction of liquid water in snow cover [/]
            'meltscheme': 'temporal', # snow melt scheme: temporal, spatial, both
            # Global parameter values for soil and skin processes
            'skincap1': 0.2, # Skin reservoir capacity on one layer (ECHAM: 0.2) [kg m-2]
            'rm_crit': 0.75,  # Critical root zone soil moisture fraction [/]
            'qsb_min': 2.77778e-07,  # Minimum drainage parameter [kg m-2 s-1]
            'qsb_max': 2.77778e-05,  # Maximum drainage parameter [kg m-2 s-1]
            'qsb_exp': 1.5,  # ECHAM drainage exponent [/]
            'qsb_low': 0.05,  # Minimum soilmoisture content for drainage  [/]
            'qsb_hig': 0.90,  # Maximum soilmoisture content for drainage  [/]
            'sevap_low': 0.05,  # Minimum soilmoisture content for soil evap [/]
            'wcap_perma': 50.0,  # Maximum water holding capacity for permafrost [kg m-2]
            # Global parameter values for flow processes
            'rivsubtime': 4, # Sub-timesteps for river flow [#]
            'v_lake': 0.01, # Flow velocity for 100% lake fraction 0.1 [m s-1] or None to disable lake retention
            'v_wetl': 0.06, # Flow velocity for 100% wetland fraction 0.1 [m s-1] or None to disable wetland retention
            'cf_wela': 0.5, # Fraction where lake and wetland impact becomes dominant
            'fak_ovr': 1, # Lag Modifikation faktor for overland flow
            'fak_gw': 1, # Lag Modifikation faktor for groundwater flow
            'fak_riv': 1, # Lag Modifikation faktor for river flow
        }

        # Try to add git information if code is from repo
        try:
            gitdir = sys.path[0]+'/.git'
            self.opt['version'] = (spr.check_output(
                    ["git", "--git-dir="+gitdir, "describe", "--always", "--long"]).strip()
                    ).decode(sys.stdout.encoding)
        except:
            self.opt['version'] = 'unknown'
        self.opt['expid'] = self.opt['expid'] + '-' + self.opt['version'].replace('_','-')

        # Dictionary for parameter fields
        self.param = xr.Dataset()

        # Dictionary for temporary states and information
        self.temporary = {}


    # ======================================================================================================
    # CONFIGURATION OF OPTIONS
    # ======================================================================================================

    def update_from_ini(self, setupfile=os.getcwd() + "/setup.ini"):
        # Replace default global model options from setup.ini file

        with open(setupfile, 'r') as optfile:
            for line in optfile:
                if line[0] != '#':  # Ignore comments
                    if line.count(':') != 1:  # Error if wrong syntax
                        print("Error reading setup.ini")
                        print("Wrong syntax for line:", line)
                    else:
                        optkey, optval = line.split(':')
                        optkey = optkey.replace(' ', '') # remove all whitespace from keywords
                        if optkey not in self.opt.keys():
                            print("Found unknown key in setup.ini: ", optkey)
                            sys.exit(1)
                        else:
                            value = utr.get_correct_type(optval.lstrip().rstrip())
                            self.opt[optkey] = value

    # ======================================================================================================
    def update_from_cli(self, optkey, optval):
        # Replace options from default or setup file with command line options

        if optkey in self.opt.keys():
            value = utr.get_correct_type(optval)
            self.opt[optkey] = value
        else:
            print("Got unknown option from command line: ", optkey)
            sys.exit(1)

    # ======================================================================================================
    def print_options(self):
        for optkey in sorted(self.opt.keys()):
            line = {'opt': optkey, 'val': str(self.opt[optkey])}
            print('{opt:<20} = {val:<50}'.format(**line))

    # ======================================================================================================
    def update_all(self, debug, spinup):
        '''Set some more global switches'''
        self.expid = self.opt['expid']
        self.with_permafrost = self.opt['with_permafrost']
        self.with_skin = self.opt['with_skin']
        self.with_rivers = self.opt['with_rivers']
        self.debug = debug
        
        # Set switch for spinup state
        if spinup == 0:
            self.spinup = False
        else:
            self.spinup = True

        # Verify sensible use of restart file and date
        if self.opt['restdate'] is None and self.opt['restart'] is not None:
            raise LookupError('External restart file provided but without defined restart date')
        if self.opt['restdate'] is not None and self.opt['restart'] is None:
            raise LookupError('Restart date provided without providing external restart file')

    # ======================================================================================================
    # READING MODEL PARAMETER FIELDS AND BUILD DERIVED FIELDS
    # ======================================================================================================

    def get_parameter(self, parafile):
        try:
            fileobj = utr.dataset2double(xdata=xr.open_dataset(parafile))
        except OSError:
            print("Cannot open file " + parafile)
            sys.exit(1)

        # Read parameter data
        for param in fileobj.data_vars:
            if param.lower() in self.param.data_vars:
                if (fileobj[param] - self.param[param.lower()]).min() != (
                        fileobj[param] - self.param[param.lower()]).max() != 0:
                    raise LookupError(
                        "Error: Parameter", param,
                        "already exists in parameter list but with different values"
                    )
                    sys.exit(1)
            else:
                self.param[param.lower()] = fileobj[param]

        # Sanity checks for specific parameter fields
        for param in ['rout_lat', 'rout_lon']:
            if param in self.param.data_vars:
                if np.any(np.isnan(self.param[param])):
                    raise LookupError(
                        "Error: No nan or missing values allowed in", param)
                    sys.exit(1)

        # Replace missing values with zero
        self.param = self.param.fillna(0)

        # Initialize grid dictionary
        self.grid = {}

    # ======================================================================================================
    def get_grid(self, griddata):
        '''Store grid information from forcing data'''
        g = self.grid
        # Get latitude and longitude
        g['coords'] = griddata.stream.coords
        g['lat'] = griddata.stream['lat']
        g['lon'] = griddata.stream['lon']
        g['nlat'] = len(griddata.stream['lat'].values)
        g['nlon'] = len(griddata.stream['lon'].values)

        # Get time vector and restart time (assuming a regular time interval)
        g['time'] = griddata.stream['time']
        g['ntime'] = len(g['time'])
        g['resttime'] = g['time'][0] - (g['time'][1] - g['time'][0])
        g['restday'] = str(g['resttime'].values).split('T')[0].replace('-', '')
        g['duration'] = g['time'][-1] - g['resttime']

        # Create time series of monthly mean dates for climatology processing
        year=int(str(g['time'][0].values).split('-')[0])
        start, end = dt.datetime(year=year, month=1, day=1), dt.datetime(year=year+1, month=1, day=1)
        alldays = np.array([dt.timedelta(days=x) for x in range((end-start).days)]) + start
        g['monstamp'] = xr.DataArray(
                alldays, coords={'time': alldays}, dims=('time',), name='time').resample(time="1MS").mean()
        g['monstamp'] = g['monstamp'].assign_coords(time=g['monstamp'].values)

        # Identify simulation duration and set chunksize
        if g['duration'] / np.timedelta64(1, 's') >= 365 * 24 * 3600:
            # Yearly chunk
            g['timeid'] = str(g['time'].values[0])[:4]
        elif g['duration'] / np.timedelta64(1, 's') >= 29 * 24 * 3600:
            # Monthly chunk
            g['timeid'] = str(g['time'].values[0]).replace('-', '')[:6]
        elif g['duration'] / np.timedelta64(1, 's') >= 24 * 3600:
            # Daily chunk
            g['timeid'] = str(g['time'].values[0]).replace('-', '')[:8]

        # Replace time vector in climatological data with actual year and
        # and add 1 month before and after year for interpolation
        newcoords = {'time': g['monstamp']['time'], 'lat': self.param.lat, 'lon': self.param.lon}
        newdims = ['time', 'lat', 'lon']
        for var in [x for x in self.param.data_vars if 'month' in self.param[x].dims]:
            newdata = self.param[var].values
            ln, un = self.param[var].long_name, self.param[var].units
            field = xr.DataArray(newdata, coords=newcoords, dims=newdims,
                    attrs={'long_name': ln, 'units': un})
            f0, f1 = field.isel(time=-1), field.isel(time=0)
            f0['time'].values = f0['time'].values - np.timedelta64(365,'D')
            f1['time'].values = f1['time'].values + np.timedelta64(365,'D')
            self.param[var] = xr.concat([f0,field,f1], dim='time').transpose('time', 'lat', 'lon')

    # ======================================================================================================
    def get_lsm(self, forcdata):
        '''return minimal lsm based on parameter and actual forcing data'''
        if "lsm" not in self.param.data_vars:
            raise LookupError("Error: No LSM found in parameter files")
        g = self.grid
        # Modify land sea mask: original mask, forcing data and glacier
        g['lsm'] = self.param['lsm']
        g['lsm'] = g['lsm'].where(forcdata['TSurf'] > 0, 0.0)
        if 'glacier' in self.param.keys():
            # Substract absolute glacier area from LSM
            abs_glacier = g['lsm'] * self.param['glacier']
            g['lsm'] = (g['lsm'] - abs_glacier).where(g['lsm'] > abs_glacier, 0.0)
            print("\nSubstract glacier fraction from land fractions")
        g['area'] = self.param['area']

        # Compute area and weights for land surface
        g['landarea'] = (g['area'] * g['lsm']).to_masked_array()
        g['landweights'] = g['landarea'] / g['landarea'].sum()

    # ======================================================================================================
    def set_permafrost(self):
        '''this function reduces maximum water capacity and water availability according
        to permafrost fraction
        '''
        # Holding capacity reduction
        wcap_perm = self.param.wcap.where(
            self.param.wcap <= self.opt['wcap_perma'],
            self.param.wcap * 0 + self.opt['wcap_perma'])
        cap_reduc = ((wcap_perm * self.param.perm + self.param.wcap *
                      (1 - self.param.perm)) / self.param.wcap)
        # Apply reduction to all water holding capacity parameters
        for para in ['wcap', 'wava', 'wmin', 'wmax']:
            attrs = self.param[para].attrs
            self.param[para] *= cap_reduc.fillna(0)
            self.param[para].attrs = attrs


    # ======================================================================================================
    def set_soilparam(self):
        '''return derived soil parameter fields'''
        self.param['wilt'] = self.param['wcap'] - self.param['wava']
        self.param['wilt'].attrs = {
            'long_name': 'wilting point',
            'units': 'kg m-2'
        }
        self.param['crit'] = self.param['wcap'] * self.opt['rm_crit']
        self.param['crit'].attrs = {
            'long_name': 'critical soil moisture',
            'units': 'kg m-2'
        }
        self.param['wlow'] = self.param['wcap'] * self.opt['sevap_low']
        self.param['wlow'].attrs = {
            'long_name': 'dry soil limit',
            'units': 'kg m-2'
        }
        self.param['boro'] = ((self.param['topo_std'] - 100.0)
                            / (self.param['topo_std'] + 1000.0))
        self.param['boro'] = self.param['boro'].where(self.param['boro'] > 0, 0)
        self.param['boro'].attrs = {
            'long_name': 'Rescaled orographical standard deviation',
            'units': '/'
        }
        self.param['imax'] = self.param['wcap'] * (1.0 + self.param['boro'])
        self.param['imax'].attrs = {
            'long_name': 'maximum infiltration capacity',
            'units': 'kg m-2'
        }
        self.param['oexp'] = 1.0 / (1.0 + self.param['boro'])
        self.param['oexp'].attrs = {
            'long_name': 'beta parameter exponent',
            'units': '/'
        }
        self.param['bmod'] = (self.param.beta + self.param.boro).where(
            self.param.boro >= 0.01, self.param.beta)
        self.param['bmod'].attrs = {
            'long_name': 'modified beta parameter',
            'units': '/'
        }


    # ======================================================================================================
    def get_flow_properties(self):
        '''returns static list of cell indices and flow target indices'''
        import processes as prc
        # Initialize list and fields
        riverflow = []
        lsm = self.grid['lsm'].values
        area = self.param['area'].values
        topo = self.param['srftopo'].values

        # Get flow directions, cell distance and height difference
        rivfl, sinks, ic, dx_cell, dh_cell = prc.eval_flowfield(
            self.param['rout_lat'].values.astype(np.int),
            self.param['rout_lon'].values.astype(np.int),
            area, topo, np.pi)
        self.temporary['riverflow'] = rivfl[0:ic + 1]
        self.temporary['flowsinks'] = sinks * 1

        # The following computations are moved from the preproc scripts to the model to be
        # computed at runtime. All values are valid only for daily time steps and 0.5 deg
        # resolution check result to generalize.
        #
        # set values calibrated for Vindelaelven catchment and modified with
        # Torneaelven experiment (all done by Stefan)
        ref_ovr = {'k0': 50.5566, 'n0': 1.11070, 'v0': 1.0885, 'dx': 171000.0}
        ref_riv = {'k0': 0.41120, 'n0': 5.47872, 'v0': 1.0039, 'dx': 228000.0}
        vmin = 0.1 # Minimum flow velocity [m s-1] --> 5.79 day for 50 km
        alpha, c = 0.1, 2 # Parameters for Sausen flow velocity computation
        
        # Compute slope and grid cell diameter
        slope_cell = np.ma.where(dx_cell > 0, dh_cell / dx_cell, 0)
        slope_subg = self.param['slope_avg'].values
        dx_subg = (area / np.pi)**(0.5) * 2

        # Compute flow velocities within and between grid cells
        vel_cell = np.ma.maximum(vmin, c * slope_cell**alpha)
        vel_subg = np.ma.maximum(vmin, c * slope_subg**alpha)

        # Compute retention coefficient for surface water bodies using preferabley subgrid properties
        ovr_k = np.ma.where(slope_subg > 0,
            # Velocity based on subgrid slope
            ref_ovr['k0'] * dx_subg / ref_ovr['dx'] * ref_ovr['v0'] / vel_subg,
            # Velocity based on normal slope
            ref_ovr['k0'] * dx_cell / ref_ovr['dx'] * ref_ovr['v0'] / vel_cell
            )
        ovr_n = ref_ovr['n0'] * np.ones_like(ovr_k)

        # Compute retention coefficient for rivers using inter-cell properties
        riv_k = ref_riv['k0'] * dx_cell / ref_riv['dx'] * ref_riv['v0'] / vel_cell
        riv_n = ref_riv['n0'] * np.ones_like(riv_k)

        # Compute retention coefficients for baseflow based on fixed properties
        oroscale = np.ma.maximum(0.01, self.param['boro'].values)
        base_k = 300.0 / (1.0 - oroscale + 0.01)
        base_k *= (dx_subg / 50000.0)  # Scaling with normalized grid cell size
        base_n = np.ones_like(base_k)

        # Apply correction faktors for sensitivity experiment
        ovr_k *= self.opt['fak_ovr']
        base_k *= self.opt['fak_gw']
        riv_k *= self.opt['fak_riv']

        # Add additional retention due to lakes and wetlands
        for f_wela, v_wela, n_wela in zip([self.param['flake'], self.param['fwetl']],
                                          [self.opt['v_lake'], self.opt['v_wetl']],
                                          ['lake', 'wetland']):
            if v_wela is not None:
                fract = np.ma.maximum(0, np.ma.minimum(1, f_wela.values))
                # Compute lake and wetland impact
                fract_scaling = 0.5 * (np.tanh(4.0 * np.pi * (fract - self.opt['cf_wela'])) + 1.0)
                # Compute and modify river flow velocity
                v_riv = np.ma.where(riv_k > 0, dx_cell / ( riv_k * riv_n * 86400.0), vmin)
                incr_lag = np.ma.logical_and(v_riv > v_wela, fract > 1.0e-3)
                v_red = np.ma.where(incr_lag,
                        1 - (1.0 - v_wela / v_riv ) * fract_scaling, 1)
                riv_n = np.ma.where(incr_lag, (riv_n - 1) * (1.0 - fract_scaling) + 1, riv_n)
                riv_k = np.ma.where(incr_lag, dx_cell / (riv_n * v_riv * v_red * 86400.0), riv_k)
                # Compute and modify surface flow velocity
                v_ovr = np.ma.where(ovr_k > 0, dx_subg / ( ovr_k * ovr_n * 86400.0), vmin)
                incr_lag = np.ma.logical_and(v_ovr > 0.1 * v_wela, fract > 1.0e-3)
                v_red = np.ma.where(incr_lag,
                        1 - (1.0 - (0.1 * v_wela) / v_ovr) * fract_scaling, 1)
                ovr_n = np.ma.where(incr_lag, (ovr_n - 1) * (1.0 - fract_scaling) + 1, ovr_n)
                ovr_k = np.ma.where(incr_lag, dx_subg / (ovr_n * v_ovr * v_red * 86400.0), ovr_k)
                print('Flow velocity for 100%',n_wela,'cover set to',v_wela,' m s-1')
            else:
                print('No flow retention due to',n_wela+'s')

        # Store flow coefficient in temporary fields with unit days
        self.temporary['lag_land'] = ovr_k
        self.temporary['lag_base'] = base_k
        self.temporary['lag_river'] = riv_k
        # Store number of flow cascades
        self.temporary['ncasc_land'] = ovr_n
        self.temporary['ncasc_base'] = base_n
        self.temporary['ncasc_river'] = riv_n

        # Scale all lag values towards integer cascade numbers
        for l in [x for x in self.temporary.keys() if 'lag_' in x]:
            c = l.replace('lag_', 'ncasc_')
            self.temporary[l] *= (self.temporary[c] / self.temporary[c].astype(int))

        print(colored(
            "Compute retention times and cascade numbers for all flow storages",
            'green'))

        # lagout = xr.Dataset()
        # for l in [x for x in self.temporary.keys() if 'lag_' in x]:
        #     c = l.replace('lag_', 'ncasc_')
        #     lagout[l] = xr.DataArray(self.temporary[l], coords=self.param.area.coords, dims=self.param.area.dims)
        #     lagout[l].attrs = {'long_name': 'Retention time for '+l, 'units': 'd'}
        #     lagout[c] = xr.DataArray(self.temporary[c], coords=self.param.area.coords, dims=self.param.area.dims)
        #     lagout[c].attrs = {'long_name': 'Cascade number for '+c, 'units': '/'}
        # lagout.to_netcdf('lagvalues.nc')

        # Add disable option for simulations without river flow directions
        # else:
        #     print(colored(
        #         "Lake outflow, Groundwater outflow and Rivers disabled due to missing lag factors",
        #         'red'))
        #     self.with_rivers = False


    # ======================================================================================================
    # Landcover change processes
    # ======================================================================================================

    def get_daily_cover(self, fcover, ctypes, fluxes, states, date):
        '''compute daily land cover type fractions'''
        # This routine interpolates the daily value from the monthly climatology
        # for all 3D cover fields, but considers 2D cover to be constant.
        # LCT priority follows the order of entrys in ctypes. If no residual area remains
        # the subsequent LCTs go hungry.
        # Note: HydroPy considers the land (non-ocean) fraction only and thus all CTLs are
        # treated as being relative to the LSM, e.g. range between 0-1

        # Set residual field to 1 and identify residual land cover type
        resi_name = ctypes[-1]
        resi_field = np.where(self.param['lsm'].values > 0, 1, 0).astype(np.float64)

        for ct in ctypes[:-1]:
            if ct not in self.param.data_vars:
                raise LookupError("Cover type", ct,
                                  "not found in parameter data")

            # Check bounds for specific cover type
            if self.param[ct].min() < 0.0 or self.param[ct].max() > 1.0:
                raise ValueError("get_daily_cover: ERROR --> cover type", ct,
                                 "outside of 0-1 bounds")

            # Interpolate between month if climatology or else use it as constant cover
            if 'time' in self.param[ct].dims:
                fcover[ct] = (utr.monthly_interpol(
                    field=self.param[ct], fdate=date)).to_masked_array()
            elif self.param[ct].dims == ('lat', 'lon'):
                fcover[ct] = self.param[ct].to_masked_array()
            else:
                raise LookupError('Unexpected dimensions for parameter field',
                                  ct, self.param[ct].dims)

            # Apply conditions for specific land cover types
            if ct == 'flake':
                # Check bounds for wetlands, too
                if self.param['fwetl'].min() < 0.0 or self.param['fwetl'].max() > 1.0:
                    raise ValueError("get_daily_cover: ERROR --> cover type fwetl",
                                     "outside of 0-1 bounds")
                # Merge Lake and wetland fraction and provide minimum fraction as surface
                # water reservoir (surface runoff sink)
                flake = np.ma.maximum(fcover['flake'], self.param['fwetl'].to_masked_array())
                fcover[ct] = np.ma.minimum(np.ma.maximum(flake, 1.0e-10), 1)

            # Reduce cover fraction in case its already claims by higher priority class
            fcover[ct] = np.ma.minimum(fcover[ct], resi_field)

            # Reduce residum fraction fraction accordingly
            resi_field -= fcover[ct]

        if resi_field.min() < 0.0 or resi_field.max() > 1.0:
            raise ValueError(resi_name,
                             'residual cover fraction out of bounds')

        fcover[resi_name] = resi_field

        # Write fractions to log
        if 'log' in vars(self).keys():
            for ct in ctypes:
                self.log.add_value(fcover[ct], 'lcf_'+ct, ct+' cover fraction', unit='/', )

    # ======================================================================================================
    # Hydrological processes for snow cover
    # ======================================================================================================
    def diag_frozen_ground(self, states, fcover):
        '''diagnose frozen ground'''
        fcover['frozen'] = np.ma.where(states['tsurf'] < self.opt['melt_crit'], True, False)

    # ======================================================================================================
    def get_rain_and_snow(self, fluxes, states):
        '''This routine computes snowfall according to WIGMOSTA'''
        temp_range = self.opt['snowf_upper'] - self.opt['rainf_lower']
        snowfrct = np.ma.minimum(
            1.0,
            np.ma.maximum(0, (self.opt['snowf_upper'] - states['tsurf']) /
                          temp_range))

        fluxes['snowf'] = fluxes['precip'] * snowfrct
        fluxes['rainf'] = np.ma.maximum(0, fluxes['precip'] - fluxes['snowf'])

    # ======================================================================================================
    def get_potential_snowmelt(self, fluxes, states, time):
        '''Computes potential snowmelt using the daily degree approach (based on MPI-HM)'''
        if self.opt['meltscheme'].lower() in ['temporal', 'both']:
            # Compute daylength for every grid cell
            today = dt.datetime.strptime(str(time.values)[:10], '%Y-%m-%d')
            daylen = utr.daylength(today, self.grid['lat'].values)
            fdaylen = np.stack((daylen / 24.0,) * len(self.grid['lon']), axis=-1)
        # Compute melt factor for different schemes
        if self.opt['meltscheme'].lower() == 'spatial':
            # Meltfactor depends on orography alone
            melt_factor = self.param.ddfac.fillna(0).values
        elif self.opt['meltscheme'].lower() == 'temporal':
            # Meltfactor depends only on daylength alone
            melt_factor = fdaylen * 8.3 + 0.7
        elif self.opt['meltscheme'].lower() == 'both':
            # Meltfactor depends on daylength and orography
            melt_factor = np.ma.maximum(fdaylen * self.param.ddfac.fillna(0).values * 1.33 , 0.7)
        else:
            raise LookupError(
                    'Invalid choice',self.opt['meltscheme'].lower(),'for melt scheme')

        fluxes['smelt'] = np.ma.maximum(
            0, melt_factor * (states['tsurf'] - self.opt['melt_crit']))

        if 'log' in vars(self).keys():
            self.log.add_value(fluxes['smelt'], 'smelt_pot', 'Potential snow melt')
            self.log.add_value(states['tsurf'], 'tsurf', 'Surface temperature', unit='K')

    # ======================================================================================================
    def update_snow(self, fluxes, states):
        '''This function computes the throughfall onto the canopy based on
           subroutine THROUGH from the old MPI-HM (IRAIME == 12)
        '''
        if 'log' in vars(self).keys():
            self.log.add_value(states['swe'], 'swe_old', 'Snow water equivalent from last time step')
            self.log.add_value(states['wliq'], 'wliq_old', 'Liquid SWE content from last time step')

        # # Refreezing of liquid water content within snow cover if below certain temperature
        with np.errstate(invalid='ignore'):
            freezing = states['tsurf'] < self.opt['t_refreeze']
        states['swe'] = np.ma.where(freezing, states['swe'] + states['wliq'],
                                    states['swe'])
        states['wliq'] = np.ma.where(freezing, states['wliq'] * 0,
                                     states['wliq'])

        # Compute new snow cover and reduce potential snowmelt if higher than snow height
        states['swe'] += fluxes['snowf']
        fluxes['smelt'] = np.ma.where(states['swe'] > fluxes['smelt'],
                                      fluxes['smelt'], states['swe'])
        states['swe'] -= fluxes['smelt']

        # Update liquid water content in snow
        wliq_max = states['swe'] * self.opt['frc_liquid']
        states['wliq'] += fluxes['smelt']
        overflow = states['wliq'] > wliq_max
        fluxes['smelt'] = np.ma.where(overflow, states['wliq'] - wliq_max, 0)
        states['wliq'] = np.ma.where(~overflow, states['wliq'], wliq_max)

        # Compute throughfall
        fluxes['rainmelt'] = fluxes['rainf'] + fluxes['smelt']

        if 'log' in vars(self).keys():
            self.log.add_value(states['swe'], 'swe_new', 'Snow water equivalent')
            self.log.add_value(states['wliq'], 'wliq_new', 'Liquid swe content')
            self.log.add_value(fluxes['rainf'], 'rainf', 'Rainfall')
            self.log.add_value(fluxes['snowf'], 'snowf', 'Snowfall')
            self.log.add_value(fluxes['rainmelt'], 'rainmelt', 'Rainfall+Snowmelt')
            self.log.add_value(fluxes['smelt'], 'smelt', 'Snowmelt')

    # ======================================================================================================
    # Hydrological processes for soil storage
    # ======================================================================================================
    def get_surface_runoff(self, fluxes, states, fcover):
        '''Separation of throughfall into surface runoff and infiltration'''
        # computed using the Improved ARNO Scheme (MPI-HM IEXC == 5)

        # Prepare temporary fields and shortcuts
        beta = np.ma.masked_where(
            self.grid['lsm'] == 0,
            self.param.bmod.values)  # Modified beta parameter
        rm_cap = np.ma.masked_where(
            self.grid['lsm'] == 0,
            self.param.wcap.values)  # Maximum water holding capacity
        rm_max = np.ma.masked_where(
            self.grid['lsm'] == 0,
            self.param.wmax.values)  # Maximum subgrid soil moisture
        rm_min = np.ma.masked_where(
            self.grid['lsm'] == 0,
            self.param.wmin.values)  # Minimum subgrid soil moisture

        # Compute subgrid root zone soil moisture rm_sub
        rm_sub = rm_max - (rm_max - rm_min) * (1 -
                                               (states['rootmoist'] - rm_min) /
                                               (rm_cap - rm_min))**(1 /
                                                                    (1 + beta))
        rm_sub = np.ma.where(states['rootmoist'] <= rm_min,
                             states['rootmoist'], rm_sub)

        # Compute single components of surface runoff equation and set bounds
        c1 = ((rm_max - rm_sub) / (rm_max - rm_min))**(1 + beta)
        c1 = np.minimum(c1, 1)
        c2 = ((rm_max - rm_sub - fluxes['throu']) / (rm_max - rm_min))**(1 + beta)
        c2 = np.maximum(c2, 0) 
        # Compute surface runoff regimes
        no_qs = fluxes['throu'] < 0
        too_dry = rm_sub + fluxes['throu'] <= rm_min
        overflow = rm_sub + fluxes['throu'] >= rm_max

        # Compute subgrid surface runoff and excess flow
        excess = np.ma.where(fluxes['throu'] > (rm_cap - states['rootmoist']),
                             fluxes['throu'] + (states['rootmoist'] - rm_cap),
                             0)
        rm_res = np.ma.where(rm_min - states['rootmoist'] > 0,
                             rm_min - states['rootmoist'], 0)
        qs = fluxes['throu'] - rm_res - ((rm_max - rm_min) /
                                         (1 + beta)) * (c1 - c2)

        # very small throughfall might cause negative surface runoff
        qs = np.maximum(qs, 0)

        # Combine different surface runoff fluxes based on regime and surface state
        qs = np.ma.where(overflow, excess, qs)
        qs = np.ma.where(too_dry, 0, qs)
        qs = np.ma.where(no_qs, 0, qs)
        qs = np.ma.where(fcover['frozen'], fluxes['throu'], qs)

        # Note: Surface runoff is implicitly scaled to non-lake fraction because
        # throughfall is scaled with the non-lake fraction
        fluxes['qs'] = qs

        # Check for negative surface runoff
        if fluxes['qs'].min() < 0:
            pdb.set_trace()
            raise ValueError("Negative surface runoff: ", fluxes['qs'].min())

    # ======================================================================================================
    def get_drainage(self, fluxes, states, fcover, dt):
        '''Leakage from soil storage'''
        # computed following MPI-HM ISUBFL == 1
        wcap = np.ma.masked_where(
            self.grid['lsm'] == 0,
            self.param.wcap.values)  # Maximum water holding capacity
        # Define drainage regimes
        no_qsb = np.ma.logical_or(
            self.param.wcap.values <= 1.0e-10, states['rootmoist'] <=
            self.param.wcap.values * self.opt['qsb_low'])
        full_qsb = np.logical_and(
            self.param.wcap.values > 1.0e-10, states['rootmoist'] >=
            self.param.wcap.values * self.opt['qsb_hig'])

        # Compute standard and maximum drainage
        qsb = self.opt['qsb_min'] * dt * (states['rootmoist'] / wcap)
        maxqsb = (qsb + dt * (self.opt['qsb_max'] - self.opt['qsb_min']) *
                  ((states['rootmoist'] - wcap * self.opt['qsb_hig']) /
                   (wcap - wcap * self.opt['qsb_hig']))**self.opt['qsb_exp'])

        # Attribute drainage based on regime
        qsb = np.ma.where(no_qsb, 0, qsb)
        qsb = np.ma.where(full_qsb, maxqsb, qsb)
        qsb = np.ma.where(qsb > states['rootmoist'], states['rootmoist'], qsb)
        qsb = np.ma.where(fcover['frozen'], 0, qsb)

        fluxes['qsb'] = qsb

        # Check for negative drainage
        if fluxes['qsb'].min() < 0:
            raise ValueError("Negative subsurface drainage: ",
                             fluxes['qsb'].min())

    # ======================================================================================================
    def get_skinevap(self, fluxes, states, fcover, date):
        '''compute evaporation from skin and canopy'''
        # based on the subroutine EVAPSKIN from the old MPI-HM (ISKIN == 1)
        # however, skin and canopy evaporation and storages are computed individually
        if self.opt['with_skin']:
            # Interpolate LAI to daily state
            lai_daily = (utr.monthly_interpol(
                field=self.param['lai'], fdate=date, bounds='zero')).to_masked_array()
            # Update maximum skin and canopy reservoir capacity (local)
            maxcap = {
                    'fbare': self.opt['skincap1'],
                    'fveg': self.opt['skincap1'] * lai_daily
                    }
            skinstor = {
                    'fbare': np.ma.where(fcover['fbare'] > 0, states['skinstor'] / fcover['fbare'], 0.0),
                    'fveg': np.ma.where(fcover['fveg'] > 0, states['canopystor'] / fcover['fveg'], 0.0),
                    }
            for sktype, skevap in zip(
                    ['fbare', 'fveg'], ['skinevap', 'canoevap']):
                # Compute wet skin fraction, PET fraction and local evaporation
                wetfract = np.ma.where(maxcap[sktype] > 0,
                    (skinstor[sktype] + fluxes['rainmelt']) / maxcap[sktype], 0.0)
                wetfract = np.ma.maximum(0.0, np.ma.minimum(1.0, wetfract))
                petfract = np.ma.where(wetfract * fluxes['potevap'] > 0,
                    (skinstor[sktype] + fluxes['rainmelt']) / (wetfract * fluxes['potevap']), 0.0)
                petfract = np.ma.maximum(0.0, np.ma.minimum(1.0, petfract))
                # Compute skin and canopy evaporation and scale to grid cell
                fluxes[skevap] = fluxes['potevap'] * wetfract * petfract * fcover[sktype]
            self.param['canocap'] = xr.DataArray(maxcap['fveg'] * fcover['fveg'],
                    coords=self.param.area.coords, dims=self.param.area.dims, name='canocap', attrs={
                        'long_name': 'Maximum canopy moisture storage', 'units': 'm2 m-2'})
            self.param['skincap'] = xr.DataArray(maxcap['fbare'] * fcover['fbare'],
                    coords=self.param.area.coords, dims=self.param.area.dims, name='skincap', attrs={
                        'long_name': 'Maximum skin moisture storage', 'units': 'm2 m-2'})

        else:
            fluxes['skinevap'] = fluxes['potevap'] * 0
            fluxes['canoevap'] = fluxes['potevap'] * 0

        # Check for negative bare soil evaporation
        if fluxes['skinevap'].min() < 0:
            raise ValueError("Negative skin evaporation from bare soil: ",
                             fluxes.skinevap.min())
        if fluxes['canoevap'].min() < 0:
            raise ValueError("Negative canopy evaporation from vegetation: ",
                             fluxes.canoevap.min())

    # ======================================================================================================
    def get_transpiration(self, fluxes, states, fcover):
        '''compute plant transpiration'''
        # based on subroutine EVAPACT from the old MPI-HM (IEVAP == 4)
        wcap = np.ma.masked_where(
            self.grid['lsm'] == 0,
            self.param.wcap.values)  # Maximum water holding capacity
        crit = np.ma.masked_where(
            self.grid['lsm'] == 0,
            self.param.crit.values)  # Maximum water holding capacity
        wilt = np.ma.masked_where(
            self.grid['lsm'] == 0,
            self.param.wilt.values)  # Maximum water holding capacity
        # Define wet and dry soil moisture regimes
        max_transp = np.ma.logical_and(wcap > 1.0e-10,
                                       states['rootmoist'] >= crit)
        no_transp = np.ma.logical_or(wcap <= 1.0e-10,
                                     states['rootmoist'] <= wilt)

        # Reduce potential evaporation by canopy evap
        if self.opt['with_skin']:
            potevap_fveg = np.ma.maximum(0, fluxes['potevap'] - np.ma.where(
                fcover['fveg'] > 0, fluxes['canoevap'] / fcover['fveg'], 0))
        else:
            potevap_fveg = fluxes['potevap']

        # Set transpiration to potential for wet soil moisture regime, zero for dry soil moisture regime
        # and scale linearly with available water for transitional regime
        transp = potevap_fveg * ((states['rootmoist'] - wilt) /
                                      (crit - wilt))
        transp = np.ma.where(max_transp, potevap_fveg, transp)
        transp = np.ma.where(no_transp, 0, transp)
        # Correct transpiration wherever root zone moisture would drop below wilting point
        rm_avail = np.ma.where(states['rootmoist'] - wilt < 0, 0,
                               states['rootmoist'] - wilt)
        transp = np.ma.where(transp > rm_avail, rm_avail, transp)

        fluxes['transp'] = transp * fcover['fveg']

        # Check for negative transpiration
        if fluxes['transp'].min() < 0:
            pdb.set_trace()
            raise ValueError("Negative transpiration: ", fluxes['transp'].min())

    # ======================================================================================================
    def get_soilevap(self, fluxes, states, fcover):
        '''compute bare soil evaporation'''
        # based on the subroutine EVAPACT from the old MPI-HM (IEVAP == 4)
        wcap = np.ma.masked_where(
            self.grid['lsm'] == 0,
            self.param.wcap.values)  # Maximum water holding capacity
        wlow = np.ma.masked_where(
            self.grid['lsm'] == 0,
            self.param.wlow.values)  # Maximum water holding capacity
        # Define wet and dry soil moisture regimes
        no_sevap = np.ma.logical_or(wcap <= 1.0e-10,
                                    states['rootmoist'] <= wlow)

        # Reduce potential evaporation by skin evap
        if self.opt['with_skin']:
            potevap_baresoil = np.ma.maximum(0, fluxes['potevap'] - np.ma.where(
                fcover['fbare'] > 0, fluxes['skinevap'] / fcover['fbare'], 0))
        else:
            potevap_baresoil = fluxes['potevap']

        # Compute bare soil evaporation
        sevap = potevap_baresoil * ((states['rootmoist'] - wlow) /
                                     (wcap - wlow))
        sevap = np.ma.where(no_sevap, 0, sevap)

        # Correct bare soil evaporation wherever root zone moisture would drop below wlow
        rm_avail = np.ma.where(states['rootmoist'] - wlow < 0, 0,
                               states['rootmoist'] - wlow)
        sevap = np.minimum(rm_avail, sevap)

        fluxes['sevap'] = sevap * fcover['fbare']

        # Check for negative bare soil evaporation
        if fluxes['sevap'].min() < 0:
            raise ValueError("Negative bare soil evaporation: ",
                             fluxes.sevap.min())

    # ======================================================================================================
    def update_skincanopy(self, fluxes, states, fcover):
        '''update skin and canopy storage and correct fluxes if necessary'''
        #
        if 'log' in vars(self).keys():
            self.log.add_value(states['skinstor'], 'skinstor_old', 'Skin storage from last time step')
            self.log.add_value(states['canopystor'], 'canopystor_old', 'Canopy storage from last time step')
            self.log.add_value(fluxes['rainmelt'] * (fcover['fbare'] + fcover['fveg']), 'rainmelt_land',
                    'Rainfall and Snowmelt scaled to land fraction')
        #
        if self.opt['with_skin']:
            # Update skin and canopy reservoir
            states['skinstor'] += (fluxes['rainmelt'] * fcover['fbare'] - fluxes['skinevap'])
            states['canopystor'] += (fluxes['rainmelt'] * fcover['fveg'] - fluxes['canoevap'])
            # Adapt evaporation in case is reduces storages below zero
            fluxes['skinevap'] = np.ma.where(states['skinstor'] < 0,
                    np.maximum(0, fluxes['skinevap'] + states['skinstor']), fluxes['skinevap'])
            states['skinstor'] = np.maximum(0, states['skinstor'])
            fluxes['canoevap'] = np.ma.where(states['canopystor'] < 0,
                    np.maximum(0, fluxes['canoevap'] + states['canopystor']), fluxes['canoevap'])
            states['canopystor'] = np.maximum(0, states['canopystor'])
            # Compute throughfall to the soil
            fluxes['throu'] = (np.maximum(0, states['skinstor'] - self.param['skincap'].values)
                    + np.maximum(0, states['canopystor'] - self.param['canocap'].values))
            # Reduce skin and canopy content accordingly
            states['skinstor'] = np.minimum(self.param['skincap'].values, states['skinstor'])
            states['canopystor'] = np.minimum(self.param['canocap'].values, states['canopystor'])

        else:
            fluxes['throu'] = fluxes['rainmelt'] * (fcover['fbare'] + fcover['fveg'])

        # Check for negative moisture states
        if states['skinstor'].min() < 0:
            raise ValueError("Negative bare soil skin storage after update_skincanopy: ",
                             states['skinstor'].min())
        if states['canopystor'].min() < 0:
            raise ValueError("Negative canopy storage after update_skincanopy: ",
                             states['canopystor'].min())

        if 'log' in vars(self).keys():
            self.log.add_value(states['skinstor'], 'skinstor_new', 'Skin reservoir')
            self.log.add_value(states['canopystor'], 'canopystor_new', 'Canopy reservoir')
            self.log.add_value(fluxes['throu'], 'throu', 'Throughfall')
            self.log.add_value(fluxes['skinevap'], 'skinevap', 'Skin evaporation')
            self.log.add_value(fluxes['canoevap'], 'canoevap', 'Canopy evaporation')
            self.log.add_value(self.param['skincap'].values, 'skinmax', 'Maximum skin reservoir capacity')
            self.log.add_value(self.param['canocap'].values, 'canomax', 'Maximum canopy reservoir capacity')


    # ======================================================================================================
    def update_soil(self, fluxes, states):
        '''update soil moisture state and correct fluxes if necessary'''
        # This water balance scheme is not yet flexible enough. Later on, it needs to account
        # for the different fluxes for different land cover types!
        wcap = self.param.wcap.values

        if 'log' in vars(self).keys():
            self.log.add_value(states['rootmoist'], 'rootmoist_old', 'Root zone soil moisture from last time step')

        # Update soil moisture state
        states['rootmoist'] += (fluxes['throu'] - fluxes['qs'])
        states['rootmoist'] -= (fluxes['transp'] + fluxes['sevap'] +
                                fluxes['qsb'])

        # Add soil moisture overflow to surface runoff
        overflow = states['rootmoist'] > wcap
        fluxes['qs'] = np.ma.where(overflow,
                                   fluxes['qs'] + states['rootmoist'] - wcap,
                                   fluxes['qs'])
        states['rootmoist'] = np.ma.where(overflow, wcap, states['rootmoist'])
        if fluxes['qs'].min() < 0:
            raise ValueError("Negative surface runoff after overflow: ",
                             fluxes['qs'].min())

        # Soil below zero --> reduce evaporation and drainage equally
        if states['rootmoist'].min() < 0:
            states['rootmoist'], corflx = utr.correct_neg_stor(states['rootmoist'],
                    [fluxes['transp'], fluxes['sevap'], fluxes['qsb']])
            fluxes['transp'], fluxes['sevap'], fluxes['qsb'] = corflx

        # Check for negative soil moisture
        if states['rootmoist'].min() < 0:
            raise ValueError("Negative soil moisture after update_soil: ",
                             states.rootmoist.min())

        if 'log' in vars(self).keys():
            self.log.add_value(states['rootmoist'], 'rootmoist_new', 'Root zone soil moisture')
            self.log.add_value(fluxes['qs'], 'qs', 'Surface runoff')
            self.log.add_value(fluxes['qsb'], 'qsb', 'Subsurface runoff')
            self.log.add_value(fluxes['transp'], 'transp', 'Transpiration')
            self.log.add_value(fluxes['sevap'], 'sevap', 'Bare soil evaporation')

    # ======================================================================================================
    # Hydrological processes for lake storage
    # ======================================================================================================
    def get_lakeevap(self, fluxes, states, fcover):
        '''compute open water evaporation over lakes'''
        levap = fluxes['potevap'] * fcover['flake']

        # Don't evaporation more than the lake contains
        fluxes['levap'] = np.minimum(levap, states['lakestor'])

        # Check for negative evaporation
        if fluxes['levap'].min() < 0:
            raise ValueError("Negative lake evaporation: ",
                             fluxes['levap'].min())

    # ======================================================================================================
    def get_lakeleak(self, fluxes, states, fcover):
        '''compute leakage from lakes into groundwater using soilmoisture deficit as proxy'''

        if self.opt['with_leakage']:
            # Compute cell average soil moisture deficit
            lleak = np.ma.maximum(0, self.param['wcap'].values - states['rootmoist'])

            # Don't infiltrate more than the lake contains
            lleak = np.ma.minimum(lleak, states['lakestor'])

            # Additionally scale with lake fraction to avoid small lakes leaking
            # all their water into the ground
            fluxes['lleak'] = np.ma.where(fcover['frozen'], 0,
                    lleak * fcover['flake']**2)

            # Check for negative evaporation
            if fluxes['lleak'].min() < 0:
                raise ValueError("Negative lake leakage: ",
                                 fluxes['lleak'].min())
        else:
            fluxes['lleak'] = fluxes['rainmelt'] * 0

    # ======================================================================================================
    def update_surface_water(self, fluxes, states, fcover):
        '''update surface water storage ( == old overland flow storage)
           and compute outflow based on retention times
        '''

        if 'log' in vars(self).keys():
            self.log.add_value(states['lakestor'], 'lakestor_old', 'Lake storage from last time step')

        # # Compute rainmelt input for lake fraction
        zeros = fluxes['rainmelt'] * 0
        rainmelt = fluxes['rainmelt'] * fcover['flake']

        # Update cell average lake storage and check for negative values
        states['lakestor'] += (rainmelt + fluxes['qs'] - fluxes['levap'] - fluxes['lleak'])

        # Lake below zero --> reduce evaporation and leakage equally
        if states['lakestor'].min() < 0:
            states['lakestor'], corflx = utr.correct_neg_stor(states['lakestor'],
                    [fluxes['levap'], fluxes['lleak']])
            fluxes['levap'], fluxes['lleak'] = corflx


        # If no lake is prescribed, use unscaled lake depth
        celllake = np.where(fcover['flake'] > 0, fcover['flake'], 1)
        # Compute outflow based on storage retention time and surface state
        lake_depth = states['lakestor'] / celllake
        flowcoeff = 1.0 / (self.temporary['lag_land'] + 1.0)
        fluxes['qsl'] = np.ma.maximum(0, np.ma.minimum(states['lakestor'],
            lake_depth * flowcoeff * celllake))
        states['lakestor'] -= fluxes['qsl']

        # Check for negative lake storage
        if states['lakestor'].min() < 0:
            raise ValueError(
                "Negative lake storage after lake outflow computation: ",
                states['lakestor'].min())

        if 'log' in vars(self).keys():
            self.log.add_value(states['lakestor'], 'lakestor_new', 'Lake storage')
            self.log.add_value(fluxes['qsl'], 'qsl', 'Lake runoff')
            self.log.add_value(rainmelt, 'rainmelt_lake', 'Rainfall + Snowmelt on lake fraction')
            self.log.add_value(fluxes['levap'], 'levap', 'Lake evaporation')
            self.log.add_value(fluxes['lleak'], 'lleak', 'Lake leakage into groundwater')

    # ======================================================================================================
    # Hydrological processes for groundwater storage
    # ======================================================================================================

    def update_groundwater(self, fluxes, states):
        '''update groundwater storage ( == old baseflow storage)
           and compute outflow based on retention times
        '''
        if 'log' in vars(self).keys():
            self.log.add_value(states['groundwstor'], 'groundwstor_old', 'Groundwater storage from last time step')

        # Add drainage and lake leakage to groundwater storage
        states['groundwstor'] += (fluxes['qsb'] + fluxes['lleak'])
        zeros = fluxes['qsb'] * 0

        # Compute outflow based on storage retention time
        flowcoeff = 1.0 / (self.temporary['lag_base'] + 1.0)
        fluxes['qg'] = np.ma.maximum(0, np.ma.minimum(states['groundwstor'],
            states['groundwstor'] * flowcoeff))
        states['groundwstor'] -= fluxes['qg']

        # Check for negative groundwater storage
        if states['groundwstor'].min() < 0:
            raise ValueError(
                "Negative groundwater storage after update_groundwater: ",
                states.groundwstor.min())

        if 'log' in vars(self).keys():
            self.log.add_value(states['groundwstor'], 'groundwstor_new', 'Groundwater storage')
            self.log.add_value(fluxes['qg'], 'qg', 'Groundwater runoff')

    # ======================================================================================================
    # Hydrological processes for river routing
    # ======================================================================================================
    def update_river(self, fluxes, states):
        '''update river storage using a linear reservoir cascade
        with subtimesteps and routing'''
        import processes as prc

        conv2col = 1.0 / self.grid['landarea'] * 1000
        if 'log' in vars(self).keys():
            self.log.add_value(states['riverstor'].sum(axis=0) * conv2col, 'riverstor_old',
                               'River storage from last time step')

        f_subfl = 1.0 / self.opt['rivsubtime']
        riv_in = np.zeros_like(fluxes['qsl'])
        riv_out = np.zeros_like(fluxes['qsl'])
        freshwater = np.zeros_like(fluxes['qsl'])

        act_infl = states['infl_subtime'] * f_subfl
        # Walk through linear cascade for all subtimesteps
        for sub in range(self.opt['rivsubtime']):
            # Compute inflow for actual subtimestep
            riv_in += act_infl
            # Compute storage outflow from inflow and states
            storage, outflow = prc.linear_cascade(act_infl,
                    states['riverstor'],
                    self.temporary['ncasc_river'],
                    self.temporary['lag_river'],
                    substeps=self.opt['rivsubtime'])
            upstream = outflow + np.nan_to_num(
                    (fluxes['qsl'] + fluxes['qg']) * self.grid['landarea'] * 0.001
                    ) * f_subfl
            riv_out += upstream
            states['riverstor'] = storage * 1
            # Rout river discharge to downstream grid cell
            if np.any(np.isnan(upstream)):
                upstream = np.nan_to_num(upstream)
            act_infl, err = prc.river_routing(upstream, self.temporary['flowsinks'],
                                         self.temporary['riverflow'], len(self.temporary['riverflow']))
            if err > 1:
                raise ValueError("Routing error exceeds threshold: ",err)
            # Collect ocean inflow and substract from cell inflow
            freshwater += np.where(self.temporary['flowsinks'] > 0.5, act_infl, 0)
            act_infl = np.where(self.temporary['flowsinks'] < 0.5, act_infl, 0)

        # Time step correction for global water balance due to using
        # inflow from the last time step
        self.temporary['riv_ts_corr'] = act_infl - states['infl_subtime'] * f_subfl

        # Check for negative river storage
        if states['riverstor'].min() < 0:
            raise ValueError("Negative riverflow storage after update_river: ",
                             states['riverstor'].min())

        # Get discharge from all ocean inflow and internal sink cells
        fluxes['freshwater'] = freshwater
        fluxes['rivdis'] = riv_in
        fluxes['dis'] = riv_out
        states['infl_subtime'] = act_infl * self.opt['rivsubtime']

        if 'log' in vars(self).keys():
            self.log.add_value(states['riverstor'].sum(axis=0) * conv2col, 'riverstor_new',
                               'River storage')
            self.log.add_value(fluxes['rivdis'] * conv2col, 'riv_in', 'Upstream inflow')
            self.log.add_value(fluxes['dis'] * conv2col, 'riv_out', 'River discharge')
