// Copyright 2020-2025 Consensys Software Inc.
// Licensed under the Apache License, Version 2.0. See the LICENSE file for details.

// Code generated by gnark DO NOT EDIT

package gkr

import (
	"fmt"
	"hash"
	"os"
	"path/filepath"
	"reflect"
	"strconv"
	"testing"
	"time"

	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc"
	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial"
	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
	gcUtils "github.com/consensys/gnark-crypto/utils"
	"github.com/consensys/gnark/frontend"
	"github.com/consensys/gnark/internal/gkr/gkrtesting"
	"github.com/consensys/gnark/internal/gkr/gkrtypes"
	"github.com/consensys/gnark/internal/utils"
	"github.com/consensys/gnark/std/gkrapi/gkr"
	"github.com/stretchr/testify/assert"
)

func TestNoGateTwoInstances(t *testing.T) {
	// Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case
	testNoGate(t, []fr.Element{four, three})
}

func TestNoGate(t *testing.T) {
	test(t, gkrtypes.Circuit{{}})
}

func TestSingleAddGate(t *testing.T) {
	test(t, gkrtypes.Circuit{{}, {}, {
		Gate:   gkrtypes.Add2(),
		Inputs: []int{0, 1},
	}})
}

func TestSingleMulGate(t *testing.T) {
	test(t, gkrtypes.Circuit{{}, {}, {
		Gate:   gkrtypes.Mul2(),
		Inputs: []int{0, 1},
	}})
}

func TestSingleInputTwoIdentityGates(t *testing.T) {
	test(t, gkrtypes.Circuit{{},
		{
			Gate:   gkrtypes.Identity(),
			Inputs: []int{0},
		},
		{
			Gate:   gkrtypes.Identity(),
			Inputs: []int{0},
		},
	})
}

func TestSingleInputTwoIdentityGatesComposed(t *testing.T) {
	test(t, gkrtypes.Circuit{{},
		{
			Gate:   gkrtypes.Identity(),
			Inputs: []int{0},
		},
		{
			Gate:   gkrtypes.Identity(),
			Inputs: []int{1},
		}})
}

func TestAPowNTimesBCircuit(t *testing.T) {
	const N = 10

	c := make(gkrtypes.Circuit, N+2)

	for i := 2; i < len(c); i++ {
		c[i] = gkrtypes.Wire{
			Gate:   gkrtypes.Mul2(),
			Inputs: []int{i - 1, 0},
		}
	}

	test(t, c)
}

func TestSingleMimcCipherGate(t *testing.T) {
	test(t, gkrtypes.Circuit{
		{}, {},
		{
			Inputs: []int{0, 1},
			Gate:   cache.GetGate("mimc"),
		},
	})
}

func TestShallowMimcTwoInstances(t *testing.T) {
	test(t, mimcCircuit(2))
}

func TestMimc(t *testing.T) {
	test(t, mimcCircuit(93))
}

func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) {
	circuit := gkrtypes.Circuit{gkrtypes.Wire{
		Gate:            gkrtypes.Identity(),
		NbUniqueOutputs: 2,
	}}

	assignment := WireAssignment{[]fr.Element{two, three}}
	var o settings
	pool := polynomial.NewPool(256, 1<<11)
	workers := gcUtils.NewWorkerPool()
	o.pool = &pool
	o.workers = workers

	claimsManagerGen := func() *claimsManager {
		manager := newClaimsManager(utils.References(circuit), assignment, o)
		manager.add(0, []fr.Element{three}, five)
		manager.add(0, []fr.Element{four}, six)
		return &manager
	}

	transcriptGen := newMessageCounterGenerator(4, 1)

	proof, err := sumcheckProve(claimsManagerGen().getClaim(0), fiatshamir.WithHash(transcriptGen(), nil))
	assert.NoError(t, err)
	err = sumcheckVerify(claimsManagerGen().getLazyClaim(0), proof, fiatshamir.WithHash(transcriptGen(), nil))
	assert.NoError(t, err)
}

var one, two, three, four, five, six fr.Element

func init() {
	one.SetOne()
	two.Double(&one)
	three.Add(&two, &one)
	four.Double(&two)
	five.Add(&three, &two)
	six.Double(&three)
}

var testManyInstancesLogMaxInstances = -1

func getLogMaxInstances(t *testing.T) int {
	if testManyInstancesLogMaxInstances == -1 {

		s := os.Getenv("GKR_LOG_INSTANCES")
		if s == "" {
			testManyInstancesLogMaxInstances = 5
		} else {
			var err error
			testManyInstancesLogMaxInstances, err = strconv.Atoi(s)
			if err != nil {
				t.Error(err)
			}
		}

	}
	return testManyInstancesLogMaxInstances
}

