Skip to content

Commit

Permalink
Merge pull request #280 from ConsenSys/simplify-r1cs-compile
Browse files Browse the repository at this point in the history
simplify post compile phase: shifting wire id is not needed 🙈
  • Loading branch information
gbotrel authored Mar 10, 2022
2 parents 4782151 + 121b111 commit e791d2e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 187 deletions.
22 changes: 15 additions & 7 deletions frontend/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
//
// initialCapacity is an optional parameter that reserves memory in slices
// it should be set to the estimated number of constraints in the circuit, if known.
func Compile(curveID ecc.ID, newCompiler NewBuilder, circuit Circuit, opts ...CompileOption) (CompiledConstraintSystem, error) {
func Compile(curveID ecc.ID, newBuilder NewBuilder, circuit Circuit, opts ...CompileOption) (CompiledConstraintSystem, error) {
// parse options
opt := CompileConfig{}
for _, o := range opts {
Expand All @@ -37,21 +37,21 @@ func Compile(curveID ecc.ID, newCompiler NewBuilder, circuit Circuit, opts ...Co
}
}

// instantiate new compiler
compiler, err := newCompiler(curveID, opt)
// instantiate new builder
builder, err := newBuilder(curveID, opt)
if err != nil {
return nil, fmt.Errorf("new compiler: %w", err)
}

// parse the circuit builds a schema of the circuit
// and call circuit.Define() method to initialize a list of constraints in the compiler
if err = parseCircuit(compiler, circuit); err != nil {
if err = parseCircuit(builder, circuit); err != nil {
return nil, fmt.Errorf("parse circuit: %w", err)

}

// compile the circuit into its final form
return compiler.Compile()
return builder.Compile()
}

func parseCircuit(builder Builder, circuit Circuit) (err error) {
Expand All @@ -60,6 +60,15 @@ func parseCircuit(builder Builder, circuit Circuit) (err error) {
return errors.New("frontend.Circuit methods must be defined on pointer receiver")
}

// parse the schema, to count the number of public and secret variables
s, err := schema.Parse(circuit, tVariable, nil)
if err != nil {
return err
}

// this not only set the schema, but sets the wire offsets for public, secret and internal wires
builder.SetSchema(s)

// leaf handlers are called when encoutering leafs in the circuit data struct
// leafs are Constraints that need to be initialized in the context of compiling a circuit
var handler schema.LeafHandler = func(visibility schema.Visibility, name string, tInput reflect.Value) error {
Expand All @@ -79,11 +88,10 @@ func parseCircuit(builder Builder, circuit Circuit) (err error) {
}
// recursively parse through reflection the circuits members to find all Constraints that need to be allocated
// (secret or public inputs)
s, err := schema.Parse(circuit, tVariable, handler)
_, err = schema.Parse(circuit, tVariable, handler)
if err != nil {
return err
}
builder.SetSchema(s)

// recover from panics to print user-friendlier messages
defer func() {
Expand Down
108 changes: 17 additions & 91 deletions frontend/cs/r1cs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func newBuilder(curveID ecc.ID, config frontend.CompileConfig) *r1cs {
// newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets
// the wire's id to the number of wires, and returns it
func (system *r1cs) newInternalVariable() compiled.LinearExpression {
idx := system.NbInternalVariables
idx := system.NbInternalVariables + system.NbPublicVariables + system.NbSecretVariables
system.NbInternalVariables++
return compiled.LinearExpression{
compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal),
Expand All @@ -97,9 +97,6 @@ func (system *r1cs) newInternalVariable() compiled.LinearExpression {

// AddPublicVariable creates a new public Variable
func (system *r1cs) AddPublicVariable(name string) frontend.Variable {
if system.Schema != nil {
panic("do not call AddPublicVariable in circuit.Define()")
}
idx := len(system.Public)
system.Public = append(system.Public, name)
return compiled.LinearExpression{
Expand All @@ -109,10 +106,7 @@ func (system *r1cs) AddPublicVariable(name string) frontend.Variable {

// AddSecretVariable creates a new secret Variable
func (system *r1cs) AddSecretVariable(name string) frontend.Variable {
if system.Schema != nil {
panic("do not call AddSecretVariable in circuit.Define()")
}
idx := len(system.Secret)
idx := len(system.Secret) + system.NbPublicVariables
system.Secret = append(system.Secret, name)
return compiled.LinearExpression{
compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret),
Expand Down Expand Up @@ -273,14 +267,19 @@ func (system *r1cs) checkVariables() error {
cptPublic--
}
case schema.Secret:
vID -= system.NbPublicVariables
if !secretConstrained[vID] {
secretConstrained[vID] = true
cptSecret--
}
case schema.Internal:
if _, ok := system.MHints[vID]; !mHintsConstrained[vID] && ok {
mHintsConstrained[vID] = true
cptHints--

if _, ok := system.MHints[vID]; ok {
vID -= (system.NbPublicVariables + system.NbSecretVariables)
if !mHintsConstrained[vID] {
mHintsConstrained[vID] = true
cptHints--
}
}
}
}
Expand Down Expand Up @@ -359,88 +358,13 @@ func (cs *r1cs) Compile() (frontend.CompiledConstraintSystem, error) {
ConstraintSystem: cs.ConstraintSystem,
Constraints: cs.Constraints,
}
res.NbPublicVariables = len(cs.Public)
res.NbSecretVariables = len(cs.Secret)

// for Logs, DebugInfo and hints the only thing that will change
// is that ID of the wires will be offseted to take into account the final wire vector ordering
// that is: public wires | secret wires | internal wires

// offset variable ID depeneding on visibility
shiftVID := func(oldID int, visibility schema.Visibility) int {
switch visibility {
case schema.Internal:
return oldID + res.NbPublicVariables + res.NbSecretVariables
case schema.Public:
return oldID
case schema.Secret:
return oldID + res.NbPublicVariables
}
return oldID
}

// we just need to offset our ids, such that wires = [ public wires | secret wires | internal wires ]
offsetIDs := func(l compiled.LinearExpression) {
for j := 0; j < len(l); j++ {
_, vID, visibility := l[j].Unpack()
l[j].SetWireID(shiftVID(vID, visibility))
}
// sanity check
if res.NbPublicVariables != len(cs.Public) || res.NbPublicVariables != cs.Schema.NbPublic+1 {
panic("number of public variables is inconsitent") // it grew after the schema parsing?
}

for i := 0; i < len(res.Constraints); i++ {
offsetIDs(res.Constraints[i].L)
offsetIDs(res.Constraints[i].R)
offsetIDs(res.Constraints[i].O)
}

// we need to offset the ids in the hints
shiftedMap := make(map[int]*compiled.Hint)

// we need to offset the ids in the hints
HINTLOOP:
for _, hint := range cs.MHints {
ws := make([]int, len(hint.Wires))
// we set for all outputs in shiftedMap. If one shifted output
// is in shiftedMap, then all are
for i, vID := range hint.Wires {
ws[i] = shiftVID(vID, schema.Internal)
if _, ok := shiftedMap[ws[i]]; i == 0 && ok {
continue HINTLOOP
}
}
inputs := make([]interface{}, len(hint.Inputs))
copy(inputs, hint.Inputs)
for j := 0; j < len(inputs); j++ {
switch t := inputs[j].(type) {
case compiled.LinearExpression:
tmp := make(compiled.LinearExpression, len(t))
copy(tmp, t)
offsetIDs(tmp)
inputs[j] = tmp
default:
inputs[j] = t
}
}
ch := &compiled.Hint{ID: hint.ID, Inputs: inputs, Wires: ws}
for _, vID := range ws {
shiftedMap[vID] = ch
}
}
res.MHints = shiftedMap

// we need to offset the ids in Logs & DebugInfo
for i := 0; i < len(cs.Logs); i++ {

for j := 0; j < len(res.Logs[i].ToResolve); j++ {
_, vID, visibility := res.Logs[i].ToResolve[j].Unpack()
res.Logs[i].ToResolve[j].SetWireID(shiftVID(vID, visibility))
}
}
for i := 0; i < len(cs.DebugInfo); i++ {
for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ {
_, vID, visibility := res.DebugInfo[i].ToResolve[j].Unpack()
res.DebugInfo[i].ToResolve[j].SetWireID(shiftVID(vID, visibility))
}
if res.NbSecretVariables != len(cs.Secret) || res.NbSecretVariables != cs.Schema.NbSecret {
panic("number of secret variables is inconsitent") // it grew after the schema parsing?
}

// build levels
Expand Down Expand Up @@ -469,6 +393,8 @@ func (cs *r1cs) SetSchema(s *schema.Schema) {
panic("SetSchema called multiple times")
}
cs.Schema = s
cs.NbPublicVariables = s.NbPublic + 1
cs.NbSecretVariables = s.NbSecret
}

func buildLevels(ccs compiled.R1CS) [][]int {
Expand Down
106 changes: 17 additions & 89 deletions frontend/cs/scs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,27 +105,21 @@ func (system *scs) addPlonkConstraint(l, r, o compiled.Term, cidl, cidr, cidm1,
// newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets
// the wire's id to the number of wires, and returns it
func (system *scs) newInternalVariable() compiled.Term {
idx := system.NbInternalVariables
idx := system.NbInternalVariables + system.NbPublicVariables + system.NbSecretVariables
system.NbInternalVariables++
return compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal)
}

// AddPublicVariable creates a new Public Variable
func (system *scs) AddPublicVariable(name string) frontend.Variable {
if system.Schema != nil {
panic("do not call AddPublicVariable in circuit.Define()")
}
idx := len(system.Public)
system.Public = append(system.Public, name)
return compiled.Pack(idx, compiled.CoeffIdOne, schema.Public)
}

// AddSecretVariable creates a new Secret Variable
func (system *scs) AddSecretVariable(name string) frontend.Variable {
if system.Schema != nil {
panic("do not call AddSecretVariable in circuit.Define()")
}
idx := len(system.Secret)
idx := len(system.Secret) + system.NbPublicVariables
system.Secret = append(system.Secret, name)
return compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret)
}
Expand Down Expand Up @@ -223,14 +217,19 @@ func (system *scs) checkVariables() error {
cptPublic--
}
case schema.Secret:
vID -= system.NbPublicVariables
if !secretConstrained[vID] {
secretConstrained[vID] = true
cptSecret--
}
case schema.Internal:
if _, ok := system.MHints[vID]; !mHintsConstrained[vID] && ok {
mHintsConstrained[vID] = true
cptHints--
if _, ok := system.MHints[vID]; ok {
vID -= (system.NbPublicVariables + system.NbSecretVariables)
if !mHintsConstrained[vID] {
mHintsConstrained[vID] = true
cptHints--
}

}
}
}
Expand Down Expand Up @@ -307,86 +306,13 @@ func (cs *scs) Compile() (frontend.CompiledConstraintSystem, error) {
ConstraintSystem: cs.ConstraintSystem,
Constraints: cs.Constraints,
}
res.NbPublicVariables = len(cs.Public)
res.NbSecretVariables = len(cs.Secret)

// Logs, DebugInfo and hints are copied, the only thing that will change
// is that ID of the wires will be offseted to take into account the final wire vector ordering
// that is: public wires | secret wires | internal wires

// shift variable ID
// we want publicWires | privateWires | internalWires
shiftVID := func(oldID int, visibility schema.Visibility) int {
switch visibility {
case schema.Internal:
return oldID + res.NbPublicVariables + res.NbSecretVariables
case schema.Public:
return oldID
case schema.Secret:
return oldID + res.NbPublicVariables
default:
return oldID
}
// sanity check
if res.NbPublicVariables != len(cs.Public) || res.NbPublicVariables != cs.Schema.NbPublic {
panic("number of public variables is inconsitent") // it grew after the schema parsing?
}

offsetTermID := func(t *compiled.Term) {
_, VID, visibility := t.Unpack()
t.SetWireID(shiftVID(VID, visibility))
}

// offset the IDs of all constraints so that the variables are
// numbered like this: [publicVariables | secretVariables | internalVariables ]
for i := 0; i < len(res.Constraints); i++ {
r1c := &res.Constraints[i]
offsetTermID(&r1c.L)
offsetTermID(&r1c.R)
offsetTermID(&r1c.O)
offsetTermID(&r1c.M[0])
offsetTermID(&r1c.M[1])
}

// we need to offset the ids in Logs & DebugInfo
for i := 0; i < len(cs.Logs); i++ {
for j := 0; j < len(res.Logs[i].ToResolve); j++ {
offsetTermID(&res.Logs[i].ToResolve[j])
}
}
for i := 0; i < len(cs.DebugInfo); i++ {
for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ {
offsetTermID(&res.DebugInfo[i].ToResolve[j])
}
}

// we need to offset the ids in the hints
shiftedMap := make(map[int]*compiled.Hint)
HINTLOOP:
for _, hint := range cs.MHints {
ws := make([]int, len(hint.Wires))
// we set for all outputs in shiftedMap. If one shifted output
// is in shiftedMap, then all are
for i, vID := range hint.Wires {
ws[i] = shiftVID(vID, schema.Internal)
if _, ok := shiftedMap[ws[i]]; i == 0 && ok {
continue HINTLOOP
}
}
inputs := make([]interface{}, len(hint.Inputs))
copy(inputs, hint.Inputs)
for j := 0; j < len(inputs); j++ {
switch t := inputs[j].(type) {
case compiled.Term:
offsetTermID(&t)
inputs[j] = t // TODO check if we can remove it
default:
inputs[j] = t
}
}
ch := &compiled.Hint{ID: hint.ID, Inputs: inputs, Wires: ws}
for _, vID := range ws {
shiftedMap[vID] = ch
}
if res.NbSecretVariables != len(cs.Secret) || res.NbSecretVariables != cs.Schema.NbSecret {
panic("number of secret variables is inconsitent") // it grew after the schema parsing?
}
res.MHints = shiftedMap

// build levels
res.Levels = buildLevels(res)
Expand Down Expand Up @@ -415,6 +341,8 @@ func (cs *scs) SetSchema(s *schema.Schema) {
panic("SetSchema called multiple times")
}
cs.Schema = s
cs.NbPublicVariables = s.NbPublic
cs.NbSecretVariables = s.NbSecret
}

func buildLevels(ccs compiled.SparseR1CS) [][]int {
Expand Down

0 comments on commit e791d2e

Please sign in to comment.