### The following contains algorithms to compute combinatorial types of slices of cubes. 
### While most code works more in general, some details in the implementation are specific for cubes.

### certified numerical algorithm for full-dimensional cells of a hyperplane arrangement
function crit_pts(A; G = x->x, num = nothing, cert = false)
    ## INPUT:
    ## A: matrix whose columns are the vertices of a polytope
    ## G: group action on the solution set
    ## num: number of solutions to be found
    ## cert: true or false, decides if to certify the solutions

    @var x[1:size(A,1)] # these are the variables
    @var u[1:size(A,2)] v # these are the parameters for the exponents 
    # define the linear forms from the vertices of the cube
    linforms = transpose(x)*A
    psi = sum(u[i]*log(linforms[i]) for i = 1:length(linforms)) - v*log(sum(x.^2)+1)
    # the solutions to the following system are the critical points of psi, in standard bijection with open chambers of the hyperplane arrangement
    CritPts = System(differentiate(psi,x), parameters = [u;v])
    
    # optional number of solutions
    if isa(num, Int)
        mon = monodromy_solve(CritPts, target_solutions_count = div(num,10)+1)
        CritOnes = solve(CritPts, solutions(mon); start_parameters = parameters(mon), target_parameters = [ones(length(u));size(A,2)])
        sols = monodromy_solve(CritPts, solutions(CritOnes), [ones(length(u));size(A,2)], group_action = G, target_solutions_count = num)
    else
        mon = monodromy_solve(CritPts)
        CritOnes = solve(CritPts, solutions(mon); start_parameters = parameters(mon), target_parameters = [ones(length(u));size(A,2)])
        sols = monodromy_solve(CritPts, solutions(CritOnes), [ones(length(u));size(A,2)], group_action = G)
    end
    
    # optional certification
    if cert
        cert_pts = HomotopyContinuation.certify(CritPts, solutions(sols), target_parameters = [ones(length(u));size(A,2)])
        return real(solutions(sols)), cert_pts
    else 
        return real(solutions(sols))
    end
end

### Normalize a vector by its first nonzero entry
function normalize_column(column)
    ## INPUT:
    ## column: vector 

    first_nonzero = findfirst(x -> x != 0, column)
    if isnothing(first_nonzero)
        return zeros(size(column))  # Return zero vector for a column of all zeros
    else
        return column / column[first_nonzero]
    end
end

### delete redundant columns of a matrix, up to scalar multiple
function remove_duplicate_cols(matrix)
    ## INPUT:
    ## matrix: matrix

    normalized_matrix = hcat([normalize_column(matrix[:, i]) for i in 1:size(matrix, 2)]...)

    # Use a dictionary to track unique columns
    unique_columns = Dict{Vector{Float64}, Int}()
    for col_index in 1:size(normalized_matrix, 2)
        # Use rounded column as a key for comparison
        col_tuple = round.(normalized_matrix[:, col_index], digits=8)
        if !(col_tuple in keys(unique_columns))
            unique_columns[col_tuple] = col_index
        end
    end

    # Extract the unique columns based on indices
    unique_indices = collect(values(unique_columns))
    result_matrix = matrix[:, unique_indices]
    
    return result_matrix
end

### restrict a matrix to the subspace associated to a list of vertices
function matrix_remove_verts(A, V)
    ## INPUT:
    ## A: matrix whose columns are the vertices of a polytope
    ## V: vector of columns of A (i.e., vertices of the polytope)
    rw, cl = size(A)

    newA = copy(A)
    newV = copy(V)
    vert_ind = [] # this records the indices of the original matrix get removed by each vector in V
    while newV != []
        rw, cl = size(newA)
        vert = newV[1]
        i = findfirst(x -> x == vert, eachcol(newA))
        k = findfirst(x -> x != 0, vert)
        B = zeros(eltype(A), rw, cl)
        for j in 1:cl
            B[:,j] = newA[:,j] - vert * newA[k,j] / vert[k]
        end
        newA = B[Not(k), Not(i)] # remove row k, and column i, that are zero
        vv = [v - vert * v[k] / vert[k] for v in newV]
        vv = [v[Not(k)] for v in vv]
        newV = filter(v -> any(x -> x != 0, v), vv)
        push!(vert_ind, [vert, k])
    end
    return newA, vert_ind
end

### map the points from a subspace back to the full-dimensional space
function back_to_full_space(pts, v_ind)
    ## INPUT:
    ## pts: list of points in a lower dimensional space
    ## v_ind: list of pairs [vector, index]. The vector was used to restrict to lower dimensional subspaces, the index denotes the variable that the vector removed

    newpts = copy(pts)
    for vi in reverse(v_ind)
        vv, ii = vi
        newpts = [[p[1:ii-1]...,-vv[Not(ii)]⋅p/vv[ii],p[ii:end]...] for p in newpts]
    end
    return newpts
end

