"""
Author: Simon Kunze

Script for numerically solving the Smoluchowski equation
for free molecular flow in straight channels.

This uses the shapely library for easy handling of
arbitrary cross sections, see https://pypi.org/project/Shapely/

"""

import numpy as np
import matplotlib.pyplot as plt

from shapely.geometry import LineString, Polygon, Point, MultiPoint


def _sample_area(x_min, x_max, y_min, y_max, res_x, res_y, area):
    """
    Samples the cross-sectional area. This is dS in Smoluchowski's integral.
    
    The resolution is given for the bounding box. The total
    resolution will be smaller for a polygon smaller than the
    bounding box.

    Parameters
    ----------
    x_min, x_max, y_min, y_max : float
        Extent of the bounding box.
    res_x, res_y : float
        Resolution across the x and y coordinate.
    area : shapely.geometry.Polygon
        The area of the channel cross section.

    Returns
    -------
    area_list : list of 2D numpy arrays
        A list with the coordinates of each sample point.

    """
    
    # make the boundary just a bit smaller to avoid sampling
    # points exactly on the boundary
    eps_x = (x_max-x_min)*0.01
    eps_y = (y_max-y_min)*0.01
    # create the boundary mesh
    mesh_x = np.linspace(x_min+eps_x, x_max-eps_x, res_x)
    mesh_y = np.linspace(y_min+eps_y, y_max-eps_y, res_y)
    mx, my = np.meshgrid(mesh_x, mesh_y)
    area_list = []
    # filter for points within the polygon
    for px, py in zip(mx.flatten(), my.flatten()):
        pp = Point(px, py)
        # The point must not be exactly on the boundary:
        if not pp.intersects(area.exterior):
            if pp.intersects(area):
                area_list.append(np.array([px, py]))
    return area_list

def _sample_boundary(origin, boundary, surf_points=100):
    """
    # Samples equi-angled points on the boundary as seend from the origin.
    This is dS' in Smoluchowski's expression.

    Parameters
    ----------
    origin : 2D numpy array
        The position from which the boundary is to be sampled.
    boundary : shapely.geometry.Polygon.exterior
        The boundary of the channel cross section.
    surf_points : int, optional
        The number of boundary points to sample. The default is 100.

    Returns
    -------
    bound_list : list of 2D numpy arrays
        A list with the coordinates of each sample point.

    """
    
    x_min, y_min, x_max, y_max = boundary.bounds
    
    # Create the sample points along a unit circle. This ensures
    # equi-angled sample points.
    vec_x = np.cos(np.linspace(0, 2*np.pi, surf_points))
    vec_y = np.sin(np.linspace(0, 2*np.pi, surf_points))
    bound_list = []  # list for storing the coordinates
    # diagonal of the bounding box, for scaling the vector below:
    diag = np.sqrt((x_max-x_min)**2 + (y_max-y_min)**2)
    # now, loop through all the equi-angled points on the unit circle
    for xx, yy in zip(vec_x, vec_y):
        # create a vector in the direction of the sample point
        vector = np.array([xx, yy])
        # scale the vector to the diagonal of the bounding box,
        # so it will always reach outside the bounding box
        vector = vector * diag
        # draw a line from the origin along the vector
        line = LineString((origin, origin + vector))
        # Get the intersection of that line with the boundary. This
        # intersection is the sample point on the boundary.
        inters = line.intersection(boundary)
        # This check is done for concave geometries:
        if type(inters) is MultiPoint:
            # If true, there are multiple intersection points along the line.
            # Now, find the intersection closest to the origin.
            distances = [np.linalg.norm(np.array([inter.x, inter.y]) - origin)
                         for inter in inters]
            min_index = np.argmin(distances)
            bound_list.append(np.array([inters[min_index].x,
                                        inters[min_index].y]))
        else:
            bound_list.append(np.array([inters.x, inters.y]))
    return bound_list

def _get_int_R_d_eps(area_point, bound_list):
    """
    Calculates the R*d_eps part of the integral.
    This is proportional to the mean distance of the area_point
    to the sampling points on the boundary.

    Parameters
    ----------
    area_list :2D numpy arrays
        Area sampling point.
    bound_list : list of 2D numpy arrays
        List of boundary sampling points.

    Returns
    -------
    int_R_d_eps : float
        Result of the integral for the area around area_point.

    """
    
    # d_eps is the angle between to boundary sampling points.
    d_eps = 2*np.pi / len(bound_list)
    # This is a list of the distance of each boundary sampling point
    # to the area sampling point:
    R_list = np.linalg.norm(area_point-bound_list, axis=1)
    # Calculate the integral:
    int_R_d_eps = np.sum(R_list) * d_eps
    return int_R_d_eps

