diff --git a/.github/workflows/develop.yml b/.github/workflows/develop.yml index 02b7be30e9..37a875a26a 100644 --- a/.github/workflows/develop.yml +++ b/.github/workflows/develop.yml @@ -81,11 +81,10 @@ jobs: - name: Test run: | go test -v -timeout=30m ./... - go test -v -timeout=30m -tags=noadx -short - name: Test (race) if: matrix.os == 'ubuntu-latest' run: | - go test -v -timeout=30m -race -short ./... + go test -v -timeout=50m -race -short ./... # - name: Test (32bits) # if: matrix.os == 'ubuntu-latest' # run: | diff --git a/debug_test.go b/debug_test.go index 58964290a7..8fd61d4738 100644 --- a/debug_test.go +++ b/debug_test.go @@ -137,7 +137,7 @@ func TestTraceNotEqual(t *testing.T) { // ------------------------------------------------------------------------------------------------- // Not boolean type notBooleanTrace struct { - A, B, C frontend.Variable + B, C frontend.Variable } func (circuit *notBooleanTrace) Define(curveID ecc.ID, api frontend.API) error { @@ -150,7 +150,7 @@ func TestTraceNotBoolean(t *testing.T) { assert := require.New(t) var circuit, witness notBooleanTrace - witness.A.Assign(1) + // witness.A.Assign(1) witness.B.Assign(24) witness.C.Assign(42) diff --git a/examples/rollup/circuit_test.go b/examples/rollup/circuit_test.go index 454d1692b7..26807943e1 100644 --- a/examples/rollup/circuit_test.go +++ b/examples/rollup/circuit_test.go @@ -79,7 +79,7 @@ func TestCircuitSignature(t *testing.T) { assert := test.NewAssert(t) var signatureCircuit circuitSignature - assert.ProverSucceeded(&signatureCircuit, &operator.witnesses, test.WithCurves(ecc.BN254)) + assert.ProverSucceeded(&signatureCircuit, &operator.witnesses, test.WithCurves(ecc.BN254), test.WithCompileOpts(frontend.IgnoreUnconstrainedInputs)) } @@ -145,7 +145,7 @@ func TestCircuitInclusionProof(t *testing.T) { var inclusionProofCircuit circuitInclusionProof - assert.ProverSucceeded(&inclusionProofCircuit, &operator.witnesses, test.WithCurves(ecc.BN254)) + assert.ProverSucceeded(&inclusionProofCircuit, &operator.witnesses, test.WithCurves(ecc.BN254), test.WithCompileOpts(frontend.IgnoreUnconstrainedInputs)) } @@ -202,7 +202,7 @@ func TestCircuitUpdateAccount(t *testing.T) { var updateAccountCircuit circuitUpdateAccount - assert.ProverSucceeded(&updateAccountCircuit, &operator.witnesses, test.WithCurves(ecc.BN254)) + assert.ProverSucceeded(&updateAccountCircuit, &operator.witnesses, test.WithCurves(ecc.BN254), test.WithCompileOpts(frontend.IgnoreUnconstrainedInputs)) } @@ -246,6 +246,7 @@ func TestCircuitFull(t *testing.T) { var rollupCircuit Circuit - assert.ProverSucceeded(&rollupCircuit, &operator.witnesses, test.WithCurves(ecc.BN254)) + // TODO full circuit has some unconstrained inputs, that's odd. + assert.ProverSucceeded(&rollupCircuit, &operator.witnesses, test.WithCurves(ecc.BN254), test.WithCompileOpts(frontend.IgnoreUnconstrainedInputs)) } diff --git a/frontend/cs.go b/frontend/cs.go index afde473e7d..26a5361a0f 100644 --- a/frontend/cs.go +++ b/frontend/cs.go @@ -17,9 +17,12 @@ limitations under the License. package frontend import ( + "errors" "io" "math/big" "sort" + "strconv" + "strings" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/hint" @@ -36,7 +39,8 @@ type constraintSystem struct { // Variables (aka wires) // virtual variables do not result in a new circuit wire // they may only contain a linear expression - public, secret, internal, virtual variables + public, secret inputs + internal, virtual variables // list of constraints in the form a * b == c // a,b and c being linear expressions @@ -48,7 +52,8 @@ type constraintSystem struct { coeffsIDsInt64 map[int64]int // map to check existence of a coefficient (key = int64 value) // Hints - mHints map[int]compiled.Hint // solver hints + mHints map[int]compiled.Hint // solver hints + mHintsConstrained map[int]bool // marks hints variables constrained status logs []compiled.LogEntry // list of logs to be printed when solving a circuit. The logs are called with the method Println debugInfo []compiled.LogEntry // list of logs storing information about R1C @@ -63,6 +68,16 @@ type variables struct { booleans map[int]struct{} // keep track of boolean variables (we constrain them once) } +type inputs struct { + variables + names []string +} + +func (v *inputs) new(cs *constraintSystem, visibility compiled.Visibility, name string) Variable { + v.names = append(v.names, name) + return v.variables.new(cs, visibility) +} + func (v *variables) new(cs *constraintSystem, visibility compiled.Visibility) Variable { idx := len(v.variables) variable := Variable{visibility: visibility, id: idx, linExp: cs.LinearExpression(compiled.Pack(idx, compiled.CoeffIdOne, visibility))} @@ -96,12 +111,13 @@ func newConstraintSystem(curveID ecc.ID, initialCapacity ...int) constraintSyste capacity = initialCapacity[0] } cs := constraintSystem{ - coeffs: make([]big.Int, 4), - coeffsIDsLarge: make(map[string]int), - coeffsIDsInt64: make(map[int64]int, 4), - constraints: make([]compiled.R1C, 0, capacity), - mDebug: make(map[int]int), - mHints: make(map[int]compiled.Hint), + coeffs: make([]big.Int, 4), + coeffsIDsLarge: make(map[string]int), + coeffsIDsInt64: make(map[int64]int, 4), + constraints: make([]compiled.R1C, 0, capacity), + mDebug: make(map[int]int), + mHints: make(map[int]compiled.Hint), + mHintsConstrained: make(map[int]bool), } cs.coeffs[compiled.CoeffIdZero].SetInt64(0) @@ -114,10 +130,10 @@ func newConstraintSystem(curveID ecc.ID, initialCapacity ...int) constraintSyste cs.coeffsIDsInt64[2] = compiled.CoeffIdTwo cs.coeffsIDsInt64[-1] = compiled.CoeffIdMinusOne - cs.public.variables = make([]Variable, 0) + cs.public.variables.variables = make([]Variable, 0) cs.public.booleans = make(map[int]struct{}) - cs.secret.variables = make([]Variable, 0) + cs.secret.variables.variables = make([]Variable, 0) cs.secret.booleans = make(map[int]struct{}) cs.internal.variables = make([]Variable, 0, capacity) @@ -127,7 +143,7 @@ func newConstraintSystem(curveID ecc.ID, initialCapacity ...int) constraintSyste cs.virtual.booleans = make(map[int]struct{}) // by default the circuit is given on public wire equal to 1 - cs.public.variables[0] = cs.newPublicVariable() + cs.public.variables.variables[0] = cs.newPublicVariable("one") cs.curveID = curveID @@ -146,6 +162,9 @@ func (cs *constraintSystem) NewHint(hintID hint.ID, inputs ...interface{}) Varia // create resulting wire r := cs.newInternalVariable() + // mark hint as unconstrained, for now + cs.mHintsConstrained[r.id] = false + // now we need to store the linear expressions of the expected input // that will be resolved in the solver hintInputs := make([]compiled.LinearExpression, len(inputs)) @@ -168,7 +187,7 @@ func (cs *constraintSystem) bitLen() int { } func (cs *constraintSystem) one() Variable { - return cs.public.variables[0] + return cs.public.variables.variables[0] } // Term packs a variable and a coeff in a compiled.Term and returns it. @@ -287,13 +306,13 @@ func (cs *constraintSystem) newInternalVariable() Variable { } // newPublicVariable creates a new public variable -func (cs *constraintSystem) newPublicVariable() Variable { - return cs.public.new(cs, compiled.Public) +func (cs *constraintSystem) newPublicVariable(name string) Variable { + return cs.public.new(cs, compiled.Public, name) } // newSecretVariable creates a new secret variable -func (cs *constraintSystem) newSecretVariable() Variable { - return cs.secret.new(cs, compiled.Secret) +func (cs *constraintSystem) newSecretVariable(name string) Variable { + return cs.secret.new(cs, compiled.Secret, name) } // newVirtualVariable creates a new virtual variable @@ -333,3 +352,101 @@ func (cs *constraintSystem) markBoolean(v Variable) bool { } return true } + +// checkVariables perform post compilation checks on the variables +// +// 1. checks that all user inputs are referenced in at least one constraint +// 2. checks that all hints are constrained +func (cs *constraintSystem) checkVariables() error { + + // TODO @gbotrel add unit test for that. + + cptSecret := len(cs.secret.variables.variables) + cptPublic := len(cs.public.variables.variables) - 1 + cptHints := len(cs.mHintsConstrained) + + secretConstrained := make([]bool, cptSecret) + publicConstrained := make([]bool, cptPublic+1) + publicConstrained[0] = true + + // for each constraint, we check the linear expressions and mark our inputs / hints as constrained + processLinearExpression := func(l compiled.LinearExpression) { + for _, t := range l { + if t.CoeffID() == compiled.CoeffIdZero { + // ignore zero coefficient, as it does not constraint the variable + // though, we may want to flag that IF the variable doesn't appear else where + continue + } + visibility := t.VariableVisibility() + vID := t.VariableID() + + switch visibility { + case compiled.Public: + if vID != 0 && !publicConstrained[vID] { + publicConstrained[vID] = true + cptPublic-- + } + case compiled.Secret: + if !secretConstrained[vID] { + secretConstrained[vID] = true + cptSecret-- + } + case compiled.Internal: + if b, ok := cs.mHintsConstrained[vID]; ok && !b { + cs.mHintsConstrained[vID] = true + cptHints-- + } + } + } + } + for _, r1c := range cs.constraints { + processLinearExpression(r1c.L) + processLinearExpression(r1c.R) + processLinearExpression(r1c.O) + + if cptHints|cptSecret|cptPublic == 0 { + return nil // we can stop. + } + + } + + // something is a miss, we build the error string + var sbb strings.Builder + if cptSecret != 0 { + sbb.WriteString(strconv.Itoa(cptSecret)) + sbb.WriteString(" unconstrained secret input(s):") + sbb.WriteByte('\n') + for i := 0; i < len(secretConstrained) && cptSecret != 0; i++ { + if !secretConstrained[i] { + sbb.WriteString(cs.secret.names[i]) + sbb.WriteByte('\n') + cptSecret-- + } + } + sbb.WriteByte('\n') + } + + if cptPublic != 0 { + sbb.WriteString(strconv.Itoa(cptPublic)) + sbb.WriteString(" unconstrained public input(s):") + sbb.WriteByte('\n') + for i := 0; i < len(publicConstrained) && cptPublic != 0; i++ { + if !publicConstrained[i] { + sbb.WriteString(cs.public.names[i]) + sbb.WriteByte('\n') + cptPublic-- + } + } + sbb.WriteByte('\n') + } + + if cptHints != 0 { + sbb.WriteString(strconv.Itoa(cptHints)) + sbb.WriteString(" unconstrained hints") + sbb.WriteByte('\n') + // TODO we may add more debug info here --> idea, in NewHint, take the debug stack, and store in the hint map some + // debugInfo to find where a hint was declared (and not constrained) + } + return errors.New(sbb.String()) + +} diff --git a/frontend/cs_api.go b/frontend/cs_api.go index d97cb5f392..d9a9a576f7 100644 --- a/frontend/cs_api.go +++ b/frontend/cs_api.go @@ -441,13 +441,14 @@ func (cs *constraintSystem) Select(i0, i1, i2 interface{}) Variable { // ensures that b is boolean cs.AssertIsBoolean(b) - if b.isConstant() { - c := b.constantValue(cs) - if c.Uint64() == 0 { - return vars[2] - } - return vars[1] - } + // this doesn't work. + // if b.isConstant() { + // c := b.constantValue(cs) + // if c.Uint64() == 0 { + // return vars[2] + // } + // return vars[1] + // } if vars[1].isConstant() && vars[2].isConstant() { n1 := vars[1].constantValue(cs) diff --git a/frontend/cs_api_test.go b/frontend/cs_api_test.go index 551fc5b32b..724c489080 100644 --- a/frontend/cs_api_test.go +++ b/frontend/cs_api_test.go @@ -27,7 +27,7 @@ import ( func TestPrintln(t *testing.T) { // must not panic. cs := newConstraintSystem(ecc.BN254) - one := cs.newPublicVariable() + one := cs.newPublicVariable("one") cs.Println(nil) cs.Println(1) diff --git a/frontend/cs_to_r1cs.go b/frontend/cs_to_r1cs.go index 9c6c2ed4f1..e0f4653a0c 100644 --- a/frontend/cs_to_r1cs.go +++ b/frontend/cs_to_r1cs.go @@ -36,8 +36,8 @@ func (cs *constraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er res := compiled.R1CS{ CS: compiled.CS{ NbInternalVariables: len(cs.internal.variables), - NbPublicVariables: len(cs.public.variables), - NbSecretVariables: len(cs.secret.variables), + NbPublicVariables: len(cs.public.variables.variables), + NbSecretVariables: len(cs.secret.variables.variables), DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), Logs: make([]compiled.LogEntry, len(cs.logs)), MHints: make(map[int]compiled.Hint, len(cs.mHints)), @@ -68,11 +68,11 @@ func (cs *constraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er shiftVID := func(oldID int, visibility compiled.Visibility) int { switch visibility { case compiled.Internal: - return oldID + len(cs.public.variables) + len(cs.secret.variables) + return oldID + len(cs.public.variables.variables) + len(cs.secret.variables.variables) case compiled.Public: return oldID case compiled.Secret: - return oldID + len(cs.public.variables) + return oldID + len(cs.public.variables.variables) } return oldID } diff --git a/frontend/cs_to_r1cs_sparse.go b/frontend/cs_to_r1cs_sparse.go index 38c7de5208..836d1eba08 100644 --- a/frontend/cs_to_r1cs_sparse.go +++ b/frontend/cs_to_r1cs_sparse.go @@ -18,7 +18,6 @@ package frontend import ( "math/big" - "math/bits" "sort" "sync" @@ -55,7 +54,14 @@ type sparseR1CS struct { // map LinearExpression -> Term. The goal is to not reduce // the same linear expression twice. - record map[uint64][]innerRecord + // key == hashCode(linearExpression) (with collisions) + // value == list of tuples {LinearExpression; reduced resulting Term} + reducedLE map[uint64][]innerRecord + + // similarly to reducedLE, excepts, the key is the hashCodeNC() which doesn't take + // into account the coefficient value of the terms + // this is used to detect if a "similar" linear expression was already recorded when splitting + reducedLE_ map[uint64]struct{} } type innerRecord struct { @@ -72,8 +78,8 @@ func (cs *constraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst ccs: compiled.SparseR1CS{ CS: compiled.CS{ NbInternalVariables: len(cs.internal.variables), - NbPublicVariables: len(cs.public.variables) - 1, // the ONE_WIRE is discarded in PlonK - NbSecretVariables: len(cs.secret.variables), + NbPublicVariables: len(cs.public.variables.variables) - 1, // the ONE_WIRE is discarded in PlonK + NbSecretVariables: len(cs.secret.variables.variables), DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), Logs: make([]compiled.LogEntry, len(cs.logs)), MDebug: make(map[int]int), @@ -84,7 +90,8 @@ func (cs *constraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst solvedVariables: make([]bool, len(cs.internal.variables), len(cs.internal.variables)*2), scsInternalVariables: len(cs.internal.variables), currentR1CDebugID: -1, - record: make(map[uint64][]innerRecord, len(cs.internal.variables)), + reducedLE: make(map[uint64][]innerRecord, len(cs.internal.variables)), + reducedLE_: make(map[uint64]struct{}, len(cs.internal.variables)), } // logs, debugInfo and hints are copied, the only thing that will change @@ -261,54 +268,70 @@ func popInternalVariable(l compiled.LinearExpression, id int) (compiled.LinearEx return _l, t } -func gcdInt64(_u, _v int64) uint64 { - var u, v uint64 - if _u < 0 { - u = uint64(-_u) - } else { - u = uint64(_u) - } - if _v < 0 { - v = uint64(-_v) - } else { - v = uint64(_v) - } +// as computeGCD, except, it fills the intermediate values such that gcds[i] == gcd(l[:i]) +func (scs *sparseR1CS) computeGCDs(l compiled.LinearExpression, gcds []*big.Int) { + mustNeg := scs.coeffs[l[0].CoeffID()].Sign() == -1 - if u == 0 { - return v - } - if v == 0 { - return u - } - if u == v { - return u + gcds[0].Set(&scs.coeffs[l[0].CoeffID()]) + if mustNeg { + gcds[0].Neg(gcds[0]) } - tu := bits.TrailingZeros64(u) - tv := bits.TrailingZeros64(v) - u >>= tu - v >>= tv + for i := 1; i < len(l); i++ { + cID := l[i].CoeffID() - for { - if u > v { - v, u = u, v + if gcds[i-1].IsUint64() { + // can be 0 or 1 + prev := gcds[i-1].Uint64() + if prev == 0 { + gcds[i].Abs(&scs.coeffs[cID]) + continue + } else if prev == 1 { + // set the rest to 1. + for ; i < len(l); i++ { + gcds[i].SetUint64(1) + } + continue + } } - v = v - u - if v == 0 { - break + + // we check coeffID here for 1 or minus 1 + if cID == compiled.CoeffIdMinusOne || cID == compiled.CoeffIdOne { + gcds[i].SetUint64(1) + continue + } + + if cID == compiled.CoeffIdZero { + gcds[i].Set(gcds[i-1]) + continue } - v >>= bits.TrailingZeros64(v) - } - if tu < tv { - return u << tu + // we compute the gcd. + gcds[i].GCD(nil, nil, gcds[i-1], &scs.coeffs[cID]) } - return u << tv + if mustNeg { + for i := 1; i < len(l); i++ { + gcds[i].Neg(gcds[i]) + } + } + } // returns ( b/computeGCD(b...), computeGCD(b...) ) // if gcd is != 0 and gcd != 1, returns true func (scs *sparseR1CS) computeGCD(l compiled.LinearExpression, gcd *big.Int) { + mustNeg := scs.coeffs[l[0].CoeffID()].Sign() == -1 + + // fast path: if any of the coeffs is 1 or -1, no need to compute the GCD + if hasOnes(l) { + if mustNeg { + gcd.SetInt64(-1) + return + } + gcd.SetUint64(1) + return + } + gcd.SetUint64(0) var i int for i = 0; i < len(l); i++ { @@ -327,17 +350,17 @@ func (scs *sparseR1CS) computeGCD(l compiled.LinearExpression, gcd *big.Int) { } other := &scs.coeffs[cID] - if gcd.IsInt64() && other.IsInt64() { - gcd.SetUint64(gcdInt64(gcd.Int64(), other.Int64())) - } else { - gcd.GCD(nil, nil, gcd, other) - } - + gcd.GCD(nil, nil, gcd, other) if gcd.IsUint64() && gcd.Uint64() == 1 { break } } + if mustNeg { + // ensure the gcd doesn't depend on the sign + gcd.Neg(gcd) + } + } // return true if linear expression contains one or minusOne coefficients @@ -351,77 +374,31 @@ func hasOnes(l compiled.LinearExpression) bool { return false } -// reduce sets gcd = gcd(l.coefs) and returns l/gcd(l.coefs) -// if gcd == 1, this returns l -func (scs *sparseR1CS) reduce(l compiled.LinearExpression, gcd *big.Int) compiled.LinearExpression { - mustNeg := scs.coeffs[l[0].CoeffID()].Sign() == -1 - - // fast path: if any of the coeffs is 1 or -1, no need to compute the GCD - if hasOnes(l) { - if !mustNeg { - gcd.SetUint64(1) - return l - } - gcd.SetInt64(-1) - return scs.divideLinearExpression(l, gcd) - } - - // compute gcd - scs.computeGCD(l, gcd) - - if mustNeg { - // ensure the gcd doesn't depend on the sign - gcd.Neg(gcd) - } - - if gcd.IsUint64() && (gcd.Uint64() == 0 || gcd.Uint64() == 1) { - // no need to create a new linear expression +// divides all coefficients in l by divisor +// if divisor == 0 or divisor == 1, returns l +func (scs *sparseR1CS) divideLinearExpression(l, r compiled.LinearExpression, divisor *big.Int) compiled.LinearExpression { + if divisor.IsUint64() && (divisor.Uint64() == 0 || divisor.Uint64() == 1) { return l } - return scs.divideLinearExpression(l, gcd) - -} - -// pre-conditions: d != 0 && d != 1 && d divides all the coefficients in l -func (scs *sparseR1CS) divideLinearExpression(l compiled.LinearExpression, gcd *big.Int) compiled.LinearExpression { // copy linear expression - r := make(compiled.LinearExpression, len(l)) + if r == nil { + r = make(compiled.LinearExpression, len(l)) + } copy(r, l) // new coeff lambda := bigIntPool.Get().(*big.Int) - if gcd.IsInt64() { - if gcd.Int64() == -1 { - for i := 0; i < len(r); i++ { - cID := r[i].CoeffID() - if cID == compiled.CoeffIdZero { - continue - } - lambda.Neg(&scs.coeffs[cID]) - r[i].SetCoeffID(scs.coeffID(lambda)) - } - } else { - _gcd := gcd.Int64() - for i := 0; i < len(r); i++ { - cID := r[i].CoeffID() - if cID == compiled.CoeffIdZero { - continue - } - other := scs.coeffs[cID] - if other.IsInt64() { - // we do int64 division and avoid calling coeffID - l := other.Int64() / _gcd - r[i].SetCoeffID(scs.coeffID64(l)) - } else { - // we use Quo here instead of Div, as we know there is no remainder - lambda.Quo(&scs.coeffs[cID], gcd) - r[i].SetCoeffID(scs.coeffID(lambda)) - } + if divisor.IsInt64() && divisor.Int64() == -1 { + for i := 0; i < len(r); i++ { + cID := r[i].CoeffID() + if cID == compiled.CoeffIdZero { + continue } + lambda.Neg(&scs.coeffs[cID]) + r[i].SetCoeffID(scs.coeffID(lambda)) } - bigIntPool.Put(lambda) return r } @@ -432,7 +409,7 @@ func (scs *sparseR1CS) divideLinearExpression(l compiled.LinearExpression, gcd * continue } // we use Quo here instead of Div, as we know there is no remainder - lambda.Quo(&scs.coeffs[cID], gcd) + lambda.Quo(&scs.coeffs[cID], divisor) r[i].SetCoeffID(scs.coeffID(lambda)) } @@ -565,8 +542,10 @@ func (scs *sparseR1CS) multiply(t compiled.Term, c *big.Int) compiled.Term { return t } -func (scs *sparseR1CS) getRecord(l compiled.LinearExpression) (compiled.Term, bool) { - list, ok := scs.record[l.Hash()] +// l is primitive +// that is, it has been factorized and we can't divide the coefficients further +func (scs *sparseR1CS) wasReduced(l compiled.LinearExpression) (compiled.Term, bool) { + list, ok := scs.reducedLE[hashCode(l)] if !ok { return 0, false } @@ -580,56 +559,117 @@ func (scs *sparseR1CS) getRecord(l compiled.LinearExpression) (compiled.Term, bo return 0, false } -func (scs *sparseR1CS) putRecord(l compiled.LinearExpression, t compiled.Term) { - id := l.Hash() - list := scs.record[id] +// l is primitive +// that is, it has been factorized and we can't divide the coefficients further +func (scs *sparseR1CS) markReduced(l compiled.LinearExpression, t compiled.Term, ncHashCode uint64) { + id := hashCode(l) + list := scs.reducedLE[id] + // here we know l is not already in the list, since the call to wasReduced returned false list = append(list, innerRecord{t: t, l: l}) - scs.record[id] = list + scs.reducedLE[id] = list + scs.reducedLE_[ncHashCode] = struct{}{} } +// split decomposes the linear expression into a single term +// for example 2a + 3b + c will be decomposed in +// v0 := 2a + 3b +// v1 := v0 + c +// return v1 +// +// for optimal output, one need to check if we can't reuse previous decompositions to avoid duplicate constraints func (scs *sparseR1CS) split(l compiled.LinearExpression) compiled.Term { - // floor case if len(l) == 1 { return l[0] } - lGCD := bigIntPool.Get().(*big.Int) - // check if l is recorded, if so we get it from the record - lReduced := scs.reduce(l, lGCD) - if t, ok := scs.getRecord(lReduced); ok { - t.SetCoeffID(scs.coeffID(lGCD)) - bigIntPool.Put(lGCD) + gcd := bigIntPool.Get().(*big.Int) + + // lf = gcd * l + // compute the GCD + scs.computeGCD(l, gcd) + + // divide if needed l by gcd + lf := scs.divideLinearExpression(l, nil, gcd) + // if we already recorded lf, the resulting term is gcd * t + if t, ok := scs.wasReduced(lf); ok { + t.SetCoeffID(scs.coeffID(gcd)) + bigIntPool.Put(gcd) return t } - // find if in the left side the constraint is recorded - gcd := bigIntPool.Get().(*big.Int) + // we create a new resulting term for this linear expression + // o correspond to the factorized linear expression lf = l / gcd + // r correspond to the initial linear expression l + // we record the factorized linear expression for potential later use + o := scs.newTerm(bOne) + r := scs.multiply(o, gcd) + scs.markReduced(lf, o, hashCodeNC(lf)) + bigIntPool.Put(gcd) + + var gcds []*big.Int + var scratch compiled.LinearExpression + + // idea: find an existing reduction that partially matches l + + // we compute a hash code of the sub expression that takes into account variables id and visibility + // but not the coeffID. Since this is computed recursively, we store the result up for each lf[:i] + hcs := hashCodeNC_(lf) + + for i := len(lf) - 1; i > 0; i-- { + + // first, we probabilistically check if it's worth it to factorize the sub expression + if _, ok := scs.reducedLE_[hcs[i-1]]; !ok { + // no need to factorize, no linear expression with same variables exist. + continue + } + + // we need to factorize, so since gcd (a,b,c) == gcd ( gcd (a,b), c) + // we compute all gcds up to lf[:i] to use in future iterations + if gcds == nil { + gcds = make([]*big.Int, i) + for i := 0; i < len(gcds); i++ { + gcds[i] = bigIntPool.Get().(*big.Int) + } + scs.computeGCDs(lf[:i], gcds) + scratch = make(compiled.LinearExpression, i) + } + + // we divide the linear expression by the gcd, same idea as above + // note that lff here reuses scratch space, but we never store it, we just compute + // a hash code on it so we're fine + lff := scs.divideLinearExpression(lf[:i], scratch[:i], gcds[i-1]) + + if t, ok := scs.wasReduced(lff); ok { + // the lff was already reduced + // so we return r such that + // r = (gcd * lff) + reduce(lf[i:]) + scs.addConstraint(compiled.SparseR1C{ + L: scs.multiply(t, gcds[i-1]), + R: scs.split(lf[i:]), + O: scs.negate(o), + }) + + for i := 0; i < len(gcds); i++ { + bigIntPool.Put(gcds[i]) + } - for i := len(l) - 1; i > 0; i-- { - ll := scs.reduce(lReduced[:i], gcd) - if t, ok := scs.getRecord(ll); ok { - t = scs.multiply(t, gcd) - o := scs.newTerm(bOne) - b := scs.split(lReduced[i:]) - scs.addConstraint(compiled.SparseR1C{L: t, R: b, O: scs.negate(o)}) - scs.putRecord(lReduced, o) - r := scs.multiply(o, lGCD) - bigIntPool.Put(lGCD) - bigIntPool.Put(gcd) return r } } - bigIntPool.Put(gcd) + + for i := 0; i < len(gcds); i++ { + bigIntPool.Put(gcds[i]) + } // else we build the reduction starting from l[0] - o := scs.newTerm(bOne) - a := lReduced[0] - b := scs.split(lReduced[1:]) - scs.addConstraint(compiled.SparseR1C{L: a, R: b, O: scs.negate(o)}) - scs.putRecord(lReduced, o) - r := scs.multiply(o, lGCD) - bigIntPool.Put(lGCD) + // that is we return a term r such that + // r = l[0] + reduced(lf[1:]) + scs.addConstraint(compiled.SparseR1C{ + L: lf[0], + R: scs.split(lf[1:]), + O: scs.negate(o)}, + ) return r } @@ -1207,3 +1247,39 @@ var bigIntPool = sync.Pool{ return new(big.Int) }, } + +// hashCode returns a fast hash of the linear expression; this is not collision resistant +// but two SORTED equal linear expressions will have equal hashes. +// +// pre conditions: l is sorted +func hashCode(l compiled.LinearExpression) uint64 { + hashcode := uint64(1) + for i := 0; i < len(l); i++ { + hashcode = hashcode*31 + uint64(l[i]) + } + return hashcode +} + +// same as hashCode but ignore the coeffID +func hashCodeNC(l compiled.LinearExpression) uint64 { + hashcode := uint64(1) + for i := 0; i < len(l); i++ { + t := l[i] + t.SetCoeffID(0) + hashcode = hashcode*31 + uint64(t) + } + return hashcode +} + +// same as hashCodeNC but return all the intermediate hash codes +func hashCodeNC_(l compiled.LinearExpression) []uint64 { + r := make([]uint64, len(l)) + hashcode := uint64(1) + for i := 0; i < len(l); i++ { + t := l[i] + t.SetCoeffID(0) + hashcode = hashcode*31 + uint64(t) + r[i] = hashcode + } + return r +} diff --git a/frontend/frontend.go b/frontend/frontend.go index 484d0f9992..e774b3c419 100644 --- a/frontend/frontend.go +++ b/frontend/frontend.go @@ -49,14 +49,29 @@ var errInputNotSet = errors.New("variable is not allocated") // // initialCapacity is an optional parameter that reserves memory in slices // it should be set to the estimated number of constraints in the circuit, if known. -func Compile(curveID ecc.ID, zkpID backend.ID, circuit Circuit, initialCapacity ...int) (ccs CompiledConstraintSystem, err error) { +func Compile(curveID ecc.ID, zkpID backend.ID, circuit Circuit, opts ...func(opt *CompileOption) error) (ccs CompiledConstraintSystem, err error) { + + // setup option + opt := CompileOption{} + for _, o := range opts { + if err := o(&opt); err != nil { + return nil, err + } + } // build the constraint system (see Circuit.Define) - cs, err := buildCS(curveID, circuit, initialCapacity...) + cs, err := buildCS(curveID, circuit, opt.capacity) if err != nil { return nil, err } + // ensure all inputs and hints are constrained + if !opt.ignoreUnconstrainedInputs { + if err := cs.checkVariables(); err != nil { + return nil, err + } + } + switch zkpID { case backend.GROTH16: ccs, err = cs.toR1CS(curveID) @@ -101,9 +116,9 @@ func buildCS(curveID ecc.ID, circuit Circuit, initialCapacity ...int) (cs constr } switch visibility { case compiled.Secret: - tInput.Set(reflect.ValueOf(cs.newSecretVariable())) + tInput.Set(reflect.ValueOf(cs.newSecretVariable(name))) case compiled.Public: - tInput.Set(reflect.ValueOf(cs.newPublicVariable())) + tInput.Set(reflect.ValueOf(cs.newPublicVariable(name))) case compiled.Unset: return errors.New("can't set val " + name + " visibility is unset") } @@ -134,3 +149,23 @@ func buildCS(curveID ecc.ID, circuit Circuit, initialCapacity ...int) (cs constr func Value(value interface{}) Variable { return Variable{WitnessValue: value} } + +// CompileOption enables to set optional argument to call of frontend.Compile() +type CompileOption struct { + capacity int + ignoreUnconstrainedInputs bool +} + +// WithOutput is a Compile option that specifies the estimated capacity needed for internal variables and constraints +func WithCapacity(capacity int) func(opt *CompileOption) error { + return func(opt *CompileOption) error { + opt.capacity = capacity + return nil + } +} + +// IgnoreUnconstrainedInputs when set, the Compile function doesn't check for unconstrained inputs +func IgnoreUnconstrainedInputs(opt *CompileOption) error { + opt.ignoreUnconstrainedInputs = true + return nil +} diff --git a/frontend/frontend_test.go b/frontend/frontend_test.go index b2547fc617..1e9864a66e 100644 --- a/frontend/frontend_test.go +++ b/frontend/frontend_test.go @@ -29,7 +29,7 @@ func BenchmarkCompileReferenceGroth16(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - Compile(ecc.BN254, backend.GROTH16, &c, benchSize) + Compile(ecc.BN254, backend.GROTH16, &c, WithCapacity(benchSize)) } } @@ -38,7 +38,7 @@ func BenchmarkCompileReferencePlonk(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - Compile(ecc.BN254, backend.PLONK, &c, benchSize) + Compile(ecc.BN254, backend.PLONK, &c, WithCapacity(benchSize)) } } diff --git a/frontend/fuzz.go b/frontend/fuzz.go index 6b852bb4e2..41ab9b3b19 100644 --- a/frontend/fuzz.go +++ b/frontend/fuzz.go @@ -43,10 +43,10 @@ func CsFuzzed(data []byte, curveID ecc.ID) (ccs CompiledConstraintSystem) { panic(fmt.Sprintf("reading byte from reader errored: %v", err)) } if b&0b00000001 == 1 { - cs.newPublicVariable() + cs.newPublicVariable("x") } if b&0b00000010 == 0b00000010 { - cs.newSecretVariable() + cs.newSecretVariable("y") } if b&0b00000100 == 0b00000100 { // multiplication @@ -133,18 +133,18 @@ compile: func (cs *constraintSystem) shuffleVariables(seed int64, withConstant bool) []interface{} { var v []interface{} - n := len(cs.public.variables) + len(cs.secret.variables) + len(cs.internal.variables) + n := len(cs.public.variables.variables) + len(cs.secret.variables.variables) + len(cs.internal.variables) if withConstant { v = make([]interface{}, 0, n*2+4*3) } else { v = make([]interface{}, 0, n) } - for i := 0; i < len(cs.public.variables); i++ { - v = append(v, cs.public.variables[i]) + for i := 0; i < len(cs.public.variables.variables); i++ { + v = append(v, cs.public.variables.variables[i]) } - for i := 0; i < len(cs.secret.variables); i++ { - v = append(v, cs.secret.variables[i]) + for i := 0; i < len(cs.secret.variables.variables); i++ { + v = append(v, cs.secret.variables.variables[i]) } for i := 0; i < len(cs.internal.variables); i++ { v = append(v, cs.internal.variables[i]) diff --git a/internal/backend/compiled/r1c.go b/internal/backend/compiled/r1c.go index e67620b453..a96051478b 100644 --- a/internal/backend/compiled/r1c.go +++ b/internal/backend/compiled/r1c.go @@ -39,22 +39,6 @@ func (l LinearExpression) Len() int { return len(l) } -// Hash returns a fast hash of the linear expression; this is not collision resistant -// but two SORTED equal linear expressions will have equal hashes. -// -// pre conditions: l is sorted -func (l LinearExpression) Hash() uint64 { - if len(l) == 0 { - return 0 - } - - hashcode := uint64(1) - for i := 0; i < len(l); i++ { - hashcode = hashcode*31 + uint64(l[i]) - } - return hashcode -} - // Equals returns true if both SORTED expressions are the same // // pre conditions: l and o are sorted diff --git a/std/algebra/fields/e12_test.go b/std/algebra/fields/e12_test.go index 1d5315c2e4..b245adf7e9 100644 --- a/std/algebra/fields/e12_test.go +++ b/std/algebra/fields/e12_test.go @@ -131,7 +131,7 @@ type fp12Square struct { func (circuit *fp12Square) Define(curveID ecc.ID, api frontend.API) error { ext := GetBLS377ExtensionFp12(api) s := circuit.A.Square(api, circuit.A, ext) - s.MustBeEqual(api, *s) + s.MustBeEqual(api, circuit.B) return nil } diff --git a/std/algebra/sw/g1_test.go b/std/algebra/sw/g1_test.go index 3ccc3bd570..c32da40fe4 100644 --- a/std/algebra/sw/g1_test.go +++ b/std/algebra/sw/g1_test.go @@ -231,19 +231,14 @@ func TestScalarMulG1(t *testing.T) { var a, c bls12377.G1Affine a.FromJacobian(&_a) - // random scalar - var r fr.Element - r.SetRandom() - // create the cs var circuit, witness g1ScalarMul - circuit.r = r - + circuit.r.SetRandom() // assign the inputs witness.A.Assign(&a) // compute the result var br big.Int - _a.ScalarMultiplication(&_a, r.ToBigIntRegular(&br)) + _a.ScalarMultiplication(&_a, circuit.r.ToBigIntRegular(&br)) c.FromJacobian(&_a) witness.C.Assign(&c) @@ -269,13 +264,13 @@ func BenchmarkScalarMulG1(b *testing.B) { var c g1ScalarMul // this is q - 1 c.r.SetString("660539884262666720468348340822774968888139573360124440321458176") - // b.Run("groth16", func(b *testing.B) { - // for i := 0; i < b.N; i++ { - // ccsBench, _ = frontend.Compile(ecc.BN254, backend.GROTH16, &c) - // } + b.Run("groth16", func(b *testing.B) { + for i := 0; i < b.N; i++ { + ccsBench, _ = frontend.Compile(ecc.BN254, backend.GROTH16, &c) + } - // }) - // b.Log("groth16", ccsBench.GetNbConstraints()) + }) + b.Log("groth16", ccsBench.GetNbConstraints()) b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { diff --git a/test/assert.go b/test/assert.go index 9e58369473..1522624505 100644 --- a/test/assert.go +++ b/test/assert.go @@ -75,7 +75,7 @@ func (assert *Assert) ProverSucceeded(circuit frontend.Circuit, validWitness fro checkError := func(err error) { assert.checkError(err, b, curve, validWitness) } // 1- compile the circuit - ccs, err := assert.compile(circuit, curve, b) + ccs, err := assert.compile(circuit, curve, b, opt.compileOpts) checkError(err) // must not error with big int test engine @@ -173,7 +173,7 @@ func (assert *Assert) ProverFailed(circuit frontend.Circuit, invalidWitness fron mustError := func(err error) { assert.mustError(err, b, curve, invalidWitness) } // 1- compile the circuit - ccs, err := assert.compile(circuit, curve, b) + ccs, err := assert.compile(circuit, curve, b, opt.compileOpts) checkError(err) // must error with big int test engine @@ -228,7 +228,7 @@ func (assert *Assert) solvingSucceeded(circuit frontend.Circuit, validWitness fr checkError := func(err error) { assert.checkError(err, b, curve, validWitness) } // 1- compile the circuit - ccs, err := assert.compile(circuit, curve, b) + ccs, err := assert.compile(circuit, curve, b, opt.compileOpts) checkError(err) // must not error with big int test engine @@ -264,7 +264,7 @@ func (assert *Assert) solvingFailed(circuit frontend.Circuit, invalidWitness fro mustError := func(err error) { assert.mustError(err, b, curve, invalidWitness) } // 1- compile the circuit - ccs, err := assert.compile(circuit, curve, b) + ccs, err := assert.compile(circuit, curve, b, opt.compileOpts) if err != nil { fmt.Println(reflect.TypeOf(circuit).String()) } @@ -307,7 +307,7 @@ func (assert *Assert) Fuzz(circuit frontend.Circuit, fuzzCount int, opts ...func // this puts the compiled circuit in the cache // we do this here in case our fuzzWitness method mutates some references in the circuit // (like []frontend.Variable) before cleaning up - _, err := assert.compile(circuit, curve, b) + _, err := assert.compile(circuit, curve, b, opt.compileOpts) assert.NoError(err) valid := 0 // "fuzz" with zeros @@ -348,20 +348,21 @@ func (assert *Assert) fuzzer(fuzzer filler, circuit, w frontend.Circuit, b backe } // compile the given circuit for given curve and backend, if not already present in cache -func (assert *Assert) compile(circuit frontend.Circuit, curveID ecc.ID, backendID backend.ID) (frontend.CompiledConstraintSystem, error) { +func (assert *Assert) compile(circuit frontend.Circuit, curveID ecc.ID, backendID backend.ID, compileOpts []func(opt *frontend.CompileOption) error) (frontend.CompiledConstraintSystem, error) { key := curveID.String() + backendID.String() + reflect.TypeOf(circuit).String() // check if we already compiled it if ccs, ok := assert.compiled[key]; ok { + // TODO we may want to check that it was compiled with the same compile options here return ccs, nil } // else compile it and ensure it is deterministic - ccs, err := frontend.Compile(curveID, backendID, circuit) + ccs, err := frontend.Compile(curveID, backendID, circuit, compileOpts...) if err != nil { return nil, err } - _ccs, err := frontend.Compile(curveID, backendID, circuit) + _ccs, err := frontend.Compile(curveID, backendID, circuit, compileOpts...) if err != nil { return nil, err } diff --git a/test/options.go b/test/options.go index 6a336c1d35..36aa05a3ca 100644 --- a/test/options.go +++ b/test/options.go @@ -19,6 +19,7 @@ package test import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" ) // TestingOption enables calls to assert.ProverSucceeded and assert.ProverFailed to run with various features @@ -31,6 +32,7 @@ type TestingOption struct { curves []ecc.ID witnessSerialization bool proverOpts []func(opt *backend.ProverOption) error + compileOpts []func(opt *frontend.CompileOption) error } // WithBackends enables calls to assert.ProverSucceeded and assert.ProverFailed to run on specific backends only @@ -71,3 +73,12 @@ func WithProverOpts(proverOpts ...func(opt *backend.ProverOption) error) func(op return nil } } + +// WithCompileOpts enables calls to assert.ProverSucceeded and assert.ProverFailed to forward frontend.Compile option +// to frontend.Compile calls +func WithCompileOpts(compileOpts ...func(opt *frontend.CompileOption) error) func(opt *TestingOption) error { + return func(opt *TestingOption) error { + opt.compileOpts = compileOpts + return nil + } +}