// Copyright 2020-2025 Consensys Software Inc.
// Licensed under the Apache License, Version 2.0. See the LICENSE file for details.

// Code generated by gnark DO NOT EDIT

package gkr

import (
	"errors"
	"fmt"
	"math/big"
	"strconv"
	"sync"

	"github.com/consensys/gnark-crypto/ecc/bw6-633/fr"
	"github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial"
	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
	"github.com/consensys/gnark-crypto/utils"
	"github.com/consensys/gnark/frontend"
	"github.com/consensys/gnark/internal/gkr/gkrtypes"
	"github.com/consensys/gnark/std/gkrapi/gkr"
)

// The goal is to prove/verify evaluations of many instances of the same circuit

// WireAssignment is assignment of values to the same wire across many instances of the circuit
type WireAssignment []polynomial.MultiLin

type Proof []sumcheckProof // for each layer, for each wire, a sumcheck (for each variable, a polynomial)

// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side).
// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-).
// Its purpose is to batch the checking of multiple evaluations of the same wire.
type eqTimesGateEvalSumcheckLazyClaims struct {
	wireI              int            // the wire for which we are making the claim, with value w
	evaluationPoints   [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w
	claimedEvaluations []fr.Element   // yᵢ = w(xᵢ), allegedly
	manager            *claimsManager // WARNING: Circular references
}

func (e *eqTimesGateEvalSumcheckLazyClaims) getWire() *gkrtypes.Wire {
	return e.manager.wires[e.wireI]
}

func (e *eqTimesGateEvalSumcheckLazyClaims) claimsNum() int {
	return len(e.evaluationPoints)
}

func (e *eqTimesGateEvalSumcheckLazyClaims) varsNum() int {
	return len(e.evaluationPoints[0])
}

// combinedSum returns ∑ᵢ aⁱ yᵢ
func (e *eqTimesGateEvalSumcheckLazyClaims) combinedSum(a fr.Element) fr.Element {
	evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations)
	return evalsAsPoly.Eval(&a)
}

func (e *eqTimesGateEvalSumcheckLazyClaims) degree(int) int {
	return 1 + e.manager.wires[e.wireI].Gate.Degree()
}

// verifyFinalEval finalizes the verification of w.
// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying
// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff )
// Both purportedValue and the vector r have been randomized during the sumcheck protocol.
// By taking the w term out of the sum we get the equivalent claim that
// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue.
// If w is an input wire, the verifier can directly check its evaluation at r.
// Otherwise, the prover makes claims about the evaluation of w's input wires,
// wᵢ, at r, to be verified later.
// The claims are communicated through the proof parameter.
// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with
// the main claim, by checking E w(wᵢ(r)...) = purportedValue.
func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error {
	// the eq terms ( E )
	numClaims := len(e.evaluationPoints)
	evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r)
	for i := numClaims - 2; i >= 0; i-- {
		evaluation.Mul(&evaluation, &combinationCoeff)
		eq := polynomial.EvalEq(e.evaluationPoints[i], r)
		evaluation.Add(&evaluation, &eq)
	}

	wire := e.manager.wires[e.wireI]

	// the w(...) term
	var gateEvaluation fr.Element
	if wire.IsInput() { // just compute w(r)
		gateEvaluation = e.manager.assignment[e.wireI].Evaluate(r, e.manager.memPool)
	} else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire
		injection, injectionLeftInv :=
			e.manager.wires.ClaimPropagationInfo(e.wireI)

		if len(injection) != len(uniqueInputEvaluations) {
			return fmt.Errorf("%d input wire evaluations given, %d expected", len(uniqueInputEvaluations), len(injection))
		}

		for uniqueI, i := range injection { // map from unique to all
			e.manager.add(wire.Inputs[i], r, uniqueInputEvaluations[uniqueI])
		}

		inputEvaluations := make([]frontend.Variable, len(wire.Inputs))
		for i, uniqueI := range injectionLeftInv { // map from all to unique
			inputEvaluations[i] = &uniqueInputEvaluations[uniqueI]
		}

		gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element))
	}

	evaluation.Mul(&evaluation, &gateEvaluation)

	if evaluation.Equal(&purportedValue) {
		return nil
	}
	return errors.New("incompatible evaluations")
}

// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side).
// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-).
// Its purpose is to batch the proving of multiple evaluations of the same wire.
type eqTimesGateEvalSumcheckClaims struct {
	wireI              int            // the wire for which we are making the claim, with value w
	evaluationPoints   [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w
	claimedEvaluations []fr.Element   // yᵢ = w(xᵢ)
	manager            *claimsManager

	input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ)

	eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -)
}

func (c *eqTimesGateEvalSumcheckClaims) getWire() *gkrtypes.Wire {
	return c.manager.wires[c.wireI]
}

// combine the multiple claims into one claim using a random combination (combinationCoeff or c).
// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim
// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and
// i iterates over the claims.
// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h).
// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form
// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1,
// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output.
// The output of combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...)..
func (c *eqTimesGateEvalSumcheckClaims) combine(combinationCoeff fr.Element) polynomial.Polynomial {
	varsNum := c.varsNum()
	eqLength := 1 << varsNum
	claimsNum := c.claimsNum()
	// initialize the eq tables ( E )
	c.eq = c.manager.memPool.Make(eqLength)

	c.eq[0].SetOne()
	c.eq.Eq(c.evaluationPoints[0])

	// E := eq(x₀, -)
	newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength))
	aI := combinationCoeff

	// E += cⁱ eq(xᵢ, -)
	for k := 1; k < claimsNum; k++ {
		newEq[0].Set(&aI)

		c.eqAcc(c.eq, newEq, c.evaluationPoints[k])

		if k+1 < claimsNum {
			aI.Mul(&aI, &combinationCoeff)
		}
	}

	c.manager.memPool.Dump(newEq)

	return c.computeGJ()
}

// eqAcc sets m to an eq table at q and then adds it to e.
// m <- eq(q, -).
// e <- e + m
func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) {
	n := len(q)

	//At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁)
	for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁
		// go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ
		const threshold = 1 << 6
		k := 1 << i
		if k < threshold {
			for j := 0; j < k; j++ {
				j0 := j << (n - i)    // bᵢ₊₁ = 0
				j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1

				m[j1].Mul(&q[i], &m[j0])  // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁
				m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁)
			}
		} else {
			c.manager.workers.Submit(k, func(start, end int) {
				for j := start; j < end; j++ {
					j0 := j << (n - i)    // bᵢ₊₁ = 0
					j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1

					m[j1].Mul(&q[i], &m[j0])  // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁
					m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁)
				}
			}, 1024).Wait()
		}

	}
	c.manager.workers.Submit(len(e), func(start, end int) {
		for i := start; i < end; i++ {
			e[i].Add(&e[i], &m[i])
		}
	}, 512).Wait()
}

// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ).
// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)).
// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum.
func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial {

	wire := c.getWire()
	degGJ := 1 + wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ)
	nbGateIn := len(c.input)

	// Both E and wᵢ (the input wires and the eq table) are multilinear, thus
	// they are linear in Xⱼ.
	// So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables.
	// ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner.
	ml := make([]polynomial.MultiLin, nbGateIn+1) // shortcut to the evaluations of the multilinear polynomials over the hypercube
	ml[0] = c.eq
	copy(ml[1:], c.input)

	sumSize := len(c.eq) / 2 // the range of h, over which we sum

	// Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called

	gJ := make([]fr.Element, degGJ)
	var mu sync.Mutex
	computeAll := func(start, end int) { // compute method to allow parallelization across instances
		var step fr.Element

		res := make([]fr.Element, degGJ)

		// evaluations of ml, laid out as:
		// ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...),
		// ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...),
		// ...
		// ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...)
		mlEvals := make([]fr.Element, degGJ*len(ml))
		gateInput := make([]frontend.Variable, nbGateIn)

		for h := start; h < end; h++ { // h counts across instances

			evalAt1Index := sumSize + h
			for k := range ml {
				// d = 0
				mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table.
				step.Sub(&mlEvals[k], &ml[k][h])     // step = ml[k](1) - ml[k](0)
				for d := 1; d < degGJ; d++ {
					mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step)
				}
			}

			eIndex := 0 // index for where the current eq term is
			nextEIndex := len(ml)
			for d := range degGJ {
				for i := range gateInput {
					gateInput[i] = &mlEvals[eIndex+1+i]
				}
				summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element)
				summand.Mul(summand, &mlEvals[eIndex])
				res[d].Add(&res[d], summand) // collect contributions into the sum from start to end
				eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml)
			}
		}
		mu.Lock()
		for i := range gJ {
			gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum
		}
		mu.Unlock()
	}

	const minBlockSize = 64

	if sumSize < minBlockSize {
		// no parallelization
		computeAll(0, sumSize)
	} else {
		c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait()
	}

	return gJ
}

// next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ.
// Thus, j <- j+1 and rⱼ = challenge.
func (c *eqTimesGateEvalSumcheckClaims) next(challenge fr.Element) polynomial.Polynomial {
	const minBlockSize = 512
	n := len(c.eq) / 2
	if n < minBlockSize {
		// no parallelization
		for i := 0; i < len(c.input); i++ {
			c.input[i].Fold(challenge)
		}
		c.eq.Fold(challenge)
	} else {
		wgs := make([]*sync.WaitGroup, len(c.input))
		for i := 0; i < len(c.input); i++ {
			wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize)
		}
		c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait()
		for _, wg := range wgs {
			wg.Wait()
		}
	}

	return c.computeGJ()
}

func (c *eqTimesGateEvalSumcheckClaims) varsNum() int {
	return len(c.evaluationPoints[0])
}

func (c *eqTimesGateEvalSumcheckClaims) claimsNum() int {
	return len(c.claimedEvaluations)
}

// proveFinalEval provides the values wᵢ(r₁, ..., rₙ)
func (c *eqTimesGateEvalSumcheckClaims) proveFinalEval(r []fr.Element) []fr.Element {
	//defer the proof, return list of claims

	injection, _ := c.manager.wires.ClaimPropagationInfo(c.wireI) // TODO @Tabaie: Instead of doing this last, we could just have fewer input in the first place; not that likely to happen with single gates, but more so with layers.
	evaluations := make([]fr.Element, len(injection))
	for i, gateInputI := range injection {
		wI := c.input[gateInputI]
		wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required.
		c.manager.add(c.getWire().Inputs[gateInputI], r, wI[0])
		evaluations[i] = wI[0]
	}

	c.manager.memPool.Dump(c.claimedEvaluations, c.eq)

	return evaluations
}

type claimsManager struct {
	claims     []*eqTimesGateEvalSumcheckLazyClaims
	assignment WireAssignment
	memPool    *polynomial.Pool
	workers    *utils.WorkerPool
	wires      gkrtypes.Wires
}

func newClaimsManager(wires []*gkrtypes.Wire, assignment WireAssignment, o settings) (manager claimsManager) {
	manager.assignment = assignment
	manager.claims = make([]*eqTimesGateEvalSumcheckLazyClaims, len(wires))
	manager.memPool = o.pool
	manager.workers = o.workers
	manager.wires = wires

	for i, wire := range wires {

		manager.claims[i] = &eqTimesGateEvalSumcheckLazyClaims{
			wireI:              i,
			evaluationPoints:   make([][]fr.Element, 0, wire.NbClaims()),
			claimedEvaluations: manager.memPool.Make(wire.NbClaims()),
			manager:            &manager,
		}
	}
	return
}

func (m *claimsManager) add(wire int, evaluationPoint []fr.Element, evaluation fr.Element) {
	claim := m.claims[wire]
	i := len(claim.evaluationPoints)
	claim.claimedEvaluations[i] = evaluation
	claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint)
}