### certified numerical algorithm restricted to a list of vertices
function crit_pts_lower(A, VV; G = x->x, N = nothing, cert = false)
    ## INPUT:
    ## A: matrix whose columns are the vertices of a polytope
    ## VV: vector of columns of A (i.e., vertices of the polytope)
    ## G: group action on the solution set
    ## num: list of numbers of solutions to be found relative to each restriction to the lists in VV
    ## cert: true or false, decides if to certify the solutions

    tot_pts = []
    if N === nothing
        if cert
            certificates = []
            for V in VV
                B, vert_ind = matrix_remove_verts(A, V)
                sols, c = crit_pts(remove_duplicate_cols(B); G = G, cert = cert)
                push!(tot_pts, back_to_full_space(sols, vert_ind)...)
                push!(certificates, c)
            end
            return tot_pts, certificates
        else
            for V in VV
                B, vert_ind = matrix_remove_verts(A, V)
                sols = crit_pts(remove_duplicate_cols(B); G = G)
                push!(tot_pts, back_to_full_space(sols, vert_ind)...)
            end
            return tot_pts
        end
    else
        if cert
            certificates = []
            for (V, n) in zip(VV, N)
                B, vert_ind = matrix_remove_verts(A, V)
                sols, c = crit_pts(remove_duplicate_cols(B); G = G, num = n, cert = cert)
                push!(tot_pts, back_to_full_space(sols, vert_ind)...)
                push!(certificates, c)
            end
            return tot_pts, certificates
        else
            for (V, n) in zip(VV, N)
                B, vert_ind = matrix_remove_verts(A, V)
                sols = crit_pts(remove_duplicate_cols(B); G = G, num = n)
                push!(tot_pts, back_to_full_space(sols, vert_ind)...)
            end
            return tot_pts
        end
    end
end

### normals of hyperplanes through d points in d-dimensional space
function pts_max_nvert(VV)
    ## INPUT:
    ## VV: vector of vectors of points (i.e., vertices of a polytope)

    tot_pts = []
    for V in VV
        M = transpose(hcat(V...))
        if rank(M) == length(V)
            push!(tot_pts, [nullspace(M)...])
        end
    end
    return tot_pts
end


### compute and compare combinatorial types
function comb_types(P,pts)
    ## INPUT:
    ## P: polytope that you want to slice: for central slices, use P in d-space; for affine slices, put P in (d+1)-dim space at height 1
    ## pts: type vector of vectors. Each vector is the normal to a hyperplane used to slice P

    d = dim(P)
    n_cones = length(pts) # number of slices to compare
    fvec = []
    slcs = []
    nrls =[]
    @show n_cones
    
    for i in 1:n_cones
        if rem(i,100)==0 # display which slice is comparing, just every 100
            @show i
        end
        u_sample = pts[i]
        u_perp = polyhedron(transpose([u_sample -u_sample]),[0; 0]) # hyperplane orthogonal to the i-th point
        slice = intersect(P,u_perp)
        
        if dim(slice)==d-1
            fv = f_vector(slice) # compute the f-vector 
            if fv ∉ fvec # if it is the first time you see this f-vector, just add the slice
                push!(fvec, fv)
                push!(slcs, slice)
                push!(nrls, u_sample)
            else # if you've seen this f-vector before, then compare the combinatorial type
                same_f_ind = findall(==(fv),fvec)
                iso = [Polymake.polytope.isomorphic(slice.pm_polytope,sl.pm_polytope) for sl in slcs[same_f_ind]] # this uses polymake's functions
                if 1 ∉ iso # if it is a new combinatorial type, add the slice
                    push!(fvec, fv)
                    push!(slcs, slice)
                    push!(nrls, u_sample)
                end
            end
        end
    end
    return fvec, slcs, nrls
end

### compare combinatorial types from several lists
function comb_types_lists(F, S, U)
    ## INPUT:
    ## F: list of lists of f-vectors
    ## S: list of lists of corresponding polytopes
    ## N: list of lists of corresponding normals (normals to the defining hyperplane of the slice)

    fvec = []
    slcs = []
    nrls = []
    total = length(F)
    for i in 1:length(F)
        @show i, total
        n1 = length(fvec) 
        f2 = F[i]
        s2 = S[i]
        u2 = U[i]
        n2 = length(f2) 
        for j in 1:n2
            if f2[j] ∉ fvec[1:n1] # if it is the first time you see this f-vector, just add the slice
                push!(fvec, f2[j])
                push!(slcs, s2[j])
                push!(nrls, u2[j])
            else # if you've seen this f-vector before, then compare the combinatorial type
                same_f_ind = findall(==(f2[j]),fvec[1:n1])
                iso = [Polymake.polytope.isomorphic(s2[j].pm_polytope,sl.pm_polytope) for sl in slcs[same_f_ind]] # this uses polymake's functions
                if 1 ∉ iso # if it is a new combinatorial type, add the slice
                    push!(fvec, f2[j])
                    push!(slcs, s2[j])
                    push!(nrls, u2[j])
                end
            end
        end
        @show length(fvec) 
    end
    return fvec, slcs, nrls
end