
# get a linear system of the momentum conservation law
# -------------  Input:
# n              Int64
# zero_s_ij      Vector{Tuple{Any, Any}} of nonedges of the graph
# HC             if output should be in HC
# -------------  Output:
# (ring), s variables, and the momentum conservation in s

function momentum_conservation(n::Int, zero_s_ij; HC = false)

    if HC
        # Step 1: Generate HC variables in order: s_12, s_13, s_23, s_14, ..., s_{n-1,n}
        s_vars = Variable[]
        for j in 2:n, i in 1:j-1
            push!(s_vars, Variable(:s, i, j))
        end
        
        S = Matrix{Any}(undef, n, n)
        for i in 1:n, j in 1:n
            if ((i,j) in zero_s_ij) || ((j,i) in zero_s_ij) || (i == j)
                S[i,j] = 0
            elseif j > i
                S[i,j] = s_vars[Int((j^2-3*j+2)//2 + i)]
            else 
                S[i,j] = s_vars[Int((i^2-3*i+2)//2 + j)]
            end
        end

        scat_eqns = [sum(S[i, :]) for i in 1:n]

        return s_vars, Expression.(scat_eqns)
    end
    
    # Step 1: Generate variable names in order: s_12, s_13, s_23, s_14, ..., s_{n-1,n}
    varnames = String[]
    index_map = Dict{Tuple{Int,Int}, Int}()
    var_count = 0
    for j in 2:n
        for i in 1:(j-1)
            var_count += 1
            name = "s_$(i)$(j)"
            push!(varnames, name)
            index_map[(i,j)] = var_count
            index_map[(j,i)] = var_count  # symmetric
        end
    end

    # Step 2: Create polynomial ring in Oscar with those variables
    R, vars = polynomial_ring(QQ, varnames)

    # Step 3: Build symmetric matrix with zero diagonal
    S = Matrix{Any}(undef, n, n)
    for i in 1:n, j in 1:n
        if ((i,j) in zero_s_ij) || ((j,i) in zero_s_ij) || (i == j)
            S[i,j] = zero(R)
        else
            idx = index_map[(i,j)]
            S[i,j] = vars[idx]
        end
    end

    # Step 4: Build n equations: sum(S[i, :]) == 0
    scat_eqns = [sum(S[i, :]) for i in 1:n]

    return R, vars, scat_eqns
end

# solve the momentum conservation and get dependent s variables derived from the system
# -------------  Input:
# n              Int64
# zero_s_ij      Vector{Tuple{Any, Any}} of nonedges of the graph
# -------------  Output:
# vector of s variables, where dependent s are written as a linear combination of free ones

function solve_momentum_cons(n::Int, zero_s_ij)

    s_vars, eqns = momentum_conservation(n, zero_s_ij, HC = true)
    A, _ = get_coeffs_kernel(s_vars, eqns)

    A = Rational.(RowEchelon.rref(A))
    s_vars = Expression.(s_vars)

    for i in n:-1:1
        k = findfirst(a -> a == 1.0, A[i,:])
        if k == size(A,2)
            s_vars[k] = 0
        elseif !isnothing(k)
            s_vars[k] = dot((-A[i,:][k+1:end]),s_vars[k+1:end])
        end
    end
    
    return s_vars
end

# get a list of all minors corresponding to nonedges
# -------------  Input:
# n              Int64
# zerosij        Vector{Tuple{Any, Any}} of nonedges of the graph
# -------------  Output:
# x              Vector{Variable} coordinates on M_0,n 
# pluecker_HC    Vector{Expression} x_j - x_i minors

function get_pluecker(n::Int, zerosij)

    #zerosij = vcat([Tuple(c) for c in add_zeros], get_non_edges(n, hypertree))
    
    sigma, _, _ = create_point_of_M0m(n)
    pluecker = minors(sigma, 2)

    @var x[1:n]
    pluecker_HC = [oscar_to_HC_Q(p,x) for p in pluecker]

    for i in 1:n-1, j in i+1:n
        if (i,j) in zerosij
            pluecker_HC[Int((j^2-3*j+2)//2 + i)] = 1
        end
    end

    return(x, pluecker_HC)
end

# get scattering equations
# -------------  Input:
# n              Int64
# zeros          Vector{Tuple{Any, Any}} of nonedges of the graph
# -------------  Output:
# x              Vector{Variable} coordinates on M_0,n 
# s_vars         Vector{Variable} Mandelstam invariants
# eqns           Vector{Expression} momentum conservation law
# F              Vector{Expression} critical equations of the logarithmic potential 

function get_scattering_eqns(n::Int, zeros; no_momentum_cons = false)

    x, p = get_pluecker(n, zeros)
    s_vars, eqns = momentum_conservation(n, zeros, HC = true)

    if no_momentum_cons
        L = sum([s_vars[i]*log(p[i]) for i = 1:length(p)])
    else
        s = solve_momentum_cons(n, zeros)
        L = sum([s[i]*log(p[i]) for i = 1:length(p)])
    end

    F = differentiate(L,x)
    
    return x, s_vars, eqns, F
end

# find a start pair for monodromy solution of scattering equations
# -------------      Input:
# n                  Int64
# nonedges           Vector{Tuple{Any, Any}} of nonedges of the graph        
# -------------  Output:
#                    start solution and start parameters

function find_start_solution(n::Int64, nonedges; no_momentum_cons = false)
    
    x0 = randn(ComplexF64, n)
    if no_momentum_cons
        x, s, eqns, F = get_scattering_eqns(n, nonedges, no_momentum_cons = true)
        S = HomotopyContinuation.evaluate(F, x => x0)
    else
        x, s, eqns, F = get_scattering_eqns(n, nonedges, no_momentum_cons = true)
        S = vcat(eqns, HomotopyContinuation.evaluate(F, x => x0))
    end
    
    A, M = get_coeffs_kernel(s, S)
    #println(A*M)

    m = size(M, 2)
    coeffs = randn(Float64,m)
    result = M*coeffs

    var, eqns = momentum_conservation(n, nonedges, HC = true)
    _, var_oscar, _ = momentum_conservation(n, nonedges)
    indices = []
    for i in 1:length(var) 
        if length(string(var_oscar[i])) == 4
            ij = (string(var_oscar[i])[3], string(var_oscar[i])[4]).-'0'
        elseif length(string(var_oscar[i])) == 5
            ij = (parse(Int, string(var_oscar[i])[3]), parse(Int, string(var_oscar[i])[4:5]))
        elseif length(string(var_oscar[i])) == 6
            ij = (parse(Int, string(var_oscar[i])[3:4]), parse(Int, string(var[i])[5:6]))
        end
        if (ij in nonedges)
            result[findfirst(==(var[i]), var)] = 0.0
        else
            push!(indices, findfirst(==(var[i]), var))
        end
    end

    #println(HomotopyContinuation.evaluate(S, x => x0, s=>result))
    return x0, result
end

# solve with monodromy the scattering equations
# -------------      Input:
# n                  Int64
# nonedges           Vector{Tuple{Any, Any}} of nonedges of the graph        
# -------------  Output:
#                    Bool, result of monodromy or "non-copious" error

function solve_monodromy(n::Int64, nonedges; no_momentum_cons = false)
    
    x, s, _, F = get_scattering_eqns(n, nonedges; no_momentum_cons = no_momentum_cons)
    sol, param = find_start_solution(n, nonedges; no_momentum_cons = no_momentum_cons)

    if !(all(isapprox(0, atol = 10^(-8)),param))
        F = System(HomotopyContinuation.evaluate(F, x[1:3] => sol[1:3]), variables = x[4:end], parameters = s)
        #println(sol)
        #println(param)
        res = solutions(monodromy_solve(F, sol[4:end], param))

        return true, res
    else
        return false, "non-copious"
    end
end

# get ML degree with monodromy the scattering equations
# -------------      Input:
# n                  Int64
# nonedges           Vector{Tuple{Any, Any}} of nonedges of the graph        
# -------------  Output:
#                    ML degree or "non-copious" error

function ML_deg_monodromy(n::Int64, nonedges; no_momentum_cons = false)

    res = solve_monodromy(n, nonedges, no_momentum_cons = no_momentum_cons)
    if res[1]
        return length(res[2])
    else
        return res[2]
    end
end

# get random kinematic data sij in K_G
# -------------      Input:
# n                  Int64
# zerosij            Vector{Tuple{Any, Any}} of nonedges of the graph        
# -------------  Output:
#                    Vector{Float64}

function get_data(n::Int64, zerosij; drop_zeros = false)

    var, eqns = momentum_conservation(n, zerosij, HC = true)
    _, var_oscar, _ = momentum_conservation(n, zerosij)
    _, M = get_coeffs_kernel(var, eqns)

    m = size(M, 2)

    coeffs = randn(Float64,m)
    result = M*coeffs

    indices = []
    for i in 1:length(var) 
        if length(string(var_oscar[i])) == 4
            ij = (string(var_oscar[i])[3], string(var_oscar[i])[4]).-'0'
        elseif length(string(var_oscar[i])) == 5
            ij = (parse(Int, string(var_oscar[i])[3]), parse(Int, string(var_oscar[i])[4:5]))
        elseif length(string(var_oscar[i])) == 6
            ij = (parse(Int, string(var_oscar[i])[3:4]), parse(Int, string(var_oscar[i])[5:6]))
        end
        if (ij in zerosij)
            result[findfirst(==(var[i]), var)] = 0.0
        else
            push!(indices, findfirst(==(var[i]), var))
        end
    end

    if drop_zeros
        result = result[indices]
    end

    return result
end