func (m *claimsManager) getLazyClaim(wire int) *eqTimesGateEvalSumcheckLazyClaims {
	return m.claims[wire]
}

func (m *claimsManager) getClaim(wireI int) *eqTimesGateEvalSumcheckClaims {
	lazy := m.claims[wireI]
	wire := m.wires[wireI]
	res := &eqTimesGateEvalSumcheckClaims{
		wireI:              wireI,
		evaluationPoints:   lazy.evaluationPoints,
		claimedEvaluations: lazy.claimedEvaluations,
		manager:            m,
	}

	if wire.IsInput() {
		res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wireI])}
	} else {
		res.input = make([]polynomial.MultiLin, len(wire.Inputs))

		for inputI, inputW := range wire.Inputs {
			res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied
		}
	}
	return res
}

func (m *claimsManager) deleteClaim(wire int) {
	m.claims[wire].manager = nil
	m.claims[wire] = nil
}

type settings struct {
	pool             *polynomial.Pool
	sorted           []*gkrtypes.Wire
	transcript       *fiatshamir.Transcript
	transcriptPrefix string
	nbVars           int
	workers          *utils.WorkerPool
}

type Option func(*settings)

func WithPool(pool *polynomial.Pool) Option {
	return func(options *settings) {
		options.pool = pool
	}
}

func WithSortedCircuit(sorted []*gkrtypes.Wire) Option {
	return func(options *settings) {
		options.sorted = sorted
	}
}

func WithWorkers(workers *utils.WorkerPool) Option {
	return func(options *settings) {
		options.workers = workers
	}
}

func setup(c gkrtypes.Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) {
	var o settings
	var err error
	for _, option := range options {
		option(&o)
	}

	o.nbVars = assignment.NumVars()
	nbInstances := assignment.NumInstances()
	if 1<<o.nbVars != nbInstances {
		return o, errors.New("number of instances must be power of 2")
	}

	if o.pool == nil {
		pool := polynomial.NewPool(c.MemoryRequirements(nbInstances)...)
		o.pool = &pool
	}

	if o.workers == nil {
		o.workers = utils.NewWorkerPool()
	}

	if o.sorted == nil {
		o.sorted = c.TopologicalSort()
	}

	if transcriptSettings.Transcript == nil {
		challengeNames := ChallengeNames(o.sorted, o.nbVars, transcriptSettings.Prefix)
		o.transcript = fiatshamir.NewTranscript(transcriptSettings.Hash, challengeNames...)
		for i := range transcriptSettings.BaseChallenges {
			if err = o.transcript.Bind(challengeNames[0], transcriptSettings.BaseChallenges[i]); err != nil {
				return o, err
			}
		}
	} else {
		o.transcript, o.transcriptPrefix = transcriptSettings.Transcript, transcriptSettings.Prefix
	}

	return o, err
}

func ChallengeNames(sorted []*gkrtypes.Wire, logNbInstances int, prefix string) []string {

	// Pre-compute the size TODO: Consider not doing this and just grow the list by appending
	size := logNbInstances // first challenge

	for _, w := range sorted {
		if w.NoProof() { // no proof, no challenge
			continue
		}
		if w.NbClaims() > 1 { //combine the claims
			size++
		}
		size += logNbInstances // full run of sumcheck on logNbInstances variables
	}

	nums := make([]string, max(len(sorted), logNbInstances))
	for i := range nums {
		nums[i] = strconv.Itoa(i)
	}

	challenges := make([]string, size)

	// output wire claims
	firstChallengePrefix := prefix + "fC."
	for j := 0; j < logNbInstances; j++ {
		challenges[j] = firstChallengePrefix + nums[j]
	}
	j := logNbInstances
	for i := len(sorted) - 1; i >= 0; i-- {
		if sorted[i].NoProof() {
			continue
		}
		wirePrefix := prefix + "w" + nums[i] + "."

		if sorted[i].NbClaims() > 1 {
			challenges[j] = wirePrefix + "comb"
			j++
		}

		partialSumPrefix := wirePrefix + "pSP."
		for k := 0; k < logNbInstances; k++ {
			challenges[j] = partialSumPrefix + nums[k]
			j++
		}
	}
	return challenges
}

