diff --git a/frontend/cs_to_r1cs_sparse.go b/frontend/cs_to_r1cs_sparse.go index 44c9a809b5..a431f90c68 100644 --- a/frontend/cs_to_r1cs_sparse.go +++ b/frontend/cs_to_r1cs_sparse.go @@ -17,7 +17,11 @@ limitations under the License. package frontend import ( + "crypto/sha256" + "encoding/binary" + "hash" "math/big" + "sort" "sync" "github.com/consensys/gnark-crypto/ecc" @@ -50,6 +54,13 @@ type sparseR1CS struct { solvedVariables []bool currentR1CDebugID int // mark the current R1C debugID + + // map LinearExpression -> Term. The goal is to not reduce + // the same linear expression twice. + record map[string]compiled.Term + + // hash function used to navigate in record + h hash.Hash } var bOne = new(big.Int).SetInt64(1) @@ -73,6 +84,8 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst solvedVariables: make([]bool, len(cs.internal.variables), len(cs.internal.variables)*2), scsInternalVariables: len(cs.internal.variables), currentR1CDebugID: -1, + record: make(map[string]compiled.Term), + h: sha256.New(), } // logs, debugInfo and hints are copied, the only thing that will change @@ -249,6 +262,76 @@ func popInternalVariable(l compiled.LinearExpression, id int) (compiled.LinearEx return _l, t } +// returns ( b/gcd(b...), gcd(b...) ) +func gcd(b []*big.Int, s *big.Int) { + + s.Set(b[0]) + for i := 0; i < len(b); i++ { + s.GCD(nil, nil, s, b[i]) + } + if s.IsUint64() && s.Uint64() == 0 { + return + } + + // ensure the gcd doesn't depend on the sign + if b[0].Sign() == -1 { + s.Neg(s) + } + for i := 0; i < len(b); i++ { + b[i].Div(b[i], s) + } + +} + +// reduce returns ( l/gcd(l.coefs), gcd(l.coefs) ) +func (scs *sparseR1CS) reduce(l compiled.LinearExpression) (compiled.LinearExpression, big.Int) { + + var s big.Int + + // get the coeffs from the linear expression + coeffs := make([]*big.Int, len(l)) + + for i := 0; i < len(l); i++ { + coeffs[i] = bigIntPool.Get().(*big.Int) + coeffs[i].Set(&scs.coeffs[l[i].CoeffID()]) + } + + // compute gcd + gcd(coeffs, &s) + + // resulting linear expression + _l := make(compiled.LinearExpression, len(l)) + copy(_l, l) + for i := 0; i < len(_l); i++ { + id := scs.coeffID(coeffs[i]) + bigIntPool.Put(coeffs[i]) + _l[i].SetCoeffID(id) + } + return _l, s + +} + +// getKeyPrimitive returns id of l, assuming that l is primitive +func (scs *sparseR1CS) GetKey(primitiveLinExp compiled.LinearExpression) string { + + // sort l to have a unique non ambiguous id + l := make(compiled.LinearExpression, len(primitiveLinExp)) + copy(l, primitiveLinExp) + if !sort.IsSorted(l) { // not sure that helps + sort.Sort(l) + } + + // get the id + var b [8]byte + scs.h.Reset() + for i := 0; i < len(l); i++ { + binary.LittleEndian.PutUint64(b[:], uint64(l[i])) + scs.h.Write(b[:]) + } + return string(scs.h.Sum(nil)) + +} + // pops the constant associated to the one_wire in the cs, which will become // a constant in a PLONK constraint. // @@ -272,7 +355,7 @@ func (scs *sparseR1CS) popConstantTerm(l compiled.LinearExpression) (compiled.Li return l, big.Int{} } -// newTerm creates a new term =1*new_variable and records it in the scs +// newTerm creates a new term =coeff*new_variable and records it in the scs // if idCS is set, uses it as variable id and does not increment the number // of new internal variables created func (scs *sparseR1CS) newTerm(coeff *big.Int, idCS ...int) compiled.Term { @@ -387,34 +470,43 @@ func (scs *sparseR1CS) multiply(t compiled.Term, c *big.Int) compiled.Term { return t } -// split splits a linear expression to plonk constraints -// ex: le = aiwi is split into PLONK constraints (using sums) -// of 3 terms) like this: -// w0' = a0w0+a1w1 -// w1' = w0' + a2w2 -// .. -// wn' = wn-1'+an-2wn-2 -// split returns a term that is equal to aiwi (it's 1xaiwi) -// no side effects on le -func (scs *sparseR1CS) split(a compiled.Term, l compiled.LinearExpression) compiled.Term { +func (scs *sparseR1CS) split(l compiled.LinearExpression) compiled.Term { // floor case - if len(l) == 0 { - return a + if len(l) == 1 { + return l[0] } - // first call - if a == 0 { - return scs.split(l[0], l[1:]) + // check if l is recorded, if so we get it from the record + _l, s := scs.reduce(l) + k := scs.GetKey(_l) + if t, ok := scs.record[k]; ok { + t.SetCoeffID(scs.coeffID(&s)) + return t } - // recursive case - r := l[0] + // find if in the left side the constraint is recorded + for i := len(l) - 1; i > 0; i-- { + ll, _s := scs.reduce(_l[:i]) + _k := scs.GetKey(ll) + if t, ok := scs.record[_k]; ok { + t = scs.multiply(t, &_s) + o := scs.newTerm(bOne) + _o := scs.negate(o) + b := scs.split(_l[i:]) + scs.addConstraint(compiled.SparseR1C{L: t, R: b, O: _o}) + scs.record[k] = o + return scs.multiply(o, &s) + } + } + // else we build the reduction starting from l[0] o := scs.newTerm(bOne) - scs.addConstraint(compiled.SparseR1C{L: a, R: r, O: o}) - o = scs.negate(o) - return scs.split(o, l[1:]) - + _o := scs.negate(o) + a := _l[0] + b := scs.split(_l[1:]) + scs.addConstraint(compiled.SparseR1C{L: a, R: b, O: _o}) + scs.record[k] = o + return scs.multiply(o, &s) } // r1cToSparseR1C splits a r1c constraint @@ -473,7 +565,7 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // cL*(r + cR) = toSolve + cO f2 := func() { - rt := scs.split(0, r) + rt := scs.split(r) cRT := scs.multiply(rt, &cL) cK.Mul(&cL, &cR) @@ -489,7 +581,7 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // (l + cL)*cR = toSolve + cO f3 := func() { - lt := scs.split(0, l) + lt := scs.split(l) cRLT := scs.multiply(lt, &cR) cK.Mul(&cL, &cR) @@ -504,8 +596,8 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // (l + cL)*(r + cR) = toSolve + cO f4 := func() { - lt := scs.split(0, l) - rt := scs.split(0, r) + lt := scs.split(l) + rt := scs.split(r) cRLT := scs.multiply(lt, &cR) cRT := scs.multiply(rt, &cL) @@ -523,7 +615,7 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // cL*cR = toSolve + o + cO f5 := func() { - ot := scs.split(0, o) + ot := scs.split(o) cK.Mul(&cL, &cR) cK.Sub(&cK, &cO) @@ -538,8 +630,8 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // cL*(r + cR) = toSolve + o + cO f6 := func() { - rt := scs.split(0, r) - ot := scs.split(0, o) + rt := scs.split(r) + ot := scs.split(o) cRT := scs.multiply(rt, &cL) cK.Mul(&cL, &cR) @@ -556,8 +648,8 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // (l + cL)*cR = toSolve + o + cO f7 := func() { - lt := scs.split(0, l) - ot := scs.split(0, o) + lt := scs.split(l) + ot := scs.split(o) cRLT := scs.multiply(lt, &cR) cK.Mul(&cL, &cR) @@ -574,9 +666,10 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // (l + cL)*(r + cR) = toSolve + o + cO f8 := func() { - lt := scs.split(0, l) - rt := scs.split(0, r) - ot := scs.split(0, o) + + lt := scs.split(l) + rt := scs.split(r) + ot := scs.split(o) cRLT := scs.multiply(lt, &cR) cRT := scs.multiply(rt, &cL) @@ -617,7 +710,7 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { f10 := func() { res := scs.newTerm(&cS, idCS) - rt := scs.split(0, r) + rt := scs.split(r) cRT := scs.multiply(rt, &cL) cRes := scs.multiply(res, &cR) @@ -634,7 +727,8 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // (toSolve + l + cL)*cR = cO f11 := func() { - lt := scs.split(0, l) + + lt := scs.split(l) lt = scs.multiply(lt, &cR) cK.Mul(&cL, &cR) @@ -653,8 +747,9 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // => toSolve*r + toSolve*cR + [ l*r + l*cR +cL*r+cL*cR-cO ]=0 f12 := func() { u := scs.newTerm(bOne) - lt := scs.split(0, l) - rt := scs.split(0, r) + + lt := scs.split(l) + rt := scs.split(r) cRLT := scs.multiply(lt, &cR) cRT := scs.multiply(rt, &cL) @@ -681,7 +776,7 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // (toSolve + cL)*cR = o + cO f13 := func() { - ot := scs.split(0, o) + ot := scs.split(o) cK.Mul(&cL, &cR) cK.Sub(&cK, &cO) @@ -698,10 +793,10 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // (toSolve + cL)*(r + cR) = o + cO // toSolve*r + toSolve*cR+cL*r+cL*cR-cO-o=0 f14 := func() { - ot := scs.split(0, o) + ot := scs.split(o) res := scs.newTerm(&cS, idCS) - rt := scs.split(0, r) + rt := scs.split(r) cK.Mul(&cL, &cR) cK.Sub(&cK, &cO) @@ -718,8 +813,9 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // (toSolve + l + cL)*cR = o + cO // toSolve*cR + l*cR + cL*cR-cO-o=0 f15 := func() { - ot := scs.split(0, o) - lt := scs.split(0, l) + + ot := scs.split(o) + lt := scs.split(l) cK.Mul(&cL, &cR) cK.Sub(&cK, &cO) @@ -739,8 +835,9 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { f16 := func() { // [l*r + l*cR +cL*r+cL*cR-cO] + u = 0 u := scs.newTerm(bOne) - lt := scs.split(0, l) - rt := scs.split(0, r) + + lt := scs.split(l) + rt := scs.split(r) cRLT := scs.multiply(lt, &cR) cRT := scs.multiply(rt, &cL) @@ -757,7 +854,7 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { // u+o+v = 0 (v = -u - o = [l*r + l*cR +cL*r+cL*cR-cO] - o) v := scs.newTerm(bOne) - ot := scs.split(0, o) + ot := scs.split(o) scs.addConstraint(compiled.SparseR1C{ L: u, R: ot, @@ -792,56 +889,36 @@ func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { switch s { case 0b0000: - // (toSolve + cL)*cR = cO f9() case 0b0001: - // (toSolve + cL)*(r + cR) = cO f10() case 0b0010: - // (toSolve + l + cL)*cR = cO f11() case 0b0011: - // (toSolve + l + cL)*(r + cR) = cO - // => toSolve*r + toSolve*cR + [ l*r + l*cR +cL*r+cL*cR-cO ]=0 f12() case 0b0100: - // (toSolve + cL)*cR = o + cO f13() case 0b0101: - // (toSolve + cL)*(r + cR) = o + cO - // toSolve*r + toSolve*cR+cL*r+cL*cR-cO-o=0 f14() case 0b0110: - // (toSolve + l + cL)*cR = o + cO - // toSolve*cR + l*cR + cL*cR-cO-o=0 f15() case 0b0111: - // (toSolve + l + cL)*(r + cR) = o + cO - // => toSolve*r + toSolve*cR + [ [l*r + l*cR +cL*r+cL*cR-cO]- o ]=0 f16() case 0b1000: - // cL*cR = toSolve + cO f1() case 0b1001: - // cL*(r + cR) = toSolve + cO f2() case 0b1010: - // (l + cL)*cR = toSolve + cO f3() case 0b1011: - // (l + cL)*(r + cR) = toSolve + cO f4() case 0b1100: - // cL*cR = toSolve + o + cO f5() case 0b1101: - // cL*(r + cR) = toSolve + o + cO f6() case 0b1110: - // (l + cL)*cR = toSolve + o + cO f7() case 0b1111: - // (l + cL)*(r + cR) = toSolve + o + cO f8() } @@ -876,7 +953,8 @@ func (scs *sparseR1CS) splitR1C(r1c compiled.R1C) { } else { // cL*(r + cR) = cO - rt := scs.split(0, r) + //rt := scs.split(0, r) + rt := scs.split(r) cRLT := scs.multiply(rt, &cL) cK.Mul(&cL, &cR) @@ -888,7 +966,8 @@ func (scs *sparseR1CS) splitR1C(r1c compiled.R1C) { } else { if len(r) == 0 { // (l + cL)*cR = cO - lt := scs.split(0, l) + //lt := scs.split(0, l) + lt := scs.split(l) cRLT := scs.multiply(lt, &cR) cK.Mul(&cL, &cR) @@ -898,8 +977,10 @@ func (scs *sparseR1CS) splitR1C(r1c compiled.R1C) { } else { // (l + cL)*(r + cR) = cO - lt := scs.split(0, l) - rt := scs.split(0, r) + // lt := scs.split(0, l) + // rt := scs.split(0, r) + lt := scs.split(l) + rt := scs.split(r) cRLT := scs.multiply(lt, &cR) cRT := scs.multiply(rt, &cL) @@ -920,7 +1001,8 @@ func (scs *sparseR1CS) splitR1C(r1c compiled.R1C) { if len(r) == 0 { // cL*cR = o + cO - ot := scs.split(0, o) + //ot := scs.split(0, o) + ot := scs.split(o) cK.Mul(&cL, &cR) cK.Sub(&cK, &cO) @@ -929,8 +1011,10 @@ func (scs *sparseR1CS) splitR1C(r1c compiled.R1C) { } else { // cL * (r + cR) = o + cO - rt := scs.split(0, r) - ot := scs.split(0, o) + // rt := scs.split(0, r) + // ot := scs.split(0, o) + rt := scs.split(r) + ot := scs.split(o) cRT := scs.multiply(rt, &cL) cK.Mul(&cL, &cR) @@ -946,8 +1030,10 @@ func (scs *sparseR1CS) splitR1C(r1c compiled.R1C) { } else { if len(r) == 0 { // (l + cL) * cR = o + cO - lt := scs.split(0, l) - ot := scs.split(0, o) + // lt := scs.split(0, l) + // ot := scs.split(0, o) + lt := scs.split(l) + ot := scs.split(o) cRLT := scs.multiply(lt, &cR) cK.Mul(&cL, &cR) @@ -960,9 +1046,12 @@ func (scs *sparseR1CS) splitR1C(r1c compiled.R1C) { }) } else { // (l + cL)*(r + cR) = o + cO - lt := scs.split(0, l) - rt := scs.split(0, r) - ot := scs.split(0, o) + // lt := scs.split(0, l) + // rt := scs.split(0, r) + // ot := scs.split(0, o) + lt := scs.split(l) + rt := scs.split(r) + ot := scs.split(o) cRT := scs.multiply(rt, &cL) cRLT := scs.multiply(lt, &cR) diff --git a/std/algebra/sw/g1.go b/std/algebra/sw/g1.go index d9cbfe9f9b..0da36e1fa6 100644 --- a/std/algebra/sw/g1.go +++ b/std/algebra/sw/g1.go @@ -198,36 +198,32 @@ func (p *G1Affine) FromJac(cs *frontend.ConstraintSystem, p1 *G1Jac) *G1Affine { // Double double a point in affine coords func (p *G1Affine) Double(cs *frontend.ConstraintSystem, p1 *G1Affine) *G1Affine { - var t, d, c1, c2, c3 big.Int + var t, d, c2, c3 big.Int t.SetInt64(3) d.SetInt64(2) - c1.SetInt64(1) c2.SetInt64(-2) c3.SetInt64(-1) // compute lambda = (3*p1.x**2+a)/2*p1.y, here we assume a=0 (j invariant 0 curve) x2 := cs.Mul(p1.X, p1.X) - cs.Mul(p1.X, p1.X) l1 := cs.Mul(x2, t) l2 := cs.Mul(p1.Y, d) l := cs.Div(l1, l2) // xr = lambda**2-p.x-p1.x - _x1 := cs.Mul(l, l, c1) + _x1 := cs.Mul(l, l) _x2 := cs.Mul(p1.X, c2) _x := cs.Add(_x1, _x2) // p.y = lambda(p.x-xr) - p.y t1 := cs.Mul(p1.X, l) t2 := cs.Mul(l, _x) - l31 := cs.Mul(t1, c1) l32 := cs.Mul(t2, c3) l33 := cs.Mul(p1.Y, c3) - l3 := cs.Add(l31, l32, l33) - p.Y = cs.Mul(l3, 1) + p.Y = cs.Add(t1, l32, l33) //p.x = xr - p.X = cs.Mul(_x, 1) + p.X = _x return p } diff --git a/std/algebra/sw/g1_test.go b/std/algebra/sw/g1_test.go index e2a780d2b8..81bcbed5cd 100644 --- a/std/algebra/sw/g1_test.go +++ b/std/algebra/sw/g1_test.go @@ -24,6 +24,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/backend/plonk" "github.com/consensys/gnark/frontend" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" @@ -172,27 +173,33 @@ func (circuit *g1DoubleAffine) Define(curveID ecc.ID, cs *frontend.ConstraintSys func TestDoubleAffineG1(t *testing.T) { // sample 2 random points - _a := randomPointG1() - var a, c bls12377.G1Affine - a.FromJacobian(&_a) + _a, _, a, _ := bls12377.Generators() + var c bls12377.G1Affine // create the cs var circuit, witness g1DoubleAffine - r1cs, err := frontend.Compile(ecc.BW6_761, backend.GROTH16, &circuit) - if err != nil { - t.Fatal(err) - } - // assign the inputs + // assign the inputs and compute the result witness.A.Assign(&a) - - // compute the result _a.DoubleAssign() c.FromJacobian(&_a) witness.C.Assign(&c) - - assert := groth16.NewAssert(t) - assert.SolvingSucceeded(r1cs, &witness) + { + r1cs, err := frontend.Compile(ecc.BW6_761, backend.GROTH16, &circuit) + if err != nil { + t.Fatal(err) + } + assert := groth16.NewAssert(t) + assert.SolvingSucceeded(r1cs, &witness) + } + { + r1cs, err := frontend.Compile(ecc.BW6_761, backend.PLONK, &circuit) + if err != nil { + t.Fatal(err) + } + assert := plonk.NewAssert(t) + assert.SolvingSucceeded(r1cs, &witness) + } } @@ -216,22 +223,30 @@ func TestNegG1(t *testing.T) { // sample 2 random points a := randomPointG1() - // create the cs - var circuit, witness g1Neg - r1cs, err := frontend.Compile(ecc.BW6_761, backend.GROTH16, &circuit) - if err != nil { - t.Fatal(err) - } - // assign the inputs + var witness g1Neg witness.A.Assign(&a) - - // compute the result a.Neg(&a) witness.C.Assign(&a) - assert := groth16.NewAssert(t) - assert.SolvingSucceeded(r1cs, &witness) + // create the cs + var circuit g1Neg + { + r1cs, err := frontend.Compile(ecc.BW6_761, backend.GROTH16, &circuit) + if err != nil { + t.Fatal(err) + } + assert := groth16.NewAssert(t) + assert.SolvingSucceeded(r1cs, &witness) + } + { + r1cs, err := frontend.Compile(ecc.BW6_761, backend.PLONK, &circuit) + if err != nil { + t.Fatal(err) + } + assert := plonk.NewAssert(t) + assert.SolvingSucceeded(r1cs, &witness) + } } @@ -265,22 +280,32 @@ func TestScalarMulG1(t *testing.T) { // create the cs var circuit, witness g1ScalarMul circuit.r = r - r1cs, err := frontend.Compile(ecc.BW6_761, backend.GROTH16, &circuit) - if err != nil { - t.Fatal(err) - } - - // assign the inputs witness.A.Assign(&a) - // compute the result var br big.Int _a.ScalarMultiplication(&_a, r.ToBigIntRegular(&br)) c.FromJacobian(&_a) witness.C.Assign(&c) - - assert := groth16.NewAssert(t) - assert.SolvingSucceeded(r1cs, &witness) + { + r1cs, err := frontend.Compile(ecc.BW6_761, backend.GROTH16, &circuit) + if err != nil { + t.Fatal(err) + } + + // assign the inputs + assert := groth16.NewAssert(t) + assert.SolvingSucceeded(r1cs, &witness) + } + { + circuit.r = r + r1cs, err := frontend.Compile(ecc.BW6_761, backend.PLONK, &circuit) + if err != nil { + t.Fatal(err) + } + + assert := plonk.NewAssert(t) + assert.SolvingSucceeded(r1cs, &witness) + } }