// Copyright 2020-2025 Consensys Software Inc.
// Licensed under the Apache License, Version 2.0. See the LICENSE file for details.

// Code generated by gnark DO NOT EDIT

package gkr

import (
	"fmt"
	"hash"

	"github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial"
	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
	"github.com/stretchr/testify/assert"

	"math/bits"

	"github.com/consensys/gnark-crypto/ecc/bw6-761/fr"

	"strings"
	"testing"
)

func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error {
	poly := make(polynomial.MultiLin, len(polyInt))
	for i, n := range polyInt {
		poly[i].SetUint64(n)
	}

	claim := singleMultilinClaim{g: poly.Clone()}

	proof, err := sumcheckProve(&claim, fiatshamir.WithHash(hashGenerator()))
	if err != nil {
		return err
	}

	var sb strings.Builder
	for _, p := range proof.partialSumPolys {

		sb.WriteString("\t{")
		for i := 0; i < len(p); i++ {
			sb.WriteString(p[i].String())
			if i+1 < len(p) {
				sb.WriteString(", ")
			}
		}
		sb.WriteString("}\n")
	}

	lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()}
	if err = sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil {
		return err
	}

	proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1))
	lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()}
	if sumcheckVerify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil {
		return fmt.Errorf("bad proof accepted")
	}
	return nil
}

func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) {

	polys := [][]uint64{
		{1, 2, 3, 4},             // 1 + 2X₁ + X₂
		{1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃
		{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄
	}

	const MaxStep = 4
	const MaxStart = 4
	hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep)

	for step := 0; step < MaxStep; step++ {
		for startState := 0; startState < MaxStart; startState++ {
			if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted
				continue
			}
			hashGens = append(hashGens, newMessageCounterGenerator(startState, step))
		}
	}

	for _, poly := range polys {
		for _, hashGen := range hashGens {
			assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen),
				"failed with poly %v and hashGen %v", poly, hashGen())
		}
	}
}

type singleMultilinClaim struct {
	g polynomial.MultiLin
}

func (c singleMultilinClaim) proveFinalEval(r []fr.Element) []fr.Element {
	return nil // verifier can compute the final eval itself
}

func (c singleMultilinClaim) varsNum() int {
	return bits.TrailingZeros(uint(len(c.g)))
}

func (c singleMultilinClaim) claimsNum() int {
	return 1
}

func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial {
	sum := g[len(g)/2]
	for i := len(g)/2 + 1; i < len(g); i++ {
		sum.Add(&sum, &g[i])
	}
	return []fr.Element{sum}
}

func (c singleMultilinClaim) combine(fr.Element) polynomial.Polynomial {
	return sumForX1One(c.g)
}

func (c *singleMultilinClaim) next(r fr.Element) polynomial.Polynomial {
	c.g.Fold(r)
	return sumForX1One(c.g)
}

type singleMultilinLazyClaim struct {
	g          polynomial.MultiLin
	claimedSum fr.Element
}

func (c singleMultilinLazyClaim) verifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error {
	val := c.g.Evaluate(r, nil)
	if val.Equal(&purportedValue) {
		return nil
	}
	return fmt.Errorf("mismatch")
}

func (c singleMultilinLazyClaim) combinedSum(combinationCoeffs fr.Element) fr.Element {
	return c.claimedSum
}

func (c singleMultilinLazyClaim) degree(i int) int {
	return 1
}

func (c singleMultilinLazyClaim) claimsNum() int {
	return 1
}

func (c singleMultilinLazyClaim) varsNum() int {
	return bits.TrailingZeros(uint(len(c.g)))
}
