import imp
import numpy as np
import pandas as pd
import theano
import warnings
from sympy.printing.theanocode import theano_function
from essm.equations import Equation
from essm.variables import Variable
from essm.variables.units import (joule, kelvin, kilogram, meter, pascal,
                                  second, watt, mol, markdown)
from essm.variables.utils import extract_variables

# Reading absolute path of this file for importing local modules
import sys
from pathlib import Path # if you haven't already done so
file = Path(__file__).resolve()
parent, root = file.parent, file.parents[1]

# Importing variable and equation definitions
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    mod = imp.load_source('Mathematical_model_variable_definitions', 
               str(parent)+'/Mathematical_model_variable_definitions.py')
names = getattr(mod, '__all__', [n for n in dir(mod) if not n.startswith('_')])
globs = globals()
for name in names:
    globs[name] = getattr(mod, name)
    
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    mod = imp.load_source('Mathematical_model_equations_definitions', 
              str(parent)+'/Mathematical_model_equations_definitions.py')
names = getattr(mod, '__all__', [n for n in dir(mod) if not n.startswith('_')])
globs = globals()
for name in names:
    globs[name] = getattr(mod, name)

def calc_slope(x):
    """
    Return slope of pandas dataframe with time index. 
    If computation fails, return slope of 0.
    
    Args:
    	x (Pandas Series): Pandas Series with time index
    	
    Returns:
    	float: Slope of linear regression through x in units of x per second
    	
    Examples:
    	>>> slope = calc_slope(df['CO2_s'][0:10])
    	>>> df_copy['dCO2_s'] = df_copy['CO2_s'].rolling('1h').apply(calc_slope)
    """
    try:
        slope = np.polyfit((x.index - x.index[0]) / 
                           np.timedelta64(1, 's'), x, 1)[0]
    except Exception as e1:
        slope = 0.
    return slope

def calc_delta(x):
    """
    Return delta of pandas dataframe with time index. 
    If computation fails, return nan. Note that the pandas rolling() 
    function in the example below uses backwards windows, i.e. the
    difference in values compared to the time interval before
    the current point.
    
    Args:
        x (Pandas Series): Pandas Series with time index
        
    Returns:
        float: last value - first value
        
    Examples:
        >>> delta = calc_delta(df['CO2_s'][0:10])
        >>> df_copy['dCO2_s'] = df_copy['CO2_s'].rolling('1h').apply(calc_delta)
    """
    try:
        delta = x[-1] - x[0]
    except Exception as e1:
        delta = np.nan
    return delta
    
