import numpy as np
import cupy as cp

def make_cl(N):
    ns = np.array(range(N+1))
    return (-np.cos(np.pi * ns / N) + 1)/2

def batched_lagrange(xs, alpha, interpolating_points, verbose=False):
    ell_x = np.prod(xs.reshape(1, -1) - interpolating_points.reshape(-1, 1), axis=0)
    masked_points = np.ma.array(interpolating_points, mask=False)
    masked_points.mask[alpha] = True
    lambda_j = 1/np.prod(interpolating_points[alpha] - masked_points)
    return ell_x * lambda_j / (xs - interpolating_points[alpha])

def cheb_matrix(x, points, flp=None, use_cupy=False, exact_tol=1e-14):
    vp = cp if use_cupy else np
    x.reshape(-1)
    if flp is None:
        per_x_points = points.reshape(1, -1)
    else:
        ind = vp.searchsorted(points, x)
        inds = ind.reshape(-1, 1) + vp.array(range(-flp, flp+1)).reshape(1, -1)
        inds = vp.abs(inds)
        inds = inds - 2 * vp.maximum(inds - points.shape[0] + 1, 0)
        per_x_points = points[inds] # x by flp 
    denominator_matrix = x.reshape(-1, 1) - per_x_points
    closest_points = vp.argmin(vp.abs(denominator_matrix), axis=1)
    is_exact = vp.take_along_axis(vp.abs(denominator_matrix), closest_points.reshape(-1, 1), axis=1).flatten() < exact_tol
    denominator_matrix = denominator_matrix[vp.logical_not(is_exact)]
    
    
    numerator_product = vp.prod(denominator_matrix, axis=-1)
    per_x_numerators = numerator_product.reshape(-1, 1) / (denominator_matrix)


    
#    numerator_product = vp.prod(x.reshape(-1, 1) - per_x_points, axis=-1)
#    per_x_numerators = numerator_product.reshape(-1, 1) / (x.reshape(-1, 1) - per_x_points)
    denominator_points = vp.expand_dims(per_x_points, axis=-1) - vp.expand_dims(per_x_points, axis=-2) + vp.expand_dims(vp.eye(len(per_x_points[0])), axis=0) # Identity removes zero along diagonal
    per_x_denominators = vp.prod(denominator_points, axis=-1)
    cheb_interp_matrix = per_x_numerators / per_x_denominators # x by flp, expecting different points each time
    final_matrix = vp.zeros((x.shape[-1], flp if flp is not None else per_x_points.shape[-1]))
    final_matrix[vp.logical_not(is_exact)] = cheb_interp_matrix
    if is_exact.any():
        to_put = vp.take_along_axis(final_matrix, closest_points.reshape(-1, 1), axis=1).flatten()
        to_put[is_exact] = 1
        vp.put_along_axis(final_matrix, closest_points.reshape(-1, 1), to_put.reshape(-1, 1), axis=1)
    return final_matrix

def apply_chebyshev(x, weights, points, flp=None):
    x_shape = x.shape
    densified_interpolation_matrix = cheb_matrix(x, points, flp)
    if flp is None:
        result = np.stack([
            np.tensordot(flp_vec, weights, axes=1) for flp_vec in densified_interpolation_matrix
        ], axis=0)
    else:
        result = np.stack([
            np.tensordot(flp_vec, weights[flp_indices], axes=1) for flp_vec, flp_indices in zip(densified_interpolation_matrix, inds)
        ], axis=0)
        
    
    return result.reshape(list(x_shape) + list(result.shape)[1:])