func getFirstChallengeNames(logNbInstances int, prefix string) []string {
	res := make([]string, logNbInstances)
	firstChallengePrefix := prefix + "fC."
	for i := 0; i < logNbInstances; i++ {
		res[i] = firstChallengePrefix + strconv.Itoa(i)
	}
	return res
}

func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) {
	res := make([]fr.Element, len(names))
	for i, name := range names {
		if bytes, err := transcript.ComputeChallenge(name); err == nil {
			res[i].SetBytes(bytes)
		} else {
			return nil, err
		}
	}
	return res, nil
}

// Prove consistency of the claimed assignment
func Prove(c gkrtypes.Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) {
	o, err := setup(c, assignment, transcriptSettings, options...)
	if err != nil {
		return nil, err
	}
	defer o.workers.Stop()

	claims := newClaimsManager(o.sorted, assignment, o)

	proof := make(Proof, len(c))
	// firstChallenge called rho in the paper
	var firstChallenge []fr.Element
	firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix))
	if err != nil {
		return nil, err
	}

	wirePrefix := o.transcriptPrefix + "w"
	var baseChallenge [][]byte
	for i := len(c) - 1; i >= 0; i-- {

		wire := o.sorted[i]

		if wire.IsOutput() {
			claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool))
		}

		claim := claims.getClaim(i)
		if wire.NoProof() { // input wires with one claim only
			proof[i] = sumcheckProof{
				partialSumPolys: []polynomial.Polynomial{},
				finalEvalProof:  []fr.Element{},
			}
		} else {
			if proof[i], err = sumcheckProve(
				claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...),
			); err != nil {
				return proof, err
			}

			baseChallenge = make([][]byte, len(proof[i].finalEvalProof))
			for j := range proof[i].finalEvalProof {
				baseChallenge[j] = proof[i].finalEvalProof[j].Marshal()
			}
		}
		// the verifier checks a single claim about input wires itself
		claims.deleteClaim(i)
	}

	return proof, nil
}

// Verify the consistency of the claimed output with the claimed input
// Unlike in Prove, the assignment argument need not be complete
func Verify(c gkrtypes.Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error {
	o, err := setup(c, assignment, transcriptSettings, options...)
	if err != nil {
		return err
	}
	defer o.workers.Stop()

	claims := newClaimsManager(o.sorted, assignment, o)

	var firstChallenge []fr.Element
	firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix))
	if err != nil {
		return err
	}

	wirePrefix := o.transcriptPrefix + "w"
	var baseChallenge [][]byte
	for i := len(c) - 1; i >= 0; i-- {
		wire := o.sorted[i]

		if wire.IsOutput() {
			claims.add(i, firstChallenge, assignment[i].Evaluate(firstChallenge, claims.memPool))
		}

		proofW := proof[i]
		claim := claims.getLazyClaim(i)
		if wire.NoProof() { // input wires with one claim only
			// make sure the proof is empty
			if len(proofW.finalEvalProof) != 0 || len(proofW.partialSumPolys) != 0 {
				return errors.New("no proof allowed for input wire with a single claim")
			}

			if wire.NbClaims() == 1 { // input wire
				// simply evaluate and see if it matches
				evaluation := assignment[i].Evaluate(claim.evaluationPoints[0], claims.memPool)
				if !claim.claimedEvaluations[0].Equal(&evaluation) {
					return errors.New("incorrect input wire claim")
				}
			}
		} else if err = sumcheckVerify(
			claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...),
		); err == nil { // incorporate prover claims about w's input into the transcript
			baseChallenge = make([][]byte, len(proofW.finalEvalProof))
			for j := range baseChallenge {
				baseChallenge[j] = proofW.finalEvalProof[j].Marshal()
			}
		} else {
			return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump?
		}
		claims.deleteClaim(i)
	}
	return nil
}