def add_validity_cols(df1, mw='30s', maxrange_H=0.01, maxrange_C=0.001, 
                      maxrange_Hr=0.01, maxrange_Cr=0.001, minflow=150., 
                      vars_min = ['Flow_r', 'Flow_s'], matchboth=False, 
                      mwmatch = '10s',
                      vars_C_stability=['CO2_s'], vars_Cr_stability=['CO2_r'],
                      vars_H_stability=['H2O_s'], vars_Hr_stability=['H2O_r'],
                      vars_C_match=['MatchCO2'], vars_H_match=['MatchH2O']):
    """
    Check minimum flow requirements and stability criteria and
    add columns 'Flow_validity', 'C_validity', 'H_validity' and 'Match_validity'
    with 1 for data considered valid and 0 for invalid data.
    
    Args:
        df1 (DataFrame): Pandas dataframe with imported and calculated LI-data
        mw (int or time interval as str): Moving window size for stability calculation
        mwmatch (int): Number of points to be excluded after matching
        matchboth (bool): True if same number of points to be excluded before match
        maxrange_C (float): Permitted maximum relative variation in CO2_s within window        
        maxrange_H (float): Permitted maximum relative variation in H2O_s within window
        maxrange_Cr (float): Permitted maximum relative variation in CO2_r
        maxrange_Hr (float): Permitted maximum relative variation in H2O_r        
        minflow (float): Permitted minimum flow through IRGA
        vars_min (list): Variables for which minflow condition is applied
        vars_C_stability (list): C-variables for which maxrange is used
        vars_H_stability (list): H2O-variables for which maxrange is use1d
        vars_Cr_stability (list): C-variables for which maxrange_r is used
        vars_Hr_stability (list): H2O-variables for which maxrange_r is used
        vars_C_match (list): C-match variables to check for changing values
        vars_H_match (list): H-match variables to check for changing values
        
    Returns: pandas DataFrame of the same length as df1 with additional validity
             columns: 'Flow_validity', 'C_validity', 'H_validity' 'MatchH2O_validity'
             and 'MatchCO2_validity', where 1 represents valid and 0 invalid data.
    """

    df = df1.copy()
    df.insert(1,'Flow_validity',0) # Add empty column in second position to hold a validity flag 
    df.insert(1,'C_validity',0) # Add empty column in second position to hold a validity flag
    df.insert(1,'H_validity',0) # Add empty column in second position to hold a validity flag
    df.insert(1,'MatchH2O_validity',0) # Add empty column in second position to hold a validity flag
    df.insert(1,'MatchCO2_validity',0) # Add empty column in second position to hold a validity flag

    df_valid = pd.DataFrame()
    
    # Check if air flow through IRGAs sufficiently large
    for var1 in vars_min:
        df_valid[var1] = df[var1] >= minflow
    df.loc[df_valid.all(axis='columns'), ('Flow_validity')] = 1

    # Check if range of water variables sufficiently small
    dfH_valid = pd.DataFrame()
    for var1 in vars_H_stability:
        dfH_valid[var1] = abs((df[var1].rolling(mw).max() - df[var1].rolling(mw).min())\
                           / df[var1].rolling(mw).mean()) <= maxrange_H
    for var1 in vars_Hr_stability:
        dfH_valid[var1] = abs((df[var1].rolling(mw).max() - df[var1].rolling(mw).min())\
                           / df[var1].rolling(mw).mean()) <= maxrange_Hr
    # Write 1 for cases where all tests are True in df
    df.loc[dfH_valid.all(axis='columns'), ('H_validity')] = 1

    # Check if range of CO2 variables sufficiently small
    dfC_valid = pd.DataFrame()
    for var1 in vars_C_stability:
        dfC_valid[var1] = abs((df[var1].rolling(mw).max() - df[var1].rolling(mw).min())\
                           / df[var1].rolling(mw).mean()) <= maxrange_C
    for var1 in vars_Cr_stability:
        dfC_valid[var1] = abs((df[var1].rolling(mw).max() - df[var1].rolling(mw).min())\
                           / df[var1].rolling(mw).mean()) <= maxrange_Cr
    # Write 1 for cases where all tests are True in df
    df.loc[dfC_valid.all(axis='columns'), ('C_validity')] = 1
    
    # Check for matching events
    if matchboth:
        if not isinstance(mwmatch, int):
            print('If matchboth=True, mwmatch must be an integer.')
            return
    # CO2 matching
    dfm_valid = pd.DataFrame()
    for var1 in vars_C_match:
        dfm_valid[var1] = df[var1].rolling(mwmatch, center=matchboth)\
                          .apply(calc_delta) == 0.0
    # Write 1 for cases where all tests are True in df
    df.loc[dfm_valid.all(axis='columns'), ('MatchCO2_validity')] = 1
    # H2O matching
    dfm_valid = pd.DataFrame()
    for var1 in vars_H_match:
        dfm_valid[var1] = df[var1].rolling(mwmatch, center=matchboth)\
                          .apply(calc_delta) == 0.0
    # Write 1 for cases where all tests are True in df
    df.loc[dfm_valid.all(axis='columns'), ('MatchH2O_validity')] = 1
    
    return df