def get_smol_int(area, mesh_points=500, surf_points=100, plot=False):
    """
    Calculate the geometry-dependent part of the smoluchowski integral.
    This is A in the Smoluchowski paper and Lambda in the manuscript.

    Parameters
    ----------
    area : shapely.geometry.Polygon
        The channel cross section.
    mesh_points : int, optional
        The number of area points to sample. The default is 500.
    surf_points : int, optional
        The number of boundary points to sample. The default is 50.
    plot : bool, optional
        Plot graphical results. The default is False.

    Returns
    -------
    float
        Result of the geometry-dependent part of the integral.

    """
    
    boundary = area.exterior
    x_min, y_min, x_max, y_max = area.bounds
    ratio = (x_max-x_min) / (y_max-y_min)
    res_y = int(np.sqrt(mesh_points/ratio))
    res_x = int(mesh_points / res_y)
    
    # sample the area
    area_list = _sample_area(x_min, x_max, y_min, y_max,
                            res_x, res_y, area)
    
    # sample the boundary for each area point
    area_bound_list = []
    for area_point in area_list:
        area_bound_list.append(_sample_boundary(area_point,
                                                boundary,
                                                surf_points=surf_points))
    
    # calculate the d_eps integral for each area point
    int_R_d_eps_list = []
    for area_point, bound_list in zip(area_list, area_bound_list):
        int_R_d_eps_list.append(_get_int_R_d_eps(area_point, bound_list))
    
    # caluclate the dS integral
    dS = area.area / len(area_list)
    int_dS = np.sum(np.array(int_R_d_eps_list)) * dS
    
    if plot:
        figx = 9
        figy = 9 / ratio
        
        # show the varying influence of the boundary mesh 
        # for different cross sectional points
        for ii in np.random.randint(0, high=len(area_list), size=3):
            fig, ax = plt.subplots(figsize=(figx,figy))
            ax.scatter(area_list[ii][0], area_list[ii][1], color="k")
            norm_R_list = np.linalg.norm(area_list[ii] 
                                         - area_bound_list[ii], axis=1)
            max_R = np.max(norm_R_list)
            for bound, norm_R in zip(area_bound_list[ii], norm_R_list):
                ax.scatter(bound[0], bound[1], color=f"{1-norm_R/max_R}")
            ax.plot(*boundary.xy, color='k', alpha=0.5)
            
        # show the influence of the cross sectional points on the dS integral
        c_max = np.max(int_R_d_eps_list)
        c_min = np.min(int_R_d_eps_list)
        fig, ax = plt.subplots(figsize=(figx,figy))
        for area_point, int_R_d_eps in zip(area_list, int_R_d_eps_list):
            ax.scatter(area_point[0], area_point[1],
                       color=f"{1-(int_R_d_eps-c_min)/(c_max-c_min)}")
        ax.plot(*boundary.xy, color='k', alpha=0.5)
            
    
    return int_dS

if __name__ == "__main__":
    
    # benchmark for circular and rectangular cross section
    # comparing the numerical result with the known
    # analytical solution

    # circular:
    radius = 0.909   # this results in the same hydraulic diameter 
	                 # as a rectangle with h=1 and w=10
    p1 = Point([radius, radius])
    area = p1.buffer(radius, resolution=40)
    
    int_dS_circ = get_smol_int(area, plot=False)
    
    smolu_circ = 16 * radius**3 * np.pi / 3
    
    print(f"error for circular: {(int_dS_circ/smolu_circ - 1)*100:.2f} %")
    print(f"Lambda for circ: {int_dS_circ}")
    
    # rectangular:
    h = 1
    w = 10
    p1 = [0, 0]
    p2 = [0, h]
    p3 = [w, h]
    p4 = [w, 0]
    area = Polygon(shell=(p1, p2, p3, p4))
    
    int_dS_rect = get_smol_int(area, plot=False)
    
    smolu_rect = 2*(h**2*w*np.log(w/h+np.sqrt(1+(w/h)**2))
                    + h*w**2*np.log(h/w+np.sqrt(1+(h/w)**2))
                    - (h**2+w**2)**(3/2) / 3
                    + (h**3+w**3) / 3)
    
    print(f"error for rectangular: {(int_dS_rect/smolu_rect - 1)*100:.2f} %")
    print(f"Lambda for rect: {int_dS_rect}")