// Complete the circuit evaluation from input values
func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment {

	nbInstances := a.NumInstances()
	maxNbIns := 0

	for i, w := range wires {
		maxNbIns = max(maxNbIns, len(w.Inputs))
		if len(a[i]) != nbInstances {
			a[i] = make([]fr.Element, nbInstances)
		}
	}

	ins := make([]fr.Element, maxNbIns)
	for i := range nbInstances {
		for wI, w := range wires {
			if !w.IsInput() {
				for inI, in := range w.Inputs {
					ins[inI] = a[in][i]
				}
				a[wI][i].Set(api.evaluate(w.Gate.Evaluate, ins[:len(w.Inputs)]...))
			}
		}
	}

	return a
}

func (a WireAssignment) NumInstances() int {
	for _, aW := range a {
		return len(aW)
	}
	panic("empty assignment")
}

func (a WireAssignment) NumVars() int {
	for _, aW := range a {
		return aW.NumVars()
	}
	panic("empty assignment")
}

// SerializeToBigInts flattens a proof object into the given slice of big.Ints
// useful in gnark hints.
func (p Proof) SerializeToBigInts(outs []*big.Int) error {
	offset := 0
	for i := range p {
		for _, poly := range p[i].partialSumPolys {
			frToBigInts(outs[offset:], poly)
			offset += len(poly)
		}
		if p[i].finalEvalProof != nil {
			frToBigInts(outs[offset:], p[i].finalEvalProof)
			offset += len(p[i].finalEvalProof)
		}
	}
	if offset != len(outs) {
		return fmt.Errorf("expected %d elements, got %d", offset, len(outs))
	}
	return nil
}

func frToBigInts(dst []*big.Int, src []fr.Element) {
	for i := range src {
		src[i].BigInt(dst[i])
	}
}

// gateAPI implements gkr.GateAPI.
type gateAPI struct{}

var api gateAPI

func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
	var res fr.Element // TODO Heap allocated. Keep an eye on perf
	res.Add(cast(i1), cast(i2))
	for _, v := range in {
		res.Add(&res, cast(v))
	}
	return &res
}

func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable {
	var prod fr.Element
	prod.Add(cast(b), cast(c))
	res := cast(a)
	res.Add(res, &prod)
	return &res
}

func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable {
	var res fr.Element
	res.Neg(cast(i1))
	return &res
}

func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
	var res fr.Element
	res.Sub(cast(i1), cast(i2))
	for _, v := range in {
		res.Sub(&res, cast(v))
	}
	return &res
}

func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
	var res fr.Element
	res.Mul(cast(i1), cast(i2))
	for _, v := range in {
		res.Mul(&res, cast(v))
	}
	return &res
}

func (gateAPI) Println(a ...frontend.Variable) {
	toPrint := make([]any, len(a))
	var x fr.Element

	for i, v := range a {
		if _, err := x.SetInterface(v); err != nil {
			toPrint[i] = x.String()
		} else {
			if s, ok := v.(string); ok {
				toPrint[i] = s
				continue
			}
			panic(fmt.Errorf("not numeric or string: %w", err))
		}
	}
	fmt.Println(toPrint...)
}

func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element {
	inVar := make([]frontend.Variable, len(in))
	for i := range in {
		inVar[i] = &in[i]
	}
	return f(api, inVar...).(*fr.Element)
}

type gateFunctionFr func(...fr.Element) *fr.Element

// convertFunc turns f into a function that accepts and returns fr.Element.
func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr {
	return func(in ...fr.Element) *fr.Element {
		return api.evaluate(f, in...)
	}
}

func cast(v frontend.Variable) *fr.Element {
	if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation
		return x
	}
	var x fr.Element
	if _, err := x.SetInterface(v); err != nil {
		panic(err)
	}
	return &x
}