func test(t *testing.T, circuit gkrtypes.Circuit) {
	wireRefs := utils.References(circuit)
	ins := circuit.Inputs()
	insAssignment := make(WireAssignment, len(ins))
	maxSize := 1 << getLogMaxInstances(t)

	for i := range ins {
		insAssignment[i] = make([]fr.Element, maxSize)
		fr.Vector(insAssignment[i]).MustSetRandom()
	}

	fullAssignment := make(WireAssignment, len(circuit))
	for _, numEvals := range []int{2, maxSize} {
		for i := range ins {
			fullAssignment[ins[i]] = insAssignment[i][:numEvals]
		}

		fullAssignment.Complete(wireRefs)

		t.Log("Selected inputs for test")

		proof, err := Prove(circuit, fullAssignment, fiatshamir.WithHash(newMessageCounter(1, 1)))
		assert.NoError(t, err)

		// Even though a hash is called here, the proof is empty

		err = Verify(circuit, fullAssignment, proof, fiatshamir.WithHash(newMessageCounter(1, 1)))
		assert.NoError(t, err, "proof rejected")

		if proof.isEmpty() { // special case for TestNoGate:
			continue // there's no way to make a trivial proof fail
		}

		err = Verify(circuit, fullAssignment, proof, fiatshamir.WithHash(newMessageCounter(0, 1)))
		assert.NotNil(t, err, "bad proof accepted")
	}

}

func (p Proof) isEmpty() bool {
	for i := range p {
		if len(p[i].finalEvalProof) != 0 {
			return false
		}
		for j := range p[i].partialSumPolys {
			if len(p[i].partialSumPolys[j]) != 0 {
				return false
			}
		}
	}
	return true
}

func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) {
	c := gkrtypes.Circuit{
		{},
	}

	assignment := WireAssignment{0: inputAssignments[0]}

	proof, err := Prove(c, assignment, fiatshamir.WithHash(newMessageCounter(1, 1)))
	assert.NoError(t, err)

	// Even though a hash is called here, the proof is empty

	err = Verify(c, assignment, proof, fiatshamir.WithHash(newMessageCounter(1, 1)))
	assert.NoError(t, err, "proof rejected")
}

func mimcCircuit(numRounds int) gkrtypes.Circuit {
	c := make(gkrtypes.Circuit, numRounds+2)

	for i := 2; i < len(c); i++ {
		c[i] = gkrtypes.Wire{
			Gate:   cache.GetGate("mimc"),
			Inputs: []int{i - 1, 0},
		}
	}
	return c
}

func TestIsAdditive(t *testing.T) {

	// f: x,y -> x² + xy
	f := func(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable {
		if len(x) != 2 {
			panic("bivariate input needed")
		}
		res := api.Add(x[0], x[1])
		return api.Mul(res, x[0])
	}

	// g: x,y -> x² + 3y
	g := func(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable {
		res := api.Mul(x[0], x[0])
		y3 := api.Mul(x[1], 3)
		return api.Add(res, y3)
	}

	// h: x -> 2x
	// but it edits it input
	h := func(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable {
		return api.Add(x[0], x[0])
	}

	assert.False(t, IsGateFunctionAdditive(f, 1, 2))
	assert.False(t, IsGateFunctionAdditive(f, 0, 2))

	assert.False(t, IsGateFunctionAdditive(g, 0, 2))
	assert.True(t, IsGateFunctionAdditive(g, 1, 2))

	assert.True(t, IsGateFunctionAdditive(h, 0, 1))
}

func generateTestProver(path string) func(t *testing.T) {
	return func(t *testing.T) {
		testCase, err := newTestCase(path)
		assert.NoError(t, err)
		proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash))
		assert.NoError(t, err)
		assert.NoError(t, proofEquals(testCase.Proof, proof))
	}
}

func generateTestVerifier(path string) func(t *testing.T) {
	return func(t *testing.T) {
		testCase, err := newTestCase(path)
		assert.NoError(t, err)
		err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash))
		assert.NoError(t, err, "proof rejected")
		testCase, err = newTestCase(path)
		assert.NoError(t, err)
		err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(newMessageCounter(2, 0)))
		assert.NotNil(t, err, "bad proof accepted")
	}
}

func TestGkrVectors(t *testing.T) {

	const testDirPath = "../test_vectors/"
	dirEntries, err := os.ReadDir(testDirPath)
	assert.NoError(t, err)
	for _, dirEntry := range dirEntries {
		if !dirEntry.IsDir() {

			if filepath.Ext(dirEntry.Name()) == ".json" {
				path := filepath.Join(testDirPath, dirEntry.Name())
				noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")]

				t.Run(noExt+"_prover", generateTestProver(path))
				t.Run(noExt+"_verifier", generateTestVerifier(path))

			}
		}
	}
}

