diff --git a/go.mod b/go.mod index 872c9f0565..8ae200dc61 100644 --- a/go.mod +++ b/go.mod @@ -5,13 +5,16 @@ go 1.17 require ( github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969 - github.com/davecgh/go-spew v1.1.1 // indirect github.com/fxamacker/cbor/v2 v2.2.0 - github.com/kr/pretty v0.2.0 // indirect github.com/leanovate/gopter v0.2.9 + github.com/stretchr/testify v1.7.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/pretty v0.2.0 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/testify v1.7.0 github.com/x448/float16 v0.8.4 // indirect golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect golang.org/x/sys v0.0.0-20210420205809-ac73e9fd8988 // indirect diff --git a/go.sum b/go.sum index c436a373c3..6b59c42b75 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,5 @@ github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a h1:AEpwbXTjBGKoqxuQ6QAcBMEuK0+PtajQj0wJkhTnSd0= github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= -github.com/consensys/gnark-crypto v0.5.4-0.20211222202820-aee0c136fb9f h1:HT4hl58/L66zdhJi8wEbdoXceHv9AnIJij5lP1iOuQw= -github.com/consensys/gnark-crypto v0.5.4-0.20211222202820-aee0c136fb9f/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= -github.com/consensys/gnark-crypto v0.6.1-0.20220202084753-13a5294ab744 h1:P9zNirCiOG0QJfnicNSe+MTRZpRzE6OMwkjiNtm9D4c= -github.com/consensys/gnark-crypto v0.6.1-0.20220202084753-13a5294ab744/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= -github.com/consensys/gnark-crypto v0.6.1-0.20220202090557-876cf1c9d281 h1:GIuvjPX6WL3NRDAEtw3/9puWLxUKIB652ODqtMUL5NM= -github.com/consensys/gnark-crypto v0.6.1-0.20220202090557-876cf1c9d281/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= -github.com/consensys/gnark-crypto v0.6.1-0.20220202101857-f9e35f654649 h1:xsWITMlfdkO4EBgyvns9iSi1qM3KFh6xpVZSiqOzPsc= -github.com/consensys/gnark-crypto v0.6.1-0.20220202101857-f9e35f654649/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= -github.com/consensys/gnark-crypto v0.6.1-0.20220202113516-032351c35ce0 h1:XPqm3EsyhGaKasVKmjDM0qry0oEXbxryAZxTfaeURro= -github.com/consensys/gnark-crypto v0.6.1-0.20220202113516-032351c35ce0/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969 h1:SPKwScbSTdl2p+QvJulHumGa2+5FO6RPh857TCPxda0= github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index 35c2a18447..faf416dce3 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -96,6 +96,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // check if there is an inconsistant constraint var check fr.Element + var solved bool // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire @@ -104,7 +105,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) + if err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) @@ -112,8 +114,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, err } - // compute values for the R1C (ie value * coeff) - a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) + if solved { + // a[i] * b[i] == c[i], since we just computed it. + continue + } // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) @@ -150,8 +154,8 @@ func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) return err } -// mulByCoeff sets res = res * t.Coeff -func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { +// divByCoeff sets res = res / t.Coeff +func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() switch cID { case compiled.CoeffIdOne: @@ -159,106 +163,79 @@ func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { case compiled.CoeffIdMinusOne: res.Neg(res) case compiled.CoeffIdZero: - res.SetZero() - case compiled.CoeffIdTwo: - res.Double(res) + panic("division by 0") default: - res.Mul(res, &cs.Coefficients[cID]) - } -} - -// compute left, right, o part of a cs constraint -// this function is called when all the wires have been computed -// it instantiates the l, r o part of a R1C -func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.Element) { - var v fr.Element - for _, t := range r.L.LinExp { - v = solution.computeTerm(t) - a.Add(&a, &v) - } - for _, t := range r.R.LinExp { - v = solution.computeTerm(t) - b.Add(&b, &v) - } - for _, t := range r.O.LinExp { - v = solution.computeTerm(t) - c.Add(&c, &v) + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &cs.Coefficients[cID]) } - return } -// solveR1c computes a wire by solving a cs -// the function searches for the unset wire (either the unset wire is -// alone, or it can be computed without ambiguity using the other computed wires -// , eg when doing a binary decomposition: either way the missing wire can -// be computed without ambiguity because the cs is correctly ordered) +// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly // -// It returns the 1 if the the position to solve is in the quadratic part (it -// means that there is a division and serves to navigate in the log info for the -// computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated var loc uint8 - var a, b, c fr.Element var termToCompute compiled.Term - processTerm := func(t compiled.Term, val *fr.Element, locValue uint8) error { - vID := t.WireID() + processLExp := func(l compiled.LinearExpression, val *fr.Element, locValue uint8) error { + for _, t := range l { + vID := t.WireID() - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } + // wire is already computed, we just accumulate in val + if solution.solved[vID] { + solution.accumulateInto(t, val) + continue + } - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err + // first we check if this is a hint wire + if hint, ok := cs.MHints[vID]; ok { + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + // now that the wire is saved, accumulate it into a, b or c + solution.accumulateInto(t, val) + continue } - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } - if loc != 0 { - panic("found more than one wire to instantiate") + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue } - termToCompute = t - loc = locValue return nil } - for _, t := range r.L.LinExp { - if err := processTerm(t, &a, 1); err != nil { - return err - } + if err = processLExp(r.L.LinExp, &a, 1); err != nil { + return } - for _, t := range r.R.LinExp { - if err := processTerm(t, &b, 2); err != nil { - return err - } + if err = processLExp(r.R.LinExp, &b, 2); err != nil { + return } - for _, t := range r.O.LinExp { - if err := processTerm(t, &c, 3); err != nil { - return err - } + if err = processLExp(r.O.LinExp, &c, 3); err != nil { + return } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return nil + return } // we compute the wire value and instantiate it + solved = true vID := termToCompute.WireID() // solver result @@ -269,23 +246,34 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { if !b.IsZero() { wire.Div(&c, &b). Sub(&wire, &a) - cs.mulByCoeff(&wire, termToCompute) + a.Add(&a, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) - cs.mulByCoeff(&wire, termToCompute) + b.Add(&b, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 3: wire.Mul(&a, &b). Sub(&wire, &c) - cs.mulByCoeff(&wire, termToCompute) + + c.Add(&c, &wire) } + // wire is the term (coeff * value) + // but in the solution we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + cs.divByCoeff(&wire, termToCompute) solution.set(vID, wire) - return nil + return } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bls12-377/cs/r1cs_test.go b/internal/backend/bls12-377/cs/r1cs_test.go index ef7aa47275..e63b16c7d6 100644 --- a/internal/backend/bls12-377/cs/r1cs_test.go +++ b/internal/backend/bls12-377/cs/r1cs_test.go @@ -116,3 +116,40 @@ func TestSerialization(t *testing.T) { } } + +const n = 10000 + +type circuit struct { + X frontend.Variable + Y frontend.Variable `gnark:",public"` +} + +func (circuit *circuit) Define(api frontend.API) error { + for i := 0; i < n; i++ { + circuit.X = api.Add(api.Mul(circuit.X, circuit.X), circuit.X, 42) + } + api.AssertIsEqual(circuit.X, circuit.Y) + return nil +} + +func BenchmarkSolve(b *testing.B) { + + var c circuit + ccs, err := frontend.Compile(ecc.BLS12_377, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + if err != nil { + b.Fatal(err) + } + + var w circuit + w.X = 1 + w.Y = 1 + witness, err := frontend.NewWitness(&w, ecc.BLS12_377) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } +} diff --git a/internal/backend/bls12-377/cs/solution.go b/internal/backend/bls12-377/cs/solution.go index 99ddb85aaf..58a903a999 100644 --- a/internal/backend/bls12-377/cs/solution.go +++ b/internal/backend/bls12-377/cs/solution.go @@ -100,11 +100,32 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } } +// r += (t.coeff*t.value) +func (s *solution) accumulateInto(t compiled.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + switch cID { + case compiled.CoeffIdZero: + return + case compiled.CoeffIdOne: + r.Add(r, &s.values[vID]) + case compiled.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case compiled.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + func (s *solution) computeLinearExpression(l compiled.LinearExpression) fr.Element { var res fr.Element - for i := 0; i < len(l); i++ { - v := s.computeTerm(l[i]) - res.Add(&res, &v) + for _, t := range l { + s.accumulateInto(t, &res) } return res } diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index d18c6f5ca0..0892b140fd 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -96,6 +96,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // check if there is an inconsistant constraint var check fr.Element + var solved bool // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire @@ -104,7 +105,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) + if err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) @@ -112,8 +114,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, err } - // compute values for the R1C (ie value * coeff) - a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) + if solved { + // a[i] * b[i] == c[i], since we just computed it. + continue + } // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) @@ -150,8 +154,8 @@ func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) return err } -// mulByCoeff sets res = res * t.Coeff -func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { +// divByCoeff sets res = res / t.Coeff +func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() switch cID { case compiled.CoeffIdOne: @@ -159,106 +163,79 @@ func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { case compiled.CoeffIdMinusOne: res.Neg(res) case compiled.CoeffIdZero: - res.SetZero() - case compiled.CoeffIdTwo: - res.Double(res) + panic("division by 0") default: - res.Mul(res, &cs.Coefficients[cID]) - } -} - -// compute left, right, o part of a cs constraint -// this function is called when all the wires have been computed -// it instantiates the l, r o part of a R1C -func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.Element) { - var v fr.Element - for _, t := range r.L.LinExp { - v = solution.computeTerm(t) - a.Add(&a, &v) - } - for _, t := range r.R.LinExp { - v = solution.computeTerm(t) - b.Add(&b, &v) - } - for _, t := range r.O.LinExp { - v = solution.computeTerm(t) - c.Add(&c, &v) + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &cs.Coefficients[cID]) } - return } -// solveR1c computes a wire by solving a cs -// the function searches for the unset wire (either the unset wire is -// alone, or it can be computed without ambiguity using the other computed wires -// , eg when doing a binary decomposition: either way the missing wire can -// be computed without ambiguity because the cs is correctly ordered) +// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly // -// It returns the 1 if the the position to solve is in the quadratic part (it -// means that there is a division and serves to navigate in the log info for the -// computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated var loc uint8 - var a, b, c fr.Element var termToCompute compiled.Term - processTerm := func(t compiled.Term, val *fr.Element, locValue uint8) error { - vID := t.WireID() + processLExp := func(l compiled.LinearExpression, val *fr.Element, locValue uint8) error { + for _, t := range l { + vID := t.WireID() - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } + // wire is already computed, we just accumulate in val + if solution.solved[vID] { + solution.accumulateInto(t, val) + continue + } - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err + // first we check if this is a hint wire + if hint, ok := cs.MHints[vID]; ok { + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + // now that the wire is saved, accumulate it into a, b or c + solution.accumulateInto(t, val) + continue } - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } - if loc != 0 { - panic("found more than one wire to instantiate") + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue } - termToCompute = t - loc = locValue return nil } - for _, t := range r.L.LinExp { - if err := processTerm(t, &a, 1); err != nil { - return err - } + if err = processLExp(r.L.LinExp, &a, 1); err != nil { + return } - for _, t := range r.R.LinExp { - if err := processTerm(t, &b, 2); err != nil { - return err - } + if err = processLExp(r.R.LinExp, &b, 2); err != nil { + return } - for _, t := range r.O.LinExp { - if err := processTerm(t, &c, 3); err != nil { - return err - } + if err = processLExp(r.O.LinExp, &c, 3); err != nil { + return } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return nil + return } // we compute the wire value and instantiate it + solved = true vID := termToCompute.WireID() // solver result @@ -269,23 +246,34 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { if !b.IsZero() { wire.Div(&c, &b). Sub(&wire, &a) - cs.mulByCoeff(&wire, termToCompute) + a.Add(&a, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) - cs.mulByCoeff(&wire, termToCompute) + b.Add(&b, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 3: wire.Mul(&a, &b). Sub(&wire, &c) - cs.mulByCoeff(&wire, termToCompute) + + c.Add(&c, &wire) } + // wire is the term (coeff * value) + // but in the solution we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + cs.divByCoeff(&wire, termToCompute) solution.set(vID, wire) - return nil + return } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bls12-381/cs/r1cs_test.go b/internal/backend/bls12-381/cs/r1cs_test.go index b6992f0f2c..a96a65d21c 100644 --- a/internal/backend/bls12-381/cs/r1cs_test.go +++ b/internal/backend/bls12-381/cs/r1cs_test.go @@ -116,3 +116,40 @@ func TestSerialization(t *testing.T) { } } + +const n = 10000 + +type circuit struct { + X frontend.Variable + Y frontend.Variable `gnark:",public"` +} + +func (circuit *circuit) Define(api frontend.API) error { + for i := 0; i < n; i++ { + circuit.X = api.Add(api.Mul(circuit.X, circuit.X), circuit.X, 42) + } + api.AssertIsEqual(circuit.X, circuit.Y) + return nil +} + +func BenchmarkSolve(b *testing.B) { + + var c circuit + ccs, err := frontend.Compile(ecc.BLS12_381, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + if err != nil { + b.Fatal(err) + } + + var w circuit + w.X = 1 + w.Y = 1 + witness, err := frontend.NewWitness(&w, ecc.BLS12_381) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } +} diff --git a/internal/backend/bls12-381/cs/solution.go b/internal/backend/bls12-381/cs/solution.go index 255c7721d9..d01da1ed65 100644 --- a/internal/backend/bls12-381/cs/solution.go +++ b/internal/backend/bls12-381/cs/solution.go @@ -100,11 +100,32 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } } +// r += (t.coeff*t.value) +func (s *solution) accumulateInto(t compiled.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + switch cID { + case compiled.CoeffIdZero: + return + case compiled.CoeffIdOne: + r.Add(r, &s.values[vID]) + case compiled.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case compiled.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + func (s *solution) computeLinearExpression(l compiled.LinearExpression) fr.Element { var res fr.Element - for i := 0; i < len(l); i++ { - v := s.computeTerm(l[i]) - res.Add(&res, &v) + for _, t := range l { + s.accumulateInto(t, &res) } return res } diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index 92a01bb696..9bd39b4b8f 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -96,6 +96,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // check if there is an inconsistant constraint var check fr.Element + var solved bool // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire @@ -104,7 +105,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) + if err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) @@ -112,8 +114,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, err } - // compute values for the R1C (ie value * coeff) - a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) + if solved { + // a[i] * b[i] == c[i], since we just computed it. + continue + } // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) @@ -150,8 +154,8 @@ func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) return err } -// mulByCoeff sets res = res * t.Coeff -func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { +// divByCoeff sets res = res / t.Coeff +func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() switch cID { case compiled.CoeffIdOne: @@ -159,106 +163,79 @@ func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { case compiled.CoeffIdMinusOne: res.Neg(res) case compiled.CoeffIdZero: - res.SetZero() - case compiled.CoeffIdTwo: - res.Double(res) + panic("division by 0") default: - res.Mul(res, &cs.Coefficients[cID]) - } -} - -// compute left, right, o part of a cs constraint -// this function is called when all the wires have been computed -// it instantiates the l, r o part of a R1C -func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.Element) { - var v fr.Element - for _, t := range r.L.LinExp { - v = solution.computeTerm(t) - a.Add(&a, &v) - } - for _, t := range r.R.LinExp { - v = solution.computeTerm(t) - b.Add(&b, &v) - } - for _, t := range r.O.LinExp { - v = solution.computeTerm(t) - c.Add(&c, &v) + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &cs.Coefficients[cID]) } - return } -// solveR1c computes a wire by solving a cs -// the function searches for the unset wire (either the unset wire is -// alone, or it can be computed without ambiguity using the other computed wires -// , eg when doing a binary decomposition: either way the missing wire can -// be computed without ambiguity because the cs is correctly ordered) +// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly // -// It returns the 1 if the the position to solve is in the quadratic part (it -// means that there is a division and serves to navigate in the log info for the -// computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated var loc uint8 - var a, b, c fr.Element var termToCompute compiled.Term - processTerm := func(t compiled.Term, val *fr.Element, locValue uint8) error { - vID := t.WireID() + processLExp := func(l compiled.LinearExpression, val *fr.Element, locValue uint8) error { + for _, t := range l { + vID := t.WireID() - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } + // wire is already computed, we just accumulate in val + if solution.solved[vID] { + solution.accumulateInto(t, val) + continue + } - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err + // first we check if this is a hint wire + if hint, ok := cs.MHints[vID]; ok { + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + // now that the wire is saved, accumulate it into a, b or c + solution.accumulateInto(t, val) + continue } - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } - if loc != 0 { - panic("found more than one wire to instantiate") + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue } - termToCompute = t - loc = locValue return nil } - for _, t := range r.L.LinExp { - if err := processTerm(t, &a, 1); err != nil { - return err - } + if err = processLExp(r.L.LinExp, &a, 1); err != nil { + return } - for _, t := range r.R.LinExp { - if err := processTerm(t, &b, 2); err != nil { - return err - } + if err = processLExp(r.R.LinExp, &b, 2); err != nil { + return } - for _, t := range r.O.LinExp { - if err := processTerm(t, &c, 3); err != nil { - return err - } + if err = processLExp(r.O.LinExp, &c, 3); err != nil { + return } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return nil + return } // we compute the wire value and instantiate it + solved = true vID := termToCompute.WireID() // solver result @@ -269,23 +246,34 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { if !b.IsZero() { wire.Div(&c, &b). Sub(&wire, &a) - cs.mulByCoeff(&wire, termToCompute) + a.Add(&a, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) - cs.mulByCoeff(&wire, termToCompute) + b.Add(&b, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 3: wire.Mul(&a, &b). Sub(&wire, &c) - cs.mulByCoeff(&wire, termToCompute) + + c.Add(&c, &wire) } + // wire is the term (coeff * value) + // but in the solution we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + cs.divByCoeff(&wire, termToCompute) solution.set(vID, wire) - return nil + return } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bls24-315/cs/r1cs_test.go b/internal/backend/bls24-315/cs/r1cs_test.go index 3ad098aaae..931504b2f7 100644 --- a/internal/backend/bls24-315/cs/r1cs_test.go +++ b/internal/backend/bls24-315/cs/r1cs_test.go @@ -116,3 +116,40 @@ func TestSerialization(t *testing.T) { } } + +const n = 10000 + +type circuit struct { + X frontend.Variable + Y frontend.Variable `gnark:",public"` +} + +func (circuit *circuit) Define(api frontend.API) error { + for i := 0; i < n; i++ { + circuit.X = api.Add(api.Mul(circuit.X, circuit.X), circuit.X, 42) + } + api.AssertIsEqual(circuit.X, circuit.Y) + return nil +} + +func BenchmarkSolve(b *testing.B) { + + var c circuit + ccs, err := frontend.Compile(ecc.BLS24_315, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + if err != nil { + b.Fatal(err) + } + + var w circuit + w.X = 1 + w.Y = 1 + witness, err := frontend.NewWitness(&w, ecc.BLS24_315) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } +} diff --git a/internal/backend/bls24-315/cs/solution.go b/internal/backend/bls24-315/cs/solution.go index b69f343696..2773465dfe 100644 --- a/internal/backend/bls24-315/cs/solution.go +++ b/internal/backend/bls24-315/cs/solution.go @@ -100,11 +100,32 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } } +// r += (t.coeff*t.value) +func (s *solution) accumulateInto(t compiled.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + switch cID { + case compiled.CoeffIdZero: + return + case compiled.CoeffIdOne: + r.Add(r, &s.values[vID]) + case compiled.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case compiled.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + func (s *solution) computeLinearExpression(l compiled.LinearExpression) fr.Element { var res fr.Element - for i := 0; i < len(l); i++ { - v := s.computeTerm(l[i]) - res.Add(&res, &v) + for _, t := range l { + s.accumulateInto(t, &res) } return res } diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index 990e9e8f9f..c86e8ddb7d 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -96,6 +96,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // check if there is an inconsistant constraint var check fr.Element + var solved bool // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire @@ -104,7 +105,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) + if err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) @@ -112,8 +114,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, err } - // compute values for the R1C (ie value * coeff) - a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) + if solved { + // a[i] * b[i] == c[i], since we just computed it. + continue + } // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) @@ -150,8 +154,8 @@ func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) return err } -// mulByCoeff sets res = res * t.Coeff -func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { +// divByCoeff sets res = res / t.Coeff +func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() switch cID { case compiled.CoeffIdOne: @@ -159,106 +163,79 @@ func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { case compiled.CoeffIdMinusOne: res.Neg(res) case compiled.CoeffIdZero: - res.SetZero() - case compiled.CoeffIdTwo: - res.Double(res) + panic("division by 0") default: - res.Mul(res, &cs.Coefficients[cID]) - } -} - -// compute left, right, o part of a cs constraint -// this function is called when all the wires have been computed -// it instantiates the l, r o part of a R1C -func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.Element) { - var v fr.Element - for _, t := range r.L.LinExp { - v = solution.computeTerm(t) - a.Add(&a, &v) - } - for _, t := range r.R.LinExp { - v = solution.computeTerm(t) - b.Add(&b, &v) - } - for _, t := range r.O.LinExp { - v = solution.computeTerm(t) - c.Add(&c, &v) + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &cs.Coefficients[cID]) } - return } -// solveR1c computes a wire by solving a cs -// the function searches for the unset wire (either the unset wire is -// alone, or it can be computed without ambiguity using the other computed wires -// , eg when doing a binary decomposition: either way the missing wire can -// be computed without ambiguity because the cs is correctly ordered) +// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly // -// It returns the 1 if the the position to solve is in the quadratic part (it -// means that there is a division and serves to navigate in the log info for the -// computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated var loc uint8 - var a, b, c fr.Element var termToCompute compiled.Term - processTerm := func(t compiled.Term, val *fr.Element, locValue uint8) error { - vID := t.WireID() + processLExp := func(l compiled.LinearExpression, val *fr.Element, locValue uint8) error { + for _, t := range l { + vID := t.WireID() - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } + // wire is already computed, we just accumulate in val + if solution.solved[vID] { + solution.accumulateInto(t, val) + continue + } - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err + // first we check if this is a hint wire + if hint, ok := cs.MHints[vID]; ok { + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + // now that the wire is saved, accumulate it into a, b or c + solution.accumulateInto(t, val) + continue } - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } - if loc != 0 { - panic("found more than one wire to instantiate") + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue } - termToCompute = t - loc = locValue return nil } - for _, t := range r.L.LinExp { - if err := processTerm(t, &a, 1); err != nil { - return err - } + if err = processLExp(r.L.LinExp, &a, 1); err != nil { + return } - for _, t := range r.R.LinExp { - if err := processTerm(t, &b, 2); err != nil { - return err - } + if err = processLExp(r.R.LinExp, &b, 2); err != nil { + return } - for _, t := range r.O.LinExp { - if err := processTerm(t, &c, 3); err != nil { - return err - } + if err = processLExp(r.O.LinExp, &c, 3); err != nil { + return } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return nil + return } // we compute the wire value and instantiate it + solved = true vID := termToCompute.WireID() // solver result @@ -269,23 +246,34 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { if !b.IsZero() { wire.Div(&c, &b). Sub(&wire, &a) - cs.mulByCoeff(&wire, termToCompute) + a.Add(&a, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) - cs.mulByCoeff(&wire, termToCompute) + b.Add(&b, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 3: wire.Mul(&a, &b). Sub(&wire, &c) - cs.mulByCoeff(&wire, termToCompute) + + c.Add(&c, &wire) } + // wire is the term (coeff * value) + // but in the solution we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + cs.divByCoeff(&wire, termToCompute) solution.set(vID, wire) - return nil + return } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bn254/cs/r1cs_test.go b/internal/backend/bn254/cs/r1cs_test.go index 7246132a23..c2eb0a7663 100644 --- a/internal/backend/bn254/cs/r1cs_test.go +++ b/internal/backend/bn254/cs/r1cs_test.go @@ -116,3 +116,40 @@ func TestSerialization(t *testing.T) { } } + +const n = 10000 + +type circuit struct { + X frontend.Variable + Y frontend.Variable `gnark:",public"` +} + +func (circuit *circuit) Define(api frontend.API) error { + for i := 0; i < n; i++ { + circuit.X = api.Add(api.Mul(circuit.X, circuit.X), circuit.X, 42) + } + api.AssertIsEqual(circuit.X, circuit.Y) + return nil +} + +func BenchmarkSolve(b *testing.B) { + + var c circuit + ccs, err := frontend.Compile(ecc.BN254, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + if err != nil { + b.Fatal(err) + } + + var w circuit + w.X = 1 + w.Y = 1 + witness, err := frontend.NewWitness(&w, ecc.BN254) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } +} diff --git a/internal/backend/bn254/cs/solution.go b/internal/backend/bn254/cs/solution.go index f4dff4d854..19f8e19653 100644 --- a/internal/backend/bn254/cs/solution.go +++ b/internal/backend/bn254/cs/solution.go @@ -100,11 +100,32 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } } +// r += (t.coeff*t.value) +func (s *solution) accumulateInto(t compiled.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + switch cID { + case compiled.CoeffIdZero: + return + case compiled.CoeffIdOne: + r.Add(r, &s.values[vID]) + case compiled.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case compiled.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + func (s *solution) computeLinearExpression(l compiled.LinearExpression) fr.Element { var res fr.Element - for i := 0; i < len(l); i++ { - v := s.computeTerm(l[i]) - res.Add(&res, &v) + for _, t := range l { + s.accumulateInto(t, &res) } return res } diff --git a/internal/backend/bw6-633/cs/r1cs.go b/internal/backend/bw6-633/cs/r1cs.go index 6f5ceb5bbe..0a9b76ff49 100644 --- a/internal/backend/bw6-633/cs/r1cs.go +++ b/internal/backend/bw6-633/cs/r1cs.go @@ -96,6 +96,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // check if there is an inconsistant constraint var check fr.Element + var solved bool // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire @@ -104,7 +105,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) + if err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) @@ -112,8 +114,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, err } - // compute values for the R1C (ie value * coeff) - a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) + if solved { + // a[i] * b[i] == c[i], since we just computed it. + continue + } // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) @@ -150,8 +154,8 @@ func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) return err } -// mulByCoeff sets res = res * t.Coeff -func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { +// divByCoeff sets res = res / t.Coeff +func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() switch cID { case compiled.CoeffIdOne: @@ -159,106 +163,79 @@ func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { case compiled.CoeffIdMinusOne: res.Neg(res) case compiled.CoeffIdZero: - res.SetZero() - case compiled.CoeffIdTwo: - res.Double(res) + panic("division by 0") default: - res.Mul(res, &cs.Coefficients[cID]) - } -} - -// compute left, right, o part of a cs constraint -// this function is called when all the wires have been computed -// it instantiates the l, r o part of a R1C -func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.Element) { - var v fr.Element - for _, t := range r.L.LinExp { - v = solution.computeTerm(t) - a.Add(&a, &v) - } - for _, t := range r.R.LinExp { - v = solution.computeTerm(t) - b.Add(&b, &v) - } - for _, t := range r.O.LinExp { - v = solution.computeTerm(t) - c.Add(&c, &v) + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &cs.Coefficients[cID]) } - return } -// solveR1c computes a wire by solving a cs -// the function searches for the unset wire (either the unset wire is -// alone, or it can be computed without ambiguity using the other computed wires -// , eg when doing a binary decomposition: either way the missing wire can -// be computed without ambiguity because the cs is correctly ordered) +// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly // -// It returns the 1 if the the position to solve is in the quadratic part (it -// means that there is a division and serves to navigate in the log info for the -// computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated var loc uint8 - var a, b, c fr.Element var termToCompute compiled.Term - processTerm := func(t compiled.Term, val *fr.Element, locValue uint8) error { - vID := t.WireID() + processLExp := func(l compiled.LinearExpression, val *fr.Element, locValue uint8) error { + for _, t := range l { + vID := t.WireID() - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } + // wire is already computed, we just accumulate in val + if solution.solved[vID] { + solution.accumulateInto(t, val) + continue + } - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err + // first we check if this is a hint wire + if hint, ok := cs.MHints[vID]; ok { + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + // now that the wire is saved, accumulate it into a, b or c + solution.accumulateInto(t, val) + continue } - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } - if loc != 0 { - panic("found more than one wire to instantiate") + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue } - termToCompute = t - loc = locValue return nil } - for _, t := range r.L.LinExp { - if err := processTerm(t, &a, 1); err != nil { - return err - } + if err = processLExp(r.L.LinExp, &a, 1); err != nil { + return } - for _, t := range r.R.LinExp { - if err := processTerm(t, &b, 2); err != nil { - return err - } + if err = processLExp(r.R.LinExp, &b, 2); err != nil { + return } - for _, t := range r.O.LinExp { - if err := processTerm(t, &c, 3); err != nil { - return err - } + if err = processLExp(r.O.LinExp, &c, 3); err != nil { + return } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return nil + return } // we compute the wire value and instantiate it + solved = true vID := termToCompute.WireID() // solver result @@ -269,23 +246,34 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { if !b.IsZero() { wire.Div(&c, &b). Sub(&wire, &a) - cs.mulByCoeff(&wire, termToCompute) + a.Add(&a, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) - cs.mulByCoeff(&wire, termToCompute) + b.Add(&b, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 3: wire.Mul(&a, &b). Sub(&wire, &c) - cs.mulByCoeff(&wire, termToCompute) + + c.Add(&c, &wire) } + // wire is the term (coeff * value) + // but in the solution we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + cs.divByCoeff(&wire, termToCompute) solution.set(vID, wire) - return nil + return } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bw6-633/cs/r1cs_test.go b/internal/backend/bw6-633/cs/r1cs_test.go index 5f98742be5..5900e2b650 100644 --- a/internal/backend/bw6-633/cs/r1cs_test.go +++ b/internal/backend/bw6-633/cs/r1cs_test.go @@ -116,3 +116,40 @@ func TestSerialization(t *testing.T) { } } + +const n = 10000 + +type circuit struct { + X frontend.Variable + Y frontend.Variable `gnark:",public"` +} + +func (circuit *circuit) Define(api frontend.API) error { + for i := 0; i < n; i++ { + circuit.X = api.Add(api.Mul(circuit.X, circuit.X), circuit.X, 42) + } + api.AssertIsEqual(circuit.X, circuit.Y) + return nil +} + +func BenchmarkSolve(b *testing.B) { + + var c circuit + ccs, err := frontend.Compile(ecc.BW6_633, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + if err != nil { + b.Fatal(err) + } + + var w circuit + w.X = 1 + w.Y = 1 + witness, err := frontend.NewWitness(&w, ecc.BW6_633) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } +} diff --git a/internal/backend/bw6-633/cs/solution.go b/internal/backend/bw6-633/cs/solution.go index d8c2e5c824..f8ecbc3ba0 100644 --- a/internal/backend/bw6-633/cs/solution.go +++ b/internal/backend/bw6-633/cs/solution.go @@ -100,11 +100,32 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } } +// r += (t.coeff*t.value) +func (s *solution) accumulateInto(t compiled.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + switch cID { + case compiled.CoeffIdZero: + return + case compiled.CoeffIdOne: + r.Add(r, &s.values[vID]) + case compiled.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case compiled.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + func (s *solution) computeLinearExpression(l compiled.LinearExpression) fr.Element { var res fr.Element - for i := 0; i < len(l); i++ { - v := s.computeTerm(l[i]) - res.Add(&res, &v) + for _, t := range l { + s.accumulateInto(t, &res) } return res } diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index ade1ca4ea4..80558dfb1b 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -96,6 +96,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // check if there is an inconsistant constraint var check fr.Element + var solved bool // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire @@ -104,7 +105,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) + if err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) @@ -112,8 +114,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, err } - // compute values for the R1C (ie value * coeff) - a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) + if solved { + // a[i] * b[i] == c[i], since we just computed it. + continue + } // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) @@ -150,8 +154,8 @@ func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) return err } -// mulByCoeff sets res = res * t.Coeff -func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { +// divByCoeff sets res = res / t.Coeff +func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() switch cID { case compiled.CoeffIdOne: @@ -159,106 +163,79 @@ func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { case compiled.CoeffIdMinusOne: res.Neg(res) case compiled.CoeffIdZero: - res.SetZero() - case compiled.CoeffIdTwo: - res.Double(res) + panic("division by 0") default: - res.Mul(res, &cs.Coefficients[cID]) - } -} - -// compute left, right, o part of a cs constraint -// this function is called when all the wires have been computed -// it instantiates the l, r o part of a R1C -func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.Element) { - var v fr.Element - for _, t := range r.L.LinExp { - v = solution.computeTerm(t) - a.Add(&a, &v) - } - for _, t := range r.R.LinExp { - v = solution.computeTerm(t) - b.Add(&b, &v) - } - for _, t := range r.O.LinExp { - v = solution.computeTerm(t) - c.Add(&c, &v) + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &cs.Coefficients[cID]) } - return } -// solveR1c computes a wire by solving a cs -// the function searches for the unset wire (either the unset wire is -// alone, or it can be computed without ambiguity using the other computed wires -// , eg when doing a binary decomposition: either way the missing wire can -// be computed without ambiguity because the cs is correctly ordered) +// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly // -// It returns the 1 if the the position to solve is in the quadratic part (it -// means that there is a division and serves to navigate in the log info for the -// computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated var loc uint8 - var a, b, c fr.Element var termToCompute compiled.Term - processTerm := func(t compiled.Term, val *fr.Element, locValue uint8) error { - vID := t.WireID() + processLExp := func(l compiled.LinearExpression, val *fr.Element, locValue uint8) error { + for _, t := range l { + vID := t.WireID() - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } + // wire is already computed, we just accumulate in val + if solution.solved[vID] { + solution.accumulateInto(t, val) + continue + } - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err + // first we check if this is a hint wire + if hint, ok := cs.MHints[vID]; ok { + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + // now that the wire is saved, accumulate it into a, b or c + solution.accumulateInto(t, val) + continue } - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } - if loc != 0 { - panic("found more than one wire to instantiate") + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue } - termToCompute = t - loc = locValue return nil } - for _, t := range r.L.LinExp { - if err := processTerm(t, &a, 1); err != nil { - return err - } + if err = processLExp(r.L.LinExp, &a, 1); err != nil { + return } - for _, t := range r.R.LinExp { - if err := processTerm(t, &b, 2); err != nil { - return err - } + if err = processLExp(r.R.LinExp, &b, 2); err != nil { + return } - for _, t := range r.O.LinExp { - if err := processTerm(t, &c, 3); err != nil { - return err - } + if err = processLExp(r.O.LinExp, &c, 3); err != nil { + return } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return nil + return } // we compute the wire value and instantiate it + solved = true vID := termToCompute.WireID() // solver result @@ -269,23 +246,34 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { if !b.IsZero() { wire.Div(&c, &b). Sub(&wire, &a) - cs.mulByCoeff(&wire, termToCompute) + a.Add(&a, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) - cs.mulByCoeff(&wire, termToCompute) + b.Add(&b, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 3: wire.Mul(&a, &b). Sub(&wire, &c) - cs.mulByCoeff(&wire, termToCompute) + + c.Add(&c, &wire) } + // wire is the term (coeff * value) + // but in the solution we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + cs.divByCoeff(&wire, termToCompute) solution.set(vID, wire) - return nil + return } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bw6-761/cs/r1cs_test.go b/internal/backend/bw6-761/cs/r1cs_test.go index 6461531004..94eec2486c 100644 --- a/internal/backend/bw6-761/cs/r1cs_test.go +++ b/internal/backend/bw6-761/cs/r1cs_test.go @@ -120,3 +120,40 @@ func TestSerialization(t *testing.T) { } } + +const n = 10000 + +type circuit struct { + X frontend.Variable + Y frontend.Variable `gnark:",public"` +} + +func (circuit *circuit) Define(api frontend.API) error { + for i := 0; i < n; i++ { + circuit.X = api.Add(api.Mul(circuit.X, circuit.X), circuit.X, 42) + } + api.AssertIsEqual(circuit.X, circuit.Y) + return nil +} + +func BenchmarkSolve(b *testing.B) { + + var c circuit + ccs, err := frontend.Compile(ecc.BW6_761, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + if err != nil { + b.Fatal(err) + } + + var w circuit + w.X = 1 + w.Y = 1 + witness, err := frontend.NewWitness(&w, ecc.BW6_761) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } +} diff --git a/internal/backend/bw6-761/cs/solution.go b/internal/backend/bw6-761/cs/solution.go index ba1a4b0688..5080524e1e 100644 --- a/internal/backend/bw6-761/cs/solution.go +++ b/internal/backend/bw6-761/cs/solution.go @@ -100,11 +100,32 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } } +// r += (t.coeff*t.value) +func (s *solution) accumulateInto(t compiled.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + switch cID { + case compiled.CoeffIdZero: + return + case compiled.CoeffIdOne: + r.Add(r, &s.values[vID]) + case compiled.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case compiled.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + func (s *solution) computeLinearExpression(l compiled.LinearExpression) fr.Element { var res fr.Element - for i := 0; i < len(l); i++ { - v := s.computeTerm(l[i]) - res.Add(&res, &v) + for _, t := range l { + s.accumulateInto(t, &res) } return res } diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index 60cb8df4b7..9a1de5980a 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -78,6 +78,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // check if there is an inconsistant constraint var check fr.Element + var solved bool // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire @@ -86,7 +87,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) + if err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) @@ -94,8 +96,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, err } - // compute values for the R1C (ie value * coeff) - a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) + if solved { + // a[i] * b[i] == c[i], since we just computed it. + continue + } // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) @@ -132,8 +136,8 @@ func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) return err } -// mulByCoeff sets res = res * t.Coeff -func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { +// divByCoeff sets res = res / t.Coeff +func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() switch cID { case compiled.CoeffIdOne: @@ -141,133 +145,120 @@ func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { case compiled.CoeffIdMinusOne: res.Neg(res) case compiled.CoeffIdZero: - res.SetZero() - case compiled.CoeffIdTwo: - res.Double(res) + panic("division by 0") default: - res.Mul(res, &cs.Coefficients[cID]) + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &cs.Coefficients[cID]) } } -// compute left, right, o part of a cs constraint -// this function is called when all the wires have been computed -// it instantiates the l, r o part of a R1C -func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.Element) { - var v fr.Element - for _, t := range r.L.LinExp { - v = solution.computeTerm(t) - a.Add(&a, &v) - } - for _, t := range r.R.LinExp { - v = solution.computeTerm(t) - b.Add(&b, &v) - } - for _, t := range r.O.LinExp { - v = solution.computeTerm(t) - c.Add(&c, &v) - } - return -} -// solveR1c computes a wire by solving a cs -// the function searches for the unset wire (either the unset wire is -// alone, or it can be computed without ambiguity using the other computed wires -// , eg when doing a binary decomposition: either way the missing wire can -// be computed without ambiguity because the cs is correctly ordered) -// -// It returns the 1 if the the position to solve is in the quadratic part (it -// means that there is a division and serves to navigate in the log info for the -// computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { + +// solveConstraint compute unsolved wires in the constraint, if any and set the solution accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a,b,c fr.Element, err error) { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated var loc uint8 - var a, b, c fr.Element var termToCompute compiled.Term - processTerm := func(t compiled.Term, val *fr.Element, locValue uint8) error { - vID := t.WireID() + processLExp := func(l compiled.LinearExpression, val *fr.Element, locValue uint8) error { + for _, t := range l { + vID := t.WireID() - // wire is already computed, we just accumulate in val - if solution.solved[vID] { - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } + // wire is already computed, we just accumulate in val + if solution.solved[vID] { + solution.accumulateInto(t, val) + continue + } - // first we check if this is a hint wire - if hint, ok := cs.MHints[vID]; ok { - if err := solution.solveWithHint(vID, hint); err != nil { - return err + // first we check if this is a hint wire + if hint, ok := cs.MHints[vID]; ok { + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + // now that the wire is saved, accumulate it into a, b or c + solution.accumulateInto(t, val) + continue } - v := solution.computeTerm(t) - val.Add(val, &v) - return nil - } - if loc != 0 { - panic("found more than one wire to instantiate") + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue } - termToCompute = t - loc = locValue return nil } - for _, t := range r.L.LinExp { - if err := processTerm(t, &a, 1); err != nil { - return err - } + if err = processLExp(r.L.LinExp, &a, 1); err != nil { + return } - for _, t := range r.R.LinExp { - if err := processTerm(t, &b, 2); err != nil { - return err - } + if err = processLExp(r.R.LinExp, &b, 2); err != nil { + return } - for _, t := range r.O.LinExp { - if err := processTerm(t, &c, 3); err != nil { - return err - } + if err = processLExp(r.O.LinExp, &c, 3); err != nil { + return } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return nil + return } // we compute the wire value and instantiate it + solved = true vID := termToCompute.WireID() // solver result var wire fr.Element + switch loc { case 1: if !b.IsZero() { wire.Div(&c, &b). Sub(&wire, &a) - cs.mulByCoeff(&wire, termToCompute) + a.Add(&a, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) - cs.mulByCoeff(&wire, termToCompute) + b.Add(&b, &wire) + } else { + // we didn't actually ensure that a * b == c + solved = false } case 3: wire.Mul(&a, &b). Sub(&wire, &c) - cs.mulByCoeff(&wire, termToCompute) + + c.Add(&c, &wire) } + // wire is the term (coeff * value) + // but in the solution we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + cs.divByCoeff(&wire, termToCompute) solution.set(vID, wire) - return nil + return } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl index ab64434800..701fd3a234 100644 --- a/internal/generator/backend/template/representations/solution.go.tmpl +++ b/internal/generator/backend/template/representations/solution.go.tmpl @@ -81,11 +81,32 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } } +// r += (t.coeff*t.value) +func (s *solution) accumulateInto(t compiled.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + switch cID { + case compiled.CoeffIdZero: + return + case compiled.CoeffIdOne: + r.Add(r, &s.values[vID]) + case compiled.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case compiled.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + func (s *solution) computeLinearExpression(l compiled.LinearExpression) fr.Element { var res fr.Element - for i := 0; i < len(l); i++ { - v := s.computeTerm(l[i]) - res.Add(&res, &v) + for _, t := range l { + s.accumulateInto(t, &res) } return res } diff --git a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl index 61d3b28776..852a1583fe 100644 --- a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl @@ -103,4 +103,43 @@ func TestSerialization(t *testing.T) { }) } +} + + +const n = 10000 + +type circuit struct { + X frontend.Variable + Y frontend.Variable `gnark:",public"` +} + +func (circuit *circuit) Define(api frontend.API) error { + for i := 0; i < n; i++ { + circuit.X = api.Add(api.Mul(circuit.X, circuit.X), circuit.X, 42) + } + api.AssertIsEqual(circuit.X, circuit.Y) + return nil +} + +func BenchmarkSolve(b *testing.B) { + + var c circuit + ccs, err := frontend.Compile(ecc.{{ .CurveID }}, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + if err != nil { + b.Fatal(err) + } + + var w circuit + w.X = 1 + w.Y = 1 + witness, err := frontend.NewWitness(&w, ecc.{{ .CurveID }}) + if err != nil { + b.Fatal(err) + } + + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } } \ No newline at end of file