import h5py
import pandas as pd
import numpy as np

from scipy.spatial import Delaunay
from scipy.interpolate import LinearNDInterpolator

def geometry(file_name):
    with h5py.File(file_name) as f:
        domain_mesh_name = next(iter(f['meshes']))
        domain_mesh = f['meshes'][domain_mesh_name]
        geo = domain_mesh['geometry']
        g = geo[0,:,:]
    return g


def search_points(file_name, points):
    g  = geometry(file_name)
    mask = np.any([(np.isclose(s,g)) for s in points],axis=0)
    points_mask = np.all(mask, axis=1)
    tol=0.1
    l = [np.argmax((np.isclose(p[0], g[:, 0],atol=tol) & (np.isclose(p[1], g[:, 1],atol=tol) & (np.isclose(p[2], g[:, 2],atol=tol))))) for p in points]
    return l

def times(file_name, t_start,t_end):
    with h5py.File(file_name) as f:
        time = f['times'][:]
        n1 = np.where(time==t_start)[0].item()
        n2 = np.where(time==t_end)[0].item()
        return slice(n1,n2)

def property(mesh, name,points_mask,time_slice):
    data = mesh[name][time_slice]
    # Align with point mask
    data_t = np.transpose(data)
    return data_t[points_mask]


def property2(mesh, name,points_mask,time_slice,r):
    data = mesh[name][time_slice,:,r]
    # Align with point mask
    data_t = np.transpose(data)
    return data_t[points_mask]


def experiment_from_hdf(file_name, points, time_slice, interests):
    with h5py.File(file_name) as f:
        domain_mesh = next(iter(f['meshes']))
        mesh = f['meshes'][domain_mesh]
        times = f['times'][time_slice]
        point_scalars =dict((interest,pd.DataFrame(property(mesh,interest,points,time_slice)))   for interest in interests if len(mesh[interest].shape)==2)
        for df in point_scalars.values():
            df.columns=times


        point_vectors_names = [interest for interest in interests if len(mesh[interest].shape)>=3]
        for attribute_name in point_vectors_names:
            m,n,r = mesh[attribute_name].shape
            for component in range(0,r):
                data = property2(mesh,attribute_name,points,time_slice,component)
                component_name = "{}_{}".format(attribute_name,component)
                d = pd.DataFrame(data)
                d.columns=times
                point_scalars[component_name]=d
            
    d = pd.concat(point_scalars.values(), keys=point_scalars.keys())
    return d


def study(cases, case_names, points, t_start, t_end, interests):
    # use first data set to find points and time_steps
    first_data_set=cases[next(iter(cases))]
    time_slice = times(first_data_set,t_start,t_end)

    # combine all experiments
    dfs = [experiment_from_hdf(file,search_points(file,points.values()),time_slice,interests) for case, file in cases.items()]
    df = pd.concat(dfs,keys=cases.keys(),names=case_names,axis=1)

    # 2 = time
    # now columns are scenarios only
    df = df.stack(len(case_names))
    df.columns.names=case_names
    df.index.names=('attribute','point','time')

    # recalculating the time values
    df = df.reset_index()
    df['time']=df['time']#-t_start
    df = df.set_index(['attribute','point','time'])
    return df


def property3(mesh, name, tri, all_points, selected_points):
    data = mesh[name]
    values = data[:,:].transpose()
    interp = LinearNDInterpolator(all_points, values)
    Z = interp(selected_points)
    return Z


def experiment_interpolate(file_name, points_selected, interests):
    f= h5py.File(file_name)
    domain_mesh = next(iter(f['meshes']))
    mesh = f['meshes'][domain_mesh]
    geo = mesh['geometry']
    g = geo[0,:,:]
    points_d = g[:,:]
    points_t=[tuple(p) for p in points_d]

    tri = Delaunay(points_t)  # Compute the triangulation

    times = f['times'][:]/31557600

    point_scalars =dict((interest,pd.DataFrame(property3(mesh,interest,tri, points_t,points_selected)))   for interest in interests if len(mesh[interest].shape)==2)
    for df in point_scalars.values():
        df.columns=times

    d = pd.concat(point_scalars.values(), keys=point_scalars.keys())
    return d



def study_interpolate(cases, case_names, selected_points, interests):
    # use first data set to find points and time_steps
    # combine all experiments
    dfs = [experiment_interpolate(file,selected_points,interests) for file in cases.values()]
    df = pd.concat(dfs,keys=cases.keys(),names=case_names,axis=1)

    # 2 = time
    # now columns are scenarios only
    df = df.stack(2)
    df.columns.names=case_names
    df.index.names=('attribute','point','time')

    # recalculating the time values
    df = df.reset_index()
    df = df.set_index(['attribute','point','time'])
    return df
