From d9f31c592cda0a5e49f229a37f1e503f61d5a3cb Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Thu, 16 Sep 2021 14:08:39 -0500 Subject: [PATCH 01/15] feat: cs.Println and debugInfo supported in groth16 and plonk --- frontend/cs.go | 150 ++++-------------- frontend/cs_api.go | 71 ++------- frontend/cs_api_test.go | 38 ++--- frontend/cs_assertions.go | 129 ++++----------- frontend/cs_debug.go | 42 +++++ frontend/cs_to_r1cs.go | 66 +++----- frontend/cs_to_r1cs_sparse.go | 38 +++-- frontend/frontend.go | 3 +- internal/backend/bls12-377/cs/cs.go | 53 ++++++- internal/backend/bls12-377/cs/r1cs.go | 47 +++--- internal/backend/bls12-377/cs/r1cs_sparse.go | 14 +- internal/backend/bls12-381/cs/cs.go | 53 ++++++- internal/backend/bls12-381/cs/r1cs.go | 47 +++--- internal/backend/bls12-381/cs/r1cs_sparse.go | 14 +- internal/backend/bls24-315/cs/cs.go | 53 ++++++- internal/backend/bls24-315/cs/r1cs.go | 47 +++--- internal/backend/bls24-315/cs/r1cs_sparse.go | 14 +- internal/backend/bn254/cs/cs.go | 53 ++++++- internal/backend/bn254/cs/r1cs.go | 47 +++--- internal/backend/bn254/cs/r1cs_sparse.go | 14 +- internal/backend/bw6-761/cs/cs.go | 53 ++++++- internal/backend/bw6-761/cs/r1cs.go | 47 +++--- internal/backend/bw6-761/cs/r1cs_sparse.go | 14 +- internal/backend/compiled/log.go | 87 ++++++++++ internal/backend/compiled/r1c.go | 12 +- internal/backend/compiled/r1cs.go | 13 +- internal/backend/compiled/r1cs_sparse.go | 6 +- internal/backend/compiled/term.go | 3 + .../template/representations/cs.go.tmpl | 53 ++++++- .../template/representations/r1cs.go.tmpl | 47 +++--- .../representations/r1cs.sparse.go.tmpl | 14 +- 31 files changed, 771 insertions(+), 571 deletions(-) create mode 100644 frontend/cs_debug.go create mode 100644 internal/backend/compiled/log.go diff --git a/frontend/cs.go b/frontend/cs.go index b7f0ad6a89..f789a4eb55 100644 --- a/frontend/cs.go +++ b/frontend/cs.go @@ -44,22 +44,24 @@ type ConstraintSystem struct { // they may only contain a linear expression public, secret, internal, virtual variables - // Constraints - constraints []compiled.R1C // list of R1C that yield an output (for example v3 == v1 * v2, return v3) - assertions []compiled.R1C // list of R1C that yield no output (for example ensuring v1 == v2) + // list of constraints in the form a * b == c + // a,b and c being linear expressions + constraints []compiled.R1C // Coefficients in the constraints coeffs []big.Int // list of unique coefficients. coeffsIDs map[string]int // map to fast check existence of a coefficient (key = coeff.Text(16)) // Hints + // TODO @gbotrel let's make it a map directly here. hints []compiled.Hint // solver hints + // TODO @gbotrel we may want to make that optional through build tags // debug info - logs []logEntry // list of logs to be printed when solving a circuit. The logs are called with the method Println - debugInfoComputation []logEntry // list of logs storing information about computations (e.g. division by 0).If an computation fails, it prints it in a friendly format - debugInfoAssertion []logEntry // list of logs storing information about assertions. If an assertion fails, it prints it in a friendly format + logs []compiled.LogEntry // list of logs to be printed when solving a circuit. The logs are called with the method Println + debugInfo []compiled.LogEntry // list of logs storing information about R1C + mDebug map[int]int // maps constraint ID to debugInfo id } type variables struct { @@ -107,7 +109,7 @@ func newConstraintSystem(initialCapacity ...int) ConstraintSystem { coeffs: make([]big.Int, 4), coeffsIDs: make(map[string]int), constraints: make([]compiled.R1C, 0, capacity), - assertions: make([]compiled.R1C, 0), + mDebug: make(map[int]int), } cs.coeffs[compiled.CoeffIdZero].SetInt64(0) @@ -179,68 +181,37 @@ func (cs *ConstraintSystem) Println(a ...interface{}) { sbb.WriteByte(' ') } - // for each argument, if it is a circuit structure and contains variable - // we add the variables in the logEntry.toResolve part, and add %s to the format string in the log entry - // if it doesn't contain variable, call fmt.Sprint(arg) instead - entry := logEntry{} - - // this is call recursively on the arguments using reflection on each argument - foundVariable := false - - var handler logValueHandler = func(name string, tInput reflect.Value) { - - v := tInput.Interface().(Variable) - - // if the variable is only in linExp form, we allocate it - _v := cs.allocate(v) - - entry.toResolve = append(entry.toResolve, compiled.Pack(_v.id, 0, _v.visibility)) - - if name == "" { - sbb.WriteString("%s") - } else { - sbb.WriteString(fmt.Sprintf("%s: %%s ", name)) - } - - foundVariable = true - } + var log compiled.LogEntry for i, arg := range a { if i > 0 { sbb.WriteByte(' ') } - foundVariable = false - parseLogValue(arg, "", handler) - if !foundVariable { + if v, ok := arg.(Variable); ok { + v.assertIsSet() + + sbb.WriteString("%s") + // we set limits to the linear expression, so that the log printer + // can evaluate it before printing it + log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) + log.ToResolve = append(log.ToResolve, v.linExp...) + log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) + } else { sbb.WriteString(fmt.Sprint(arg)) } } sbb.WriteByte('\n') // set format string to be used with fmt.Sprintf, once the variables are solved in the R1CS.Solve() method - entry.format = sbb.String() + log.Format = sbb.String() - cs.logs = append(cs.logs, entry) -} - -type logEntry struct { - format string - toResolve []compiled.Term + cs.logs = append(cs.logs, log) } var ( bOne = new(big.Int).SetInt64(1) ) -// debug info in case a variable is not set -// func debugInfoUnsetVariable(term compiled.Term) logEntry { -// entry := logEntry{} -// stack := getCallStack() -// entry.format = stack[len(stack)-1] -// entry.toResolve = append(entry.toResolve, term) -// return entry -// } - func (cs *ConstraintSystem) one() Variable { return cs.public.variables[0] } @@ -268,11 +239,8 @@ func newR1C(l, r, o Variable) compiled.R1C { // NbConstraints enables circuit profiling and helps debugging // It returns the number of constraints created at the current stage of the circuit construction. -// -// The number returns included both the assertions and the non-assertion constraints -// (eg: the constraints which creates a new variable) func (cs *ConstraintSystem) NbConstraints() int { - return len(cs.constraints) + len(cs.assertions) + return len(cs.constraints) } // LinearExpression packs a list of compiled.Term in a compiled.LinearExpression and returns it. @@ -310,11 +278,6 @@ func (cs *ConstraintSystem) reduce(l compiled.LinearExpression) compiled.LinearE return l } -func (cs *ConstraintSystem) addAssertion(constraint compiled.R1C, debugInfo logEntry) { - cs.assertions = append(cs.assertions, constraint) - cs.debugInfoAssertion = append(cs.debugInfoAssertion, debugInfo) -} - // coeffID tries to fetch the entry where b is if it exits, otherwise appends b to // the list of coeffs and returns the corresponding entry func (cs *ConstraintSystem) coeffID(b *big.Int) int { @@ -349,17 +312,11 @@ func (cs *ConstraintSystem) coeffID(b *big.Int) int { return resID } -// if v is unset and linExp is non empty, the variable is allocated -// resulting in one more constraint in the system. If v is set OR v is -// unset and linexp is emppty, it does nothing. -func (cs *ConstraintSystem) allocate(v Variable) Variable { - if v.visibility == compiled.Unset && len(v.linExp) > 0 { - iv := cs.newInternalVariable() - one := cs.one() - cs.constraints = append(cs.constraints, newR1C(v, one, iv)) - return iv +func (cs *ConstraintSystem) addConstraint(r1c compiled.R1C, debugID ...int) { + cs.constraints = append(cs.constraints, r1c) + if len(debugID) > 0 { + cs.mDebug[len(cs.constraints)-1] = debugID[0] } - return v } // newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets @@ -430,59 +387,8 @@ func parseLogValue(input interface{}, name string, handler logValueHandler) { } } -// derived from: https://golang.org/pkg/runtime/#example_Frames -// we stop when func name == Define as it is where the gnark circuit code should start -func getCallStack() []string { - // Ask runtime.Callers for up to 10 pcs - pc := make([]uintptr, 10) - n := runtime.Callers(3, pc) - if n == 0 { - // No pcs available. Stop now. - // This can happen if the first argument to runtime.Callers is large. - return nil - } - pc = pc[:n] // pass only valid pcs to runtime.CallersFrames - frames := runtime.CallersFrames(pc) - // Loop to get frames. - // A fixed number of pcs can expand to an indefinite number of Frames. - var toReturn []string - for { - frame, more := frames.Next() - fe := strings.Split(frame.Function, "/") - function := fe[len(fe)-1] - toReturn = append(toReturn, fmt.Sprintf("%s\n\t%s:%d", function, frame.File, frame.Line)) - if !more { - break - } - if strings.HasSuffix(function, "Define") { - break - } - } - return toReturn -} - func (cs *ConstraintSystem) buildVarFromWire(pv Wire) Variable { - return Variable{pv, cs.LinearExpression(cs.makeTerm(pv, bOne))} -} - -// creates a string formatted to display correctly a variable, from its linear expression representation -// (i.e. the linear expression leading to it) -func (cs *ConstraintSystem) buildLogEntryFromVariable(v Variable) logEntry { - - var res logEntry - var sbb strings.Builder - sbb.Grow(len(v.linExp) * len(" + (xx + xxxxxxxxxxxx")) - - for i := 0; i < len(v.linExp); i++ { - if i > 0 { - sbb.WriteString(" + ") - } - c := cs.coeffs[v.linExp[i].CoeffID()] - sbb.WriteString(fmt.Sprintf("(%%s * %s)", c.String())) - } - res.format = sbb.String() - res.toResolve = v.linExp.Clone() - return res + return Variable{pv, cs.LinearExpression(compiled.Pack(pv.id, compiled.CoeffIdOne, pv.visibility))} } // markBoolean marks the variable as boolean and return true diff --git a/frontend/cs_api.go b/frontend/cs_api.go index 6deab0f4c9..f5199e2547 100644 --- a/frontend/cs_api.go +++ b/frontend/cs_api.go @@ -18,7 +18,6 @@ package frontend import ( "math/big" - "strings" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/internal/backend/compiled" @@ -185,79 +184,28 @@ func (cs *ConstraintSystem) Mul(i1, i2 interface{}, in ...interface{}) Variable // Inverse returns res = inverse(v) func (cs *ConstraintSystem) Inverse(v Variable) Variable { - v.assertIsSet() + debug := cs.addDebugInfo("inverse", v) // allocate resulting variable res := cs.newInternalVariable() - cs.constraints = append(cs.constraints, newR1C(v, res, cs.one())) - - // prepare debug info to be displayed in case the constraint is not solved - debugInfo := logEntry{ - toResolve: nil, - } - var sbb strings.Builder - sbb.WriteString("couldn't solve computational constraint (inversion by zero ?)") - stack := getCallStack() - for i := 0; i < len(stack); i++ { - sbb.WriteByte('\n') - sbb.WriteString(stack[i]) - } - debugInfo.format = sbb.String() - - // add it to the logs record - cs.debugInfoComputation = append(cs.debugInfoComputation, debugInfo) + cs.addConstraint(newR1C(v, res, cs.one()), debug) return res } // Div returns res = i1 / i2 func (cs *ConstraintSystem) Div(i1, i2 interface{}) Variable { - // allocate resulting variable res := cs.newInternalVariable() - // O - switch t1 := i1.(type) { - case Variable: - t1.assertIsSet() - switch t2 := i2.(type) { - case Variable: - t2.assertIsSet() - cs.constraints = append(cs.constraints, newR1C(t2, res, t1)) - default: - tmp := cs.Constant(t2) - cs.constraints = append(cs.constraints, newR1C(res, tmp, t1)) - } - default: - switch t2 := i2.(type) { - case Variable: - t2.assertIsSet() - tmp := cs.Constant(t1) - cs.constraints = append(cs.constraints, newR1C(t2, res, tmp)) - default: - tmp1 := cs.Constant(t1) - tmp2 := cs.Constant(t2) - cs.constraints = append(cs.constraints, newR1C(res, tmp2, tmp1)) - } - } + v1 := cs.Constant(i1) + v2 := cs.Constant(i2) - // prepare debug info to be displayed in case the constraint is not solved - debugInfo := logEntry{ - toResolve: nil, - } - var sbb strings.Builder - sbb.WriteString("couldn't solve computational constraint (inversion by zero ?)") - stack := getCallStack() - for i := 0; i < len(stack); i++ { - sbb.WriteByte('\n') - sbb.WriteString(stack[i]) - } - debugInfo.format = sbb.String() + debug := cs.addDebugInfo("div", v1, " / ", v2, " != ", res) - // add it to the logs record - cs.debugInfoComputation = append(cs.debugInfoComputation, debugInfo) + cs.addConstraint(newR1C(v2, res, v1), debug) return res } @@ -316,6 +264,7 @@ func (cs *ConstraintSystem) And(a, b Variable) Variable { // IsZero returns 1 if a is zero, 0 otherwise func (cs *ConstraintSystem) IsZero(a Variable) Variable { a.assertIsSet() + debug := cs.addDebugInfo("isZero", a) //m * (1 - m) = 0 // constrain m to be 0 or 1 // a * m = 0 // constrain m to be 0 if a != 0 @@ -323,7 +272,7 @@ func (cs *ConstraintSystem) IsZero(a Variable) Variable { // m is computed by the solver such that m = 1 - a^(modulus - 1) m := cs.NewHint(hint.IsZero, a) - cs.constraints = append(cs.constraints, newR1C(a, m, cs.Constant(0))) + cs.addConstraint(newR1C(a, m, cs.Constant(0)), debug) cs.AssertIsBoolean(m) ma := cs.Add(m, a) @@ -349,7 +298,7 @@ func (cs *ConstraintSystem) ToBinary(a Variable, nbBits int) []Variable { // here what we do is we add a single constraint where // Σ (2**i * b[i]) == a var c big.Int - c.Set(bOne) + c.SetUint64(1) var Σbi Variable Σbi.linExp = make(compiled.LinearExpression, nbBits) @@ -378,7 +327,7 @@ func (cs *ConstraintSystem) FromBinary(b ...Variable) Variable { res = cs.Constant(0) // no constraint is recorded var c big.Int - c.Set(bOne) + c.SetUint64(1) L := make(compiled.LinearExpression, len(b)) for i := 0; i < len(L); i++ { diff --git a/frontend/cs_api_test.go b/frontend/cs_api_test.go index 69b5497b6d..ddd23060f7 100644 --- a/frontend/cs_api_test.go +++ b/frontend/cs_api_test.go @@ -19,7 +19,6 @@ type csState struct { nbSecretVariables int nbInternalVariables int nbConstraints int - nbAssertions int } // deltaState holds the difference between the next state (after calling a function from the API) and the previous one @@ -38,7 +37,7 @@ type nextstatefunc func(state commands.State) commands.State // the names of the public/secret inputs are variableName.String() func incVariableName() { - variableName.Add(&variableName, bOne) + variableName.Add(&variableName, new(big.Int).SetUint64(1)) } // ------------------------------------------------------------------------------ @@ -103,8 +102,7 @@ func postConditionAPI(state commands.State, result commands.Result) *gopter.Prop if len(csRes.cs.public.variables) != st.nbPublicVariables || len(csRes.cs.secret.variables) != st.nbSecretVariables || len(csRes.cs.internal.variables) != st.nbInternalVariables || - len(csRes.cs.constraints) != st.nbConstraints || - len(csRes.cs.assertions) != st.nbAssertions { + len(csRes.cs.constraints) != st.nbConstraints { return &gopter.PropResult{Status: gopter.PropFalse} } return &gopter.PropResult{Status: gopter.PropTrue} @@ -150,7 +148,7 @@ func rfAddSub() runfunc { return res } -var nsAddSub = deltaState{1, 2, 0, 0, 0} // ex: after calling add, we should have 1 public variable, 3 secret variables, 0 internal variable, 0 constraint more in the cs +var nsAddSub = deltaState{1, 2, 0, 0} // ex: after calling add, we should have 1 public variable, 3 secret variables, 0 internal variable, 0 constraint more in the cs // mul variables func rfMul() runfunc { @@ -182,7 +180,7 @@ func rfMul() runfunc { return res } -var nsMul = csState{1, 1, 1, 1, 0} +var nsMul = csState{1, 1, 1, 1} // inverse a variable func rfInverse() runfunc { @@ -210,7 +208,7 @@ func rfInverse() runfunc { return res } -var nsInverse = deltaState{1, 0, 1, 1, 0} +var nsInverse = deltaState{1, 0, 1, 1} // div 2 variables func rfDiv() runfunc { @@ -251,7 +249,7 @@ func rfDiv() runfunc { return res } -var nsDiv = deltaState{1, 1, 4, 4, 0} +var nsDiv = deltaState{1, 1, 4, 4} // xor between two variables func rfXor() runfunc { @@ -283,7 +281,7 @@ func rfXor() runfunc { return res } -var nsXor = deltaState{1, 1, 1, 1, 2} +var nsXor = deltaState{1, 1, 1, 3} // binary decomposition of a variable func rfToBinary() runfunc { @@ -309,7 +307,7 @@ func rfToBinary() runfunc { return res } -var nsToBinary = deltaState{1, 0, 256, 1, 256} +var nsToBinary = deltaState{1, 0, 256, 257} // select constraint betwwen variableq func rfSelect() runfunc { @@ -353,7 +351,7 @@ func rfSelect() runfunc { return res } -var nsSelect = deltaState{1, 2, 3, 3, 1} +var nsSelect = deltaState{1, 2, 3, 4} // copy of variable func rfConstant() runfunc { @@ -381,7 +379,7 @@ func rfConstant() runfunc { return res } -var nsConstant = deltaState{1, 0, 0, 0, 0} +var nsConstant = deltaState{1, 0, 0, 0} // equality between 2 variables func rfIsEqual() runfunc { @@ -416,7 +414,7 @@ func rfIsEqual() runfunc { return res } -var nsIsEqual = deltaState{1, 1, 0, 0, 2} +var nsIsEqual = deltaState{1, 1, 0, 2} // packing from binary variables func rfFromBinary() runfunc { @@ -444,7 +442,7 @@ func rfFromBinary() runfunc { return res } -var nsFromBinary = deltaState{256, 0, 0, 0, 256} +var nsFromBinary = deltaState{256, 0, 0, 256} // boolean constrain a variable func rfIsBoolean() runfunc { @@ -481,11 +479,11 @@ func rfIsBoolean() runfunc { return res } -var nsIsBoolean = deltaState{1, 1, 0, 0, 2} +var nsIsBoolean = deltaState{1, 1, 0, 2} -var nsMustBeLessOrEqVar = deltaState{1, 1, 1281, 771, 768} +var nsMustBeLessOrEqVar = deltaState{1, 1, 1281, 1539} -var nsMustBeLessOrEqConst = csState{1, 0, 257, 2, 511} // nb internal variables: 256+HW(bound), nb constraints: 1+HW(bound), nb assertions: 256+HW(^bound) +var nsMustBeLessOrEqConst = csState{1, 0, 257, 513} // nb internal variables: 256+HW(bound), nb constraints: 1+HW(bound), nb assertions: 256+HW(^bound) // ------------------------------------------------------------------------------ // build the next state function using the delta state @@ -496,7 +494,6 @@ func nextStateFunc(ds deltaState) nextstatefunc { state.(*csState).nbSecretVariables += ds.nbSecretVariables state.(*csState).nbInternalVariables += ds.nbInternalVariables state.(*csState).nbConstraints += ds.nbConstraints - state.(*csState).nbAssertions += ds.nbAssertions return state } return res @@ -679,11 +676,6 @@ func (c *isLessOrEq) Define(curveID ecc.ID, cs *ConstraintSystem) error { } func TestUnsetVariables(t *testing.T) { - // TODO unset variables with markBoolean will panic. - // doing - // var a Variable - // cs.AssertIsBoolean(a) - // will panic. mapFuncs := map[string]Circuit{ "add": &addCircuit{}, "sub": &subCircuit{}, diff --git a/frontend/cs_assertions.go b/frontend/cs_assertions.go index 9af35cd853..6bed794c4e 100644 --- a/frontend/cs_assertions.go +++ b/frontend/cs_assertions.go @@ -2,48 +2,25 @@ package frontend import ( "math/big" - "strings" "github.com/consensys/gnark/internal/backend/compiled" ) // AssertIsEqual adds an assertion in the constraint system (i1 == i2) func (cs *ConstraintSystem) AssertIsEqual(i1, i2 interface{}) { + // encoded i1 * 1 == i2 + // TODO do cs.Sub(i1,i2) == 0 ? - // encoded as L * R == O - // set L = i1 - // set R = 1 - // set O = i2 - - // we don't do just "cs.Sub(i1,i2)" to allow proper logging - debugInfo := logEntry{} - - l := cs.Constant(i1) // no constraint is recorded - r := cs.Constant(1) // no constraint is recorded - o := cs.Constant(i2) // no constraint is recorded - - // build log - var sbb strings.Builder - sbb.WriteString("[") - lhs := cs.buildLogEntryFromVariable(l) - sbb.WriteString(lhs.format) - debugInfo.toResolve = lhs.toResolve - sbb.WriteString(" != ") - rhs := cs.buildLogEntryFromVariable(o) - sbb.WriteString(rhs.format) - debugInfo.toResolve = append(debugInfo.toResolve, rhs.toResolve...) - sbb.WriteString("]") - - // get call stack - sbb.WriteString("error AssertIsEqual") - stack := getCallStack() - for i := 0; i < len(stack); i++ { - sbb.WriteByte('\n') - sbb.WriteString(stack[i]) + l := cs.Constant(i1) + o := cs.Constant(i2) + + if len(l.linExp) > len(o.linExp) { + l, o = o, l // maximize number of zeroes in r1cs.A } - debugInfo.format = sbb.String() - cs.addAssertion(newR1C(l, r, o), debugInfo) + debug := cs.addDebugInfo("assertIsEqual", l, " == ", o) + + cs.addConstraint(newR1C(l, cs.one(), o), debug) } // AssertIsDifferent constrain i1 and i2 to be different @@ -53,9 +30,7 @@ func (cs *ConstraintSystem) AssertIsDifferent(i1, i2 interface{}) { // AssertIsBoolean adds an assertion in the constraint system (v == 0 || v == 1) func (cs *ConstraintSystem) AssertIsBoolean(v Variable) { - v.assertIsSet() - if v.visibility == compiled.Unset { // we need to create a new wire here. vv := cs.newVirtualVariable() @@ -66,26 +41,12 @@ func (cs *ConstraintSystem) AssertIsBoolean(v Variable) { if !cs.markBoolean(v) { return // variable is already constrained } + debug := cs.addDebugInfo("assertIsBoolean", v) // ensure v * (1 - v) == 0 - - _v := cs.Sub(1, v) // no variable is recorded in the cs - o := cs.Constant(0) // no variable is recorded in the cs - - // prepare debug info to be displayed in case the constraint is not solved - debugInfo := logEntry{ - toResolve: nil, - } - var sbb strings.Builder - sbb.WriteString("error AssertIsBoolean") - stack := getCallStack() - for i := 0; i < len(stack); i++ { - sbb.WriteByte('\n') - sbb.WriteString(stack[i]) - } - debugInfo.format = sbb.String() - - cs.addAssertion(newR1C(v, _v, o), debugInfo) + _v := cs.Sub(1, v) + o := cs.Constant(0) + cs.addConstraint(newR1C(v, _v, o), debug) } // AssertIsLessOrEqual adds assertion in constraint system (v <= bound) @@ -108,31 +69,14 @@ func (cs *ConstraintSystem) AssertIsLessOrEqual(v Variable, bound interface{}) { } -func (cs *ConstraintSystem) mustBeLessOrEqVar(w, bound Variable) { - - // prepare debug info to be displayed in case the constraint is not solved - dbgInfoW := cs.buildLogEntryFromVariable(w) - dbgInfoBound := cs.buildLogEntryFromVariable(bound) - var sbb strings.Builder - var debugInfo logEntry - sbb.WriteString(dbgInfoW.format) - sbb.WriteString(" <= ") - sbb.WriteString(dbgInfoBound.format) - debugInfo.toResolve = make([]compiled.Term, len(dbgInfoW.toResolve)+len(dbgInfoBound.toResolve)) - copy(debugInfo.toResolve[:], dbgInfoW.toResolve) - copy(debugInfo.toResolve[len(dbgInfoW.toResolve):], dbgInfoBound.toResolve) - - stack := getCallStack() - for i := 0; i < len(stack); i++ { - sbb.WriteByte('\n') - sbb.WriteString(stack[i]) - } - debugInfo.format = sbb.String() +func (cs *ConstraintSystem) mustBeLessOrEqVar(v, bound Variable) { + debug := cs.addDebugInfo("mustBeLessOrEq", v, " <= ", bound) + // TODO nbBits shouldn't be here. const nbBits = 256 - binw := cs.ToBinary(w, nbBits) - binbound := cs.ToBinary(bound, nbBits) + wBits := cs.ToBinary(v, nbBits) + boundBits := cs.ToBinary(bound, nbBits) p := make([]Variable, nbBits+1) p[nbBits] = cs.Constant(1) @@ -141,41 +85,25 @@ func (cs *ConstraintSystem) mustBeLessOrEqVar(w, bound Variable) { for i := nbBits - 1; i >= 0; i-- { - p1 := cs.Mul(p[i+1], binw[i]) - p[i] = cs.Select(binbound[i], p1, p[i+1]) - t := cs.Select(binbound[i], zero, p[i+1]) + p1 := cs.Mul(p[i+1], wBits[i]) + p[i] = cs.Select(boundBits[i], p1, p[i+1]) + t := cs.Select(boundBits[i], zero, p[i+1]) l := cs.one() - l = cs.Sub(l, t) // no constraint is recorded - l = cs.Sub(l, binw[i]) // no constraint is recorded + l = cs.Sub(l, t) // no constraint is recorded + l = cs.Sub(l, wBits[i]) // no constraint is recorded - r := binw[i] + r := wBits[i] o := cs.Constant(0) // no constraint is recorded - cs.addAssertion(newR1C(l, r, o), debugInfo) + cs.addConstraint(newR1C(l, r, o), debug) } } func (cs *ConstraintSystem) mustBeLessOrEqCst(v Variable, bound big.Int) { - - // prepare debug info to be displayed in case the constraint is not solved - dbgInfoW := cs.buildLogEntryFromVariable(v) - var sbb strings.Builder - var debugInfo logEntry - sbb.WriteString(dbgInfoW.format) - sbb.WriteString(" <= ") - sbb.WriteString(bound.String()) - - debugInfo.toResolve = dbgInfoW.toResolve - - stack := getCallStack() - for i := 0; i < len(stack); i++ { - sbb.WriteByte('\n') - sbb.WriteString(stack[i]) - } - debugInfo.format = sbb.String() + debug := cs.addDebugInfo("mustBeLessOrEq", v, " <= ", cs.Constant(bound)) // TODO store those constant elsewhere (for the moment they don't depend on the base curve, but that might change) const nbBits = 256 @@ -206,7 +134,8 @@ func (cs *ConstraintSystem) mustBeLessOrEqCst(v Variable, bound big.Int) { r := vBits[(i+1)*wordSize-1-j] o := cs.Constant(0) - cs.addAssertion(newR1C(l, r, o), debugInfo) + + cs.addConstraint(newR1C(l, r, o), debug) } else { p[(i+1)*wordSize-1-j] = cs.Mul(p[(i+1)*wordSize-j], vBits[(i+1)*wordSize-1-j]) diff --git a/frontend/cs_debug.go b/frontend/cs_debug.go new file mode 100644 index 0000000000..d737ec9027 --- /dev/null +++ b/frontend/cs_debug.go @@ -0,0 +1,42 @@ +package frontend + +import ( + "strconv" + "strings" + + "github.com/consensys/gnark/internal/backend/compiled" +) + +// TODO @gbotrel maybe rename to newLog if common with cs.Println +func (cs *ConstraintSystem) addDebugInfo(errName string, i ...interface{}) int { + var debug compiled.LogEntry + + // TODO @gbotrel reserve capacity for the string builder + const minLogSize = 500 + var sbb strings.Builder + sbb.Grow(minLogSize) + sbb.WriteString("[") + sbb.WriteString(errName) + sbb.WriteString("] ") + + for _, _i := range i { + switch v := _i.(type) { + case Variable: + debug.WriteLinearExpression(v.linExp, &sbb) + case string: + sbb.WriteString(v) + case int: + sbb.WriteString(strconv.Itoa(v)) + case compiled.Term: + debug.WriteTerm(v, &sbb) + default: + panic("unsupported log type") + } + } + sbb.WriteByte('\n') + debug.WriteStack(&sbb) + debug.Format = sbb.String() + + cs.debugInfo = append(cs.debugInfo, debug) + return len(cs.debugInfo) - 1 +} diff --git a/frontend/cs_to_r1cs.go b/frontend/cs_to_r1cs.go index 37978458f2..2c0bb2bf64 100644 --- a/frontend/cs_to_r1cs.go +++ b/frontend/cs_to_r1cs.go @@ -18,19 +18,26 @@ func (cs *ConstraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er // setting up the result res := compiled.R1CS{ - NbInternalVariables: len(cs.internal.variables), - NbPublicVariables: len(cs.public.variables), - NbSecretVariables: len(cs.secret.variables), - NbConstraints: len(cs.constraints) + len(cs.assertions), - Constraints: make([]compiled.R1C, len(cs.constraints)+len(cs.assertions)), - Logs: make([]compiled.LogEntry, len(cs.logs)), - DebugInfoComputation: make([]compiled.LogEntry, len(cs.debugInfoComputation)+len(cs.debugInfoAssertion)), - Hints: make([]compiled.Hint, len(cs.hints)), + NbInternalVariables: len(cs.internal.variables), + NbPublicVariables: len(cs.public.variables), + NbSecretVariables: len(cs.secret.variables), + NbConstraints: len(cs.constraints), + Constraints: make([]compiled.R1C, len(cs.constraints)), + Logs: make([]compiled.LogEntry, len(cs.logs)), + DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), + Hints: make([]compiled.Hint, len(cs.hints)), + MDebug: make(map[int]int), } // computational constraints (= gates) copy(res.Constraints, cs.constraints) - copy(res.Constraints[len(cs.constraints):], cs.assertions) + + copy(res.Logs, cs.logs) + copy(res.DebugInfo, cs.debugInfo) + + for k, v := range cs.mDebug { + res.MDebug[k] = v + } // note: verbose, but we offset the IDs of the wires where they appear, that is, // in the logs, debug info, constraints and hints @@ -74,41 +81,18 @@ func (cs *ConstraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er } - // we need to offset the ids in logs - for i := 0; i < len(cs.logs); i++ { - entry := compiled.LogEntry{ - Format: cs.logs[i].format, - } - for j := 0; j < len(cs.logs[i].toResolve); j++ { - _, vID, visibility := cs.logs[i].toResolve[j].Unpack() - entry.ToResolve = append(entry.ToResolve, shiftVID(vID, visibility)) - } - - res.Logs[i] = entry - } - - // offset ids in the debugInfoComputation - for i := 0; i < len(cs.debugInfoComputation); i++ { - entry := compiled.LogEntry{ - Format: cs.debugInfoComputation[i].format, - } - for j := 0; j < len(cs.debugInfoComputation[i].toResolve); j++ { - _, vID, visibility := cs.debugInfoComputation[i].toResolve[j].Unpack() - entry.ToResolve = append(entry.ToResolve, shiftVID(vID, visibility)) + // we need to offset the ids in logs & debugInfo + for i := 0; i < len(res.Logs); i++ { + for j := 0; j < len(res.Logs[i].ToResolve); j++ { + _, vID, visibility := res.Logs[i].ToResolve[j].Unpack() + res.Logs[i].ToResolve[j].SetVariableID(shiftVID(vID, visibility)) } - - res.DebugInfoComputation[i] = entry } - for i := 0; i < len(cs.debugInfoAssertion); i++ { - entry := compiled.LogEntry{ - Format: cs.debugInfoAssertion[i].format, - } - for j := 0; j < len(cs.debugInfoAssertion[i].toResolve); j++ { - _, vID, visibility := cs.debugInfoAssertion[i].toResolve[j].Unpack() - entry.ToResolve = append(entry.ToResolve, shiftVID(vID, visibility)) + for i := 0; i < len(res.DebugInfo); i++ { + for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { + _, vID, visibility := res.DebugInfo[i].ToResolve[j].Unpack() + res.DebugInfo[i].ToResolve[j].SetVariableID(shiftVID(vID, visibility)) } - - res.DebugInfoComputation[i+len(cs.debugInfoComputation)] = entry } switch curveID { diff --git a/frontend/cs_to_r1cs_sparse.go b/frontend/cs_to_r1cs_sparse.go index 1658b2a709..d039bf426d 100644 --- a/frontend/cs_to_r1cs_sparse.go +++ b/frontend/cs_to_r1cs_sparse.go @@ -48,6 +48,8 @@ type sparseR1CS struct { // and guarantee that the solver will encounter at most one unsolved wire // per SparseR1C solvedVariables []bool + + currentR1CDebugID int // mark the current R1C debugID } func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSystem, error) { @@ -58,12 +60,15 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst NbPublicVariables: len(cs.public.variables) - 1, // the ONE_WIRE is discarded as it is not used in PLONK NbSecretVariables: len(cs.secret.variables), NbInternalVariables: len(cs.internal.variables), - Constraints: make([]compiled.SparseR1C, 0, len(cs.constraints)+len(cs.assertions)), + Constraints: make([]compiled.SparseR1C, 0, len(cs.constraints)), Logs: make([]compiled.LogEntry, len(cs.logs)), + DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), Hints: make([]compiled.Hint, len(cs.hints)), + MDebug: make(map[int]int), }, solvedVariables: make([]bool, len(cs.internal.variables), len(cs.internal.variables)*2), scsInternalVariables: len(cs.internal.variables), + currentR1CDebugID: -1, } // note: verbose, but we offset the IDs of the wires where they appear, that is, @@ -72,6 +77,9 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst // the same wireID multiple times. copy(res.ccs.Hints, cs.hints) + copy(res.ccs.Logs, cs.logs) + copy(res.ccs.DebugInfo, cs.debugInfo) + // TODO @gbotrel we may not want to do that as it may hide some bugs // if there is a R1C with several unsolved wires, wether they are hint wires or not // will be problematic at solving time @@ -83,11 +91,13 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst // in particular, all linear expressions that appear in the R1C // will be split in multiple constraints in the SparseR1C for i := 0; i < len(cs.constraints); i++ { + if dID, ok := cs.mDebug[i]; ok { + res.currentR1CDebugID = dID + } else { + res.currentR1CDebugID = -1 + } res.r1cToSparseR1C(cs.constraints[i]) } - for i := 0; i < len(cs.assertions); i++ { - res.r1cToSparseR1C(cs.assertions[i]) - } // shift variable ID // we want publicWires | privateWires | internalWires @@ -136,16 +146,16 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst } // offset IDs in the logs - for i := 0; i < len(cs.logs); i++ { - entry := compiled.LogEntry{ - Format: cs.logs[i].format, - ToResolve: make([]int, len(cs.logs[i].toResolve)), + // we need to offset the ids in logs & debugInfo + for i := 0; i < len(res.ccs.Logs); i++ { + for j := 0; j < len(res.ccs.Logs[i].ToResolve); j++ { + offsetTermID(&res.ccs.Logs[i].ToResolve[j]) } - for j := 0; j < len(cs.logs[i].toResolve); j++ { - _, cID, cVisibility := cs.logs[i].toResolve[j].Unpack() - entry.ToResolve[j] = shiftVID(cID, cVisibility) + } + for i := 0; i < len(res.ccs.DebugInfo); i++ { + for j := 0; j < len(res.ccs.DebugInfo[i].ToResolve); j++ { + offsetTermID(&res.ccs.DebugInfo[i].ToResolve[j]) } - res.ccs.Logs[i] = entry } // we need to offset the ids in the hints @@ -295,6 +305,9 @@ func (scs *sparseR1CS) addConstraint(c compiled.SparseR1C) { if c.M[1] == 0 { c.M[1].SetVariableID(c.R.VariableID()) } + if scs.currentR1CDebugID != -1 { + scs.ccs.MDebug[len(scs.ccs.Constraints)] = scs.currentR1CDebugID + } scs.ccs.Constraints = append(scs.ccs.Constraints, c) } @@ -391,6 +404,7 @@ func (scs *sparseR1CS) split(a compiled.Term, l compiled.LinearExpression) compi // r1cToSparseR1C splits a r1c constraint func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { + // find if the variable to solve is in the left, right, or o linear expression lro, idCS := findUnsolvedVariable(r1c, scs.solvedVariables) if lro == -1 { diff --git a/frontend/frontend.go b/frontend/frontend.go index 5ab626e6e2..fa63e0d9fc 100644 --- a/frontend/frontend.go +++ b/frontend/frontend.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "runtime/debug" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" @@ -112,7 +113,7 @@ func buildCS(curveID ecc.ID, circuit Circuit, initialCapacity ...int) (cs Constr // recover from panics to print user-friendlier messages defer func() { if r := recover(); r != nil { - err = fmt.Errorf("%v", r) + err = fmt.Errorf("%v\n%s", r, string(debug.Stack())) } }() // instantiate our constraint system diff --git a/internal/backend/bls12-377/cs/cs.go b/internal/backend/bls12-377/cs/cs.go index 6a018df255..783b712b2b 100644 --- a/internal/backend/bls12-377/cs/cs.go +++ b/internal/backend/bls12-377/cs/cs.go @@ -151,8 +151,59 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} + var ( + isEval bool + eval fr.Element + missingValue bool + ) for j := 0; j < len(log.ToResolve); j++ { - vID := log.ToResolve[j] + if log.ToResolve[j] == compiled.TermDelimitor { + // this is a special case where we want to evaluate the following terms until the next delimitor. + if !isEval { + isEval = true + missingValue = false + eval.SetZero() + continue + } + isEval = false + if missingValue { + toResolve = append(toResolve, "???") + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + continue + } + cID, vID, visibility := log.ToResolve[j].Unpack() + + if isEval { + // we are evaluating + if visibility == compiled.Virtual { + // just add the constant + eval.Add(&eval, &s.coefficients[cID]) + continue + } + if !s.solved[vID] { + missingValue = true + continue + } + tv := s.computeTerm(log.ToResolve[j]) + eval.Add(&eval, &tv) + continue + } + + if visibility == compiled.Virtual { + // it's just a constant + if cID == compiled.CoeffIdMinusOne { + toResolve = append(toResolve, "-1") + } else { + toResolve = append(toResolve, s.coefficients[cID].String()) + } + continue + } + if !(cID == compiled.CoeffIdMinusOne || cID == compiled.CoeffIdOne) { + toResolve = append(toResolve, s.coefficients[cID].String()) + } if !s.solved[vID] { toResolve = append(toResolve, "???") } else { diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index ef03182f09..dfe73ace12 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -100,11 +100,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // check if there is an inconsistant constraint var check fr.Element - // TODO @gbotrel clean this - // this variable is used to navigate in the debugInfoComputation slice. - // It is incremented by one each time a division happens for solving a constraint. - var debugInfoComputationOffset uint - // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire // first we solve the unsolved wire (if any) @@ -112,11 +107,9 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - offset, err := cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { return solution.values, err } - debugInfoComputationOffset += offset // compute values for the R1C (ie value * coeff) a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) @@ -124,9 +117,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) if !check.Equal(&c[i]) { - debugInfo := cs.DebugInfoComputation[debugInfoComputationOffset] - debugInfoStr := solution.logValue(debugInfo) - return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -203,10 +198,7 @@ func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.E // It returns the 1 if the the position to solve is in the quadratic part (it // means that there is a division and serves to navigate in the log info for the // computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error) { - - // value to return: 1 if the wire to solve is in the quadratic term, 0 otherwise - var offset uint +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -241,19 +233,19 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error for _, t := range r.L { if err := processTerm(t, &a, 1); err != nil { - return 0, err + return err } } for _, t := range r.R { if err := processTerm(t, &b, 2); err != nil { - return 0, err + return err } } for _, t := range r.O { if err := processTerm(t, &c, 3); err != nil { - return 0, err + return err } } @@ -261,7 +253,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return 0, nil + return nil } // we compute the wire value and instantiate it @@ -276,14 +268,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error wire.Div(&c, &b). Sub(&wire, &a) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 3: wire.Mul(&a, &b). @@ -293,7 +283,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error solution.set(vID, wire) - return offset, nil + return nil } // TODO @gbotrel clean logs and html @@ -406,14 +396,15 @@ func (cs *R1CS) SetLoggerOutput(w io.Writer) { // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) - - // encode our object - if err := encoder.Encode(cs); err != nil { - return _w.N, err + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err } + encoder := enc.NewEncoder(&_w) - return _w.N, nil + // encode our object + err = encoder.Encode(cs) + return _w.N, err } // ReadFrom attempts to decode R1CS from io.Reader using cbor diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go index e2e5435a99..5e63fa29ed 100644 --- a/internal/backend/bls12-377/cs/r1cs_sparse.go +++ b/internal/backend/bls12-377/cs/r1cs_sparse.go @@ -109,7 +109,11 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) return solution.values, fmt.Errorf("constraint %d: %w", i, err) } if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -322,10 +326,14 @@ func (cs *SparseR1CS) CurveID() ecc.ID { // WriteTo encodes SparseR1CS into provided io.Writer using cbor func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) // encode our object - err := encoder.Encode(cs) + err = encoder.Encode(cs) return _w.N, err } diff --git a/internal/backend/bls12-381/cs/cs.go b/internal/backend/bls12-381/cs/cs.go index bd5292d077..321de6f9d3 100644 --- a/internal/backend/bls12-381/cs/cs.go +++ b/internal/backend/bls12-381/cs/cs.go @@ -151,8 +151,59 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} + var ( + isEval bool + eval fr.Element + missingValue bool + ) for j := 0; j < len(log.ToResolve); j++ { - vID := log.ToResolve[j] + if log.ToResolve[j] == compiled.TermDelimitor { + // this is a special case where we want to evaluate the following terms until the next delimitor. + if !isEval { + isEval = true + missingValue = false + eval.SetZero() + continue + } + isEval = false + if missingValue { + toResolve = append(toResolve, "???") + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + continue + } + cID, vID, visibility := log.ToResolve[j].Unpack() + + if isEval { + // we are evaluating + if visibility == compiled.Virtual { + // just add the constant + eval.Add(&eval, &s.coefficients[cID]) + continue + } + if !s.solved[vID] { + missingValue = true + continue + } + tv := s.computeTerm(log.ToResolve[j]) + eval.Add(&eval, &tv) + continue + } + + if visibility == compiled.Virtual { + // it's just a constant + if cID == compiled.CoeffIdMinusOne { + toResolve = append(toResolve, "-1") + } else { + toResolve = append(toResolve, s.coefficients[cID].String()) + } + continue + } + if !(cID == compiled.CoeffIdMinusOne || cID == compiled.CoeffIdOne) { + toResolve = append(toResolve, s.coefficients[cID].String()) + } if !s.solved[vID] { toResolve = append(toResolve, "???") } else { diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index b035e3e6cd..d302b6f871 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -100,11 +100,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // check if there is an inconsistant constraint var check fr.Element - // TODO @gbotrel clean this - // this variable is used to navigate in the debugInfoComputation slice. - // It is incremented by one each time a division happens for solving a constraint. - var debugInfoComputationOffset uint - // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire // first we solve the unsolved wire (if any) @@ -112,11 +107,9 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - offset, err := cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { return solution.values, err } - debugInfoComputationOffset += offset // compute values for the R1C (ie value * coeff) a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) @@ -124,9 +117,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) if !check.Equal(&c[i]) { - debugInfo := cs.DebugInfoComputation[debugInfoComputationOffset] - debugInfoStr := solution.logValue(debugInfo) - return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -203,10 +198,7 @@ func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.E // It returns the 1 if the the position to solve is in the quadratic part (it // means that there is a division and serves to navigate in the log info for the // computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error) { - - // value to return: 1 if the wire to solve is in the quadratic term, 0 otherwise - var offset uint +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -241,19 +233,19 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error for _, t := range r.L { if err := processTerm(t, &a, 1); err != nil { - return 0, err + return err } } for _, t := range r.R { if err := processTerm(t, &b, 2); err != nil { - return 0, err + return err } } for _, t := range r.O { if err := processTerm(t, &c, 3); err != nil { - return 0, err + return err } } @@ -261,7 +253,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return 0, nil + return nil } // we compute the wire value and instantiate it @@ -276,14 +268,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error wire.Div(&c, &b). Sub(&wire, &a) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 3: wire.Mul(&a, &b). @@ -293,7 +283,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error solution.set(vID, wire) - return offset, nil + return nil } // TODO @gbotrel clean logs and html @@ -406,14 +396,15 @@ func (cs *R1CS) SetLoggerOutput(w io.Writer) { // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) - - // encode our object - if err := encoder.Encode(cs); err != nil { - return _w.N, err + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err } + encoder := enc.NewEncoder(&_w) - return _w.N, nil + // encode our object + err = encoder.Encode(cs) + return _w.N, err } // ReadFrom attempts to decode R1CS from io.Reader using cbor diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go index b72a43faf7..14def8092e 100644 --- a/internal/backend/bls12-381/cs/r1cs_sparse.go +++ b/internal/backend/bls12-381/cs/r1cs_sparse.go @@ -109,7 +109,11 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) return solution.values, fmt.Errorf("constraint %d: %w", i, err) } if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -322,10 +326,14 @@ func (cs *SparseR1CS) CurveID() ecc.ID { // WriteTo encodes SparseR1CS into provided io.Writer using cbor func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) // encode our object - err := encoder.Encode(cs) + err = encoder.Encode(cs) return _w.N, err } diff --git a/internal/backend/bls24-315/cs/cs.go b/internal/backend/bls24-315/cs/cs.go index d260069445..cc0bcbb8cd 100644 --- a/internal/backend/bls24-315/cs/cs.go +++ b/internal/backend/bls24-315/cs/cs.go @@ -151,8 +151,59 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} + var ( + isEval bool + eval fr.Element + missingValue bool + ) for j := 0; j < len(log.ToResolve); j++ { - vID := log.ToResolve[j] + if log.ToResolve[j] == compiled.TermDelimitor { + // this is a special case where we want to evaluate the following terms until the next delimitor. + if !isEval { + isEval = true + missingValue = false + eval.SetZero() + continue + } + isEval = false + if missingValue { + toResolve = append(toResolve, "???") + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + continue + } + cID, vID, visibility := log.ToResolve[j].Unpack() + + if isEval { + // we are evaluating + if visibility == compiled.Virtual { + // just add the constant + eval.Add(&eval, &s.coefficients[cID]) + continue + } + if !s.solved[vID] { + missingValue = true + continue + } + tv := s.computeTerm(log.ToResolve[j]) + eval.Add(&eval, &tv) + continue + } + + if visibility == compiled.Virtual { + // it's just a constant + if cID == compiled.CoeffIdMinusOne { + toResolve = append(toResolve, "-1") + } else { + toResolve = append(toResolve, s.coefficients[cID].String()) + } + continue + } + if !(cID == compiled.CoeffIdMinusOne || cID == compiled.CoeffIdOne) { + toResolve = append(toResolve, s.coefficients[cID].String()) + } if !s.solved[vID] { toResolve = append(toResolve, "???") } else { diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index 69b737db99..3484459626 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -100,11 +100,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // check if there is an inconsistant constraint var check fr.Element - // TODO @gbotrel clean this - // this variable is used to navigate in the debugInfoComputation slice. - // It is incremented by one each time a division happens for solving a constraint. - var debugInfoComputationOffset uint - // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire // first we solve the unsolved wire (if any) @@ -112,11 +107,9 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - offset, err := cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { return solution.values, err } - debugInfoComputationOffset += offset // compute values for the R1C (ie value * coeff) a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) @@ -124,9 +117,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) if !check.Equal(&c[i]) { - debugInfo := cs.DebugInfoComputation[debugInfoComputationOffset] - debugInfoStr := solution.logValue(debugInfo) - return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -203,10 +198,7 @@ func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.E // It returns the 1 if the the position to solve is in the quadratic part (it // means that there is a division and serves to navigate in the log info for the // computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error) { - - // value to return: 1 if the wire to solve is in the quadratic term, 0 otherwise - var offset uint +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -241,19 +233,19 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error for _, t := range r.L { if err := processTerm(t, &a, 1); err != nil { - return 0, err + return err } } for _, t := range r.R { if err := processTerm(t, &b, 2); err != nil { - return 0, err + return err } } for _, t := range r.O { if err := processTerm(t, &c, 3); err != nil { - return 0, err + return err } } @@ -261,7 +253,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return 0, nil + return nil } // we compute the wire value and instantiate it @@ -276,14 +268,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error wire.Div(&c, &b). Sub(&wire, &a) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 3: wire.Mul(&a, &b). @@ -293,7 +283,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error solution.set(vID, wire) - return offset, nil + return nil } // TODO @gbotrel clean logs and html @@ -406,14 +396,15 @@ func (cs *R1CS) SetLoggerOutput(w io.Writer) { // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) - - // encode our object - if err := encoder.Encode(cs); err != nil { - return _w.N, err + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err } + encoder := enc.NewEncoder(&_w) - return _w.N, nil + // encode our object + err = encoder.Encode(cs) + return _w.N, err } // ReadFrom attempts to decode R1CS from io.Reader using cbor diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go index 36b81c1042..11f4438c50 100644 --- a/internal/backend/bls24-315/cs/r1cs_sparse.go +++ b/internal/backend/bls24-315/cs/r1cs_sparse.go @@ -109,7 +109,11 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) return solution.values, fmt.Errorf("constraint %d: %w", i, err) } if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -322,10 +326,14 @@ func (cs *SparseR1CS) CurveID() ecc.ID { // WriteTo encodes SparseR1CS into provided io.Writer using cbor func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) // encode our object - err := encoder.Encode(cs) + err = encoder.Encode(cs) return _w.N, err } diff --git a/internal/backend/bn254/cs/cs.go b/internal/backend/bn254/cs/cs.go index 00da483db7..1a25e1881f 100644 --- a/internal/backend/bn254/cs/cs.go +++ b/internal/backend/bn254/cs/cs.go @@ -151,8 +151,59 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} + var ( + isEval bool + eval fr.Element + missingValue bool + ) for j := 0; j < len(log.ToResolve); j++ { - vID := log.ToResolve[j] + if log.ToResolve[j] == compiled.TermDelimitor { + // this is a special case where we want to evaluate the following terms until the next delimitor. + if !isEval { + isEval = true + missingValue = false + eval.SetZero() + continue + } + isEval = false + if missingValue { + toResolve = append(toResolve, "???") + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + continue + } + cID, vID, visibility := log.ToResolve[j].Unpack() + + if isEval { + // we are evaluating + if visibility == compiled.Virtual { + // just add the constant + eval.Add(&eval, &s.coefficients[cID]) + continue + } + if !s.solved[vID] { + missingValue = true + continue + } + tv := s.computeTerm(log.ToResolve[j]) + eval.Add(&eval, &tv) + continue + } + + if visibility == compiled.Virtual { + // it's just a constant + if cID == compiled.CoeffIdMinusOne { + toResolve = append(toResolve, "-1") + } else { + toResolve = append(toResolve, s.coefficients[cID].String()) + } + continue + } + if !(cID == compiled.CoeffIdMinusOne || cID == compiled.CoeffIdOne) { + toResolve = append(toResolve, s.coefficients[cID].String()) + } if !s.solved[vID] { toResolve = append(toResolve, "???") } else { diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index 906d480e87..8d96387a3d 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -100,11 +100,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // check if there is an inconsistant constraint var check fr.Element - // TODO @gbotrel clean this - // this variable is used to navigate in the debugInfoComputation slice. - // It is incremented by one each time a division happens for solving a constraint. - var debugInfoComputationOffset uint - // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire // first we solve the unsolved wire (if any) @@ -112,11 +107,9 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - offset, err := cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { return solution.values, err } - debugInfoComputationOffset += offset // compute values for the R1C (ie value * coeff) a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) @@ -124,9 +117,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) if !check.Equal(&c[i]) { - debugInfo := cs.DebugInfoComputation[debugInfoComputationOffset] - debugInfoStr := solution.logValue(debugInfo) - return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -203,10 +198,7 @@ func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.E // It returns the 1 if the the position to solve is in the quadratic part (it // means that there is a division and serves to navigate in the log info for the // computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error) { - - // value to return: 1 if the wire to solve is in the quadratic term, 0 otherwise - var offset uint +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -241,19 +233,19 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error for _, t := range r.L { if err := processTerm(t, &a, 1); err != nil { - return 0, err + return err } } for _, t := range r.R { if err := processTerm(t, &b, 2); err != nil { - return 0, err + return err } } for _, t := range r.O { if err := processTerm(t, &c, 3); err != nil { - return 0, err + return err } } @@ -261,7 +253,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return 0, nil + return nil } // we compute the wire value and instantiate it @@ -276,14 +268,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error wire.Div(&c, &b). Sub(&wire, &a) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 3: wire.Mul(&a, &b). @@ -293,7 +283,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error solution.set(vID, wire) - return offset, nil + return nil } // TODO @gbotrel clean logs and html @@ -406,14 +396,15 @@ func (cs *R1CS) SetLoggerOutput(w io.Writer) { // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) - - // encode our object - if err := encoder.Encode(cs); err != nil { - return _w.N, err + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err } + encoder := enc.NewEncoder(&_w) - return _w.N, nil + // encode our object + err = encoder.Encode(cs) + return _w.N, err } // ReadFrom attempts to decode R1CS from io.Reader using cbor diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go index 4fedb6f24d..0a7ff0eaec 100644 --- a/internal/backend/bn254/cs/r1cs_sparse.go +++ b/internal/backend/bn254/cs/r1cs_sparse.go @@ -109,7 +109,11 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) return solution.values, fmt.Errorf("constraint %d: %w", i, err) } if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -322,10 +326,14 @@ func (cs *SparseR1CS) CurveID() ecc.ID { // WriteTo encodes SparseR1CS into provided io.Writer using cbor func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) // encode our object - err := encoder.Encode(cs) + err = encoder.Encode(cs) return _w.N, err } diff --git a/internal/backend/bw6-761/cs/cs.go b/internal/backend/bw6-761/cs/cs.go index 6ea550d88f..c545fa7910 100644 --- a/internal/backend/bw6-761/cs/cs.go +++ b/internal/backend/bw6-761/cs/cs.go @@ -151,8 +151,59 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} + var ( + isEval bool + eval fr.Element + missingValue bool + ) for j := 0; j < len(log.ToResolve); j++ { - vID := log.ToResolve[j] + if log.ToResolve[j] == compiled.TermDelimitor { + // this is a special case where we want to evaluate the following terms until the next delimitor. + if !isEval { + isEval = true + missingValue = false + eval.SetZero() + continue + } + isEval = false + if missingValue { + toResolve = append(toResolve, "???") + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + continue + } + cID, vID, visibility := log.ToResolve[j].Unpack() + + if isEval { + // we are evaluating + if visibility == compiled.Virtual { + // just add the constant + eval.Add(&eval, &s.coefficients[cID]) + continue + } + if !s.solved[vID] { + missingValue = true + continue + } + tv := s.computeTerm(log.ToResolve[j]) + eval.Add(&eval, &tv) + continue + } + + if visibility == compiled.Virtual { + // it's just a constant + if cID == compiled.CoeffIdMinusOne { + toResolve = append(toResolve, "-1") + } else { + toResolve = append(toResolve, s.coefficients[cID].String()) + } + continue + } + if !(cID == compiled.CoeffIdMinusOne || cID == compiled.CoeffIdOne) { + toResolve = append(toResolve, s.coefficients[cID].String()) + } if !s.solved[vID] { toResolve = append(toResolve, "???") } else { diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index 669b5537d6..7e78134464 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -100,11 +100,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // check if there is an inconsistant constraint var check fr.Element - // TODO @gbotrel clean this - // this variable is used to navigate in the debugInfoComputation slice. - // It is incremented by one each time a division happens for solving a constraint. - var debugInfoComputationOffset uint - // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire // first we solve the unsolved wire (if any) @@ -112,11 +107,9 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - offset, err := cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { return solution.values, err } - debugInfoComputationOffset += offset // compute values for the R1C (ie value * coeff) a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) @@ -124,9 +117,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) if !check.Equal(&c[i]) { - debugInfo := cs.DebugInfoComputation[debugInfoComputationOffset] - debugInfoStr := solution.logValue(debugInfo) - return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -203,10 +198,7 @@ func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.E // It returns the 1 if the the position to solve is in the quadratic part (it // means that there is a division and serves to navigate in the log info for the // computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error) { - - // value to return: 1 if the wire to solve is in the quadratic term, 0 otherwise - var offset uint +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -241,19 +233,19 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error for _, t := range r.L { if err := processTerm(t, &a, 1); err != nil { - return 0, err + return err } } for _, t := range r.R { if err := processTerm(t, &b, 2); err != nil { - return 0, err + return err } } for _, t := range r.O { if err := processTerm(t, &c, 3); err != nil { - return 0, err + return err } } @@ -261,7 +253,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return 0, nil + return nil } // we compute the wire value and instantiate it @@ -276,14 +268,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error wire.Div(&c, &b). Sub(&wire, &a) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 3: wire.Mul(&a, &b). @@ -293,7 +283,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error solution.set(vID, wire) - return offset, nil + return nil } // TODO @gbotrel clean logs and html @@ -406,14 +396,15 @@ func (cs *R1CS) SetLoggerOutput(w io.Writer) { // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) - - // encode our object - if err := encoder.Encode(cs); err != nil { - return _w.N, err + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err } + encoder := enc.NewEncoder(&_w) - return _w.N, nil + // encode our object + err = encoder.Encode(cs) + return _w.N, err } // ReadFrom attempts to decode R1CS from io.Reader using cbor diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go index dc334eadae..71a0dd36d8 100644 --- a/internal/backend/bw6-761/cs/r1cs_sparse.go +++ b/internal/backend/bw6-761/cs/r1cs_sparse.go @@ -109,7 +109,11 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) return solution.values, fmt.Errorf("constraint %d: %w", i, err) } if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -322,10 +326,14 @@ func (cs *SparseR1CS) CurveID() ecc.ID { // WriteTo encodes SparseR1CS into provided io.Writer using cbor func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) // encode our object - err := encoder.Encode(cs) + err = encoder.Encode(cs) return _w.N, err } diff --git a/internal/backend/compiled/log.go b/internal/backend/compiled/log.go new file mode 100644 index 0000000000..e65a84d419 --- /dev/null +++ b/internal/backend/compiled/log.go @@ -0,0 +1,87 @@ +package compiled + +import ( + "runtime" + "strconv" + "strings" +) + +// LogEntry is used as a shared data structure between the frontend and the backend +// to represent string values (in logs or debug info) where a value is not known at compile time +// (which is the case for variables that need to be resolved in the R1CS) +type LogEntry struct { + Format string + ToResolve []Term +} + +func (l *LogEntry) WriteLinearExpression(le LinearExpression, sbb *strings.Builder) { + sbb.Grow(len(le) * len(" + (xx + xxxxxxxxxxxx")) + + for i := 0; i < len(le); i++ { + if i > 0 { + sbb.WriteString(" + ") + } + l.WriteTerm(le[i], sbb) + } +} + +func (l *LogEntry) WriteTerm(t Term, sbb *strings.Builder) { + // virtual == only a coeff, we discard the wire + if t.VariableVisibility() == Public && t.VariableID() == 0 { + sbb.WriteString("%s") + t.SetVariableVisibility(Virtual) + l.ToResolve = append(l.ToResolve, t) + return + } + + cID := t.CoeffID() + if cID == CoeffIdMinusOne { + sbb.WriteString("-%s") + } else if cID == CoeffIdOne { + sbb.WriteString("%s") + } else { + sbb.WriteString("%s*%s") + } + + l.ToResolve = append(l.ToResolve, t) +} + +func (l *LogEntry) WriteStack(sbb *strings.Builder) { + // derived from: https://golang.org/pkg/runtime/#example_Frames + // we stop when func name == Define as it is where the gnark circuit code should start + + // Ask runtime.Callers for up to 10 pcs + pc := make([]uintptr, 10) + n := runtime.Callers(3, pc) + if n == 0 { + // No pcs available. Stop now. + // This can happen if the first argument to runtime.Callers is large. + return + } + pc = pc[:n] // pass only valid pcs to runtime.CallersFrames + frames := runtime.CallersFrames(pc) + // Loop to get frames. + // A fixed number of pcs can expand to an indefinite number of Frames. + for { + frame, more := frames.Next() + fe := strings.Split(frame.Function, "/") + function := fe[len(fe)-1] + if strings.Contains(function, "frontend.(*ConstraintSystem)") { + continue + } + + sbb.WriteString(function) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(frame.File) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(frame.Line)) + sbb.WriteByte('\n') + if !more { + break + } + if strings.HasSuffix(function, "Define") { + break + } + } +} diff --git a/internal/backend/compiled/r1c.go b/internal/backend/compiled/r1c.go index 1e08ecec36..f25c39335a 100644 --- a/internal/backend/compiled/r1c.go +++ b/internal/backend/compiled/r1c.go @@ -54,17 +54,7 @@ func (l LinearExpression) Less(i, j int) bool { // R1C used to compute the wires type R1C struct { - L LinearExpression - R LinearExpression - O LinearExpression -} - -// LogEntry is used as a shared data structure between the frontend and the backend -// to represent string values (in logs or debug info) where a value is not known at compile time -// (which is the case for variables that need to be resolved in the R1CS) -type LogEntry struct { - Format string - ToResolve []int + L, R, O LinearExpression } // Visibility encodes a Variable (or wire) visibility diff --git a/internal/backend/compiled/r1cs.go b/internal/backend/compiled/r1cs.go index 4410d75636..a9aa52c5d0 100644 --- a/internal/backend/compiled/r1cs.go +++ b/internal/backend/compiled/r1cs.go @@ -23,18 +23,19 @@ import ( // R1CS decsribes a set of R1CS constraint type R1CS struct { // Wires - NbInternalVariables int - NbPublicVariables int // includes ONE wire - NbSecretVariables int - Logs []LogEntry - DebugInfoComputation []LogEntry + NbInternalVariables int + NbPublicVariables int // includes ONE wire + NbSecretVariables int + Logs []LogEntry + DebugInfo []LogEntry // Constraints NbConstraints int // total number of constraints Constraints []R1C // Hints - Hints []Hint + Hints []Hint + MDebug map[int]int // maps constraint id to debugInfo id } // GetNbConstraints returns the number of constraints diff --git a/internal/backend/compiled/r1cs_sparse.go b/internal/backend/compiled/r1cs_sparse.go index 0ade1f341e..f58b77bbc9 100644 --- a/internal/backend/compiled/r1cs_sparse.go +++ b/internal/backend/compiled/r1cs_sparse.go @@ -34,10 +34,12 @@ type SparseR1CS struct { Constraints []SparseR1C // Logs (e.g. variables that have been printed using cs.Println) - Logs []LogEntry + Logs []LogEntry + DebugInfo []LogEntry // Hints - Hints []Hint + Hints []Hint + MDebug map[int]int // maps constraint id to debugInfo id } // GetNbVariables return number of internal, secret and public variables diff --git a/internal/backend/compiled/term.go b/internal/backend/compiled/term.go index 60d420b0e2..766770f31b 100644 --- a/internal/backend/compiled/term.go +++ b/internal/backend/compiled/term.go @@ -44,6 +44,9 @@ const ( nbBitsVariableVisibility = 3 ) +// TermDelimitor is reserved for internal use +const TermDelimitor Term = Term(maskFutureUse) + const ( shiftVariableID = 0 shiftCoeffID = nbBitsVariableID diff --git a/internal/generator/backend/template/representations/cs.go.tmpl b/internal/generator/backend/template/representations/cs.go.tmpl index 90df368f88..c94fd17839 100644 --- a/internal/generator/backend/template/representations/cs.go.tmpl +++ b/internal/generator/backend/template/representations/cs.go.tmpl @@ -140,8 +140,59 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} + var ( + isEval bool + eval fr.Element + missingValue bool + ) for j := 0; j < len(log.ToResolve); j++ { - vID := log.ToResolve[j] + if log.ToResolve[j] == compiled.TermDelimitor { + // this is a special case where we want to evaluate the following terms until the next delimitor. + if !isEval { + isEval = true + missingValue = false + eval.SetZero() + continue + } + isEval = false + if missingValue { + toResolve = append(toResolve, "???") + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + continue + } + cID, vID, visibility := log.ToResolve[j].Unpack() + + if isEval { + // we are evaluating + if visibility == compiled.Virtual { + // just add the constant + eval.Add(&eval, &s.coefficients[cID]) + continue + } + if !s.solved[vID] { + missingValue = true + continue + } + tv := s.computeTerm(log.ToResolve[j]) + eval.Add(&eval, &tv) + continue + } + + if visibility == compiled.Virtual { + // it's just a constant + if cID == compiled.CoeffIdMinusOne { + toResolve = append(toResolve, "-1") + } else { + toResolve = append(toResolve, s.coefficients[cID].String()) + } + continue + } + if !(cID == compiled.CoeffIdMinusOne || cID == compiled.CoeffIdOne) { + toResolve = append(toResolve, s.coefficients[cID].String()) + } if !s.solved[vID] { toResolve = append(toResolve, "???") } else { diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index fa61f08062..33f66dfdcc 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -89,11 +89,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // check if there is an inconsistant constraint var check fr.Element - // TODO @gbotrel clean this - // this variable is used to navigate in the debugInfoComputation slice. - // It is incremented by one each time a division happens for solving a constraint. - var debugInfoComputationOffset uint - // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire @@ -102,11 +97,9 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - offset, err := cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { return solution.values, err } - debugInfoComputationOffset += offset // compute values for the R1C (ie value * coeff) a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) @@ -114,9 +107,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) if !check.Equal(&c[i]) { - debugInfo := cs.DebugInfoComputation[debugInfoComputationOffset] - debugInfoStr := solution.logValue(debugInfo) - return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -199,10 +194,7 @@ func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.E // It returns the 1 if the the position to solve is in the quadratic part (it // means that there is a division and serves to navigate in the log info for the // computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error) { - - // value to return: 1 if the wire to solve is in the quadratic term, 0 otherwise - var offset uint +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -237,19 +229,19 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error for _, t := range r.L { if err := processTerm(t, &a, 1); err != nil { - return 0, err + return err } } for _, t := range r.R { if err := processTerm(t, &b, 2); err != nil { - return 0, err + return err } } for _, t := range r.O { if err := processTerm(t, &c, 3); err != nil { - return 0, err + return err } } @@ -257,7 +249,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return 0, nil + return nil } // we compute the wire value and instantiate it @@ -272,14 +264,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error wire.Div(&c, &b). Sub(&wire, &a) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 3: wire.Mul(&a, &b). @@ -289,7 +279,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error solution.set(vID, wire) - return offset, nil + return nil } @@ -407,14 +397,15 @@ func (cs *R1CS) SetLoggerOutput(w io.Writer) { // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) - - // encode our object - if err := encoder.Encode(cs); err != nil { - return _w.N, err + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err } + encoder := enc.NewEncoder(&_w) - return _w.N, nil + // encode our object + err = encoder.Encode(cs) + return _w.N, err } // ReadFrom attempts to decode R1CS from io.Reader using cbor diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl index 143f549042..47e8a8d07a 100644 --- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl @@ -97,7 +97,11 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) return solution.values, fmt.Errorf("constraint %d: %w", i, err) } if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -318,10 +322,14 @@ func (cs *SparseR1CS) CurveID() ecc.ID { // WriteTo encodes SparseR1CS into provided io.Writer using cbor func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) // encode our object - err := encoder.Encode(cs) + err = encoder.Encode(cs) return _w.N, err } From 46442f9761a0bb132b950bc80198cb9cd529ad71 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Thu, 16 Sep 2021 14:17:19 -0500 Subject: [PATCH 02/15] refactor: move mHints to compiled R1CS and SparseR1CS --- frontend/cs_to_r1cs.go | 4 ++ frontend/cs_to_r1cs_sparse.go | 5 +++ internal/backend/bls12-377/cs/r1cs.go | 37 ++++------------- internal/backend/bls12-377/cs/r1cs_sparse.go | 34 +++------------- internal/backend/bls12-381/cs/r1cs.go | 37 ++++------------- internal/backend/bls12-381/cs/r1cs_sparse.go | 34 +++------------- internal/backend/bls24-315/cs/r1cs.go | 37 ++++------------- internal/backend/bls24-315/cs/r1cs_sparse.go | 34 +++------------- internal/backend/bn254/cs/r1cs.go | 37 ++++------------- internal/backend/bn254/cs/r1cs_sparse.go | 34 +++------------- internal/backend/bw6-761/cs/r1cs.go | 37 ++++------------- internal/backend/bw6-761/cs/r1cs_sparse.go | 34 +++------------- internal/backend/compiled/r1cs.go | 4 +- internal/backend/compiled/r1cs_sparse.go | 4 +- .../template/representations/r1cs.go.tmpl | 40 ++++--------------- .../representations/r1cs.sparse.go.tmpl | 37 +++-------------- 16 files changed, 94 insertions(+), 355 deletions(-) diff --git a/frontend/cs_to_r1cs.go b/frontend/cs_to_r1cs.go index 2c0bb2bf64..228476fad2 100644 --- a/frontend/cs_to_r1cs.go +++ b/frontend/cs_to_r1cs.go @@ -78,7 +78,11 @@ func (cs *ConstraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er for j := 0; j < len(res.Hints[i].Inputs); j++ { offsetIDs(res.Hints[i].Inputs[j]) } + } + res.MHints = make(map[int]int, len(res.Hints)) + for i := 0; i < len(res.Hints); i++ { + res.MHints[res.Hints[i].WireID] = i } // we need to offset the ids in logs & debugInfo diff --git a/frontend/cs_to_r1cs_sparse.go b/frontend/cs_to_r1cs_sparse.go index d039bf426d..6207a0ace6 100644 --- a/frontend/cs_to_r1cs_sparse.go +++ b/frontend/cs_to_r1cs_sparse.go @@ -169,6 +169,11 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst } } + res.ccs.MHints = make(map[int]int, len(res.ccs.Hints)) + for i := 0; i < len(res.ccs.Hints); i++ { + res.ccs.MHints[res.ccs.Hints[i].WireID] = i + } + // update number of internal variables with new wires created // while processing R1C -> SparseR1C res.ccs.NbInternalVariables = res.scsInternalVariables diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index dfe73ace12..d7a2b29200 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -41,7 +41,6 @@ type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here loggerOut io.Writer - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -55,8 +54,6 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r.Coefficients[i].SetBigInt(&coefficients[i]) } - r.initHints() - return &r } @@ -108,6 +105,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + // TODO should return debug info, if any. return solution.values, err } @@ -143,15 +141,6 @@ func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) er return err } -func (cs *R1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // mulByCoeff sets res = res * t.Coeff func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() @@ -218,8 +207,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { } // first we check if this is a hint wire - if hID, ok := cs.mHints[vID]; ok { - // TODO handle error + if hID, ok := cs.MHints[vID]; ok { return solution.solveHint(cs.Hints[hID], vID) } @@ -299,15 +287,7 @@ func (cs *R1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *R1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - return t.Execute(w, &d) + return t.Execute(w, cs) } func add(a, b int) int { @@ -318,10 +298,10 @@ func sub(a, b int) int { return a - b } -func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int) string { +func toHTML(l compiled.LinearExpression, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder for i := 0; i < len(l); i++ { - termToHTML(l[i], &sbb, coeffs, mHints, false) + termToHTML(l[i], &sbb, coeffs, MHints, false) if i+1 < len(l) { sbb.WriteString(" + ") } @@ -329,7 +309,7 @@ func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int return sbb.String() } -func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHints map[int]int, offset bool) { +func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, MHints map[int]int, offset bool) { tID := t.CoeffID() if tID == compiled.CoeffIdOne { // do nothing, just print the variable @@ -350,7 +330,7 @@ func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHin switch t.VariableVisibility() { case compiled.Internal: class = "internal" - if _, ok := mHints[vID]; ok { + if _, ok := MHints[vID]; ok { class = "hint" } case compiled.Public: @@ -418,8 +398,5 @@ func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { return int64(decoder.NumBytesRead()), err } - // init the hint map - cs.initHints() - return int64(decoder.NumBytesRead()), nil } diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go index 5e63fa29ed..56a12482f7 100644 --- a/internal/backend/bls12-377/cs/r1cs_sparse.go +++ b/internal/backend/bls12-377/cs/r1cs_sparse.go @@ -39,7 +39,6 @@ type SparseR1CS struct { // Coefficients in the constraints Coefficients []fr.Element // list of unique coefficients. - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -52,8 +51,6 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs.Coefficients[i].SetBigInt(&coefficients[i]) } - cs.initHints() - return &cs } @@ -135,7 +132,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.mHints[lID]; ok { + if hID, ok := cs.MHints[lID]; ok { if err := solution.solveHint(cs.Hints[hID], lID); err != nil { return -1, err } @@ -147,7 +144,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.mHints[rID]; ok { + if hID, ok := cs.MHints[rID]; ok { if err := solution.solveHint(cs.Hints[hID], rID); err != nil { return -1, err } @@ -158,7 +155,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.mHints[oID]; ok { + if hID, ok := cs.MHints[oID]; ok { if err := solution.solveHint(cs.Hints[hID], oID); err != nil { return -1, err } @@ -269,21 +266,12 @@ func (cs *SparseR1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *SparseR1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - - return t.Execute(w, &d) + return t.Execute(w, cs) } -func toHTMLTerm(t compiled.Term, coeffs []fr.Element, mHints map[int]int) string { +func toHTMLTerm(t compiled.Term, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder - termToHTML(t, &sbb, coeffs, mHints, true) + termToHTML(t, &sbb, coeffs, MHints, true) return sbb.String() } @@ -299,15 +287,6 @@ func toHTMLCoeff(cID int, coeffs []fr.Element) string { return sbb.String() } -func (cs *SparseR1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // FrSize return fr.Limbs * 8, size in byte of a fr element func (cs *SparseR1CS) FrSize() int { return fr.Limbs * 8 @@ -345,6 +324,5 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { } decoder := dm.NewDecoder(r) err = decoder.Decode(cs) - cs.initHints() return int64(decoder.NumBytesRead()), err } diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index d302b6f871..fcb68d8edc 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -41,7 +41,6 @@ type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here loggerOut io.Writer - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -55,8 +54,6 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r.Coefficients[i].SetBigInt(&coefficients[i]) } - r.initHints() - return &r } @@ -108,6 +105,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + // TODO should return debug info, if any. return solution.values, err } @@ -143,15 +141,6 @@ func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) er return err } -func (cs *R1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // mulByCoeff sets res = res * t.Coeff func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() @@ -218,8 +207,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { } // first we check if this is a hint wire - if hID, ok := cs.mHints[vID]; ok { - // TODO handle error + if hID, ok := cs.MHints[vID]; ok { return solution.solveHint(cs.Hints[hID], vID) } @@ -299,15 +287,7 @@ func (cs *R1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *R1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - return t.Execute(w, &d) + return t.Execute(w, cs) } func add(a, b int) int { @@ -318,10 +298,10 @@ func sub(a, b int) int { return a - b } -func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int) string { +func toHTML(l compiled.LinearExpression, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder for i := 0; i < len(l); i++ { - termToHTML(l[i], &sbb, coeffs, mHints, false) + termToHTML(l[i], &sbb, coeffs, MHints, false) if i+1 < len(l) { sbb.WriteString(" + ") } @@ -329,7 +309,7 @@ func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int return sbb.String() } -func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHints map[int]int, offset bool) { +func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, MHints map[int]int, offset bool) { tID := t.CoeffID() if tID == compiled.CoeffIdOne { // do nothing, just print the variable @@ -350,7 +330,7 @@ func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHin switch t.VariableVisibility() { case compiled.Internal: class = "internal" - if _, ok := mHints[vID]; ok { + if _, ok := MHints[vID]; ok { class = "hint" } case compiled.Public: @@ -418,8 +398,5 @@ func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { return int64(decoder.NumBytesRead()), err } - // init the hint map - cs.initHints() - return int64(decoder.NumBytesRead()), nil } diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go index 14def8092e..2367017887 100644 --- a/internal/backend/bls12-381/cs/r1cs_sparse.go +++ b/internal/backend/bls12-381/cs/r1cs_sparse.go @@ -39,7 +39,6 @@ type SparseR1CS struct { // Coefficients in the constraints Coefficients []fr.Element // list of unique coefficients. - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -52,8 +51,6 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs.Coefficients[i].SetBigInt(&coefficients[i]) } - cs.initHints() - return &cs } @@ -135,7 +132,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.mHints[lID]; ok { + if hID, ok := cs.MHints[lID]; ok { if err := solution.solveHint(cs.Hints[hID], lID); err != nil { return -1, err } @@ -147,7 +144,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.mHints[rID]; ok { + if hID, ok := cs.MHints[rID]; ok { if err := solution.solveHint(cs.Hints[hID], rID); err != nil { return -1, err } @@ -158,7 +155,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.mHints[oID]; ok { + if hID, ok := cs.MHints[oID]; ok { if err := solution.solveHint(cs.Hints[hID], oID); err != nil { return -1, err } @@ -269,21 +266,12 @@ func (cs *SparseR1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *SparseR1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - - return t.Execute(w, &d) + return t.Execute(w, cs) } -func toHTMLTerm(t compiled.Term, coeffs []fr.Element, mHints map[int]int) string { +func toHTMLTerm(t compiled.Term, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder - termToHTML(t, &sbb, coeffs, mHints, true) + termToHTML(t, &sbb, coeffs, MHints, true) return sbb.String() } @@ -299,15 +287,6 @@ func toHTMLCoeff(cID int, coeffs []fr.Element) string { return sbb.String() } -func (cs *SparseR1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // FrSize return fr.Limbs * 8, size in byte of a fr element func (cs *SparseR1CS) FrSize() int { return fr.Limbs * 8 @@ -345,6 +324,5 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { } decoder := dm.NewDecoder(r) err = decoder.Decode(cs) - cs.initHints() return int64(decoder.NumBytesRead()), err } diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index 3484459626..b732cdd332 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -41,7 +41,6 @@ type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here loggerOut io.Writer - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -55,8 +54,6 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r.Coefficients[i].SetBigInt(&coefficients[i]) } - r.initHints() - return &r } @@ -108,6 +105,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + // TODO should return debug info, if any. return solution.values, err } @@ -143,15 +141,6 @@ func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) er return err } -func (cs *R1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // mulByCoeff sets res = res * t.Coeff func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() @@ -218,8 +207,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { } // first we check if this is a hint wire - if hID, ok := cs.mHints[vID]; ok { - // TODO handle error + if hID, ok := cs.MHints[vID]; ok { return solution.solveHint(cs.Hints[hID], vID) } @@ -299,15 +287,7 @@ func (cs *R1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *R1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - return t.Execute(w, &d) + return t.Execute(w, cs) } func add(a, b int) int { @@ -318,10 +298,10 @@ func sub(a, b int) int { return a - b } -func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int) string { +func toHTML(l compiled.LinearExpression, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder for i := 0; i < len(l); i++ { - termToHTML(l[i], &sbb, coeffs, mHints, false) + termToHTML(l[i], &sbb, coeffs, MHints, false) if i+1 < len(l) { sbb.WriteString(" + ") } @@ -329,7 +309,7 @@ func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int return sbb.String() } -func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHints map[int]int, offset bool) { +func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, MHints map[int]int, offset bool) { tID := t.CoeffID() if tID == compiled.CoeffIdOne { // do nothing, just print the variable @@ -350,7 +330,7 @@ func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHin switch t.VariableVisibility() { case compiled.Internal: class = "internal" - if _, ok := mHints[vID]; ok { + if _, ok := MHints[vID]; ok { class = "hint" } case compiled.Public: @@ -418,8 +398,5 @@ func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { return int64(decoder.NumBytesRead()), err } - // init the hint map - cs.initHints() - return int64(decoder.NumBytesRead()), nil } diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go index 11f4438c50..e262132e71 100644 --- a/internal/backend/bls24-315/cs/r1cs_sparse.go +++ b/internal/backend/bls24-315/cs/r1cs_sparse.go @@ -39,7 +39,6 @@ type SparseR1CS struct { // Coefficients in the constraints Coefficients []fr.Element // list of unique coefficients. - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -52,8 +51,6 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs.Coefficients[i].SetBigInt(&coefficients[i]) } - cs.initHints() - return &cs } @@ -135,7 +132,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.mHints[lID]; ok { + if hID, ok := cs.MHints[lID]; ok { if err := solution.solveHint(cs.Hints[hID], lID); err != nil { return -1, err } @@ -147,7 +144,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.mHints[rID]; ok { + if hID, ok := cs.MHints[rID]; ok { if err := solution.solveHint(cs.Hints[hID], rID); err != nil { return -1, err } @@ -158,7 +155,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.mHints[oID]; ok { + if hID, ok := cs.MHints[oID]; ok { if err := solution.solveHint(cs.Hints[hID], oID); err != nil { return -1, err } @@ -269,21 +266,12 @@ func (cs *SparseR1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *SparseR1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - - return t.Execute(w, &d) + return t.Execute(w, cs) } -func toHTMLTerm(t compiled.Term, coeffs []fr.Element, mHints map[int]int) string { +func toHTMLTerm(t compiled.Term, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder - termToHTML(t, &sbb, coeffs, mHints, true) + termToHTML(t, &sbb, coeffs, MHints, true) return sbb.String() } @@ -299,15 +287,6 @@ func toHTMLCoeff(cID int, coeffs []fr.Element) string { return sbb.String() } -func (cs *SparseR1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // FrSize return fr.Limbs * 8, size in byte of a fr element func (cs *SparseR1CS) FrSize() int { return fr.Limbs * 8 @@ -345,6 +324,5 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { } decoder := dm.NewDecoder(r) err = decoder.Decode(cs) - cs.initHints() return int64(decoder.NumBytesRead()), err } diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index 8d96387a3d..71f422fff3 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -41,7 +41,6 @@ type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here loggerOut io.Writer - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -55,8 +54,6 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r.Coefficients[i].SetBigInt(&coefficients[i]) } - r.initHints() - return &r } @@ -108,6 +105,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + // TODO should return debug info, if any. return solution.values, err } @@ -143,15 +141,6 @@ func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) er return err } -func (cs *R1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // mulByCoeff sets res = res * t.Coeff func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() @@ -218,8 +207,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { } // first we check if this is a hint wire - if hID, ok := cs.mHints[vID]; ok { - // TODO handle error + if hID, ok := cs.MHints[vID]; ok { return solution.solveHint(cs.Hints[hID], vID) } @@ -299,15 +287,7 @@ func (cs *R1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *R1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - return t.Execute(w, &d) + return t.Execute(w, cs) } func add(a, b int) int { @@ -318,10 +298,10 @@ func sub(a, b int) int { return a - b } -func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int) string { +func toHTML(l compiled.LinearExpression, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder for i := 0; i < len(l); i++ { - termToHTML(l[i], &sbb, coeffs, mHints, false) + termToHTML(l[i], &sbb, coeffs, MHints, false) if i+1 < len(l) { sbb.WriteString(" + ") } @@ -329,7 +309,7 @@ func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int return sbb.String() } -func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHints map[int]int, offset bool) { +func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, MHints map[int]int, offset bool) { tID := t.CoeffID() if tID == compiled.CoeffIdOne { // do nothing, just print the variable @@ -350,7 +330,7 @@ func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHin switch t.VariableVisibility() { case compiled.Internal: class = "internal" - if _, ok := mHints[vID]; ok { + if _, ok := MHints[vID]; ok { class = "hint" } case compiled.Public: @@ -418,8 +398,5 @@ func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { return int64(decoder.NumBytesRead()), err } - // init the hint map - cs.initHints() - return int64(decoder.NumBytesRead()), nil } diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go index 0a7ff0eaec..af6a6aea5e 100644 --- a/internal/backend/bn254/cs/r1cs_sparse.go +++ b/internal/backend/bn254/cs/r1cs_sparse.go @@ -39,7 +39,6 @@ type SparseR1CS struct { // Coefficients in the constraints Coefficients []fr.Element // list of unique coefficients. - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -52,8 +51,6 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs.Coefficients[i].SetBigInt(&coefficients[i]) } - cs.initHints() - return &cs } @@ -135,7 +132,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.mHints[lID]; ok { + if hID, ok := cs.MHints[lID]; ok { if err := solution.solveHint(cs.Hints[hID], lID); err != nil { return -1, err } @@ -147,7 +144,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.mHints[rID]; ok { + if hID, ok := cs.MHints[rID]; ok { if err := solution.solveHint(cs.Hints[hID], rID); err != nil { return -1, err } @@ -158,7 +155,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.mHints[oID]; ok { + if hID, ok := cs.MHints[oID]; ok { if err := solution.solveHint(cs.Hints[hID], oID); err != nil { return -1, err } @@ -269,21 +266,12 @@ func (cs *SparseR1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *SparseR1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - - return t.Execute(w, &d) + return t.Execute(w, cs) } -func toHTMLTerm(t compiled.Term, coeffs []fr.Element, mHints map[int]int) string { +func toHTMLTerm(t compiled.Term, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder - termToHTML(t, &sbb, coeffs, mHints, true) + termToHTML(t, &sbb, coeffs, MHints, true) return sbb.String() } @@ -299,15 +287,6 @@ func toHTMLCoeff(cID int, coeffs []fr.Element) string { return sbb.String() } -func (cs *SparseR1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // FrSize return fr.Limbs * 8, size in byte of a fr element func (cs *SparseR1CS) FrSize() int { return fr.Limbs * 8 @@ -345,6 +324,5 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { } decoder := dm.NewDecoder(r) err = decoder.Decode(cs) - cs.initHints() return int64(decoder.NumBytesRead()), err } diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index 7e78134464..229eeb2931 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -41,7 +41,6 @@ type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here loggerOut io.Writer - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -55,8 +54,6 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r.Coefficients[i].SetBigInt(&coefficients[i]) } - r.initHints() - return &r } @@ -108,6 +105,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + // TODO should return debug info, if any. return solution.values, err } @@ -143,15 +141,6 @@ func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) er return err } -func (cs *R1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // mulByCoeff sets res = res * t.Coeff func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() @@ -218,8 +207,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { } // first we check if this is a hint wire - if hID, ok := cs.mHints[vID]; ok { - // TODO handle error + if hID, ok := cs.MHints[vID]; ok { return solution.solveHint(cs.Hints[hID], vID) } @@ -299,15 +287,7 @@ func (cs *R1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *R1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - return t.Execute(w, &d) + return t.Execute(w, cs) } func add(a, b int) int { @@ -318,10 +298,10 @@ func sub(a, b int) int { return a - b } -func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int) string { +func toHTML(l compiled.LinearExpression, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder for i := 0; i < len(l); i++ { - termToHTML(l[i], &sbb, coeffs, mHints, false) + termToHTML(l[i], &sbb, coeffs, MHints, false) if i+1 < len(l) { sbb.WriteString(" + ") } @@ -329,7 +309,7 @@ func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int return sbb.String() } -func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHints map[int]int, offset bool) { +func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, MHints map[int]int, offset bool) { tID := t.CoeffID() if tID == compiled.CoeffIdOne { // do nothing, just print the variable @@ -350,7 +330,7 @@ func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHin switch t.VariableVisibility() { case compiled.Internal: class = "internal" - if _, ok := mHints[vID]; ok { + if _, ok := MHints[vID]; ok { class = "hint" } case compiled.Public: @@ -418,8 +398,5 @@ func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { return int64(decoder.NumBytesRead()), err } - // init the hint map - cs.initHints() - return int64(decoder.NumBytesRead()), nil } diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go index 71a0dd36d8..4514777d3b 100644 --- a/internal/backend/bw6-761/cs/r1cs_sparse.go +++ b/internal/backend/bw6-761/cs/r1cs_sparse.go @@ -39,7 +39,6 @@ type SparseR1CS struct { // Coefficients in the constraints Coefficients []fr.Element // list of unique coefficients. - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -52,8 +51,6 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs.Coefficients[i].SetBigInt(&coefficients[i]) } - cs.initHints() - return &cs } @@ -135,7 +132,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.mHints[lID]; ok { + if hID, ok := cs.MHints[lID]; ok { if err := solution.solveHint(cs.Hints[hID], lID); err != nil { return -1, err } @@ -147,7 +144,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.mHints[rID]; ok { + if hID, ok := cs.MHints[rID]; ok { if err := solution.solveHint(cs.Hints[hID], rID); err != nil { return -1, err } @@ -158,7 +155,7 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.mHints[oID]; ok { + if hID, ok := cs.MHints[oID]; ok { if err := solution.solveHint(cs.Hints[hID], oID); err != nil { return -1, err } @@ -269,21 +266,12 @@ func (cs *SparseR1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *SparseR1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - - return t.Execute(w, &d) + return t.Execute(w, cs) } -func toHTMLTerm(t compiled.Term, coeffs []fr.Element, mHints map[int]int) string { +func toHTMLTerm(t compiled.Term, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder - termToHTML(t, &sbb, coeffs, mHints, true) + termToHTML(t, &sbb, coeffs, MHints, true) return sbb.String() } @@ -299,15 +287,6 @@ func toHTMLCoeff(cID int, coeffs []fr.Element) string { return sbb.String() } -func (cs *SparseR1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // FrSize return fr.Limbs * 8, size in byte of a fr element func (cs *SparseR1CS) FrSize() int { return fr.Limbs * 8 @@ -345,6 +324,5 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { } decoder := dm.NewDecoder(r) err = decoder.Decode(cs) - cs.initHints() return int64(decoder.NumBytesRead()), err } diff --git a/internal/backend/compiled/r1cs.go b/internal/backend/compiled/r1cs.go index a9aa52c5d0..30e0a7d755 100644 --- a/internal/backend/compiled/r1cs.go +++ b/internal/backend/compiled/r1cs.go @@ -34,7 +34,9 @@ type R1CS struct { Constraints []R1C // Hints - Hints []Hint + Hints []Hint + + MHints map[int]int // maps wire id to hint id MDebug map[int]int // maps constraint id to debugInfo id } diff --git a/internal/backend/compiled/r1cs_sparse.go b/internal/backend/compiled/r1cs_sparse.go index f58b77bbc9..840adb6ded 100644 --- a/internal/backend/compiled/r1cs_sparse.go +++ b/internal/backend/compiled/r1cs_sparse.go @@ -38,7 +38,9 @@ type SparseR1CS struct { DebugInfo []LogEntry // Hints - Hints []Hint + Hints []Hint + + MHints map[int]int // maps wire id to hint id MDebug map[int]int // maps constraint id to debugInfo id } diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index 33f66dfdcc..e7782436cc 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -24,7 +24,6 @@ type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here loggerOut io.Writer - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -38,8 +37,6 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r.Coefficients[i].SetBigInt(&coefficients[i]) } - r.initHints() - return &r } @@ -98,6 +95,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + // TODO should return debug info, if any. return solution.values, err } @@ -136,17 +134,6 @@ func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) er -func (cs *R1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i:=0; i Date: Thu, 16 Sep 2021 23:12:51 -0500 Subject: [PATCH 03/15] fix: remove debug stack trace from frontend error --- frontend/frontend.go | 54 +++----------------------------------------- 1 file changed, 3 insertions(+), 51 deletions(-) diff --git a/frontend/frontend.go b/frontend/frontend.go index fa63e0d9fc..678af9f067 100644 --- a/frontend/frontend.go +++ b/frontend/frontend.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "reflect" - "runtime/debug" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" @@ -42,11 +41,6 @@ func Compile(curveID ecc.ID, zkpID backend.ID, circuit Circuit, initialCapacity return nil, err } - // sanity checks - if err := cs.sanityCheck(); err != nil { - return nil, err - } - switch zkpID { case backend.GROTH16: ccs, err = cs.toR1CS(curveID) @@ -62,50 +56,6 @@ func Compile(curveID ecc.ID, zkpID backend.ID, circuit Circuit, initialCapacity return } -// sanityCheck ensures: -// * all constraints must have at most one unsolved wire, excluding hint wires -func (cs *ConstraintSystem) sanityCheck() error { - - solved := make([]bool, len(cs.internal.variables)) - for i := 0; i < len(cs.hints); i++ { - solved[cs.hints[i].WireID] = true - } - - countUnsolved := func(r1c compiled.R1C) int { - c := 0 - for i := 0; i < len(r1c.L); i++ { - _, vID, visibility := r1c.L[i].Unpack() - if visibility == compiled.Internal && !solved[vID] { - c++ - solved[vID] = true - } - } - for i := 0; i < len(r1c.R); i++ { - _, vID, visibility := r1c.R[i].Unpack() - if visibility == compiled.Internal && !solved[vID] { - c++ - solved[vID] = true - } - } - for i := 0; i < len(r1c.O); i++ { - _, vID, visibility := r1c.O[i].Unpack() - if visibility == compiled.Internal && !solved[vID] { - c++ - solved[vID] = true - } - } - return c - } - - for _, r1c := range cs.constraints { - if countUnsolved(r1c) > 1 { - return errors.New("constraint system has invalid constraints with multiple unsolved wire") - } - } - - return nil -} - // buildCS builds the constraint system. It bootstraps the inputs // allocations by parsing the circuit's underlying structure, then // it builds the constraint system using the Define method. @@ -113,7 +63,9 @@ func buildCS(curveID ecc.ID, circuit Circuit, initialCapacity ...int) (cs Constr // recover from panics to print user-friendlier messages defer func() { if r := recover(); r != nil { - err = fmt.Errorf("%v\n%s", r, string(debug.Stack())) + err = fmt.Errorf("%v", r) + // TODO @gbotrel with debug buiild tag + // fmt.Println(string(debug.Stack())) } }() // instantiate our constraint system From c97ccb1708f74ca4d2b1f65afe261194d86641fd Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Thu, 16 Sep 2021 23:17:53 -0500 Subject: [PATCH 04/15] build: comment fuzz test part that depends on assertions --- backend/groth16/fuzz.go | 22 +++++++++++----------- frontend/fuzz.go | 6 +++--- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/backend/groth16/fuzz.go b/backend/groth16/fuzz.go index 998f58bc69..093131059b 100644 --- a/backend/groth16/fuzz.go +++ b/backend/groth16/fuzz.go @@ -4,8 +4,6 @@ package groth16 import ( - "strings" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" backend_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/cs" @@ -17,7 +15,7 @@ import ( func Fuzz(data []byte) int { curves := []ecc.ID{ecc.BN254, ecc.BLS12_381} for _, curveID := range curves { - ccs, nbAssertions := frontend.CsFuzzed(data, curveID) + ccs := frontend.CsFuzzed(data, curveID) _, s, p := ccs.GetNbVariables() wSize := s + p - 1 ccs.SetLoggerOutput(nil) @@ -25,17 +23,19 @@ func Fuzz(data []byte) int { case *backend_bls12381.R1CS: w := make(witness_bls12381.Witness, wSize) // make w random - err := _r1cs.IsSolved(w, nil) - if nbAssertions == 0 && err != nil && !strings.Contains(err.Error(), "couldn't solve computational constraint") { - panic("no assertions, yet solving resulted in an error.") - } + _ = _r1cs.IsSolved(w, nil) + // TODO FIXME @gbotrel + // if nbAssertions == 0 && err != nil && !strings.Contains(err.Error(), "couldn't solve computational constraint") { + // panic("no assertions, yet solving resulted in an error.") + // } case *backend_bn254.R1CS: w := make(witness_bn254.Witness, wSize) // make w random - err := _r1cs.IsSolved(w, nil) - if nbAssertions == 0 && err != nil && !strings.Contains(err.Error(), "couldn't solve computational constraint") { - panic("no assertions, yet solving resulted in an error.") - } + _ = _r1cs.IsSolved(w, nil) + // TODO FIXME @gbotrel + // if nbAssertions == 0 && err != nil && !strings.Contains(err.Error(), "couldn't solve computational constraint") { + // panic("no assertions, yet solving resulted in an error.") + // } default: panic("unrecognized R1CS curve type") } diff --git a/frontend/fuzz.go b/frontend/fuzz.go index ab498886e2..a312bf5d44 100644 --- a/frontend/fuzz.go +++ b/frontend/fuzz.go @@ -22,13 +22,13 @@ func Fuzz(data []byte) int { curves := []ecc.ID{ecc.BN254, ecc.BLS12_381} for _, curveID := range curves { - _, _ = CsFuzzed(data, curveID) + _ = CsFuzzed(data, curveID) } return 1 } -func CsFuzzed(data []byte, curveID ecc.ID) (ccs CompiledConstraintSystem, nbAssertions int) { +func CsFuzzed(data []byte, curveID ecc.ID) (ccs CompiledConstraintSystem) { cs := newConstraintSystem() reader := bytes.NewReader(data) @@ -126,7 +126,7 @@ compile: if err != nil { panic(fmt.Sprintf("compiling (curve %s) failed: %v", curveID.String(), err)) } - return ccs, len(cs.assertions) + return ccs } func (cs *ConstraintSystem) shuffleVariables(seed int64, withConstant bool) []interface{} { From 9a49587867aa21b28cdba4371dace2a51ab27778 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 17 Sep 2021 08:43:42 -0500 Subject: [PATCH 05/15] refactor: remove NbConstraints from R1CS --- frontend/cs_to_r1cs.go | 1 - internal/backend/bls12-377/cs/r1cs.go | 10 +++++----- internal/backend/bls12-377/cs/r1cs_sparse.go | 7 ++++--- internal/backend/bls12-377/groth16/prove.go | 6 +++--- internal/backend/bls12-377/groth16/setup.go | 4 ++-- internal/backend/bls12-381/cs/r1cs.go | 10 +++++----- internal/backend/bls12-381/cs/r1cs_sparse.go | 7 ++++--- internal/backend/bls12-381/groth16/prove.go | 6 +++--- internal/backend/bls12-381/groth16/setup.go | 4 ++-- internal/backend/bls24-315/cs/r1cs.go | 10 +++++----- internal/backend/bls24-315/cs/r1cs_sparse.go | 7 ++++--- internal/backend/bls24-315/groth16/prove.go | 6 +++--- internal/backend/bls24-315/groth16/setup.go | 4 ++-- internal/backend/bn254/cs/r1cs.go | 10 +++++----- internal/backend/bn254/cs/r1cs_sparse.go | 7 ++++--- internal/backend/bn254/groth16/prove.go | 6 +++--- internal/backend/bn254/groth16/setup.go | 4 ++-- internal/backend/bw6-761/cs/r1cs.go | 10 +++++----- internal/backend/bw6-761/cs/r1cs_sparse.go | 7 ++++--- internal/backend/bw6-761/groth16/prove.go | 6 +++--- internal/backend/bw6-761/groth16/setup.go | 4 ++-- internal/backend/compiled/r1cs.go | 5 ++--- .../backend/template/representations/r1cs.go.tmpl | 10 +++++----- .../template/representations/r1cs.sparse.go.tmpl | 7 ++++--- .../template/zkpschemes/groth16/groth16.prove.go.tmpl | 6 +++--- .../template/zkpschemes/groth16/groth16.setup.go.tmpl | 4 ++-- 26 files changed, 86 insertions(+), 82 deletions(-) diff --git a/frontend/cs_to_r1cs.go b/frontend/cs_to_r1cs.go index 228476fad2..fd77c53910 100644 --- a/frontend/cs_to_r1cs.go +++ b/frontend/cs_to_r1cs.go @@ -21,7 +21,6 @@ func (cs *ConstraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er NbInternalVariables: len(cs.internal.variables), NbPublicVariables: len(cs.public.variables), NbSecretVariables: len(cs.secret.variables), - NbConstraints: len(cs.constraints), Constraints: make([]compiled.R1C, len(cs.constraints)), Logs: make([]compiled.LogEntry, len(cs.logs)), DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index d7a2b29200..c2a7ad86c2 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -75,8 +75,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi } // compute the wires and the a, b, c polynomials - if len(a) != int(cs.NbConstraints) || len(b) != int(cs.NbConstraints) || len(c) != int(cs.NbConstraints) { - return solution.values, errors.New("invalid input size: len(a, b, c) == cs.NbConstraints") + if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { + return solution.values, errors.New("invalid input size: len(a, b, c) == len(Constraints)") } solution.solved[0] = true // ONE_WIRE @@ -134,9 +134,9 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - a := make([]fr.Element, cs.NbConstraints) - b := make([]fr.Element, cs.NbConstraints) - c := make([]fr.Element, cs.NbConstraints) + a := make([]fr.Element, len(cs.Constraints)) + b := make([]fr.Element, len(cs.Constraints)) + c := make([]fr.Element, len(cs.Constraints)) _, err := cs.Solve(witness, a, b, c, hintFunctions) return err } diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go index 56a12482f7..53e2633a05 100644 --- a/internal/backend/bls12-377/cs/r1cs_sparse.go +++ b/internal/backend/bls12-377/cs/r1cs_sparse.go @@ -37,8 +37,8 @@ import ( type SparseR1CS struct { compiled.SparseR1CS - // Coefficients in the constraints - Coefficients []fr.Element // list of unique coefficients. + Coefficients []fr.Element // coefficients in the constraints + loggerOut io.Writer } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -46,6 +46,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs := SparseR1CS{ SparseR1CS: ccs, Coefficients: make([]fr.Element, len(coefficients)), + loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { cs.Coefficients[i].SetBigInt(&coefficients[i]) @@ -58,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) (values []fr.Element, err error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/internal/backend/bls12-377/groth16/prove.go b/internal/backend/bls12-377/groth16/prove.go index bb12874e91..37b963a409 100644 --- a/internal/backend/bls12-377/groth16/prove.go +++ b/internal/backend/bls12-377/groth16/prove.go @@ -61,9 +61,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_377witness.Witness, hint } // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - b := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - c := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) + a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { diff --git a/internal/backend/bls12-377/groth16/setup.go b/internal/backend/bls12-377/groth16/setup.go index 971ee42f8d..3e53f9b514 100644 --- a/internal/backend/bls12-377/groth16/setup.go +++ b/internal/backend/bls12-377/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(r1cs.NbConstraints), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -412,7 +412,7 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.NbPublicVariables + r1cs.NbSecretVariables - nbConstraints := r1cs.NbConstraints + nbConstraints := len(r1cs.Constraints) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints), 1, true) diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index fcb68d8edc..4d74554a33 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -75,8 +75,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi } // compute the wires and the a, b, c polynomials - if len(a) != int(cs.NbConstraints) || len(b) != int(cs.NbConstraints) || len(c) != int(cs.NbConstraints) { - return solution.values, errors.New("invalid input size: len(a, b, c) == cs.NbConstraints") + if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { + return solution.values, errors.New("invalid input size: len(a, b, c) == len(Constraints)") } solution.solved[0] = true // ONE_WIRE @@ -134,9 +134,9 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - a := make([]fr.Element, cs.NbConstraints) - b := make([]fr.Element, cs.NbConstraints) - c := make([]fr.Element, cs.NbConstraints) + a := make([]fr.Element, len(cs.Constraints)) + b := make([]fr.Element, len(cs.Constraints)) + c := make([]fr.Element, len(cs.Constraints)) _, err := cs.Solve(witness, a, b, c, hintFunctions) return err } diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go index 2367017887..dc56b51dda 100644 --- a/internal/backend/bls12-381/cs/r1cs_sparse.go +++ b/internal/backend/bls12-381/cs/r1cs_sparse.go @@ -37,8 +37,8 @@ import ( type SparseR1CS struct { compiled.SparseR1CS - // Coefficients in the constraints - Coefficients []fr.Element // list of unique coefficients. + Coefficients []fr.Element // coefficients in the constraints + loggerOut io.Writer } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -46,6 +46,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs := SparseR1CS{ SparseR1CS: ccs, Coefficients: make([]fr.Element, len(coefficients)), + loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { cs.Coefficients[i].SetBigInt(&coefficients[i]) @@ -58,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) (values []fr.Element, err error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/internal/backend/bls12-381/groth16/prove.go b/internal/backend/bls12-381/groth16/prove.go index c2ad6ec166..7e60e7d984 100644 --- a/internal/backend/bls12-381/groth16/prove.go +++ b/internal/backend/bls12-381/groth16/prove.go @@ -61,9 +61,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_381witness.Witness, hint } // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - b := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - c := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) + a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { diff --git a/internal/backend/bls12-381/groth16/setup.go b/internal/backend/bls12-381/groth16/setup.go index 71228b06a7..4cc996279c 100644 --- a/internal/backend/bls12-381/groth16/setup.go +++ b/internal/backend/bls12-381/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(r1cs.NbConstraints), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -412,7 +412,7 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.NbPublicVariables + r1cs.NbSecretVariables - nbConstraints := r1cs.NbConstraints + nbConstraints := len(r1cs.Constraints) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints), 1, true) diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index b732cdd332..b4ed832598 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -75,8 +75,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi } // compute the wires and the a, b, c polynomials - if len(a) != int(cs.NbConstraints) || len(b) != int(cs.NbConstraints) || len(c) != int(cs.NbConstraints) { - return solution.values, errors.New("invalid input size: len(a, b, c) == cs.NbConstraints") + if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { + return solution.values, errors.New("invalid input size: len(a, b, c) == len(Constraints)") } solution.solved[0] = true // ONE_WIRE @@ -134,9 +134,9 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - a := make([]fr.Element, cs.NbConstraints) - b := make([]fr.Element, cs.NbConstraints) - c := make([]fr.Element, cs.NbConstraints) + a := make([]fr.Element, len(cs.Constraints)) + b := make([]fr.Element, len(cs.Constraints)) + c := make([]fr.Element, len(cs.Constraints)) _, err := cs.Solve(witness, a, b, c, hintFunctions) return err } diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go index e262132e71..b496e450ce 100644 --- a/internal/backend/bls24-315/cs/r1cs_sparse.go +++ b/internal/backend/bls24-315/cs/r1cs_sparse.go @@ -37,8 +37,8 @@ import ( type SparseR1CS struct { compiled.SparseR1CS - // Coefficients in the constraints - Coefficients []fr.Element // list of unique coefficients. + Coefficients []fr.Element // coefficients in the constraints + loggerOut io.Writer } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -46,6 +46,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs := SparseR1CS{ SparseR1CS: ccs, Coefficients: make([]fr.Element, len(coefficients)), + loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { cs.Coefficients[i].SetBigInt(&coefficients[i]) @@ -58,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) (values []fr.Element, err error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/internal/backend/bls24-315/groth16/prove.go b/internal/backend/bls24-315/groth16/prove.go index deaad3f871..23da0b0fdc 100644 --- a/internal/backend/bls24-315/groth16/prove.go +++ b/internal/backend/bls24-315/groth16/prove.go @@ -61,9 +61,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls24_315witness.Witness, hint } // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - b := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - c := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) + a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { diff --git a/internal/backend/bls24-315/groth16/setup.go b/internal/backend/bls24-315/groth16/setup.go index 9c5a225cb6..ec5ebda7ec 100644 --- a/internal/backend/bls24-315/groth16/setup.go +++ b/internal/backend/bls24-315/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(r1cs.NbConstraints), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -412,7 +412,7 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.NbPublicVariables + r1cs.NbSecretVariables - nbConstraints := r1cs.NbConstraints + nbConstraints := len(r1cs.Constraints) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints), 1, true) diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index 71f422fff3..f40f5c1382 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -75,8 +75,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi } // compute the wires and the a, b, c polynomials - if len(a) != int(cs.NbConstraints) || len(b) != int(cs.NbConstraints) || len(c) != int(cs.NbConstraints) { - return solution.values, errors.New("invalid input size: len(a, b, c) == cs.NbConstraints") + if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { + return solution.values, errors.New("invalid input size: len(a, b, c) == len(Constraints)") } solution.solved[0] = true // ONE_WIRE @@ -134,9 +134,9 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - a := make([]fr.Element, cs.NbConstraints) - b := make([]fr.Element, cs.NbConstraints) - c := make([]fr.Element, cs.NbConstraints) + a := make([]fr.Element, len(cs.Constraints)) + b := make([]fr.Element, len(cs.Constraints)) + c := make([]fr.Element, len(cs.Constraints)) _, err := cs.Solve(witness, a, b, c, hintFunctions) return err } diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go index af6a6aea5e..9c11079791 100644 --- a/internal/backend/bn254/cs/r1cs_sparse.go +++ b/internal/backend/bn254/cs/r1cs_sparse.go @@ -37,8 +37,8 @@ import ( type SparseR1CS struct { compiled.SparseR1CS - // Coefficients in the constraints - Coefficients []fr.Element // list of unique coefficients. + Coefficients []fr.Element // coefficients in the constraints + loggerOut io.Writer } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -46,6 +46,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs := SparseR1CS{ SparseR1CS: ccs, Coefficients: make([]fr.Element, len(coefficients)), + loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { cs.Coefficients[i].SetBigInt(&coefficients[i]) @@ -58,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) (values []fr.Element, err error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/internal/backend/bn254/groth16/prove.go b/internal/backend/bn254/groth16/prove.go index aac9da3b2a..906ee2280d 100644 --- a/internal/backend/bn254/groth16/prove.go +++ b/internal/backend/bn254/groth16/prove.go @@ -61,9 +61,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bn254witness.Witness, hintFunc } // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - b := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - c := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) + a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { diff --git a/internal/backend/bn254/groth16/setup.go b/internal/backend/bn254/groth16/setup.go index 97e00dcf86..dc710839d3 100644 --- a/internal/backend/bn254/groth16/setup.go +++ b/internal/backend/bn254/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(r1cs.NbConstraints), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -412,7 +412,7 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.NbPublicVariables + r1cs.NbSecretVariables - nbConstraints := r1cs.NbConstraints + nbConstraints := len(r1cs.Constraints) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints), 1, true) diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index 229eeb2931..9cdd80224c 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -75,8 +75,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi } // compute the wires and the a, b, c polynomials - if len(a) != int(cs.NbConstraints) || len(b) != int(cs.NbConstraints) || len(c) != int(cs.NbConstraints) { - return solution.values, errors.New("invalid input size: len(a, b, c) == cs.NbConstraints") + if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { + return solution.values, errors.New("invalid input size: len(a, b, c) == len(Constraints)") } solution.solved[0] = true // ONE_WIRE @@ -134,9 +134,9 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - a := make([]fr.Element, cs.NbConstraints) - b := make([]fr.Element, cs.NbConstraints) - c := make([]fr.Element, cs.NbConstraints) + a := make([]fr.Element, len(cs.Constraints)) + b := make([]fr.Element, len(cs.Constraints)) + c := make([]fr.Element, len(cs.Constraints)) _, err := cs.Solve(witness, a, b, c, hintFunctions) return err } diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go index 4514777d3b..96494976e8 100644 --- a/internal/backend/bw6-761/cs/r1cs_sparse.go +++ b/internal/backend/bw6-761/cs/r1cs_sparse.go @@ -37,8 +37,8 @@ import ( type SparseR1CS struct { compiled.SparseR1CS - // Coefficients in the constraints - Coefficients []fr.Element // list of unique coefficients. + Coefficients []fr.Element // coefficients in the constraints + loggerOut io.Writer } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -46,6 +46,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs := SparseR1CS{ SparseR1CS: ccs, Coefficients: make([]fr.Element, len(coefficients)), + loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { cs.Coefficients[i].SetBigInt(&coefficients[i]) @@ -58,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) (values []fr.Element, err error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/internal/backend/bw6-761/groth16/prove.go b/internal/backend/bw6-761/groth16/prove.go index 6d81352e53..19cd63abea 100644 --- a/internal/backend/bw6-761/groth16/prove.go +++ b/internal/backend/bw6-761/groth16/prove.go @@ -61,9 +61,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bw6_761witness.Witness, hintFu } // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - b := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - c := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) + a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { diff --git a/internal/backend/bw6-761/groth16/setup.go b/internal/backend/bw6-761/groth16/setup.go index 68745e3335..b9cacb8339 100644 --- a/internal/backend/bw6-761/groth16/setup.go +++ b/internal/backend/bw6-761/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(r1cs.NbConstraints), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -412,7 +412,7 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.NbPublicVariables + r1cs.NbSecretVariables - nbConstraints := r1cs.NbConstraints + nbConstraints := len(r1cs.Constraints) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints), 1, true) diff --git a/internal/backend/compiled/r1cs.go b/internal/backend/compiled/r1cs.go index 30e0a7d755..5e06de9241 100644 --- a/internal/backend/compiled/r1cs.go +++ b/internal/backend/compiled/r1cs.go @@ -30,8 +30,7 @@ type R1CS struct { DebugInfo []LogEntry // Constraints - NbConstraints int // total number of constraints - Constraints []R1C + Constraints []R1C // Hints Hints []Hint @@ -42,7 +41,7 @@ type R1CS struct { // GetNbConstraints returns the number of constraints func (r1cs *R1CS) GetNbConstraints() int { - return r1cs.NbConstraints + return len(r1cs.Constraints) } // GetNbVariables return number of internal, secret and public variables diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index e7782436cc..32f335387c 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -61,8 +61,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // compute the wires and the a, b, c polynomials - if len(a) != int(cs.NbConstraints) || len(b) != int(cs.NbConstraints) || len(c) != int(cs.NbConstraints){ - return solution.values, errors.New("invalid input size: len(a, b, c) == cs.NbConstraints") + if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints){ + return solution.values, errors.New("invalid input size: len(a, b, c) == len(Constraints)") } @@ -125,9 +125,9 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - a := make([]fr.Element, cs.NbConstraints) - b := make([]fr.Element, cs.NbConstraints) - c := make([]fr.Element, cs.NbConstraints) + a := make([]fr.Element, len(cs.Constraints)) + b := make([]fr.Element, len(cs.Constraints)) + c := make([]fr.Element, len(cs.Constraints)) _, err := cs.Solve(witness, a, b, c, hintFunctions) return err } diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl index 9ff3d29cfe..821273057d 100644 --- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl @@ -19,8 +19,8 @@ import ( type SparseR1CS struct { compiled.SparseR1CS - // Coefficients in the constraints - Coefficients []fr.Element // list of unique coefficients. + Coefficients []fr.Element // coefficients in the constraints + loggerOut io.Writer } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -28,6 +28,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs := SparseR1CS{ SparseR1CS: ccs, Coefficients: make([]fr.Element, len(coefficients)), + loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { cs.Coefficients[i].SetBigInt(&coefficients[i]) @@ -41,7 +42,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) (values []fr.Element, err error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl index 5d258050c7..051c7e2706 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl @@ -39,9 +39,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness {{ toLower .CurveID }}witness. } // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - b := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - c := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) + a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions ); err != nil { diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl index 456784ede0..cd69cb7f06 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl @@ -74,7 +74,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(r1cs.NbConstraints), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -401,7 +401,7 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.NbPublicVariables + r1cs.NbSecretVariables - nbConstraints := r1cs.NbConstraints + nbConstraints := len(r1cs.Constraints) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints), 1, true) From 388366c14fd59a234bf9087237bb7efb81ba53b0 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 17 Sep 2021 09:11:46 -0500 Subject: [PATCH 06/15] refactor: factorized structs between compiled.SparseR1Cs and compiled.R1CS --- frontend/cs_to_r1cs.go | 18 +++--- frontend/cs_to_r1cs_sparse.go | 18 +++--- internal/backend/compiled/cs.go | 79 ++++++++++++++++++++++++ internal/backend/compiled/r1c.go | 54 ++-------------- internal/backend/compiled/r1cs.go | 68 +------------------- internal/backend/compiled/r1cs_sparse.go | 77 +---------------------- internal/backend/compiled/term.go | 36 ++++++++++- 7 files changed, 141 insertions(+), 209 deletions(-) create mode 100644 internal/backend/compiled/cs.go diff --git a/frontend/cs_to_r1cs.go b/frontend/cs_to_r1cs.go index fd77c53910..454fd5715f 100644 --- a/frontend/cs_to_r1cs.go +++ b/frontend/cs_to_r1cs.go @@ -18,14 +18,16 @@ func (cs *ConstraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er // setting up the result res := compiled.R1CS{ - NbInternalVariables: len(cs.internal.variables), - NbPublicVariables: len(cs.public.variables), - NbSecretVariables: len(cs.secret.variables), - Constraints: make([]compiled.R1C, len(cs.constraints)), - Logs: make([]compiled.LogEntry, len(cs.logs)), - DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), - Hints: make([]compiled.Hint, len(cs.hints)), - MDebug: make(map[int]int), + CS: compiled.CS{ + NbInternalVariables: len(cs.internal.variables), + NbPublicVariables: len(cs.public.variables), + NbSecretVariables: len(cs.secret.variables), + DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), + Logs: make([]compiled.LogEntry, len(cs.logs)), + Hints: make([]compiled.Hint, len(cs.hints)), + MDebug: make(map[int]int), + }, + Constraints: make([]compiled.R1C, len(cs.constraints)), } // computational constraints (= gates) diff --git a/frontend/cs_to_r1cs_sparse.go b/frontend/cs_to_r1cs_sparse.go index 6207a0ace6..b3082daf23 100644 --- a/frontend/cs_to_r1cs_sparse.go +++ b/frontend/cs_to_r1cs_sparse.go @@ -57,14 +57,16 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst res := sparseR1CS{ ConstraintSystem: cs, ccs: compiled.SparseR1CS{ - NbPublicVariables: len(cs.public.variables) - 1, // the ONE_WIRE is discarded as it is not used in PLONK - NbSecretVariables: len(cs.secret.variables), - NbInternalVariables: len(cs.internal.variables), - Constraints: make([]compiled.SparseR1C, 0, len(cs.constraints)), - Logs: make([]compiled.LogEntry, len(cs.logs)), - DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), - Hints: make([]compiled.Hint, len(cs.hints)), - MDebug: make(map[int]int), + CS: compiled.CS{ + NbInternalVariables: len(cs.internal.variables), + NbPublicVariables: len(cs.public.variables) - 1, // the ONE_WIRE is discarded in PlonK + NbSecretVariables: len(cs.secret.variables), + DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), + Logs: make([]compiled.LogEntry, len(cs.logs)), + Hints: make([]compiled.Hint, len(cs.hints)), + MDebug: make(map[int]int), + }, + Constraints: make([]compiled.SparseR1C, 0, len(cs.constraints)), }, solvedVariables: make([]bool, len(cs.internal.variables), len(cs.internal.variables)*2), scsInternalVariables: len(cs.internal.variables), diff --git a/internal/backend/compiled/cs.go b/internal/backend/compiled/cs.go new file mode 100644 index 0000000000..89688c1642 --- /dev/null +++ b/internal/backend/compiled/cs.go @@ -0,0 +1,79 @@ +package compiled + +import ( + "io" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/hint" +) + +// CS contains common element between R1CS and CS +type CS struct { + // number of wires + NbInternalVariables int + NbPublicVariables int + NbSecretVariables int + + // logs (added with cs.Println, resolved when solver sets a value to a wire) + Logs []LogEntry + + // debug info contains stack trace (including line number) of a call to a cs.API that + // results in an unsolved constraint + DebugInfo []LogEntry + + // hints + Hints []Hint + + // maps wire id to hint id + MHints map[int]int + + // maps constraint id to debugInfo id + MDebug map[int]int +} + +// Visibility encodes a Variable (or wire) visibility +// Possible values are Unset, Internal, Secret or Public +type Visibility uint8 + +const ( + Unset Visibility = iota + Internal + Secret + Public + Virtual +) + +// Hint represents a solver hint +// it enables the solver to compute a Wire with a function provided at solving time +// using pre-defined inputs +type Hint struct { + WireID int // resulting wire ID to compute + ID hint.ID // hint function id + Inputs []LinearExpression // terms to inject in the hint function +} + +// GetNbVariables return number of internal, secret and public variables +func (cs *CS) GetNbVariables() (internal, secret, public int) { + return cs.NbInternalVariables, cs.NbSecretVariables, cs.NbPublicVariables +} + +// FrSize panics +func (cs *CS) FrSize() int { panic("not implemented") } + +// GetNbCoefficients panics +func (cs *CS) GetNbCoefficients() int { panic("not implemented") } + +// CurveID returns ecc.UNKNOWN +func (cs *CS) CurveID() ecc.ID { return ecc.UNKNOWN } + +// WriteTo panics +func (cs *CS) WriteTo(w io.Writer) (n int64, err error) { panic("not implemented") } + +// ReadFrom panics +func (cs *CS) ReadFrom(r io.Reader) (n int64, err error) { panic("not implemented") } + +// SetLoggerOutput panics +func (cs *CS) SetLoggerOutput(w io.Writer) { panic("not implemented") } + +// ToHTML panics +func (cs *CS) ToHTML(w io.Writer) error { panic("not implemtened") } diff --git a/internal/backend/compiled/r1c.go b/internal/backend/compiled/r1c.go index f25c39335a..783b80e9a8 100644 --- a/internal/backend/compiled/r1c.go +++ b/internal/backend/compiled/r1c.go @@ -16,12 +16,14 @@ package compiled import ( "math/big" - "strconv" "strings" - - "github.com/consensys/gnark/backend/hint" ) +// R1C used to compute the wires +type R1C struct { + L, R, O LinearExpression +} + // LinearExpression represent a linear expression of variables type LinearExpression []Term @@ -52,32 +54,6 @@ func (l LinearExpression) Less(i, j int) bool { return iVis > jVis } -// R1C used to compute the wires -type R1C struct { - L, R, O LinearExpression -} - -// Visibility encodes a Variable (or wire) visibility -// Possible values are Unset, Internal, Secret or Public -type Visibility uint8 - -const ( - Unset Visibility = iota - Internal - Secret - Public - Virtual -) - -// Hint represents a solver hint -// it enables the solver to compute a Wire with a function provided at solving time -// using pre-defined inputs -type Hint struct { - WireID int // resulting wire ID to compute - ID hint.ID // hint function id - Inputs []LinearExpression // terms to inject in the hint function -} - func (r1c *R1C) String(coeffs []big.Int) string { var sbb strings.Builder sbb.WriteString("L[") @@ -99,23 +75,3 @@ func (l LinearExpression) string(sbb *strings.Builder, coeffs []big.Int) { } } } - -func (t Term) string(sbb *strings.Builder, coeffs []big.Int) { - sbb.WriteString(coeffs[t.CoeffID()].String()) - sbb.WriteString("*") - switch t.VariableVisibility() { - case Internal: - sbb.WriteString("i") - case Public: - sbb.WriteString("p") - case Secret: - sbb.WriteString("s") - case Virtual: - sbb.WriteString("v") - case Unset: - sbb.WriteString("u") - default: - panic("not implemented") - } - sbb.WriteString(strconv.Itoa(t.VariableID())) -} diff --git a/internal/backend/compiled/r1cs.go b/internal/backend/compiled/r1cs.go index 5e06de9241..8831f6da01 100644 --- a/internal/backend/compiled/r1cs.go +++ b/internal/backend/compiled/r1cs.go @@ -14,77 +14,13 @@ package compiled -import ( - "io" - - "github.com/consensys/gnark-crypto/ecc" -) - -// R1CS decsribes a set of R1CS constraint +// R1CS decsribes a set of R1C constraint type R1CS struct { - // Wires - NbInternalVariables int - NbPublicVariables int // includes ONE wire - NbSecretVariables int - Logs []LogEntry - DebugInfo []LogEntry - - // Constraints + CS Constraints []R1C - - // Hints - Hints []Hint - - MHints map[int]int // maps wire id to hint id - MDebug map[int]int // maps constraint id to debugInfo id } // GetNbConstraints returns the number of constraints func (r1cs *R1CS) GetNbConstraints() int { return len(r1cs.Constraints) } - -// GetNbVariables return number of internal, secret and public variables -func (r1cs *R1CS) GetNbVariables() (internal, secret, public int) { - internal = r1cs.NbInternalVariables - secret = r1cs.NbSecretVariables - public = r1cs.NbPublicVariables - return -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (r1cs *R1CS) GetNbCoefficients() int { - panic("not implemented") -} - -// CurveID returns ecc.UNKNOWN as this is a untyped R1CS using big.Int -func (r1cs *R1CS) CurveID() ecc.ID { - return ecc.UNKNOWN -} - -// FrSize panics -func (r1cs *R1CS) FrSize() int { - panic("not implemented") -} - -// WriteTo panics (can't serialize untyped R1CS) -func (r1cs *R1CS) WriteTo(w io.Writer) (n int64, err error) { - panic("not implemented") -} - -// ReadFrom panics (can't deserialize untyped R1CS) -func (r1cs *R1CS) ReadFrom(r io.Reader) (n int64, err error) { - panic("not implemented") -} - -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (r1cs *R1CS) SetLoggerOutput(w io.Writer) { - panic("not implemented") -} - -// ToHTML returns an HTML human-readable representation of the constraint system -func (r1cs *R1CS) ToHTML(w io.Writer) error { - panic("not implemented") -} diff --git a/internal/backend/compiled/r1cs_sparse.go b/internal/backend/compiled/r1cs_sparse.go index 840adb6ded..bcd59d5994 100644 --- a/internal/backend/compiled/r1cs_sparse.go +++ b/internal/backend/compiled/r1cs_sparse.go @@ -14,86 +14,13 @@ package compiled -import ( - "io" - - "github.com/consensys/gnark-crypto/ecc" -) - -// SparseR1CS represents a Plonk like circuit +// R1CS decsribes a set of SparseR1C constraint type SparseR1CS struct { - - // Variables [publicVariables| secretVariables | internalVariables ] - NbInternalVariables int - NbPublicVariables int - NbSecretVariables int - - // Constraints contains an ordered list of SparseR1C - // the solver will iterate through them and is guaranteed that there will be at most one - // unsolved wire per constraint + CS Constraints []SparseR1C - - // Logs (e.g. variables that have been printed using cs.Println) - Logs []LogEntry - DebugInfo []LogEntry - - // Hints - Hints []Hint - - MHints map[int]int // maps wire id to hint id - MDebug map[int]int // maps constraint id to debugInfo id -} - -// GetNbVariables return number of internal, secret and public variables -func (cs *SparseR1CS) GetNbVariables() (internal, secret, public int) { - internal = cs.NbInternalVariables - secret = cs.NbSecretVariables - public = cs.NbPublicVariables - return } // GetNbConstraints returns the number of constraints func (cs *SparseR1CS) GetNbConstraints() int { return len(cs.Constraints) } - -// GetNbWires returns the number of wires (internal) -func (cs *SparseR1CS) GetNbWires() int { - return cs.NbInternalVariables -} - -// FrSize panics -func (cs *SparseR1CS) FrSize() int { - panic("not implemented") -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - panic("not implemented") -} - -// CurveID returns ecc.UNKNOWN as this is a untyped R1CS using big.Int -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.UNKNOWN -} - -// WriteTo panics -func (cs *SparseR1CS) WriteTo(w io.Writer) (n int64, err error) { - panic("not implemented") -} - -// ReadFrom panics -func (cs *SparseR1CS) ReadFrom(r io.Reader) (n int64, err error) { - panic("not implemented") -} - -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { - panic("not implemented") -} - -func (cs *SparseR1CS) ToHTML(w io.Writer) error { - panic("not implemtened") -} diff --git a/internal/backend/compiled/term.go b/internal/backend/compiled/term.go index 766770f31b..6800544d0b 100644 --- a/internal/backend/compiled/term.go +++ b/internal/backend/compiled/term.go @@ -14,6 +14,12 @@ package compiled +import ( + "math/big" + "strconv" + "strings" +) + // Term lightweight version of a term, no pointers // first 4 bits are reserved // next 30 bits represented the coefficient idx (in r1cs.Coefficients) by which the wire is multiplied @@ -40,23 +46,27 @@ const ( const ( nbBitsVariableID = 29 nbBitsCoeffID = 30 - nbBitsFutureUse = 2 + nbBitsDelimitor = 1 + nbBitsFutureUse = 1 nbBitsVariableVisibility = 3 ) // TermDelimitor is reserved for internal use -const TermDelimitor Term = Term(maskFutureUse) +// the constraint solver will evaluate the sum of all terms appearing between two TermDelimitor +const TermDelimitor Term = Term(maskDelimitor) const ( shiftVariableID = 0 shiftCoeffID = nbBitsVariableID - shiftFutureUse = shiftCoeffID + nbBitsCoeffID + shiftDelimitor = shiftCoeffID + nbBitsCoeffID + shiftFutureUse = shiftDelimitor + nbBitsDelimitor shiftVariableVisibility = shiftFutureUse + nbBitsFutureUse ) const ( maskVariableID = uint64((1 << nbBitsVariableID) - 1) maskCoeffID = uint64((1<> shiftCoeffID) } + +func (t Term) string(sbb *strings.Builder, coeffs []big.Int) { + sbb.WriteString(coeffs[t.CoeffID()].String()) + sbb.WriteString("*") + switch t.VariableVisibility() { + case Internal: + sbb.WriteString("i") + case Public: + sbb.WriteString("p") + case Secret: + sbb.WriteString("s") + case Virtual: + sbb.WriteString("v") + case Unset: + sbb.WriteString("u") + default: + panic("not implemented") + } + sbb.WriteString(strconv.Itoa(t.VariableID())) +} From bddbbb6ea1e46f3403c09c67b46163e97180c86a Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 17 Sep 2021 10:58:39 -0500 Subject: [PATCH 07/15] refactor: save hints in a map in ConstraintSystem instead of slice --- frontend/cs.go | 10 ++--- frontend/cs_to_r1cs.go | 34 +++++++-------- frontend/cs_to_r1cs_sparse.go | 43 +++++++++---------- internal/backend/bls12-377/cs/r1cs.go | 4 +- internal/backend/bls12-377/cs/r1cs_sparse.go | 12 +++--- .../bls12-377/cs/{cs.go => solution.go} | 2 +- internal/backend/bls12-381/cs/r1cs.go | 4 +- internal/backend/bls12-381/cs/r1cs_sparse.go | 12 +++--- .../bls12-381/cs/{cs.go => solution.go} | 2 +- internal/backend/bls24-315/cs/r1cs.go | 4 +- internal/backend/bls24-315/cs/r1cs_sparse.go | 12 +++--- .../bls24-315/cs/{cs.go => solution.go} | 2 +- internal/backend/bn254/cs/r1cs.go | 4 +- internal/backend/bn254/cs/r1cs_sparse.go | 12 +++--- .../backend/bn254/cs/{cs.go => solution.go} | 2 +- internal/backend/bw6-761/cs/r1cs.go | 4 +- internal/backend/bw6-761/cs/r1cs_sparse.go | 12 +++--- .../backend/bw6-761/cs/{cs.go => solution.go} | 2 +- internal/backend/compiled/cs.go | 10 ++--- internal/generator/backend/main.go | 2 +- .../template/representations/r1cs.go.tmpl | 4 +- .../representations/r1cs.sparse.go.tmpl | 12 +++--- .../{cs.go.tmpl => solution.go.tmpl} | 2 +- 23 files changed, 97 insertions(+), 110 deletions(-) rename internal/backend/bls12-377/cs/{cs.go => solution.go} (98%) rename internal/backend/bls12-381/cs/{cs.go => solution.go} (98%) rename internal/backend/bls24-315/cs/{cs.go => solution.go} (98%) rename internal/backend/bn254/cs/{cs.go => solution.go} (98%) rename internal/backend/bw6-761/cs/{cs.go => solution.go} (98%) rename internal/generator/backend/template/representations/{cs.go.tmpl => solution.go.tmpl} (98%) diff --git a/frontend/cs.go b/frontend/cs.go index f789a4eb55..a72161819f 100644 --- a/frontend/cs.go +++ b/frontend/cs.go @@ -53,11 +53,8 @@ type ConstraintSystem struct { coeffsIDs map[string]int // map to fast check existence of a coefficient (key = coeff.Text(16)) // Hints - // TODO @gbotrel let's make it a map directly here. - hints []compiled.Hint // solver hints + mHints map[int]compiled.Hint // solver hints - // TODO @gbotrel we may want to make that optional through build tags - // debug info logs []compiled.LogEntry // list of logs to be printed when solving a circuit. The logs are called with the method Println debugInfo []compiled.LogEntry // list of logs storing information about R1C @@ -110,6 +107,7 @@ func newConstraintSystem(initialCapacity ...int) ConstraintSystem { coeffsIDs: make(map[string]int), constraints: make([]compiled.R1C, 0, capacity), mDebug: make(map[int]int), + mHints: make(map[int]compiled.Hint), } cs.coeffs[compiled.CoeffIdZero].SetInt64(0) @@ -132,8 +130,6 @@ func newConstraintSystem(initialCapacity ...int) ConstraintSystem { // by default the circuit is given on public wire equal to 1 cs.public.variables[0] = cs.newPublicVariable() - cs.hints = make([]compiled.Hint, 0) - return cs } @@ -160,7 +156,7 @@ func (cs *ConstraintSystem) NewHint(hintID hint.ID, inputs ...interface{}) Varia } // add the hint to the constraint system - cs.hints = append(cs.hints, compiled.Hint{WireID: r.id, ID: hintID, Inputs: hintInputs}) + cs.mHints[r.id] = compiled.Hint{ID: hintID, Inputs: hintInputs} return r } diff --git a/frontend/cs_to_r1cs.go b/frontend/cs_to_r1cs.go index 454fd5715f..722ecbbfdf 100644 --- a/frontend/cs_to_r1cs.go +++ b/frontend/cs_to_r1cs.go @@ -24,28 +24,26 @@ func (cs *ConstraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er NbSecretVariables: len(cs.secret.variables), DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), Logs: make([]compiled.LogEntry, len(cs.logs)), - Hints: make([]compiled.Hint, len(cs.hints)), + MHints: make(map[int]compiled.Hint, len(cs.mHints)), MDebug: make(map[int]int), }, Constraints: make([]compiled.R1C, len(cs.constraints)), } - // computational constraints (= gates) - copy(res.Constraints, cs.constraints) - + // for logs, debugInfo and hints the only thing that will change + // is that ID of the wires will be offseted to take into account the final wire vector ordering + // that is: public wires | secret wires | internal wires copy(res.Logs, cs.logs) copy(res.DebugInfo, cs.debugInfo) + // computational constraints (= gates) + copy(res.Constraints, cs.constraints) + + // for a R1CS, the correspondance between constraint and debug info won't change, we just copy for k, v := range cs.mDebug { res.MDebug[k] = v } - // note: verbose, but we offset the IDs of the wires where they appear, that is, - // in the logs, debug info, constraints and hints - // since we don't use pointers but Terms (uint64), we need to potentially offset - // the same wireID multiple times. - copy(res.Hints, cs.hints) - // offset variable ID depeneding on visibility shiftVID := func(oldID int, visibility compiled.Visibility) int { switch visibility { @@ -74,16 +72,14 @@ func (cs *ConstraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er } // we need to offset the ids in the hints - for i := 0; i < len(res.Hints); i++ { - res.Hints[i].WireID = shiftVID(res.Hints[i].WireID, compiled.Internal) - for j := 0; j < len(res.Hints[i].Inputs); j++ { - offsetIDs(res.Hints[i].Inputs[j]) + for vID, hint := range cs.mHints { + k := shiftVID(vID, compiled.Internal) + inputs := make([]compiled.LinearExpression, len(hint.Inputs)) + copy(inputs, hint.Inputs) + for j := 0; j < len(inputs); j++ { + offsetIDs(inputs[j]) } - } - - res.MHints = make(map[int]int, len(res.Hints)) - for i := 0; i < len(res.Hints); i++ { - res.MHints[res.Hints[i].WireID] = i + res.MHints[k] = compiled.Hint{ID: hint.ID, Inputs: inputs} } // we need to offset the ids in logs & debugInfo diff --git a/frontend/cs_to_r1cs_sparse.go b/frontend/cs_to_r1cs_sparse.go index b3082daf23..a923679904 100644 --- a/frontend/cs_to_r1cs_sparse.go +++ b/frontend/cs_to_r1cs_sparse.go @@ -63,8 +63,8 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst NbSecretVariables: len(cs.secret.variables), DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), Logs: make([]compiled.LogEntry, len(cs.logs)), - Hints: make([]compiled.Hint, len(cs.hints)), MDebug: make(map[int]int), + MHints: make(map[int]compiled.Hint), }, Constraints: make([]compiled.SparseR1C, 0, len(cs.constraints)), }, @@ -73,26 +73,26 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst currentR1CDebugID: -1, } - // note: verbose, but we offset the IDs of the wires where they appear, that is, - // in the logs, debug info, constraints and hints - // since we don't use pointers but Terms (uint64), we need to potentially offset - // the same wireID multiple times. - copy(res.ccs.Hints, cs.hints) - + // logs, debugInfo and hints are copied, the only thing that will change + // is that ID of the wires will be offseted to take into account the final wire vector ordering + // that is: public wires | secret wires | internal wires copy(res.ccs.Logs, cs.logs) copy(res.ccs.DebugInfo, cs.debugInfo) - // TODO @gbotrel we may not want to do that as it may hide some bugs - // if there is a R1C with several unsolved wires, wether they are hint wires or not - // will be problematic at solving time - for i := 0; i < len(cs.hints); i++ { - res.solvedVariables[cs.hints[i].WireID] = true + // we mark hint wires are solved + // each R1C from the frontend.ConstraintSystem is allowed to have at most one unsolved wire + // excluding hints. We mark hint wires as "solved" to ensure spliting R1C to SparseR1C + // won't create invalid SparseR1C constraint with more than one wire to solve for the solver + for vID := range cs.mHints { + res.solvedVariables[vID] = true } // convert the R1C to SparseR1C // in particular, all linear expressions that appear in the R1C // will be split in multiple constraints in the SparseR1C for i := 0; i < len(cs.constraints); i++ { + // we set currentR1CDebugID to the debugInfo ID corresponding to the R1C we're processing + // if present. All constraints created throuh addConstraint will add a new mapping if dID, ok := cs.mDebug[i]; ok { res.currentR1CDebugID = dID } else { @@ -161,19 +161,16 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst } // we need to offset the ids in the hints - for i := 0; i < len(res.ccs.Hints); i++ { - res.ccs.Hints[i].WireID = shiftVID(res.ccs.Hints[i].WireID, compiled.Internal) - for j := 0; j < len(res.ccs.Hints[i].Inputs); j++ { - l := res.ccs.Hints[i].Inputs[j] - for k := 0; k < len(l); k++ { - offsetTermID(&l[k]) + for vID, hint := range cs.mHints { + k := shiftVID(vID, compiled.Internal) + inputs := make([]compiled.LinearExpression, len(hint.Inputs)) + copy(inputs, hint.Inputs) + for j := 0; j < len(inputs); j++ { + for k := 0; k < len(inputs[j]); k++ { + offsetTermID(&inputs[j][k]) } } - } - - res.ccs.MHints = make(map[int]int, len(res.ccs.Hints)) - for i := 0; i < len(res.ccs.Hints); i++ { - res.ccs.MHints[res.ccs.Hints[i].WireID] = i + res.ccs.MHints[k] = compiled.Hint{ID: hint.ID, Inputs: inputs} } // update number of internal variables with new wires created diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index c2a7ad86c2..0db969ba4e 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -207,8 +207,8 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { } // first we check if this is a hint wire - if hID, ok := cs.MHints[vID]; ok { - return solution.solveHint(cs.Hints[hID], vID) + if hint, ok := cs.MHints[vID]; ok { + return solution.solveWithHint(vID, hint) } if loc != 0 { diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go index 53e2633a05..59f8c4bf43 100644 --- a/internal/backend/bls12-377/cs/r1cs_sparse.go +++ b/internal/backend/bls12-377/cs/r1cs_sparse.go @@ -133,8 +133,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.MHints[lID]; ok { - if err := solution.solveHint(cs.Hints[hID], lID); err != nil { + if hint, ok := cs.MHints[lID]; ok { + if err := solution.solveWithHint(lID, hint); err != nil { return -1, err } } else { @@ -145,8 +145,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.MHints[rID]; ok { - if err := solution.solveHint(cs.Hints[hID], rID); err != nil { + if hint, ok := cs.MHints[rID]; ok { + if err := solution.solveWithHint(rID, hint); err != nil { return -1, err } } else { @@ -156,8 +156,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.MHints[oID]; ok { - if err := solution.solveHint(cs.Hints[hID], oID); err != nil { + if hint, ok := cs.MHints[oID]; ok { + if err := solution.solveWithHint(oID, hint); err != nil { return -1, err } } else { diff --git a/internal/backend/bls12-377/cs/cs.go b/internal/backend/bls12-377/cs/solution.go similarity index 98% rename from internal/backend/bls12-377/cs/cs.go rename to internal/backend/bls12-377/cs/solution.go index 783b712b2b..6c34e817f1 100644 --- a/internal/backend/bls12-377/cs/cs.go +++ b/internal/backend/bls12-377/cs/solution.go @@ -105,7 +105,7 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } // solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveHint(h compiled.Hint, vID int) error { +func (s *solution) solveWithHint(vID int, h compiled.Hint) error { // compute values for all inputs. inputs := make([]fr.Element, len(h.Inputs)) diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index 4d74554a33..d49ccc53d3 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -207,8 +207,8 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { } // first we check if this is a hint wire - if hID, ok := cs.MHints[vID]; ok { - return solution.solveHint(cs.Hints[hID], vID) + if hint, ok := cs.MHints[vID]; ok { + return solution.solveWithHint(vID, hint) } if loc != 0 { diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go index dc56b51dda..6ef64f3ef9 100644 --- a/internal/backend/bls12-381/cs/r1cs_sparse.go +++ b/internal/backend/bls12-381/cs/r1cs_sparse.go @@ -133,8 +133,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.MHints[lID]; ok { - if err := solution.solveHint(cs.Hints[hID], lID); err != nil { + if hint, ok := cs.MHints[lID]; ok { + if err := solution.solveWithHint(lID, hint); err != nil { return -1, err } } else { @@ -145,8 +145,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.MHints[rID]; ok { - if err := solution.solveHint(cs.Hints[hID], rID); err != nil { + if hint, ok := cs.MHints[rID]; ok { + if err := solution.solveWithHint(rID, hint); err != nil { return -1, err } } else { @@ -156,8 +156,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.MHints[oID]; ok { - if err := solution.solveHint(cs.Hints[hID], oID); err != nil { + if hint, ok := cs.MHints[oID]; ok { + if err := solution.solveWithHint(oID, hint); err != nil { return -1, err } } else { diff --git a/internal/backend/bls12-381/cs/cs.go b/internal/backend/bls12-381/cs/solution.go similarity index 98% rename from internal/backend/bls12-381/cs/cs.go rename to internal/backend/bls12-381/cs/solution.go index 321de6f9d3..7ca64a9afd 100644 --- a/internal/backend/bls12-381/cs/cs.go +++ b/internal/backend/bls12-381/cs/solution.go @@ -105,7 +105,7 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } // solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveHint(h compiled.Hint, vID int) error { +func (s *solution) solveWithHint(vID int, h compiled.Hint) error { // compute values for all inputs. inputs := make([]fr.Element, len(h.Inputs)) diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index b4ed832598..8f5f6814f5 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -207,8 +207,8 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { } // first we check if this is a hint wire - if hID, ok := cs.MHints[vID]; ok { - return solution.solveHint(cs.Hints[hID], vID) + if hint, ok := cs.MHints[vID]; ok { + return solution.solveWithHint(vID, hint) } if loc != 0 { diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go index b496e450ce..dea13c2836 100644 --- a/internal/backend/bls24-315/cs/r1cs_sparse.go +++ b/internal/backend/bls24-315/cs/r1cs_sparse.go @@ -133,8 +133,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.MHints[lID]; ok { - if err := solution.solveHint(cs.Hints[hID], lID); err != nil { + if hint, ok := cs.MHints[lID]; ok { + if err := solution.solveWithHint(lID, hint); err != nil { return -1, err } } else { @@ -145,8 +145,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.MHints[rID]; ok { - if err := solution.solveHint(cs.Hints[hID], rID); err != nil { + if hint, ok := cs.MHints[rID]; ok { + if err := solution.solveWithHint(rID, hint); err != nil { return -1, err } } else { @@ -156,8 +156,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.MHints[oID]; ok { - if err := solution.solveHint(cs.Hints[hID], oID); err != nil { + if hint, ok := cs.MHints[oID]; ok { + if err := solution.solveWithHint(oID, hint); err != nil { return -1, err } } else { diff --git a/internal/backend/bls24-315/cs/cs.go b/internal/backend/bls24-315/cs/solution.go similarity index 98% rename from internal/backend/bls24-315/cs/cs.go rename to internal/backend/bls24-315/cs/solution.go index cc0bcbb8cd..70515ea6ba 100644 --- a/internal/backend/bls24-315/cs/cs.go +++ b/internal/backend/bls24-315/cs/solution.go @@ -105,7 +105,7 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } // solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveHint(h compiled.Hint, vID int) error { +func (s *solution) solveWithHint(vID int, h compiled.Hint) error { // compute values for all inputs. inputs := make([]fr.Element, len(h.Inputs)) diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index f40f5c1382..c1b0a0a4c3 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -207,8 +207,8 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { } // first we check if this is a hint wire - if hID, ok := cs.MHints[vID]; ok { - return solution.solveHint(cs.Hints[hID], vID) + if hint, ok := cs.MHints[vID]; ok { + return solution.solveWithHint(vID, hint) } if loc != 0 { diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go index 9c11079791..2123014785 100644 --- a/internal/backend/bn254/cs/r1cs_sparse.go +++ b/internal/backend/bn254/cs/r1cs_sparse.go @@ -133,8 +133,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.MHints[lID]; ok { - if err := solution.solveHint(cs.Hints[hID], lID); err != nil { + if hint, ok := cs.MHints[lID]; ok { + if err := solution.solveWithHint(lID, hint); err != nil { return -1, err } } else { @@ -145,8 +145,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.MHints[rID]; ok { - if err := solution.solveHint(cs.Hints[hID], rID); err != nil { + if hint, ok := cs.MHints[rID]; ok { + if err := solution.solveWithHint(rID, hint); err != nil { return -1, err } } else { @@ -156,8 +156,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.MHints[oID]; ok { - if err := solution.solveHint(cs.Hints[hID], oID); err != nil { + if hint, ok := cs.MHints[oID]; ok { + if err := solution.solveWithHint(oID, hint); err != nil { return -1, err } } else { diff --git a/internal/backend/bn254/cs/cs.go b/internal/backend/bn254/cs/solution.go similarity index 98% rename from internal/backend/bn254/cs/cs.go rename to internal/backend/bn254/cs/solution.go index 1a25e1881f..07bc4b921a 100644 --- a/internal/backend/bn254/cs/cs.go +++ b/internal/backend/bn254/cs/solution.go @@ -105,7 +105,7 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } // solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveHint(h compiled.Hint, vID int) error { +func (s *solution) solveWithHint(vID int, h compiled.Hint) error { // compute values for all inputs. inputs := make([]fr.Element, len(h.Inputs)) diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index 9cdd80224c..96c3650bd6 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -207,8 +207,8 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { } // first we check if this is a hint wire - if hID, ok := cs.MHints[vID]; ok { - return solution.solveHint(cs.Hints[hID], vID) + if hint, ok := cs.MHints[vID]; ok { + return solution.solveWithHint(vID, hint) } if loc != 0 { diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go index 96494976e8..9137ee9f5d 100644 --- a/internal/backend/bw6-761/cs/r1cs_sparse.go +++ b/internal/backend/bw6-761/cs/r1cs_sparse.go @@ -133,8 +133,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.MHints[lID]; ok { - if err := solution.solveHint(cs.Hints[hID], lID); err != nil { + if hint, ok := cs.MHints[lID]; ok { + if err := solution.solveWithHint(lID, hint); err != nil { return -1, err } } else { @@ -145,8 +145,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.MHints[rID]; ok { - if err := solution.solveHint(cs.Hints[hID], rID); err != nil { + if hint, ok := cs.MHints[rID]; ok { + if err := solution.solveWithHint(rID, hint); err != nil { return -1, err } } else { @@ -156,8 +156,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.MHints[oID]; ok { - if err := solution.solveHint(cs.Hints[hID], oID); err != nil { + if hint, ok := cs.MHints[oID]; ok { + if err := solution.solveWithHint(oID, hint); err != nil { return -1, err } } else { diff --git a/internal/backend/bw6-761/cs/cs.go b/internal/backend/bw6-761/cs/solution.go similarity index 98% rename from internal/backend/bw6-761/cs/cs.go rename to internal/backend/bw6-761/cs/solution.go index c545fa7910..ebb1d6a669 100644 --- a/internal/backend/bw6-761/cs/cs.go +++ b/internal/backend/bw6-761/cs/solution.go @@ -105,7 +105,7 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } // solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveHint(h compiled.Hint, vID int) error { +func (s *solution) solveWithHint(vID int, h compiled.Hint) error { // compute values for all inputs. inputs := make([]fr.Element, len(h.Inputs)) diff --git a/internal/backend/compiled/cs.go b/internal/backend/compiled/cs.go index 89688c1642..92e5cbb94a 100644 --- a/internal/backend/compiled/cs.go +++ b/internal/backend/compiled/cs.go @@ -21,13 +21,12 @@ type CS struct { // results in an unsolved constraint DebugInfo []LogEntry - // hints - Hints []Hint - - // maps wire id to hint id - MHints map[int]int + // maps wire id to hint + // a wire may point to at most one hint + MHints map[int]Hint // maps constraint id to debugInfo id + // several constraints may point to the same debug info MDebug map[int]int } @@ -47,7 +46,6 @@ const ( // it enables the solver to compute a Wire with a function provided at solving time // using pre-defined inputs type Hint struct { - WireID int // resulting wire ID to compute ID hint.ID // hint function id Inputs []LinearExpression // terms to inject in the hint function } diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 5b406b972c..90e3562a8d 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -76,7 +76,7 @@ func main() { entries := []bavard.Entry{ {File: filepath.Join(backendCSDir, "r1cs.go"), Templates: []string{"r1cs.go.tmpl", importCurve}}, {File: filepath.Join(backendCSDir, "r1cs_sparse.go"), Templates: []string{"r1cs.sparse.go.tmpl", importCurve}}, - {File: filepath.Join(backendCSDir, "cs.go"), Templates: []string{"cs.go.tmpl", importCurve}}, + {File: filepath.Join(backendCSDir, "solution.go"), Templates: []string{"solution.go.tmpl", importCurve}}, {File: filepath.Join(backendCSDir, "hints.go"), Templates: []string{"hints.go.tmpl", importCurve}}, } if err := bgen.Generate(d, "cs", "./template/representations/", entries...); err != nil { diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index 32f335387c..5b7761f857 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -201,8 +201,8 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { } // first we check if this is a hint wire - if hID, ok := cs.MHints[vID];ok { - return solution.solveHint(cs.Hints[hID], vID) + if hint, ok := cs.MHints[vID]; ok { + return solution.solveWithHint(vID, hint) } if loc != 0 { diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl index 821273057d..5421d62b3f 100644 --- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl @@ -123,8 +123,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) ( i if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.MHints[lID]; ok { - if err := solution.solveHint(cs.Hints[hID],lID); err != nil { + if hint, ok := cs.MHints[lID]; ok { + if err := solution.solveWithHint(lID, hint); err != nil { return -1, err } } else { @@ -135,8 +135,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) ( i if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.MHints[rID]; ok { - if err := solution.solveHint(cs.Hints[hID],rID); err != nil { + if hint, ok := cs.MHints[rID]; ok { + if err := solution.solveWithHint(rID, hint); err != nil { return -1, err } } else { @@ -146,8 +146,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) ( i if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.MHints[oID]; ok { - if err := solution.solveHint(cs.Hints[hID],oID); err != nil { + if hint, ok := cs.MHints[oID]; ok { + if err := solution.solveWithHint(oID, hint); err != nil { return -1, err } } else { diff --git a/internal/generator/backend/template/representations/cs.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl similarity index 98% rename from internal/generator/backend/template/representations/cs.go.tmpl rename to internal/generator/backend/template/representations/solution.go.tmpl index c94fd17839..46a1630fac 100644 --- a/internal/generator/backend/template/representations/cs.go.tmpl +++ b/internal/generator/backend/template/representations/solution.go.tmpl @@ -91,7 +91,7 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { // solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveHint(h compiled.Hint, vID int) error { +func (s *solution) solveWithHint(vID int, h compiled.Hint) error { // compute values for all inputs. inputs := make([]fr.Element, len(h.Inputs)) From a20efb93aab4a113afe813ad4e6e88e0b3f17e4c Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 17 Sep 2021 11:53:25 -0500 Subject: [PATCH 08/15] refactor: groth16.Prove and plonk.Prove takes backend.ProverOption as parameter --- backend/backend.go | 52 +++++++++++++++++++ backend/groth16/assert.go | 41 +++++++++------ backend/groth16/groth16.go | 41 ++++++++------- backend/plonk/assert.go | 28 ++++++---- backend/plonk/plonk.go | 40 +++++++------- frontend/cs.go | 3 -- integration_test.go | 12 ++--- internal/backend/bls12-377/cs/r1cs.go | 22 +++----- internal/backend/bls12-377/cs/r1cs_sparse.go | 20 ++++--- internal/backend/bls12-377/cs/r1cs_test.go | 3 -- .../backend/bls12-377/groth16/groth16_test.go | 6 +-- internal/backend/bls12-377/groth16/prove.go | 10 ++-- .../backend/bls12-377/plonk/plonk_test.go | 6 +-- internal/backend/bls12-377/plonk/prove.go | 8 +-- internal/backend/bls12-381/cs/r1cs.go | 22 +++----- internal/backend/bls12-381/cs/r1cs_sparse.go | 20 ++++--- internal/backend/bls12-381/cs/r1cs_test.go | 3 -- .../backend/bls12-381/groth16/groth16_test.go | 6 +-- internal/backend/bls12-381/groth16/prove.go | 10 ++-- .../backend/bls12-381/plonk/plonk_test.go | 6 +-- internal/backend/bls12-381/plonk/prove.go | 8 +-- internal/backend/bls24-315/cs/r1cs.go | 22 +++----- internal/backend/bls24-315/cs/r1cs_sparse.go | 20 ++++--- internal/backend/bls24-315/cs/r1cs_test.go | 3 -- .../backend/bls24-315/groth16/groth16_test.go | 6 +-- internal/backend/bls24-315/groth16/prove.go | 10 ++-- .../backend/bls24-315/plonk/plonk_test.go | 6 +-- internal/backend/bls24-315/plonk/prove.go | 8 +-- internal/backend/bn254/cs/r1cs.go | 22 +++----- internal/backend/bn254/cs/r1cs_sparse.go | 20 ++++--- internal/backend/bn254/cs/r1cs_test.go | 3 -- .../backend/bn254/groth16/groth16_test.go | 6 +-- internal/backend/bn254/groth16/prove.go | 10 ++-- internal/backend/bn254/plonk/plonk_test.go | 6 +-- internal/backend/bn254/plonk/prove.go | 8 +-- internal/backend/bw6-761/cs/r1cs.go | 22 +++----- internal/backend/bw6-761/cs/r1cs_sparse.go | 20 ++++--- internal/backend/bw6-761/cs/r1cs_test.go | 3 -- .../backend/bw6-761/groth16/groth16_test.go | 6 +-- internal/backend/bw6-761/groth16/prove.go | 10 ++-- internal/backend/bw6-761/plonk/plonk_test.go | 6 +-- internal/backend/bw6-761/plonk/prove.go | 8 +-- internal/backend/compiled/cs.go | 3 -- .../template/representations/r1cs.go.tmpl | 21 +++----- .../representations/r1cs.sparse.go.tmpl | 20 ++++--- .../representations/tests/r1cs.go.tmpl | 3 -- .../zkpschemes/groth16/groth16.prove.go.tmpl | 10 ++-- .../zkpschemes/groth16/tests/groth16.go.tmpl | 6 +-- .../zkpschemes/plonk/plonk.prove.go.tmpl | 8 +-- .../zkpschemes/plonk/tests/plonk.go.tmpl | 6 +-- std/groth16/verifier_test.go | 2 +- 51 files changed, 341 insertions(+), 330 deletions(-) diff --git a/backend/backend.go b/backend/backend.go index 0873e69c24..a9f3cc15f6 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -15,6 +15,13 @@ // Package backend implements Zero Knowledge Proof systems: it consumes circuit compiled with gnark/frontend. package backend +import ( + "io" + "os" + + "github.com/consensys/gnark/backend/hint" +) + // ID represent a unique ID for a proving scheme type ID uint16 @@ -40,3 +47,48 @@ func (id ID) String() string { return "unknown" } } + +// NewProverOption returns a default ProverOption with given options applied +func NewProverOption(opts ...func(opt *ProverOption) error) (ProverOption, error) { + opt := ProverOption{LoggerOut: os.Stdout} + for _, option := range opts { + if err := option(&opt); err != nil { + return ProverOption{}, err + } + } + return opt, nil +} + +// ProverOption is shared accross backends to parametrize calls to xxx.Prove(...) +type ProverOption struct { + Force bool // default to false + HintFunctions []hint.Function // default to nil (use only solver std hints) + LoggerOut io.Writer // default to os.Stdout +} + +// IgnoreSolverError is a ProverOption that indicates that the Prove algorithm +// should complete, even if constraint system is not solved. +// In that case, Prove will output an invalid Proof, but will execute all algorithms +// which is useful for test and benchmarking purposes +func IgnoreSolverError(opt *ProverOption) error { + opt.Force = true + return nil +} + +// WithHints is a Prover option that specifies additional hint functions to be used +// by the constraint solver +func WithHints(hintFunctions ...hint.Function) func(opt *ProverOption) error { + return func(opt *ProverOption) error { + opt.HintFunctions = append(opt.HintFunctions, hintFunctions...) + return nil + } +} + +// WithOutput is a Prover option that specifies an io.Writer as destination for logs printed by +// cs.Println(). If set to nil, no logs are printed. +func WithOutput(w io.Writer) func(opt *ProverOption) error { + return func(opt *ProverOption) error { + opt.LoggerOut = w + return nil + } +} diff --git a/backend/groth16/assert.go b/backend/groth16/assert.go index 068f22311d..edda493ccf 100644 --- a/backend/groth16/assert.go +++ b/backend/groth16/assert.go @@ -20,7 +20,7 @@ import ( "reflect" "testing" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" backend_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" witness_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/witness" @@ -47,11 +47,11 @@ func NewAssert(t *testing.T) *Assert { } // ProverFailed check that a witness does NOT solve a circuit -func (assert *Assert) ProverFailed(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) { +func (assert *Assert) ProverFailed(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) { // setup pk, err := DummySetup(r1cs) assert.NoError(err) - _, err = Prove(r1cs, pk, witness, hintFunctions) + _, err = Prove(r1cs, pk, witness, opts...) assert.Error(err, "proving with bad witness should output an error") } @@ -68,7 +68,7 @@ func (assert *Assert) ProverFailed(r1cs frontend.CompiledConstraintSystem, witne // 5. Ensure deserialization(serialization) of generated objects is correct // // ensure result vectors a*b=c, and check other properties like random sampling -func (assert *Assert) ProverSucceeded(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) { +func (assert *Assert) ProverSucceeded(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) { // setup pk, vk, err := Setup(r1cs) assert.NoError(err) @@ -84,17 +84,17 @@ func (assert *Assert) ProverSucceeded(r1cs frontend.CompiledConstraintSystem, wi } // ensure expected Values are computed correctly - assert.SolvingSucceeded(r1cs, witness, hintFunctions...) + assert.SolvingSucceeded(r1cs, witness, opts...) // extract full witness & public witness // prover - proof, err := Prove(r1cs, pk, witness, hintFunctions) + proof, err := Prove(r1cs, pk, witness, opts...) assert.NoError(err, "proving with good witness should not output an error") // ensure random sampling; calling prove twice with same witness should produce different proof { - proof2, err := Prove(r1cs, pk, witness, hintFunctions) + proof2, err := Prove(r1cs, pk, witness, opts...) assert.NoError(err, "proving with good witness should not output an error") assert.False(reflect.DeepEqual(proof, proof2), "calling prove twice with same input should produce different proof") } @@ -130,49 +130,56 @@ func (assert *Assert) SerializationRawSucceeded(from gnarkio.WriterRawTo, to io. } // SolvingSucceeded Verifies that the R1CS is solved with the given witness, without executing groth16 workflow -func (assert *Assert) SolvingSucceeded(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) { - assert.NoError(IsSolved(r1cs, witness, hintFunctions)) +func (assert *Assert) SolvingSucceeded(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) { + assert.NoError(IsSolved(r1cs, witness, opts...)) } // SolvingFailed Verifies that the R1CS is not solved with the given witness, without executing groth16 workflow -func (assert *Assert) SolvingFailed(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) { - assert.Error(IsSolved(r1cs, witness, hintFunctions)) +func (assert *Assert) SolvingFailed(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) { + assert.Error(IsSolved(r1cs, witness, opts...)) } // IsSolved attempts to solve the constraint system with provided witness // returns nil if it succeeds, error otherwise. -func IsSolved(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions []hint.Function) error { +func IsSolved(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) error { + + // apply options + opt, err := backend.NewProverOption(opts...) + if err != nil { + return err + } + switch _r1cs := r1cs.(type) { case *backend_bls12377.R1CS: w := witness_bls12377.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return _r1cs.IsSolved(w, hintFunctions) + return _r1cs.IsSolved(w, opt) case *backend_bls12381.R1CS: w := witness_bls12381.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return _r1cs.IsSolved(w, hintFunctions) + return _r1cs.IsSolved(w, opt) case *backend_bn254.R1CS: w := witness_bn254.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return _r1cs.IsSolved(w, hintFunctions) + return _r1cs.IsSolved(w, opt) case *backend_bw6761.R1CS: w := witness_bw6761.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return _r1cs.IsSolved(w, hintFunctions) + return _r1cs.IsSolved(w, opt) case *backend_bls24315.R1CS: w := witness_bls24315.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return _r1cs.IsSolved(w, hintFunctions) + return _r1cs.IsSolved(w, opt) default: panic("unrecognized R1CS curve type") } diff --git a/backend/groth16/groth16.go b/backend/groth16/groth16.go index 8985e10eff..d159d523c3 100644 --- a/backend/groth16/groth16.go +++ b/backend/groth16/groth16.go @@ -24,7 +24,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" backend_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" backend_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/cs" @@ -188,11 +188,12 @@ func ReadAndVerify(proof Proof, vk VerifyingKey, publicWitness io.Reader) error // will executes all the prover computations, even if the witness is invalid // will produce an invalid proof // internally, the solution vector to the R1CS will be filled with random values which may impact benchmarking -func Prove(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness frontend.Circuit, hintFunctions []hint.Function, force ...bool) (Proof, error) { +func Prove(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) (Proof, error) { - _force := false - if len(force) > 0 { - _force = force[0] + // apply options + opt, err := backend.NewProverOption(opts...) + if err != nil { + return nil, err } switch _r1cs := r1cs.(type) { @@ -201,31 +202,31 @@ func Prove(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness fronte if err := w.FromFullAssignment(witness); err != nil { return nil, err } - return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, hintFunctions, _force) + return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, opt) case *backend_bls12381.R1CS: w := witness_bls12381.Witness{} if err := w.FromFullAssignment(witness); err != nil { return nil, err } - return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, hintFunctions, _force) + return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, opt) case *backend_bn254.R1CS: w := witness_bn254.Witness{} if err := w.FromFullAssignment(witness); err != nil { return nil, err } - return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, hintFunctions, _force) + return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, opt) case *backend_bw6761.R1CS: w := witness_bw6761.Witness{} if err := w.FromFullAssignment(witness); err != nil { return nil, err } - return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, hintFunctions, _force) + return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, opt) case *backend_bls24315.R1CS: w := witness_bls24315.Witness{} if err := w.FromFullAssignment(witness); err != nil { return nil, err } - return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, hintFunctions, _force) + return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, opt) default: panic("unrecognized R1CS curve type") } @@ -234,10 +235,12 @@ func Prove(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness fronte // ReadAndProve behaves like Prove, , except witness is read from a io.Reader // witness must be encoded following the binary serialization protocol described in // gnark/backend/witness package -func ReadAndProve(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness io.Reader, hintFunctions []hint.Function, force ...bool) (Proof, error) { - _force := false - if len(force) > 0 { - _force = force[0] +func ReadAndProve(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness io.Reader, opts ...func(opt *backend.ProverOption) error) (Proof, error) { + + // apply options + opt, err := backend.NewProverOption(opts...) + if err != nil { + return nil, err } _, nbSecret, nbPublic := r1cs.GetNbVariables() @@ -249,31 +252,31 @@ func ReadAndProve(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, hintFunctions, _force) + return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, opt) case *backend_bls12381.R1CS: w := witness_bls12381.Witness{} if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, hintFunctions, _force) + return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, opt) case *backend_bn254.R1CS: w := witness_bn254.Witness{} if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, hintFunctions, _force) + return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, opt) case *backend_bw6761.R1CS: w := witness_bw6761.Witness{} if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, hintFunctions, _force) + return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, opt) case *backend_bls24315.R1CS: w := witness_bls24315.Witness{} if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, hintFunctions, _force) + return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, opt) default: panic("unrecognized R1CS curve type") } diff --git a/backend/plonk/assert.go b/backend/plonk/assert.go index cd60a7d583..6a3419bfd6 100644 --- a/backend/plonk/assert.go +++ b/backend/plonk/assert.go @@ -20,7 +20,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" cs_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" @@ -54,7 +54,7 @@ func NewAssert(t *testing.T) *Assert { return &Assert{require.New(t)} } -func (assert *Assert) ProverSucceeded(ccs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) { +func (assert *Assert) ProverSucceeded(ccs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) { // checks if the system is solvable assert.SolvingSucceeded(ccs, witness) @@ -66,7 +66,7 @@ func (assert *Assert) ProverSucceeded(ccs frontend.CompiledConstraintSystem, wit assert.NoError(err, "Generating public data should not have failed") // generates the proof - proof, err := Prove(ccs, pk, witness, hintFunctions) + proof, err := Prove(ccs, pk, witness, opts...) assert.NoError(err, "Proving with good witness should not output an error") // verifies the proof @@ -75,7 +75,7 @@ func (assert *Assert) ProverSucceeded(ccs frontend.CompiledConstraintSystem, wit } -func (assert *Assert) ProverFailed(ccs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) { +func (assert *Assert) ProverFailed(ccs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) { // generates public data srs, err := newKZGSrs(ccs) @@ -84,7 +84,7 @@ func (assert *Assert) ProverFailed(ccs frontend.CompiledConstraintSystem, witnes assert.NoError(err, "Generating public data should not have failed") // generates the proof - _, err = Prove(ccs, pk, witness, hintFunctions) + _, err = Prove(ccs, pk, witness, opts...) assert.Error(err, "generating an incorrect proof should output an error") } @@ -100,38 +100,44 @@ func (assert *Assert) SolvingFailed(ccs frontend.CompiledConstraintSystem, witne // IsSolved attempts to solve the constraint system with provided witness // returns nil if it succeeds, error otherwise. -func IsSolved(ccs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) error { +func IsSolved(ccs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) error { + + opt, err := backend.NewProverOption(opts...) + if err != nil { + return err + } + switch tccs := ccs.(type) { case *cs_bn254.SparseR1CS: w := witness_bn254.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return tccs.IsSolved(w, hintFunctions) + return tccs.IsSolved(w, opt) case *cs_bls12381.SparseR1CS: w := witness_bls12381.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return tccs.IsSolved(w, hintFunctions) + return tccs.IsSolved(w, opt) case *cs_bls12377.SparseR1CS: w := witness_bls12377.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return tccs.IsSolved(w, hintFunctions) + return tccs.IsSolved(w, opt) case *cs_bw6761.SparseR1CS: w := witness_bw6761.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return tccs.IsSolved(w, hintFunctions) + return tccs.IsSolved(w, opt) case *cs_bls24315.SparseR1CS: w := witness_bls24315.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return tccs.IsSolved(w, hintFunctions) + return tccs.IsSolved(w, opt) default: panic("unknown constraint system type") } diff --git a/backend/plonk/plonk.go b/backend/plonk/plonk.go index 0437075fc8..2f456b9d24 100644 --- a/backend/plonk/plonk.go +++ b/backend/plonk/plonk.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/kzg" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" cs_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" @@ -158,11 +158,12 @@ func Setup(ccs frontend.CompiledConstraintSystem, kzgSRS kzg.SRS) (ProvingKey, V // will executes all the prover computations, even if the witness is invalid // will produce an invalid proof // internally, the solution vector to the SparseR1CS will be filled with random values which may impact benchmarking -func Prove(ccs frontend.CompiledConstraintSystem, pk ProvingKey, fullWitness frontend.Circuit, hintFunctions []hint.Function, force ...bool) (Proof, error) { +func Prove(ccs frontend.CompiledConstraintSystem, pk ProvingKey, fullWitness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) (Proof, error) { - _force := false - if len(force) > 0 { - _force = force[0] + // apply options + opt, err := backend.NewProverOption(opts...) + if err != nil { + return nil, err } switch tccs := ccs.(type) { @@ -171,35 +172,35 @@ func Prove(ccs frontend.CompiledConstraintSystem, pk ProvingKey, fullWitness fro if err := w.FromFullAssignment(fullWitness); err != nil { return nil, err } - return plonk_bn254.Prove(tccs, pk.(*plonk_bn254.ProvingKey), w, hintFunctions, _force) + return plonk_bn254.Prove(tccs, pk.(*plonk_bn254.ProvingKey), w, opt) case *cs_bls12381.SparseR1CS: w := witness_bls12381.Witness{} if err := w.FromFullAssignment(fullWitness); err != nil { return nil, err } - return plonk_bls12381.Prove(tccs, pk.(*plonk_bls12381.ProvingKey), w, hintFunctions, _force) + return plonk_bls12381.Prove(tccs, pk.(*plonk_bls12381.ProvingKey), w, opt) case *cs_bls12377.SparseR1CS: w := witness_bls12377.Witness{} if err := w.FromFullAssignment(fullWitness); err != nil { return nil, err } - return plonk_bls12377.Prove(tccs, pk.(*plonk_bls12377.ProvingKey), w, hintFunctions, _force) + return plonk_bls12377.Prove(tccs, pk.(*plonk_bls12377.ProvingKey), w, opt) case *cs_bw6761.SparseR1CS: w := witness_bw6761.Witness{} if err := w.FromFullAssignment(fullWitness); err != nil { return nil, err } - return plonk_bw6761.Prove(tccs, pk.(*plonk_bw6761.ProvingKey), w, hintFunctions, _force) + return plonk_bw6761.Prove(tccs, pk.(*plonk_bw6761.ProvingKey), w, opt) case *cs_bls24315.SparseR1CS: w := witness_bls24315.Witness{} if err := w.FromFullAssignment(fullWitness); err != nil { return nil, err } - return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), w, hintFunctions, _force) + return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), w, opt) default: panic("unrecognized SparseR1CS curve type") @@ -339,11 +340,12 @@ func NewVerifyingKey(curveID ecc.ID) VerifyingKey { } // ReadAndProve generates PLONK proof from a circuit, associated proving key, and the full witness -func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness io.Reader, hintFunctions []hint.Function, force ...bool) (Proof, error) { +func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness io.Reader, opts ...func(opt *backend.ProverOption) error) (Proof, error) { - _force := false - if len(force) > 0 { - _force = force[0] + // apply options + opt, err := backend.NewProverOption(opts...) + if err != nil { + return nil, err } _, nbSecret, nbPublic := ccs.GetNbVariables() @@ -356,7 +358,7 @@ func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - proof, err := plonk_bn254.Prove(tccs, _pk, w, hintFunctions, _force) + proof, err := plonk_bn254.Prove(tccs, _pk, w, opt) if err != nil { return proof, err } @@ -368,7 +370,7 @@ func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - proof, err := plonk_bls12381.Prove(tccs, _pk, w, hintFunctions, _force) + proof, err := plonk_bls12381.Prove(tccs, _pk, w, opt) if err != nil { return proof, err } @@ -380,7 +382,7 @@ func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - proof, err := plonk_bls12377.Prove(tccs, _pk, w, hintFunctions, _force) + proof, err := plonk_bls12377.Prove(tccs, _pk, w, opt) if err != nil { return proof, err } @@ -392,7 +394,7 @@ func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - proof, err := plonk_bw6761.Prove(tccs, _pk, w, hintFunctions, _force) + proof, err := plonk_bw6761.Prove(tccs, _pk, w, opt) if err != nil { return proof, err } @@ -404,7 +406,7 @@ func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - proof, err := plonk_bls24315.Prove(tccs, _pk, w, hintFunctions, _force) + proof, err := plonk_bls24315.Prove(tccs, _pk, w, opt) if err != nil { return proof, err } diff --git a/frontend/cs.go b/frontend/cs.go index a72161819f..4649c025aa 100644 --- a/frontend/cs.go +++ b/frontend/cs.go @@ -85,9 +85,6 @@ type CompiledConstraintSystem interface { GetNbConstraints() int GetNbCoefficients() int - // SetLoggerOutput replace existing logger output with provided one - SetLoggerOutput(w io.Writer) - CurveID() ecc.ID FrSize() int diff --git a/integration_test.go b/integration_test.go index 777f1f4b2c..36a09b0500 100644 --- a/integration_test.go +++ b/integration_test.go @@ -68,10 +68,10 @@ func TestIntegrationAPI(t *testing.T) { pk, vk, err := groth16.Setup(ccs) assert.NoError(err) - correctProof, err := groth16.Prove(ccs, pk, circuit.Good, nil) + correctProof, err := groth16.Prove(ccs, pk, circuit.Good) assert.NoError(err) - wrongProof, err := groth16.Prove(ccs, pk, circuit.Bad, nil, true) + wrongProof, err := groth16.Prove(ccs, pk, circuit.Bad, backend.IgnoreSolverError) assert.NoError(err) assert.NoError(groth16.Verify(correctProof, vk, circuit.Good)) @@ -84,7 +84,7 @@ func TestIntegrationAPI(t *testing.T) { _, err := witness.WriteFullTo(&buf, curve, circuit.Good) assert.NoError(err) - correctProof, err := groth16.ReadAndProve(ccs, pk, &buf, nil) + correctProof, err := groth16.ReadAndProve(ccs, pk, &buf) assert.NoError(err) buf.Reset() @@ -114,10 +114,10 @@ func TestIntegrationAPI(t *testing.T) { pk, vk, err := plonk.Setup(ccs, srs) assert.NoError(err) - correctProof, err := plonk.Prove(ccs, pk, circuit.Good, nil) + correctProof, err := plonk.Prove(ccs, pk, circuit.Good) assert.NoError(err) - wrongProof, err := plonk.Prove(ccs, pk, circuit.Bad, nil, true) + wrongProof, err := plonk.Prove(ccs, pk, circuit.Bad, backend.IgnoreSolverError) assert.NoError(err) assert.NoError(plonk.Verify(correctProof, vk, circuit.Good)) @@ -130,7 +130,7 @@ func TestIntegrationAPI(t *testing.T) { _, err := witness.WriteFullTo(&buf, curve, circuit.Good) assert.NoError(err) - correctProof, err := plonk.ReadAndProve(ccs, pk, &buf, nil) + correctProof, err := plonk.ReadAndProve(ccs, pk, &buf) assert.NoError(err) buf.Reset() diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index 0db969ba4e..e9b60cec33 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -21,12 +21,11 @@ import ( "fmt" "io" "math/big" - "os" "strings" "github.com/fxamacker/cbor/v2" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -40,7 +39,6 @@ import ( type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here - loggerOut io.Writer } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -48,7 +46,6 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r := R1CS{ R1CS: cs, Coefficients: make([]fr.Element, len(coefficients)), - loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { r.Coefficients[i].SetBigInt(&coefficients[i]) @@ -62,10 +59,10 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { nbWires := cs.NbPublicVariables + cs.NbSecretVariables + cs.NbInternalVariables - solution, err := newSolution(nbWires, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbWires, opt.HintFunctions, cs.Coefficients) if err != nil { return make([]fr.Element, nbWires), err } @@ -92,7 +89,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(cs.loggerOut, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // check if there is an inconsistant constraint var check fr.Element @@ -133,11 +130,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { +func (cs *R1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { a := make([]fr.Element, len(cs.Constraints)) b := make([]fr.Element, len(cs.Constraints)) c := make([]fr.Element, len(cs.Constraints)) - _, err := cs.Solve(witness, a, b, c, hintFunctions) + _, err := cs.Solve(witness, a, b, c, opt) return err } @@ -366,13 +363,6 @@ func (cs *R1CS) FrSize() int { return fr.Limbs * 8 } -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *R1CS) SetLoggerOutput(w io.Writer) { - cs.loggerOut = w -} - // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go index 59f8c4bf43..22421afdde 100644 --- a/internal/backend/bls12-377/cs/r1cs_sparse.go +++ b/internal/backend/bls12-377/cs/r1cs_sparse.go @@ -26,7 +26,7 @@ import ( "strings" "text/template" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -59,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables @@ -76,7 +76,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) } // keep track of wire that have a value - solution, err := newSolution(nbVariables, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbVariables, opt.HintFunctions, cs.Coefficients) if err != nil { return solution.values, err } @@ -92,8 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) solution.nbSolved += len(witness) // defer log printing once all solution.values are computed - // TODO @gbotrel replace stdout by writer set by user, same as in R1CS - defer solution.printLogs(os.Stdout, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // batch invert the coefficients to avoid many divisions in the solver coefficientsNegInv := fr.BatchInvert(cs.Coefficients) @@ -225,8 +224,8 @@ func (cs *SparseR1CS) solveConstraint(c compiled.SparseR1C, solution *solution, // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps r1cs.Solve() and allocates r1cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - _, err := cs.Solve(witness, hintFunctions) +func (cs *SparseR1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + _, err := cs.Solve(witness, opt) return err } @@ -327,3 +326,10 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { err = decoder.Decode(cs) return int64(decoder.NumBytesRead()), err } + +// SetLoggerOutput replace existing logger output with provided one +// default uses os.Stdout +// if nil is provided, logs are not printed +func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { + cs.loggerOut = w +} diff --git a/internal/backend/bls12-377/cs/r1cs_test.go b/internal/backend/bls12-377/cs/r1cs_test.go index f4149a87e9..1409e76e89 100644 --- a/internal/backend/bls12-377/cs/r1cs_test.go +++ b/internal/backend/bls12-377/cs/r1cs_test.go @@ -48,9 +48,6 @@ func TestSerialization(t *testing.T) { t.Fatal(err) } - // no need to serialize. - r1cs.SetLoggerOutput(nil) - r1cs2.SetLoggerOutput(nil) { buffer.Reset() t.Log(name) diff --git a/internal/backend/bls12-377/groth16/groth16_test.go b/internal/backend/bls12-377/groth16/groth16_test.go index a85b69f9df..124493eb33 100644 --- a/internal/backend/bls12-377/groth16/groth16_test.go +++ b/internal/backend/bls12-377/groth16/groth16_test.go @@ -130,7 +130,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() b.Run("prover", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + _, _ = bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) } }) } @@ -151,7 +151,7 @@ func BenchmarkVerifier(b *testing.B) { var pk bls12_377groth16.ProvingKey var vk bls12_377groth16.VerifyingKey bls12_377groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -175,7 +175,7 @@ func BenchmarkSerialization(b *testing.B) { var pk bls12_377groth16.ProvingKey var vk bls12_377groth16.VerifyingKey bls12_377groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } diff --git a/internal/backend/bls12-377/groth16/prove.go b/internal/backend/bls12-377/groth16/prove.go index 37b963a409..2bde0d9440 100644 --- a/internal/backend/bls12-377/groth16/prove.go +++ b/internal/backend/bls12-377/groth16/prove.go @@ -27,7 +27,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" bls12_377witness "github.com/consensys/gnark/internal/backend/bls12-377/witness" "github.com/consensys/gnark/internal/utils" "math/big" @@ -53,9 +53,7 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knoweldge of a r1cs with full witness (secret + public part). -// if force flag is set, Prove ignores R1CS solving error (ie invalid witness) and executes -// the FFTs and MultiExponentiations to compute an (invalid) Proof object -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_377witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_377witness.Witness, opt backend.ProverOption) (*Proof, error) { if len(witness) != int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables) { return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables), r1cs.NbPublicVariables, r1cs.NbSecretVariables) } @@ -66,8 +64,8 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_377witness.Witness, hint c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { - if !force { + if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill wireValues with random values else multi exps don't do much diff --git a/internal/backend/bls12-377/plonk/plonk_test.go b/internal/backend/bls12-377/plonk/plonk_test.go index 9d1e8218c3..ae5e8fb660 100644 --- a/internal/backend/bls12-377/plonk/plonk_test.go +++ b/internal/backend/bls12-377/plonk/plonk_test.go @@ -135,7 +135,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err = bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + _, err = bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } @@ -160,7 +160,7 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - proof, err := bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -184,7 +184,7 @@ func BenchmarkSerialization(b *testing.B) { b.Fatal(err) } - proof, err := bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls12-377/plonk/prove.go b/internal/backend/bls12-377/plonk/prove.go index 2239a27ad6..c1efe643ef 100644 --- a/internal/backend/bls12-377/plonk/prove.go +++ b/internal/backend/bls12-377/plonk/prove.go @@ -38,7 +38,7 @@ import ( "github.com/consensys/gnark/internal/backend/bls12-377/cs" "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/utils" ) @@ -60,7 +60,7 @@ type Proof struct { } // Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witness, opt backend.ProverOption) (*Proof, error) { // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -74,8 +74,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn // compute the constraint system solution var solution []fr.Element var err error - if solution, err = spr.Solve(fullWitness, hintFunctions); err != nil { - if !force { + if solution, err = spr.Solve(fullWitness, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill solution with random values diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index d49ccc53d3..d4fc1a806d 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -21,12 +21,11 @@ import ( "fmt" "io" "math/big" - "os" "strings" "github.com/fxamacker/cbor/v2" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -40,7 +39,6 @@ import ( type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here - loggerOut io.Writer } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -48,7 +46,6 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r := R1CS{ R1CS: cs, Coefficients: make([]fr.Element, len(coefficients)), - loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { r.Coefficients[i].SetBigInt(&coefficients[i]) @@ -62,10 +59,10 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { nbWires := cs.NbPublicVariables + cs.NbSecretVariables + cs.NbInternalVariables - solution, err := newSolution(nbWires, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbWires, opt.HintFunctions, cs.Coefficients) if err != nil { return make([]fr.Element, nbWires), err } @@ -92,7 +89,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(cs.loggerOut, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // check if there is an inconsistant constraint var check fr.Element @@ -133,11 +130,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { +func (cs *R1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { a := make([]fr.Element, len(cs.Constraints)) b := make([]fr.Element, len(cs.Constraints)) c := make([]fr.Element, len(cs.Constraints)) - _, err := cs.Solve(witness, a, b, c, hintFunctions) + _, err := cs.Solve(witness, a, b, c, opt) return err } @@ -366,13 +363,6 @@ func (cs *R1CS) FrSize() int { return fr.Limbs * 8 } -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *R1CS) SetLoggerOutput(w io.Writer) { - cs.loggerOut = w -} - // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go index 6ef64f3ef9..04909b7220 100644 --- a/internal/backend/bls12-381/cs/r1cs_sparse.go +++ b/internal/backend/bls12-381/cs/r1cs_sparse.go @@ -26,7 +26,7 @@ import ( "strings" "text/template" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -59,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables @@ -76,7 +76,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) } // keep track of wire that have a value - solution, err := newSolution(nbVariables, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbVariables, opt.HintFunctions, cs.Coefficients) if err != nil { return solution.values, err } @@ -92,8 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) solution.nbSolved += len(witness) // defer log printing once all solution.values are computed - // TODO @gbotrel replace stdout by writer set by user, same as in R1CS - defer solution.printLogs(os.Stdout, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // batch invert the coefficients to avoid many divisions in the solver coefficientsNegInv := fr.BatchInvert(cs.Coefficients) @@ -225,8 +224,8 @@ func (cs *SparseR1CS) solveConstraint(c compiled.SparseR1C, solution *solution, // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps r1cs.Solve() and allocates r1cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - _, err := cs.Solve(witness, hintFunctions) +func (cs *SparseR1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + _, err := cs.Solve(witness, opt) return err } @@ -327,3 +326,10 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { err = decoder.Decode(cs) return int64(decoder.NumBytesRead()), err } + +// SetLoggerOutput replace existing logger output with provided one +// default uses os.Stdout +// if nil is provided, logs are not printed +func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { + cs.loggerOut = w +} diff --git a/internal/backend/bls12-381/cs/r1cs_test.go b/internal/backend/bls12-381/cs/r1cs_test.go index cf34792a3e..a270c89096 100644 --- a/internal/backend/bls12-381/cs/r1cs_test.go +++ b/internal/backend/bls12-381/cs/r1cs_test.go @@ -48,9 +48,6 @@ func TestSerialization(t *testing.T) { t.Fatal(err) } - // no need to serialize. - r1cs.SetLoggerOutput(nil) - r1cs2.SetLoggerOutput(nil) { buffer.Reset() t.Log(name) diff --git a/internal/backend/bls12-381/groth16/groth16_test.go b/internal/backend/bls12-381/groth16/groth16_test.go index bbdf0274c8..3c948011ed 100644 --- a/internal/backend/bls12-381/groth16/groth16_test.go +++ b/internal/backend/bls12-381/groth16/groth16_test.go @@ -130,7 +130,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() b.Run("prover", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = bls12_381groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + _, _ = bls12_381groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) } }) } @@ -151,7 +151,7 @@ func BenchmarkVerifier(b *testing.B) { var pk bls12_381groth16.ProvingKey var vk bls12_381groth16.VerifyingKey bls12_381groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls12_381groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bls12_381groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -175,7 +175,7 @@ func BenchmarkSerialization(b *testing.B) { var pk bls12_381groth16.ProvingKey var vk bls12_381groth16.VerifyingKey bls12_381groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls12_381groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bls12_381groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } diff --git a/internal/backend/bls12-381/groth16/prove.go b/internal/backend/bls12-381/groth16/prove.go index 7e60e7d984..eb053866f6 100644 --- a/internal/backend/bls12-381/groth16/prove.go +++ b/internal/backend/bls12-381/groth16/prove.go @@ -27,7 +27,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" bls12_381witness "github.com/consensys/gnark/internal/backend/bls12-381/witness" "github.com/consensys/gnark/internal/utils" "math/big" @@ -53,9 +53,7 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knoweldge of a r1cs with full witness (secret + public part). -// if force flag is set, Prove ignores R1CS solving error (ie invalid witness) and executes -// the FFTs and MultiExponentiations to compute an (invalid) Proof object -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_381witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_381witness.Witness, opt backend.ProverOption) (*Proof, error) { if len(witness) != int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables) { return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables), r1cs.NbPublicVariables, r1cs.NbSecretVariables) } @@ -66,8 +64,8 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_381witness.Witness, hint c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { - if !force { + if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill wireValues with random values else multi exps don't do much diff --git a/internal/backend/bls12-381/plonk/plonk_test.go b/internal/backend/bls12-381/plonk/plonk_test.go index 502e624e51..2dd84f8001 100644 --- a/internal/backend/bls12-381/plonk/plonk_test.go +++ b/internal/backend/bls12-381/plonk/plonk_test.go @@ -135,7 +135,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err = bls12_381plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + _, err = bls12_381plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } @@ -160,7 +160,7 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - proof, err := bls12_381plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bls12_381plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -184,7 +184,7 @@ func BenchmarkSerialization(b *testing.B) { b.Fatal(err) } - proof, err := bls12_381plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bls12_381plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls12-381/plonk/prove.go b/internal/backend/bls12-381/plonk/prove.go index d84d5cc20a..9ab6a7bdca 100644 --- a/internal/backend/bls12-381/plonk/prove.go +++ b/internal/backend/bls12-381/plonk/prove.go @@ -38,7 +38,7 @@ import ( "github.com/consensys/gnark/internal/backend/bls12-381/cs" "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/utils" ) @@ -60,7 +60,7 @@ type Proof struct { } // Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witness, opt backend.ProverOption) (*Proof, error) { // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -74,8 +74,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn // compute the constraint system solution var solution []fr.Element var err error - if solution, err = spr.Solve(fullWitness, hintFunctions); err != nil { - if !force { + if solution, err = spr.Solve(fullWitness, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill solution with random values diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index 8f5f6814f5..5943d58eb3 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -21,12 +21,11 @@ import ( "fmt" "io" "math/big" - "os" "strings" "github.com/fxamacker/cbor/v2" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -40,7 +39,6 @@ import ( type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here - loggerOut io.Writer } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -48,7 +46,6 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r := R1CS{ R1CS: cs, Coefficients: make([]fr.Element, len(coefficients)), - loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { r.Coefficients[i].SetBigInt(&coefficients[i]) @@ -62,10 +59,10 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { nbWires := cs.NbPublicVariables + cs.NbSecretVariables + cs.NbInternalVariables - solution, err := newSolution(nbWires, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbWires, opt.HintFunctions, cs.Coefficients) if err != nil { return make([]fr.Element, nbWires), err } @@ -92,7 +89,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(cs.loggerOut, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // check if there is an inconsistant constraint var check fr.Element @@ -133,11 +130,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { +func (cs *R1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { a := make([]fr.Element, len(cs.Constraints)) b := make([]fr.Element, len(cs.Constraints)) c := make([]fr.Element, len(cs.Constraints)) - _, err := cs.Solve(witness, a, b, c, hintFunctions) + _, err := cs.Solve(witness, a, b, c, opt) return err } @@ -366,13 +363,6 @@ func (cs *R1CS) FrSize() int { return fr.Limbs * 8 } -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *R1CS) SetLoggerOutput(w io.Writer) { - cs.loggerOut = w -} - // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go index dea13c2836..a3def050cb 100644 --- a/internal/backend/bls24-315/cs/r1cs_sparse.go +++ b/internal/backend/bls24-315/cs/r1cs_sparse.go @@ -26,7 +26,7 @@ import ( "strings" "text/template" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -59,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables @@ -76,7 +76,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) } // keep track of wire that have a value - solution, err := newSolution(nbVariables, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbVariables, opt.HintFunctions, cs.Coefficients) if err != nil { return solution.values, err } @@ -92,8 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) solution.nbSolved += len(witness) // defer log printing once all solution.values are computed - // TODO @gbotrel replace stdout by writer set by user, same as in R1CS - defer solution.printLogs(os.Stdout, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // batch invert the coefficients to avoid many divisions in the solver coefficientsNegInv := fr.BatchInvert(cs.Coefficients) @@ -225,8 +224,8 @@ func (cs *SparseR1CS) solveConstraint(c compiled.SparseR1C, solution *solution, // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps r1cs.Solve() and allocates r1cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - _, err := cs.Solve(witness, hintFunctions) +func (cs *SparseR1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + _, err := cs.Solve(witness, opt) return err } @@ -327,3 +326,10 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { err = decoder.Decode(cs) return int64(decoder.NumBytesRead()), err } + +// SetLoggerOutput replace existing logger output with provided one +// default uses os.Stdout +// if nil is provided, logs are not printed +func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { + cs.loggerOut = w +} diff --git a/internal/backend/bls24-315/cs/r1cs_test.go b/internal/backend/bls24-315/cs/r1cs_test.go index 51dc21d66c..6608db8102 100644 --- a/internal/backend/bls24-315/cs/r1cs_test.go +++ b/internal/backend/bls24-315/cs/r1cs_test.go @@ -48,9 +48,6 @@ func TestSerialization(t *testing.T) { t.Fatal(err) } - // no need to serialize. - r1cs.SetLoggerOutput(nil) - r1cs2.SetLoggerOutput(nil) { buffer.Reset() t.Log(name) diff --git a/internal/backend/bls24-315/groth16/groth16_test.go b/internal/backend/bls24-315/groth16/groth16_test.go index 9ed77d4fc1..5053d1d88e 100644 --- a/internal/backend/bls24-315/groth16/groth16_test.go +++ b/internal/backend/bls24-315/groth16/groth16_test.go @@ -130,7 +130,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() b.Run("prover", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = bls24_315groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + _, _ = bls24_315groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) } }) } @@ -151,7 +151,7 @@ func BenchmarkVerifier(b *testing.B) { var pk bls24_315groth16.ProvingKey var vk bls24_315groth16.VerifyingKey bls24_315groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls24_315groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bls24_315groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -175,7 +175,7 @@ func BenchmarkSerialization(b *testing.B) { var pk bls24_315groth16.ProvingKey var vk bls24_315groth16.VerifyingKey bls24_315groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls24_315groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bls24_315groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } diff --git a/internal/backend/bls24-315/groth16/prove.go b/internal/backend/bls24-315/groth16/prove.go index 23da0b0fdc..916be2bf8e 100644 --- a/internal/backend/bls24-315/groth16/prove.go +++ b/internal/backend/bls24-315/groth16/prove.go @@ -27,7 +27,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" bls24_315witness "github.com/consensys/gnark/internal/backend/bls24-315/witness" "github.com/consensys/gnark/internal/utils" "math/big" @@ -53,9 +53,7 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knoweldge of a r1cs with full witness (secret + public part). -// if force flag is set, Prove ignores R1CS solving error (ie invalid witness) and executes -// the FFTs and MultiExponentiations to compute an (invalid) Proof object -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls24_315witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls24_315witness.Witness, opt backend.ProverOption) (*Proof, error) { if len(witness) != int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables) { return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables), r1cs.NbPublicVariables, r1cs.NbSecretVariables) } @@ -66,8 +64,8 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls24_315witness.Witness, hint c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { - if !force { + if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill wireValues with random values else multi exps don't do much diff --git a/internal/backend/bls24-315/plonk/plonk_test.go b/internal/backend/bls24-315/plonk/plonk_test.go index f823c8cb18..94497ead20 100644 --- a/internal/backend/bls24-315/plonk/plonk_test.go +++ b/internal/backend/bls24-315/plonk/plonk_test.go @@ -135,7 +135,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err = bls24_315plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + _, err = bls24_315plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } @@ -160,7 +160,7 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - proof, err := bls24_315plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bls24_315plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -184,7 +184,7 @@ func BenchmarkSerialization(b *testing.B) { b.Fatal(err) } - proof, err := bls24_315plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bls24_315plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls24-315/plonk/prove.go b/internal/backend/bls24-315/plonk/prove.go index 9ad9b3741e..a772beec84 100644 --- a/internal/backend/bls24-315/plonk/prove.go +++ b/internal/backend/bls24-315/plonk/prove.go @@ -38,7 +38,7 @@ import ( "github.com/consensys/gnark/internal/backend/bls24-315/cs" "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/utils" ) @@ -60,7 +60,7 @@ type Proof struct { } // Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witness, opt backend.ProverOption) (*Proof, error) { // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -74,8 +74,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn // compute the constraint system solution var solution []fr.Element var err error - if solution, err = spr.Solve(fullWitness, hintFunctions); err != nil { - if !force { + if solution, err = spr.Solve(fullWitness, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill solution with random values diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index c1b0a0a4c3..36a9a3dd23 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -21,12 +21,11 @@ import ( "fmt" "io" "math/big" - "os" "strings" "github.com/fxamacker/cbor/v2" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -40,7 +39,6 @@ import ( type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here - loggerOut io.Writer } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -48,7 +46,6 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r := R1CS{ R1CS: cs, Coefficients: make([]fr.Element, len(coefficients)), - loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { r.Coefficients[i].SetBigInt(&coefficients[i]) @@ -62,10 +59,10 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { nbWires := cs.NbPublicVariables + cs.NbSecretVariables + cs.NbInternalVariables - solution, err := newSolution(nbWires, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbWires, opt.HintFunctions, cs.Coefficients) if err != nil { return make([]fr.Element, nbWires), err } @@ -92,7 +89,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(cs.loggerOut, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // check if there is an inconsistant constraint var check fr.Element @@ -133,11 +130,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { +func (cs *R1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { a := make([]fr.Element, len(cs.Constraints)) b := make([]fr.Element, len(cs.Constraints)) c := make([]fr.Element, len(cs.Constraints)) - _, err := cs.Solve(witness, a, b, c, hintFunctions) + _, err := cs.Solve(witness, a, b, c, opt) return err } @@ -366,13 +363,6 @@ func (cs *R1CS) FrSize() int { return fr.Limbs * 8 } -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *R1CS) SetLoggerOutput(w io.Writer) { - cs.loggerOut = w -} - // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go index 2123014785..4fdc2d163c 100644 --- a/internal/backend/bn254/cs/r1cs_sparse.go +++ b/internal/backend/bn254/cs/r1cs_sparse.go @@ -26,7 +26,7 @@ import ( "strings" "text/template" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -59,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables @@ -76,7 +76,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) } // keep track of wire that have a value - solution, err := newSolution(nbVariables, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbVariables, opt.HintFunctions, cs.Coefficients) if err != nil { return solution.values, err } @@ -92,8 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) solution.nbSolved += len(witness) // defer log printing once all solution.values are computed - // TODO @gbotrel replace stdout by writer set by user, same as in R1CS - defer solution.printLogs(os.Stdout, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // batch invert the coefficients to avoid many divisions in the solver coefficientsNegInv := fr.BatchInvert(cs.Coefficients) @@ -225,8 +224,8 @@ func (cs *SparseR1CS) solveConstraint(c compiled.SparseR1C, solution *solution, // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps r1cs.Solve() and allocates r1cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - _, err := cs.Solve(witness, hintFunctions) +func (cs *SparseR1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + _, err := cs.Solve(witness, opt) return err } @@ -327,3 +326,10 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { err = decoder.Decode(cs) return int64(decoder.NumBytesRead()), err } + +// SetLoggerOutput replace existing logger output with provided one +// default uses os.Stdout +// if nil is provided, logs are not printed +func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { + cs.loggerOut = w +} diff --git a/internal/backend/bn254/cs/r1cs_test.go b/internal/backend/bn254/cs/r1cs_test.go index bdda6cea67..146faec6f6 100644 --- a/internal/backend/bn254/cs/r1cs_test.go +++ b/internal/backend/bn254/cs/r1cs_test.go @@ -48,9 +48,6 @@ func TestSerialization(t *testing.T) { t.Fatal(err) } - // no need to serialize. - r1cs.SetLoggerOutput(nil) - r1cs2.SetLoggerOutput(nil) { buffer.Reset() t.Log(name) diff --git a/internal/backend/bn254/groth16/groth16_test.go b/internal/backend/bn254/groth16/groth16_test.go index 884cd99dfb..a59471e6c0 100644 --- a/internal/backend/bn254/groth16/groth16_test.go +++ b/internal/backend/bn254/groth16/groth16_test.go @@ -130,7 +130,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() b.Run("prover", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = bn254groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + _, _ = bn254groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) } }) } @@ -151,7 +151,7 @@ func BenchmarkVerifier(b *testing.B) { var pk bn254groth16.ProvingKey var vk bn254groth16.VerifyingKey bn254groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bn254groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bn254groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -175,7 +175,7 @@ func BenchmarkSerialization(b *testing.B) { var pk bn254groth16.ProvingKey var vk bn254groth16.VerifyingKey bn254groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bn254groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bn254groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } diff --git a/internal/backend/bn254/groth16/prove.go b/internal/backend/bn254/groth16/prove.go index 906ee2280d..5fcbef9f6a 100644 --- a/internal/backend/bn254/groth16/prove.go +++ b/internal/backend/bn254/groth16/prove.go @@ -27,7 +27,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" bn254witness "github.com/consensys/gnark/internal/backend/bn254/witness" "github.com/consensys/gnark/internal/utils" "math/big" @@ -53,9 +53,7 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knoweldge of a r1cs with full witness (secret + public part). -// if force flag is set, Prove ignores R1CS solving error (ie invalid witness) and executes -// the FFTs and MultiExponentiations to compute an (invalid) Proof object -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bn254witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bn254witness.Witness, opt backend.ProverOption) (*Proof, error) { if len(witness) != int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables) { return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables), r1cs.NbPublicVariables, r1cs.NbSecretVariables) } @@ -66,8 +64,8 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bn254witness.Witness, hintFunc c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { - if !force { + if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill wireValues with random values else multi exps don't do much diff --git a/internal/backend/bn254/plonk/plonk_test.go b/internal/backend/bn254/plonk/plonk_test.go index 58a3e3d0cd..6e26a1d314 100644 --- a/internal/backend/bn254/plonk/plonk_test.go +++ b/internal/backend/bn254/plonk/plonk_test.go @@ -135,7 +135,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err = bn254plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + _, err = bn254plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } @@ -160,7 +160,7 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - proof, err := bn254plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bn254plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -184,7 +184,7 @@ func BenchmarkSerialization(b *testing.B) { b.Fatal(err) } - proof, err := bn254plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bn254plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index 1c3ccc4e7d..0421b01dee 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -38,7 +38,7 @@ import ( "github.com/consensys/gnark/internal/backend/bn254/cs" "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/utils" ) @@ -60,7 +60,7 @@ type Proof struct { } // Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, opt backend.ProverOption) (*Proof, error) { // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -74,8 +74,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, // compute the constraint system solution var solution []fr.Element var err error - if solution, err = spr.Solve(fullWitness, hintFunctions); err != nil { - if !force { + if solution, err = spr.Solve(fullWitness, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill solution with random values diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index 96c3650bd6..962e7f970e 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -21,12 +21,11 @@ import ( "fmt" "io" "math/big" - "os" "strings" "github.com/fxamacker/cbor/v2" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -40,7 +39,6 @@ import ( type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here - loggerOut io.Writer } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -48,7 +46,6 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r := R1CS{ R1CS: cs, Coefficients: make([]fr.Element, len(coefficients)), - loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { r.Coefficients[i].SetBigInt(&coefficients[i]) @@ -62,10 +59,10 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { nbWires := cs.NbPublicVariables + cs.NbSecretVariables + cs.NbInternalVariables - solution, err := newSolution(nbWires, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbWires, opt.HintFunctions, cs.Coefficients) if err != nil { return make([]fr.Element, nbWires), err } @@ -92,7 +89,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(cs.loggerOut, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // check if there is an inconsistant constraint var check fr.Element @@ -133,11 +130,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { +func (cs *R1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { a := make([]fr.Element, len(cs.Constraints)) b := make([]fr.Element, len(cs.Constraints)) c := make([]fr.Element, len(cs.Constraints)) - _, err := cs.Solve(witness, a, b, c, hintFunctions) + _, err := cs.Solve(witness, a, b, c, opt) return err } @@ -366,13 +363,6 @@ func (cs *R1CS) FrSize() int { return fr.Limbs * 8 } -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *R1CS) SetLoggerOutput(w io.Writer) { - cs.loggerOut = w -} - // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go index 9137ee9f5d..5adf8da3b8 100644 --- a/internal/backend/bw6-761/cs/r1cs_sparse.go +++ b/internal/backend/bw6-761/cs/r1cs_sparse.go @@ -26,7 +26,7 @@ import ( "strings" "text/template" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -59,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables @@ -76,7 +76,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) } // keep track of wire that have a value - solution, err := newSolution(nbVariables, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbVariables, opt.HintFunctions, cs.Coefficients) if err != nil { return solution.values, err } @@ -92,8 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) solution.nbSolved += len(witness) // defer log printing once all solution.values are computed - // TODO @gbotrel replace stdout by writer set by user, same as in R1CS - defer solution.printLogs(os.Stdout, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // batch invert the coefficients to avoid many divisions in the solver coefficientsNegInv := fr.BatchInvert(cs.Coefficients) @@ -225,8 +224,8 @@ func (cs *SparseR1CS) solveConstraint(c compiled.SparseR1C, solution *solution, // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps r1cs.Solve() and allocates r1cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - _, err := cs.Solve(witness, hintFunctions) +func (cs *SparseR1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + _, err := cs.Solve(witness, opt) return err } @@ -327,3 +326,10 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { err = decoder.Decode(cs) return int64(decoder.NumBytesRead()), err } + +// SetLoggerOutput replace existing logger output with provided one +// default uses os.Stdout +// if nil is provided, logs are not printed +func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { + cs.loggerOut = w +} diff --git a/internal/backend/bw6-761/cs/r1cs_test.go b/internal/backend/bw6-761/cs/r1cs_test.go index 5289783da9..7601610373 100644 --- a/internal/backend/bw6-761/cs/r1cs_test.go +++ b/internal/backend/bw6-761/cs/r1cs_test.go @@ -52,9 +52,6 @@ func TestSerialization(t *testing.T) { t.Fatal(err) } - // no need to serialize. - r1cs.SetLoggerOutput(nil) - r1cs2.SetLoggerOutput(nil) { buffer.Reset() t.Log(name) diff --git a/internal/backend/bw6-761/groth16/groth16_test.go b/internal/backend/bw6-761/groth16/groth16_test.go index ca19ed0afe..f84636e32b 100644 --- a/internal/backend/bw6-761/groth16/groth16_test.go +++ b/internal/backend/bw6-761/groth16/groth16_test.go @@ -130,7 +130,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() b.Run("prover", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = bw6_761groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + _, _ = bw6_761groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) } }) } @@ -151,7 +151,7 @@ func BenchmarkVerifier(b *testing.B) { var pk bw6_761groth16.ProvingKey var vk bw6_761groth16.VerifyingKey bw6_761groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bw6_761groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bw6_761groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -175,7 +175,7 @@ func BenchmarkSerialization(b *testing.B) { var pk bw6_761groth16.ProvingKey var vk bw6_761groth16.VerifyingKey bw6_761groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bw6_761groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bw6_761groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } diff --git a/internal/backend/bw6-761/groth16/prove.go b/internal/backend/bw6-761/groth16/prove.go index 19cd63abea..c3995877ca 100644 --- a/internal/backend/bw6-761/groth16/prove.go +++ b/internal/backend/bw6-761/groth16/prove.go @@ -27,7 +27,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" bw6_761witness "github.com/consensys/gnark/internal/backend/bw6-761/witness" "github.com/consensys/gnark/internal/utils" "math/big" @@ -53,9 +53,7 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knoweldge of a r1cs with full witness (secret + public part). -// if force flag is set, Prove ignores R1CS solving error (ie invalid witness) and executes -// the FFTs and MultiExponentiations to compute an (invalid) Proof object -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bw6_761witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bw6_761witness.Witness, opt backend.ProverOption) (*Proof, error) { if len(witness) != int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables) { return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables), r1cs.NbPublicVariables, r1cs.NbSecretVariables) } @@ -66,8 +64,8 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bw6_761witness.Witness, hintFu c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { - if !force { + if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill wireValues with random values else multi exps don't do much diff --git a/internal/backend/bw6-761/plonk/plonk_test.go b/internal/backend/bw6-761/plonk/plonk_test.go index 2089e9e37c..b9728c4d75 100644 --- a/internal/backend/bw6-761/plonk/plonk_test.go +++ b/internal/backend/bw6-761/plonk/plonk_test.go @@ -135,7 +135,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err = bw6_761plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + _, err = bw6_761plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } @@ -160,7 +160,7 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - proof, err := bw6_761plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bw6_761plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -184,7 +184,7 @@ func BenchmarkSerialization(b *testing.B) { b.Fatal(err) } - proof, err := bw6_761plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bw6_761plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bw6-761/plonk/prove.go b/internal/backend/bw6-761/plonk/prove.go index ad79236bf8..0a52b7f156 100644 --- a/internal/backend/bw6-761/plonk/prove.go +++ b/internal/backend/bw6-761/plonk/prove.go @@ -38,7 +38,7 @@ import ( "github.com/consensys/gnark/internal/backend/bw6-761/cs" "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/utils" ) @@ -60,7 +60,7 @@ type Proof struct { } // Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witness, opt backend.ProverOption) (*Proof, error) { // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -74,8 +74,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes // compute the constraint system solution var solution []fr.Element var err error - if solution, err = spr.Solve(fullWitness, hintFunctions); err != nil { - if !force { + if solution, err = spr.Solve(fullWitness, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill solution with random values diff --git a/internal/backend/compiled/cs.go b/internal/backend/compiled/cs.go index 92e5cbb94a..a8e6c17e6a 100644 --- a/internal/backend/compiled/cs.go +++ b/internal/backend/compiled/cs.go @@ -70,8 +70,5 @@ func (cs *CS) WriteTo(w io.Writer) (n int64, err error) { panic("not implemented // ReadFrom panics func (cs *CS) ReadFrom(r io.Reader) (n int64, err error) { panic("not implemented") } -// SetLoggerOutput panics -func (cs *CS) SetLoggerOutput(w io.Writer) { panic("not implemented") } - // ToHTML panics func (cs *CS) ToHTML(w io.Writer) error { panic("not implemtened") } diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index 5b7761f857..16909315e6 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -2,7 +2,6 @@ import ( "errors" "fmt" "io" - "os" "math/big" "strings" @@ -10,7 +9,7 @@ import ( "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark/internal/backend/compiled" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark-crypto/ecc" "text/template" @@ -23,7 +22,6 @@ import ( type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here - loggerOut io.Writer } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -31,7 +29,6 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r := R1CS{ R1CS: cs, Coefficients: make([]fr.Element, len(coefficients)), - loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { r.Coefficients[i].SetBigInt(&coefficients[i]) @@ -46,10 +43,10 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { nbWires := cs.NbPublicVariables + cs.NbSecretVariables + cs.NbInternalVariables - solution, err := newSolution(nbWires, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbWires, opt.HintFunctions, cs.Coefficients) if err != nil { return make([]fr.Element, nbWires), err } @@ -80,7 +77,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(cs.loggerOut, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // check if there is an inconsistant constraint @@ -124,11 +121,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { +func (cs *R1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { a := make([]fr.Element, len(cs.Constraints)) b := make([]fr.Element, len(cs.Constraints)) c := make([]fr.Element, len(cs.Constraints)) - _, err := cs.Solve(witness, a, b, c, hintFunctions) + _, err := cs.Solve(witness, a, b, c, opt) return err } @@ -366,12 +363,6 @@ func (cs *R1CS) FrSize() int { return fr.Limbs * 8 } -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *R1CS) SetLoggerOutput(w io.Writer) { - cs.loggerOut = w -} // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl index 5421d62b3f..a591ade61e 100644 --- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl @@ -10,7 +10,7 @@ import ( "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark/internal/backend/compiled" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" {{ template "import_fr" . }} ) @@ -42,7 +42,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved @@ -63,7 +63,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) // keep track of wire that have a value - solution, err := newSolution(nbVariables, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbVariables, opt.HintFunctions, cs.Coefficients) if err != nil { return solution.values, err } @@ -79,8 +79,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) solution.nbSolved += len(witness) // defer log printing once all solution.values are computed - // TODO @gbotrel replace stdout by writer set by user, same as in R1CS - defer solution.printLogs(os.Stdout, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // batch invert the coefficients to avoid many divisions in the solver coefficientsNegInv := fr.BatchInvert(cs.Coefficients) @@ -216,8 +215,8 @@ func (cs *SparseR1CS) solveConstraint(c compiled.SparseR1C, solution *solution, // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps r1cs.Solve() and allocates r1cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - _, err := cs.Solve(witness, hintFunctions) +func (cs *SparseR1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + _, err := cs.Solve(witness, opt) return err } @@ -320,3 +319,10 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { err = decoder.Decode(cs) return int64(decoder.NumBytesRead()), err } + +// SetLoggerOutput replace existing logger output with provided one +// default uses os.Stdout +// if nil is provided, logs are not printed +func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { + cs.loggerOut = w +} \ No newline at end of file diff --git a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl index fc34031c41..50527d1cff 100644 --- a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl @@ -37,9 +37,6 @@ func TestSerialization(t *testing.T) { t.Fatal(err) } - // no need to serialize. - r1cs.SetLoggerOutput(nil) - r1cs2.SetLoggerOutput(nil) { buffer.Reset() t.Log(name) diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl index 051c7e2706..9b869f5dba 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl @@ -9,7 +9,7 @@ import ( "math/big" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" ) // Proof represents a Groth16 proof that was encoded with a ProvingKey and can be verified @@ -31,9 +31,7 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knoweldge of a r1cs with full witness (secret + public part). -// if force flag is set, Prove ignores R1CS solving error (ie invalid witness) and executes -// the FFTs and MultiExponentiations to compute an (invalid) Proof object -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness {{ toLower .CurveID }}witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness {{ toLower .CurveID }}witness.Witness, opt backend.ProverOption) (*Proof, error) { if len(witness) != int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables) { return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables), r1cs.NbPublicVariables, r1cs.NbSecretVariables) } @@ -44,8 +42,8 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness {{ toLower .CurveID }}witness. c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions ); err != nil { - if !force { + if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill wireValues with random values else multi exps don't do much diff --git a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl index 6fb83f3c5e..ce9947e887 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl @@ -109,7 +109,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() b.Run("prover", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = {{toLower .CurveID}}groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + _, _ = {{toLower .CurveID}}groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) } }) } @@ -130,7 +130,7 @@ func BenchmarkVerifier(b *testing.B) { var pk {{toLower .CurveID}}groth16.ProvingKey var vk {{toLower .CurveID}}groth16.VerifyingKey {{toLower .CurveID}}groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := {{toLower .CurveID}}groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness,nil, false) + proof, err := {{toLower .CurveID}}groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness,backend.ProverOption{}) if err != nil { panic(err) } @@ -156,7 +156,7 @@ func BenchmarkSerialization(b *testing.B) { var pk {{toLower .CurveID}}groth16.ProvingKey var vk {{toLower .CurveID}}groth16.VerifyingKey {{toLower .CurveID}}groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := {{toLower .CurveID}}groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness,nil, false) + proof, err := {{toLower .CurveID}}groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness,backend.ProverOption{}) if err != nil { panic(err) } diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl index 20b6fe6088..6a76b0bc3c 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl @@ -14,7 +14,7 @@ import ( {{ template "import_backend_cs" . }} "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark-crypto/fiat-shamir" ) @@ -36,7 +36,7 @@ type Proof struct { } // Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }}witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }}witness.Witness, opt backend.ProverOption) (*Proof, error) { // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -50,8 +50,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } // compute the constraint system solution var solution []fr.Element var err error - if solution, err = spr.Solve(fullWitness, hintFunctions); err != nil { - if !force { + if solution, err = spr.Solve(fullWitness, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill solution with random values diff --git a/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl index 8ce7eaf658..943f5484bf 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl @@ -115,7 +115,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err = {{toLower .CurveID}}plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil , false) + _, err = {{toLower .CurveID}}plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness,backend.ProverOption{}) if err != nil { b.Fatal(err) } @@ -140,7 +140,7 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - proof, err := {{toLower .CurveID}}plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil , false) + proof, err := {{toLower .CurveID}}plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -166,7 +166,7 @@ func BenchmarkSerialization(b *testing.B) { b.Fatal(err) } - proof, err := {{toLower .CurveID}}plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false ) + proof, err := {{toLower .CurveID}}plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{} ) if err != nil { b.Fatal(err) } diff --git a/std/groth16/verifier_test.go b/std/groth16/verifier_test.go index 94ef2bcecd..c9f3b034c0 100644 --- a/std/groth16/verifier_test.go +++ b/std/groth16/verifier_test.go @@ -80,7 +80,7 @@ func generateBls377InnerProof(t *testing.T, vk *groth16_bls12377.VerifyingKey, p // generate the data to return for the bls12377 proof var pk groth16_bls12377.ProvingKey groth16_bls12377.Setup(r1cs.(*backend_bls12377.R1CS), &pk, vk) - _proof, err := groth16_bls12377.Prove(r1cs.(*backend_bls12377.R1CS), &pk, correctAssignment, nil, false) + _proof, err := groth16_bls12377.Prove(r1cs.(*backend_bls12377.R1CS), &pk, correctAssignment, backend.ProverOption{}) if err != nil { t.Fatal(err) } From 996b46def49f00a0184b2efaff77811f68c69f4e Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 17 Sep 2021 11:55:21 -0500 Subject: [PATCH 09/15] fix: comment fuzz test --- backend/groth16/fuzz.go | 57 +++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/backend/groth16/fuzz.go b/backend/groth16/fuzz.go index 093131059b..d761b4b03c 100644 --- a/backend/groth16/fuzz.go +++ b/backend/groth16/fuzz.go @@ -6,39 +6,40 @@ package groth16 import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" - backend_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/cs" - witness_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/witness" - backend_bn254 "github.com/consensys/gnark/internal/backend/bn254/cs" - witness_bn254 "github.com/consensys/gnark/internal/backend/bn254/witness" + // backend_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/cs" + // witness_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/witness" + // backend_bn254 "github.com/consensys/gnark/internal/backend/bn254/cs" + // witness_bn254 "github.com/consensys/gnark/internal/backend/bn254/witness" ) +// TODO FIXME @gbotrel func Fuzz(data []byte) int { curves := []ecc.ID{ecc.BN254, ecc.BLS12_381} for _, curveID := range curves { - ccs := frontend.CsFuzzed(data, curveID) - _, s, p := ccs.GetNbVariables() - wSize := s + p - 1 - ccs.SetLoggerOutput(nil) - switch _r1cs := ccs.(type) { - case *backend_bls12381.R1CS: - w := make(witness_bls12381.Witness, wSize) - // make w random - _ = _r1cs.IsSolved(w, nil) - // TODO FIXME @gbotrel - // if nbAssertions == 0 && err != nil && !strings.Contains(err.Error(), "couldn't solve computational constraint") { - // panic("no assertions, yet solving resulted in an error.") - // } - case *backend_bn254.R1CS: - w := make(witness_bn254.Witness, wSize) - // make w random - _ = _r1cs.IsSolved(w, nil) - // TODO FIXME @gbotrel - // if nbAssertions == 0 && err != nil && !strings.Contains(err.Error(), "couldn't solve computational constraint") { - // panic("no assertions, yet solving resulted in an error.") - // } - default: - panic("unrecognized R1CS curve type") - } + frontend.CsFuzzed(data, curveID) + // _, s, p := ccs.GetNbVariables() + // wSize := s + p - 1 + // ccs.SetLoggerOutput(nil) + // switch _r1cs := ccs.(type) { + // case *backend_bls12381.R1CS: + // w := make(witness_bls12381.Witness, wSize) + // // make w random + // _ = _r1cs.IsSolved(w, nil) + // // TODO FIXME @gbotrel + // // if nbAssertions == 0 && err != nil && !strings.Contains(err.Error(), "couldn't solve computational constraint") { + // // panic("no assertions, yet solving resulted in an error.") + // // } + // case *backend_bn254.R1CS: + // w := make(witness_bn254.Witness, wSize) + // // make w random + // _ = _r1cs.IsSolved(w, nil) + // // TODO FIXME @gbotrel + // // if nbAssertions == 0 && err != nil && !strings.Contains(err.Error(), "couldn't solve computational constraint") { + // // panic("no assertions, yet solving resulted in an error.") + // // } + // default: + // panic("unrecognized R1CS curve type") } + // } return 1 } From 8279a29c7bbfa3fe920f6bbdfd417aaef0285ac1 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 17 Sep 2021 13:48:40 -0500 Subject: [PATCH 10/15] test: added non regression for cs.Println and debugInfo traces --- debug_test.go | 202 ++++++++++++++++++++++++++++++++++++++ frontend/cs_api.go | 5 +- frontend/cs_assertions.go | 2 +- frontend/cs_debug.go | 7 ++ go.mod | 2 +- go.sum | 2 + 6 files changed, 216 insertions(+), 4 deletions(-) create mode 100644 debug_test.go diff --git a/debug_test.go b/debug_test.go new file mode 100644 index 0000000000..ed2fdfc443 --- /dev/null +++ b/debug_test.go @@ -0,0 +1,202 @@ +package gnark + +import ( + "bytes" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/frontend" + "github.com/stretchr/testify/require" +) + +// ------------------------------------------------------------------------------------------------- +// test println (non regression) +type printlnCircuit struct { + A, B frontend.Variable +} + +func (circuit *printlnCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error { + c := cs.Add(circuit.A, circuit.B) + cs.Println(c, "is the addition") + d := cs.Mul(circuit.A, c) + cs.Println(d, new(big.Int).SetInt64(42)) + bs := cs.ToBinary(circuit.B, 10) + cs.Println("bits", bs[3]) + return nil +} + +func TestPrintln(t *testing.T) { + assert := require.New(t) + + var circuit, witness printlnCircuit + witness.A.Assign(2) + witness.B.Assign(11) + + var expected bytes.Buffer + expected.WriteString("debug_test.go:24 13 is the addition\n") + expected.WriteString("debug_test.go:26 26 42\n") + expected.WriteString("debug_test.go:28 bits 1\n") + + { + trace, err := getGroth16Trace(&circuit, &witness) + assert.NoError(err) + assert.Equal(trace, expected.String()) + } + + { + trace, err := getPlonkTrace(&circuit, &witness) + assert.NoError(err) + assert.Equal(trace, expected.String()) + } +} + +// ------------------------------------------------------------------------------------------------- +// Div by 0 +type divBy0Trace struct { + A, B, C frontend.Variable +} + +func (circuit *divBy0Trace) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error { + d := cs.Add(circuit.B, circuit.C) + cs.Div(circuit.A, d) + return nil +} + +func TestDivBy0(t *testing.T) { + assert := require.New(t) + + var circuit, witness divBy0Trace + witness.A.Assign(2) + witness.B.Assign(-2) + witness.C.Assign(2) + + { + _, err := getGroth16Trace(&circuit, &witness) + assert.Error(err) + assert.Contains(err.Error(), "constraint is not satisfied: [div] 2/(-2 + 2) == 0") + assert.Contains(err.Error(), "gnark.(*divBy0Trace).Define") + assert.Contains(err.Error(), "gnark/debug_test.go:65") + } + + { + _, err := getPlonkTrace(&circuit, &witness) + assert.Error(err) + assert.Contains(err.Error(), "constraint is not satisfied: [div] 2/(-2 + 2) == 0") + assert.Contains(err.Error(), "gnark.(*divBy0Trace).Define") + assert.Contains(err.Error(), "gnark/debug_test.go:65") + } +} + +// ------------------------------------------------------------------------------------------------- +// Not Equal +type notEqualTrace struct { + A, B, C frontend.Variable +} + +func (circuit *notEqualTrace) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error { + d := cs.Add(circuit.B, circuit.C) + cs.AssertIsEqual(circuit.A, d) + return nil +} + +func TestNotEqual(t *testing.T) { + assert := require.New(t) + + var circuit, witness notEqualTrace + witness.A.Assign(1) + witness.B.Assign(24) + witness.C.Assign(42) + + { + _, err := getGroth16Trace(&circuit, &witness) + assert.Error(err) + assert.Contains(err.Error(), "constraint is not satisfied: [assertIsEqual] 1 == (24 + 42)") + assert.Contains(err.Error(), "gnark.(*notEqualTrace).Define") + assert.Contains(err.Error(), "gnark/debug_test.go:102") + } + + { + _, err := getPlonkTrace(&circuit, &witness) + assert.Error(err) + assert.Contains(err.Error(), "constraint is not satisfied: [assertIsEqual] 1 == (24 + 42)") + assert.Contains(err.Error(), "gnark.(*notEqualTrace).Define") + assert.Contains(err.Error(), "gnark/debug_test.go:102") + } +} + +// ------------------------------------------------------------------------------------------------- +// Not boolean +type notBooleanTrace struct { + A, B, C frontend.Variable +} + +func (circuit *notBooleanTrace) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error { + d := cs.Add(circuit.B, circuit.C) + cs.AssertIsBoolean(d) + return nil +} + +func TestNotBoolean(t *testing.T) { + assert := require.New(t) + + var circuit, witness notBooleanTrace + witness.A.Assign(1) + witness.B.Assign(24) + witness.C.Assign(42) + + { + _, err := getGroth16Trace(&circuit, &witness) + assert.Error(err) + assert.Contains(err.Error(), "constraint is not satisfied: [assertIsBoolean] (24 + 42) == (0|1)") + assert.Contains(err.Error(), "gnark.(*notBooleanTrace).Define") + assert.Contains(err.Error(), "gnark/debug_test.go:139") + } + + { + _, err := getPlonkTrace(&circuit, &witness) + assert.Error(err) + assert.Contains(err.Error(), "constraint is not satisfied: [assertIsBoolean] (24 + 42) == (0|1)") + assert.Contains(err.Error(), "gnark.(*notBooleanTrace).Define") + assert.Contains(err.Error(), "gnark/debug_test.go:139") + } +} + +func getPlonkTrace(circuit, witness frontend.Circuit) (string, error) { + ccs, err := frontend.Compile(ecc.BN254, backend.PLONK, circuit) + if err != nil { + return "", err + } + + srs, err := plonk.NewSRS(ccs) + if err != nil { + return "", err + } + pk, _, err := plonk.Setup(ccs, srs) + if err != nil { + return "", err + } + + var buf bytes.Buffer + _, err = plonk.Prove(ccs, pk, witness, backend.WithOutput(&buf)) + return buf.String(), err +} + +func getGroth16Trace(circuit, witness frontend.Circuit) (string, error) { + ccs, err := frontend.Compile(ecc.BN254, backend.GROTH16, circuit) + if err != nil { + return "", err + } + + pk, err := groth16.DummySetup(ccs) + if err != nil { + return "", err + } + + var buf bytes.Buffer + _, err = groth16.Prove(ccs, pk, witness, backend.WithOutput(&buf)) + return buf.String(), err +} diff --git a/frontend/cs_api.go b/frontend/cs_api.go index f5199e2547..7b370b875c 100644 --- a/frontend/cs_api.go +++ b/frontend/cs_api.go @@ -185,11 +185,12 @@ func (cs *ConstraintSystem) Mul(i1, i2 interface{}, in ...interface{}) Variable // Inverse returns res = inverse(v) func (cs *ConstraintSystem) Inverse(v Variable) Variable { v.assertIsSet() - debug := cs.addDebugInfo("inverse", v) // allocate resulting variable res := cs.newInternalVariable() + debug := cs.addDebugInfo("inverse", v, "*", res, " == 1") + cs.addConstraint(newR1C(v, res, cs.one()), debug) return res @@ -203,7 +204,7 @@ func (cs *ConstraintSystem) Div(i1, i2 interface{}) Variable { v1 := cs.Constant(i1) v2 := cs.Constant(i2) - debug := cs.addDebugInfo("div", v1, " / ", v2, " != ", res) + debug := cs.addDebugInfo("div", v1, "/", v2, " == ", res) cs.addConstraint(newR1C(v2, res, v1), debug) diff --git a/frontend/cs_assertions.go b/frontend/cs_assertions.go index 6bed794c4e..466af0d664 100644 --- a/frontend/cs_assertions.go +++ b/frontend/cs_assertions.go @@ -41,7 +41,7 @@ func (cs *ConstraintSystem) AssertIsBoolean(v Variable) { if !cs.markBoolean(v) { return // variable is already constrained } - debug := cs.addDebugInfo("assertIsBoolean", v) + debug := cs.addDebugInfo("assertIsBoolean", v, " == (0|1)") // ensure v * (1 - v) == 0 _v := cs.Sub(1, v) diff --git a/frontend/cs_debug.go b/frontend/cs_debug.go index d737ec9027..b4e98d353a 100644 --- a/frontend/cs_debug.go +++ b/frontend/cs_debug.go @@ -22,7 +22,14 @@ func (cs *ConstraintSystem) addDebugInfo(errName string, i ...interface{}) int { for _, _i := range i { switch v := _i.(type) { case Variable: + if len(v.linExp) > 1 { + sbb.WriteString("(") + } debug.WriteLinearExpression(v.linExp, &sbb) + if len(v.linExp) > 1 { + sbb.WriteString(")") + } + case string: sbb.WriteString(v) case int: diff --git a/go.mod b/go.mod index 715a08287d..54d39f486b 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.16 require ( github.com/consensys/bavard v0.1.8-0.20210806153619-fcffe4ffd871 - github.com/consensys/gnark-crypto v0.5.1-0.20210907174324-9721833081d7 + github.com/consensys/gnark-crypto v0.5.1-0.20210917183421-cb36b2c871c0 github.com/davecgh/go-spew v1.1.1 // indirect github.com/fxamacker/cbor/v2 v2.2.0 github.com/kr/pretty v0.2.0 // indirect diff --git a/go.sum b/go.sum index cdee21cff1..541eec72a4 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/consensys/gnark-crypto v0.5.1-0.20210907173531-0ae8b5c38618 h1:vnrIRU github.com/consensys/gnark-crypto v0.5.1-0.20210907173531-0ae8b5c38618/go.mod h1:wAZ9dsKCDVTSIy2KVTik+ZF16GUX9qp96mxFBDl9iAQ= github.com/consensys/gnark-crypto v0.5.1-0.20210907174324-9721833081d7 h1:2k7ImGxDTTY2OpiKjnFDfqc/ir8O54qCwUTnobfDbkM= github.com/consensys/gnark-crypto v0.5.1-0.20210907174324-9721833081d7/go.mod h1:wAZ9dsKCDVTSIy2KVTik+ZF16GUX9qp96mxFBDl9iAQ= +github.com/consensys/gnark-crypto v0.5.1-0.20210917183421-cb36b2c871c0 h1:ODfAG0P/XaGvh1JNZM9tzL2MKVaqFdE7FeATcrdrHB0= +github.com/consensys/gnark-crypto v0.5.1-0.20210917183421-cb36b2c871c0/go.mod h1:wAZ9dsKCDVTSIy2KVTik+ZF16GUX9qp96mxFBDl9iAQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From 9e7f3e8ea722cc6cac360a1b5c0504269bc1ede9 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 17 Sep 2021 13:57:01 -0500 Subject: [PATCH 11/15] fix: ensure frontend.ConstraintSystem is not modified by compile process --- debug_test.go | 12 ++++++------ frontend/cs_to_r1cs.go | 26 +++++++++++++++++++++----- frontend/cs_to_r1cs_sparse.go | 19 ++++++++++++++----- 3 files changed, 41 insertions(+), 16 deletions(-) diff --git a/debug_test.go b/debug_test.go index ed2fdfc443..270ae3fed5 100644 --- a/debug_test.go +++ b/debug_test.go @@ -79,7 +79,7 @@ func TestDivBy0(t *testing.T) { assert.Error(err) assert.Contains(err.Error(), "constraint is not satisfied: [div] 2/(-2 + 2) == 0") assert.Contains(err.Error(), "gnark.(*divBy0Trace).Define") - assert.Contains(err.Error(), "gnark/debug_test.go:65") + assert.Contains(err.Error(), "debug_test.go:65") } { @@ -87,7 +87,7 @@ func TestDivBy0(t *testing.T) { assert.Error(err) assert.Contains(err.Error(), "constraint is not satisfied: [div] 2/(-2 + 2) == 0") assert.Contains(err.Error(), "gnark.(*divBy0Trace).Define") - assert.Contains(err.Error(), "gnark/debug_test.go:65") + assert.Contains(err.Error(), "debug_test.go:65") } } @@ -116,7 +116,7 @@ func TestNotEqual(t *testing.T) { assert.Error(err) assert.Contains(err.Error(), "constraint is not satisfied: [assertIsEqual] 1 == (24 + 42)") assert.Contains(err.Error(), "gnark.(*notEqualTrace).Define") - assert.Contains(err.Error(), "gnark/debug_test.go:102") + assert.Contains(err.Error(), "debug_test.go:102") } { @@ -124,7 +124,7 @@ func TestNotEqual(t *testing.T) { assert.Error(err) assert.Contains(err.Error(), "constraint is not satisfied: [assertIsEqual] 1 == (24 + 42)") assert.Contains(err.Error(), "gnark.(*notEqualTrace).Define") - assert.Contains(err.Error(), "gnark/debug_test.go:102") + assert.Contains(err.Error(), "debug_test.go:102") } } @@ -153,7 +153,7 @@ func TestNotBoolean(t *testing.T) { assert.Error(err) assert.Contains(err.Error(), "constraint is not satisfied: [assertIsBoolean] (24 + 42) == (0|1)") assert.Contains(err.Error(), "gnark.(*notBooleanTrace).Define") - assert.Contains(err.Error(), "gnark/debug_test.go:139") + assert.Contains(err.Error(), "debug_test.go:139") } { @@ -161,7 +161,7 @@ func TestNotBoolean(t *testing.T) { assert.Error(err) assert.Contains(err.Error(), "constraint is not satisfied: [assertIsBoolean] (24 + 42) == (0|1)") assert.Contains(err.Error(), "gnark.(*notBooleanTrace).Define") - assert.Contains(err.Error(), "gnark/debug_test.go:139") + assert.Contains(err.Error(), "debug_test.go:139") } } diff --git a/frontend/cs_to_r1cs.go b/frontend/cs_to_r1cs.go index 722ecbbfdf..42913c0fbc 100644 --- a/frontend/cs_to_r1cs.go +++ b/frontend/cs_to_r1cs.go @@ -33,11 +33,15 @@ func (cs *ConstraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er // for logs, debugInfo and hints the only thing that will change // is that ID of the wires will be offseted to take into account the final wire vector ordering // that is: public wires | secret wires | internal wires - copy(res.Logs, cs.logs) - copy(res.DebugInfo, cs.debugInfo) // computational constraints (= gates) - copy(res.Constraints, cs.constraints) + for i, r1c := range cs.constraints { + res.Constraints[i] = compiled.R1C{ + L: r1c.L.Clone(), + R: r1c.R.Clone(), + O: r1c.O.Clone(), + } + } // for a R1CS, the correspondance between constraint and debug info won't change, we just copy for k, v := range cs.mDebug { @@ -83,13 +87,25 @@ func (cs *ConstraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er } // we need to offset the ids in logs & debugInfo - for i := 0; i < len(res.Logs); i++ { + for i := 0; i < len(cs.logs); i++ { + res.Logs[i] = compiled.LogEntry{ + Format: cs.logs[i].Format, + ToResolve: make([]compiled.Term, len(cs.logs[i].ToResolve)), + } + copy(res.Logs[i].ToResolve, cs.logs[i].ToResolve) + for j := 0; j < len(res.Logs[i].ToResolve); j++ { _, vID, visibility := res.Logs[i].ToResolve[j].Unpack() res.Logs[i].ToResolve[j].SetVariableID(shiftVID(vID, visibility)) } } - for i := 0; i < len(res.DebugInfo); i++ { + for i := 0; i < len(cs.debugInfo); i++ { + res.DebugInfo[i] = compiled.LogEntry{ + Format: cs.debugInfo[i].Format, + ToResolve: make([]compiled.Term, len(cs.debugInfo[i].ToResolve)), + } + copy(res.DebugInfo[i].ToResolve, cs.debugInfo[i].ToResolve) + for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { _, vID, visibility := res.DebugInfo[i].ToResolve[j].Unpack() res.DebugInfo[i].ToResolve[j].SetVariableID(shiftVID(vID, visibility)) diff --git a/frontend/cs_to_r1cs_sparse.go b/frontend/cs_to_r1cs_sparse.go index a923679904..68f2d963d3 100644 --- a/frontend/cs_to_r1cs_sparse.go +++ b/frontend/cs_to_r1cs_sparse.go @@ -76,8 +76,6 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst // logs, debugInfo and hints are copied, the only thing that will change // is that ID of the wires will be offseted to take into account the final wire vector ordering // that is: public wires | secret wires | internal wires - copy(res.ccs.Logs, cs.logs) - copy(res.ccs.DebugInfo, cs.debugInfo) // we mark hint wires are solved // each R1C from the frontend.ConstraintSystem is allowed to have at most one unsolved wire @@ -147,14 +145,25 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst offsetTermID(&r1c.M[1]) } - // offset IDs in the logs // we need to offset the ids in logs & debugInfo - for i := 0; i < len(res.ccs.Logs); i++ { + for i := 0; i < len(cs.logs); i++ { + res.ccs.Logs[i] = compiled.LogEntry{ + Format: cs.logs[i].Format, + ToResolve: make([]compiled.Term, len(cs.logs[i].ToResolve)), + } + copy(res.ccs.Logs[i].ToResolve, cs.logs[i].ToResolve) + for j := 0; j < len(res.ccs.Logs[i].ToResolve); j++ { offsetTermID(&res.ccs.Logs[i].ToResolve[j]) } } - for i := 0; i < len(res.ccs.DebugInfo); i++ { + for i := 0; i < len(cs.debugInfo); i++ { + res.ccs.DebugInfo[i] = compiled.LogEntry{ + Format: cs.debugInfo[i].Format, + ToResolve: make([]compiled.Term, len(cs.debugInfo[i].ToResolve)), + } + copy(res.ccs.DebugInfo[i].ToResolve, cs.debugInfo[i].ToResolve) + for j := 0; j < len(res.ccs.DebugInfo[i].ToResolve); j++ { offsetTermID(&res.ccs.DebugInfo[i].ToResolve[j]) } From 535e3d214afb90ad563d20bbd09f6f0180060cdc Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 17 Sep 2021 14:13:33 -0500 Subject: [PATCH 12/15] build: remove dead code, makes staticcheck happier --- frontend/cs.go | 46 ---------------------------------------------- 1 file changed, 46 deletions(-) diff --git a/frontend/cs.go b/frontend/cs.go index 4649c025aa..d28030ec0c 100644 --- a/frontend/cs.go +++ b/frontend/cs.go @@ -21,7 +21,6 @@ import ( "io" "math/big" "path/filepath" - "reflect" "runtime" "sort" "strconv" @@ -335,51 +334,6 @@ func (cs *ConstraintSystem) newVirtualVariable() Variable { return cs.virtual.new(cs, compiled.Virtual) } -type logValueHandler func(name string, tValue reflect.Value) - -func appendName(baseName, name string) string { - if baseName == "" { - return name - } - return baseName + "_" + name -} - -func parseLogValue(input interface{}, name string, handler logValueHandler) { - tVariable := reflect.TypeOf(Variable{}) - - tValue := reflect.ValueOf(input) - if tValue.Kind() == reflect.Ptr { - tValue = tValue.Elem() - } - switch tValue.Kind() { - case reflect.Struct: - switch tValue.Type() { - case tVariable: - handler(name, tValue) - return - default: - for i := 0; i < tValue.NumField(); i++ { - if tValue.Field(i).CanInterface() { - value := tValue.Field(i).Interface() - _name := appendName(name, tValue.Type().Field(i).Name) - parseLogValue(value, _name, handler) - } - } - } - case reflect.Slice, reflect.Array: - if tValue.Len() == 0 { - fmt.Println("warning, got unitizalized slice (or empty array). Ignoring;") - return - } - for j := 0; j < tValue.Len(); j++ { - value := tValue.Index(j).Interface() - entry := "[" + strconv.Itoa(j) + "]" - _name := appendName(name, entry) - parseLogValue(value, _name, handler) - } - } -} - func (cs *ConstraintSystem) buildVarFromWire(pv Wire) Variable { return Variable{pv, cs.LinearExpression(compiled.Pack(pv.id, compiled.CoeffIdOne, pv.visibility))} } From f97fe6b382d497c19a68e2f59bef5c8ed9e76ada Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 20 Sep 2021 11:25:21 -0500 Subject: [PATCH 13/15] feat: cs.Println now supports structures with Variables --- debug_test.go | 30 ++++--- frontend/cs.go | 52 ----------- frontend/cs_debug.go | 90 ++++++++++++++++++- frontend/cs_to_r1cs_sparse.go | 2 + internal/backend/bls12-377/cs/solution.go | 6 +- internal/backend/bls12-381/cs/solution.go | 6 +- internal/backend/bls24-315/cs/solution.go | 6 +- internal/backend/bn254/cs/solution.go | 6 +- internal/backend/bw6-761/cs/solution.go | 6 +- .../template/representations/solution.go.tmpl | 6 +- 10 files changed, 132 insertions(+), 78 deletions(-) diff --git a/debug_test.go b/debug_test.go index 270ae3fed5..ee8211073a 100644 --- a/debug_test.go +++ b/debug_test.go @@ -26,6 +26,10 @@ func (circuit *printlnCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSys cs.Println(d, new(big.Int).SetInt64(42)) bs := cs.ToBinary(circuit.B, 10) cs.Println("bits", bs[3]) + cs.Println("circuit", circuit) + cs.AssertIsBoolean(cs.Constant(10)) // this will fail + m := cs.Mul(circuit.A, circuit.B) + cs.Println("m", m) // this should not be resolved return nil } @@ -40,16 +44,16 @@ func TestPrintln(t *testing.T) { expected.WriteString("debug_test.go:24 13 is the addition\n") expected.WriteString("debug_test.go:26 26 42\n") expected.WriteString("debug_test.go:28 bits 1\n") + expected.WriteString("debug_test.go:29 circuit {A: 2, B: 11}\n") + expected.WriteString("debug_test.go:32 m \n") { - trace, err := getGroth16Trace(&circuit, &witness) - assert.NoError(err) + trace, _ := getGroth16Trace(&circuit, &witness) assert.Equal(trace, expected.String()) } { - trace, err := getPlonkTrace(&circuit, &witness) - assert.NoError(err) + trace, _ := getPlonkTrace(&circuit, &witness) assert.Equal(trace, expected.String()) } } @@ -66,7 +70,7 @@ func (circuit *divBy0Trace) Define(curveID ecc.ID, cs *frontend.ConstraintSystem return nil } -func TestDivBy0(t *testing.T) { +func TestTraceDivBy0(t *testing.T) { assert := require.New(t) var circuit, witness divBy0Trace @@ -79,7 +83,7 @@ func TestDivBy0(t *testing.T) { assert.Error(err) assert.Contains(err.Error(), "constraint is not satisfied: [div] 2/(-2 + 2) == 0") assert.Contains(err.Error(), "gnark.(*divBy0Trace).Define") - assert.Contains(err.Error(), "debug_test.go:65") + assert.Contains(err.Error(), "debug_test.go:69") } { @@ -87,7 +91,7 @@ func TestDivBy0(t *testing.T) { assert.Error(err) assert.Contains(err.Error(), "constraint is not satisfied: [div] 2/(-2 + 2) == 0") assert.Contains(err.Error(), "gnark.(*divBy0Trace).Define") - assert.Contains(err.Error(), "debug_test.go:65") + assert.Contains(err.Error(), "debug_test.go:69") } } @@ -103,7 +107,7 @@ func (circuit *notEqualTrace) Define(curveID ecc.ID, cs *frontend.ConstraintSyst return nil } -func TestNotEqual(t *testing.T) { +func TestTraceNotEqual(t *testing.T) { assert := require.New(t) var circuit, witness notEqualTrace @@ -116,7 +120,7 @@ func TestNotEqual(t *testing.T) { assert.Error(err) assert.Contains(err.Error(), "constraint is not satisfied: [assertIsEqual] 1 == (24 + 42)") assert.Contains(err.Error(), "gnark.(*notEqualTrace).Define") - assert.Contains(err.Error(), "debug_test.go:102") + assert.Contains(err.Error(), "debug_test.go:106") } { @@ -124,7 +128,7 @@ func TestNotEqual(t *testing.T) { assert.Error(err) assert.Contains(err.Error(), "constraint is not satisfied: [assertIsEqual] 1 == (24 + 42)") assert.Contains(err.Error(), "gnark.(*notEqualTrace).Define") - assert.Contains(err.Error(), "debug_test.go:102") + assert.Contains(err.Error(), "debug_test.go:106") } } @@ -140,7 +144,7 @@ func (circuit *notBooleanTrace) Define(curveID ecc.ID, cs *frontend.ConstraintSy return nil } -func TestNotBoolean(t *testing.T) { +func TestTraceNotBoolean(t *testing.T) { assert := require.New(t) var circuit, witness notBooleanTrace @@ -153,7 +157,7 @@ func TestNotBoolean(t *testing.T) { assert.Error(err) assert.Contains(err.Error(), "constraint is not satisfied: [assertIsBoolean] (24 + 42) == (0|1)") assert.Contains(err.Error(), "gnark.(*notBooleanTrace).Define") - assert.Contains(err.Error(), "debug_test.go:139") + assert.Contains(err.Error(), "debug_test.go:143") } { @@ -161,7 +165,7 @@ func TestNotBoolean(t *testing.T) { assert.Error(err) assert.Contains(err.Error(), "constraint is not satisfied: [assertIsBoolean] (24 + 42) == (0|1)") assert.Contains(err.Error(), "gnark.(*notBooleanTrace).Define") - assert.Contains(err.Error(), "debug_test.go:139") + assert.Contains(err.Error(), "debug_test.go:143") } } diff --git a/frontend/cs.go b/frontend/cs.go index d28030ec0c..56a8a1cdfa 100644 --- a/frontend/cs.go +++ b/frontend/cs.go @@ -17,14 +17,9 @@ limitations under the License. package frontend import ( - "fmt" "io" "math/big" - "path/filepath" - "runtime" "sort" - "strconv" - "strings" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/hint" @@ -157,53 +152,6 @@ func (cs *ConstraintSystem) NewHint(hintID hint.ID, inputs ...interface{}) Varia return r } -// Println enables circuit debugging and behaves almost like fmt.Println() -// -// the print will be done once the R1CS.Solve() method is executed -// -// if one of the input is a Variable, its value will be resolved avec R1CS.Solve() method is called -func (cs *ConstraintSystem) Println(a ...interface{}) { - var sbb strings.Builder - - // prefix log line with file.go:line - if _, file, line, ok := runtime.Caller(1); ok { - sbb.WriteString(filepath.Base(file)) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(line)) - sbb.WriteByte(' ') - } - - var log compiled.LogEntry - - for i, arg := range a { - if i > 0 { - sbb.WriteByte(' ') - } - if v, ok := arg.(Variable); ok { - v.assertIsSet() - - sbb.WriteString("%s") - // we set limits to the linear expression, so that the log printer - // can evaluate it before printing it - log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) - log.ToResolve = append(log.ToResolve, v.linExp...) - log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) - } else { - sbb.WriteString(fmt.Sprint(arg)) - } - } - sbb.WriteByte('\n') - - // set format string to be used with fmt.Sprintf, once the variables are solved in the R1CS.Solve() method - log.Format = sbb.String() - - cs.logs = append(cs.logs, log) -} - -var ( - bOne = new(big.Int).SetInt64(1) -) - func (cs *ConstraintSystem) one() Variable { return cs.public.variables[0] } diff --git a/frontend/cs_debug.go b/frontend/cs_debug.go index b4e98d353a..d3845f1b1d 100644 --- a/frontend/cs_debug.go +++ b/frontend/cs_debug.go @@ -1,13 +1,101 @@ package frontend import ( + "fmt" + "path/filepath" + "reflect" + "runtime" "strconv" "strings" "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/internal/parser" ) -// TODO @gbotrel maybe rename to newLog if common with cs.Println +// Println enables circuit debugging and behaves almost like fmt.Println() +// +// the print will be done once the R1CS.Solve() method is executed +// +// if one of the input is a Variable, its value will be resolved avec R1CS.Solve() method is called +func (cs *ConstraintSystem) Println(a ...interface{}) { + var sbb strings.Builder + + // prefix log line with file.go:line + if _, file, line, ok := runtime.Caller(1); ok { + sbb.WriteString(filepath.Base(file)) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(line)) + sbb.WriteByte(' ') + } + + var log compiled.LogEntry + + for i, arg := range a { + if i > 0 { + sbb.WriteByte(' ') + } + if v, ok := arg.(Variable); ok { + v.assertIsSet() + + sbb.WriteString("%s") + // we set limits to the linear expression, so that the log printer + // can evaluate it before printing it + log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) + log.ToResolve = append(log.ToResolve, v.linExp...) + log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) + } else { + printArg(&log, &sbb, arg) + } + } + sbb.WriteByte('\n') + + // set format string to be used with fmt.Sprintf, once the variables are solved in the R1CS.Solve() method + log.Format = sbb.String() + + cs.logs = append(cs.logs, log) +} + +func printArg(log *compiled.LogEntry, sbb *strings.Builder, a interface{}) error { + + count := 0 + counter := func(visibility compiled.Visibility, name string, tValue reflect.Value) error { + count++ + return nil + } + if err := parser.Visit(a, "", compiled.Unset, counter, reflect.TypeOf(Variable{})); err != nil { + return err + } + + if count == 0 { + sbb.WriteString(fmt.Sprint(a)) + return nil + } + sbb.WriteByte('{') + printer := func(visibility compiled.Visibility, name string, tValue reflect.Value) error { + count-- + sbb.WriteString(name) + sbb.WriteString(": ") + sbb.WriteString("%s") + if count != 0 { + sbb.WriteString(", ") + } + + v := tValue.Interface().(Variable) + // we set limits to the linear expression, so that the log printer + // can evaluate it before printing it + log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) + log.ToResolve = append(log.ToResolve, v.linExp...) + log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) + return nil + } + if err := parser.Visit(a, "", compiled.Unset, printer, reflect.TypeOf(Variable{})); err != nil { + return err + } + sbb.WriteByte('}') + + return nil +} + func (cs *ConstraintSystem) addDebugInfo(errName string, i ...interface{}) int { var debug compiled.LogEntry diff --git a/frontend/cs_to_r1cs_sparse.go b/frontend/cs_to_r1cs_sparse.go index 68f2d963d3..44c9a809b5 100644 --- a/frontend/cs_to_r1cs_sparse.go +++ b/frontend/cs_to_r1cs_sparse.go @@ -52,6 +52,8 @@ type sparseR1CS struct { currentR1CDebugID int // mark the current R1C debugID } +var bOne = new(big.Int).SetInt64(1) + func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSystem, error) { res := sparseR1CS{ diff --git a/internal/backend/bls12-377/cs/solution.go b/internal/backend/bls12-377/cs/solution.go index 6c34e817f1..6a24bb28cb 100644 --- a/internal/backend/bls12-377/cs/solution.go +++ b/internal/backend/bls12-377/cs/solution.go @@ -149,6 +149,8 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { } } +const unsolvedVariable = "" + func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} var ( @@ -167,7 +169,7 @@ func (s *solution) logValue(log compiled.LogEntry) string { } isEval = false if missingValue { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { // we have to append our accumulator toResolve = append(toResolve, eval.String()) @@ -205,7 +207,7 @@ func (s *solution) logValue(log compiled.LogEntry) string { toResolve = append(toResolve, s.coefficients[cID].String()) } if !s.solved[vID] { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { toResolve = append(toResolve, s.values[vID].String()) } diff --git a/internal/backend/bls12-381/cs/solution.go b/internal/backend/bls12-381/cs/solution.go index 7ca64a9afd..2744adf6d0 100644 --- a/internal/backend/bls12-381/cs/solution.go +++ b/internal/backend/bls12-381/cs/solution.go @@ -149,6 +149,8 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { } } +const unsolvedVariable = "" + func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} var ( @@ -167,7 +169,7 @@ func (s *solution) logValue(log compiled.LogEntry) string { } isEval = false if missingValue { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { // we have to append our accumulator toResolve = append(toResolve, eval.String()) @@ -205,7 +207,7 @@ func (s *solution) logValue(log compiled.LogEntry) string { toResolve = append(toResolve, s.coefficients[cID].String()) } if !s.solved[vID] { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { toResolve = append(toResolve, s.values[vID].String()) } diff --git a/internal/backend/bls24-315/cs/solution.go b/internal/backend/bls24-315/cs/solution.go index 70515ea6ba..f2cc9e8d08 100644 --- a/internal/backend/bls24-315/cs/solution.go +++ b/internal/backend/bls24-315/cs/solution.go @@ -149,6 +149,8 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { } } +const unsolvedVariable = "" + func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} var ( @@ -167,7 +169,7 @@ func (s *solution) logValue(log compiled.LogEntry) string { } isEval = false if missingValue { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { // we have to append our accumulator toResolve = append(toResolve, eval.String()) @@ -205,7 +207,7 @@ func (s *solution) logValue(log compiled.LogEntry) string { toResolve = append(toResolve, s.coefficients[cID].String()) } if !s.solved[vID] { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { toResolve = append(toResolve, s.values[vID].String()) } diff --git a/internal/backend/bn254/cs/solution.go b/internal/backend/bn254/cs/solution.go index 07bc4b921a..e46aebcc25 100644 --- a/internal/backend/bn254/cs/solution.go +++ b/internal/backend/bn254/cs/solution.go @@ -149,6 +149,8 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { } } +const unsolvedVariable = "" + func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} var ( @@ -167,7 +169,7 @@ func (s *solution) logValue(log compiled.LogEntry) string { } isEval = false if missingValue { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { // we have to append our accumulator toResolve = append(toResolve, eval.String()) @@ -205,7 +207,7 @@ func (s *solution) logValue(log compiled.LogEntry) string { toResolve = append(toResolve, s.coefficients[cID].String()) } if !s.solved[vID] { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { toResolve = append(toResolve, s.values[vID].String()) } diff --git a/internal/backend/bw6-761/cs/solution.go b/internal/backend/bw6-761/cs/solution.go index ebb1d6a669..deceb66fb9 100644 --- a/internal/backend/bw6-761/cs/solution.go +++ b/internal/backend/bw6-761/cs/solution.go @@ -149,6 +149,8 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { } } +const unsolvedVariable = "" + func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} var ( @@ -167,7 +169,7 @@ func (s *solution) logValue(log compiled.LogEntry) string { } isEval = false if missingValue { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { // we have to append our accumulator toResolve = append(toResolve, eval.String()) @@ -205,7 +207,7 @@ func (s *solution) logValue(log compiled.LogEntry) string { toResolve = append(toResolve, s.coefficients[cID].String()) } if !s.solved[vID] { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { toResolve = append(toResolve, s.values[vID].String()) } diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl index 46a1630fac..178cf40a31 100644 --- a/internal/generator/backend/template/representations/solution.go.tmpl +++ b/internal/generator/backend/template/representations/solution.go.tmpl @@ -137,6 +137,8 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { } } +const unsolvedVariable = "" + func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} @@ -156,7 +158,7 @@ func (s *solution) logValue(log compiled.LogEntry) string { } isEval = false if missingValue { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { // we have to append our accumulator toResolve = append(toResolve, eval.String()) @@ -194,7 +196,7 @@ func (s *solution) logValue(log compiled.LogEntry) string { toResolve = append(toResolve, s.coefficients[cID].String()) } if !s.solved[vID] { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { toResolve = append(toResolve, s.values[vID].String()) } From 59110ae7766a39f2f36d0a8239a320fc679e240b Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 20 Sep 2021 11:25:53 -0500 Subject: [PATCH 14/15] build: go mod tidy --- go.mod | 1 - go.sum | 9 --------- 2 files changed, 10 deletions(-) diff --git a/go.mod b/go.mod index 54d39f486b..8bd8eaa3b6 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/fxamacker/cbor/v2 v2.2.0 github.com/kr/pretty v0.2.0 // indirect github.com/leanovate/gopter v0.2.9 - github.com/pkg/profile v1.6.0 // indirect github.com/stretchr/testify v1.7.0 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect diff --git a/go.sum b/go.sum index 541eec72a4..8c70a887cf 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,5 @@ github.com/consensys/bavard v0.1.8-0.20210806153619-fcffe4ffd871 h1:gfdz2r/E4uQhD8jDUv2SaWQClfzFuZioHGAzPw7oZng= github.com/consensys/bavard v0.1.8-0.20210806153619-fcffe4ffd871/go.mod h1:Bpd0/3mZuaj6Sj+PqrmIquiOKy397AKGThQPaGzNXAQ= -github.com/consensys/gnark-crypto v0.4.1-0.20210818174051-018b86471fca h1:YuKivJirttUz/FNlAp1dwIiJiYyPOoyno2CoRlfqMNs= -github.com/consensys/gnark-crypto v0.4.1-0.20210818174051-018b86471fca/go.mod h1:5u+nS08qZhHtugNg17dAnCGqbnRCJ6XSdPj0LyFvAOM= -github.com/consensys/gnark-crypto v0.5.0 h1:c+1SOpCPKmw5lKth/hIoRgcw23KSgWnNR/b5M+JRC3k= -github.com/consensys/gnark-crypto v0.5.0/go.mod h1:wAZ9dsKCDVTSIy2KVTik+ZF16GUX9qp96mxFBDl9iAQ= -github.com/consensys/gnark-crypto v0.5.1-0.20210907173531-0ae8b5c38618 h1:vnrIRUFj8afz/QlnRlD+AVLkgkyYj6JD6MNLK6R7dcg= -github.com/consensys/gnark-crypto v0.5.1-0.20210907173531-0ae8b5c38618/go.mod h1:wAZ9dsKCDVTSIy2KVTik+ZF16GUX9qp96mxFBDl9iAQ= -github.com/consensys/gnark-crypto v0.5.1-0.20210907174324-9721833081d7 h1:2k7ImGxDTTY2OpiKjnFDfqc/ir8O54qCwUTnobfDbkM= -github.com/consensys/gnark-crypto v0.5.1-0.20210907174324-9721833081d7/go.mod h1:wAZ9dsKCDVTSIy2KVTik+ZF16GUX9qp96mxFBDl9iAQ= github.com/consensys/gnark-crypto v0.5.1-0.20210917183421-cb36b2c871c0 h1:ODfAG0P/XaGvh1JNZM9tzL2MKVaqFdE7FeATcrdrHB0= github.com/consensys/gnark-crypto v0.5.1-0.20210917183421-cb36b2c871c0/go.mod h1:wAZ9dsKCDVTSIy2KVTik+ZF16GUX9qp96mxFBDl9iAQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -22,7 +14,6 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= github.com/leanovate/gopter v0.2.9/go.mod h1:U2L/78B+KVFIx2VmW6onHJQzXtFb+p5y3y2Sh+Jxxv8= -github.com/pkg/profile v1.6.0/go.mod h1:qBsxPvzyUincmltOk6iyRVxHYg4adc0OFOv72ZdLa18= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= From f624d36fcd481d4ca90d298a22e465d50cc06935 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 20 Sep 2021 11:32:31 -0500 Subject: [PATCH 15/15] style: printArg doesn't return error --- frontend/cs_debug.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/frontend/cs_debug.go b/frontend/cs_debug.go index d3845f1b1d..b9fb30e9b9 100644 --- a/frontend/cs_debug.go +++ b/frontend/cs_debug.go @@ -55,21 +55,22 @@ func (cs *ConstraintSystem) Println(a ...interface{}) { cs.logs = append(cs.logs, log) } -func printArg(log *compiled.LogEntry, sbb *strings.Builder, a interface{}) error { +func printArg(log *compiled.LogEntry, sbb *strings.Builder, a interface{}) { count := 0 counter := func(visibility compiled.Visibility, name string, tValue reflect.Value) error { count++ return nil } - if err := parser.Visit(a, "", compiled.Unset, counter, reflect.TypeOf(Variable{})); err != nil { - return err - } + // ignoring error, counter() always return nil + _ = parser.Visit(a, "", compiled.Unset, counter, reflect.TypeOf(Variable{})) + // no variables in nested struct, we use fmt std print function if count == 0 { sbb.WriteString(fmt.Sprint(a)) - return nil + return } + sbb.WriteByte('{') printer := func(visibility compiled.Visibility, name string, tValue reflect.Value) error { count-- @@ -88,12 +89,9 @@ func printArg(log *compiled.LogEntry, sbb *strings.Builder, a interface{}) error log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) return nil } - if err := parser.Visit(a, "", compiled.Unset, printer, reflect.TypeOf(Variable{})); err != nil { - return err - } + // ignoring error, printer() doesn't return errors + _ = parser.Visit(a, "", compiled.Unset, printer, reflect.TypeOf(Variable{})) sbb.WriteByte('}') - - return nil } func (cs *ConstraintSystem) addDebugInfo(errName string, i ...interface{}) int {