def fun_compfluxes(df, df_units, eqs=[(eq_sE_c_a, 'sE'), (eq_sA_c_a, 'sA'), 
     (eq_sE_dt1, 'sE_dt'), (eq_sA_dt1, 'sA_dt'),
     (eq_sA_avg1, 'sA_avg'), (eq_sE_avg1, 'sE_avg')],
                   volume=0.03*0.03*0.02, window=2, vdict=None):
    """
    Compute fluxes based on eqs and return new df and dictionary with units.
    If non-steady-state equations are provided in eqs, compute
    change in c_a and w_a based on window size provided in `window` (either
    int or string, e.g. '30s' for 30 seconds,  
    and use chamber air volume provided in m ** 3 as `volume`.
    """
    df_copy = df.copy()
    df_copy_units = df_units.copy()
    if vdict is None:
        vdict = Variable.__defaults__.copy()
    vdict[V_c] = volume
    # Extract set of variables needed for eqs
    vars1 = []
    eqs1 = []
    for (eq1, name) in eqs:
        expr = eq1.rhs.subs(vdict)
        eqs1.append(expr)
        vars1.append(list(extract_variables(expr)))
    vars1 = set([item for sublist in vars1 for item in sublist]) # set of flattened list

    # Prepare datasets
    # Type conversions
    df_copy["Flow"] = df_copy["Flow"].astype(float)
    df_copy["H2O_r"] = df_copy["H2O_r"].astype(float)
    df_copy["H2O_s"] = df_copy["H2O_s"].astype(float)
    df_copy["CO2_r"] = df_copy["CO2_r"].astype(float)
    df_copy["CO2_s"] = df_copy["CO2_s"].astype(float)
    # Derived variables
    muovals = df_copy['Flow'] * 1e-6   # converting from umol/s to mol/s
    wavals = df_copy['H2O_s'] * 1e-3   # converting from mmol/mol to mol/mol
    wovals = df_copy['H2O_r'] * 1e-3   # converting from mmol/mol to mol/mol
    cavals = df_copy['CO2_s'] * 1e-6   # converting from umol/mol to mol/mol
    covals = df_copy['CO2_r'] * 1e-6   # converting from umol/mol to mol/mol
    Tavals = df_copy['Tair'] + 273.13  # converting to K
    Pavals = df_copy['Pa']*1000   # converting to Pa

    # Create dictionary with var names and associated 1D-arrays 
    varsdict = {
        c_a: cavals, 
        c_o: covals,  
        mu_o: muovals,
        P_a: Pavals,
        T_a: Tavals,
        w_a: wavals,
        w_o: wovals}

    # Compute other derived variables as needed
    if dca_dt in vars1:
        dcavals = np.zeros_like(df_copy['H2O_s'])
        varsdict[dca_dt] = dcavals
    if Delta_t in vars1:    
        dtvals = df['time1'].rolling(window).apply(calc_delta)
        varsdict[Delta_t] = dtvals
    if dca_dt in vars1:
        dcavals = df_copy['CO2_s'].rolling(window).apply(calc_slope)\
                  * 10 ** -6  # in mol/s/s
        varsdict[dca_dt] = dcavals
    if dwa_dt in vars1:
        dwavals = df_copy['H2O_s'].rolling(window).apply(calc_slope)\
                  * 10 ** -3  # in mol/s/s
        varsdict[dwa_dt] = dwavals
    if c_abar in vars1:
        cabarvals = df_copy['CO2_s'].rolling(window).mean()\
                  * 10 ** -6  # in mol/s
        varsdict[c_abar] = cabarvals
    if c_obar in vars1:
        cobarvals = df_copy['CO2_r'].rolling(window).mean()\
                  * 10 ** -6  # in mol/s
        varsdict[c_obar] = cobarvals
    if w_abar in vars1:
        wabarvals = df_copy['H2O_s'].rolling(window).mean()\
                  * 10 ** -3  # in mol/s
        varsdict[w_abar] = wabarvals
    if w_obar in vars1:
        wobarvals = df_copy['H2O_r'].rolling(window).mean()\
                  * 10 ** -3  # in mol/s
        varsdict[w_obar] = wobarvals
    if mu_obar in vars1:
        muobarvals = df_copy['Flow'].rolling(window).mean()\
                  * 1e-6   # converting from umol/s to mol/s
        varsdict[mu_obar] = muobarvals
    if Delta_ca in vars1:
        deltacavals = df_copy['CO2_s'].rolling(window).apply(calc_delta)\
                  * 10 ** -6  # in mol/s
        varsdict[Delta_ca] = deltacavals
    if Delta_wa in vars1:
        deltawavals = df_copy['H2O_s'].rolling(window).apply(calc_delta)\
                  * 10 ** -3  # in mol/s   
        varsdict[Delta_wa] = deltawavals

    # Define theano function for all equations with all vars
    f1 = theano_function(vars1, eqs1, dim=1)  
    # Create input list for theano function, run and create dictionary with results
    input_list = [varsdict[var1] for var1 in vars1]
    resf1 = f1(*input_list) # run theano-function
    # Append results to df and units
    res = {}
    for i in range(len(eqs)):
        key = eqs[i][1]
        res[key] = resf1[i]
        df_copy[key] = res[key]
        df_copy_units[key] = markdown(eqs[i][0].lhs.definition.unit)
    return(df_copy, df_copy_units)