func proofEquals(expected Proof, seen Proof) error {
	if len(expected) != len(seen) {
		return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen))
	}
	for i, x := range expected {
		xSeen := seen[i]

		if xSeen.finalEvalProof == nil {
			if seenFinalEval := x.finalEvalProof; len(seenFinalEval) != 0 {
				return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval))
			}
		} else {
			if err := sliceEquals(x.finalEvalProof, xSeen.finalEvalProof); err != nil {
				return fmt.Errorf("final evaluation proof mismatch")
			}
		}
		if err := polynomialSliceEquals(x.partialSumPolys, xSeen.partialSumPolys); err != nil {
			return err
		}
	}
	return nil
}

func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) {
	fmt.Println("creating circuit structure")
	c := mimcCircuit(mimcDepth)

	in0 := make([]fr.Element, nbInstances)
	in1 := make([]fr.Element, nbInstances)
	fr.Vector(in0).MustSetRandom()
	fr.Vector(in1).MustSetRandom()

	fmt.Println("evaluating circuit")
	start := time.Now().UnixMicro()
	assignment := WireAssignment{in0, in1}.Complete(utils.References(c))
	solved := time.Now().UnixMicro() - start
	fmt.Println("solved in", solved, "μs")

	//b.ResetTimer()
	fmt.Println("constructing proof")
	start = time.Now().UnixMicro()
	_, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC()))
	proved := time.Now().UnixMicro() - start
	fmt.Println("proved in", proved, "μs")
	assert.NoError(b, err)
}

func BenchmarkGkrMimc19(b *testing.B) {
	benchmarkGkrMiMC(b, 1<<19, 91)
}

func BenchmarkGkrMimc17(b *testing.B) {
	benchmarkGkrMiMC(b, 1<<17, 91)
}

func unmarshalProof(printable gkrtesting.PrintableProof) (Proof, error) {
	proof := make(Proof, len(printable))
	for i := range printable {
		finalEvalProof := []fr.Element(nil)

		if printable[i].FinalEvalProof != nil {
			finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof)
			finalEvalProof = make([]fr.Element, finalEvalSlice.Len())
			for k := range finalEvalProof {
				if _, err := setElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil {
					return nil, err
				}
			}
		}

		proof[i] = sumcheckProof{
			partialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)),
			finalEvalProof:  finalEvalProof,
		}
		for k := range printable[i].PartialSumPolys {
			var err error
			if proof[i].partialSumPolys[k], err = sliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil {
				return nil, err
			}
		}
	}
	return proof, nil
}

type TestCase struct {
	Circuit         gkrtypes.Circuit
	Hash            hash.Hash
	Proof           Proof
	FullAssignment  WireAssignment
	InOutAssignment WireAssignment
}

var (
	testCases = make(map[string]*TestCase)
	cache     = gkrtesting.NewCache()
)

func newTestCase(path string) (*TestCase, error) {
	path, err := filepath.Abs(path)
	if err != nil {
		return nil, err
	}
	dir := filepath.Dir(path)

	tCase, ok := testCases[path]
	if ok {
		return tCase, nil
	}

	info, err := cache.ReadTestCaseInfo(path)
	if err != nil {
		return nil, err
	}

	circuit := cache.GetCircuit(filepath.Join(dir, info.Circuit))
	var _hash hash.Hash
	if _hash, err = hashFromDescription(info.Hash); err != nil {
		return nil, err
	}
	var proof Proof
	if proof, err = unmarshalProof(info.Proof); err != nil {
		return nil, err
	}

	fullAssignment := make(WireAssignment, len(circuit))
	inOutAssignment := make(WireAssignment, len(circuit))

	sorted := circuit.TopologicalSort()

	inI, outI := 0, 0
	for i, w := range sorted {
		var assignmentRaw []interface{}
		if w.IsInput() {
			if inI == len(info.Input) {
				return nil, fmt.Errorf("fewer input in vector than in circuit")
			}
			assignmentRaw = info.Input[inI]
			inI++
		} else if w.IsOutput() {
			if outI == len(info.Output) {
				return nil, fmt.Errorf("fewer output in vector than in circuit")
			}
			assignmentRaw = info.Output[outI]
			outI++
		}
		if assignmentRaw != nil {
			var wireAssignment []fr.Element
			if wireAssignment, err = sliceToElementSlice(assignmentRaw); err != nil {
				return nil, err
			}

			fullAssignment[i] = wireAssignment
			inOutAssignment[i] = wireAssignment
		}
	}

	fullAssignment.Complete(utils.References(circuit))

	for i, w := range sorted {
		if w.IsOutput() {
			if err = sliceEquals(inOutAssignment[i], fullAssignment[i]); err != nil {
				return nil, fmt.Errorf("assignment mismatch: %v", err)
			}
		}
	}

	tCase = &TestCase{
		FullAssignment:  fullAssignment,
		InOutAssignment: inOutAssignment,
		Proof:           proof,
		Hash:            _hash,
		Circuit:         circuit,
	}

	testCases[path] = tCase

	return tCase, nil
}
