import os
import glob
import numpy as np
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from uncertainties import unumpy as unp
from uncertainties import ufloat as ufl


def print_netcdf_structure(ds):
    print('----> General structure:')
    print(ds)
    print('----> Groups and variables:')
    for grp in sorted(ds.groups.keys()):
        print('--> {}:'.format(grp))
        for var in sorted(ds.groups[grp].variables.keys()):
            if hasattr(ds.groups[grp][var], 'comments'):
                print('. {} [{}], {}'.format(var, ds.groups[grp][var].unit,
                                             ds.groups[grp][var].comments))
            else:
                print('. {} [{}]'.format(var, ds.groups[grp][var].unit))


def load_var(var):
    if hasattr(var, 'std'):
        if var[:].data.size == 1:
            return ufl(var[:].data[()], var.std)
        else:
            return unp.uarray(var[:].data, var.std)
    else:
        return var[:]


# listing_runs
list_run_paths = sorted(glob.glob(os.path.join('runs', '*.nc')))

# ##### Printing netcdf file structure
ds = Dataset(list_run_paths[-1])
print_netcdf_structure(ds)

# #### Loading data
wanted_keys = ['Bottom slope', 'Geometrical Froude number', 'h0 (lock height)']
print('loading datasets ...')
DATA = {}
for run_path in list_run_paths:
    run = run_path.split(os.sep)[-1].replace('.nc', '')
    ds = Dataset(run_path)
    if ds.particle_type == 'silica sand':
        DATA[run] = {}
        DATA[run]['set-up'] = ds.set_up
        for key in wanted_keys:
            for grp in sorted(ds.groups.keys()):
                if key in ds.groups[grp].variables.keys():
                    DATA[run][key] = load_var(ds.groups[grp].variables[key])

# ### Reproducing curve Fr = f(\theta)
# organizing data
RUNS = sorted(DATA.keys())
possible_slopes_setups = sorted(set([(DATA[run]['Bottom slope'].n,
                                      DATA[run]['set-up']) for run in RUNS]))

FROUDES = np.array([
    np.array([DATA[run]['Geometrical Froude number'] for run in RUNS
              if ((DATA[run]['Bottom slope'].n == slope_setup[0]) & (DATA[run]['set-up'] == slope_setup[1]))]
             ).mean() for slope_setup in possible_slopes_setups
])
SLOPES = np.array([
    np.array([DATA[run]['Bottom slope'] for run in RUNS
              if ((DATA[run]['Bottom slope'].n == slope_setup[0]) & (DATA[run]['set-up'] == slope_setup[1]))]
             ).mean() for slope_setup in possible_slopes_setups
])
SETUPS = np.array([slope_setup[1] for slope_setup in possible_slopes_setups])

# plotting figure
fig, ax = plt.subplots(1, 1, constrained_layout=True)

x = unp.nominal_values(SLOPES)
xerr = unp.std_devs(SLOPES)
y = unp.nominal_values(FROUDES)
yerr = unp.std_devs(FROUDES)

for setup in set(SETUPS):
    mask = (SETUPS == setup)
    ax.errorbar(x[mask], y[mask], xerr=xerr[mask], yerr=yerr[mask],
                fmt='.' if setup == 'set-up 1' else 's',
                mfc=None if setup == 'set-up 1' else 'white')
ax.set_xlabel(r'Bottom slope, $\theta$ [$\circ$]')
ax.set_ylabel(r'Froude number, $\langle \mathcal{F}r\rangle$')
plt.show()